Source code for enerzyme.tasks.picker

from typing import List, Callable, Literal, Dict, Any
import numpy as np
from tqdm import tqdm
from ..utils import logger


[docs] def build_Fa_picking(criterion: Literal["std_mean", "norm_std_max"]) -> Callable: def picking_func(y_preds, error_lower_bound: float=0.0, error_upper_bound: float=float("inf"), mode: Literal["single", "committee"]="single", stat_only: bool=False) -> Dict[str, Any]: if mode == "single": Fas = y_preds["Fa"] elif mode == "committee": Fas = [np.stack(y_pred, axis=-1) for y_pred in zip(*[y_pred_single["Fa"] for y_pred_single in y_preds])] else: raise ValueError(f"Unknown mode: {mode}") sample_size = len(Fas) if criterion == "std_mean": if mode == "single" and "Fa_var" in y_preds: # shape of Fa_var: (N, 3) est_error = np.array([np.mean(np.sqrt(Fa_var)) for Fa_var in y_preds["Fa_var"]]) else: assert Fas[0].ndim == 3 # shape of Fa should be: (N, 3, ensemble_size) est_error = np.array([np.mean(np.std(Fa, axis=-1, ddof=1)) for Fa in Fas]) elif criterion == "norm_std_max": assert Fas[0].ndim == 3 est_error = [] for Fa in Fas: Fa_mean = np.mean(Fa, axis=-1, keepdims=True) # (N, 3, 1) Fa_norm_dev = np.linalg.norm(Fa - Fa_mean, axis=1) # (N, ensemble_size) Fa_norm_std = np.mean(Fa_norm_dev, axis=-1) # (N, ) est_error.append(np.max(Fa_norm_std)) est_error = np.array(est_error) upper_bool = est_error > error_upper_bound lower_bool = est_error < error_lower_bound picked = np.where(~(upper_bool | lower_bool))[0].tolist() upper = np.where(upper_bool)[0].tolist() lower = np.where(lower_bool)[0].tolist() logger.info(f"Estimated error: {np.mean(est_error):.4f} +/- {np.std(est_error):.4f}") if not stat_only: logger.info(f"({len(picked)} / {sample_size}) picked, {len(lower)} lower, {len(upper)} upper!") return { "picked_indices": picked, "estimated_error_mean": np.mean(est_error), } return picking_func
[docs] def random_picking(y_preds) -> List[int]: sample_size = len(y_preds[0]["Fa"]) picked = list(range(sample_size)) logger.info(f"({len(picked)} / {sample_size}) picked!") return picked
PICKING_REGISTER = { "max_Fa_norm_std": build_Fa_picking("norm_std_max"), "mean_Fa_std": build_Fa_picking("std_mean"), "random": random_picking }