fused_multi_head_attention

paddle.incubate.nn.functional. fused_multi_head_attention ( x, qkv_weight, linear_weight, pre_layer_norm=False, pre_ln_scale=None, pre_ln_bias=None, ln_scale=None, ln_bias=None, pre_ln_epsilon=1e-05, qkv_bias=None, linear_bias=None, attn_mask=None, dropout_rate=0.5, attn_dropout_rate=0.5, ln_epsilon=1e-05, training=True, mode='upscale_in_train', name=None ) [source]

Attention mapps queries and a set of key-value pairs to outputs, and Multi-Head Attention performs multiple parallel attention to jointly attending to information from different representation subspaces. This API only support self_attention. The pseudo code is as follows:

if pre_layer_norm:
    out = layer_norm(x)
    out = linear(out) + qkv) + bias
else:
    out = linear(x) + bias
out = transpose(out, perm=[2, 0, 3, 1, 4])
# extract q, k and v from out.
q = out[0:1,::]
k = out[1:2,::]
v = out[2:3,::]
out = q * k^t
out = attn_mask + out
out = softmax(out)
out = dropout(out)
out = out * v
out = transpose(out, perm=[0, 2, 1, 3])
out = out_linear(out)
if pre_layer_norm:
    out = x + dropout(linear_bias + out)
else:
    out = layer_norm(x + dropout(linear_bias + out))
Parameters
  • x (Tensor) – The input tensor of fused_multi_head_attention. The shape is [batch_size, sequence_len, embed_dim].

  • qkv_weight (Tensor) – The qkv weight tensor. The shape is [3, num_head, dim_head, dim_embed].

  • linear_weight (Tensor) – The linear weight tensor. The shape is [embed_dim, embed_dim].

  • pre_layer_norm (bool, optional) – whether it is pre_layer_norm (True) or post_layer_norm architecture (False). Default False.

  • pre_ln_scale (Tensor, optional) – The weight tensor of pre layernorm. Default None.

  • pre_ln_bias (Tensor, optional) – The bias tensor of pre layernorm. Default None.

  • ln_scale (Tensor, optional) – The weight tensor of layernorm. Default None.

  • ln_bias (Tensor, optional) – The bias tensor of layernorm. Default None.

  • pre_ln_epsilon (float, optional) – Small float value added to denominator of the pre layer_norm to avoid dividing by zero. Default is 1e-5.

  • qkv_bias (Tensor, optional) – The bias of qkv computation. The shape is [3, num_head, dim_head]. Default None.

  • linear_bias (Tensor, optional) – The bias of linear. The shape is [embed_dim]. Default None.

  • attn_mask (Tensor, optional) – A tensor used in multi-head attention to prevents attention to some unwanted positions, usually the paddings or the subsequent positions. It is a tensor with shape broadcasted to [batch_size, n_head, sequence_length, sequence_length]. When the data type is bool, the unwanted positions have False values and the others have True values. When the data type is int, the unwanted positions have 0 values and the others have 1 values. When the data type is float, the unwanted positions have -INF values and the others have 0 values. It can be None when nothing wanted or needed to be prevented attention to. Default None.

  • dropout_rate (float, optional) – The dropout probability used on attention weights to drop some attention targets for the dropout after attention. 0 for no dropout. Default 0.5.

  • attn_dropout_rate (float, optional) – The dropout probability used on attention weights to drop some attention targets for the dropout in attention. 0 for no dropout. Default 0.5.

  • ln_epsilon (float, optional) – Small float value added to denominator of layer_norm to avoid dividing by zero. Default is 1e-5.

  • training (bool, optional) – A flag indicating whether it is in train phrase or not. Default True.

  • mode (str, optional) –

    [‘upscale_in_train’(default) | ‘downscale_in_infer’]

    1. upscale_in_train(default), upscale the output at training time

      • train: out = input * mask / ( 1.0 - p )

      • inference: out = input

    2. downscale_in_infer, downscale the output at inference

      • train: out = input * mask

      • inference: out = input * (1.0 - p)

  • name (str, optional) – Name for the operation (optional, default is None). For more information, please refer to Name.

Returns

The output Tensor, the data type and shape is same as x.

Return type

Tensor

Examples

# required: gpu
import paddle
import paddle.incubate.nn.functional as F

# input: [batch_size, seq_len, embed_dim]
x = paddle.rand(shape=(2, 4, 128), dtype="float32")
# qkv_weight: [3, num_head, head_dim, embed_dim]
qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32")
# qkv_bias: [3, num_head, head_dim]
qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32")
# linear_weight: [embed_dim, embed_dim]
linear_weight = paddle.rand(shape=(128, 128), dtype="float32")
# linear_bias: [embed_dim]
linear_bias = paddle.rand(shape=[128], dtype="float32")
# self attention mask: [batch_size, num_heads, seq_len, seq_len]
attn_mask = paddle.rand(shape=(2, 4, 4, 4), dtype="float32")

# output: [batch_size, seq_len, embed_dim]
output = F.fused_multi_head_attention(
    x, qkv_weight, linear_weight, False,
    None, None, None, None, 1e-5, qkv_bias,
    linear_bias, attn_mask)
# [2, 4, 128]
print(output.shape)