Source code for enerzyme.models.modelhub

import os.path as osp
from functools import partial, cmp_to_key
from glob import glob
from collections import defaultdict
from typing import Literal, Optional
from ..utils import logger
from ..data.datahub import DataHub
from ..tasks.trainer import Trainer
from .ff import FF_single, FF_committee


[docs] def get_model_str(model_id, model_params): return f"{model_id}-{model_params.architecture}" + (f"-{model_params.suffix}" if model_params.suffix is not None else "")
[docs] def get_all_possible_paths(pretrain_path: str, prefix: str): return glob(osp.join(pretrain_path, f"{prefix}_best-v*.pth")) + glob(osp.join(pretrain_path, f"{prefix}_last-v*.pth")) + glob(osp.join(pretrain_path, f"{prefix}_best.pth")) + glob(osp.join(pretrain_path, f"{prefix}_last.pth"))
[docs] def get_all_possible_paths_with_rank(pretrain_path: str, model_rank: Optional[int]=None): if model_rank is None: return get_all_possible_paths(pretrain_path, "model") + get_all_possible_paths(pretrain_path, "model0") else: return get_all_possible_paths(pretrain_path, f"model{model_rank}")
[docs] def path_resolve(path, target_preference: Literal["best", "last"]="best", target_model_rank: Optional[int]=None): basename = osp.basename(path).split(".")[0] model_prefix, model_suffix = basename.split("_") if len(model_prefix) == 5: if target_model_rank is None: model_rank = 1 else: model_rank = -1 else: model_rank = int(model_prefix[5:]) version_split = model_suffix.split("-") if len(version_split) == 1: version = 0 else: version = int(version_split[1][1:]) return { "model_rank": model_rank, "preference": int(target_preference == version_split[0]), "version": version, }
[docs] def compare_path(path1, path2, preference: Literal["best", "last"]="best", model_rank: Optional[int]=None) -> int: ''' Compare two paths in the order of - preference: if the preference is "best", then the path of the best checkpoint is prioritized; otherwise, the path of the last checkpoint is prioritized. - model_rank: explicit model rank is prioritized. - version: latest version is prioritized. Params: ---------- path1: str The first path to compare. path2: str The second path to compare. preference: Literal["best", "last"] The preference between the best checkpoint and the last checkpoint. model_rank: Optional[int] The model rank. Only used for deep ensemble to make sure that the model with the correct rank is selected. Returns: ---------- int The difference between the two paths. When the difference is positive, path1 prioritized over path2. ''' path1_info = path_resolve(path1, preference, model_rank) path2_info = path_resolve(path2, preference, model_rank) if path1_info["preference"] != path2_info["preference"]: return path1_info["preference"] - path2_info["preference"] elif path1_info["model_rank"] != path2_info["model_rank"]: return path1_info["model_rank"] - path2_info["model_rank"] else: return path1_info["version"] - path2_info["version"]
[docs] def get_pretrain_path(pretrain_path: Optional[str]=None, preference: Literal["best", "last"]="best", model_rank: Optional[int]=None): if pretrain_path is not None: if osp.isfile(pretrain_path): return pretrain_path elif osp.isdir(pretrain_path): if model_rank == None: model_rank = '' all_possible_paths = sorted( get_all_possible_paths_with_rank(pretrain_path, model_rank), key=cmp_to_key(partial( compare_path, preference=preference, model_rank=model_rank )), reverse=True ) if len(all_possible_paths) == 0: raise FileNotFoundError(f"Pretrained model{' ' if model_rank is None else 'ranked ' + str(model_rank)} not found in {pretrain_path}") return all_possible_paths[0] else: raise FileNotFoundError(f"Pretrained model not found at {pretrain_path}") else: return None
[docs] class ModelHub:
[docs] def __init__(self, datahub: DataHub, trainer: Trainer, **params) -> None: self.datahub = datahub self.models = defaultdict(dict) self.default_trainer = trainer self.default_metrics = trainer.metrics self.out_dir = self.default_trainer.out_dir self._init_models(**params)
def _init_models(self, **params) -> None: for model_id, model_params in {**params.get('internal_FFs', {}), **params.get('external_FFs', {})}.items(): model_str = get_model_str(model_id, model_params) new_trainer_config = model_params.get('Trainer', None) new_metric_config = model_params.get('Metric', None) if model_params['active']: if new_trainer_config is None and new_metric_config is None: trainer = self.default_trainer else: trainer_config = self.default_trainer.config.copy() if new_metric_config is not None: logger.info("init {} custom metrics".format(model_str)) metric_config = new_metric_config else: metric_config = self.default_trainer.metric_config.copy() if new_trainer_config is not None: logger.info("init {} custom train parameters".format(model_str)) for k, v in new_trainer_config.items(): trainer_config[k] = v trainer = Trainer(self.out_dir, metric_config=metric_config, **trainer_config) if trainer.committee_size > 1: logger.info(f"Initiate {model_str} Force Field Committee of {trainer.committee_size}") self.models['FF'][model_str] = FF_committee(trainer.committee_size, self.datahub, trainer, model_str, **model_params) else: logger.info(f"Initiate {model_str} Force Field") self.models['FF'][model_str] = FF_single(self.datahub, trainer, model_str, **model_params)