from typing import Dict, Union, Literal
import torch
from torch import nn, Tensor
from . import BaseFFLayer
from ...data.transform import PERIODIC_TABLE
[docs]
class AtomicAffineLayer(BaseFFLayer):
[docs]
def __init__(
self,
max_Za: int,
shifts: Dict[Literal["Ea", "Qa"], Dict[Literal["values", "learnable"], Union[Dict[str, float], float, bool]]]={"Ea": {"values": 0, "learnable": True}, "Qa": {"values": 0, "learnable": True}},
scales: Dict[Literal["Ea", "Qa"], Dict[Literal["values", "learnable"], Union[Dict[str, float], float, bool]]]={"Ea": {"values": 1, "learnable": True}, "Qa": {"values": 1, "learnable": True}}
) -> None:
atomic_properties = shifts.keys() | scales.keys()
super().__init__(input_fields={"Za"} | atomic_properties, output_fields=atomic_properties)
self.max_Za = max_Za
self.shifts = self.build_affine(shifts, 0)
self.scales = self.build_affine(scales, 1)
[docs]
def build_affine(
self,
params: Dict[Literal["Ea", "Qa"], Dict[Literal["values", "learnable"], Union[Dict[str, float], float, bool]]],
default_value: float
) -> nn.ParameterDict:
affine_dict = dict()
for name, param in params.items():
values = param["values"]
if isinstance(values, dict):
affine_param = torch.full((self.max_Za + 1,), float(default_value))
for idx, value in values.items():
if isinstance(idx, str):
affine_param[PERIODIC_TABLE.loc[idx]["Za"]] = value
else:
affine_param[idx] = value
else:
affine_param = torch.full((self.max_Za + 1,), float(values))
affine_dict[name] = nn.Parameter(affine_param, requires_grad=param["learnable"])
return nn.ParameterDict(affine_dict)
[docs]
def get_Ea(self, Ea: Tensor, Qa: Tensor, Za: Tensor) -> Tensor:
return (Ea + self.shifts.Ea.gather(0, Za).view((-1, ) if Ea.dim() == 1 else (-1, 1))) * self.scales.Ea.gather(0, Za).view((-1, ) if Ea.dim() == 1 else (-1, 1))
[docs]
def get_Qa(self, Ea: Tensor, Qa: Tensor, Za: Tensor) -> Tensor:
return (Qa + self.shifts.Qa.gather(0, Za).view((-1, ) if Qa.dim() == 1 else (-1, 1))) * self.scales.Qa.gather(0, Za).view((-1, ) if Qa.dim() == 1 else (-1, 1))
def _load_from_state_dict(self, state_dict: Dict[str, Tensor], *args, **kwargs):
for k, v in state_dict.items():
if k.endswith("shifts.Ea") or k.endswith("shifts.Qa") or k.endswith("scales.Ea") or k.endswith("scales.Qa"):
if len(v) > self.max_Za + 1:
state_dict[k] = v[:self.max_Za + 1]
elif len(v) < self.max_Za + 1:
if k.endswith("shifts.Ea"):
state_dict[k] = torch.concat([v, self.shifts.Ea[len(v):]], dim=0)
if k.endswith("shifts.Qa"):
state_dict[k] = torch.concat([v, self.shifts.Qa[len(v):]], dim=0)
if k.endswith("scales.Ea"):
state_dict[k] = torch.concat([v, self.scales.Ea[len(v):]], dim=0)
if k.endswith("scales.Qa"):
state_dict[k] = torch.concat([v, self.scales.Qa[len(v):]], dim=0)
super()._load_from_state_dict(state_dict, *args, **kwargs)