attention¶
- paddle.sparse.nn.functional. attention ( query, key, value, sparse_mask, key_padding_mask=None, attn_mask=None, name=None ) [源代码] ¶
注解
该 API 从 CUDA 11.7 开始支持。
稀疏 Attention,该 API 内部使用 SparseCsrTensor 来存储 Transformer 模块中的 attention 矩阵,从而达到减少显存占用、提高性能的目的。 参数 sparse_mask 描述了稀疏矩阵的非 0 元素索引布局。
其中:矩阵 Q K V 表示 attention 模块的三个输入 Tensor,其 shape 均为 [batch_size, num_heads, seq_len, head_dim] , 公式中的 d 代表 head_dim 。
参数¶
query (DenseTensor) - Attention 模块的 query 输入,4D Tensor,数据类型为 float32、float64。
key (DenseTensor) - Attention 模块的 key 输入,4D Tensor,数据类型为 float32、float64。
value (DenseTensor) - Attention 模块的 value 输入,4D Tensor,数据类型为 float32、float64。
sparse_mask (SparseCsrTensor) - Attention 模块的非 0 元素布局,是一个 3D 的 SparseCsrTensor,shape 为 [batch_size*num_heads, seq_len, seq_len] 。 同时每个批次的非 0 元素个数均相等。crows 和 cols 的数据类型为 int64,value 的数据类型为 float32、float64。
key_padding_mask (DenseTensor, 可选) - Attention 模块中的 key padding mask,是一个 2D 的 DenseTensor,shape 为 [batch_size, seq_len] 。 数据类型为 float32、float64。默认:None,表示无此掩码运算。
attn_mask (DenseTensor, 可选) - Attention 模块中的 attention mask,是一个 2D 的 DenseTensor,shape 为 [seq_len, seq_len] 。 数据类型为 float32、float64。默认:None,表示无此掩码运算。
name (str,可选) - 具体用法请参见 Name,一般无需设置,默认值为 None。
返回¶
DenseTensor: 维度为 4,shape 为 [batch_size, num_heads, seq_len, head_dim] ,dtype 与输入相同。
代码示例¶
# required: gpu
import paddle
batch_size = 16
num_heads = 16
seq_len = 512
head_dim = 32
query = paddle.rand([batch_size, num_heads, seq_len, head_dim])
key = paddle.rand([batch_size, num_heads, seq_len, head_dim])
value = paddle.rand([batch_size, num_heads, seq_len, head_dim])
query.stop_gradient = False
key.stop_gradient = False
value.stop_gradient = False
mask = paddle.nn.functional.dropout(paddle.ones([seq_len, seq_len])).expand([batch_size, num_heads, seq_len, seq_len])
sp_mask = mask.reshape([-1, seq_len, seq_len]).to_sparse_csr()
kp_mask = paddle.randint(0, 2, [batch_size, seq_len])
attn_mask = paddle.randint(0, 2, [seq_len, seq_len])
output = paddle.sparse.nn.functional.attention(query, key, value, sp_mask, kp_mask, attn_mask)
output.backward()