Source code for enerzyme.models.xpainn.core

from typing import List, Optional, Union
import torch
from torch.nn import Module
from ..layers import BaseFFCore, DistanceLayer, RangeSeparationLayer
try:
    from xequinet.nn.model import XPaiNN
    from xequinet.utils.config import NetConfig
    from torch_geometric.data import Data
except ImportError as e:
    raise ImportError("External FF: XPaiNN is not installed. Please install it following https://github.com/X1X1010/XequiNet.", e)


DEFAULT_BUILD_PARAMS = {
    'cutoff_sr': 5.0,
}
DEFAULT_LAYER_PARAMS = [
    {'name': 'RangeSeparation'},
    {"name": "Core", "params": {
        'embed_basis': "gfn2-xtb",                  # embedding basis type
        'aux_basis': "aux56",                       # auxiliary basis type
        'node_dim': 128,                            # node irreps for the input
        'edge_irreps': "128x0e + 64x1o + 32x2e",    # edge irreps for the input
        'hidden_dim': 64,                           # hidden dimension for the output
        'hidden_irreps': "64x0e + 32x1o + 16x2e",   # hidden irreps for the output
        'rbf_kernel': "bessel",                     # radial basis function type
        'num_basis': 20,                            # number of the radial basis functions
        'cutoff_fn': "cosine",                      # cutoff function type
        'max_edges': 100,                           # maximum number of the edges
        'action_blocks': 3,                         # number of the action blocks
        'activation': "silu",                       # activation function type
        'norm_type': "nonorm",                      # normalization layer type
        'output_mode': "grad",                    # task type (`scalar` is for energy like, `grad` is for force like, etc.)
        'output_dim': 1,                            # output dimension of multi-task (only for `scalar` mode)
        'node_average': False,                      # whether to add the node average to the output (only for `scalar` mode)
        'default_dtype': "float32"                 # default data type      
    }}
]


[docs] class XPaiNNWrapper(BaseFFCore):
[docs] def __init__(self, cutoff_sr: float, embed_basis: str, aux_basis: str, node_dim: int, edge_irreps: str, hidden_dim: int, hidden_irreps: str, rbf_kernel: str, num_basis: int, cutoff_fn: str, max_edges: int, action_blocks: int, activation: str, norm_type: str, output_mode: str, output_dim: int, node_average: Union[bool, float], default_dtype: str, ): super().__init__(input_fields={"Ra", "Za", "batch_seg", "idx_i", "idx_j"}, output_fields={"E", "Fa"}) config = NetConfig(**{ "cutoff": cutoff_sr, "embed_basis": embed_basis, "aux_basis": aux_basis, "node_dim": node_dim, "edge_irreps": edge_irreps, "hidden_dim": hidden_dim, "hidden_irreps": hidden_irreps, "rbf_kernel": rbf_kernel, "num_basis": num_basis, "cutoff_fn": cutoff_fn, "max_edges": max_edges, "action_blocks": action_blocks, "activation": activation, "norm_type": norm_type, "output_mode": output_mode, "output_dim": output_dim, "node_average": node_average, "default_dtype": default_dtype, }) self.model = XPaiNN(config)
def __str__(self) -> str: return """ #################################################################### # Wrapped XPaiNN (J. Chem. Theory Comput. 2024, 20, 21, 9500–9511) # #################################################################### """
[docs] def build(self, built_layers: List[Module]): calculate_distance = DistanceLayer() calculate_distance.reset_field_name(Dij="Dij_lr") self.pre_sequence.append(calculate_distance) pre_core = True for layer in built_layers: if layer is self: pre_core = False continue if pre_core: if isinstance(layer, RangeSeparationLayer): layer.reset_field_name(idx_i_lr="idx_i", idx_j_lr="idx_j") self.pre_sequence.append(layer) else: self.post_sequence.append(layer)
[docs] def get_output(self, Ra, Za, batch_seg, idx_i, idx_j): data = Data(edge_index = torch.stack([idx_i, idx_j]), pos=Ra, at_no=Za, batch=batch_seg) E, Fa = self.model(data) return {"E": E.squeeze(-1), "Fa": Fa}