Source code for enerzyme.models.irreps_tools

from typing import List, Tuple, Union
from collections import namedtuple
import torch
import e3nn.o3 as o3
from e3nn.util.jit import compile_mode
from e3nn.o3 import Irreps

_INPUT = namedtuple("_INPUT", "tensor, start, stop")
_TP = namedtuple("_TP", "op, args")


[docs] def linear_out_irreps(irreps: Irreps, target_irreps: Irreps) -> Irreps: # Assuming simplified irreps irreps_mid = [] for _, ir_in in irreps: found = False for mul, ir_out in target_irreps: if ir_in == ir_out: irreps_mid.append((mul, ir_out)) found = True break if not found: raise RuntimeError(f"{ir_in} not in {target_irreps}") return Irreps(irreps_mid)
[docs] def tp_out_irreps_with_instructions( irreps1: Irreps, irreps2: Irreps, target_irreps: Irreps ) -> Tuple[Irreps, List]: trainable = True # Collect possible irreps and their instructions irreps_out_list: List[Tuple[int, Irreps]] = [] instructions = [] for i, (mul, ir_in) in enumerate(irreps1): for j, (_, ir_edge) in enumerate(irreps2): for ir_out in ir_in * ir_edge: # | l1 - l2 | <= l <= l1 + l2 if ir_out in target_irreps: k = len(irreps_out_list) # instruction index irreps_out_list.append((mul, ir_out)) instructions.append((i, j, k, "uvu", trainable)) # We sort the output irreps of the tensor product so that we can simplify them # when they are provided to the second o3.Linear irreps_out = Irreps(irreps_out_list) irreps_out, permut, _ = irreps_out.sort() # Permute the output indexes of the instructions to match the sorted irreps: instructions = [ (i_in1, i_in2, permut[i_out], mode, train) for i_in1, i_in2, i_out, mode, train in instructions ] instructions = sorted(instructions, key=lambda x: x[2]) return irreps_out, instructions
[docs] @compile_mode("script") class reshape_irreps(torch.nn.Module):
[docs] def __init__(self, irreps: Irreps) -> None: super().__init__() self.irreps = Irreps(irreps) self.dims = [] self.muls = [] for mul, ir in self.irreps: d = ir.dim self.dims.append(d) self.muls.append(mul)
[docs] def forward(self, tensor: torch.Tensor) -> torch.Tensor: ix = 0 out = [] batch, _ = tensor.shape for mul, d in zip(self.muls, self.dims): field = tensor[:, ix : ix + mul * d] # [batch, sample, mul * repr] ix += mul * d field = field.reshape(batch, mul, d) out.append(field) return torch.cat(out, dim=-1)
def _wigner_nj( irrepss: List[Irreps], normalization: str = "component", filter_ir_mid=None, dtype=None, ): irrepss = [Irreps(irreps) for irreps in irrepss] if filter_ir_mid is not None: filter_ir_mid = [o3.Irrep(ir) for ir in filter_ir_mid] if len(irrepss) == 1: (irreps,) = irrepss ret = [] e = torch.eye(irreps.dim, dtype=dtype) i = 0 for mul, ir in irreps: for _ in range(mul): sl = slice(i, i + ir.dim) ret += [(ir, _INPUT(0, sl.start, sl.stop), e[sl])] i += ir.dim return ret *irrepss_left, irreps_right = irrepss ret = [] for ir_left, path_left, C_left in _wigner_nj( irrepss_left, normalization=normalization, filter_ir_mid=filter_ir_mid, dtype=dtype, ): i = 0 for mul, ir in irreps_right: for ir_out in ir_left * ir: if filter_ir_mid is not None and ir_out not in filter_ir_mid: continue C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype) if normalization == "component": C *= ir_out.dim**0.5 if normalization == "norm": C *= ir_left.dim**0.5 * ir.dim**0.5 C = torch.einsum("jk,ijl->ikl", C_left.flatten(1), C) C = C.reshape( ir_out.dim, *(irreps.dim for irreps in irrepss_left), ir.dim ) for u in range(mul): E = torch.zeros( ir_out.dim, *(irreps.dim for irreps in irrepss_left), irreps_right.dim, dtype=dtype, ) sl = slice(i + u * ir.dim, i + (u + 1) * ir.dim) E[..., sl] = C ret += [ ( ir_out, _TP( op=(ir_left, ir, ir_out), args=( path_left, _INPUT(len(irrepss_left), sl.start, sl.stop), ), ), E, ) ] i += mul * ir.dim return sorted(ret, key=lambda x: x[0])
[docs] def U_matrix_real( irreps_in: Union[str, Irreps], irreps_out: Union[str, Irreps], correlation: int, normalization: str = "component", filter_ir_mid=None, dtype=None, ): irreps_out = Irreps(irreps_out) irrepss = [Irreps(irreps_in)] * correlation if correlation == 4: filter_ir_mid = [ (0, 1), (1, -1), (2, 1), (3, -1), (4, 1), (5, -1), (6, 1), (7, -1), (8, 1), (9, -1), (10, 1), (11, -1), ] wigners = _wigner_nj(irrepss, normalization, filter_ir_mid, dtype) current_ir = wigners[0][0] out = [] stack = torch.tensor([]) for ir, _, base_o3 in wigners: if ir in irreps_out and ir == current_ir: stack = torch.cat((stack, base_o3.squeeze().unsqueeze(-1)), dim=-1) last_ir = current_ir elif ir in irreps_out and ir != current_ir: if len(stack) != 0: out += [last_ir, stack] stack = base_o3.squeeze().unsqueeze(-1) current_ir, last_ir = ir, ir else: current_ir = ir out += [last_ir, stack] return out