Source code for enerzyme.tasks.metrics

from typing import Dict, Callable, Tuple, List, Union, Optional
import numpy as np
from ..data import is_atomic, get_tensor_rank
from ..utils.base_logger import logger


[docs] def build_single_metric(metric_str: str) -> Callable[[Dict[str, Union[List, np.ndarray]], Dict[str, Union[List, np.ndarray]], str], Optional[float]]: if metric_str == "rmse": try: from sklearn.metrics import root_mean_squared_error except ImportError: from sklearn.metrics import mean_squared_error metric_func = lambda x, y: mean_squared_error(x, y, squared=False) else: metric_func = lambda x, y: root_mean_squared_error(x, y) elif metric_str == "mae": from sklearn.metrics import mean_absolute_error metric_func = lambda x, y: mean_absolute_error(x, y) else: raise ValueError(f"Unknown metric: {metric_str}") def metric(label: Dict[str, Union[List, np.ndarray]], prediction: Dict[str, Union[List, np.ndarray]], target_name: str) -> Optional[float]: y_true = label.get(target_name, []) if not y_true: return 0 y_pred = prediction[target_name] if is_atomic(target_name) or get_tensor_rank(target_name): y_trues, y_preds = np.concatenate(y_true), np.concatenate(y_pred) if y_preds.ndim == y_trues.ndim + 1: score = metric_func(y_trues, np.mean(y_preds, axis=-1)) else: score = metric_func(y_trues, y_preds) else: y_trues, y_preds = np.array(y_true), np.array(y_pred) if y_preds.ndim == y_trues.ndim + 1: score = metric_func(y_trues, np.mean(y_preds, axis=-1)) else: score = metric_func(y_true, y_pred) return score return metric
[docs] class Metrics(object):
[docs] def __init__(self, metric_config: Dict=dict()) -> None: self.metric_config = dict() self.metrics_register = dict() for target, metrics in metric_config.items(): for metric, weight in metrics.items(): self.metric_config[f"{target}_{metric}"] = weight if metric not in self.metrics_register: self.metrics_register[metric] = build_single_metric(metric)
def __str__(self) -> str: terms = [] for target_metric, weight in self.metric_config.items(): if weight == 1: terms.append(target_metric) elif weight is not None and weight != 0: terms.append(f"{weight:.2f} * {target_metric}") return " + ".join(terms)
[docs] def cal_single_metric(self, label: Dict[str, Union[List, np.ndarray]], prediction: Dict[str, Union[List, np.ndarray]], target_name: str, metric_name: str) -> float: return self.metrics_register[metric_name](label, prediction, target_name)
[docs] def cal_judge_score(self, raw_metric_score: Dict[str, float]) -> float: judge_score = 0 for target_metric, weight in self.metric_config.items(): if weight is not None and weight != 0: judge_score += weight * raw_metric_score[target_metric] return judge_score
[docs] def cal_metric(self, label: Dict[str, Union[List, np.ndarray]], predict: Dict[str, Union[List, np.ndarray]]) -> Dict[str, float]: raw_metric_score = dict() for target_metric in self.metric_config: raw_metric_score[target_metric] = self.cal_single_metric(label, predict, *target_metric.split("_")) raw_metric_score["_judge_score"] = self.cal_judge_score(raw_metric_score) return raw_metric_score
def _early_stop_choice(self, wait: int, best_score: float, metric_score: Dict[str, float], save_handle: Callable, patience: int, epoch: int) -> Tuple[bool, float, int]: judge_score = metric_score.get("_judge_score", self.cal_judge_score(metric_score)) return self._judge_early_stop_decrease(wait, judge_score, best_score, save_handle, patience, epoch) def _judge_early_stop_decrease(self, wait: int, score: float, min_score: float, save_handle: Callable, patience: int, epoch: int) -> Tuple[bool, float, int]: is_early_stop = False saved = False if score <= min_score: min_score = score wait = 0 save_handle(best_score=score, best_epoch=epoch, epoch=epoch) saved = True elif score >= min_score: wait += 1 if wait == patience: logger.warning(f'Early stopping at epoch: {epoch+1}') is_early_stop = True return is_early_stop, min_score, wait, saved