enerzyme.models.leftnet.core.EquiMessagePassing#
- class enerzyme.models.leftnet.core.EquiMessagePassing(hidden_channels, num_radial, hidden_channels_chi=96, head: int = 16, chi1: int = 32, chi2: int = 8, has_dropout_flag: bool = False, has_norm_before_flag=True, has_norm_after_flag=False, reduce_mode='sum', device=device(type='cpu'))[source]#
Bases:
MessagePassing- __init__(hidden_channels, num_radial, hidden_channels_chi=96, head: int = 16, chi1: int = 32, chi2: int = 8, has_dropout_flag: bool = False, has_norm_before_flag=True, has_norm_after_flag=False, reduce_mode='sum', device=device(type='cpu'))[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- aggregate(features: Tuple[Tensor, Tensor], index: Tensor, ptr: Tensor | None, dim_size: int | None) Tuple[Tensor, Tensor][source]#
Aggregates messages from neighbors as \(\bigoplus_{j \in \mathcal{N}(i)}\).
Takes in the output of message computation as first argument and any argument which was initially passed to
propagate().By default, this function will delegate its call to the underlying
Aggregationmodule to reduce messages as specified in__init__()by theaggrargument.
- forward(x, vec, edge_index, edge_rbf, weight, edge_vector, rope)[source]#
Runs the forward pass of the module.
- message(xh_j, vec_j, rbfh_ij, r_ij)[source]#
Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in
edge_index. This function can take any argument as input which was initially passed topropagate(). Furthermore, tensors passed topropagate()can be mapped to the respective nodes \(i\) and \(j\) by appending_ior_jto the variable name, .e.g.x_iandx_j.