Source code for enerzyme.models.layers.gather_embedding
from typing import Dict, Set
import torch
from torch import Tensor
from . import BaseFFLayer
[docs]
class GatherAtomEmbedding(BaseFFLayer):
[docs]
def __init__(self) -> None:
super().__init__(input_fields={}, output_fields={"atom_embedding"})
[docs]
def get_relevant_input_fields(self, net_input_fields: Set[str]) -> Set[str]:
relevant_input_fields = set()
for field in net_input_fields:
if field.endswith("_embedding"):
relevant_input_fields.add(field)
return relevant_input_fields
[docs]
def get_output(self, **relevant_input: Dict[str, Tensor]) -> Dict[str, Tensor]:
return {"atom_embedding": torch.sum(torch.stack([v for v in relevant_input.values()], dim=0), dim=0)}