Source code for mit_d3m.metrics

# -*- coding: utf-8 -*-

import warnings

from sklearn import metrics
from sklearn.exceptions import UndefinedMetricWarning
from sklearn.preprocessing import LabelBinarizer

warnings.filterwarnings("ignore", category=UndefinedMetricWarning)


[docs]def accuracy(ground_truth, predicted): return metrics.accuracy_score(ground_truth, predicted)
[docs]def f1(ground_truth, predicted): return metrics.f1_score(ground_truth, predicted)
[docs]def f1_micro(ground_truth, predicted): return metrics.f1_score(ground_truth, predicted, average='micro')
[docs]def f1_macro(ground_truth, predicted): return metrics.f1_score(ground_truth, predicted, average='macro')
[docs]def roc_auc(ground_truth, predicted): return metrics.roc_auc_score(ground_truth, predicted)
[docs]def roc_auc_micro(ground_truth, predicted): ground_truth, predicted = _binarize(ground_truth, predicted) return metrics.roc_auc_score(ground_truth, predicted, average='micro')
[docs]def roc_auc_macro(ground_truth, predicted): ground_truth, predicted = _binarize(ground_truth, predicted) return metrics.roc_auc_score(ground_truth, predicted, average='macro')
[docs]def l2(ground_truth, predicted): return (metrics.mean_squared_error(ground_truth, predicted))**0.5
[docs]def avg_l2(ground_truth_l, predicted_l): l2_sum = 0.0 count = 0 for pair in zip(ground_truth_l, predicted_l): l2_sum += l2(pair[0], pair[1]) count += 1 return l2_sum / count
[docs]def l1(ground_truth, predicted): return metrics.mean_absolute_error(ground_truth, predicted)
[docs]def r2(ground_truth, predicted): return metrics.r2_score(ground_truth, predicted)
[docs]def norm_mut_info(ground_truth, predicted): return metrics.normalized_mutual_info_score(ground_truth, predicted)
[docs]def jacc_sim(ground_truth, predicted): return metrics.jaccard_similarity_score(ground_truth, predicted)
[docs]def mean_se(ground_truth, predicted): return metrics.mean_squared_error(ground_truth, predicted)
def _binarize(ground, pred): label_binarizer = LabelBinarizer() return label_binarizer.fit_transform(ground), label_binarizer.transform(pred) # MIT LL defined these strings here: # https://gitlab.datadrivendiscovery.org/MIT-LL/d3m_data_supply/blob/shared/documentation/problemSchema.md#performance-metrics METRICS_DICT = { 'accuracy': accuracy, 'f1': f1, 'f1Micro': f1_micro, 'f1Macro': f1_macro, 'rocAuc': roc_auc, 'rocAucMicro': roc_auc_micro, 'rocAucMacro': roc_auc_macro, 'meanSquaredError': mean_se, 'rootMeanSquaredError': l2, 'rootMeanSquaredErrorAvg': avg_l2, 'meanAbsoluteError': l1, 'rSquared': r2, 'normalizedMutualInformation': norm_mut_info, 'jaccardSimilarityScore': jacc_sim }