import os.path as osp
from typing import Dict, Optional
from torch.nn import Module
import numpy as np
import ase
import ase.io
import torch
from copy import copy
from typing import Literal
from itertools import combinations
from functools import partial
from ase.constraints import FixBondLengths
from ase.units import Hartree, fs, kB, Bohr
from ..data.transform import Transform
from ..utils import logger
from .trainer import DTYPE_MAPPING, _load_state_dict
[docs]
def get_optimizer(
optimizer_name: Literal["BFGS", "LBFGS", "MDMin", "FIRE", "GPMin", "BFGSLineSearch", "LBFGSLineSearch", "odesolver", "static"],
neb: bool=False
):
if optimizer_name == "BFGS":
from ase.optimize import BFGS
return BFGS
elif optimizer_name == "LBFGS":
from ase.optimize import LBFGS
return LBFGS
elif optimizer_name == "MDMin":
from ase.optimize import MDMin
return MDMin
elif optimizer_name == "FIRE":
from ase.optimize import FIRE
return FIRE
elif optimizer_name == "GPMin":
from ase.optimize import GPMin
return GPMin
elif optimizer_name == "BFGSLineSearch":
from ase.optimize import BFGSLineSearch
return BFGSLineSearch
elif optimizer_name == "LBFGSLineSearch":
from ase.optimize import LBFGSLineSearch
return LBFGSLineSearch
elif neb == True:
from ase.mep.neb import NEBOptimizer
if optimizer_name == "odesolver":
return partial(NEBOptimizer, method="ODE")
elif optimizer_name == "static":
return partial(NEBOptimizer, method="static")
else:
raise ValueError(f"NEB optimizer {optimizer_name} not supported")
else:
raise ValueError(f"Optimizer {optimizer_name} not supported")
[docs]
class Simulation:
[docs]
def __init__(self, config: Dict, model: Module, model_path: str, out_dir: str, transform: Transform, calculator_patch_module: Optional[object]=None, plumed_patch_module: Optional[object]=None) -> None:
self.environment = config.Simulation.environment
self.task = config.Simulation.task
self.structure_file = config.System.structure_file
self.charge = config.System.charge
self.multiplicity = config.System.multiplicity
self.transform = transform
self.external_calculator_config = config.Simulation.get("external_calculator", dict())
self.uncertainty_calculator_config = config.Simulation.get("uncertainty_calculator", None)
self.plumed_config_generator_config = config.Simulation.get("plumed_config_generator", None)
self.internal_calculator_weight = config.Simulation.get("internal_calculator_weight", 1.0)
self.idx_start_from = config.Simulation.get("idx_start_from", 1)
self.integrate_config = config.Simulation.get("integrate", None)
self.initialize_config = config.Simulation.get("initialize", None)
self.constraint_config = config.Simulation.get("constraint", None)
self.sampling_config = config.Simulation.get("sampling", None)
self.optimize_config = config.Simulation.get("optimize", None)
self.fs_in_t = config.Simulation.get("fs_in_t", 1)
self.log_interval = config.Simulation.get("log_interval", 20)
self.neighbor_list_type = config.Simulation.get("neighbor_list", "full")
self.cuda = config.Simulation.get('cuda', False)
self.dtype = DTYPE_MAPPING[config.Simulation.get("dtype", "float64")]
self.device = torch.device("cuda:0" if torch.cuda.is_available() and self.cuda else "cpu")
# single ff simulation
self.model = model.to(self.device).type(self.dtype)
_load_state_dict(model, self.device, model_path, inference=True)
self.model.eval()
self.calculator = None
self.out_dir = out_dir
self.calculator_patch_module = calculator_patch_module
self.plumed_patch_module = plumed_patch_module
getattr(self, f"_init_{self.environment}_env")()
def _init_ase_env(self):
from .calculator import ASECalculator
self.initial_structures = ase.io.read(self.structure_file, ":")
for atoms in self.initial_structures:
atoms.info.update({"spin": self.multiplicity, "charge": self.charge})
self.systems = self.initial_structures.copy()
self.system = self.systems[-1]
if self.calculator_patch_module is not None:
external_calculator_name = self.external_calculator_config.get("name", None)
if external_calculator_name is not None:
if hasattr(self.calculator_patch_module, external_calculator_name):
self.external_calculator = getattr(self.calculator_patch_module, external_calculator_name)
else:
raise ValueError(f"External calculator {external_calculator_name} not found in {self.calculator_patch_module}")
else:
raise ValueError(f"External calculator name not specified!")
logger.info(f"Initialized external calculator: {external_calculator_name}")
else:
self.external_calculator = None
self.calculator = ASECalculator(
model=self.model,
device=self.device,
dtype=self.dtype,
neighbor_list_type=self.neighbor_list_type,
transform=self.transform,
internal_calculator_weight=self.internal_calculator_weight,
external_calculator=self.external_calculator,
external_calculator_config=self.external_calculator_config,
uncertainty_calculator_config=self.uncertainty_calculator_config
)
if self.constraint_config is not None:
self.constraints = []
for k, v in self.constraint_config.items():
if k == "fix_atom":
from ase.constraints import FixAtoms
c = FixAtoms(indices=[idx - self.idx_start_from for idx in v.indices])
self.constraints.append(c)
elif k == "Hookean_allpairs":
from ase.constraints import Hookean
c = [
Hookean(i, j,
k=v.k * Hartree / self.calculator.Hartree_in_E,
rt=self.system.get_distance(i - self.idx_start_from, j - self.idx_start_from)
) for i, j in combinations(v.indices, 2)
]
self.constraints.extend(c)
for system in self.systems:
system.calc = copy(self.calculator)
system.set_constraint(self.constraints)
def _run_sp(self):
for system in self.systems:
for property in ["energy", "forces", "dipole", "charges"]:
if property in system.info:
system.info.pop(property)
system.get_potential_energy()
ase.io.write(osp.join(self.out_dir, f"sp.xyz"), system, append=True)
def _run_opt(self):
logger.info(f"Running optimization with {self.optimize_config.optimizer} optimizer")
optimizer = get_optimizer(self.optimize_config.optimizer)(self.system)
def write_xyz(atoms=None):
ase.io.write(osp.join(self.out_dir, f"traj-opt.xyz"), atoms, append=True)
optimizer.attach(write_xyz, interval=1, atoms=self.system)
optimizer.run(fmax=self.optimize_config.get("fmax", 4.5e-4) / self.system.calc.Hartree_in_E * Hartree / Bohr)
ase.io.write(osp.join(self.out_dir, f"optim.xyz"), self.system, append=True)
logger.info(f"Final energy: {self.system.get_potential_energy()}")
def _run_scan(self):
if self.sampling_config.cv == "distance":
i0 = self.sampling_config.params.i0 - self.idx_start_from
i1 = self.sampling_config.params.i1 - self.idx_start_from
x0 = self.sampling_config.params.x0
x1 = self.sampling_config.params.x1
num = self.sampling_config.params.num
x_scan = np.linspace(x0, x1, num)
for i, x in enumerate(x_scan):
# reset constraint
del self.system.constraints
logger.info(f"Setting distance to: {self.system.get_distance(i0, i1)}")
self.system.set_distance(i0, i1, x)
c = FixBondLengths([(i0, i1)], bondlengths=[x], tolerance=1e-6)
self.system.set_constraint(self.constraints + [c])
optimizer = get_optimizer(self.optimize_config.optimizer)(self.system)
def write_xyz(atoms=None):
ase.io.write(osp.join(self.out_dir, f"traj-{i}.xyz"), atoms, append=True)
optimizer.attach(write_xyz, interval=1, atoms=self.system)
optimizer.run(fmax=4.5e-4 / self.system.calc.Hartree_in_E * Hartree / Bohr)
ase.io.write(osp.join(self.out_dir, f"scan_optim.xyz"), self.system, append=True)
logger.info(f"Final energy: {self.system.get_potential_energy()}")
logger.info(f"Final distance: {self.system.get_distance(i0, i1)}")
def _get_integrator(self, traj_file):
if self.integrate_config.integrator.lower() == "langevin":
from ase.md.langevin import Langevin
dyn = Langevin(self.system,
timestep=self.integrate_config.time_step * fs / self.fs_in_t,
temperature_K=self.integrate_config.temperature_in_K,
friction=self.integrate_config.friction,
logfile="-",
loginterval=self.log_interval
)
def write_xyz(atoms=None):
write_atoms = atoms.copy()
write_atoms.constraints = []
ase.io.write(osp.join(self.out_dir, traj_file), write_atoms, append=True)
dyn.attach(write_xyz, interval=self.log_interval, atoms=self.system)
return dyn
def _md_initialize(self):
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
MaxwellBoltzmannDistribution(self.system, temperature_K=self.initialize_config.get("temperature_in_K", self.integrate_config.temperature_in_K))
def _run_md(self) -> None:
if self.initialize_config is not None:
self._md_initialize()
dyn = self._get_integrator(traj_file="md.traj.xyz")
dyn.run(self.integrate_config.n_step)
def _run_neb(self):
from ase.mep.neb import NEB, NEBTools
num_images = self.sampling_config.params.num_images
assert num_images > 2, "Number of images must be greater than 2"
requires_interpolation = 0
if len(self.initial_structures) == 2:
requires_interpolation = 1
images = []
for _ in range(num_images - 1):
images.append(self.initial_structures[0].copy())
images.append(self.initial_structures[1].copy())
elif len(self.initial_structures) == num_images:
images = self.systems
elif len(self.initial_structures) == 3:
requires_interpolation = 2
images = []
else:
raise ValueError(f"Number of initial structures {len(self.initial_structures)} is not capatible with the number of images {num_images}")
if requires_interpolation:
for image in images:
image.calc = copy(self.calculator)
image.set_constraint(self.constraints)
if self.sampling_config.params.get("relax_endpoints", True):
logger.info("Relaxing endpoints")
# relaxing reactant
optimizer = get_optimizer(self.optimize_config.optimizer)(images[0])
def write_xyz(atoms=None):
ase.io.write(osp.join(self.out_dir, f"neb-relax-reactant.xyz"), atoms, append=True)
optimizer.attach(write_xyz, interval=1, atoms=images[0])
optimizer.run(fmax=4.5e-4 / self.calculator.Hartree_in_E * Hartree / Bohr)
# relaxing product
optimizer = get_optimizer(self.optimize_config.optimizer)(images[-1])
def write_xyz(atoms=None):
ase.io.write(osp.join(self.out_dir, f"neb-relax-product.xyz"), atoms, append=True)
optimizer.attach(write_xyz, interval=1, atoms=images[-1])
optimizer.run(fmax=4.5e-4 / self.calculator.Hartree_in_E * Hartree / Bohr)
if requires_interpolation:
if requires_interpolation == 1:
logger.info(f"Interpolating {num_images} images between reactant and product")
interpolated_neb = NEB(images)
interpolated_neb.interpolate(
method=self.sampling_config.params.interpolation.method,
apply_constraint=self.sampling_config.params.interpolation.apply_constraint
)
elif requires_interpolation == 2:
logger.info(f"Interpolating {num_images} images between reactant, guessed TS and product")
middle_idx = num_images // 2
first_half_neb = NEB(images[:middle_idx])
first_half_neb.interpolate(
method=self.sampling_config.params.interpolation.method,
apply_constraint=self.sampling_config.params.interpolation.apply_constraint
)
second_half_neb = NEB(images[middle_idx:])
second_half_neb.interpolate(
method=self.sampling_config.params.interpolation.method,
apply_constraint=self.sampling_config.params.interpolation.apply_constraint
)
neb = NEB(
images,
k=self.sampling_config.params.spring_constants / self.calculator.Hartree_in_E * Hartree / (Bohr ** 2),
climb=False,
allow_shared_calculator=True
)
neb_tools = NEBTools(neb.images)
# plain neb
neb_optimizer_name = self.sampling_config.params.get("neb_optimizer", "odesolver")
optimizer = get_optimizer(neb_optimizer_name, neb=True)(neb, trajectory=osp.join(self.out_dir, "neb.traj"))
def write_xyz(images=None):
for i, image in enumerate(images):
ase.io.write(osp.join(self.out_dir, f"neb-{i}.xyz"), image, append=True)
ase.io.write(osp.join(self.out_dir, f"neb.xyz"), images, append=False)
optimizer.attach(write_xyz, interval=1, images=images)
if self.sampling_config.params.get("climb", False):
optimizer.run(fmax=4.5e-4 / self.calculator.Hartree_in_E * Hartree / Bohr * 2)
else:
optimizer.run(fmax=4.5e-4 / self.calculator.Hartree_in_E * Hartree / Bohr)
barrier, dE = neb_tools.get_barrier()
logger.info(f"NEB barrier: {barrier}, dE: {dE}")
ase.io.write(osp.join(self.out_dir, f"neb.xyz"), images, append=False)
if self.sampling_config.params.get("climb", False):
neb.climb = True
ci_neb_optimizer_name = self.sampling_config.params.get("ci_neb_optimizer", neb_optimizer_name)
optimizer = get_optimizer(ci_neb_optimizer_name, neb=True)(neb, trajectory=osp.join(self.out_dir, "ci-neb.traj"))
def write_xyz(images=None):
for i, image in enumerate(images):
ase.io.write(osp.join(self.out_dir, f"ci-neb-{i}.xyz"), image, append=True)
ase.io.write(osp.join(self.out_dir, f"ci-neb.xyz"), images, append=False)
optimizer.attach(write_xyz, interval=1, images=images)
optimizer.run(fmax=4.5e-4 / self.calculator.Hartree_in_E * Hartree / Bohr)
barrier, dE = neb_tools.get_barrier()
logger.info(f"CI-NEB barrier: {barrier}, dE: {dE}")
ase.io.write(osp.join(self.out_dir, f"ci-neb.xyz"), images, append=False)
def _run_plumed(self):
if self.plumed_patch_module is not None:
plumed_config_generator_name = self.plumed_config_generator_config.get("name", None)
if plumed_config_generator_name is not None:
if hasattr(self.plumed_patch_module, plumed_config_generator_name):
self.plumed_config_generator = getattr(self.plumed_patch_module, plumed_config_generator_name)
else:
raise ValueError(f"Plumed config generator {plumed_config_generator_name} not found in {self.plumed_patch_module}")
else:
raise ValueError(f"Plumed config generator name not specified!")
logger.info(f"Initialized plumed config generator: {plumed_config_generator_name}")
plumed_config = self.plumed_config_generator(
self.system,
integrate_config=self.integrate_config,
idx_start_from=self.idx_start_from,
**self.sampling_config.params.plumed_config
)
with open(osp.join(self.out_dir, "plumed.dat"), "w") as f:
f.writelines([line + "\n" for line in plumed_config])
else:
plumed_config = self.sampling_config.params.plumed_config
from ase.calculators.plumed import Plumed
plumed_calc = Plumed(self.calculator, plumed_config,
timestep=self.integrate_config.time_step * fs / self.fs_in_t,
atoms=self.system,
kT=self.integrate_config.temperature_in_K * kB,
log=osp.join(self.out_dir, "plumed.log"),
restart=False
)
if self.initialize_config is not None:
self._md_initialize()
dyn = self._get_integrator(traj_file="plumed.traj.xyz")
dyn.run(self.integrate_config.n_step)
[docs]
def run(self):
getattr(self, f"_run_{self.task}")()