from typing import Dict, List, Optional, Literal, Union
import numpy as np
import torch
from torch import Tensor
import torch.nn.functional as F
from e3nn.util.jit import compile_mode
from e3nn.o3 import Irreps, SphericalHarmonics
from .interaction import INTERACTION_CLASSES, EquivariantProductBasisBlock, LinearReadoutBlock, NonLinearReadoutBlock
from ..layers import BaseFFCore, DistanceLayer, RangeSeparationLayer, BaseAtomEmbedding, BaseElectronEmbedding, BaseRBF, ChargeConservationLayer
DEFAULT_BUILD_PARAMS = {
'r_max': 5.0,
'atomic_numbers': [1, 6, 7, 8, 15, 16],
}
DEFAULT_LAYER_PARAMS = [{
'name': 'Core',
'params': {
'num_bessel': 8,
'num_polynomial_cutoff': 5,
'interaction_cls': "RealAgnosticResidualInteractionBlock",
'interaction_cls_first': "RealAgnosticResidualInteractionBlock",
'max_ell': 3,
'correlation': 3,
'num_interactions': 2,
'MLP_irreps': "16x0e",
'radial_MLP': [64, 64, 64],
'hidden_irreps': "128x0e + 128x1o",
'gate': "silu",
'avg_num_neighbors': 8.0,
}
}]
GATE_FUNCTIONS = {
"abs": torch.abs,
"tanh": torch.tanh,
"silu": torch.nn.functional.silu,
"None": None,
}
[docs]
class MACEWrapper(BaseFFCore):
[docs]
def __init__(self,
atomic_numbers: List[int],
r_max: float,
num_bessel: int,
num_polynomial_cutoff: int,
interaction_cls: str,
interaction_cls_first: str,
max_ell: int,
correlation: int,
num_interactions: int,
MLP_irreps: str,
radial_MLP: List[int],
hidden_irreps: str,
gate: str,
avg_num_neighbors: float
):
try:
from mace.modules.models import ScaleShiftMACE
from mace.modules import interaction_classes, gate_dict
from mace.tools import get_atomic_number_table_from_zs
except ImportError:
raise ImportError("External FF: MACE is not installed. Please install it with `pip install mace-torch`.")
super().__init__(input_fields={"Ra", "Za", "batch_seg"}, output_fields={"E", "Fa"})
self.z_table = get_atomic_number_table_from_zs(atomic_numbers)
self.r_max = r_max
self.model = ScaleShiftMACE(
atomic_inter_scale=1.0,
atomic_inter_shift=0.0,
r_max=r_max,
num_bessel=num_bessel,
num_polynomial_cutoff=num_polynomial_cutoff,
max_ell=max_ell,
interaction_cls=interaction_classes[interaction_cls],
interaction_cls_first=interaction_classes[interaction_cls_first],
num_interactions=num_interactions,
num_elements=len(self.z_table),
hidden_irreps=Irreps(hidden_irreps),
MLP_irreps=Irreps(MLP_irreps),
atomic_energies=np.zeros(len(self.z_table)),
avg_num_neighbors=avg_num_neighbors,
atomic_numbers=self.z_table.zs,
correlation=correlation,
gate=gate_dict[gate],
radial_MLP=radial_MLP
)
def __str__(self) -> str:
return """
#################################################
# Wrapped MACE (NeurIPS 2022, arXiv:2206.07697) #
#################################################
"""
[docs]
def build(self, built_layers) -> None:
pass
[docs]
def get_output(self, Ra: Tensor, Za: Tensor, batch_seg: Tensor) -> Dict[str, Tensor]:
from mace.data.neighborhood import get_neighborhood
from mace.tools.utils import atomic_numbers_to_indices
from mace.data.atomic_data import to_one_hot
mace_data = dict()
indices = atomic_numbers_to_indices(Za.cpu(), z_table=self.z_table)
one_hot = to_one_hot(torch.tensor(indices, dtype=torch.long, device=Za.device).unsqueeze(-1), num_classes=len(self.z_table))
mace_data["batch"] = batch_seg
mace_data["ptr"] = [0]
mace_data["edge_index"], mace_data["shifts"], mace_data["unit_shifts"] = None, 0, 0
count = 0
for i in range(batch_seg[-1] + 1):
mask = batch_seg == i
edge_index, shifts, unit_shifts = get_neighborhood(
positions=Ra[mask].detach().cpu().numpy(), cutoff=self.r_max, pbc=None, cell=None
)
edge_index = torch.tensor(edge_index, dtype=torch.long, device=Ra.device)
if i == 0:
mace_data["edge_index"] = edge_index
else:
mace_data["edge_index"] = torch.cat([mace_data["edge_index"], edge_index + count], dim=1)
N = mask.sum().item()
count += N
mace_data["ptr"].append(count)
mace_data["ptr"] = torch.tensor(mace_data["ptr"], dtype=torch.long, device=Ra.device)
mace_data["positions"] = Ra
mace_data["node_attrs"] = one_hot
mace_data["batch"] = batch_seg
mace_data["cell"] = None
output = self.model(mace_data, compute_force=True, compute_virials=False, compute_stress=False, compute_displacement=False, compute_hessian=False, training=self.model.training)
return {"E": output["energy"], "Fa": output["forces"]}
[docs]
@compile_mode("script")
class MACECore(BaseFFCore):
[docs]
def __init__(self,
max_Za: int, max_ell: int, dim_embedding: int, num_rbf: int,
additional_hidden_irreps: str,
interaction_cls_first: Literal["RealAgnosticResidualInteractionBlock"],
interaction_cls: Literal["RealAgnosticResidualInteractionBlock"],
correlation: Union[int, List[int]],
num_interactions: int,
avg_num_neighbors: float,
MLP_irreps: str,
radial_MLP: List[int],
gate: str,
shallow_ensemble_size: int=1
):
super().__init__(input_fields={"Za", "vij_sr", "idx_i_sr", "idx_j_sr", "rbf", "atom_embedding", "charge_embedding", "spin_embedding"}, output_fields={"Ea", "Qa"})
self.max_Za = max_Za
if isinstance(correlation, int):
correlation = [correlation] * num_interactions
sh_irreps = Irreps.spherical_harmonics(max_ell)
self.spherical_harmonics = SphericalHarmonics(
sh_irreps, normalize=True, normalization="component"
)
node_attrs_irreps = Irreps([(max_Za + 1, (0, 1))])
node_feats_irreps = Irreps([(dim_embedding, (0, 1))])
edge_feats_irreps = Irreps([(num_rbf, (0, 1))])
hidden_irreps = Irreps(f"{dim_embedding}x0e+" + additional_hidden_irreps)
interaction_irreps = (sh_irreps * dim_embedding).sort()[0].simplify()
inter_first = INTERACTION_CLASSES[interaction_cls_first](
node_attrs_irreps=node_attrs_irreps,
node_feats_irreps=node_feats_irreps,
edge_attrs_irreps=sh_irreps,
edge_feats_irreps=edge_feats_irreps,
target_irreps=interaction_irreps,
hidden_irreps=hidden_irreps,
avg_num_neighbors=avg_num_neighbors,
radial_MLP=radial_MLP,
)
self.interactions = torch.nn.ModuleList([inter_first])
self.shallow_ensemble_size = shallow_ensemble_size
use_sc_first = False
if "Residual" in interaction_cls_first:
use_sc_first = True
node_feats_irreps_out = inter_first.target_irreps
prod = EquivariantProductBasisBlock(
node_feats_irreps=node_feats_irreps_out,
target_irreps=hidden_irreps,
correlation=correlation[0],
num_elements=max_Za + 1,
use_sc=use_sc_first,
)
self.products = torch.nn.ModuleList([prod])
self.readouts = torch.nn.ModuleList([LinearReadoutBlock(hidden_irreps, shallow_ensemble_size)])
for i in range(num_interactions - 1):
if i == num_interactions - 2:
hidden_irreps_out = str(
hidden_irreps[0]
) # Select only scalars for last layer
else:
hidden_irreps_out = hidden_irreps
inter = INTERACTION_CLASSES[interaction_cls](
node_attrs_irreps=node_attrs_irreps,
node_feats_irreps=hidden_irreps,
edge_attrs_irreps=sh_irreps,
edge_feats_irreps=edge_feats_irreps,
target_irreps=interaction_irreps,
hidden_irreps=hidden_irreps_out,
avg_num_neighbors=avg_num_neighbors,
radial_MLP=radial_MLP,
)
self.interactions.append(inter)
prod = EquivariantProductBasisBlock(
node_feats_irreps=interaction_irreps,
target_irreps=hidden_irreps_out,
correlation=correlation[i + 1],
num_elements=max_Za + 1,
use_sc=True,
)
self.products.append(prod)
if i == num_interactions - 2:
self.readouts.append(
NonLinearReadoutBlock(hidden_irreps_out, MLP_irreps, GATE_FUNCTIONS[gate], shallow_ensemble_size)
)
else:
self.readouts.append(LinearReadoutBlock(hidden_irreps, shallow_ensemble_size))
def __str__(self) -> str:
return """
###################################################
# Augmented MACE (NeurIPS 2022, arXiv:2206.07697) #
###################################################
"""
[docs]
def build(self, built_layers) -> None:
self.calculate_distance = DistanceLayer()
self.calculate_distance = DistanceLayer()
self.calculate_distance.with_vector_on("vij_lr")
self.calculate_distance.reset_field_name(Dij="Dij_lr")
self.pre_sequence.append(self.calculate_distance)
pre_core = True
for layer in built_layers:
if layer is self:
pre_core = False
continue
if pre_core:
# reset pre-core layers
if isinstance(layer, RangeSeparationLayer):
self.range_separation = layer
self.range_separation.reset_field_name(idx_i_lr="idx_i", idx_j_lr="idx_j")
elif isinstance(layer, BaseAtomEmbedding):
self.atom_embedding = layer
elif isinstance(layer, BaseElectronEmbedding):
if layer.attribute == "charge":
self.charge_embedding = layer
elif layer.attribute == "spin":
self.spin_embedding = layer
elif isinstance(layer, BaseRBF):
self.radial_basis_function = layer
# build pre-core sequence
self.pre_sequence.append(layer)
else:
# build post-core sequence
if isinstance(layer, ChargeConservationLayer):
self.charge_conservation = layer
self.post_sequence.append(layer)
[docs]
def get_output(self,
Za: Tensor, vij_sr: Tensor,
idx_i_sr: Tensor, idx_j_sr: Tensor, rbf: Tensor,
atom_embedding: Tensor, charge_embedding: Optional[Tensor]=None, spin_embedding: Optional[Tensor]=None
) -> Dict[str, Tensor]:
# prepare initial attributes and features
if charge_embedding is None:
charge_embedding = torch.zeros_like(atom_embedding)
if spin_embedding is None:
spin_embedding = torch.zeros_like(atom_embedding)
node_feats = atom_embedding + charge_embedding + spin_embedding
node_attrs = F.one_hot(Za, num_classes=self.max_Za + 1).to(node_feats.dtype)
edge_attrs = self.spherical_harmonics(vij_sr)
edge_feats = rbf
node_properties_list = []
for interaction, product, readout in zip(
self.interactions, self.products, self.readouts
):
node_feats, sc = interaction(
node_attrs=node_attrs,
node_feats=node_feats,
edge_attrs=edge_attrs,
edge_feats=edge_feats,
idx_i_sr=idx_i_sr,
idx_j_sr=idx_j_sr,
)
node_feats = product(
node_feats=node_feats, sc=sc, node_attrs=node_attrs
)
if self.shallow_ensemble_size > 1:
node_properties_list.append(readout(node_feats).reshape(-1, 2, self.shallow_ensemble_size))
else:
node_properties_list.append(readout(node_feats))
node_properties = torch.sum(
torch.stack(node_properties_list, dim=0), dim=0
)
return {"Ea": node_properties[:, 0], "Qa": node_properties[:, 1]}