Source code for enerzyme.models.layers.atom_embedding

import math
from abc import abstractmethod
from typing import Dict
import numpy as np
import torch
from torch import Tensor
from torch.nn import Embedding, Parameter, Linear, init
from . import BaseFFLayer


[docs] class BaseAtomEmbedding(BaseFFLayer):
[docs] def __init__(self, max_Za, dim_embedding) -> None: super().__init__(input_fields={"Za"}, output_fields={"atom_embedding"}) self.max_Za = max_Za self.dim_embedding = dim_embedding
[docs] @abstractmethod def get_embedding(self, Za: Tensor) -> Tensor: ...
[docs] def get_atom_embedding(self, Za: Tensor) -> Tensor: return self.get_embedding(Za)
[docs] class RandomAtomEmbedding(BaseAtomEmbedding):
[docs] def __init__(self, max_Za, dim_embedding) -> None: super().__init__(max_Za, dim_embedding) self.embedding = Embedding(max_Za + 1, dim_embedding)
[docs] def get_embedding(self, Za: Tensor) -> Tensor: return self.embedding(Za)
@property def weight(self) -> Tensor: return self.embedding.weight def _load_from_state_dict(self, state_dict: Dict[str, Tensor], *args, **kwargs): for k, v in state_dict.items(): if k.endswith("embedding.weight"): if len(v) > self.max_Za + 1: state_dict[k] = v[:self.max_Za + 1] print(len(v), ">") elif len(v) < self.max_Za + 1: state_dict[k] = torch.concat([v, self.embedding.weight[len(v):]], dim=0) super()._load_from_state_dict(state_dict, *args, **kwargs)
ELECTRON_CONFIG = np.array([ # Z 1s 2s 2p 3s 3p 4s 3d 4p 5s 4d 5p 6s 4f 5d 6p vs vp vd vf [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # n [ 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # H [ 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], # He [ 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # Li [ 4, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], # Be [ 5, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0], # B [ 6, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0], # C [ 7, 2, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0], # N [ 8, 2, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 4, 0, 0], # O [ 9, 2, 2, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 5, 0, 0], # F [ 10, 2, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 6, 0, 0], # Ne [ 11, 2, 2, 6, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # Na [ 12, 2, 2, 6, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], # Mg [ 13, 2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0], # Al [ 14, 2, 2, 6, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0], # Si [ 15, 2, 2, 6, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0], # P [ 16, 2, 2, 6, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 4, 0, 0], # S [ 17, 2, 2, 6, 2, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 5, 0, 0], # Cl [ 18, 2, 2, 6, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 6, 0, 0], # Ar [ 19, 2, 2, 6, 2, 6, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # K [ 20, 2, 2, 6, 2, 6, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], # Ca [ 21, 2, 2, 6, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 1, 0], # Sc [ 22, 2, 2, 6, 2, 6, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 2, 0], # Ti [ 23, 2, 2, 6, 2, 6, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 3, 0], # V [ 24, 2, 2, 6, 2, 6, 1, 5, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 5, 0], # Cr [ 25, 2, 2, 6, 2, 6, 2, 5, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 5, 0], # Mn [ 26, 2, 2, 6, 2, 6, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 6, 0], # Fe [ 27, 2, 2, 6, 2, 6, 2, 7, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 7, 0], # Co [ 28, 2, 2, 6, 2, 6, 2, 8, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 8, 0], # Ni [ 29, 2, 2, 6, 2, 6, 1, 10, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 10, 0], # Cu [ 30, 2, 2, 6, 2, 6, 2, 10, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 10, 0], # Zn [ 31, 2, 2, 6, 2, 6, 2, 10, 1, 0, 0, 0, 0, 0, 0, 0, 2, 1, 10, 0], # Ga [ 32, 2, 2, 6, 2, 6, 2, 10, 2, 0, 0, 0, 0, 0, 0, 0, 2, 2, 10, 0], # Ge [ 33, 2, 2, 6, 2, 6, 2, 10, 3, 0, 0, 0, 0, 0, 0, 0, 2, 3, 10, 0], # As [ 34, 2, 2, 6, 2, 6, 2, 10, 4, 0, 0, 0, 0, 0, 0, 0, 2, 4, 10, 0], # Se [ 35, 2, 2, 6, 2, 6, 2, 10, 5, 0, 0, 0, 0, 0, 0, 0, 2, 5, 10, 0], # Br [ 36, 2, 2, 6, 2, 6, 2, 10, 6, 0, 0, 0, 0, 0, 0, 0, 2, 6, 10, 0], # Kr [ 37, 2, 2, 6, 2, 6, 2, 10, 6, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # Rb [ 38, 2, 2, 6, 2, 6, 2, 10, 6, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], # Sr [ 39, 2, 2, 6, 2, 6, 2, 10, 6, 2, 1, 0, 0, 0, 0, 0, 2, 0, 1, 0], # Y [ 40, 2, 2, 6, 2, 6, 2, 10, 6, 2, 2, 0, 0, 0, 0, 0, 2, 0, 2, 0], # Zr [ 41, 2, 2, 6, 2, 6, 2, 10, 6, 1, 4, 0, 0, 0, 0, 0, 1, 0, 4, 0], # Nb [ 42, 2, 2, 6, 2, 6, 2, 10, 6, 1, 5, 0, 0, 0, 0, 0, 1, 0, 5, 0], # Mo [ 43, 2, 2, 6, 2, 6, 2, 10, 6, 2, 5, 0, 0, 0, 0, 0, 2, 0, 5, 0], # Tc [ 44, 2, 2, 6, 2, 6, 2, 10, 6, 1, 7, 0, 0, 0, 0, 0, 1, 0, 7, 0], # Ru [ 45, 2, 2, 6, 2, 6, 2, 10, 6, 1, 8, 0, 0, 0, 0, 0, 1, 0, 8, 0], # Rh [ 46, 2, 2, 6, 2, 6, 2, 10, 6, 0, 10, 0, 0, 0, 0, 0, 0, 0, 10, 0], # Pd [ 47, 2, 2, 6, 2, 6, 2, 10, 6, 1, 10, 0, 0, 0, 0, 0, 1, 0, 10, 0], # Ag [ 48, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 0, 0, 0, 0, 0, 2, 0, 10, 0], # Cd [ 49, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 1, 0, 0, 0, 0, 2, 1, 10, 0], # In [ 50, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 2, 0, 0, 0, 0, 2, 2, 10, 0], # Sn [ 51, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 3, 0, 0, 0, 0, 2, 3, 10, 0], # Sb [ 52, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 4, 0, 0, 0, 0, 2, 4, 10, 0], # Te [ 53, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 5, 0, 0, 0, 0, 2, 5, 10, 0], # I [ 54, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 0, 0, 0, 0, 2, 6, 10, 0], # Xe [ 55, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 1, 0, 0, 0, 1, 0, 0, 0], # Cs [ 56, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 0, 0, 0, 2, 0, 0, 0], # Ba [ 57, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 0, 1, 0, 2, 0, 1, 0], # La [ 58, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 1, 1, 0, 2, 0, 1, 1], # Ce [ 59, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 3, 0, 0, 2, 0, 0, 3], # Pr [ 60, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 4, 0, 0, 2, 0, 0, 4], # Nd [ 61, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 5, 0, 0, 2, 0, 0, 5], # Pm [ 62, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 6, 0, 0, 2, 0, 0, 6], # Sm [ 63, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 7, 0, 0, 2, 0, 0, 7], # Eu [ 64, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 7, 1, 0, 2, 0, 1, 7], # Gd [ 65, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 9, 0, 0, 2, 0, 0, 9], # Tb [ 66, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 10, 0, 0, 2, 0, 0, 10], # Dy [ 67, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 11, 0, 0, 2, 0, 0, 11], # Ho [ 68, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 12, 0, 0, 2, 0, 0, 12], # Er [ 69, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 13, 0, 0, 2, 0, 0, 13], # Tm [ 70, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 0, 0, 2, 0, 0, 14], # Yb [ 71, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 1, 0, 2, 0, 1, 14], # Lu [ 72, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 2, 0, 2, 0, 2, 14], # Hf [ 73, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 3, 0, 2, 0, 3, 14], # Ta [ 74, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 4, 0, 2, 0, 4, 14], # W [ 75, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 5, 0, 2, 0, 5, 14], # Re [ 76, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 6, 0, 2, 0, 6, 14], # Os [ 77, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 7, 0, 2, 0, 7, 14], # Ir [ 78, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 1, 14, 9, 0, 1, 0, 9, 14], # Pt [ 79, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 1, 14, 10, 0, 1, 0, 10, 14], # Au [ 80, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 0, 2, 0, 10, 14], # Hg [ 81, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 1, 2, 1, 10, 14], # Tl [ 82, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 2, 2, 2, 10, 14], # Pb [ 83, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 3, 2, 3, 10, 14], # Bi [ 84, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 4, 2, 4, 10, 14], # Po [ 85, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 5, 2, 5, 10, 14], # At [ 86, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 6, 2, 6, 10, 14] # Rn ], dtype=np.float64) # normalize entries (between 0.0 and 1.0) ELECTRON_CONFIG = ELECTRON_CONFIG / np.max(ELECTRON_CONFIG, axis=0)
[docs] class NuclearEmbedding(BaseAtomEmbedding):
[docs] def __init__(self, max_Za: int, dim_embedding: int, zero_init: bool=True, use_electron_config: bool=True) -> None: super().__init__(max_Za, dim_embedding) self.register_parameter( "element_embedding", Parameter(Tensor(max_Za + 1, dim_embedding)) ) self.register_buffer( "embedding", Tensor(max_Za + 1, dim_embedding), persistent=False ) self.use_electron_config = use_electron_config if use_electron_config: self.register_buffer("electron_config", torch.tensor(ELECTRON_CONFIG)) self.config_linear = Linear( self.electron_config.size(1), dim_embedding, bias=False) self.reset_parameters(zero_init)
[docs] def reset_parameters(self, zero_init: bool=True) -> None: """ Initialize parameters. """ if zero_init: init.zeros_(self.element_embedding) if self.use_electron_config: init.zeros_(self.config_linear.weight) else: init.uniform_(self.element_embedding, -math.sqrt(3), math.sqrt(3)) if self.use_electron_config: init.orthogonal_(self.config_linear.weight)
[docs] def train(self, mode: bool=True) -> None: """ Switch between training and evaluation mode. """ super().train(mode=mode) if not self.training: with torch.no_grad(): if self.use_electron_config: self.embedding = self.element_embedding + self.config_linear( self.electron_config ) else: self.embedding = self.element_embedding
[docs] def get_embedding(self, Za: Tensor) -> Tensor: """ Assign corresponding embeddings to nuclear charges. N: Number of atoms. num_features: Dimensions of feature space. Arguments: Za (LongTensor [N]): Nuclear charges (atomic numbers) of atoms. Returns: x (FloatTensor [N, num_features]): Embeddings of all atoms. """ if self.training: # during training, the embedding needs to be recomputed if self.use_electron_config: self.embedding = self.element_embedding + self.config_linear( self.electron_config ) else: self.embedding = self.element_embedding if self.embedding.device.type == "cpu": # indexing is faster on CPUs return self.embedding[Za] else: # gathering is faster on GPUs return torch.gather( self.embedding, 0, Za.view(-1, 1).expand(-1, self.dim_embedding) )