flash_attn_qkvpacked

paddle.nn.functional. flash_attn_qkvpacked ( qkv, dropout=0.0, causal=False, return_softmax=False, *, fixed_seed_offset=None, rng_name='', training=True, name=None ) [source]

The equation is:

\[result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V\]

where : Q, K, and V represent the three input parameters of the attention module. The dimensions of the three parameters are the same. d represents the size of the last dimension of the three parameters.

Warning

This API only supports inputs with dtype float16 and bfloat16. Don’t call this API if flash_attn is not supported.

Parameters
  • qkv (Tensor) – The query/key/value packed tensor in the Attention module. 5-D tensor with shape: [batchsize, seqlen , num_heads/num_heads_k + 2, num_heads_k, head_dim]. The dtype can be float16 or bfloat16.

  • dropout (float) – The dropout ratio.

  • causal (bool) – Whether enable causal mode.

  • return_softmax (bool) – Whether to return softmax.

  • fixed_seed_offset (Tensor, optional) – With fixed seed, offset for dropout mask.

  • training (bool) – Whether it is in the training phase.

  • rng_name (str) – The name to select Generator.

  • name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name.

Returns

[batch_size, seq_len, num_heads, head_dim]. The dtype can be float16 or bfloat16. - softmax(Tensor). The softmax tensor. None if return_softmax is False.

Return type

  • out(Tensor). The attention tensor. 4-D tensor with shape

Examples

>>> 
>>> import paddle

>>> paddle.seed(2023)
>>> q = paddle.rand((1, 128, 2, 16))
>>> qkv = paddle.stack([q, q, q], axis=2)
>>> output = paddle.nn.functional.flash_attn_qkvpacked(qkv, 0.9, False, False)
>>> print(output)
(Tensor(shape=[1, 128, 2, 16], dtype=float32, place=Place(cpu), stop_gradient=True,
[[[[0.34992966, 0.34456208, 0.45826620, ..., 0.39883569,
    0.42132431, 0.39157745],
   [0.76687670, 0.65837246, 0.69117945, ..., 0.82817286,
    0.76690865, 0.71485823]],
  ...,
  [[0.71662450, 0.57275224, 0.57053083, ..., 0.48108247,
    0.53336465, 0.54540104],
   [0.59137970, 0.51350880, 0.50449550, ..., 0.38860250,
    0.40526697, 0.60541755]]]]), None)
>>>