enerzyme.models.spookynet.interaction.NonlocalInteraction#

class enerzyme.models.spookynet.interaction.NonlocalInteraction(dim_embedding: int, num_residual_q: int, num_residual_k: int, num_residual_v: int, activation_fn: Literal['shifted_softplus', 'swish'] = 'swish')[source]#

Bases: Module

Block for updating atomic features through nonlocal interactions with all atoms.

Arguments:
num_features (int):

Dimensions of feature space.

num_basis_functions (int):

Number of radial basis functions.

num_residual_pre_i (int):

Number of residual blocks applied to atomic features in i branch (central atoms) before computing the interaction.

num_residual_pre_j (int):

Number of residual blocks applied to atomic features in j branch (neighbouring atoms) before computing the interaction.

num_residual_post (int):

Number of residual blocks applied to interaction features.

activation (str):

Kind of activation function. Possible values: ‘swish’: Swish activation function. ‘ssp’: Shifted softplus activation function.

__init__(dim_embedding: int, num_residual_q: int, num_residual_k: int, num_residual_v: int, activation_fn: Literal['shifted_softplus', 'swish'] = 'swish') None[source]#

Initializes the NonlocalInteraction class.

forward(x: Tensor, num_batch: int, batch_seg: Tensor, mask: Tensor | None = None) Tensor[source]#

Evaluate interaction block. N: Number of atoms.

x (FloatTensor [N, num_features]):

Atomic feature vectors.