Source code for enerzyme.models.blocks.attention

import math
from typing import Optional
import torch
from torch.nn import Module
import torch.nn.functional as F
import numpy as np



[docs] class Attention(Module): """ Efficient (linear scaling) approximation for attention described in Choromanski, K., et al. "Rethinking Attention with Performers.". Arguments: dim_qk (int): Dimension of query/key vectors. dim_v (int): Dimension of value vectors. num_random_featues (int): Number of random features for approximating attention matrix. If this is 0, the exact attention matrix is computed. """
[docs] def __init__( self, dim_qk: int, num_random_features: Optional[int] = None ) -> None: """ Initializes the Attention class. """ super(Attention, self).__init__() self.num_random_features = num_random_features if self.num_random_features is not None: omega = self._omega(num_random_features, dim_qk) else: omega = [] self.register_buffer("omega", torch.tensor(omega, dtype=torch.float64))
def _omega(self, nrows: int, ncols: int) -> np.ndarray: """ Return a (nrows x ncols) random feature matrix. """ nblocks = int(nrows / ncols) blocks = [] for i in range(nblocks): block = np.random.normal(size=(ncols, ncols)) q, _ = np.linalg.qr(block) blocks.append(np.transpose(q)) missing_rows = nrows - nblocks * ncols if missing_rows > 0: block = np.random.normal(size=(ncols, ncols)) q, _ = np.linalg.qr(block) blocks.append(np.transpose(q)[:missing_rows]) norm = np.linalg.norm( # renormalize rows so they still follow N(0,1) np.random.normal(size=(nrows, ncols)), axis=1, keepdims=True ) return (norm * np.vstack(blocks)).T def _phi( self, X: torch.Tensor, is_query: bool, num_batch: int, batch_seg: torch.Tensor, eps: float = 1e-4, ) -> torch.Tensor: """ Normalize X and project into random feature space. """ d = X.shape[-1] m = self.omega.shape[-1] U = torch.matmul(X / d ** 0.25, self.omega.type_as(X)) h = torch.sum(X ** 2, dim=-1, keepdim=True) / (2 * d ** 0.5) # OLD # determine maximum (is subtracted to prevent numerical overflow) if is_query: maximum, _ = torch.max(U, dim=-1, keepdim=True) else: if num_batch > 1: brow = batch_seg.view(1, -1, 1).expand(num_batch, -1, U.shape[-1]) bcol = ( torch.arange( num_batch, dtype=batch_seg.dtype, device=batch_seg.device ) .view(-1, 1, 1) .expand(-1, U.shape[-2], U.shape[-1]) ) mask = torch.where( brow == bcol, torch.ones_like(U), torch.zeros_like(U) ) tmp = U.unsqueeze(0).expand(num_batch, -1, -1) tmp, _ = torch.max(mask * tmp, dim=-1) tmp, _ = torch.max(tmp, dim=-1) if tmp.device.type == "cpu": # indexing faster on CPU maximum = tmp[batch_seg].unsqueeze(-1) else: # gathering is faster on GPUs maximum = torch.gather(tmp, 0, batch_seg).unsqueeze(-1) else: maximum = torch.max(U) return (torch.exp(U - h - maximum) + eps) / math.sqrt(m) def _exact_attention( self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, num_batch: int, batch_seg: torch.Tensor, eps: float = 1e-8, ): """ Compute exact attention. """ d = Q.shape[-1] dot = Q @ K.T # dot product A = torch.exp((dot - torch.max(dot)) / d ** 0.5) # attention matrix if num_batch > 1: # mask out entries of different batches brow = batch_seg.view(1, -1).expand(A.shape[-2], -1) bcol = batch_seg.view(-1, 1).expand(-1, A.shape[-1]) mask = torch.where(brow == bcol, torch.ones_like(A), torch.zeros_like(A)) A = A * mask norm = torch.sum(A, -1, keepdim=True) + eps return (A / norm) @ V def _approximate_attention( self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, num_batch: int, batch_seg: torch.Tensor, mask: Optional[torch.Tensor] = None, eps: float = 1e-8, ) -> torch.Tensor: """ Compute approximate attention. """ Q = self._phi(Q, True, num_batch, batch_seg) # random projection of Q K = self._phi(K, False, num_batch, batch_seg) # random projection of K if num_batch > 1: d = Q.shape[-1] # compute norm idx = batch_seg.unsqueeze(-1).expand(-1, d) tmp = K.new_zeros(num_batch, d).scatter_add_(0, idx, K) norm = torch.gather(Q @ tmp.T, -1, batch_seg.unsqueeze(-1)) + eps # the ops below are equivalent to this loop (but more efficient): # return torch.cat([Q[b==batch_seg]@( # K[b==batch_seg].transpose(-1,-2)@V[b==batch_seg]) # for b in range(num_batch)])/norm if mask is None: # mask can be shared across multiple attentions one_hot = F.one_hot(batch_seg).to( dtype=V.dtype, device=V.device ) mask = one_hot @ one_hot.transpose(-1, -2) return ((mask * (K @ Q.transpose(-1, -2))).transpose(-1, -2) @ V) / norm else: norm = Q @ torch.sum(K, 0, keepdim=True).T + eps return (Q @ (K.T @ V)) / norm
[docs] def forward( self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, num_batch: int, batch_seg: torch.Tensor, mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Compute attention for the given query, key and value vectors. N: Number of input values. dim_qk: Dimension of query/key vectors. dim_v: Dimension of value vectors. Arguments: Q (FloatTensor [N, dim_qk]): Matrix of N query vectors. K (FloatTensor [N, dim_qk]): Matrix of N key vectors. V (FloatTensor [N, dim_v]): Matrix of N value vectors. num_batch (int): Number of different batches in the input values. batch_seg (LongTensor [N]): Index for each input that specifies to which batch it belongs. For example, when the input consists of a sequence of size 3 and another sequence of size 5, batch_seg would be [0, 0, 0, 1, 1, 1, 1, 1] (num_batch would be 2 then). Returns: y (FloatTensor [N, dim_v]): Attention-weighted sum of value vectors. """ if self.num_random_features is None: return self._exact_attention(Q, K, V, num_batch, batch_seg) else: return self._approximate_attention(Q, K, V, num_batch, batch_seg, mask)