from typing import Optional, Union, Literal
from numpy import ndarray
from torch import Tensor
from torch.nn import Module, init, Sequential, Parameter, Dropout
import torch.nn.functional as F
import torch
from ..activation import get_activation_fn, ACTIVATION_PARAM_TYPE, ACTIVATION_KEY_TYPE
from ..init import semi_orthogonal_glorot_weights
INITIAL_WEIGHT_TYPE = Union[Tensor, ndarray, Literal["semi_orthogonal_glorot", "orthogonal", "zero", "xavier_uniform"]]
INITIAL_BIAS_TYPE = Union[Tensor, ndarray, Literal["zero"]]
[docs]
class NeuronLayer(Module):
def __str__(self) -> str:
return "[ " + str(self.dim_feature_in) + " -> " + str(self.dim_feature_out) + " ]"
[docs]
def __init__(
self, dim_feature_in, dim_feature_out,
activation_fn: Optional[ACTIVATION_KEY_TYPE]=None,
activation_params: ACTIVATION_PARAM_TYPE=dict(),
) -> None:
super().__init__()
self.dim_feature_in = dim_feature_in
self.dim_feature_out = dim_feature_out
if activation_fn is not None:
self.activation_fn = get_activation_fn(activation_fn, activation_params)
else:
self.activation_fn = None
[docs]
class DenseLayer(NeuronLayer):
def __str__(self) -> str:
return "Dense layer: " + super().__str__()
[docs]
def __init__(
self, dim_feature_in: int, dim_feature_out: int,
activation_fn: Optional[ACTIVATION_KEY_TYPE]=None, activation_params: ACTIVATION_PARAM_TYPE=dict(),
initial_weight: INITIAL_WEIGHT_TYPE="orthogonal",
initial_bias: INITIAL_BIAS_TYPE="zero",
use_bias: bool=True,
shallow_ensemble_size: int=1
) -> None:
self.shallow_ensemble_size = shallow_ensemble_size
super().__init__(dim_feature_in, dim_feature_out, activation_fn, activation_params)
if initial_weight == "semi_orthogonal_glorot":
self.weight = Parameter(semi_orthogonal_glorot_weights(dim_feature_in, dim_feature_out * shallow_ensemble_size))
elif initial_weight == "orthogonal" or shallow_ensemble_size > 1:
self.weight = Parameter(torch.empty(dim_feature_out * shallow_ensemble_size, dim_feature_in))
init.orthogonal_(self.weight)
elif initial_weight == "zero":
self.weight = Parameter(torch.empty(dim_feature_out * shallow_ensemble_size, dim_feature_in))
init.zeros_(self.weight)
elif initial_weight == "xavier_uniform":
self.weight = Parameter(torch.empty(dim_feature_out * shallow_ensemble_size, dim_feature_in))
init.xavier_uniform_(self.weight)
else:
if not isinstance(initial_weight, Tensor):
initial_weight = torch.tensor(initial_weight)
self.weight = Parameter(initial_weight)
if use_bias:
if initial_bias == "zero":
self.bias = Parameter(torch.empty(dim_feature_out * shallow_ensemble_size))
init.zeros_(self.bias)
else:
if not isinstance(initial_bias, Tensor):
initial_bias = torch.tensor(initial_bias)
self.bias = Parameter(initial_bias)
else:
self.bias = None
[docs]
def forward(self, x: Tensor) -> Tensor:
y = F.linear(x, self.weight, self.bias)
if self.activation_fn is not None:
y = self.activation_fn(y)
if self.shallow_ensemble_size > 1:
return y.view(-1, self.dim_feature_out, self.shallow_ensemble_size)
else:
return y
[docs]
def l2loss(self) -> Tensor:
return F.mse_loss(self.weight, torch.zeros_like(self.weight), reduction="sum") / 2 / self.shallow_ensemble_size
[docs]
class ResidualLayer(NeuronLayer):
def __str__(self) -> str:
return "Residual layer: " + super().__str__()
[docs]
def __init__(
self, dim_feature_in: int, dim_feature_out: int,
activation_fn: Optional[ACTIVATION_KEY_TYPE]=None, activation_params: ACTIVATION_PARAM_TYPE=dict(),
initial_weight1: INITIAL_WEIGHT_TYPE="orthogonal",
initial_weight2: INITIAL_WEIGHT_TYPE="zero",
initial_bias: INITIAL_BIAS_TYPE="zero",
dropout_rate: float=0,
use_bias: bool=True,
use_residual: bool=True
) -> None:
super().__init__(dim_feature_in, dim_feature_out, activation_fn, activation_params)
self.use_residual = use_residual
dropout = Dropout(dropout_rate)
dense1 = DenseLayer(
dim_feature_in=dim_feature_in,
dim_feature_out=dim_feature_out,
activation_fn=activation_fn,
activation_params=activation_params,
initial_weight=initial_weight1,
initial_bias=initial_bias,
use_bias=use_bias
)
dense2 = DenseLayer(
dim_feature_in=dim_feature_out,
dim_feature_out=dim_feature_out,
activation_fn=None,
initial_weight=initial_weight2,
initial_bias=initial_bias,
use_bias=use_bias
)
if activation_fn is not None:
self.residual = Sequential(self.activation_fn, dropout, dense1, dense2)
else:
self.residual = Sequential(dropout, dense1, dense2)
[docs]
def forward(self, x: Tensor) -> Tensor:
if self.use_residual:
return x + self.residual(x)
else:
return self.residual(x)
[docs]
class ResidualStack(NeuronLayer):
def __str__(self) -> str:
return f"Residual stack ({self.num_residual} layers): " + super().__str__()
[docs]
def __init__(
self, dim_feature: int, num_residual: int,
activation_fn: Optional[ACTIVATION_KEY_TYPE]=None, activation_params: ACTIVATION_PARAM_TYPE=dict(),
initial_weight1: INITIAL_WEIGHT_TYPE="orthogonal",
initial_weight2: INITIAL_WEIGHT_TYPE="zero",
initial_bias: INITIAL_BIAS_TYPE="zero",
dropout_rate: float=0,
use_bias: bool=True,
use_residual: bool=True
) -> None:
super().__init__(dim_feature, dim_feature)
self.num_residual = num_residual
self.stack = Sequential(*(ResidualLayer(
dim_feature, dim_feature,
activation_fn, activation_params,
initial_weight1, initial_weight2, initial_bias,
dropout_rate, use_bias, use_residual=use_residual
) for _ in range(num_residual)))
[docs]
def forward(self, x: Tensor) -> Tensor:
return self.stack(x)
[docs]
class ResidualMLP(NeuronLayer):
def __str__(self) -> str:
return f"Residual MLP ({self.num_residual} residual layers): " + super().__str__()
[docs]
def __init__(
self, dim_feature_in: int, dim_feature_out:int, num_residual: int,
activation_fn: Optional[ACTIVATION_KEY_TYPE]=None, activation_params: ACTIVATION_PARAM_TYPE=dict(),
initial_weight1: INITIAL_WEIGHT_TYPE="orthogonal",
initial_weight2: INITIAL_WEIGHT_TYPE="zero",
initial_weight_out: INITIAL_WEIGHT_TYPE="zero",
initial_bias_residual: INITIAL_BIAS_TYPE="zero",
initial_bias_out: INITIAL_BIAS_TYPE="zero",
dropout_rate: float=0,
use_bias_residual: bool=True,
use_bias_out: bool=True,
shallow_ensemble_size: int=1,
use_residual: bool=True
) -> None:
super().__init__(dim_feature_in, dim_feature_out, activation_fn, activation_params)
self.stack = ResidualStack(
dim_feature_in, num_residual,
activation_fn, activation_params,
initial_weight1, initial_weight2, initial_bias_residual,
dropout_rate, use_bias_residual, use_residual=use_residual
)
self.output = DenseLayer(
dim_feature_in, dim_feature_out,
initial_weight=initial_weight_out, initial_bias=initial_bias_out,
use_bias=use_bias_out,
shallow_ensemble_size=shallow_ensemble_size
)
self.num_residual = num_residual
[docs]
def forward(self, x: Tensor) -> Tensor:
if self.activation_fn is None:
return self.output(self.stack(x))
else:
return self.output(self.activation_fn(self.stack(x)))