attention¶
- paddle.sparse.nn.functional. attention ( query, key, value, sparse_mask, key_padding_mask=None, attn_mask=None, name=None ) [source]
-
Note
This API is only used from
CUDA 11.8
.SparseCsrTensor is used to store the intermediate result of Attention matrix in Transformer module, which can reduce memory usage and improve performance.
sparse_mask
express the sparse layout in CSR format. The calculation equation is:\[result = softmax(\frac{ Q * K^T }{\sqrt{d}}) * V\]where :
Q
,K
, andV
represent the three input parameters of the attention module. The shape of the three parameters are: [batch_size, num_heads, seq_len, head_dim], andd
representshead_dim
.- Parameters
-
query (DenseTensor) – query in the Attention module. 4D Tensor with float32 or float64.
key (DenseTensor) – key in the Attention module. 4D Tensor with float32 or float64.
value (DenseTensor) – value in the Attention module. 4D Tensor with float32 or float64.
sparse_mask (SparseCsrTensor) – The sparse layout in the Attention module. Its dense shape is [batch_size*num_heads, seq_len, seq_len]. nnz of each batch must be the same. dtype of crows and cols must be int64, dtype of values can be float32 or float64.
key_padding_mask (DenseTensor, optional) – The key padding mask tensor in the Attention module. 2D tensor with shape: [batch_size, seq_len]. dtype can be float32 or float64. Default: None.
attn_mask (DenseTensor, optional) – The attention mask tensor in the Attention module. 2D tensor with shape: [seq_len, seq_len]. dtype can be float32 or float64. Default: None.
name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name.
- Returns
-
[batch_size, num_heads, seq_len, head_dim]. dtype is same with input.
- Return type
-
4D tensor with shape
Examples
>>> >>> import paddle >>> paddle.device.set_device('gpu') >>> batch_size = 16 >>> num_heads = 16 >>> seq_len = 512 >>> head_dim = 32 >>> query = paddle.rand([batch_size, num_heads, seq_len, head_dim]) >>> key = paddle.rand([batch_size, num_heads, seq_len, head_dim]) >>> value = paddle.rand([batch_size, num_heads, seq_len, head_dim]) >>> query.stop_gradient = False >>> key.stop_gradient = False >>> value.stop_gradient = False >>> mask = paddle.nn.functional.dropout(paddle.ones([seq_len, seq_len])).expand([batch_size, num_heads, seq_len, seq_len]) >>> sp_mask = mask.reshape([-1, seq_len, seq_len]).to_sparse_csr() >>> kp_mask = paddle.randint(0, 2, [batch_size, seq_len]).astype(paddle.float32) >>> attn_mask = paddle.randint(0, 2, [seq_len, seq_len]).astype(paddle.float32) >>> output = paddle.sparse.nn.functional.attention(query, key, value, sp_mask, kp_mask, attn_mask) >>> output.backward()