Source code for enerzyme.tasks.calculator

import pathlib, importlib, sys, os
from typing import Dict, Optional, Any
from torch.nn import Module
import numpy as np
import ase
import torch
from ase.calculators.calculator import Calculator, all_changes
from ase.units import Hartree
from ..data.transform import Transform
from ..data.neighbor_list import full_neighbor_list
from .trainer import _decorate_batch_output, _decorate_batch_input, _to_device


def _init_patch_module(patch_path: str) -> None:
    p = pathlib.Path(patch_path).expanduser().resolve()
    if not p.exists():
        raise FileNotFoundError(p)
    if p.suffix != ".py":
        raise ValueError(f"Plugin must be a .py file, got: {p}")

    # Hash the patch file to get a unique module name
    module_name = f"plugin_{p.stem}_{abs(hash(str(p)))}"
    spec = importlib.util.spec_from_file_location(module_name, str(p))
    if spec is None or spec.loader is None:
        raise ImportError(f"Cannot create module spec for {patch_path}")
    module = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = module
    spec.loader.exec_module(module)
    return module

[docs] class ASECalculator(Calculator): implemented_properties = ["energy", "forces", "dipole", "charges"]
[docs] def __init__( self, model: Module, restart: Optional[str]=None, label: Optional[str]=None, atoms: Optional[ase.Atoms]=None, device: Optional[torch.device]=None, dtype: Optional[torch.dtype]=None, transform: Optional[Transform]=None, neighbor_list_type: Optional[str]="full", Hartree_in_E: float=1, internal_calculator_weight: float=1.0, uncertainty_calculator_config: Optional[Dict[str, Any]]=None, external_calculator: Optional[Calculator]=None, external_calculator_config: Optional[Dict[str, Any]]=None, **params: Dict[str, Any] ) -> None: Calculator.__init__( self, restart=restart, label=label, atoms=atoms, **params ) self.model = model self.positions = None self.device = device self.pbc = np.array([False]) self.cell = None self.cell_offsets = None self.dtype = dtype self.neighbor_list_type = neighbor_list_type self.transform = transform self.Hartree_in_E = Hartree_in_E self.E_conversion_factor = Hartree / self.Hartree_in_E self.internal_calculator_weight = internal_calculator_weight self.external_calculator = external_calculator if self.external_calculator is None: self.external_calculator_weight = None self.use_external_calculator = False else: self.external_calculator_weight = external_calculator_config.get("weight", 0.0) self.use_external_calculator = True self.uncertainty_calculator_config = uncertainty_calculator_config
def _calculate_UDD(self, output: Dict[str, Any], A: float, B: float, NM: Optional[int]=None) -> Dict[str, Any]: results, biases = dict(), dict() if "E_var_grad" in output: E_var = output["E_var"][0] E_var_grad = output["E_var_grad"][0] E_mean = output["E"][0] F_mean = output["Fa"][0] NM = self.model.shallow_ensemble_size else: NM = output["E"][0].shape[-1] E_mean = np.mean(output["E"][0], axis=-1, keepdims=True) # shape: (1, 1) E_dev = output["E"][0] - E_mean # shape: (1, NM) E_var = np.mean(E_dev**2, axis=-1, keepdims=True).squeeze() F_mean = np.mean(output["Fa"][0], axis=-1, keepdims=True) # shape: (NA, 3, 1) F_dev = output["Fa"][0] - F_mean # shape: (NA, 3, NM) E_var_grad = -np.sum(F_dev * E_dev, axis=-1) # shape: (NA, 3) scale = 1 / (2 * NM * B) neg_scaled_var = -E_var * scale exp_neg_scaled_var = np.exp(neg_scaled_var) E_bias = A * (exp_neg_scaled_var - 1) F_bias = A / scale * exp_neg_scaled_var * E_var_grad # shape: (NA, 3) results["energy"] = E_mean.squeeze() * self.E_conversion_factor biases["energy"] = E_bias * self.E_conversion_factor results["forces"] = F_mean.squeeze() * self.E_conversion_factor biases["forces"] = F_bias * self.E_conversion_factor if "M2" in output: results["dipole"] = output["M2"][0] if "Qa" in output: results["charges"] = output["Qa"][0] return results, biases def _calculate_internal(self, atoms: ase.Atoms) -> None: results = dict() features = { "Q": atoms.info.get("charge", 0), "S": atoms.info.get("spin", 0) - 1, "Ra": atoms.positions, "Za": atoms.numbers, "N": len(atoms) } if self.neighbor_list_type == "full": idx_i, idx_j = full_neighbor_list(features["N"]) features["idx_i"] = idx_i features["idx_j"] = idx_j features["N_pair"] = len(idx_i) net_input, _ = _decorate_batch_input( batch=[(features, None)], device=self.device, dtype=self.dtype ) net_input, _ = _to_device((net_input, {}), self.device) net_output = self.model(net_input) output, _ = _decorate_batch_output( output=net_output, features=net_input, targets=None ) self.transform.inverse_transform(output) if self.uncertainty_calculator_config is None: results["energy"] = output["E"][0] * self.E_conversion_factor results["forces"] = output["Fa"][0] * self.E_conversion_factor if "M2" in output: results["dipole"] = output["M2"][0] if "Qa" in output: results["charges"] = output["Qa"][0] biases = dict() else: uncertainty_calculator_name = self.uncertainty_calculator_config.get("name", None) if uncertainty_calculator_name == "UDD": results, biases = self._calculate_UDD(output, **self.uncertainty_calculator_config.get("params")) else: raise ValueError(f"Uncertainty calculator {uncertainty_calculator_name} not supported") return results, biases
[docs] def calculate(self, atoms=None, properties=["energy", "forces", "dipole", "charges"], system_changes=all_changes) -> None: Calculator.calculate(self, atoms, properties, system_changes) if self.internal_calculator_weight != 0 or self.uncertainty_calculator_config is not None: internal_results, biases = self._calculate_internal(atoms) if self.use_external_calculator: external_calculator_properties = ['energy', 'forces'] self.external_calculator.calculate(atoms, properties=properties, system_changes=system_changes) external_results = dict() for property in external_calculator_properties: external_results[property] = self.external_calculator.results[property] for property in properties: if self.use_external_calculator and property in external_calculator_properties: if self.internal_calculator_weight == 0: self.results[property] = self.external_calculator_weight * external_results[property] else: self.results[property] = self.internal_calculator_weight * internal_results[property] + self.external_calculator_weight * external_results[property] else: self.results[property] = self.internal_calculator_weight * internal_results[property] if self.uncertainty_calculator_config is not None and property in ["energy", "forces"]: self.results[property] += biases[property]
[docs] def get_calculator(model_dir: str, device: str="cuda", dtype: str="float64", model_config_path: Optional[str] = None, calculator_patch: Optional[str] = None, neighbor_list_type: Optional[str]="full", Hartree_in_E: float=1, internal_calculator_weight: float=1.0, uncertainty_calculator_config: Optional[Dict[str, Any]]=None, external_calculator: Optional[Calculator]=None, external_calculator_config: Optional[Dict[str, Any]]=None): from ..models import get_model_str, build_model, get_pretrain_path from ..utils import YamlHandler, logger from .trainer import DTYPE_MAPPING, _load_state_dict device = torch.device(device) dtype = DTYPE_MAPPING[dtype] if model_config_path is None: model_config_path = os.path.join(model_dir, 'config.yaml') model_config = YamlHandler(model_config_path).read_yaml() logger.info('Model Config: {}'.format(model_config)) transform = Transform(model_config.Datahub.transforms, simulation_mode=True) if calculator_patch is not None: calculator_patch_module = _init_patch_module(calculator_patch) logger.info(f"Initialized calculator patch module: {calculator_patch}") else: calculator_patch_module = None if calculator_patch_module is not None: external_calculator_name = external_calculator_config.get("name", None) if external_calculator_name is not None: if hasattr(calculator_patch_module, external_calculator_name): external_calculator = getattr(calculator_patch_module, external_calculator_name) else: raise ValueError(f"External calculator {external_calculator_name} not found in {calculator_patch_module}") else: raise ValueError(f"External calculator name not specified!") logger.info(f"Initialized external calculator: {external_calculator_name}") else: external_calculator = None for FF_key, FF_params in (model_config.Modelhub.internal_FFs | model_config.Modelhub.external_FFs).items(): if FF_params.get("active", False): model_str = get_model_str(FF_key, FF_params) model = build_model(FF_params.architecture, FF_params.layers, FF_params.build_params) model_path = get_pretrain_path(os.path.join(model_dir, model_str), "best") model = model.to(device).type(dtype) _load_state_dict(model, device, model_path, inference=True) model.eval() calculator = ASECalculator( model=model, device=device, dtype=dtype, transform=transform, neighbor_list_type=neighbor_list_type, Hartree_in_E=Hartree_in_E, internal_calculator_weight=internal_calculator_weight, uncertainty_calculator_config=uncertainty_calculator_config, external_calculator=external_calculator, external_calculator_config=external_calculator_config ) break return calculator