Source code for enerzyme.models.layers.dispersion.grimme_d4

import os, math
from typing import Optional
import torch
from torch import Tensor
from torch.nn import Parameter
import torch.nn.functional as F
from .. import BaseFFLayer
from ...functional import softplus_inverse, segment_sum_coo
from ...cutoff import smooth_transition


"""
computes D4 dispersion energy
HF      s6=1.00000000, s8=1.61679827, a1=0.44959224, a2=3.35743605
"""


[docs] class GrimmeD4EnergyLayer(BaseFFLayer):
[docs] def __init__( self, cutoff: Optional[float] = None, s6: float = 1.00000000, s8: float = 1.61679827, a1: float = 0.44959224, a2: float = 3.35743605, g_a: float = 3.0, g_c: float = 2.0, k2: float = 1.3333333333333333, # 4/3 k4: float = 4.10451, k5: float = 19.08857, k6: float = 254.5553148552, # 2*11.28174**2 kn: float = 7.5, wf: float = 6.0, max_Za: int = 87, Hartree_in_E: float=1, Bohr_in_R: float=0.5291772108 ) -> None: """ Initializes the D4DispersionEnergy class. """ super().__init__(output_fields={"E_disp_a"}) # Grimme's D4 dispersion is only parametrized up to Rn (Z=86) assert max_Za <= 87 # trainable parameters self.register_parameter( "_s6", Parameter(softplus_inverse(s6), requires_grad=False) ) # s6 is usually not fitted (correct long-range) self.register_parameter( "_s8", Parameter(softplus_inverse(s8), requires_grad=True) ) self.register_parameter( "_a1", Parameter(softplus_inverse(a1), requires_grad=True) ) self.register_parameter( "_a2", Parameter(softplus_inverse(a2), requires_grad=True) ) self.register_parameter( "_scaleq", Parameter(softplus_inverse(1.0), requires_grad=True) ) # for scaling charges of reference systems # D4 constants self.Zmax = max_Za self.Hartree_in_E = Hartree_in_E self.Bohr_in_R = Bohr_in_R # factor of 0.5 prevents double counting self.set_cutoff(cutoff) self.g_a = g_a self.g_c = g_c self.k2 = k2 self.k4 = k4 self.k5 = k5 self.k6 = k6 self.kn = kn self.wf = wf # load D4 data directory = os.path.join(os.path.dirname(os.path.abspath(__file__)), "grimme_d4_tables") self.register_buffer( "refsys", # [Zmax,max_nref] torch.load(os.path.join(directory, "refsys.pth"))[:max_Za], ) self.register_buffer( "zeff", torch.load(os.path.join(directory, "zeff.pth"))[:max_Za] # [Zmax] ) self.register_buffer( "refh", # [Zmax,max_nref] torch.load(os.path.join(directory, "refh.pth"))[:max_Za], ) self.register_buffer( "sscale", torch.load(os.path.join(directory, "sscale.pth")) # [18] ) self.register_buffer( "secaiw", torch.load(os.path.join(directory, "secaiw.pth")) # [18,23] ) self.register_buffer( "gam", torch.load(os.path.join(directory, "gam.pth"))[:max_Za] # [Zmax] ) self.register_buffer( "ascale", # [Zmax,max_nref] torch.load(os.path.join(directory, "ascale.pth"))[:max_Za], ) self.register_buffer( "alphaiw", # [Zmax,max_nref,23] torch.load(os.path.join(directory, "alphaiw.pth"))[:max_Za], ) self.register_buffer( "hcount", # [Zmax,max_nref] torch.load(os.path.join(directory, "hcount.pth"))[:max_Za], ) self.register_buffer( "casimir_polder_weights", # [23] torch.load(os.path.join(directory, "casimir_polder_weights.pth"))[:max_Za], ) self.register_buffer( "rcov", torch.load(os.path.join(directory, "rcov.pth"))[:max_Za] # [Zmax] ) self.register_buffer( "en", torch.load(os.path.join(directory, "en.pth"))[:max_Za] # [Zmax] ) self.register_buffer( "ncount_mask", # [Zmax,max_nref,max_ncount] torch.load(os.path.join(directory, "ncount_mask.pth"))[:max_Za], ) self.register_buffer( "ncount_weight", # [Zmax,max_nref,max_ncount] torch.load(os.path.join(directory, "ncount_weight.pth"))[:max_Za], ) self.register_buffer( "cn", # [Zmax,max_nref,max_ncount] torch.load(os.path.join(directory, "cn.pth"))[:max_Za], ) self.register_buffer( "fixgweights", # [Zmax,max_nref] torch.load(os.path.join(directory, "fixgweights.pth"))[:max_Za], ) self.register_buffer( "refq", # [Zmax,max_nref] torch.load(os.path.join(directory, "refq.pth"))[:max_Za], ) self.register_buffer( "sqrt_r4r2", # [Zmax] torch.load(os.path.join(directory, "sqrt_r4r2.pth"))[:max_Za], ) self.register_buffer( "alpha", # [Zmax,max_nref,23] torch.load(os.path.join(directory, "alpha.pth"))[:max_Za], ) self.max_nref = self.refsys.size(-1) self._compute_refc6() self.reset_parameters()
[docs] def reset_parameters(self) -> None: """ For compatibility with other modules. """ pass
[docs] def set_cutoff(self, cutoff: Optional[float] = None) -> None: """ Can be used to change the cutoff. """ if cutoff is None: self.cutoff = None self.cuton = None else: self.cutoff = cutoff / self.Bohr_in_R self.cuton = self.cutoff - self.Bohr_in_R
def _compute_refc6(self) -> None: """ Function to compute the refc6 tensor. Important: If the charges of reference systems are scaled and the scaleq parameter changes (e.g. during training), then the refc6 tensor must be recomputed for correct results. """ with torch.no_grad(): allZ = torch.arange(self.Zmax) is_ = self.refsys[allZ, :] iz = self.zeff[is_] refh = self.refh[allZ, :] * F.softplus(self._scaleq) qref = iz qmod = iz + refh ones_like_qmod = torch.ones_like(qmod) qmod_ = torch.where(qmod > 1e-8, qmod, ones_like_qmod) alpha = ( self.sscale[is_].view(-1, self.max_nref, 1) * self.secaiw[is_] * torch.where( qmod > 1e-8, torch.exp( self.g_a * (1 - torch.exp(self.gam[is_] * self.g_c * (1 - qref / qmod_))) ), math.exp(self.g_a) * ones_like_qmod, ).view(-1, self.max_nref, 1) ) alpha = torch.max( self.ascale[allZ, :].view(-1, self.max_nref, 1) * ( self.alphaiw[allZ, :, :] - self.hcount[allZ, :].view(-1, self.max_nref, 1) * alpha ), torch.zeros_like(alpha), ) alpha_expanded = alpha.view( alpha.size(0), 1, alpha.size(1), 1, -1 ) * alpha.view(1, alpha.size(0), 1, alpha.size(1), -1) self.register_buffer( "refc6", 3.0 / math.pi * torch.sum( alpha_expanded * self.casimir_polder_weights.view(1, 1, 1, 1, -1), -1, ), persistent=False, )
[docs] def get_E_disp_a(self, Za: Tensor, Qa: Tensor, Dij_lr: Tensor, idx_i: Tensor, idx_j: Tensor) -> Tensor: # initialization of Zi/Zj and unit conversion Dij_lr_ = Dij_lr / self.Bohr_in_R # convert distances to Bohr Zi = Za[idx_i] Zj = Za[idx_j] # calculate coordination numbers rco = self.k2 * (self.rcov[Zi] + self.rcov[Zj]) den = self.k4 * torch.exp( -((torch.abs(self.en[Zi] - self.en[Zj]) + self.k5) ** 2) / self.k6 ) tmp = den * 0.5 * (1.0 + torch.erf(-self.kn * (Dij_lr_ - rco) / rco)) if self.cutoff is not None: tmp = tmp * smooth_transition(Dij_lr_, self.cutoff, self.cuton) covcn = segment_sum_coo(tmp, idx_i, dim_size=len(Za)) # calculate gaussian weights gweights = torch.sum( self.ncount_mask[Za] * torch.exp( -self.wf * self.ncount_weight[Za] * (covcn.view(-1, 1, 1) - self.cn[Za]) ** 2 ), -1, ) norm = torch.sum(gweights, -1, True) # norm_ is used to prevent nans in backwards pass norm_ = torch.where(norm > 1e-8, norm, torch.ones_like(norm)) gweights = torch.where(norm > 1e-8, gweights / norm_, self.fixgweights[Za]) # calculate dispersion energy iz = self.zeff[Za].view(-1, 1) refq = self.refq[Za] * F.softplus(self._scaleq) qref = iz + refq qmod = iz + Qa.view(-1, 1).expand(-1, self.refq.size(1)) ones_like_qmod = torch.ones_like(qmod) qmod_ = torch.where(qmod > 1e-8, qmod, ones_like_qmod) zeta = ( torch.where( qmod > 1e-8, torch.exp( self.g_a * ( 1 - torch.exp( self.gam[Za].view(-1, 1) * self.g_c * (1 - qref / qmod_) ) ) ), math.exp(self.g_a) * ones_like_qmod, ) * gweights ) if zeta.device.type == "cpu": # indexing is faster on CPUs zetai = zeta[idx_i] zetaj = zeta[idx_j] else: # gathering is faster on GPUs zetai = torch.gather(zeta, 0, idx_i.view(-1, 1).expand(-1, zeta.size(1))) zetaj = torch.gather(zeta, 0, idx_j.view(-1, 1).expand(-1, zeta.size(1))) refc6ij = self.refc6[Zi, Zj, :, :] zetaij = zetai.view(zetai.size(0), zetai.size(1), 1) * zetaj.view( zetaj.size(0), 1, zetaj.size(1) ) c6ij = torch.sum((refc6ij * zetaij).view(refc6ij.size(0), -1), -1) sqrt_r4r2ij = math.sqrt(3) * self.sqrt_r4r2[Zi] * self.sqrt_r4r2[Zj] a1 = F.softplus(self._a1) a2 = F.softplus(self._a2) r0 = a1 * sqrt_r4r2ij + a2 if self.cutoff is None: oor6 = 1 / (Dij_lr_ ** 6 + r0 ** 6) oor8 = 1 / (Dij_lr_ ** 8 + r0 ** 8) else: cut2 = self.cutoff ** 2 cut6 = cut2 ** 3 cut8 = cut2 * cut6 tmp6 = r0 ** 6 tmp8 = r0 ** 8 cut6tmp6 = cut6 + tmp6 cut8tmp8 = cut8 + tmp8 tmpc = Dij_lr_ / self.cutoff - 1 oor6 = ( 1 / (Dij_lr_ ** 6 + tmp6) - 1 / cut6tmp6 + 6 * cut6 / cut6tmp6 ** 2 * tmpc ) oor8 = ( 1 / (Dij_lr_ ** 8 + tmp8) - 1 / cut8tmp8 + 8 * cut8 / cut8tmp8 ** 2 * tmpc ) oor6 = torch.where(Dij_lr_ < self.cutoff, oor6, torch.zeros_like(oor6)) oor8 = torch.where(Dij_lr_ < self.cutoff, oor8, torch.zeros_like(oor8)) s6 = F.softplus(self._s6) s8 = F.softplus(self._s8) pairwise = -c6ij * (s6 * oor6 + s8 * sqrt_r4r2ij ** 2 * oor8) * self.Hartree_in_E / 2 edisp = segment_sum_coo(pairwise, idx_i, dim_size=len(Za)) return edisp