Source code for enerzyme.models.layers.geometry

from typing import Optional, Dict, Literal
from torch import Tensor
import torch.nn.functional as F
from . import BaseFFLayer
from ..cutoff import CUTOFF_KEY_TYPE, CUTOFF_REGISTER


[docs] class DistanceLayer(BaseFFLayer): ''' Compute the distance between atoms '''
[docs] def __init__(self) -> None: super().__init__(input_fields={"Ra", "idx_i", "idx_j", "offsets"}, output_fields={"Dij", "vij"}) self._with_vector = False
[docs] def with_vector_on(self, vij_name: str="vij") -> None: self._with_vector = True self.reset_field_name(vij=vij_name)
[docs] def get_output(self, Ra: Tensor, idx_i: Tensor, idx_j: Tensor, offsets: Optional[Tensor]=None) -> Dict[str, Tensor]: ''' Compute the distance with atom pair indices Params: ----- Ra: Float tensor of atom positions, shape [N * batch_size] idx_i: Long tensor of the first pair indices, shape [N_pair * batch_size] idx_j: Long tensor of the second pair indices, shape [N_pair * batch_size] offsets: Float tensor of distance offsets, shape [N_pair * batch_size] Returns: ----- Dij: Float tensor of distances, shape [N_pair * batch_size] ''' relevant_output = dict() if Ra.device.type == "cpu": # indexing is faster on CPUs Ri = Ra[idx_i] Rj = Ra[idx_j] else: Ri = Ra.gather(0, idx_i.view(-1, 1).expand(-1, 3)) Rj = Ra.gather(0, idx_j.view(-1, 1).expand(-1, 3)) if offsets is not None: Rj_ = Rj + offsets else: Rj_ = Rj relevant_output["Dij"] = F.pairwise_distance(Ri, Rj_, eps=1e-15) if self._with_vector: relevant_output["vij"] = Rj_ - Ri return relevant_output
[docs] class RangeSeparationLayer(BaseFFLayer):
[docs] def __init__(self, cutoff_sr, cutoff_fn: Optional[CUTOFF_KEY_TYPE]=None) -> None: super().__init__() self.cutoff_sr = cutoff_sr if cutoff_fn is not None: self.cutoff_fn = CUTOFF_REGISTER[cutoff_fn] else: self.cutoff_fn = None
[docs] def get_output(self, Dij_lr: Tensor, idx_i_lr: Tensor, idx_j_lr: Tensor, vij_lr: Optional[Tensor]=None) -> Dict[ Literal["Dij_sr", "idx_i_sr", "idx_j_sr", "vij_sr", "cutoff_values_sr"], Tensor ]: cutmask = Dij_lr < self.cutoff_sr relevant_output = { "Dij_sr": Dij_lr[cutmask], "idx_i_sr": idx_i_lr[cutmask], "idx_j_sr": idx_j_lr[cutmask], } if self.cutoff_fn is not None: relevant_output["cutoff_values_sr"] = self.cutoff_fn(relevant_output["Dij_sr"], self.cutoff_sr) if vij_lr is not None: relevant_output["vij_sr"] = vij_lr[cutmask] return relevant_output
[docs] class RadiusGraphLayer(BaseFFLayer):
[docs] def __init__(self, cutoff: float, max_num_neighbors: int) -> None: super().__init__(input_fields={"Ra", "batch_seg"}, output_fields={"idx_i_sr", "idx_j_sr"}) self.cutoff = cutoff self.max_num_neighbors = max_num_neighbors
[docs] def get_output(self, Ra: Tensor, batch_seg: Tensor) -> Dict[str, Tensor]: from torch_geometric.nn import radius_graph edge_index = radius_graph(Ra, r=self.cutoff, batch=batch_seg, max_num_neighbors=self.max_num_neighbors) return {"idx_i_sr": edge_index[0], "idx_j_sr": edge_index[1]}