flash_attn_varlen_qkvpacked

paddle.nn.functional. flash_attn_varlen_qkvpacked ( qkv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, scale, dropout=0.0, causal=False, return_softmax=False, fixed_seed_offset=None, rng_name='', varlen_padded=True, training=True, name=None ) [source]

The equation is:

\[result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V\]

where : Q, K, and V represent the three input parameters of the attention module. The dimensions of the three parameters are the same. d represents the size of the last dimension of the three parameters.

Warning

This API only supports inputs with dtype float16 and bfloat16.

Parameters
  • qkv (Tensor) – The padded query/key/value packed tensor in the Attention module. The padding part won’t be computed 4-D tensor with shape: [total_seq_len, num_heads/num_heads_k + 2, num_heads_k, head_dim]. The dtype can be float16 or bfloat16.

  • cu_seqlens_q (Tensor) – The cumulative sequence lengths of the sequences in the batch, used to index query.

  • cu_seqlens_k (Tensor) – The cumulative sequence lengths of the sequences in the batch, used to index key and value.

  • max_seqlen_q (int) – Maximum sequence length of query in the batch. Note it’s the padding length, not the max actual seqlen

  • max_seqlen_k (int) – Maximum sequence length of key/value in the batch.

  • scale (float) – The scaling of QK^T before applying softmax.

  • dropout (float) – The dropout ratio.

  • causal (bool) – Whether enable causal mode.

  • return_softmax (bool) – Whether to return softmax.

  • fixed_seed_offset (Tensor, optional) – With fixed seed, offset for dropout mask.

  • rng_name (str) – The name to select Generator.

  • training (bool) – Whether it is in the training phase.

  • name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name.

Returns

[total_seq_len, num_heads, head_dim]. The dtype can be float16 or bfloat16. - softmax(Tensor). The softmax tensor. None if return_softmax is False.

Return type

  • out(Tensor). The attention tensor. The tensor is padded by zeros. 3-D tensor with shape

Examples

>>> 
>>> import paddle
>>> paddle.seed(2023)
>>> q = paddle.rand((2, 128, 8, 16), dtype='float16')
>>> cu = paddle.arange(0, 384, 128, dtype='int32')
>>> qq = paddle.reshape(q, [256, 8, 16])
>>> qkv = paddle.stack([qq, qq, qq], axis=2)
>>> output = paddle.nn.functional.flash_attn_varlen_qkvpacked(qkv, cu, cu, 128, 128, 0.25, 0.0, False, False)
>>>