enerzyme.models.spookynet.interaction.InteractionModule#
- class enerzyme.models.spookynet.interaction.InteractionModule(dim_embedding: int, num_rbf: int, num_residual_pre: int, num_residual_local_x: int, num_residual_local_s: int, num_residual_local_p: int, num_residual_local_d: int, num_residual_local: int, num_residual_nonlocal_q: int, num_residual_nonlocal_k: int, num_residual_nonlocal_v: int, num_residual_post: int, num_residual_output: int, activation_fn: Literal['shifted_softplus', 'swish'] = 'swish')[source]#
Bases:
ModuleInteractionModule of SpookyNet, which computes a single iteration.
- Arguments:
- num_features (int):
Dimensions of feature space.
- num_basis_functions (int):
Number of radial basis functions.
- num_residual_pre (int):
Number of residual blocks applied to atomic features before interaction with neighbouring atoms.
- num_residual_post (int):
Number of residual blocks applied to atomic features after interaction with neighbouring atoms.
- num_residual_pre_local_i (int):
Number of residual blocks applied to atomic features in i branch (central atoms) before computing the local interaction.
- num_residual_pre_local_j (int):
Number of residual blocks applied to atomic features in j branch (neighbouring atoms) before computing the local interaction.
- num_residual_post_local (int):
Number of residual blocks applied to interaction features.
- num_residual_output (int):
Number of residual blocks applied to atomic features in output branch.
- activation (str):
Kind of activation function. Possible values: ‘swish’: Swish activation function. ‘ssp’: Shifted softplus activation function.
- __init__(dim_embedding: int, num_rbf: int, num_residual_pre: int, num_residual_local_x: int, num_residual_local_s: int, num_residual_local_p: int, num_residual_local_d: int, num_residual_local: int, num_residual_nonlocal_q: int, num_residual_nonlocal_k: int, num_residual_nonlocal_v: int, num_residual_post: int, num_residual_output: int, activation_fn: Literal['shifted_softplus', 'swish'] = 'swish') None[source]#
Initializes the InteractionModule class.
- forward(x: Tensor, rbf: Tensor, pij: Tensor, dij: Tensor, idx_i: Tensor, idx_j: Tensor, num_batch: int, batch_seg: Tensor, mask: Tensor | None = None) Tuple[Tensor, Tensor][source]#
Evaluate all modules in the block. N: Number of atoms. P: Number of atom pairs. B: Batch size (number of different molecules).
- Arguments:
- x (FloatTensor [N, num_features]):
Latent atomic feature vectors.
- rbf (FloatTensor [P, num_basis_functions]):
Values of the radial basis functions for the pairwise distances.
- idx_i (LongTensor [P]):
Index of atom i for all atomic pairs ij. Each pair must be specified as both ij and ji.
- idx_j (LongTensor [P]):
Same as idx_i, but for atom j.
- num_batch (int):
Batch size (number of different molecules).
- batch_seg (LongTensor [N]):
Index for each atom that specifies to which molecule in the batch it belongs.
- Returns:
- x (FloatTensor [N, num_features]):
Updated latent atomic feature vectors.
- y (FloatTensor [N, num_features]):
Contribution to output atomic features (environment descriptors).