Source code for enerzyme.models.physnet.core

from typing import Dict, List
import torch
from torch import Tensor
from torch.nn import Module, Sequential
from .interaction import InteractionBlock, OutputBlock
from ..layers import DistanceLayer, RangeSeparationLayer, BaseFFCore
from ..activation import ACTIVATION_KEY_TYPE, ACTIVATION_PARAM_TYPE


DEFAULT_BUILD_PARAMS = {
    'dim_embedding': 128,
    'num_rbf': 64,
    'max_Za': 94,
    'cutoff_sr': 10.0,
    'Hartree_in_E': 1,
    'Bohr_in_R': 0.5291772108,
    'cutoff_fn': 'polynomial'
}
DEFAULT_LAYER_PARAMS = [{'name': 'RangeSeparation'},
 {'name': 'ExponentialGaussianRBF',
  'params': {'no_basis_at_infinity': False,
   'init_alpha': 1,
   'exp_weighting': False,
   'learnable_shape': True,
   'init_width_flavor': 'PhysNet'}},
 {'name': 'RandomAtomEmbedding'},
 {'name': 'Core',
  'params': {'num_blocks': 5,
   'num_residual_atomic': 2,
   'num_residual_interaction': 3,
   'num_residual_output': 1,
   'activation_fn': 'shifted_softplus',
   'activation_params': {'dim_feature': 1,
    'initial_alpha': 1,
    'initial_beta': 1,
    'learnable': False},
   'dropout_rate': 0.0}},
 {'name': 'AtomicAffine',
  'params': {'shifts': {'Ea': {'values': 0, 'learnable': True},
    'Qa': {'values': 0, 'learnable': True}},
   'scales': {'Ea': {'values': 1, 'learnable': True},
    'Qa': {'values': 1, 'learnable': True}}}},
 {'name': 'ChargeConservation'},
 {'name': 'AtomicCharge2Dipole'},
 {'name': 'ElectrostaticEnergy',
  'params': {'cutoff_lr': None, 'flavor': 'PhysNet'}},
 {'name': 'GrimmeD3Energy', 'params': {'learnable': True}},
 {'name': 'EnergyReduce'},
 {'name': 'Force'}]


[docs] class PhysNetCore(BaseFFCore): def __str__(self) -> str: return """ ################################################################################### # Pytorch Implementation of PhysNet (J. Chem. Theory Comput. 2019, 15, 3678−3693) # ################################################################################### """
[docs] def __init__( self, dim_embedding: int, num_rbf: int, num_blocks: int=3, # number of building blocks to be stacked num_residual_atomic: int=2, # number of residual layers for atomic refinements of feature vector num_residual_interaction: int=2, # number of residual layers for refinement of message vector num_residual_output: int=1, # number of residual layers for the output blocks activation_fn: ACTIVATION_KEY_TYPE="shifted_softplus", # activation function activation_params: ACTIVATION_PARAM_TYPE=dict(), dropout_rate: float=0.0, shallow_ensemble_size: int=1 ) -> None: super().__init__(input_fields={"rbf", "atom_embedding", "idx_i_sr", "idx_j_sr"}, output_fields={"Ea", "Qa", "nh_loss"}) self.num_blocks = num_blocks self.drop_out = dropout_rate self.shallow_ensemble_size = shallow_ensemble_size self.interaction_block = Sequential(*[ InteractionBlock( num_rbf, dim_embedding, num_residual_atomic, num_residual_interaction, activation_fn=activation_fn, activation_params=activation_params, dropout_rate=dropout_rate ) for _ in range(num_blocks) ]) self.output_block = Sequential(*[ OutputBlock( dim_embedding, num_residual_output, activation_fn=activation_fn, activation_params=activation_params, dropout_rate=dropout_rate, shallow_ensemble_size=shallow_ensemble_size ) for _ in range(num_blocks) ])
[docs] def build(self, built_layers: List[Module]) -> None: # build necessary fixed pre-core layers # TODO: make this more flexible pre_core = True for i, layer in enumerate(built_layers): if i == 0: if isinstance(layer, DistanceLayer): layer.reset_field_name(Dij="Dij_lr") self.pre_sequence.append(layer) else: calculate_distance = DistanceLayer() calculate_distance.reset_field_name(Dij="Dij_lr") self.pre_sequence.append(calculate_distance) if layer is self: pre_core = False continue if pre_core: # reset pre-core layers if isinstance(layer, RangeSeparationLayer): layer.reset_field_name(idx_i_lr="idx_i", idx_j_lr="idx_j") # build pre-core sequence self.pre_sequence.append(layer) else: # build post-core sequence self.post_sequence.append(layer)
[docs] def get_output(self, rbf: Tensor, atom_embedding: Tensor, idx_i_sr: Tensor, idx_j_sr: Tensor) -> Dict[str, Tensor]: ''' Compute raw atomic properties ''' Ea = 0 # atomic energy Qa = 0 # atomic charge nhloss = 0 #non-hierarchicality loss for i in range(self.num_blocks): atom_embedding = self.interaction_block[i](atom_embedding, rbf, idx_i_sr, idx_j_sr) out = self.output_block[i](atom_embedding) Ea += out[:,0] Qa += out[:,1] # compute non-hierarchicality loss out2 = out ** 2 if i > 0: nhloss += torch.mean(out2 / (out2 + lastout2 + 1e-7)) lastout2 = out2 output = {"Ea": Ea, "Qa": Qa, "nh_loss": nhloss} if self.shallow_ensemble_size > 1: output["l2_penalty"] = 0 for i in range(self.num_blocks): output["l2_penalty"] += self.output_block[i].output.l2loss() return output