Source code for enerzyme.tasks.server

from addict import Dict
from ase.units import Bohr
import torch
from torch.nn import Module
from ..data.transform import Transform
from ..data.neighbor_list import full_neighbor_list
from .trainer import DTYPE_MAPPING, _load_state_dict
from .batch import _decorate_batch_input, _to_device, _decorate_batch_output


[docs] class Server:
[docs] def __init__(self, config: Dict, model: Module, model_path: str, out_dir: str, transform: Transform): self.neighbor_list_type = config.Server.get("neighbor_list", "full") self.cuda = config.Server.get('cuda', False) self.dtype = DTYPE_MAPPING[config.Server.get("dtype", "float64")] self.Hartree_in_E = config.Server.get("Hartree_in_E", 1) self.Bohr_in_R = config.Server.get("Bohr_in_R", Bohr) self.device = torch.device("cuda:0" if torch.cuda.is_available() and self.cuda else "cpu") # single ff simulation self.model = model.to(self.device).type(self.dtype) _load_state_dict(model, self.device, model_path, inference=True) self.model.eval() self.calculator = None self.out_dir = out_dir self.transform = transform
[docs] def calculate(self, info): features = info.get("features", None) if features is None: return {} if features["N"] is None: features["N"] = len(features["Ra"]) if self.neighbor_list_type == "full": idx_i, idx_j = full_neighbor_list(features["N"]) features["idx_i"] = idx_i features["idx_j"] = idx_j features["N_pair"] = len(idx_i) net_input, _ = _decorate_batch_input( batch=[(features, None)], device=self.device, dtype=self.dtype ) net_input, _ = _to_device((net_input, {}), self.device) net_output = self.model(net_input) output, _ = _decorate_batch_output( output=net_output, features=net_input, targets=None ) self.transform.inverse_transform(output) return output