enerzyme.models.spookynet.interaction.LocalInteraction#

class enerzyme.models.spookynet.interaction.LocalInteraction(dim_embedding: int, num_rbf: int, num_residual_x: int, num_residual_s: int, num_residual_p: int, num_residual_d: int, num_residual: int, activation_fn: Literal['shifted_softplus', 'swish'] = 'swish')[source]#

Bases: Module

Block for updating atomic features through local interactions with neighboring atoms (message-passing).

Arguments:
num_features (int):

Dimensions of feature space.

num_basis_functions (int):

Number of radial basis functions.

num_residual_pre_i (int):

Number of residual blocks applied to atomic features in i branch (central atoms) before computing the interaction.

num_residual_pre_j (int):

Number of residual blocks applied to atomic features in j branch (neighbouring atoms) before computing the interaction.

num_residual_post (int):

Number of residual blocks applied to interaction features.

activation (str):

Kind of activation function. Possible values: ‘swish’: Swish activation function. ‘ssp’: Shifted softplus activation function.

__init__(dim_embedding: int, num_rbf: int, num_residual_x: int, num_residual_s: int, num_residual_p: int, num_residual_d: int, num_residual: int, activation_fn: Literal['shifted_softplus', 'swish'] = 'swish') None[source]#

Initializes the LocalInteraction class.

forward(x: Tensor, rbf: Tensor, pij: Tensor, dij: Tensor, idx_i: Tensor, idx_j: Tensor) Tensor[source]#

Evaluate interaction block. N: Number of atoms. P: Number of atom pairs.

x (FloatTensor [N, num_features]):

Atomic feature vectors.

rbf (FloatTensor [N, num_basis_functions]):

Values of the radial basis functions for the pairwise distances.

idx_i (LongTensor [P]):

Index of atom i for all atomic pairs ij. Each pair must be specified as both ij and ji.

idx_j (LongTensor [P]):

Same as idx_i, but for atom j.

reset_parameters() None[source]#

Initialize parameters.