fused_multi_head_attention

paddle.incubate.nn.functional. fused_multi_head_attention ( x, qkv_weight, linear_weight, pre_layer_norm=False, pre_ln_scale=None, pre_ln_bias=None, ln_scale=None, ln_bias=None, pre_ln_epsilon=1e-05, qkv_bias=None, linear_bias=None, cache_kv=None, attn_mask=None, dropout_rate=0.5, attn_dropout_rate=0.5, ln_epsilon=1e-05, training=True, mode='upscale_in_train', ring_id=- 1, add_residual=True, num_heads=- 1, transpose_qkv_wb=False, name=None ) [source]

Attention maps queries and a set of key-value pairs to outputs, and Multi-Head Attention performs multiple parallel attention to jointly attending to information from different representation subspaces. This API only support self_attention. The pseudo code is as follows:

>>> residual = x
>>> if pre_layer_norm:
...     out = layer_norm(x)
... else:
...     out = x
>>> # compute q, k, v
>>> out = matmul(out, qkv_weight) + qkv_bias
>>> out = transpose(out, perm=[2, 0, 3, 1, 4])
>>> # extract q, k and v from out
>>> q = out[0:1,::] * (head_dim ** -0.5)
>>> k = out[1:2,::]
>>> v = out[2:3,::]
>>> out = matmul(q, k, transpose_y=True)
>>> out = out + attn_mask
>>> out = softmax(out)
>>> out = dropout(out)
>>> out = matmul(out, v)
>>> # combine heads
>>> out = transpose(out, perm=[0, 2, 1, 3])
>>> # project to output
>>> out = linear(out)
>>> if add_residual:
...     out = residual + dropout(out)
... else:
...     out = dropout(out)
>>> if not pre_layer_norm:
...     out = layer_norm(out)
Parameters
  • x (Tensor) – The input tensor of fused_multi_head_attention. The shape is [batch_size, sequence_len, embed_dim].

  • qkv_weight (Tensor) – The qkv weight tensor. If transpose_qkv_wb is False, the shape is [3, num_head, dim_head, dim_embed]. Otherwise, the shape is [dim_embed, 3 * dim_embed].

  • linear_weight (Tensor) – The linear weight tensor. The shape is [embed_dim, embed_dim].

  • pre_layer_norm (bool, optional) – whether it is pre_layer_norm (True) or post_layer_norm architecture (False). Default False.

  • pre_ln_scale (Tensor, optional) – The weight tensor of pre layernorm. Default None.

  • pre_ln_bias (Tensor, optional) – The bias tensor of pre layernorm. Default None.

  • ln_scale (Tensor, optional) – The weight tensor of layernorm. Default None.

  • ln_bias (Tensor, optional) – The bias tensor of layernorm. Default None.

  • pre_ln_epsilon (float, optional) – Small float value added to denominator of the pre layer_norm to avoid dividing by zero. Default is 1e-5.

  • qkv_bias (Tensor, optional) – The bias of qkv computation. If transpose_qkv_wb is False, the shape is [3, num_head, dim_head]. Otherwise, the shape is [3 * dim_embed]. Default None.

  • linear_bias (Tensor, optional) – The bias of linear. The shape is [embed_dim]. Default None.

  • cache_kv (Tensor, optional) – For generation model, cache structure. The shape is [2, bsz, num_head, seq_len, head_dim]. Default None.

  • attn_mask (Tensor, optional) – A tensor used in multi-head attention to prevents attention to some unwanted positions, usually the paddings or the subsequent positions. It is a tensor with shape broadcasted to [batch_size, n_head, sequence_length, sequence_length]. When the data type is bool, the unwanted positions have False values and the others have True values. When the data type is int, the unwanted positions have 0 values and the others have 1 values. When the data type is float, the unwanted positions have -INF values and the others have 0 values. It can be None when nothing wanted or needed to be prevented attention to. Default None.

  • dropout_rate (float, optional) – The dropout probability used on attention weights to drop some attention targets for the dropout after attention. 0 for no dropout. Default 0.5.

  • attn_dropout_rate (float, optional) – The dropout probability used on attention weights to drop some attention targets for the dropout in attention. 0 for no dropout. Default 0.5.

  • ln_epsilon (float, optional) – Small float value added to denominator of layer_norm to avoid dividing by zero. Default is 1e-5.

  • training (bool, optional) – A flag indicating whether it is in train phrase or not. Default True.

  • mode (str, optional) –

    [‘upscale_in_train’(default) | ‘downscale_in_infer’]

    1. upscale_in_train(default), upscale the output at training time

      • train: out = input * mask / ( 1.0 - p )

      • inference: out = input

    2. downscale_in_infer, downscale the output at inference

      • train: out = input * mask

      • inference: out = input * (1.0 - p)

  • ring_id (int, optional) – For distributed forward in mp, only support NCCL and forward. Default is -1, means not using mp

  • add_residual (bool, optional) – Whether add residual at the end. Default is True.

  • num_heads (int, optional) – If enable transpose_qkv_wb, should provide the num_heads. Default is -1, means not transpose qkv wb.

  • transpose_qkv_wb (bool, optional) – Whether transpose the qkv_weight and qkv_bias in the op. Only support GPU for now. Default is false, means not transpose qkv wb.

  • name (str, optional) – For details, please refer to Name. Generally, no setting is required. Default: None.

Returns

The output Tensor, the data type and shape is same as x.

Return type

Tensor

Examples

>>> 
>>> import paddle
>>> paddle.device.set_device('gpu')
>>> import paddle.incubate.nn.functional as F

>>> # input: [batch_size, seq_len, embed_dim]
>>> x = paddle.rand(shape=(2, 4, 128), dtype="float32")
>>> # qkv_weight: [3, num_head, head_dim, embed_dim]
>>> qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32")
>>> # qkv_bias: [3, num_head, head_dim]
>>> qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32")
>>> # linear_weight: [embed_dim, embed_dim]
>>> linear_weight = paddle.rand(shape=(128, 128), dtype="float32")
>>> # linear_bias: [embed_dim]
>>> linear_bias = paddle.rand(shape=[128], dtype="float32")
>>> # self attention mask: [batch_size, num_heads, seq_len, seq_len]
>>> attn_mask = paddle.rand(shape=(2, 4, 4, 4), dtype="float32")

>>> # output: [batch_size, seq_len, embed_dim]
>>> output = F.fused_multi_head_attention(
...     x, qkv_weight, linear_weight, False,
...     None, None, None, None, 1e-5, qkv_bias,
...     linear_bias, None, attn_mask)
>>> print(output.shape)
[2, 4, 128]