enerzyme.models.blocks.attention.Attention#
- class enerzyme.models.blocks.attention.Attention(dim_qk: int, num_random_features: int | None = None)[source]#
Bases:
ModuleEfficient (linear scaling) approximation for attention described in Choromanski, K., et al. “Rethinking Attention with Performers.”.
- Arguments:
- dim_qk (int):
Dimension of query/key vectors.
- dim_v (int):
Dimension of value vectors.
- num_random_featues (int):
Number of random features for approximating attention matrix. If this is 0, the exact attention matrix is computed.
- __init__(dim_qk: int, num_random_features: int | None = None) None[source]#
Initializes the Attention class.
- forward(Q: Tensor, K: Tensor, V: Tensor, num_batch: int, batch_seg: Tensor, mask: Tensor | None = None) Tensor[source]#
Compute attention for the given query, key and value vectors. N: Number of input values. dim_qk: Dimension of query/key vectors. dim_v: Dimension of value vectors.
- Arguments:
- Q (FloatTensor [N, dim_qk]):
Matrix of N query vectors.
- K (FloatTensor [N, dim_qk]):
Matrix of N key vectors.
- V (FloatTensor [N, dim_v]):
Matrix of N value vectors.
- num_batch (int):
Number of different batches in the input values.
- batch_seg (LongTensor [N]):
Index for each input that specifies to which batch it belongs. For example, when the input consists of a sequence of size 3 and another sequence of size 5, batch_seg would be [0, 0, 0, 1, 1, 1, 1, 1] (num_batch would be 2 then).
- Returns:
- y (FloatTensor [N, dim_v]):
Attention-weighted sum of value vectors.