from math import pi as PI
import torch
from torch import Tensor
from torch.nn import Module, Sequential, Linear, init
from torch_geometric.nn import MessagePassing
from ..activation import ACTIVATION_KEY_TYPE, ACTIVATION_PARAM_TYPE, get_activation_fn
[docs]
class InteractionBlock(Module):
[docs]
def __init__(self, hidden_channels: int, num_gaussians: int,
num_filters: int, cutoff: float,
activation_fn: ACTIVATION_KEY_TYPE="shifted_softplus", # activation function
activation_params: ACTIVATION_PARAM_TYPE=dict()
):
super().__init__()
self.mlp = Sequential(
Linear(num_gaussians, num_filters),
get_activation_fn(activation_fn, activation_params),
Linear(num_filters, num_filters),
)
self.conv = CFConv(hidden_channels, hidden_channels, num_filters,
self.mlp, cutoff)
self.act = get_activation_fn(activation_fn, activation_params)
self.lin = Linear(hidden_channels, hidden_channels)
self.reset_parameters()
[docs]
def reset_parameters(self):
init.xavier_uniform_(self.mlp[0].weight)
self.mlp[0].bias.data.fill_(0)
init.xavier_uniform_(self.mlp[2].weight)
self.mlp[2].bias.data.fill_(0)
self.conv.reset_parameters()
init.xavier_uniform_(self.lin.weight)
self.lin.bias.data.fill_(0)
[docs]
def forward(self, x: Tensor, edge_index: Tensor, edge_weight: Tensor,
edge_attr: Tensor) -> Tensor:
x = self.conv(x, edge_index, edge_weight, edge_attr)
x = self.act(x)
x = self.lin(x)
return x
[docs]
class CFConv(MessagePassing):
[docs]
def __init__(
self,
in_channels: int,
out_channels: int,
num_filters: int,
nn: Sequential,
cutoff: float,
):
super().__init__(aggr='add')
self.lin1 = Linear(in_channels, num_filters, bias=False)
self.lin2 = Linear(num_filters, out_channels)
self.nn = nn
self.cutoff = cutoff
self.reset_parameters()
[docs]
def reset_parameters(self):
init.xavier_uniform_(self.lin1.weight)
init.xavier_uniform_(self.lin2.weight)
self.lin2.bias.data.fill_(0)
[docs]
def forward(self, x: Tensor, edge_index: Tensor, edge_weight: Tensor,
edge_attr: Tensor) -> Tensor:
C = 0.5 * (torch.cos(edge_weight * PI / self.cutoff) + 1.0)
W = self.nn(edge_attr) * C.view(-1, 1)
x = self.lin1(x)
x = self.propagate(edge_index, x=x, W=W)
x = self.lin2(x)
return x
[docs]
def message(self, x_j: Tensor, W: Tensor) -> Tensor:
return x_j * W