import math
from abc import abstractmethod
from typing import Literal, Optional
import numpy as np
import torch
from torch import Tensor
from torch.nn import Parameter
import torch.nn.functional as F
from . import BaseFFLayer
from ..cutoff import CUTOFF_REGISTER
from ..functional import softplus_inverse
[docs]
class BaseRBF(BaseFFLayer):
[docs]
def __init__(
self,
num_rbf: int,
cutoff_sr: float,
cutoff_fn: Literal["polynomial", "bump"]
) -> None:
super().__init__(input_fields={"Dij_sr", "cutoff_values_sr"}, output_fields={"rbf"})
self.num_rbf = num_rbf
self.cutoff_sr = cutoff_sr
self.cutoff_fn = CUTOFF_REGISTER[cutoff_fn]
[docs]
def get_rbf(self, Dij_sr: Tensor, cutoff_values_sr: Optional[Tensor]=None, **kwargs) -> Tensor:
if cutoff_values_sr is None:
cutoff_values_sr = self.cutoff_fn(Dij_sr, cutoff=self.cutoff_sr)
return cutoff_values_sr.view(-1, 1) * self._get_rbf(Dij_sr)
@abstractmethod
def _get_rbf(self, Dij: Tensor) -> Tensor:
...
[docs]
class GaussianRBFLayer(BaseRBF):
"""
Radial basis functions based on Gaussian functions:
.. math:: g_i(x) = \\exp(-\\mathtt{width}\\cdot(x-\\mathtt{center}_i)^2)
Here, :math:`i` takes values from :math:`0` to :math:`\\mathtt{num\\_rbf}-1`. The `center` is chosen
to optimally span the range :math:`x\\in (0,\\mathtt{cutoff}]` and the :math:`\\mathtt{width}` parameter is
selected to give optimal overlap between adjacent Gaussian functions.
Arguments:
num_basis_functions (int):
Number of radial basis functions.
cutoff (float):
Cutoff radius.
"""
[docs]
def __init__(self, num_rbf: int, cutoff_sr: float, cutoff_fn: Literal["polynomial", "bump"]="bump") -> None:
""" Initializes the GaussianFunctions class. """
super().__init__(num_rbf, cutoff_sr, cutoff_fn)
self.register_buffer(
"center",
torch.linspace(0, cutoff_sr, num_rbf, dtype=torch.float64),
)
self.register_buffer(
"width", torch.tensor(num_rbf / cutoff_sr, dtype=torch.float64)
)
def _get_rbf(self, Dij: Tensor) -> Tensor:
"""
Evaluates radial basis functions given distances and the corresponding
values of a cutoff function (must be consistent with cutoff value
passed at initialization).
N: Number of input values.
num_basis_functions: Number of radial basis functions.
Arguments:
Dij (FloatTensor [N]):
Input distances.
Returns:
rbf (FloatTensor [N, num_basis_functions]):
Values of the radial basis functions for the distances r.
"""
return torch.exp(
-self.width * (Dij.view(-1, 1) - self.center) ** 2
)
[docs]
class BernsteinRBFLayer(BaseRBF):
"""
Radial basis functions based on Bernstein polynomials given by:
b_{v,n}(x) = (n over v) * (x/cutoff)**v * (1-(x/cutoff))**(n-v)
(see https://en.wikipedia.org/wiki/Bernstein_polynomial)
Here, n = num_basis_functions-1 and v takes values from 0 to n. The basis
functions are placed to optimally cover the range x = 0...cutoff.
Arguments:
num_basis_functions (int):
Number of radial basis functions.
cutoff (float):
Cutoff radius.
"""
[docs]
def __init__(self, num_rbf: int, cutoff_sr: float, cutoff_fn: Literal["polynomial", "bump"]="bump") -> None:
""" Initializes the BernsteinPolynomials class. """
super().__init__(num_rbf, cutoff_sr, cutoff_fn)
# compute values to initialize buffers
from ..special import get_berstein_coefficient
v, n, logc = get_berstein_coefficient(self.num_rbf)
# register buffers and parameters
self.register_buffer("logc", torch.tensor(logc, dtype=torch.float64))
self.register_buffer("n", torch.tensor(n, dtype=torch.float64))
self.register_buffer("v", torch.tensor(v, dtype=torch.float64))
def _get_rbf(self, r: Tensor) -> Tensor:
"""
Evaluates radial basis functions given distances and the corresponding
values of a cutoff function (must be consistent with cutoff value
passed at initialization).
N: Number of input values.
num_basis_functions: Number of radial basis functions.
Arguments:
Dij (FloatTensor [N]):
Input distances.
Returns:
rbf (FloatTensor [N, num_basis_functions]):
Values of the radial basis functions for the distances r.
"""
x = r.view(-1, 1) / self.cutoff_sr
x = torch.where(x < 1.0, x, 0.5 * torch.ones_like(x)) # prevent nans
x = torch.log(x)
x = self.logc + self.n * x + self.v * torch.log(-torch.expm1(x))
return torch.exp(x)
[docs]
class SincRBFLayer(BaseRBF):
"""
Radial basis functions based on sinc functions given by:
g_i(x) = sinc((i+1)*x/cutoff)
Here, i takes values from 0 to num_basis_functions-1.
Arguments:
num_basis_functions (int):
Number of radial basis functions.
cutoff (float):
Cutoff radius.
"""
[docs]
def __init__(self, num_rbf: int, cutoff_sr: float, cutoff_fn: Literal["polynomial", "bump"]="bump") -> None:
""" Initializes the SincFunctions class. """
super().__init__(num_rbf, cutoff_sr, cutoff_fn)
self.register_buffer(
"factor", torch.linspace(1, num_rbf, num_rbf, dtype=torch.float64) / cutoff_sr,
)
try:
from torch import sinc
except ImportError:
from ..special import sinc
self.sinc = sinc
def _get_rbf(self, Dij: Tensor) -> Tensor:
"""
Evaluates radial basis functions given distances and the corresponding
values of a cutoff function (must be consistent with cutoff value
passed at initialization).
N: Number of input values.
num_basis_functions: Number of radial basis functions.
Arguments:
Dij (FloatTensor [N]):
Input distances.
Returns:
rbf (FloatTensor [N, num_basis_functions]):
Values of the radial basis functions for the distances r.
"""
x = self.factor * Dij.view(-1, 1)
return self.sinc(x)
[docs]
class ExponentialRBF(BaseRBF):
[docs]
def __init__(
self,
num_rbf: int,
no_basis_at_infinity: bool=False,
init_alpha: float=0.9448630629184640,
exp_weighting: bool=False,
learnable_shape: bool=True,
cutoff_sr: float=float("inf"),
cutoff_fn: Literal["polynomial", "bump"]="bump"
) -> None:
'''
The base class of radial basis functions with a general exponential form.
It entails the physical knowledge that
bound state wave functions in two-body systems decay exponentially. [1,2,3]
RBF(r; alpha) = cutoff_fn(r) * exp(inner_fn(r; alpha)) * (exp(-alpha*r) if exp_weighting)
Params:
----------
num_basis_functions: Number of radial basis functions.
no_basis_function_at_infinity: If True, no basis function is put at exp(-alpha*x) = 0, i.e.
x = infinity.
init_alpha: Initial value for scaling parameter alpha (Default value corresponds
to 0.5 1/Bohr converted to 1/Angstrom).
exp_weighting: If `True`, basis functions are weighted with a factor exp(-alpha*r).
learnable_shape: If `True`, shape parameters of exponentials are learnable.
cutoff: Short range cutoff threshold for radial base functions.
cutoff_fn: Short range cutoff function, whose are called by `cutoff_fn(x, cutoff=cutoff)`
where x is the distance.
References:
----------
[1] Commun. Math. Phys. 1973, 32, 319−340.
[2] J. Chem. Theory Comput. 2019, 15, 3678−3693.
[3] Nat. Chem. 2020, 12, 891–897.
'''
super().__init__(num_rbf, cutoff_sr, cutoff_fn)
self.exp_weighting = exp_weighting
self.learnable_shape = learnable_shape
self.no_basis_at_infinity = no_basis_at_infinity
self.register_parameter(
"alpha", Parameter(softplus_inverse(torch.tensor(init_alpha, dtype=torch.float64)))
)
[docs]
@abstractmethod
def inner_fn(self, alphar: Optional[Tensor]=None, expalphar: Optional[Tensor]=None) -> Tensor:
...
def _get_rbf(self, Dij: Tensor) -> Tensor:
'''
Evaluate the RBF values
Params:
----------
r: Float tensor of distances, shape [M]
cutoff_values: Float tensor of pre-calculated cutoff distances, shape [M].
If not provided, the cutoff distances are calculated by `cutoff_fn`.
Returns:
----------
rbf: Float tensor of RBFs, shape [M, `num_basis_functions`]
'''
alphar = -F.softplus(self.alpha) * Dij.view(-1, 1)
expalphar = torch.exp(alphar)
return torch.exp(self.inner_fn(alphar, expalphar)) * (expalphar if self.exp_weighting else 1)
[docs]
class ExponentialGaussianRBFLayer(ExponentialRBF):
[docs]
def __init__(
self,
num_rbf: int,
no_basis_at_infinity: bool=False,
init_alpha: float=0.9448630629184640,
exp_weighting: bool=False,
learnable_shape: bool=True,
cutoff_sr: float=float("inf"),
cutoff_fn: Literal["polynomial", "bump"]="polynomial",
init_width_flavor: Literal["PhysNet", "SpookyNet"]="PhysNet"
) -> None:
r'''
Radial basis functions based on exponential Gaussian functions given by:
g_i(x) = exp(-width_i*(exp(-alpha*x)-center_i)**2)
Params:
----------
num_basis_functions: Number of radial basis functions.
dtype: Data type of floating numbers.
no_basis_function_at_infinity: If True, no basis function is put at exp(-alpha*x) = 0, i.e.
x = infinity.
init_alpha: Initial value for scaling parameter alpha (Default value corresponds
to 0.5 1/Bohr converted to 1/Angstrom).
init_width_flavor: Initialization flavor for width of the exponentials. Options:
- `PhysNet`: A constant number (2K^{-1}(1-\exp(-`cutoff`)))^{-2}, where K is `num_basis_functions`
- `SpookyNet`: A constant number K or K+1 (`no_basis_function_at_infinity=True`)
exp_weighting: If `True`, basis functions are weighted with a factor exp(-alpha*r).
learnable_shape: If `True`, centers and widths of exponentials are learnable.
cutoff: Short range cutoff threshold for radial base functions.
cutoff_fn: Short range cutoff function, whose are called by `cutoff_fn(x, cutoff=cutoff)`
where x is the distance.
init_width_flavor: Initialization flavor for width of the exponentials. Options:
- `PhysNet`: A constant number (2K^{-1}(1-\exp(-`cutoff`)))^{-2}, where K is `num_basis_functions` [1].
- `SpookyNet`: A constant number K or K+1 (`no_basis_function_at_infinity=True`).
References:
----------
[1] J. Chem. Theory Comput. 2019, 15, 3678−3693.
'''
super().__init__(
num_rbf=num_rbf,
no_basis_at_infinity=no_basis_at_infinity,
init_alpha=init_alpha,
exp_weighting=exp_weighting,
learnable_shape=learnable_shape,
cutoff_sr=cutoff_sr,
cutoff_fn=cutoff_fn
)
if cutoff_sr == float("inf") and no_basis_at_infinity:
self.register_parameter(
"centers", Parameter(
softplus_inverse(torch.linspace(
1, 0, num_rbf + 1, dtype=torch.float64
)[:-1]),
requires_grad=self.learnable_shape
)
)
else:
self.register_parameter(
"centers", Parameter(
softplus_inverse(torch.linspace(
1, math.exp(-cutoff_sr), num_rbf, dtype=torch.float64
)),
requires_grad=self.learnable_shape
)
)
self._init_width(init_width_flavor)
def _init_width(self, init_width_flavor: Literal["PhysNet", "SpookyNet"]="PhysNet") -> None:
'''
Initialize the widths of exponentials based on the chosen flavor.
'''
if init_width_flavor == "SpookyNet":
self.register_parameter(
"widths", Parameter(
softplus_inverse(
torch.tensor(
1.0 * self.num_rbf + \
int(self.no_basis_at_infinity), dtype=torch.float64
)
),
requires_grad=self.learnable_shape
)
)
elif init_width_flavor == "PhysNet":
self.register_parameter(
"widths", Parameter(
softplus_inverse(
torch.tensor(
[(0.5 / ((-math.expm1(-self.cutoff_sr)) / self.num_rbf)) ** 2] * \
self.num_rbf, dtype=torch.float64
)
),
requires_grad=self.learnable_shape
)
)
[docs]
def inner_fn(self, alphar: Tensor, expalphar: Tensor) -> torch.Tensor:
return -F.softplus(self.widths) * (expalphar - F.softplus(self.centers)) ** 2
[docs]
class ExponentialBernsteinRBFLayer(ExponentialRBF):
[docs]
def __init__(
self,
num_rbf: int,
no_basis_at_infinity: bool=False,
init_alpha: float=0.9448630629184640,
exp_weighting: bool=False,
learnable_shape: bool=True,
cutoff_sr: float=float("inf"),
cutoff_fn: Literal["polynomial", "bump"]="bump",
) -> None:
'''
Radial basis functions based on exponential Bernstein polynomials given by:
b_{v,n}(x) = (n over v) * exp(-alpha*x)**v * (1-exp(-alpha*x))**(n-v)
(see https://en.wikipedia.org/wiki/Bernstein_polynomial)
For n to infinity, linear combination of b_{v,n}s can approximate
any continuous function on the interval [0, 1] uniformly [1].
NOTE: There is a problem for x = 0, as log(-expm1(0)) will be log(0) = -inf.
This itself is not an issue, but the buffer v contains an entry 0 and
0*(-inf)=nan. The correct behaviour could be recovered by replacing the nan
with 0.0, but should not be necessary because issues are only present when
r = 0, which will not occur with chemically meaningful inputs.
References:
----------
[1] Commun. Kharkov Math. Soc. 1912, 13, 1.
'''
super().__init__(
num_rbf=num_rbf,
no_basis_at_infinity=no_basis_at_infinity,
init_alpha=init_alpha,
exp_weighting=exp_weighting,
learnable_shape=learnable_shape,
cutoff_sr=cutoff_sr,
cutoff_fn=cutoff_fn
)
from ..special import get_berstein_coefficient
self.num_rbf += int(no_basis_at_infinity)
v, n, logc = get_berstein_coefficient(self.num_rbf)
if no_basis_at_infinity: # remove last basis function at infinity
v = v[:-1]
n = n[:-1]
logc = logc[:-1]
self.register_buffer("logc", torch.tensor(logc))
self.register_buffer("n", torch.tensor(n))
self.register_buffer("v", torch.tensor(v))
[docs]
def inner_fn(self, alphar: Tensor, expalphar: Tensor) -> Tensor:
return self.logc + self.n * alphar + self.v * torch.log(-torch.expm1(alphar))
[docs]
class BesselRBFLayer(BaseRBF):
[docs]
def __init__(self, num_rbf: int, cutoff_sr: float, cutoff_fn: Literal["polynomial", "bump"]="polynomial", trainable: bool=False) -> None:
super().__init__(num_rbf, cutoff_sr, cutoff_fn)
bessel_weights = (
np.pi
/ cutoff_sr
* torch.linspace(
start=1.0,
end=num_rbf,
steps=num_rbf,
dtype=torch.get_default_dtype(),
)
)
if trainable:
self.bessel_weights = torch.nn.Parameter(bessel_weights)
else:
self.register_buffer("bessel_weights", bessel_weights)
self.register_buffer(
"prefactor",
torch.tensor(np.sqrt(2.0 / cutoff_sr), dtype=torch.get_default_dtype()),
)
def _get_rbf(self, x: Tensor) -> Tensor: # [..., 1]
numerator = torch.sin(self.bessel_weights.view(1, -1) * x.view(-1, 1)) # [..., num_basis]
return self.prefactor * (numerator / x.view(-1, 1))
[docs]
class GaussianSmearing(BaseFFLayer):
[docs]
def __init__(
self,
num_rbf: int,
cutoff_sr: float,
cuton: float=0.0,
):
super().__init__(input_fields={"Dij_sr"}, output_fields={"rbf"})
offset = torch.linspace(cuton, cutoff_sr, num_rbf)
self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2
self.register_buffer('offset', offset)
[docs]
def get_rbf(self, Dij_sr: Tensor) -> Tensor:
dist = Dij_sr.view(-1, 1) - self.offset.view(1, -1)
return torch.exp(self.coeff * torch.pow(dist, 2))