Source code for enerzyme.models.leftnet.core

# borrowed from https://github.com/yuanqidu/M2Hub/blob/master/m2models/models/leftnet.py

### feature dimension
### variable name

import math
from math import pi
from typing import Optional, Tuple, Union

import numpy as np
import torch
from torch import nn
from torch.nn import Embedding
from torch_geometric.nn import radius_graph
from torch_geometric.nn.conv import MessagePassing
from torch_scatter import scatter, segment_coo, segment_csr

from ..layers import BaseFFCore, DistanceLayer, RangeSeparationLayer


DEFAULT_BUILD_PARAMS = {
    "cutoff_sr": 5.0,
    "max_Za": 94,
}
DEFAULT_LAYER_PARAMS = [
    {'name': 'RangeSeparation'},
    {"name": "Core", "params": {
        "num_layers": 4,
        "hidden_channels": 128,
        "num_radial": 96,
        "eps": 1e-10,
        "head": 16,
        "main_chi1": 24,
        "mp_chi1": 24,
        "chi2": 6,
        "hidden_channels_chi": 96,
        "has_dropout_flag": True,
        "has_norm_before_flag": True,
        "has_norm_after_flag": False,
        "reduce_mode": "sum"
    }},
    {'name': 'AtomicAffine',
        'params': {
            'shifts': {
                'Ea': {'values': 0, 'learnable': True},
                'Qa': {'values': 0, 'learnable': True}},
            'scales': {
                'Ea': {'values': 1, 'learnable': True},
                'Qa': {'values': 1, 'learnable': True}}}
    },
    {'name': 'ChargeConservation'},
    {'name': 'AtomicCharge2Dipole'},
    {'name': 'ElectrostaticEnergy',
        'params': {'cutoff_lr': None, 'flavor': 'PhysNet'}},
    {'name': 'EnergyReduce'},
    {'name': 'Force'}
]







[docs] def swish(x): return x * torch.sigmoid(x)
[docs] def get_max_neighbors_mask( natoms, index, atom_distance, max_num_neighbors_threshold ): """ Give a mask that filters out edges so that each atom has at most `max_num_neighbors_threshold` neighbors. Assumes that `index` is sorted. """ device = natoms.device num_atoms = natoms.sum() # Get number of neighbors # segment_coo assumes sorted index ones = index.new_ones(1).expand_as(index) num_neighbors = segment_coo(ones, index, dim_size=num_atoms) max_num_neighbors = num_neighbors.max() num_neighbors_thresholded = num_neighbors.clamp( max=max_num_neighbors_threshold ) # Get number of (thresholded) neighbors per image image_indptr = torch.zeros( natoms.shape[0] + 1, device=device, dtype=torch.long ) image_indptr[1:] = torch.cumsum(natoms, dim=0) num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr) # If max_num_neighbors is below the threshold, return early if ( max_num_neighbors <= max_num_neighbors_threshold or max_num_neighbors_threshold <= 0 ): mask_num_neighbors = torch.tensor( [True], dtype=bool, device=device ).expand_as(index) return mask_num_neighbors, num_neighbors_image # Create a tensor of size [num_atoms, max_num_neighbors] to sort the distances of the neighbors. # Fill with infinity so we can easily remove unused distances later. distance_sort = torch.full( [num_atoms * max_num_neighbors], np.inf, device=device ) # Create an index map to map distances from atom_distance to distance_sort # index_sort_map assumes index to be sorted index_neighbor_offset = torch.cumsum(num_neighbors, dim=0) - num_neighbors index_neighbor_offset_expand = torch.repeat_interleave( index_neighbor_offset, num_neighbors ) index_sort_map = ( index * max_num_neighbors + torch.arange(len(index), device=device) - index_neighbor_offset_expand ) distance_sort.index_copy_(0, index_sort_map, atom_distance) distance_sort = distance_sort.view(num_atoms, max_num_neighbors) # Sort neighboring atoms based on distance distance_sort, index_sort = torch.sort(distance_sort, dim=1) # Select the max_num_neighbors_threshold neighbors that are closest distance_sort = distance_sort[:, :max_num_neighbors_threshold] index_sort = index_sort[:, :max_num_neighbors_threshold] # Offset index_sort so that it indexes into index index_sort = index_sort + index_neighbor_offset.view(-1, 1).expand( -1, max_num_neighbors_threshold ) # Remove "unused pairs" with infinite distances mask_finite = torch.isfinite(distance_sort) index_sort = torch.masked_select(index_sort, mask_finite) # At this point index_sort contains the index into index of the # closest max_num_neighbors_threshold neighbors per atom # Create a mask to remove all pairs not in index_sort mask_num_neighbors = torch.zeros(len(index), device=device, dtype=bool) mask_num_neighbors.index_fill_(0, index_sort, True) return mask_num_neighbors, num_neighbors_image
[docs] def check_and_reshape_cell(cell): if cell.dim() == 2 and cell.size(0) % 3 == 0 and cell.size(1) == 3: # ��� cell �Ƕ�ά���������ҵ�һ��ά�ȵĴ�С�� 3 �ı������ڶ���ά�ȵĴ�С�� 3 # ��������Ϊ (batch_size, 3, 3) ����״ batch_size = cell.size(0) // 3 cell = cell.reshape(batch_size, 3, 3) elif cell.dim() != 3 or cell.size(1) != 3 or cell.size(2) != 3: # ��� cell ��ά�Ȳ��� 3�����ߵڶ����͵�����ά�ȵĴ�С���� 3 # ���׳�һ���쳣 raise ValueError("Invalid cell shape. Expected (batch_size, 3, 3), but got {}".format(cell.size())) return cell
[docs] def radius_graph_pbc( data, radius, max_num_neighbors_threshold, pbc=[True, True, True] ): device = data.pos.device batch_size = len(data.natoms) data.cell = check_and_reshape_cell(data.cell) if hasattr(data, "pbc"): data.pbc = torch.atleast_2d(data.pbc) for i in range(3): if not torch.any(data.pbc[:, i]).item(): pbc[i] = False elif torch.all(data.pbc[:, i]).item(): pbc[i] = True else: raise RuntimeError( "Different structures in the batch have different PBC configurations. This is not currently supported." ) # position of the atoms atom_pos = data.pos # Before computing the pairwise distances between atoms, first create a list of atom indices to compare for the entire batch num_atoms_per_image = data.natoms num_atoms_per_image_sqr = (num_atoms_per_image**2).long() # index offset between images index_offset = ( torch.cumsum(num_atoms_per_image, dim=0) - num_atoms_per_image ) index_offset_expand = torch.repeat_interleave( index_offset, num_atoms_per_image_sqr ) num_atoms_per_image_expand = torch.repeat_interleave( num_atoms_per_image, num_atoms_per_image_sqr ) # Compute a tensor containing sequences of numbers that range from 0 to num_atoms_per_image_sqr for each image # that is used to compute indices for the pairs of atoms. This is a very convoluted way to implement # the following (but 10x faster since it removes the for loop) # for batch_idx in range(batch_size): # batch_count = torch.cat([batch_count, torch.arange(num_atoms_per_image_sqr[batch_idx], device=device)], dim=0) num_atom_pairs = torch.sum(num_atoms_per_image_sqr) index_sqr_offset = ( torch.cumsum(num_atoms_per_image_sqr, dim=0) - num_atoms_per_image_sqr ) index_sqr_offset = torch.repeat_interleave( index_sqr_offset, num_atoms_per_image_sqr ) atom_count_sqr = ( torch.arange(num_atom_pairs, device=device) - index_sqr_offset ) # Compute the indices for the pairs of atoms (using division and mod) # If the systems get too large this apporach could run into numerical precision issues index1 = ( torch.div( atom_count_sqr, num_atoms_per_image_expand, rounding_mode="floor" ) ) + index_offset_expand index2 = ( atom_count_sqr % num_atoms_per_image_expand ) + index_offset_expand # Get the positions for each atom pos1 = torch.index_select(atom_pos, 0, index1) pos2 = torch.index_select(atom_pos, 0, index2) # Calculate required number of unit cells in each direction. # Smallest distance between planes separated by a1 is # 1 / ||(a2 x a3) / V||_2, since a2 x a3 is the area of the plane. # Note that the unit cell volume V = a1 * (a2 x a3) and that # (a2 x a3) / V is also the reciprocal primitive vector # (crystallographer's definition). #print(data.cell.shape) cross_a2a3 = torch.cross(data.cell[:, 1], data.cell[:, 2], dim=-1) cell_vol = torch.sum(data.cell[:, 0] * cross_a2a3, dim=-1, keepdim=True) if pbc[0]: inv_min_dist_a1 = torch.norm(cross_a2a3 / cell_vol, dim=-1) rep_a1 = torch.ceil(radius * inv_min_dist_a1) else: rep_a1 = data.cell.new_zeros(1) if pbc[1]: cross_a3a1 = torch.cross(data.cell[:, 2], data.cell[:, 0], dim=-1) inv_min_dist_a2 = torch.norm(cross_a3a1 / cell_vol, dim=-1) rep_a2 = torch.ceil(radius * inv_min_dist_a2) else: rep_a2 = data.cell.new_zeros(1) if pbc[2]: cross_a1a2 = torch.cross(data.cell[:, 0], data.cell[:, 1], dim=-1) inv_min_dist_a3 = torch.norm(cross_a1a2 / cell_vol, dim=-1) rep_a3 = torch.ceil(radius * inv_min_dist_a3) else: rep_a3 = data.cell.new_zeros(1) # Take the max over all images for uniformity. This is essentially padding. # Note that this can significantly increase the number of computed distances # if the required repetitions are very different between images # (which they usually are). Changing this to sparse (scatter) operations # might be worth the effort if this function becomes a bottleneck. max_rep = [int(rep_a1.max()), int(rep_a2.max()), int(rep_a3.max())] # Tensor of unit cells cells_per_dim = [ torch.arange(-rep, rep + 1, device=device, dtype=torch.float32) for rep in max_rep ] unit_cell = torch.cartesian_prod(*cells_per_dim) num_cells = len(unit_cell) unit_cell_per_atom = unit_cell.view(1, num_cells, 3).repeat( len(index2), 1, 1 ) unit_cell = torch.transpose(unit_cell, 0, 1) unit_cell_batch = unit_cell.view(1, 3, num_cells).expand( batch_size, -1, -1 ) # Compute the x, y, z positional offsets for each cell in each image data_cell = torch.transpose(data.cell, 1, 2) pbc_offsets = torch.bmm(data_cell, unit_cell_batch) pbc_offsets_per_atom = torch.repeat_interleave( pbc_offsets, num_atoms_per_image_sqr, dim=0 ) # Expand the positions and indices for the 9 cells pos1 = pos1.view(-1, 3, 1).expand(-1, -1, num_cells) pos2 = pos2.view(-1, 3, 1).expand(-1, -1, num_cells) index1 = index1.view(-1, 1).repeat(1, num_cells).view(-1) index2 = index2.view(-1, 1).repeat(1, num_cells).view(-1) # Add the PBC offsets for the second atom pos2 = pos2 + pbc_offsets_per_atom # Compute the squared distance between atoms atom_distance_sqr = torch.sum((pos1 - pos2) ** 2, dim=1) atom_distance_sqr = atom_distance_sqr.view(-1) # Remove pairs that are too far apart mask_within_radius = torch.le(atom_distance_sqr, radius * radius) # Remove pairs with the same atoms (distance = 0.0) mask_not_same = torch.gt(atom_distance_sqr, 0.0001) mask = torch.logical_and(mask_within_radius, mask_not_same) index1 = torch.masked_select(index1, mask) index2 = torch.masked_select(index2, mask) unit_cell = torch.masked_select( unit_cell_per_atom.view(-1, 3), mask.view(-1, 1).expand(-1, 3) ) unit_cell = unit_cell.view(-1, 3) atom_distance_sqr = torch.masked_select(atom_distance_sqr, mask) mask_num_neighbors, num_neighbors_image = get_max_neighbors_mask( natoms=data.natoms, index=index1, atom_distance=atom_distance_sqr, max_num_neighbors_threshold=max_num_neighbors_threshold, ) if not torch.all(mask_num_neighbors): # Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors index1 = torch.masked_select(index1, mask_num_neighbors) index2 = torch.masked_select(index2, mask_num_neighbors) unit_cell = torch.masked_select( unit_cell.view(-1, 3), mask_num_neighbors.view(-1, 1).expand(-1, 3) ) unit_cell = unit_cell.view(-1, 3) edge_index = torch.stack((index2, index1)) return edge_index, unit_cell, num_neighbors_image
[docs] def get_pbc_distances( pos, edge_index, cell, cell_offsets, neighbors, return_offsets=False, return_distance_vec=False, ): row, col = edge_index distance_vectors = pos[row] - pos[col] # correct for pbc neighbors = neighbors.to(cell.device) cell = torch.repeat_interleave(cell, neighbors, dim=0) offsets = cell_offsets.float().view(-1, 1, 3).bmm(cell.float()).view(-1, 3) distance_vectors += offsets # compute distances distances = distance_vectors.norm(dim=-1) # redundancy: remove zero distances nonzero_idx = torch.arange(len(distances), device=distances.device)[ distances != 0 ] edge_index = edge_index[:, nonzero_idx] distances = distances[nonzero_idx] out = { "edge_index": edge_index, "distances": distances, } if return_distance_vec: out["distance_vec"] = distance_vectors[nonzero_idx] if return_offsets: out["offsets"] = offsets[nonzero_idx] return out
## radial basis function to embed distances ## add comments: based on XXX code
[docs] class rbf_emb(nn.Module): ''' modified: delete cutoff with r '''
[docs] def __init__(self, num_rbf, rbound_upper, rbf_trainable=False): super().__init__() self.rbound_upper = rbound_upper self.rbound_lower = 0 self.num_rbf = num_rbf self.rbf_trainable = rbf_trainable means, betas = self._initial_params() self.register_buffer("means", means) self.register_buffer("betas", betas)
def _initial_params(self): start_value = torch.exp(torch.scalar_tensor(-self.rbound_upper)) end_value = torch.exp(torch.scalar_tensor(-self.rbound_lower)) means = torch.linspace(start_value, end_value, self.num_rbf) betas = torch.tensor([(2 / self.num_rbf * (end_value - start_value)) ** -2] * self.num_rbf) return means, betas
[docs] def reset_parameters(self): means, betas = self._initial_params() self.means.data.copy_(means) self.betas.data.copy_(betas)
[docs] def forward(self, dist): dist = dist.unsqueeze(-1) rbounds = 0.5 * \ (torch.cos(dist * pi / self.rbound_upper) + 1.0) rbounds = rbounds * (dist < self.rbound_upper).float() return rbounds * torch.exp(-self.betas * torch.square((torch.exp(-dist) - self.means)))
[docs] class NeighborEmb(MessagePassing):
[docs] def __init__(self, hid_dim: int): super(NeighborEmb, self).__init__(aggr='add') self.embedding = nn.Embedding(95, hid_dim) self.hid_dim = hid_dim self.ln_emb = nn.LayerNorm(hid_dim, elementwise_affine=False)
[docs] def forward(self, z, s, edge_index, embs): s_neighbors = self.ln_emb(self.embedding(z)) s_neighbors = self.propagate(edge_index, x=s_neighbors, norm=embs) s = s + s_neighbors return s
[docs] def message(self, x_j, norm): return norm.view(-1, self.hid_dim) * x_j
[docs] class S_vector(MessagePassing):
[docs] def __init__(self, hid_dim: int): super(S_vector, self).__init__(aggr='add') self.hid_dim = hid_dim self.lin1 = nn.Sequential( nn.Linear(hid_dim, hid_dim), nn.LayerNorm(hid_dim, elementwise_affine=False), nn.SiLU())
[docs] def forward(self, s, v, edge_index, emb): s = self.lin1(s) emb = emb.unsqueeze(1) * v v = self.propagate(edge_index, x=s, norm=emb) return v.view(-1, 3, self.hid_dim)
[docs] def message(self, x_j, norm): x_j = x_j.unsqueeze(1) a = norm.view(-1, 3, self.hid_dim) * x_j return a.view(-1, 3 * self.hid_dim)
[docs] class EquiMessagePassing(MessagePassing):
[docs] def __init__( self, hidden_channels, num_radial, hidden_channels_chi=96, head: int = 16, chi1: int = 32, chi2: int = 8, has_dropout_flag: bool = False, has_norm_before_flag=True, has_norm_after_flag=False, reduce_mode='sum', device=torch.device('cuda') if torch.cuda.is_available() else torch.device("cpu") ): super(EquiMessagePassing, self).__init__(aggr="add", node_dim=0) self.device = device self.reduce_mode = reduce_mode self.chi1 = chi1 self.chi2 = chi2 self.head = head self.hidden_channels = hidden_channels self.hidden_channels_chi = hidden_channels_chi self.scale = nn.Linear(self.hidden_channels, self.hidden_channels_chi * 2) self.num_radial = num_radial self.dir_proj = nn.Sequential( nn.Linear(3 * self.hidden_channels + self.num_radial, self.hidden_channels * 3), nn.SiLU(inplace=True), nn.Linear(self.hidden_channels * 3, self.hidden_channels * 3), ) self.x_proj = nn.Sequential( nn.Linear(hidden_channels, hidden_channels), nn.SiLU(), nn.Linear(hidden_channels, hidden_channels * 3), ) self.rbf_proj = nn.Linear(num_radial, hidden_channels * 3) self.x_layernorm = nn.LayerNorm(hidden_channels) self.diagonal = nn.Sequential( nn.Linear(hidden_channels * 3, hidden_channels_chi // 2), nn.SiLU(), nn.Linear(hidden_channels_chi // 2, self.chi2), ) self.has_dropout_flag = has_dropout_flag self.has_norm_before_flag = has_norm_before_flag self.has_norm_after_flag = has_norm_after_flag if self.has_norm_after_flag: self.dx_layer_norm = nn.LayerNorm(self.chi1) if self.has_norm_before_flag: self.dx_layer_norm = nn.LayerNorm(self.chi1 + self.hidden_channels) self.dropout = nn.Dropout(p=0.5) #self.diachi1 = torch.nn.Parameter(torch.randn((self.chi1), dtype=torch.complex64, device=self.device)�� self.diachi1 = torch.nn.Parameter(torch.randn((self.chi1), device=self.device)) self.scale2 = nn.Sequential( nn.Linear(self.chi1, hidden_channels//2), ) self.kernel_real = torch.nn.Parameter(torch.randn((self.head + 1, (self.hidden_channels_chi) // self.head, self.chi2), device=self.device)) self.kernel_imag = torch.nn.Parameter(torch.randn((self.head + 1, (self.hidden_channels_chi) // self.head, self.chi2), device=self.device)) #self.kernel = torch.complex(self.kernel_real, self.kernel_imag) / math.sqrt((self.hidden_channels) // self.head) # self.kerneli_real = torch.nn.Parameter(torch.randn((hidden_channels // 2, hidden_channels + 1), device=self.device)) # self.kerneli_imag = torch.nn.Parameter(torch.randn((hidden_channels // 2, hidden_channels + 1), device=self.device)) #self.kerneli = torch.complex(self.kerneli_real, self.kerneli_imag) / math.sqrt(hidden_channels // 2) # self.kerneli = torch.randn((hidden_channels // 2, hidden_channels + 1), dtype=torch.complex64)/math.sqrt(hidden_channels) #/ hidden_channels # self.conv11_real = torch.randn((self.chi1), device=self.device) # self.conv11_imag = torch.randn((self.chi1), device=self.device) # self.conv11 = torch.nn.Parameter(torch.complex(self.conv11_real, self.conv11_imag)) # self.conv11 = self.conv11 / self.conv11.abs() # self.kerneli = self.kerneli.unsqueeze(-1) * self.conv11.unsqueeze(0).unsqueeze(0) self.fc_mps = nn.Linear(self.chi1, self.chi1)#.to(torch.cfloat) self.fc_dx = nn.Linear(self.chi1, hidden_channels)#.to(torch.cfloat) # self.fc_dx = nn.Linear(self.chi1, self.chi1).to(torch.cfloat) self.dia = nn.Linear(self.chi1, self.chi1)#.to(torch.cfloat) # self.kernel_vec = torch.randn((3, self.chi1, self.chi1), # dtype=torch.complex64) / hidden_channels #self.unitary = torch.nn.Parameter(torch.randn((self.chi1, self.chi1), dtype=torch.complex64, device=self.device)) self.unitary = torch.nn.Parameter(torch.randn((self.chi1, self.chi1), device=self.device)) self.activation = nn.SiLU() self.inv_sqrt_3 = 1 / math.sqrt(3.0) self.inv_sqrt_h = 1 / math.sqrt(hidden_channels) self.x_layernorm = nn.LayerNorm(hidden_channels) self.reset_parameters()
[docs] def reset_parameters(self): nn.init.xavier_uniform_(self.x_proj[0].weight) self.x_proj[0].bias.data.fill_(0) nn.init.xavier_uniform_(self.x_proj[2].weight) self.x_proj[2].bias.data.fill_(0) nn.init.xavier_uniform_(self.rbf_proj.weight) self.rbf_proj.bias.data.fill_(0) self.x_layernorm.reset_parameters() ## question: why don't reset parameters for dir_proj? nn.init.xavier_uniform_(self.dir_proj[0].weight) self.dir_proj[0].bias.data.fill_(0) nn.init.xavier_uniform_(self.dir_proj[2].weight) self.dir_proj[2].bias.data.fill_(0)
[docs] def forward(self, x, vec, edge_index, edge_rbf, weight, edge_vector,rope): # ,unitary): if rope != None: real, imag = torch.split(x, [self.hidden_channels//2, self.hidden_channels//2], dim=-1) dy_pre = torch.complex(real=real, imag=imag) dy_pre = dy_pre* rope x = torch.cat([dy_pre.real, dy_pre.imag], dim=-1) xh = self.x_proj(self.x_layernorm(x)) rbfh = self.rbf_proj(edge_rbf) weight = self.dir_proj(weight) rbfh = rbfh * weight # propagate_type: (xh: Tensor, vec: Tensor, rbfh_ij: Tensor, r_ij: Tensor) dx, dvec = self.propagate( edge_index, xh=xh, vec=vec, rbfh_ij=rbfh, r_ij=edge_vector, size=None, # rotation = unitary, ) if self.has_norm_before_flag: dx = self.dx_layer_norm(dx) dx, dy = torch.split(dx, [self.chi1, self.hidden_channels], dim=-1) if self.has_norm_after_flag: dx = self.dx_layer_norm(dx) dx = self.scale2(dx) dx = torch.complex(torch.cos(dx), torch.sin(dx)) return dx, dy, dvec
[docs] def message(self, xh_j, vec_j, rbfh_ij, r_ij): x, xh2, xh3 = torch.split(xh_j * rbfh_ij, self.hidden_channels, dim=-1) xh2 = xh2 * self.inv_sqrt_3 # mean = mean_j real, imagine = torch.split(self.scale(x), self.hidden_channels_chi, dim=-1) real = real.reshape(x.shape[0], self.head, (self.hidden_channels_chi) // self.head) imagine = imagine.reshape(x.shape[0], self.head, (self.hidden_channels_chi) // self.head) if self.has_dropout_flag: real = self.dropout(real) imagine = self.dropout(imagine) # complex invariant quantum state if real.dtype == torch.float64: cdtype = torch.complex128 elif real.dtype == torch.float32: cdtype = torch.complex64 phi = torch.complex(real, imagine) # phi_bar = torch.complex(real, -imagine) q = phi # k = phi_bar a = torch.ones(q.shape[0], 1, (self.hidden_channels_chi) // self.head, device=self.device, dtype=cdtype) kernel = (torch.complex(self.kernel_real, self.kernel_imag) / math.sqrt((self.hidden_channels) // self.head)).expand(q.shape[0], -1, -1, -1) equation = 'ijl, ijlk->ik' # conv: E, chi2 # print(torch.cat([a, q], dim=1).shape) # print(kernel.shape) conv = torch.einsum(equation, torch.cat([a, q], dim=1), kernel.to(cdtype)) a = 1.0 * self.activation(self.diagonal(rbfh_ij)) #b = a.unsqueeze(-1) * self.diachi1.unsqueeze(0).unsqueeze(0) + torch.ones(kernel.shape[0], self.chi2, self.chi1, device=self.device, dtype=torch.complex64) b = a.unsqueeze(-1) * self.diachi1.unsqueeze(0).unsqueeze(0) + torch.ones(kernel.shape[0], self.chi2, self.chi1, device=self.device) dia = self.dia(b) equation = 'ik,ikl->il' kernel = torch.einsum(equation, conv, dia.to(cdtype)) kernel_real,kernel_imag = kernel.real,kernel.imag kernel_real,kernel_imag = self.fc_mps(kernel_real),self.fc_mps(kernel_imag) kernel = torch.angle(torch.complex(kernel_real, kernel_imag)) agg = torch.cat([kernel, x], dim=-1) vec = vec_j * xh2.unsqueeze(1) + xh3.unsqueeze(1) * r_ij.unsqueeze(2) vec = vec * self.inv_sqrt_h return agg, vec
[docs] def aggregate( self, features: Tuple[torch.Tensor, torch.Tensor], index: torch.Tensor, ptr: Optional[torch.Tensor], dim_size: Optional[int], ) -> Tuple[torch.Tensor, torch.Tensor]: x, vec = features x = scatter(x, index, dim=self.node_dim, dim_size=dim_size, reduce=self.reduce_mode) vec = scatter(vec, index, dim=self.node_dim, dim_size=dim_size) return x, vec
[docs] def update( self, inputs: Tuple[torch.Tensor, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: return inputs
[docs] class FTE(nn.Module):
[docs] def __init__(self, hidden_channels): super().__init__() self.hidden_channels = hidden_channels self.vec_proj = nn.Linear( hidden_channels, hidden_channels * 2, bias=False ) self.xvec_proj = nn.Sequential( nn.Linear(hidden_channels * 2, hidden_channels), nn.SiLU(), nn.Linear(hidden_channels, hidden_channels * 3), ) self.inv_sqrt_2 = 1 / math.sqrt(2.0) self.inv_sqrt_h = 1 / math.sqrt(hidden_channels) self.reset_parameters()
[docs] def reset_parameters(self): nn.init.xavier_uniform_(self.vec_proj.weight) nn.init.xavier_uniform_(self.xvec_proj[0].weight) self.xvec_proj[0].bias.data.fill_(0) nn.init.xavier_uniform_(self.xvec_proj[2].weight) self.xvec_proj[2].bias.data.fill_(0)
[docs] def forward(self, x, vec, node_frame): vec = self.vec_proj(vec) vec1, vec2 = torch.split( vec, self.hidden_channels, dim=-1 ) # # # scalrization = torch.sum(vec1.unsqueeze(2) * node_frame.unsqueeze(-1), dim=1) # # # scalrization[:, 1, :] = torch.abs(scalrization[:, 1, :].clone()) # scalar = torch.sqrt(torch.sum(vec1 ** 2, dim=-2) + 1e-10) scalar = torch.norm(vec1, dim=-2, p=1) vec_dot = (vec1 * vec2).sum(dim=1) vec_dot = vec_dot * self.inv_sqrt_h x_vec_h = self.xvec_proj( torch.cat( [x, scalar], dim=-1 ) ) xvec1, xvec2, xvec3 = torch.split( x_vec_h, self.hidden_channels, dim=-1 ) dx = xvec1 + xvec2 + vec_dot dx = dx * self.inv_sqrt_2 dvec = xvec3.unsqueeze(1) * vec2 return dx, dvec
[docs] class aggregate_pos(MessagePassing):
[docs] def __init__(self, aggr='mean'): super(aggregate_pos, self).__init__(aggr=aggr)
[docs] def forward(self, vector, edge_index): v = self.propagate(edge_index, x=vector) return v
[docs] class EquiOutput(nn.Module):
[docs] def __init__(self, hidden_channels): super().__init__() self.hidden_channels = hidden_channels self.output_network = nn.ModuleList( [ # GatedEquivariantBlock( # hidden_channels, # hidden_channels // 2, # ), GatedEquivariantBlock(hidden_channels, 1), ] ) self.reset_parameters()
[docs] def reset_parameters(self): for layer in self.output_network: layer.reset_parameters()
[docs] def forward(self, x, vec): for layer in self.output_network: x, vec = layer(x, vec) return vec.squeeze()
# Borrowed from TorchMD-Net
[docs] class GatedEquivariantBlock(nn.Module): """Gated Equivariant Block as defined in Stt et al. (2021): Equivariant message passing for the prediction of tensorial properties and molecular spectra """
[docs] def __init__( self, hidden_channels, out_channels, ): super(GatedEquivariantBlock, self).__init__() self.out_channels = out_channels self.vec1_proj = nn.Linear( hidden_channels, hidden_channels, bias=False ) self.vec2_proj = nn.Linear(hidden_channels, out_channels, bias=False) self.update_net = nn.Sequential( nn.Linear(hidden_channels * 2, hidden_channels), nn.SiLU(), nn.Linear(hidden_channels, out_channels * 2), ) self.act = nn.SiLU()
[docs] def reset_parameters(self): nn.init.xavier_uniform_(self.vec1_proj.weight) nn.init.xavier_uniform_(self.vec2_proj.weight) nn.init.xavier_uniform_(self.update_net[0].weight) self.update_net[0].bias.data.fill_(0) nn.init.xavier_uniform_(self.update_net[2].weight) self.update_net[2].bias.data.fill_(0)
[docs] def forward(self, x, v): vec1 = torch.norm(self.vec1_proj(v), dim=-2) vec2 = self.vec2_proj(v) x = torch.cat([x, vec1], dim=-1) x, v = torch.split(self.update_net(x), self.out_channels, dim=-1) v = v.unsqueeze(1) * vec2 x = self.act(x) return x, v
[docs] class LEFTNet(BaseFFCore):
[docs] def __init__( self, cutoff_sr: float = 5.0, num_layers: int = 4, hidden_channels: int = 128, num_radial: int = 96, eps: Union[float, str] = 1e-10, use_sigmoid: bool = False, head: int = 16, main_chi1: int = 24, mp_chi1: int = 24, chi2: int = 6, hidden_channels_chi: int = 96, has_dropout_flag=True, has_norm_before_flag=True, has_norm_after_flag=False, reduce_mode='sum' ): super(LEFTNet, self).__init__(input_fields={"Ra", "Za", "batch_seg", "idx_i_sr", "idx_j_sr", "Dij_sr", "vij_sr"}, output_fields={"Ea", "Qa"}) self.eps = float(eps) self.num_layers = num_layers self.hidden_channels = hidden_channels self.cutoff = cutoff_sr self.chi1 = main_chi1 self.pos_require_grad= True if self.pos_require_grad: self.out_forces = EquiOutput(hidden_channels) self.z_emb_ln = nn.LayerNorm(hidden_channels, elementwise_affine=False) self.z_emb = Embedding(95, hidden_channels) self.kernel1 = torch.nn.Parameter(torch.randn((hidden_channels, self.chi1 * 2))) self.kernels_real = [] self.kernels_imag = [] self.radial_emb = rbf_emb(num_radial, cutoff_sr) self.radial_lin = nn.Sequential( nn.Linear(num_radial, hidden_channels), nn.SiLU(inplace=True), nn.Linear(hidden_channels, hidden_channels)) self.neighbor_emb = NeighborEmb(hidden_channels) self.S_vector = S_vector(hidden_channels) self.lin = nn.Sequential( nn.Linear(3, hidden_channels // 4), nn.SiLU(inplace=True), nn.Linear(hidden_channels // 4, 1)) self.message_layers = nn.ModuleList() self.FTEs = nn.ModuleList() for _ in range(num_layers): self.message_layers.append( EquiMessagePassing(hidden_channels=hidden_channels, num_radial=num_radial, head=head, chi2=chi2, chi1=mp_chi1, has_dropout_flag=has_dropout_flag, has_norm_before_flag=has_norm_before_flag, has_norm_after_flag=has_norm_after_flag, hidden_channels_chi=hidden_channels_chi, reduce_mode=reduce_mode) ) self.FTEs.append(FTE(hidden_channels)) kernel_real = torch.randn((hidden_channels, self.chi1, self.chi1)) kernel_imag = torch.randn((hidden_channels, self.chi1, self.chi1)) self.kernels_real.append(kernel_real) self.kernels_imag.append(kernel_imag) self.kernels_real = torch.nn.Parameter(torch.stack(self.kernels_real)) self.kernels_imag = torch.nn.Parameter(torch.stack(self.kernels_imag)) self.num_targets = 2 self.last_layer = nn.Linear(hidden_channels, self.num_targets) self.last_layer_quantum = nn.Linear(self.chi1 * 2, self.num_targets) # self.out_forces = EquiOutput(hidden_channels) # for node-wise frame self.mean_neighbor_pos = aggregate_pos(aggr='mean') self.inv_sqrt_2 = 1 / math.sqrt(2.0) self.reset_parameters() self.use_sigmoid = use_sigmoid
def __str__(self) -> str: return """ ############################################ # LEFTNet (NeurIPS 2023, arXiv:2304.04757) # ############################################ """
[docs] def reset_parameters(self): self.z_emb.reset_parameters() self.radial_emb.reset_parameters() for layer in self.message_layers: layer.reset_parameters() for layer in self.FTEs: layer.reset_parameters() self.last_layer.reset_parameters() for layer in self.radial_lin: if hasattr(layer, 'reset_parameters'): layer.reset_parameters() for layer in self.lin: if hasattr(layer, 'reset_parameters'): layer.reset_parameters()
#@conditional_grad(torch.enable_grad()) def __forward(self, Ra, batch, Za, idx_i, idx_j, dist, vecs): pos = Ra - scatter(Ra, batch, dim=0)[batch] edge_index = torch.stack([idx_i, idx_j], dim=0) # embed z z_emb = self.z_emb_ln(self.z_emb(Za)) # radial_emb shape: (num_edges, num_radial), radial_hidden shape: (num_edges, hidden_channels) radial_emb = self.radial_emb(dist) radial_hidden = self.radial_lin(radial_emb) rbounds = 0.5 * (torch.cos(dist * pi / self.cutoff) + 1.0) radial_hidden = rbounds.unsqueeze(-1) * radial_hidden # init invariant node features # shape: (num_nodes, hidden_channels) s = self.neighbor_emb(Za, z_emb, edge_index, radial_hidden) # init equivariant node features # shape: (num_nodes, 3, hidden_channels) vec = torch.zeros(s.size(0), 3, s.size(1), device=s.device) # bulid edge-wise frame edge_diff = vecs edge_diff = edge_diff / (dist.unsqueeze(1) + self.eps) mean = scatter(pos[edge_index[0]], edge_index[1], reduce='mean', dim=0) # noise = torch.clip(torch.empty(1,3).to(z.device).normal_(mean=0.0, std=0.1), min=-0.1, max=0.1) edge_cross = torch.cross(pos[idx_i]-mean[idx_i], pos[idx_j]-mean[idx_i]) # edge_cross = edge_cross / ((torch.sqrt(torch.sum((edge_cross) ** 2, 1).unsqueeze(1))) + self.eps) edge_vertical = torch.cross(edge_diff, edge_cross) # shape: (num_edges, 3, 3) edge_frame = torch.cat((edge_diff.unsqueeze(-1), edge_cross.unsqueeze(-1), edge_vertical.unsqueeze(-1)), dim=-1) node_frame = 0 # LSE: local 3D substructure encoding # S_i_j shape: (num_nodes, 3, hidden_channels) S_i_j = self.S_vector(s, edge_diff.unsqueeze(-1), edge_index, radial_hidden) scalrization1 = torch.sum(S_i_j[idx_i].unsqueeze(2) * edge_frame.unsqueeze(-1), dim=1) scalrization2 = torch.sum(S_i_j[idx_j].unsqueeze(2) * edge_frame.unsqueeze(-1), dim=1) scalrization1[:, 1, :] = torch.abs(scalrization1[:, 1, :].clone()) scalrization2[:, 1, :] = torch.abs(scalrization2[:, 1, :].clone()) scalar3 = (self.lin(torch.permute(scalrization1, (0, 2, 1))) + torch.permute(scalrization1, (0, 2, 1))[:, :, 0].unsqueeze(2)).squeeze(-1) / math.sqrt( self.hidden_channels) scalar4 = (self.lin(torch.permute(scalrization2, (0, 2, 1))) + torch.permute(scalrization2, (0, 2, 1))[:, :, 0].unsqueeze(2)).squeeze(-1) / math.sqrt( self.hidden_channels) edge_weight = torch.cat((scalar3, scalar4), dim=-1) * rbounds.unsqueeze(-1) edge_weight = torch.cat((edge_weight, radial_hidden, radial_emb), dim=-1) equation = 'ik,bi->bk' quantum = torch.einsum(equation, self.kernel1, z_emb) real, imagine = torch.split(quantum, self.chi1, dim=-1) quantum = torch.complex(real, imagine) for i in range(self.num_layers): if i>0: rope, ds, dvec = self.message_layers[i]( s, vec, edge_index, radial_emb, edge_weight, edge_diff,rope ) else: rope, ds, dvec = self.message_layers[i]( s, vec, edge_index, radial_emb, edge_weight, edge_diff,rope=None ) s = s + ds vec = vec + dvec equation = 'ikl,bi,bl->bk' kerneli = torch.complex(self.kernels_real[i], self.kernels_imag[i]) quantum = torch.einsum(equation, kerneli, s.to(kerneli.dtype), quantum) quantum = quantum / quantum.abs().to(kerneli.dtype) # FTE: frame transition encoding ds, dvec = self.FTEs[i](s, vec, node_frame) s = s + ds vec = vec + dvec s = self.last_layer(s) + self.last_layer_quantum(torch.cat([quantum.real, quantum.imag], dim=-1)) / self.chi1 return s[:,0], s[:,1]
[docs] def build(self, built_layers) -> None: calculate_distance = DistanceLayer() calculate_distance.with_vector_on("vij_lr") calculate_distance.reset_field_name(Dij="Dij_lr") self.pre_sequence.append(calculate_distance) pre_core = True for layer in built_layers: if layer is self: pre_core = False continue if pre_core: if isinstance(layer, RangeSeparationLayer): layer.reset_field_name(idx_i_lr="idx_i", idx_j_lr="idx_j") self.pre_sequence.append(layer) else: self.post_sequence.append(layer)
[docs] def get_output(self, Ra, Za, batch_seg, idx_i_sr, idx_j_sr, Dij_sr, vij_sr): Ea, Qa = self.__forward(Ra, batch_seg, Za, idx_i_sr, idx_j_sr, Dij_sr, vij_sr) return {"Ea": Ea, "Qa": Qa}
@property def num_params(self): return sum(p.numel() for p in self.parameters())