Source code for enerzyme.models.schnet.interaction

from math import pi as PI
import torch
from torch import Tensor
from torch.nn import Module, Sequential, Linear, init
from torch_geometric.nn import MessagePassing
from ..activation import ACTIVATION_KEY_TYPE, ACTIVATION_PARAM_TYPE, get_activation_fn


[docs] class InteractionBlock(Module):
[docs] def __init__(self, hidden_channels: int, num_gaussians: int, num_filters: int, cutoff: float, activation_fn: ACTIVATION_KEY_TYPE="shifted_softplus", # activation function activation_params: ACTIVATION_PARAM_TYPE=dict() ): super().__init__() self.mlp = Sequential( Linear(num_gaussians, num_filters), get_activation_fn(activation_fn, activation_params), Linear(num_filters, num_filters), ) self.conv = CFConv(hidden_channels, hidden_channels, num_filters, self.mlp, cutoff) self.act = get_activation_fn(activation_fn, activation_params) self.lin = Linear(hidden_channels, hidden_channels) self.reset_parameters()
[docs] def reset_parameters(self): init.xavier_uniform_(self.mlp[0].weight) self.mlp[0].bias.data.fill_(0) init.xavier_uniform_(self.mlp[2].weight) self.mlp[2].bias.data.fill_(0) self.conv.reset_parameters() init.xavier_uniform_(self.lin.weight) self.lin.bias.data.fill_(0)
[docs] def forward(self, x: Tensor, edge_index: Tensor, edge_weight: Tensor, edge_attr: Tensor) -> Tensor: x = self.conv(x, edge_index, edge_weight, edge_attr) x = self.act(x) x = self.lin(x) return x
[docs] class CFConv(MessagePassing):
[docs] def __init__( self, in_channels: int, out_channels: int, num_filters: int, nn: Sequential, cutoff: float, ): super().__init__(aggr='add') self.lin1 = Linear(in_channels, num_filters, bias=False) self.lin2 = Linear(num_filters, out_channels) self.nn = nn self.cutoff = cutoff self.reset_parameters()
[docs] def reset_parameters(self): init.xavier_uniform_(self.lin1.weight) init.xavier_uniform_(self.lin2.weight) self.lin2.bias.data.fill_(0)
[docs] def forward(self, x: Tensor, edge_index: Tensor, edge_weight: Tensor, edge_attr: Tensor) -> Tensor: C = 0.5 * (torch.cos(edge_weight * PI / self.cutoff) + 1.0) W = self.nn(edge_attr) * C.view(-1, 1) x = self.lin1(x) x = self.propagate(edge_index, x=x, W=W) x = self.lin2(x) return x
[docs] def message(self, x_j: Tensor, W: Tensor) -> Tensor: return x_j * W