scaled_dot_product_attention

paddle.nn.functional. scaled_dot_product_attention ( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, training=True, name=None ) [source]

The equation is:

\[result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V\]

where : Q, K, and V represent the three input parameters of the attention module. The dimensions of the three parameters are the same. d represents the size of the last dimension of the three parameters.

Warning

This API only supports inputs with dtype float16 and bfloat16.

Parameters
  • query (Tensor) – The query tensor in the Attention module. 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float61 or bfloat16.

  • key (Tensor) – The key tensor in the Attention module. 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float61 or bfloat16.

  • value (Tensor) – The value tensor in the Attention module. 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float61 or bfloat16.

  • attn_mask (Tensor,optional) – A float mask of the same type as query, key, value that is added to the attention score.

  • dropout_p (float) – The dropout ratio.

  • is_causal (bool) – Whether enable causal mode.

  • training (bool) – Whether it is in the training phase.

  • name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name.

Returns

The attention tensor.

4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float16 or bfloat16.

Return type

out(Tensor)

Examples

>>> 
>>> import paddle
>>> q = paddle.rand((1, 128, 2, 16), dtype=paddle.bfloat16)
>>> output = paddle.nn.functional.scaled_dot_product_attention(q, q, q, None, 0.9, False)
>>> print(output)
>>>