Source code for enerzyme.models.cutoff
from typing import Literal, Callable
import torch
from torch import Tensor
[docs]
def scale(cutoff_fn: Callable[[Tensor, Tensor, Tensor], Tensor]):
def scaled_transition_fn(x: Tensor, cutoff: float, cuton: float=0) -> Tensor:
zeros = torch.zeros_like(x)
ones = torch.ones_like(x)
x_ = (x - cuton) / (cutoff - cuton)
kernel = cutoff_fn(x_, zeros, ones)
return torch.where(x_ > 0, torch.where(x_ < 1, kernel, zeros), ones)
return scaled_transition_fn
[docs]
@scale
def polynomial_transition(x_: Tensor, zeros: Tensor, ones: Tensor) -> Tensor:
"""
Polynomial cutoff function that goes from f(x) = 1 to f(x) = 0 in the interval
from x = 0 to x = 1 with sufficiently many smooth derivatives [1].
For x >= 1, f(x) = 0.
For x <= 0, f(x) = 1.
Params:
-----
x: Only 0<=x<=1 should be used as input.
References:
-----
[1] Texturing & Modeling: A Procedural Approach; Morgan Kaufmann: 2003.
"""
x3 = x_ ** 3
x4 = x3 * x_
x5 = x4 * x_
return 1 - 6 * x5 + 15 * x4 - 10 * x3
[docs]
@scale
def bump_transition(x_: Tensor, zeros: Tensor, ones: Tensor) -> Tensor:
"""
Smooth cutoff function that goes from f(x) = 1 to f(x) = 0 in the interval
from x = 0 to x = cutoff [1].
For x >= cutoff, f(x) = 0.
Params:
-----
x: Only 0<=x<=1 should be used as input.
References:
-----
[1] Nat. Commun., 2021, 12, 7273.
"""
x_ = torch.where((0 < x_) & (x_ < 1), x_, zeros)
x2 = x_ ** 2
return torch.exp(-x2 / (1 - x2))
def _smooth_transition(x_: Tensor, ones: Tensor) -> Tensor:
return torch.exp(-1 / torch.where(x_ > 0, x_, ones))
[docs]
@scale
def smooth_transition(x_: Tensor, zeros: Tensor, ones: Tensor) -> Tensor:
fp = _smooth_transition(x_, ones)
fm = _smooth_transition(1 - x_, ones)
return fm / (fp + fm)
CUTOFF_REGISTER = {
"polynomial": polynomial_transition,
"bump": bump_transition,
"smooth": smooth_transition
}
CUTOFF_KEY_TYPE = Literal["polynomial", "bump", "smooth"]