enerzyme.models.spookynet.interaction.LocalInteraction#
- class enerzyme.models.spookynet.interaction.LocalInteraction(dim_embedding: int, num_rbf: int, num_residual_x: int, num_residual_s: int, num_residual_p: int, num_residual_d: int, num_residual: int, activation_fn: Literal['shifted_softplus', 'swish'] = 'swish')[source]#
Bases:
ModuleBlock for updating atomic features through local interactions with neighboring atoms (message-passing).
- Arguments:
- num_features (int):
Dimensions of feature space.
- num_basis_functions (int):
Number of radial basis functions.
- num_residual_pre_i (int):
Number of residual blocks applied to atomic features in i branch (central atoms) before computing the interaction.
- num_residual_pre_j (int):
Number of residual blocks applied to atomic features in j branch (neighbouring atoms) before computing the interaction.
- num_residual_post (int):
Number of residual blocks applied to interaction features.
- activation (str):
Kind of activation function. Possible values: ‘swish’: Swish activation function. ‘ssp’: Shifted softplus activation function.
- __init__(dim_embedding: int, num_rbf: int, num_residual_x: int, num_residual_s: int, num_residual_p: int, num_residual_d: int, num_residual: int, activation_fn: Literal['shifted_softplus', 'swish'] = 'swish') None[source]#
Initializes the LocalInteraction class.
- forward(x: Tensor, rbf: Tensor, pij: Tensor, dij: Tensor, idx_i: Tensor, idx_j: Tensor) Tensor[source]#
Evaluate interaction block. N: Number of atoms. P: Number of atom pairs.
- x (FloatTensor [N, num_features]):
Atomic feature vectors.
- rbf (FloatTensor [N, num_basis_functions]):
Values of the radial basis functions for the pairwise distances.
- idx_i (LongTensor [P]):
Index of atom i for all atomic pairs ij. Each pair must be specified as both ij and ji.
- idx_j (LongTensor [P]):
Same as idx_i, but for atom j.