Source code for enerzyme.models.init

import numpy as np
import torch
from torch import Tensor

#generates a random square orthogonal matrix of dimension dim
[docs] def square_orthogonal_matrix(dim=3): H = np.eye(dim) D = np.ones((dim,)) for n in range(1, dim): x = np.random.normal(size=(dim - n + 1,)) D[n - 1] = np.sign(x[0]) x[0] -= D[n - 1]*np.sqrt((x * x).sum()) # Householder transformation Hx = (np.eye(dim - n + 1) - 2.*np.outer(x, x)/(x * x).sum()) mat = np.eye(dim) mat[n - 1:, n - 1:] = Hx H = np.dot(H, mat) # Fix the last sign such that the determinant is 1 D[-1] = (-1) ** (1 - (dim % 2)) * D.prod() # Equivalent to np.dot(np.diag(D), H) but faster, apparently H = (D * H.T).T return H
#generates a random (semi-)orthogonal matrix of size NxM
[docs] def semi_orthogonal_matrix(N, M, seed=None): if N > M: #number of rows is larger than number of columns square_matrix = square_orthogonal_matrix(dim=N) else: #number of columns is larger than number of rows square_matrix = square_orthogonal_matrix(dim=M) return square_matrix[:N,:M]
#generates a weight matrix with variance according to Glorot initialization #based on a random (semi-)orthogonal matrix #neural networks are expected to learn better when features are decorrelated #(stated by eg. "Reducing overfitting in deep networks by decorrelating representations", #"Dropout: a simple way to prevent neural networks from overfitting", #"Exact solutions to the nonlinear dynamics of learning in deep linear neural networks")
[docs] def semi_orthogonal_glorot_weights(n_in: int, n_out: int, scale: float=2.0) -> Tensor: W = semi_orthogonal_matrix(n_in, n_out) W *= np.sqrt(scale / ((n_in + n_out) * W.var())) return torch.tensor(W.T, requires_grad=True)