Source code for enerzyme.models.physnet.interaction

from typing import Optional, Union
import numpy as np
import torch
from torch import Tensor
from torch.nn import Dropout, Parameter
from ..functional import segment_sum_coo
from ..activation import ACTIVATION_KEY_TYPE, ACTIVATION_PARAM_TYPE
from ..blocks.mlp import DenseLayer as _DenseLayer
from ..blocks.mlp import ResidualStack as _ResidualStack
from ..blocks.mlp import INITIAL_WEIGHT_TYPE, INITIAL_BIAS_TYPE, NeuronLayer, ResidualMLP


DEFAULT_TYPE = Optional[Union[Tensor, np.ndarray, str]]


[docs] def weight_default(initial_weight: DEFAULT_TYPE=None) -> INITIAL_WEIGHT_TYPE: return "semi_orthogonal_glorot" if initial_weight is None else initial_weight
[docs] def bias_default(initial_weight: DEFAULT_TYPE=None) -> INITIAL_BIAS_TYPE: return "zero" if initial_weight is None else initial_weight
[docs] def DenseLayer( dim_feature_in: int, dim_feature_out: int, activation_fn: Optional[ACTIVATION_KEY_TYPE]=None, activation_params: ACTIVATION_PARAM_TYPE=dict(), initial_weight: DEFAULT_TYPE=None, initial_bias: DEFAULT_TYPE=None, use_bias: bool=True ) -> _DenseLayer: return _DenseLayer( dim_feature_in=dim_feature_in, dim_feature_out=dim_feature_out, activation_fn=activation_fn, activation_params=activation_params, initial_weight=weight_default(initial_weight), initial_bias=bias_default(initial_bias), use_bias=use_bias )
[docs] def ResidualStack( dim_feature: int, num_residual: int, activation_fn: Optional[ACTIVATION_KEY_TYPE]=None, activation_params: ACTIVATION_PARAM_TYPE=dict(), initial_weight: Optional[Union[Tensor, np.ndarray]]=None, initial_bias: Optional[Union[Tensor, np.ndarray]]=None, use_bias: bool=True, dropout_rate: float=0.0 ) -> _ResidualStack: default_initial_weight = weight_default(initial_weight) return _ResidualStack( dim_feature=dim_feature, num_residual=num_residual, activation_fn=activation_fn, activation_params=activation_params, initial_weight1=default_initial_weight, initial_weight2=default_initial_weight, initial_bias=bias_default(initial_bias), use_bias=use_bias, dropout_rate=dropout_rate )
[docs] class InteractionLayer(NeuronLayer): def __str__(self) -> str: return "Interaction layer: " + super().__str__()
[docs] def __init__( self, num_rbf, dim_embedding, num_residual, activation_fn: Optional[ACTIVATION_KEY_TYPE]=None, activation_params: ACTIVATION_PARAM_TYPE=dict(), dropout_rate=0.0) -> None: super().__init__(num_rbf, dim_embedding, activation_fn, activation_params) self.dropout = Dropout(dropout_rate) #transforms radial basis functions to feature space self.k2f = DenseLayer(num_rbf, dim_embedding, initial_weight="zero", use_bias=False) #rearrange feature vectors for computing the "message" self.dense_i = DenseLayer(dim_embedding, dim_embedding, activation_fn, activation_params) # central atoms self.dense_j = DenseLayer(dim_embedding, dim_embedding, activation_fn, activation_params) # neighbouring atoms #for performing residual transformation on the "message" self.residual_stack = ResidualStack(dim_embedding, num_residual, activation_fn, activation_params, dropout_rate=dropout_rate) #for performing the final update to the feature vectors self.dense = DenseLayer(dim_embedding, dim_embedding) self.u = Parameter(torch.ones([dim_embedding]))
[docs] def forward(self, x: Tensor, rbf: Tensor, idx_i: Tensor, idx_j: Tensor) -> Tensor: #pre-activation if self.activation_fn is not None: xa = self.dropout(self.activation_fn(x)) else: xa = self.dropout(x) #calculate feature mask from radial basis functions g = self.k2f(rbf) #calculate contribution of neighbors and central atom xi = self.dense_i(xa) xj = segment_sum_coo(g * self.dense_j(xa)[idx_j], idx_i, dim_size=xi.shape[0]) #add contributions to get the "message" m = xi + xj m = self.residual_stack(m) if self.activation_fn is not None: m = self.activation_fn(m) x = self.u * x + self.dense(m) return x
[docs] class InteractionBlock(NeuronLayer): def __str__(self) -> str: return "Interaction Block: " + super().__str__()
[docs] def __init__( self, num_rbf: int, dim_embedding: int, num_residual_atomic: int, num_residual_interaction: int, activation_fn: ACTIVATION_KEY_TYPE=None, activation_params: ACTIVATION_PARAM_TYPE=dict(), dropout_rate: float=0.0 ) -> None: super().__init__(num_rbf, dim_embedding) #interaction layer self.interaction = InteractionLayer(num_rbf, dim_embedding, num_residual_interaction, activation_fn=activation_fn, activation_params=activation_params, dropout_rate=dropout_rate) #residual layers self.residual_stack = ResidualStack(dim_embedding, num_residual_atomic, activation_fn, activation_params, dropout_rate=dropout_rate)
[docs] def forward(self, x: Tensor, rbf: Tensor, idx_i: Tensor, idx_j: Tensor) -> Tensor: return self.residual_stack(self.interaction(x, rbf, idx_i, idx_j))
[docs] def OutputBlock( dim_embedding: int, num_residual: int, activation_fn: ACTIVATION_KEY_TYPE=None, activation_params: ACTIVATION_PARAM_TYPE=dict(), dropout_rate: float=0.0, shallow_ensemble_size: int=1 ) -> ResidualMLP: default_initial_weight = weight_default() return ResidualMLP( dim_feature_in=dim_embedding, dim_feature_out=2, num_residual=num_residual, activation_fn=activation_fn, activation_params=activation_params, initial_weight1=default_initial_weight, initial_weight2=default_initial_weight, initial_weight_out="zero", initial_bias_residual=bias_default(), use_bias_out=False, dropout_rate=dropout_rate, shallow_ensemble_size=shallow_ensemble_size )