Source code for enerzyme.models.activation

import math
from abc import ABC, abstractmethod
from typing import Literal, Dict, Union
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module, Parameter


[docs] class BaseScaledTemperedActivation(ABC, Module):
[docs] def __init__(self, dim_feature: int=1, initial_alpha: float=1.0, initial_beta: float=1.0, learnable: bool=False) -> None: super().__init__() if float(initial_alpha) == 1.0 and float(initial_beta) == 1.0 and not learnable: self.simple = True self.alpha = 1.0 self.beta = 1.0 self._activation_fn = self.simple_activation_fn else: self.simple = False self.register_parameter("alpha", Parameter(torch.full((dim_feature,), initial_alpha), requires_grad=learnable)) self.register_parameter("beta", Parameter(torch.full((dim_feature,), initial_beta), requires_grad=learnable)) self._activation_fn = self.activation_fn
[docs] @abstractmethod def simple_activation_fn(self, x: Tensor) -> Tensor: ...
[docs] @abstractmethod def activation_fn(self, x: Tensor) -> Tensor: ...
[docs] def forward(self, x: Tensor) -> Tensor: return self._activation_fn(x)
[docs] class ShiftedSoftplus(BaseScaledTemperedActivation):
[docs] def __init__(self, dim_feature: int=1, initial_alpha: float=1.0, initial_beta: float=1.0, learnable: bool=False) -> None: """ Shifted softplus activation function with learnable feature-wise parameters: f(x) = alpha/beta * (softplus(beta*x) - log(2)) softplus(x) = log(exp(x) + 1) For beta -> 0 : f(x) -> 0.5*alpha*x For beta -> inf: f(x) -> max(0, alpha*x) Arguments: num_features (int): Dimensions of feature space. initial_alpha (float): Initial "scale" alpha of the softplus function. initial_beta (float): Initial "temperature" beta of the softplus function. """ super().__init__(dim_feature, initial_alpha, initial_beta, learnable) self.log2 = math.log(2.0)
[docs] def simple_activation_fn(self, x: Tensor) -> Tensor: return F.softplus(x) - self.log2
[docs] def activation_fn(self, x: Tensor) -> Tensor: return self.alpha * torch.where( self.beta != 0, (F.softplus(self.beta * x) - self.log2) / self.beta, 0.5 * x )
[docs] class Swish(BaseScaledTemperedActivation): """ Swish activation function with learnable feature-wise parameters: f(x) = alpha*x * sigmoid(beta*x) sigmoid(x) = 1/(1 + exp(-x)) For beta -> 0 : f(x) -> 0.5*alpha*x For beta -> inf: f(x) -> max(0, alpha*x) Arguments: num_features (int): Dimensions of feature space. initial_alpha (float): Initial "scale" alpha of the "linear component". initial_beta (float): Initial "temperature" of the "sigmoid component". The default value of 1.702 has the effect of initializing swish to an approximation of the Gaussian Error Linear Unit (GELU) activation function from Hendrycks, Dan, and Gimpel, Kevin. "Gaussian error linear units (GELUs)." """
[docs] def __init__( self, dim_feature: int=1, initial_alpha: float=1.0, initial_beta: float=1.702, learnable: bool=True ) -> None: """ Initializes the Swish class. """ super().__init__(dim_feature, initial_alpha, initial_beta, learnable)
[docs] def simple_activation_fn(self, x: Tensor) -> Tensor: return F.silu(x)
[docs] def activation_fn(self, x: Tensor) -> Tensor: """ Evaluate activation function given the input features x. num_features: Dimensions of feature space. Arguments: x (FloatTensor [:, num_features]): Input features. Returns: y (FloatTensor [:, num_features]): Activated features. """ return self.alpha * x * torch.sigmoid(self.beta * x)
ACTIVATION_REGISTER = { "shifted_softplus": ShiftedSoftplus, "swish": Swish, } ACTIVATION_KEY_TYPE = Literal["shifted_softplus", "swish"] ACTIVATION_PARAM_TYPE = Dict[Literal["dim_feature", "initial_alpha", "initial_beta", "learnable"], Union[int, float, bool]]
[docs] def get_activation_fn( key: ACTIVATION_KEY_TYPE, activation_params: ACTIVATION_PARAM_TYPE=dict() ) -> BaseScaledTemperedActivation: return ACTIVATION_REGISTER[key](**activation_params)