enerzyme.models.leftnet.core.EquiMessagePassing#

class enerzyme.models.leftnet.core.EquiMessagePassing(hidden_channels, num_radial, hidden_channels_chi=96, head: int = 16, chi1: int = 32, chi2: int = 8, has_dropout_flag: bool = False, has_norm_before_flag=True, has_norm_after_flag=False, reduce_mode='sum', device=device(type='cpu'))[source]#

Bases: MessagePassing

__init__(hidden_channels, num_radial, hidden_channels_chi=96, head: int = 16, chi1: int = 32, chi2: int = 8, has_dropout_flag: bool = False, has_norm_before_flag=True, has_norm_after_flag=False, reduce_mode='sum', device=device(type='cpu'))[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

aggregate(features: Tuple[Tensor, Tensor], index: Tensor, ptr: Tensor | None, dim_size: int | None) Tuple[Tensor, Tensor][source]#

Aggregates messages from neighbors as \(\bigoplus_{j \in \mathcal{N}(i)}\).

Takes in the output of message computation as first argument and any argument which was initially passed to propagate().

By default, this function will delegate its call to the underlying Aggregation module to reduce messages as specified in __init__() by the aggr argument.

forward(x, vec, edge_index, edge_rbf, weight, edge_vector, rope)[source]#

Runs the forward pass of the module.

message(xh_j, vec_j, rbfh_ij, r_ij)[source]#

Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in edge_index. This function can take any argument as input which was initially passed to propagate(). Furthermore, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

reset_parameters()[source]#

Resets all learnable parameters of the module.

update(inputs: Tuple[Tensor, Tensor]) Tuple[Tensor, Tensor][source]#

Updates node embeddings in analogy to \(\gamma_{\mathbf{\Theta}}\) for each node \(i \in \mathcal{V}\). Takes in the output of aggregation as first argument and any argument which was initially passed to propagate().