Source code for enerzyme.models.layers.electron_embedding

from typing import Dict, Optional, Literal
from abc import abstractmethod
import torch
from torch import Tensor
from torch.nn import Linear, init
import torch.nn.functional as F
from torch_scatter import segment_sum_coo
from ..blocks.mlp import ResidualMLP as _ResidualMLP
from . import BaseFFLayer
from ..activation import ACTIVATION_KEY_TYPE


[docs] def ResidualMLP( dim_embedding: int, num_residual: int, activation_fn: ACTIVATION_KEY_TYPE, use_bias: bool=True, zero_init: bool=True ) -> _ResidualMLP: return _ResidualMLP( dim_feature_in=dim_embedding, dim_feature_out=dim_embedding, num_residual=num_residual, activation_fn=activation_fn, activation_params = { "dim_feature": dim_embedding, "learnable": True }, initial_weight1="orthogonal", initial_weight2="zero", initial_weight_out="zero" if zero_init else "orthogonal", use_bias_residual=use_bias, use_bias_out=use_bias )
[docs] class BaseElectronEmbedding(BaseFFLayer):
[docs] def __init__( self, dim_embedding: int, num_residual: int, attribute: Literal["charge", "spin"]="charge" ) -> None: input_fields = {"atom_embedding", "batch_seg"} if attribute == "charge": input_fields.add("Q") elif attribute == "spin": input_fields.add("S") super().__init__(input_fields=input_fields, output_fields={"electron_embedding"}) self.dim_embedding = dim_embedding self.num_residual = num_residual self.attribute = attribute self.reset_field_name(electron_embedding=f"{attribute}_embedding")
[docs] @abstractmethod def get_electron_embedding(self, atom_embedding, Q, batch_seg) -> Dict[Literal["electron_embedding"], Tensor]: ...
[docs] class ElectronicEmbedding(BaseElectronEmbedding): """ Block for updating atomic features through nonlocal interactions with the electrons. Arguments: num_features (int): Dimensions of feature space. num_basis_functions (int): Number of radial basis functions. num_residual_pre_i (int): Number of residual blocks applied to atomic features in i branch (central atoms) before computing the interaction. num_residual_pre_j (int): Number of residual blocks applied to atomic features in j branch (neighbouring atoms) before computing the interaction. num_residual_post (int): Number of residual blocks applied to interaction features. activation (str): Kind of activation function. Possible values: 'swish': Swish activation function. 'ssp': Shifted softplus activation function. """
[docs] def __init__( self, dim_embedding: int, num_residual: int, activation_fn: ACTIVATION_KEY_TYPE="swish", attribute: Literal["charge", "spin"]="charge" ) -> None: """ Initializes the ElectronicEmbedding class. """ super().__init__(dim_embedding, num_residual, attribute) self.linear_q = Linear(dim_embedding, dim_embedding) self.sqrt_dim_embedding = dim_embedding ** 0.5 if attribute == "charge": # charges are duplicated to use separate weights for +/- self.linear_k = Linear(2, dim_embedding, bias=False) self.linear_v = Linear(2, dim_embedding, bias=False) else: self.linear_k = Linear(1, dim_embedding, bias=False) self.linear_v = Linear(1, dim_embedding, bias=False) self.resblock = ResidualMLP( dim_embedding, num_residual, activation_fn, use_bias=False, zero_init=True ) self.reset_parameters()
[docs] def reset_parameters(self) -> None: """ Initialize parameters. """ init.orthogonal_(self.linear_k.weight) init.orthogonal_(self.linear_v.weight) init.orthogonal_(self.linear_q.weight) init.zeros_(self.linear_q.bias) self.eps = 1e-8
[docs] def get_electron_embedding( self, atom_embedding: Tensor, Q: Optional[Tensor]=None, S: Optional[Tensor]=None, batch_seg: Optional[Tensor]=None ) -> torch.Tensor: """ Evaluate interaction block. N: Number of atoms. x (FloatTensor [N, num_features]): Atomic feature vectors. """ if batch_seg is None: # assume a single batch batch_seg = torch.zeros(atom_embedding.size(0), dtype=torch.long, device=atom_embedding.device) q = self.linear_q(atom_embedding) # queries if self.attribute == "charge": if Q is None: Q = torch.zeros(batch_seg[-1] + 1, dtype=atom_embedding.dtype, device=atom_embedding.device) e = F.relu(torch.stack([Q, -Q], dim=-1)) else: if S is None: S = torch.zeros(batch_seg[-1] + 1, dtype=atom_embedding.dtype, device=atom_embedding.device) e = torch.abs(S).unsqueeze(-1) # +/- spin is the same => abs enorm = torch.maximum(e, torch.ones_like(e)) k = self.linear_k(e / enorm)[batch_seg] # keys v = self.linear_v(e)[batch_seg] # values dot = torch.sum(k * q, dim=-1) / self.sqrt_dim_embedding # scaled dot product a = F.softplus(dot) # unnormalized attention weights anorm = segment_sum_coo(a, batch_seg) if a.device.type == "cpu": # indexing is faster on CPUs anorm = anorm[batch_seg] else: # gathering is faster on GPUs anorm = torch.gather(anorm, 0, batch_seg) return self.resblock((a / (anorm + self.eps)).unsqueeze(-1) * v)
[docs] class NonlinearElectronicEmbedding(BaseElectronEmbedding): """ Block for updating atomic features through nonlocal interactions with the electrons. Arguments: num_features (int): Dimensions of feature space. num_basis_functions (int): Number of radial basis functions. num_residual_pre_i (int): Number of residual blocks applied to atomic features in i branch (central atoms) before computing the interaction. num_residual_pre_j (int): Number of residual blocks applied to atomic features in j branch (neighbouring atoms) before computing the interaction. num_residual_post (int): Number of residual blocks applied to interaction features. activation (str): Kind of activation function. Possible values: 'swish': Swish activation function. 'ssp': Shifted softplus activation function. """
[docs] def __init__( self, dim_embedding: int, num_residual: int, activation_fn: str="swish", attribute: Literal["charge", "spin"]="charge" ) -> None: """ Initializes the NonlinearElectronicEmbedding class. """ super(NonlinearElectronicEmbedding, self).__init__(dim_embedding, num_residual, attribute) self.linear_q = Linear(dim_embedding, dim_embedding, bias=False) self.featurize_k = Linear(1, dim_embedding) self.resblock_k = ResidualMLP( dim_embedding, num_residual, activation_fn=activation_fn, zero_init=True ) self.featurize_v = Linear(1, dim_embedding, bias=False) self.resblock_v = ResidualMLP( dim_embedding, num_residual, activation_fn=activation_fn, zero_init=True, use_bias=False, ) self.reset_parameters()
[docs] def reset_parameters(self) -> None: """ Initialize parameters. """ init.orthogonal_(self.linear_q.weight) init.orthogonal_(self.featurize_k.weight) init.zeros_(self.featurize_k.bias) init.orthogonal_(self.featurize_v.weight)
[docs] def get_electron_embedding( self, atom_embedding: Tensor, Q: Optional[Tensor]=None, S: Optional[Tensor]=None, batch_seg: Optional[Tensor]=None, mask: Optional[Tensor] = None, eps: float = 1e-8, ) -> Tensor: """ Evaluate interaction block. N: Number of atoms. x (FloatTensor [N, num_features]): Atomic feature vectors. """ if batch_seg is None: # assume a single batch batch_seg = torch.zeros(atom_embedding.size(0), dtype=torch.long, device=atom_embedding.device) if self.attribute == "charge": E = Q else: E = S e = E.unsqueeze(-1) q = self.linear_q(atom_embedding) # queries k = self.resblock_k(self.featurize_k(e))[batch_seg] # keys v = self.resblock_v(self.featurize_v(e))[batch_seg] # values # dot product dot = torch.sum(k * q, dim=-1) # determine maximum dot product (for numerics) num_batch = batch_seg[-1] + 1 if num_batch > 1: if mask is None: mask = ( F.one_hot(batch_seg) .to(dtype=atom_embedding.dtype, device=atom_embedding.device) .transpose(-1, -2) ) tmp = dot.view(1, -1).expand(num_batch, -1) tmp, _ = torch.max(mask * tmp, dim=-1) if tmp.device.type == "cpu": # indexing is faster on CPUs maximum = tmp[batch_seg] else: # gathering is faster on GPUs maximum = torch.gather(tmp, 0, batch_seg) else: maximum = torch.max(dot) # attention d = k.shape[-1] a = torch.exp((dot - maximum) / d ** 0.5) anorm = segment_sum_coo(a, batch_seg) if a.device.type == "cpu": # indexing is faster on CPUs anorm = anorm[batch_seg] else: # gathering is faster on GPUs anorm = torch.gather(anorm, 0, batch_seg) return (a / (anorm + eps)).unsqueeze(-1) * v