Source code for enerzyme.models.layers.electrostatics

from typing import Dict, Literal, Optional, Tuple
import torch
from torch import Tensor
from torch.nn import Module
from . import BaseFFLayer
from ..functional import segment_sum_coo
from ..cutoff import CUTOFF_KEY_TYPE, CUTOFF_REGISTER


[docs] class ChargeConservationLayer(BaseFFLayer):
[docs] def __init__(self) -> None: r""" Correct the atomic charges to make their summation equal to the total charge by [1] q^{corrected}_i = q_i - 1 / N (\sum_{j=1}^N q_j - Q) References: ----- [1] J. Chem. Theory Comput. 2019, 15, 3678−3693. """ super().__init__()
[docs] def get_output( self, Za: Tensor, Qa: Tensor, Q: Optional[Tensor]=None, batch_seg: Optional[Tensor]=None ) -> Dict[Literal["Qa", "Q"], Tensor]: ''' Correct the atomic charge Params: ----- Za: Long tensor of atomic numbers, shape [N * batch_size] Qa: Float tensor of atomic charges, shape [N * batch_size] Q: Float tensor of total charges, shape [batch_size] batch_seg: Long tensor of batch indices, shape [N * batch_size] Returns: ----- Qa_corrected: Float tensor of corrected atomic charge, shape [N * batch_size] raw_Q: Float tensor of total atomic charge before correction, shape [batch_size] ''' if batch_seg is None: batch_seg = torch.zeros_like(Za, dtype=torch.long) #number of atoms per batch (needed for charge scaling) N_per_batch = segment_sum_coo(torch.ones_like(batch_seg), batch_seg) view_shape = (-1, ) if Qa.dim() == 1 else (-1, 1) raw_Q = segment_sum_coo(Qa, batch_seg) if Q is None: #assume desired total charge zero if not given Q = torch.zeros_like(N_per_batch) #return scaled charges (such that they have the desired total charge) return { "Qa": Qa + ((Q.view(view_shape) - raw_Q) / N_per_batch.view(view_shape))[batch_seg], "Q": raw_Q }
[docs] class ElectrostaticEnergyLayer(BaseFFLayer):
[docs] def __init__( self, cutoff_sr: float, cutoff_lr: Optional[float]=None, Bohr_in_R: float=0.5291772108, Hartree_in_E: float=1, dielectric_constant: float=1, cutoff_fn: CUTOFF_KEY_TYPE="smooth", flavor: Literal["PhysNet", "SpookyNet"]="SpookyNet" ) -> None: r""" Calculate the electrostatic energy from distributed multipoles and atomic positions Params: ----- bohr_in_Ra: the numerical value of one Bohr in the unit of atom positions. Hartree_in_Ea: the numerical value of one Hartree in the unit of energy. short_range_cutoff: the cutoff of short range interaction, the Coulomb's law at long-range and a damped term at short-range to avoid the singularity at r = 0 are smoothly interpolated by \phi [1]: \chi(r) = \phi(2r) + 1 / \sqrt{r^2 + 1} + (1 - \phi(2r)) / r long_range_cutoff: the cutoff of long range interaction, outside which the electrostatics are ignored References: ----- [1] J. Chem. Theory Comput. 2019, 15, 3678−3693. """ super().__init__(input_fields={"Dij_lr", "Qa", "idx_i", "idx_j"}, output_fields={"E_ele_a"}) self.kehalf = 0.5 * Bohr_in_R * Hartree_in_E if flavor == "PhysNet": self.cutoff = cutoff_sr / 2 self.cuton = 0 elif flavor == "SpookyNet": self.cutoff = cutoff_sr * 0.75 self.cuton = cutoff_sr * 0.25 self.cutoff_lr = cutoff_lr self.cutoff_fn = CUTOFF_REGISTER[cutoff_fn] self.dielectric_constant = dielectric_constant if cutoff_lr is not None and cutoff_lr > 0: self.cutoff_lr2 = self.cutoff_lr * self.cutoff_lr self.two_div_cut = 2.0 / self.cutoff_lr if flavor == "PhysNet": self.lr_shield = self._simple_lr_shield elif flavor == "SpookyNet": self.rcutconstant = self.cutoff_lr / (self.cutoff_lr ** 2 + 1.0) ** 1.5 self.cutconstant = (2 * self.cutoff_lr ** 2 + 1.0) / (self.cutoff_lr** 2 + 1.0) ** 1.5 self.lr_shield = self._smooth_lr_shield
def _lr_ordinary(self, Dij: Tensor) -> Tensor: return 1.0 / Dij + Dij / self.lr_cutoff2 - self.two_div_cut def _shield(self, Dij: Tensor) -> Tensor: return torch.sqrt(Dij * Dij + 1.0) def _simple_lr_shield(self, Dij: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: Dij_shield = self._shield(Dij) zeros = torch.zeros_like(Dij) condition = Dij < self.cutoff_lr return ( torch.where(condition, self._lr_ordinary(Dij), zeros), torch.where(condition, self._lr_ordinary(Dij_shield), zeros), condition, zeros ) def _smooth_lr_shield(self, Dij: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: Dij_shield = self._shield(Dij) zeros = torch.zeros_like(Dij) condition = Dij < self.cutoff_lr return ( torch.where(condition, self._lr_ordinary(Dij), zeros), torch.where(condition, 1.0 / Dij_shield + Dij * self.rcutconstant - self.cutconstant, zeros), condition, zeros )
[docs] def get_E_ele_a(self, Dij_lr: Tensor, Qa: Tensor, idx_i: Tensor, idx_j: Tensor) -> Tensor: ''' Compute the atomic electrostatic energy Params: ----- Dij: Float tensor of pair distances, shape [N_pair * batch_size] Qa: Float tensor of atomic charges, shape [N * batch_size] idx_i: Long tensor of the first indices of pairs, shape [N_pair * batch_size] idx_j: Long tensor of the second indices of pairs, shape [N_pair * batch_size] Returns: ----- Ea: Float tensor of atomic electrostatic energy, shape [N * batch_size] ''' if Qa.device.type == "cpu" or Qa.dim() > 1: fac = self.kehalf * Qa[idx_i] * Qa[idx_j] / self.dielectric_constant else: fac = self.kehalf * Qa.gather(0, idx_i) * Qa.gather(0, idx_j) / self.dielectric_constant switch = self.cutoff_fn(Dij_lr, self.cutoff, self.cuton) cswitch = 1 - switch view_shape = (-1, 1) if Qa.dim() > 1 else (-1,) if self.cutoff_lr is None or self.cutoff_lr <= 0: Eele_ordinary = 1.0 / Dij_lr Eele_shielded = 1.0 / self._shield(Dij_lr) Eele = fac * (switch * Eele_shielded + cswitch * Eele_ordinary).view(view_shape) else: Eele_ordinary, Eele_shielded, condition, zeros = self.lr_shield(Dij_lr) # combine shielded and ordinary interactions and apply prefactors Eele = fac * (switch * Eele_shielded + cswitch * Eele_ordinary).view(view_shape) Eele = torch.where(condition, Eele, zeros) return segment_sum_coo(Eele, idx_i, dim_size=len(Qa))
[docs] class AtomicCharge2DipoleLayer(BaseFFLayer):
[docs] def __init__(self) -> None: super().__init__(input_fields={"Qa", "Ra", "batch_seg"}, output_fields={"M2"})
[docs] def get_M2(self, Qa: Tensor, Ra: Tensor, batch_seg: Optional[Tensor]=None) -> Tensor: if batch_seg is None: batch_seg = torch.zeros_like(Qa, dtype=torch.long) Pa = Qa.unsqueeze(1) * Ra.view((-1, 3, 1) if Qa.dim() > 1 else (-1, 3)) return segment_sum_coo(Pa, batch_seg)