fused_multi_head_attention¶
- paddle.incubate.nn.functional. fused_multi_head_attention ( x, qkv_weight, linear_weight, pre_layer_norm=False, pre_ln_scale=None, pre_ln_bias=None, ln_scale=None, ln_bias=None, pre_ln_epsilon=1e-05, qkv_bias=None, linear_bias=None, cache_kv=None, attn_mask=None, dropout_rate=0.5, attn_dropout_rate=0.5, ln_epsilon=1e-05, training=True, mode='upscale_in_train', ring_id=- 1, add_residual=True, num_heads=- 1, transpose_qkv_wb=False, name=None ) [source]
-
Attention mapps queries and a set of key-value pairs to outputs, and Multi-Head Attention performs multiple parallel attention to jointly attending to information from different representation subspaces. This API only support self_attention. The pseudo code is as follows:
residual = x if pre_layer_norm: out = layer_norm(x) else: out = x # compute q, k, v out = matmul(out, qkv_weight) + qkv_bias out = transpose(out, perm=[2, 0, 3, 1, 4]) # extract q, k and v from out q = out[0:1,::] * (head_dim ** -0.5) k = out[1:2,::] v = out[2:3,::] out = matmul(q, k, transpose_y=True) out = out + attn_mask out = softmax(out) out = dropout(out) out = matmul(out, v) # combine heads out = transpose(out, perm=[0, 2, 1, 3]) # project to output out = linear(out) if add_residual: out = residual + dropout(out) else: out = dropout(out) if not pre_layer_norm: out = layer_norm(out)
- Parameters
-
x (Tensor) – The input tensor of fused_multi_head_attention. The shape is [batch_size, sequence_len, embed_dim].
qkv_weight (Tensor) – The qkv weight tensor. If transpose_qkv_wb is False, the shape is [3, num_head, dim_head, dim_embed]. Otherwise, the shape is [dim_embed, 3 * dim_embed].
linear_weight (Tensor) – The linear weight tensor. The shape is [embed_dim, embed_dim].
pre_layer_norm (bool, optional) – whether it is pre_layer_norm (True) or post_layer_norm architecture (False). Default False.
pre_ln_scale (Tensor, optional) – The weight tensor of pre layernorm. Default None.
pre_ln_bias (Tensor, optional) – The bias tensor of pre layernorm. Default None.
ln_scale (Tensor, optional) – The weight tensor of layernorm. Default None.
ln_bias (Tensor, optional) – The bias tensor of layernorm. Default None.
pre_ln_epsilon (float, optional) – Small float value added to denominator of the pre layer_norm to avoid dividing by zero. Default is 1e-5.
qkv_bias (Tensor, optional) – The bias of qkv computation. If transpose_qkv_wb is False, the shape is [3, num_head, dim_head]. Otherwise, the shape is [3 * dim_embed]. Default None.
linear_bias (Tensor, optional) – The bias of linear. The shape is [embed_dim]. Default None.
cache_kv (Tensor, optional) – For generation model, cache structure. The shape is [2, bsz, num_head, seq_len, head_dim]. Default None.
attn_mask (Tensor, optional) – A tensor used in multi-head attention to prevents attention to some unwanted positions, usually the paddings or the subsequent positions. It is a tensor with shape broadcasted to [batch_size, n_head, sequence_length, sequence_length]. When the data type is bool, the unwanted positions have False values and the others have True values. When the data type is int, the unwanted positions have 0 values and the others have 1 values. When the data type is float, the unwanted positions have -INF values and the others have 0 values. It can be None when nothing wanted or needed to be prevented attention to. Default None.
dropout_rate (float, optional) – The dropout probability used on attention weights to drop some attention targets for the dropout after attention. 0 for no dropout. Default 0.5.
attn_dropout_rate (float, optional) – The dropout probability used on attention weights to drop some attention targets for the dropout in attention. 0 for no dropout. Default 0.5.
ln_epsilon (float, optional) – Small float value added to denominator of layer_norm to avoid dividing by zero. Default is 1e-5.
training (bool, optional) – A flag indicating whether it is in train phrase or not. Default True.
mode (str, optional) –
[‘upscale_in_train’(default) | ‘downscale_in_infer’]
upscale_in_train(default), upscale the output at training time
train: out = input * mask / ( 1.0 - p )
inference: out = input
downscale_in_infer, downscale the output at inference
train: out = input * mask
inference: out = input * (1.0 - p)
ring_id (int, optional) – For distributed forward in mp, only support NCCL and forward. Default is -1, means not using mp
add_residual (bool, optional) – Whether add residual at the end. Default is True.
num_heads (int, optional) – If enable transpose_qkv_wb, should provide the num_heads. Default is -1, means not transpose qkv wb.
transpose_qkv_wb (bool, optional) – Whether transpose the qkv_weight and qkv_bias in the op. Only support GPU for now. Default is false, means not transpose qkv wb.
name (str, optional) – For details, please refer to Name. Generally, no setting is required. Default: None.
- Returns
-
The output Tensor, the data type and shape is same as x.
- Return type
-
Tensor
Examples
# required: gpu import paddle import paddle.incubate.nn.functional as F # input: [batch_size, seq_len, embed_dim] x = paddle.rand(shape=(2, 4, 128), dtype="float32") # qkv_weight: [3, num_head, head_dim, embed_dim] qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32") # qkv_bias: [3, num_head, head_dim] qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32") # linear_weight: [embed_dim, embed_dim] linear_weight = paddle.rand(shape=(128, 128), dtype="float32") # linear_bias: [embed_dim] linear_bias = paddle.rand(shape=[128], dtype="float32") # self attention mask: [batch_size, num_heads, seq_len, seq_len] attn_mask = paddle.rand(shape=(2, 4, 4, 4), dtype="float32") # output: [batch_size, seq_len, embed_dim] output = F.fused_multi_head_attention( x, qkv_weight, linear_weight, False, None, None, None, None, 1e-5, qkv_bias, linear_bias, None, attn_mask) # [2, 4, 128] print(output.shape)