Source code for enerzyme.models.layers.zbl
from typing import Optional
import torch
from torch import Tensor
from torch.nn import Parameter, init
import torch.nn.functional as F
from torch_scatter import segment_sum_coo
from . import BaseFFLayer
from ..functional import softplus_inverse
from ..cutoff import CUTOFF_REGISTER, CUTOFF_KEY_TYPE
[docs]
class ZBLRepulsionEnergyLayer(BaseFFLayer):
"""
Short-range repulsive potential with learnable parameters inspired by the
Ziegler-Biersack-Littmark (ZBL) potential described in Ziegler, J.F.,
Biersack, J.P., and Littmark, U., "The stopping and range of ions in
solids".
Arguments:
a0 (float):
Bohr radius in chosen length units (default value corresponds to
lengths in Angstrom).
ke (float):
Coulomb constant in chosen unit system (default value corresponds to
lengths in Angstrom and energy in electronvolt).
"""
[docs]
def __init__(
self, Bohr_in_R: float=0.5291772105638411, Hartree_in_E: float=1, cutoff_sr: Optional[float]=None,
cutoff_fn: CUTOFF_KEY_TYPE=None
) -> None:
""" Initializes the ZBLRepulsionEnergy class. """
super().__init__(output_fields={"E_zbl_a"})
self.a0 = Bohr_in_R
self.kehalf = 0.5 * Bohr_in_R * Hartree_in_E
if cutoff_fn is not None:
self.cutoff_fn = CUTOFF_REGISTER[cutoff_fn]
self.cutoff_sr = cutoff_sr
self.register_parameter("_adiv", Parameter(torch.Tensor(1)))
self.register_parameter("_apow", Parameter(torch.Tensor(1)))
self.register_parameter("_c1", Parameter(torch.Tensor(1)))
self.register_parameter("_c2", Parameter(torch.Tensor(1)))
self.register_parameter("_c3", Parameter(torch.Tensor(1)))
self.register_parameter("_c4", Parameter(torch.Tensor(1)))
self.register_parameter("_a1", Parameter(torch.Tensor(1)))
self.register_parameter("_a2", Parameter(torch.Tensor(1)))
self.register_parameter("_a3", Parameter(torch.Tensor(1)))
self.register_parameter("_a4", Parameter(torch.Tensor(1)))
self.reset_parameters()
[docs]
def reset_parameters(self) -> None:
""" Initialize parameters to the default ZBL potential. """
init.constant_(self._adiv, softplus_inverse(1 / (0.8854 * self.a0)))
init.constant_(self._apow, softplus_inverse(0.23))
init.constant_(self._c1, softplus_inverse(0.18180))
init.constant_(self._c2, softplus_inverse(0.50990))
init.constant_(self._c3, softplus_inverse(0.28020))
init.constant_(self._c4, softplus_inverse(0.02817))
init.constant_(self._a1, softplus_inverse(3.20000))
init.constant_(self._a2, softplus_inverse(0.94230))
init.constant_(self._a3, softplus_inverse(0.40280))
init.constant_(self._a4, softplus_inverse(0.20160))
[docs]
def get_E_zbl_a(
self,
Za: Tensor,
Dij_sr: Tensor,
idx_i_sr: Tensor,
idx_j_sr: Tensor,
cutoff_values_sr: Optional[Tensor]=None,
) -> Tensor:
"""
Evaluate the short-range repulsive potential.
P: Number of atom pairs.
Arguments:
N (int):
Number of atoms.
Zf (FloatTensor [N]):
Nuclear charges of atoms (as floating point values).
rij (FloatTensor [P]):
Pairwise interatomic distances.
cutoff_values (FloatTensor [P]):
Values of a cutoff function for the distances rij.
idx_i (LongTensor [P]):
Index of atom i for all atomic pairs ij. Each pair must be
specified as both ij and ji.
idx_j (LongTensor [P]):
Same as idx_i, but for atom j.
Returns:
e (FloatTensor [N]):
Atomic contributions to the total repulsive energy.
"""
if cutoff_values_sr is None:
cutoff_values_sr = self.cutoff_fn(Dij_sr, cutoff=self.cutoff_sr)
# calculate ZBL parameters
Zf = Za.type_as(self._a1)
z = Zf ** F.softplus(self._apow)
a = (z[idx_i_sr] + z[idx_j_sr]) * F.softplus(self._adiv)
a1 = F.softplus(self._a1) * a
a2 = F.softplus(self._a2) * a
a3 = F.softplus(self._a3) * a
a4 = F.softplus(self._a4) * a
c1 = F.softplus(self._c1)
c2 = F.softplus(self._c2)
c3 = F.softplus(self._c3)
c4 = F.softplus(self._c4)
# normalize c coefficients (necessary to get asymptotically correct
# behaviour for r -> 0)
csum = c1 + c2 + c3 + c4
c1 = c1 / csum
c2 = c2 / csum
c3 = c3 / csum
c4 = c4 / csum
# compute interactions
zizj = Zf[idx_i_sr] * Zf[idx_j_sr]
f = (
c1 * torch.exp(-a1 * Dij_sr)
+ c2 * torch.exp(-a2 * Dij_sr)
+ c3 * torch.exp(-a3 * Dij_sr)
+ c4 * torch.exp(-a4 * Dij_sr)
) * cutoff_values_sr
return segment_sum_coo(self.kehalf * f * zizj / Dij_sr, idx_i_sr, dim_size=len(Za))