import math
from typing import Dict, List, Optional, Tuple, Literal
import torch
from torch import Tensor
from torch.nn import Module, ModuleList, Linear
import torch.nn.functional as F
from .interaction import InteractionModule
from ..blocks.mlp import DenseLayer
from ..layers import DistanceLayer, RangeSeparationLayer, BaseFFCore, BaseAtomEmbedding, BaseElectronEmbedding, BaseRBF, ChargeConservationLayer, GatherAtomEmbedding
from ..activation import ACTIVATION_KEY_TYPE
DEFAULT_BUILD_PARAMS = {'dim_embedding': 64,
'num_rbf': 16,
'max_Za': 86,
'cutoff_sr': 5.291772105638412,
'Hartree_in_E': 1,
'Bohr_in_R': 0.5291772108,
'activation_fn': 'swish'}
DEFAULT_LAYER_PARAMS = [{'name': 'RangeSeparation', 'params': {'cutoff_fn': 'bump'}},
{'name': 'ExponentialBernsteinRBF',
'params': {'no_basis_at_infinity': False,
'init_alpha': 0.944863062918464,
'exp_weighting': False,
'learnable_shape': True}},
{'name': 'NuclearEmbedding',
'params': {'zero_init': True, 'use_electron_config': True}},
{'name': 'ElectronicEmbedding',
'params': {'num_residual': 1, 'attribute': 'charge'}},
{'name': 'ElectronicEmbedding',
'params': {'num_residual': 1, 'attribute': 'spin'}},
{'name': 'Core',
'params': {'num_modules': 3,
'num_residual_pre': 1,
'num_residual_local_x': 1,
'num_residual_local_s': 1,
'num_residual_local_p': 1,
'num_residual_local_d': 1,
'num_residual_local': 1,
'num_residual_nonlocal_q': 1,
'num_residual_nonlocal_k': 1,
'num_residual_nonlocal_v': 1,
'num_residual_post': 1,
'num_residual_output': 1,
'use_irreps': True,
'dropout_rate': 0.0}},
{'name': 'AtomicAffine',
'params': {'shifts': {'Ea': {'values': 0, 'learnable': True},
'Qa': {'values': 0, 'learnable': True}},
'scales': {'Ea': {'values': 1, 'learnable': True},
'Qa': {'values': 1, 'learnable': True}}}},
{'name': 'ChargeConservation'},
{'name': 'AtomicCharge2Dipole'},
{'name': 'ZBLRepulsionEnergy'},
{'name': 'ElectrostaticEnergy', 'params': {'flavor': 'SpookyNet'}},
{'name': 'GrimmeD4Energy', 'params': {'learnable': True}},
{'name': 'EnergyReduce'},
{'name': 'Force'}]
[docs]
class SpookyNetCore(BaseFFCore):
def __str__(self) -> str:
return """
###############################################
# SpookyNet (Nat. Commun., 2021, 12(1): 7273) #
###############################################
"""
[docs]
def __init__(
self, dim_embedding: int, num_rbf: int, num_modules: int, num_residual_pre: int,
num_residual_local_x: int, num_residual_local_s: int, num_residual_local_p: int,
num_residual_local_d: int, num_residual_local: int,
num_residual_nonlocal_q: int, num_residual_nonlocal_k: int, num_residual_nonlocal_v: int,
num_residual_post: int, num_residual_output: int, activation_fn: ACTIVATION_KEY_TYPE, use_irreps: bool, dropout_rate: float=0.0,
shallow_ensemble_size: int=1
) -> None:
super().__init__()
self.interaction = ModuleList(
[
InteractionModule(
dim_embedding=dim_embedding,
num_rbf=num_rbf,
num_residual_pre=num_residual_pre,
num_residual_local_x=num_residual_local_x,
num_residual_local_s=num_residual_local_s,
num_residual_local_p=num_residual_local_p,
num_residual_local_d=num_residual_local_d,
num_residual_local=num_residual_local,
num_residual_nonlocal_q=num_residual_nonlocal_q,
num_residual_nonlocal_k=num_residual_nonlocal_k,
num_residual_nonlocal_v=num_residual_nonlocal_v,
num_residual_post=num_residual_post,
num_residual_output=num_residual_output,
activation_fn=activation_fn,
)
for _ in range(num_modules)
]
)
if shallow_ensemble_size > 1:
self.output = DenseLayer(dim_embedding, 2, use_bias=False, shallow_ensemble_size=shallow_ensemble_size)
else:
self.output = Linear(dim_embedding, 2, bias=False)
self.use_irreps = use_irreps
self._sqrt2 = math.sqrt(2.0)
self._sqrt3 = math.sqrt(3.0)
self._sqrt3half = 0.5 * self._sqrt3
self.module_keep_prob = 1 - dropout_rate
self.calculate_distance: DistanceLayer = None
self.range_separation: RangeSeparationLayer = None
self.shallow_ensemble_size = shallow_ensemble_size
[docs]
def build(self, built_layers: List[Module]) -> None:
# build necessary fixed pre-core layers
self.calculate_distance = DistanceLayer()
self.calculate_distance.with_vector_on("vij_lr")
self.calculate_distance.reset_field_name(Dij="Dij_lr")
self.pre_sequence.append(self.calculate_distance)
pre_core = True
for layer in built_layers:
if layer is self:
pre_core = False
continue
if pre_core:
# reset pre-core layers
if isinstance(layer, RangeSeparationLayer):
self.range_separation = layer
self.range_separation.reset_field_name(idx_i_lr="idx_i", idx_j_lr="idx_j")
elif isinstance(layer, BaseAtomEmbedding):
self.atom_embedding = layer
elif isinstance(layer, BaseElectronEmbedding):
if layer.attribute == "charge":
self.charge_embedding = layer
elif layer.attribute == "spin":
self.spin_embedding = layer
elif isinstance(layer, BaseRBF):
self.radial_basis_function = layer
# build pre-core sequence
self.pre_sequence.append(layer)
else:
# build post-core sequence
if isinstance(layer, ChargeConservationLayer):
self.charge_conservation = layer
self.post_sequence.append(layer)
self.gather_embedding = GatherAtomEmbedding()
self.pre_sequence.append(self.gather_embedding)
def _atomic_properties_static(self, Dij_sr: Tensor, vij_sr: Tensor, batch_seg: Optional[Tensor]=None) -> Tuple[Tensor, Tensor, Tensor, int]:
pij = vij_sr / Dij_sr.unsqueeze(-1)
if self.use_irreps: # irreducible representation
try:
from e3nn.o3 import spherical_harmonics as sh
# strictly reproduction
dij = sh(2, pij[:, [1,2,0]], normalize=True, normalization="norm")[:, [0,3,1,2,4]]
except ImportError:
dij = torch.stack(
[
self._sqrt3 * pij[:, 0] * pij[:, 1], # xy
self._sqrt3 * pij[:, 0] * pij[:, 2], # xz
self._sqrt3 * pij[:, 1] * pij[:, 2], # yz
0.5 * (3 * pij[:, 2] * pij[:, 2] - 1.0), # z2
self._sqrt3half
* (pij[:, 0] * pij[:, 0] - pij[:, 1] * pij[:, 1]), # x2-y2
],
dim=-1,
)
else: # reducible Cartesian functions
dij = torch.stack(
[
pij[:, 0] * pij[:, 0], # x2
pij[:, 1] * pij[:, 1], # y2
pij[:, 2] * pij[:, 2], # z2
self._sqrt2 * pij[:, 0] * pij[:, 1], # x*y
self._sqrt2 * pij[:, 0] * pij[:, 2], # x*z
self._sqrt2 * pij[:, 1] * pij[:, 2], # y*z
],
dim=-1,
)
if batch_seg is None:
num_batch = 1
else:
num_batch = batch_seg[-1] + 1
if num_batch > 1:
one_hot = F.one_hot(batch_seg).to(
dtype=Dij_sr.dtype, device=Dij_sr.device
)
mask = one_hot @ one_hot.transpose(-1, -2)
else:
mask = None
return pij, dij, mask, num_batch
def _atomic_properties_dynamic(
self, atom_embedding: Tensor, num_batch: int,
rbf: Tensor, pij: Tensor, dij: Tensor, idx_i_sr: Tensor, idx_j_sr: Tensor, mask: Tensor, batch_seg: Optional[Tensor]=None):
x = atom_embedding
dropout_mask = torch.ones((num_batch, 1), dtype=x.dtype, device=x.device)
f = x.new_zeros(x.size())
for module in self.interaction:
x, y = module(
x, rbf, pij, dij, idx_i_sr, idx_j_sr, num_batch, batch_seg, mask
)
# apply dropout mask
if self.training and self.module_keep_prob < 1.0:
y = y * dropout_mask[batch_seg]
dropout_mask = dropout_mask * torch.bernoulli(self.keep_prob * torch.ones_like(dropout_mask))
f = f + y
out = self.output(f)
ea = out.narrow(1, 0, 1).squeeze(1) # atomic energy
qa = out.narrow(1, 1, 1).squeeze(1) # partial charge
return ea, qa
[docs]
def get_output(
self, Dij_sr: Tensor, vij_sr: Tensor, idx_i_sr: Tensor, idx_j_sr: Tensor,
rbf: Tensor, atom_embedding: Tensor, batch_seg: Optional[Tensor]=None
) -> Dict[Literal["Ea", "Qa"], Tensor]:
pij, dij, mask, num_batch = self._atomic_properties_static(Dij_sr, vij_sr, batch_seg)
ea, qa = self._atomic_properties_dynamic(
atom_embedding, num_batch, rbf, pij, dij, idx_i_sr, idx_j_sr, mask, batch_seg
)
return {"Ea": ea, "Qa": qa}