Source code for enerzyme.models.schnet.core

"""
Description:    Code is adapted from torch geometric implementation https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/models/schnet.py.
All rights reserved to original authors.
"""

from typing import List
import torch
from torch import Tensor
from torch.nn import Linear, Module, ModuleList
from .interaction import InteractionBlock
from ..blocks.mlp import DenseLayer
from ..layers import BaseFFCore, DistanceLayer, RangeSeparationLayer
from ..activation import ACTIVATION_KEY_TYPE, ACTIVATION_PARAM_TYPE, get_activation_fn


DEFAULT_BUILD_PARAMS = {
    'dim_embedding': 128,
    'num_rbf': 128,
    'max_Za': 94,
    'cutoff_sr': 5.0,
    'Hartree_in_E': 1,
    'Bohr_in_R': 0.5291772108
}
DEFAULT_LAYER_PARAMS = [{'name': 'RangeSeparation'},
 {'name': 'GaussianSmearing',
  'params': {'no_basis_at_infinity': False,
   'init_alpha': 1,
   'exp_weighting': False,
   'learnable_shape': True,
   'init_width_flavor': 'PhysNet'}},
 {'name': 'RandomAtomEmbedding'},
 {'name': 'Core',
  'params': {'hidden_channels': 128,
   'num_interactions': 4,
   'activation_fn': 'shifted_softplus',
   'activation_params': {'dim_feature': 1,
    'initial_alpha': 1,
    'initial_beta': 1,
    'learnable': False},
}},
 {'name': 'AtomicAffine',
  'params': {'shifts': {'Ea': {'values': 0, 'learnable': True},
    'Qa': {'values': 0, 'learnable': True}},
   'scales': {'Ea': {'values': 1, 'learnable': True},
    'Qa': {'values': 1, 'learnable': True}}}},
 {'name': 'ChargeConservation'},
 {'name': 'AtomicCharge2Dipole'},
 {'name': 'ElectrostaticEnergy',
  'params': {'cutoff_lr': None, 'flavor': 'PhysNet'}},
 {'name': 'GrimmeD3Energy', 'params': {'learnable': True}},
 {'name': 'EnergyReduce'},
 {'name': 'Force'}]

[docs] class SchNetCore(BaseFFCore):
[docs] def __init__( self, hidden_channels: int = 128, dim_embedding: int = 128, num_interactions: int = 4, num_rbf: int = 128, cutoff_sr: float = 5.0, activation_fn: ACTIVATION_KEY_TYPE="shifted_softplus", # activation function activation_params: ACTIVATION_PARAM_TYPE=dict(), shallow_ensemble_size: int=1 ): super().__init__(input_fields={"idx_i_sr", "idx_j_sr", "Dij_sr", "rbf", "atom_embedding"}, output_fields={"Ea", "Qa"}) self.hidden_channels = hidden_channels self.num_filters = dim_embedding self.num_interactions = num_interactions self.num_gaussians = num_rbf self.cutoff = cutoff_sr self.interactions = ModuleList() for _ in range(num_interactions): block = InteractionBlock(hidden_channels, num_rbf, dim_embedding, cutoff_sr) self.interactions.append(block) self.lin1 = Linear(hidden_channels, hidden_channels // 2) self.act = get_activation_fn(activation_fn, activation_params) self.lin2 = DenseLayer(hidden_channels // 2, 2, initial_weight="xavier_uniform", initial_bias="zero", shallow_ensemble_size=shallow_ensemble_size) self.shallow_ensemble_size = shallow_ensemble_size self.reset_parameters()
def __str__(self) -> str: return """ ###################################################### # Augmented SchNet (NeurIPS 2017, arXiv: 1706.08566) # ###################################################### """
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" for interaction in self.interactions: interaction.reset_parameters() torch.nn.init.xavier_uniform_(self.lin1.weight) self.lin1.bias.data.fill_(0)
[docs] def build(self, built_layers: List[Module]) -> None: # build necessary fixed pre-core layers # TODO: make this more flexible pre_core = True for i, layer in enumerate(built_layers): if i == 0: if isinstance(layer, DistanceLayer): layer.reset_field_name(Dij="Dij_lr") self.pre_sequence.append(layer) else: calculate_distance = DistanceLayer() calculate_distance.reset_field_name(Dij="Dij_lr") self.pre_sequence.append(calculate_distance) if layer is self: pre_core = False continue if pre_core: # build pre-core sequence if isinstance(layer, RangeSeparationLayer): layer.reset_field_name(idx_i_lr="idx_i", idx_j_lr="idx_j") self.pre_sequence.append(layer) else: # build post-core sequence self.post_sequence.append(layer)
[docs] def get_output(self, idx_i_sr: Tensor, idx_j_sr: Tensor, Dij_sr: Tensor, rbf: Tensor, atom_embedding: Tensor): edge_index = torch.stack([idx_i_sr, idx_j_sr]) for interaction in self.interactions: atom_embedding = atom_embedding + interaction(atom_embedding, edge_index, Dij_sr, rbf) atom_embedding = self.lin1(atom_embedding) atom_embedding = self.act(atom_embedding) output = self.lin2(atom_embedding) return {"Ea": output[:,0], "Qa": output[:,1]}