from abc import ABC, abstractmethod
from typing import Optional, List, Tuple, Dict, Union, Callable
import numpy as np
import torch
from torch import nn, Tensor
from torch.nn import Module, ModuleList, Parameter, ParameterList
import torch.nn.functional as F
from opt_einsum_fx import optimize_einsums_full
from e3nn import o3
from e3nn.nn import FullyConnectedNet, Activation
from e3nn.o3 import Irreps, TensorProduct, FullyConnectedTensorProduct
from e3nn.util.jit import compile_mode
from torch_scatter import scatter_sum
from ..irreps_tools import tp_out_irreps_with_instructions, reshape_irreps, U_matrix_real, linear_out_irreps
BATCH_EXAMPLE = 10
ALPHABET = ["w", "x", "v", "n", "z", "r", "t", "y", "u", "o", "p", "s"]
[docs]
@compile_mode("script")
class Contraction(Module):
[docs]
def __init__(
self,
irreps_in: Irreps,
irrep_out: Irreps,
correlation: int,
internal_weights: bool = True,
num_elements: Optional[int] = None,
weights: Optional[Tensor] = None,
) -> None:
super().__init__()
self.num_features = irreps_in.count((0, 1))
self.coupling_irreps = Irreps([irrep.ir for irrep in irreps_in])
self.correlation = correlation
dtype = torch.get_default_dtype()
for nu in range(1, correlation + 1):
U_matrix = U_matrix_real(
irreps_in=self.coupling_irreps,
irreps_out=irrep_out,
correlation=nu,
dtype=dtype,
)[-1]
self.register_buffer(f"U_matrix_{nu}", U_matrix)
# Tensor contraction equations
self.contractions_weighting = ModuleList()
self.contractions_features = ModuleList()
# Create weight for product basis
self.weights = ParameterList([])
for i in range(correlation, 0, -1):
# Shapes definying
num_params = self.U_tensors(i).size()[-1]
num_equivariance = 2 * irrep_out.lmax + 1
num_ell = self.U_tensors(i).size()[-2]
if i == correlation:
parse_subscript_main = (
[ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)]
+ ["ik,ekc,bci,be -> bc"]
+ [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)]
)
graph_module_main = torch.fx.symbolic_trace(
lambda x, y, w, z: torch.einsum(
"".join(parse_subscript_main), x, y, w, z
)
)
# Optimizing the contractions
self.graph_opt_main = optimize_einsums_full(
model=graph_module_main,
example_inputs=(
torch.randn(
[num_equivariance] + [num_ell] * i + [num_params]
).squeeze(0),
torch.randn((num_elements, num_params, self.num_features)),
torch.randn((BATCH_EXAMPLE, self.num_features, num_ell)),
torch.randn((BATCH_EXAMPLE, num_elements)),
),
)
# Parameters for the product basis
w = Parameter(
torch.randn((num_elements, num_params, self.num_features))
/ num_params
)
self.weights_max = w
else:
# Generate optimized contractions equations
parse_subscript_weighting = (
[ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))]
+ ["k,ekc,be->bc"]
+ [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))]
)
parse_subscript_features = (
["bc"]
+ [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))]
+ ["i,bci->bc"]
+ [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))]
)
# Symbolic tracing of contractions
graph_module_weighting = torch.fx.symbolic_trace(
lambda x, y, z: torch.einsum(
"".join(parse_subscript_weighting), x, y, z
)
)
graph_module_features = torch.fx.symbolic_trace(
lambda x, y: torch.einsum("".join(parse_subscript_features), x, y)
)
# Optimizing the contractions
graph_opt_weighting = optimize_einsums_full(
model=graph_module_weighting,
example_inputs=(
torch.randn(
[num_equivariance] + [num_ell] * i + [num_params]
).squeeze(0),
torch.randn((num_elements, num_params, self.num_features)),
torch.randn((BATCH_EXAMPLE, num_elements)),
),
)
graph_opt_features = optimize_einsums_full(
model=graph_module_features,
example_inputs=(
torch.randn(
[BATCH_EXAMPLE, self.num_features, num_equivariance]
+ [num_ell] * i
).squeeze(2),
torch.randn((BATCH_EXAMPLE, self.num_features, num_ell)),
),
)
self.contractions_weighting.append(graph_opt_weighting)
self.contractions_features.append(graph_opt_features)
# Parameters for the product basis
w = Parameter(
torch.randn((num_elements, num_params, self.num_features))
/ num_params
)
self.weights.append(w)
if not internal_weights:
self.weights = weights[:-1]
self.weights_max = weights[-1]
[docs]
def forward(self, x: Tensor, y: Tensor):
out = self.graph_opt_main(
self.U_tensors(self.correlation),
self.weights_max,
x,
y,
)
for i, (weight, contract_weights, contract_features) in enumerate(
zip(self.weights, self.contractions_weighting, self.contractions_features)
):
c_tensor = contract_weights(
self.U_tensors(self.correlation - i - 1),
weight,
y,
)
c_tensor = c_tensor + out
out = contract_features(c_tensor, x)
return out.view(out.shape[0], -1)
[docs]
def U_tensors(self, nu: int):
return dict(self.named_buffers())[f"U_matrix_{nu}"]
[docs]
@compile_mode("script")
class SymmetricContraction(nn.Module):
[docs]
def __init__(
self,
irreps_in: Irreps,
irreps_out: Irreps,
correlation: Union[int, Dict[str, int]],
irrep_normalization: Optional[str] = "component",
path_normalization: Optional[str] = "element",
internal_weights: Optional[bool] = None,
shared_weights: Optional[bool] = None,
num_elements: Optional[int] = None,
) -> None:
super().__init__()
if irrep_normalization is None:
irrep_normalization = "component"
if path_normalization is None:
path_normalization = "element"
assert irrep_normalization in ["component", "norm", "none"]
assert path_normalization in ["element", "path", "none"]
self.irreps_in = Irreps(irreps_in)
self.irreps_out = Irreps(irreps_out)
del irreps_in, irreps_out
if not isinstance(correlation, tuple):
corr = correlation
correlation = {}
for irrep_out in self.irreps_out:
correlation[irrep_out] = corr
assert shared_weights or not internal_weights
if internal_weights is None:
internal_weights = True
self.internal_weights = internal_weights
self.shared_weights = shared_weights
del internal_weights, shared_weights
self.contractions = ModuleList()
for irrep_out in self.irreps_out:
self.contractions.append(
Contraction(
irreps_in=self.irreps_in,
irrep_out=Irreps(str(irrep_out.ir)),
correlation=correlation[irrep_out],
internal_weights=self.internal_weights,
num_elements=num_elements,
weights=self.shared_weights,
)
)
[docs]
def forward(self, x: Tensor, y: Tensor):
outs = [contraction(x, y) for contraction in self.contractions]
return torch.cat(outs, dim=-1)
[docs]
@compile_mode("script")
class EquivariantProductBasisBlock(nn.Module):
[docs]
def __init__(
self,
node_feats_irreps: Irreps,
target_irreps: Irreps,
correlation: int,
use_sc: bool = True,
num_elements: Optional[int] = None,
) -> None:
super().__init__()
self.use_sc = use_sc
self.symmetric_contractions = SymmetricContraction(
irreps_in=node_feats_irreps,
irreps_out=target_irreps,
correlation=correlation,
num_elements=num_elements,
)
# Update linear
self.linear = o3.Linear(
target_irreps,
target_irreps,
internal_weights=True,
shared_weights=True,
)
[docs]
def forward(
self,
node_feats: Tensor,
sc: Optional[Tensor],
node_attrs: Tensor,
) -> Tensor:
node_feats = self.symmetric_contractions(node_feats, node_attrs)
if self.use_sc and sc is not None:
return self.linear(node_feats) + sc
return self.linear(node_feats)
[docs]
@compile_mode("script")
class TensorProductWeightsBlock(Module):
[docs]
def __init__(self, num_elements: int, num_edge_feats: int, num_feats_out: int):
super().__init__()
weights = torch.empty(
(num_elements, num_edge_feats, num_feats_out),
dtype=torch.get_default_dtype(),
)
torch.nn.init.xavier_uniform_(weights)
self.weights = Parameter(weights)
[docs]
def forward(
self,
sender_or_receiver_node_attrs: Tensor, # assumes that the node attributes are one-hot encoded
edge_feats: Tensor,
):
return torch.einsum(
"be, ba, aek -> bk", edge_feats, sender_or_receiver_node_attrs, self.weights
)
def __repr__(self):
return (
f'{self.__class__.__name__}(shape=({", ".join(str(s) for s in self.weights.shape)}), '
f"weights={np.prod(self.weights.shape)})"
)
[docs]
@compile_mode("script")
class InteractionBlock(ABC, Module):
[docs]
def __init__(
self,
node_attrs_irreps: Irreps,
node_feats_irreps: Irreps,
edge_attrs_irreps: Irreps,
edge_feats_irreps: Irreps,
target_irreps: Irreps,
hidden_irreps: Irreps,
avg_num_neighbors: float,
radial_MLP: Optional[List[int]] = None,
) -> None:
super().__init__()
self.node_attrs_irreps = node_attrs_irreps # one hot atom type
self.node_feats_irreps = node_feats_irreps # atom embedding
self.edge_attrs_irreps = edge_attrs_irreps # spherical harmonics
self.edge_feats_irreps = edge_feats_irreps # radial basis functions
self.target_irreps = target_irreps
self.hidden_irreps = hidden_irreps
self.avg_num_neighbors = avg_num_neighbors
if radial_MLP is None:
radial_MLP = [64, 64, 64]
self.radial_MLP = radial_MLP
self._setup()
@abstractmethod
def _setup(self) -> None:
...
[docs]
@abstractmethod
def forward(
self,
node_attrs: Tensor,
node_feats: Tensor,
edge_attrs: Tensor,
edge_feats: Tensor,
idx_i_sr: Tensor,
idx_j_sr: Tensor,
) -> Tuple[Tensor, Tensor]:
...
[docs]
@compile_mode("script")
class ResidualElementDependentInteractionBlock(InteractionBlock):
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _setup(self) -> None:
self.linear_up = o3.Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps
)
self.conv_tp = TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
)
self.conv_tp_weights = TensorProductWeightsBlock(
num_elements=self.node_attrs_irreps.num_irreps,
num_edge_feats=self.edge_feats_irreps.num_irreps,
num_feats_out=self.conv_tp.weight_numel,
)
# Linear
irreps_mid = irreps_mid.simplify()
self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps)
self.irreps_out = self.irreps_out.simplify()
self.linear = o3.Linear(
irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
)
# Selector TensorProduct
self.skip_tp = FullyConnectedTensorProduct(
self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out
)
[docs]
def forward(
self,
node_attrs: Tensor,
node_feats: Tensor,
edge_attrs: Tensor,
edge_feats: Tensor,
idx_i_sr: Tensor,
idx_j_sr: Tensor,
) -> Tensor:
sender = idx_i_sr
receiver = idx_j_sr
num_nodes = node_feats.shape[0]
sc = self.skip_tp(node_feats, node_attrs)
node_feats = self.linear_up(node_feats)
tp_weights = self.conv_tp_weights(node_attrs[sender], edge_feats)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.linear(message) / self.avg_num_neighbors
return message + sc # [n_nodes, irreps]
[docs]
@compile_mode("script")
class AgnosticNonlinearInteractionBlock(InteractionBlock):
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _setup(self) -> None:
self.linear_up = o3.Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps
)
self.conv_tp = TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
)
# Convolution weights
input_dim = self.edge_feats_irreps.num_irreps
self.conv_tp_weights = FullyConnectedNet(
[input_dim] + self.radial_MLP + [self.conv_tp.weight_numel],
F.silu,
)
# Linear
irreps_mid = irreps_mid.simplify()
self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps)
self.irreps_out = self.irreps_out.simplify()
self.linear = o3.Linear(
irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
)
# Selector TensorProduct
self.skip_tp = FullyConnectedTensorProduct(
self.irreps_out, self.node_attrs_irreps, self.irreps_out
)
[docs]
def forward(
self,
node_attrs: Tensor,
node_feats: Tensor,
edge_attrs: Tensor,
edge_feats: Tensor,
idx_i_sr: Tensor,
idx_j_sr: Tensor,
) -> Tensor:
sender = idx_i_sr
receiver = idx_j_sr
num_nodes = node_feats.shape[0]
tp_weights = self.conv_tp_weights(edge_feats)
node_feats = self.linear_up(node_feats)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.linear(message) / self.avg_num_neighbors
message = self.skip_tp(message, node_attrs)
return message # [n_nodes, irreps]
[docs]
@compile_mode("script")
class AgnosticResidualNonlinearInteractionBlock(InteractionBlock):
def _setup(self) -> None:
# First linear
self.linear_up = o3.Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps
)
self.conv_tp = TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
)
# Convolution weights
input_dim = self.edge_feats_irreps.num_irreps
self.conv_tp_weights = FullyConnectedNet(
[input_dim] + self.radial_MLP + [self.conv_tp.weight_numel],
F.silu,
)
# Linear
irreps_mid = irreps_mid.simplify()
self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps)
self.irreps_out = self.irreps_out.simplify()
self.linear = o3.Linear(
irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
)
# Selector TensorProduct
self.skip_tp = FullyConnectedTensorProduct(
self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out
)
[docs]
def forward(
self,
node_attrs: Tensor,
node_feats: Tensor,
edge_attrs: Tensor,
edge_feats: Tensor,
idx_i_sr: Tensor,
idx_j_sr: Tensor,
) -> Tensor:
sender = idx_i_sr
receiver = idx_j_sr
num_nodes = node_feats.shape[0]
sc = self.skip_tp(node_feats, node_attrs)
node_feats = self.linear_up(node_feats)
tp_weights = self.conv_tp_weights(edge_feats)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.linear(message) / self.avg_num_neighbors
message = message + sc
return message # [n_nodes, irreps]
[docs]
@compile_mode("script")
class RealAgnosticInteractionBlock(InteractionBlock):
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _setup(self) -> None:
# First linear
self.linear_up = o3.Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps,
self.edge_attrs_irreps,
self.target_irreps,
)
self.conv_tp = TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
)
# Convolution weights
input_dim = self.edge_feats_irreps.num_irreps
self.conv_tp_weights = FullyConnectedNet(
[input_dim] + self.radial_MLP + [self.conv_tp.weight_numel],
F.silu,
)
# Linear
irreps_mid = irreps_mid.simplify()
self.irreps_out = self.target_irreps
self.linear = o3.Linear(
irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
)
# Selector TensorProduct
self.skip_tp = FullyConnectedTensorProduct(
self.irreps_out, self.node_attrs_irreps, self.irreps_out
)
self.reshape = reshape_irreps(self.irreps_out)
[docs]
def forward(
self,
node_attrs: Tensor,
node_feats: Tensor,
edge_attrs: Tensor,
edge_feats: Tensor,
idx_i_sr: Tensor,
idx_j_sr: Tensor,
) -> Tuple[Tensor, Tensor]:
sender = idx_i_sr
receiver = idx_j_sr
num_nodes = node_feats.shape[0]
node_feats = self.linear_up(node_feats)
tp_weights = self.conv_tp_weights(edge_feats)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.linear(message) / self.avg_num_neighbors
message = self.skip_tp(message, node_attrs)
return (
self.reshape(message),
None,
) # [n_nodes, channels, (lmax + 1)**2]
[docs]
@compile_mode("script")
class RealAgnosticResidualInteractionBlock(InteractionBlock):
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _setup(self) -> None:
# First linear
self.linear_up = o3.Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps,
self.edge_attrs_irreps,
self.target_irreps,
)
self.conv_tp = TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
)
# Convolution weights
input_dim = self.edge_feats_irreps.num_irreps
self.conv_tp_weights = FullyConnectedNet(
[input_dim] + self.radial_MLP + [self.conv_tp.weight_numel],
F.silu,
)
# Linear
irreps_mid = irreps_mid.simplify()
self.irreps_out = self.target_irreps
self.linear = o3.Linear(
irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
)
# Selector TensorProduct
self.skip_tp = FullyConnectedTensorProduct(
self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps
)
self.reshape = reshape_irreps(self.irreps_out)
[docs]
def forward(
self,
node_attrs: Tensor,
node_feats: Tensor,
edge_attrs: Tensor,
edge_feats: Tensor,
idx_i_sr: Tensor,
idx_j_sr: Tensor,
) -> Tuple[Tensor, Tensor]:
sender = idx_i_sr
receiver = idx_j_sr
num_nodes = node_feats.shape[0]
sc = self.skip_tp(node_feats, node_attrs)
node_feats = self.linear_up(node_feats) # Linear map of node features
tp_weights = self.conv_tp_weights(edge_feats) # Lift rbf to tensor product weights
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps] tensor product of edge irreps and node features with rbf weights
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps] reduce the messages from all neighbors
message = self.linear(message) / self.avg_num_neighbors
return (
self.reshape(message),
sc,
) # [n_nodes, channels, (lmax + 1)**2]
[docs]
@compile_mode("script")
class RealAgnosticAttResidualInteractionBlock(InteractionBlock):
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _setup(self) -> None:
self.node_feats_down_irreps = Irreps("64x0e")
# First linear
self.linear_up = o3.Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps,
self.edge_attrs_irreps,
self.target_irreps,
)
self.conv_tp = TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
)
# Convolution weights
self.linear_down = o3.Linear(
self.node_feats_irreps,
self.node_feats_down_irreps,
internal_weights=True,
shared_weights=True,
)
input_dim = (
self.edge_feats_irreps.num_irreps
+ 2 * self.node_feats_down_irreps.num_irreps
)
self.conv_tp_weights = FullyConnectedNet(
[input_dim] + 3 * [256] + [self.conv_tp.weight_numel],
F.silu,
)
# Linear
irreps_mid = irreps_mid.simplify()
self.irreps_out = self.target_irreps
self.linear = o3.Linear(
irreps_mid,
self.irreps_out,
internal_weights=True,
shared_weights=True,
)
self.reshape = reshape_irreps(self.irreps_out)
# Skip connection.
self.skip_linear = o3.Linear(self.node_feats_irreps, self.hidden_irreps)
[docs]
def forward(
self,
node_attrs: Tensor,
node_feats: Tensor,
edge_attrs: Tensor,
edge_feats: Tensor,
idx_i_sr: Tensor,
idx_j_sr: Tensor,
) -> Tuple[Tensor, Tensor]:
sender = idx_i_sr
receiver = idx_j_sr
num_nodes = node_feats.shape[0]
sc = self.skip_linear(node_feats)
node_feats_up = self.linear_up(node_feats)
node_feats_down = self.linear_down(node_feats)
augmented_edge_feats = torch.cat(
[
edge_feats,
node_feats_down[sender],
node_feats_down[receiver],
],
dim=-1,
)
tp_weights = self.conv_tp_weights(augmented_edge_feats)
mji = self.conv_tp(
node_feats_up[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.linear(message) / self.avg_num_neighbors
return (
self.reshape(message),
sc,
) # [n_nodes, channels, (lmax + 1)**2]
[docs]
@compile_mode("script")
class LinearReadoutBlock(nn.Module):
[docs]
def __init__(self, irreps_in: Irreps, shallow_ensemble_size: int=1):
super().__init__()
self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=Irreps(f"{shallow_ensemble_size * 2}x0e"))
[docs]
def forward(self, x: Tensor) -> Tensor: # [n_nodes, irreps] # [..., ]
return self.linear(x) # [n_nodes, 2 * shallow_ensemble_size]
[docs]
@compile_mode("script")
class NonLinearReadoutBlock(nn.Module):
[docs]
def __init__(
self, irreps_in: Irreps, MLP_irreps: Irreps, gate: Optional[Callable], shallow_ensemble_size: int=1
):
super().__init__()
self.hidden_irreps = MLP_irreps
self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.hidden_irreps)
self.non_linearity = Activation(irreps_in=self.hidden_irreps, acts=[gate])
self.linear_2 = o3.Linear(
irreps_in=self.hidden_irreps, irreps_out=Irreps(f"{shallow_ensemble_size * 2}x0e")
)
[docs]
def forward(self, x: Tensor) -> Tensor: # [n_nodes, irreps] # [..., ]
x = self.non_linearity(self.linear_1(x))
return self.linear_2(x) # [n_nodes, 2 * shallow_ensemble_size]
INTERACTION_CLASSES = {
"ResidualElementDependentInteractionBlock": ResidualElementDependentInteractionBlock,
"AgnosticNonlinearInteractionBlock": AgnosticNonlinearInteractionBlock,
"AgnosticResidualNonlinearInteractionBlock": AgnosticResidualNonlinearInteractionBlock,
"RealAgnosticInteractionBlock": RealAgnosticInteractionBlock,
"RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock,
"RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock,
}