Source code for easy_tpp.utils.metrics

from collections import defaultdict

import numpy as np

from easy_tpp.utils.log_utils import default_logger as logger


[docs]class MetricsHelper: MAXIMIZE = 'maximize' MINIMIZE = 'minimize' _registry_center = defaultdict(tuple)
[docs] @staticmethod def get_metric_function(name): if name in MetricsHelper._registry_center: return MetricsHelper._registry_center[name][0] else: logger.warn(f'Metric is not found: {name}') return None
[docs] @staticmethod def get_metric_direction(name): if name in MetricsHelper._registry_center: return MetricsHelper._registry_center[name][1] else: return None
[docs] @staticmethod def get_all_registered_metric(): return MetricsHelper._registry_center.values
[docs] @staticmethod def register(name, direction, overwrite=True): registry_center = MetricsHelper._registry_center def _add_metric_to_registry(func): if name in registry_center: if overwrite: registry_center[name] = (func, direction) else: logger.warn(f'The metric {name} is already registered, and cannot be overwritten!') else: registry_center[name] = (func, direction) return func return _add_metric_to_registry
[docs] @staticmethod def metrics_dict_to_str(metrics_dict): """ Convert metrics to a string to show in console """ eval_info = '' for k, v in metrics_dict.items(): eval_info += '{0} is {1}, '.format(k, v) return eval_info[:-2]
[docs] @staticmethod def get_metrics_callback_from_names(metric_names): """ Metrics function callbacks """ metric_functions = [] metric_names_ = [] for name in metric_names: metric = MetricsHelper.get_metric_function(name) if metric is not None: metric_functions.append(metric) metric_names_.append(name) def metrics(preds, labels, **kwargs): """ call metrics functions """ res = dict() for metric_name, metric_func in zip(metric_names_, metric_functions): res[metric_name.lower()] = metric_func(preds, labels, **kwargs) return res return metrics
[docs]class MetricsTracker: """Track and record the metrics. """
[docs] def __init__(self): self.current_best = { 'loglike': np.finfo(float).min, 'distance': np.finfo(float).max } self.episode_best = 'NeverUpdated'
[docs] def update_best(self, key, value, epoch): """Update the recorder for the best metrics. Args: key (str): metrics key. value (float): metrics value. epoch (int): num of epoch. Raises: NotImplementedError: for keys other than 'loglike'. Returns: bool: whether the recorder has been updated. """ updated = False if key == 'loglike': if value > self.current_best[key]: updated = True self.current_best[key] = value self.episode_best = epoch else: raise NotImplementedError return updated