enerzyme.models.blocks.attention.Attention#

class enerzyme.models.blocks.attention.Attention(dim_qk: int, num_random_features: int | None = None)[source]#

Bases: Module

Efficient (linear scaling) approximation for attention described in Choromanski, K., et al. “Rethinking Attention with Performers.”.

Arguments:
dim_qk (int):

Dimension of query/key vectors.

dim_v (int):

Dimension of value vectors.

num_random_featues (int):

Number of random features for approximating attention matrix. If this is 0, the exact attention matrix is computed.

__init__(dim_qk: int, num_random_features: int | None = None) None[source]#

Initializes the Attention class.

forward(Q: Tensor, K: Tensor, V: Tensor, num_batch: int, batch_seg: Tensor, mask: Tensor | None = None) Tensor[source]#

Compute attention for the given query, key and value vectors. N: Number of input values. dim_qk: Dimension of query/key vectors. dim_v: Dimension of value vectors.

Arguments:
Q (FloatTensor [N, dim_qk]):

Matrix of N query vectors.

K (FloatTensor [N, dim_qk]):

Matrix of N key vectors.

V (FloatTensor [N, dim_v]):

Matrix of N value vectors.

num_batch (int):

Number of different batches in the input values.

batch_seg (LongTensor [N]):

Index for each input that specifies to which batch it belongs. For example, when the input consists of a sequence of size 3 and another sequence of size 5, batch_seg would be [0, 0, 0, 1, 1, 1, 1, 1] (num_batch would be 2 then).

Returns:
y (FloatTensor [N, dim_v]):

Attention-weighted sum of value vectors.