from typing import List, Tuple, Union
from collections import namedtuple
import torch
import e3nn.o3 as o3
from e3nn.util.jit import compile_mode
from e3nn.o3 import Irreps
_INPUT = namedtuple("_INPUT", "tensor, start, stop")
_TP = namedtuple("_TP", "op, args")
[docs]
def linear_out_irreps(irreps: Irreps, target_irreps: Irreps) -> Irreps:
# Assuming simplified irreps
irreps_mid = []
for _, ir_in in irreps:
found = False
for mul, ir_out in target_irreps:
if ir_in == ir_out:
irreps_mid.append((mul, ir_out))
found = True
break
if not found:
raise RuntimeError(f"{ir_in} not in {target_irreps}")
return Irreps(irreps_mid)
[docs]
def tp_out_irreps_with_instructions(
irreps1: Irreps, irreps2: Irreps, target_irreps: Irreps
) -> Tuple[Irreps, List]:
trainable = True
# Collect possible irreps and their instructions
irreps_out_list: List[Tuple[int, Irreps]] = []
instructions = []
for i, (mul, ir_in) in enumerate(irreps1):
for j, (_, ir_edge) in enumerate(irreps2):
for ir_out in ir_in * ir_edge: # | l1 - l2 | <= l <= l1 + l2
if ir_out in target_irreps:
k = len(irreps_out_list) # instruction index
irreps_out_list.append((mul, ir_out))
instructions.append((i, j, k, "uvu", trainable))
# We sort the output irreps of the tensor product so that we can simplify them
# when they are provided to the second o3.Linear
irreps_out = Irreps(irreps_out_list)
irreps_out, permut, _ = irreps_out.sort()
# Permute the output indexes of the instructions to match the sorted irreps:
instructions = [
(i_in1, i_in2, permut[i_out], mode, train)
for i_in1, i_in2, i_out, mode, train in instructions
]
instructions = sorted(instructions, key=lambda x: x[2])
return irreps_out, instructions
[docs]
@compile_mode("script")
class reshape_irreps(torch.nn.Module):
[docs]
def __init__(self, irreps: Irreps) -> None:
super().__init__()
self.irreps = Irreps(irreps)
self.dims = []
self.muls = []
for mul, ir in self.irreps:
d = ir.dim
self.dims.append(d)
self.muls.append(mul)
[docs]
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
ix = 0
out = []
batch, _ = tensor.shape
for mul, d in zip(self.muls, self.dims):
field = tensor[:, ix : ix + mul * d] # [batch, sample, mul * repr]
ix += mul * d
field = field.reshape(batch, mul, d)
out.append(field)
return torch.cat(out, dim=-1)
def _wigner_nj(
irrepss: List[Irreps],
normalization: str = "component",
filter_ir_mid=None,
dtype=None,
):
irrepss = [Irreps(irreps) for irreps in irrepss]
if filter_ir_mid is not None:
filter_ir_mid = [o3.Irrep(ir) for ir in filter_ir_mid]
if len(irrepss) == 1:
(irreps,) = irrepss
ret = []
e = torch.eye(irreps.dim, dtype=dtype)
i = 0
for mul, ir in irreps:
for _ in range(mul):
sl = slice(i, i + ir.dim)
ret += [(ir, _INPUT(0, sl.start, sl.stop), e[sl])]
i += ir.dim
return ret
*irrepss_left, irreps_right = irrepss
ret = []
for ir_left, path_left, C_left in _wigner_nj(
irrepss_left,
normalization=normalization,
filter_ir_mid=filter_ir_mid,
dtype=dtype,
):
i = 0
for mul, ir in irreps_right:
for ir_out in ir_left * ir:
if filter_ir_mid is not None and ir_out not in filter_ir_mid:
continue
C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype)
if normalization == "component":
C *= ir_out.dim**0.5
if normalization == "norm":
C *= ir_left.dim**0.5 * ir.dim**0.5
C = torch.einsum("jk,ijl->ikl", C_left.flatten(1), C)
C = C.reshape(
ir_out.dim, *(irreps.dim for irreps in irrepss_left), ir.dim
)
for u in range(mul):
E = torch.zeros(
ir_out.dim,
*(irreps.dim for irreps in irrepss_left),
irreps_right.dim,
dtype=dtype,
)
sl = slice(i + u * ir.dim, i + (u + 1) * ir.dim)
E[..., sl] = C
ret += [
(
ir_out,
_TP(
op=(ir_left, ir, ir_out),
args=(
path_left,
_INPUT(len(irrepss_left), sl.start, sl.stop),
),
),
E,
)
]
i += mul * ir.dim
return sorted(ret, key=lambda x: x[0])
[docs]
def U_matrix_real(
irreps_in: Union[str, Irreps],
irreps_out: Union[str, Irreps],
correlation: int,
normalization: str = "component",
filter_ir_mid=None,
dtype=None,
):
irreps_out = Irreps(irreps_out)
irrepss = [Irreps(irreps_in)] * correlation
if correlation == 4:
filter_ir_mid = [
(0, 1),
(1, -1),
(2, 1),
(3, -1),
(4, 1),
(5, -1),
(6, 1),
(7, -1),
(8, 1),
(9, -1),
(10, 1),
(11, -1),
]
wigners = _wigner_nj(irrepss, normalization, filter_ir_mid, dtype)
current_ir = wigners[0][0]
out = []
stack = torch.tensor([])
for ir, _, base_o3 in wigners:
if ir in irreps_out and ir == current_ir:
stack = torch.cat((stack, base_o3.squeeze().unsqueeze(-1)), dim=-1)
last_ir = current_ir
elif ir in irreps_out and ir != current_ir:
if len(stack) != 0:
out += [last_ir, stack]
stack = base_o3.squeeze().unsqueeze(-1)
current_ir, last_ir = ir, ir
else:
current_ir = ir
out += [last_ir, stack]
return out