Transformer¶
- class paddle.nn. Transformer ( d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, activation='relu', attn_dropout=None, act_dropout=None, normalize_before=False, weight_attr=None, bias_attr=None, custom_encoder=None, custom_decoder=None ) [source]
-
A Transformer model composed of an instance of TransformerEncoder and an instance of TransformerDecoder. While the embedding layer and output layer are not included.
Please refer to Attention is all you need , and see TransformerEncoder and TransformerDecoder for more details.
Users can configure the model architecture with corresponding parameters. Note the usage of normalize_before representing where to apply layer normalization (in pre-process or post-precess of multi-head attention or FFN), and some transformer like models are different on this, such as BERT and GPT2 . The default architecture here places layer normalization in post-process and applies another layer normalization on the output of last encoder/decoder layer.
- Parameters
-
d_model (int, optional) – The expected feature size in the encoder/decoder input and output. Default 512
nhead (int, optional) – The number of heads in multi-head attention(MHA). Default 8
num_encoder_layers (int, optional) – The number of layers in encoder. Default 6
num_decoder_layers (int, optional) – The number of layers in decoder. Default 6
dim_feedforward (int, optional) – The hidden layer size in the feedforward network(FFN). Default 2048
dropout (float, optional) – The dropout probability used in pre-process and post-precess of MHA and FFN sub-layer. Default 0.1
activation (str, optional) – The activation function in the feedforward network. Default relu.
attn_dropout (float, optional) – The dropout probability used in MHA to drop some attention target. If None, use the value of dropout. Default None
act_dropout (float, optional) – The dropout probability used after FFN activation. If None, use the value of dropout. Default None
normalize_before (bool, optional) – Indicate whether to put layer normalization into preprocessing of MHA and FFN sub-layers. If True, pre-process is layer normalization and post-precess includes dropout, residual connection. Otherwise, no pre-process and post-precess includes dropout, residual connection, layer normalization. Default False
weight_attr (ParamAttr|list|tuple, optional) – To specify the weight parameter property. If it is a list/tuple, the length of weight_attr could be 1, 2 or 3. If it is 3, weight_attr[0] would be used as weight_attr for self attention, weight_attr[1] would be used as weight_attr for cross attention of TransformerDecoder, and weight_attr[2] would be used as weight_attr for linear in FFN. If it is 2, weight_attr[0] would be used as weight_attr both for self attention and cross attention and weight_attr[1] would be used as weight_attr for linear in FFN. If it is 1, weight_attr[0] would be used as weight_attr for self attention, cross attention and linear in FFN. Otherwise, the three sub-layers all uses it as weight_attr to create parameters. Default: None, which means the default weight parameter property is used. See usage for details in
ParamAttr
.bias_attr (ParamAttr|list|tuple|bool, optional) – To specify the bias parameter property. If it is a list/tuple, the length of bias_attr could be 1, 2 or 3. If it is 3, bias_attr[0] would be used as bias_attr for self attention, bias_attr[1] would be used as bias_attr for cross attention of TransformerDecoder, and bias_attr[2] would be used as bias_attr for linear in FFN. If it is 2, bias_attr[0] would be used as bias_attr both for self attention and cross attention and bias_attr[1] would be used as bias_attr for linear in FFN. If it is 1, bias_attr[0] would be used as bias_attr for self attention, cross attention and linear in FFN. Otherwise, the three sub-layers all uses it as bias_attr to create parameters. The False value means the corresponding layer would not have trainable bias parameter. See usage for details in
ParamAttr
. Default: None,which means the default bias parameter property is used.custom_encoder (Layer, optional) – If custom encoder is provided, use it as the encoder. Default None
custom_decoder (Layer, optional) – If custom decoder is provided, use it as the decoder. Default None
Examples
>>> import paddle >>> from paddle.nn import Transformer >>> # src: [batch_size, tgt_len, d_model] >>> enc_input = paddle.rand((2, 4, 128)) >>> # tgt: [batch_size, src_len, d_model] >>> dec_input = paddle.rand((2, 6, 128)) >>> # src_mask: [batch_size, n_head, src_len, src_len] >>> enc_self_attn_mask = paddle.rand((2, 2, 4, 4)) >>> # tgt_mask: [batch_size, n_head, tgt_len, tgt_len] >>> dec_self_attn_mask = paddle.rand((2, 2, 6, 6)) >>> # memory_mask: [batch_size, n_head, tgt_len, src_len] >>> cross_attn_mask = paddle.rand((2, 2, 6, 4)) >>> transformer = Transformer(128, 2, 4, 4, 512) >>> output = transformer(enc_input, ... dec_input, ... enc_self_attn_mask, ... dec_self_attn_mask, ... cross_attn_mask) >>> print(output.shape) [2, 6, 128]
-
forward
(
src,
tgt,
src_mask=None,
tgt_mask=None,
memory_mask=None
)
forward¶
-
Applies a Transformer model on the inputs.
- Parameters
-
src (Tensor) – The input of Transformer encoder. It is a tensor with shape [batch_size, source_length, d_model]. The data type should be float32 or float64.
tgt (Tensor) – The input of Transformer decoder. It is a tensor with shape [batch_size, target_length, d_model]. The data type should be float32 or float64.
memory (Tensor) – The output of Transformer encoder. It is a tensor with shape [batch_size, source_length, d_model]. The data type should be float32 or float64.
src_mask (Tensor, optional) – A tensor used in multi-head attention to prevents attention to some unwanted positions, usually the paddings or the subsequent positions. It is a tensor with shape broadcasted to [batch_size, n_head, sequence_length, sequence_length]. When the data type is bool, the unwanted positions have False values and the others have True values. When the data type is int, the unwanted positions have 0 values and the others have 1 values. When the data type is float, the unwanted positions have -INF values and the others have 0 values. It can be None when nothing wanted or needed to be prevented attention to. Default None.
tgt_mask (Tensor, optional) – A tensor used in self attention to prevents attention to some unwanted positions, usually the the subsequent positions. It is a tensor with shape broadcasted to [batch_size, n_head, target_length, target_length]. When the data type is bool, the unwanted positions have False values and the others have True values. When the data type is int, the unwanted positions have 0 values and the others have 1 values. When the data type is float, the unwanted positions have -INF values and the others have 0 values. It can be None when nothing wanted or needed to be prevented attention to. Default None.
memory_mask (Tensor, optional) – A tensor used in decoder-encoder cross attention to prevents attention to some unwanted positions, usually the paddings. It is a tensor with shape broadcasted to [batch_size, n_head, target_length, source_length]. When the data type is bool, the unwanted positions have False values and the others have True values. When the data type is int, the unwanted positions have 0 values and the others have 1 values. When the data type is float, the unwanted positions have -INF values and the others have 0 values. It can be None when nothing wanted or needed to be prevented attention to. Default None.
- Returns
-
- It is a tensor that has the same shape and data type
-
as tgt, representing the output of Transformer decoder.
- Return type
-
Tensor
-
generate_square_subsequent_mask
(
length
)
generate_square_subsequent_mask¶
-
Generate a square mask for the sequence. The mask ensures that the predictions for position i can depend only on the known outputs at positions less than i.
- Parameters
-
length (int|Tensor) – The length of sequence.
- Returns
-
Tensor, generated square mask according to the given length. The shape is [length, length].
Examples
>>> import paddle >>> from paddle.nn.layer.transformer import Transformer >>> length = 5 >>> d_model, n_head, dim_feedforward = 8, 4, 64 >>> transformer_paddle = Transformer( ... d_model, n_head, dim_feedforward=dim_feedforward) >>> mask = transformer_paddle.generate_square_subsequent_mask(length) >>> print(mask) Tensor(shape=[5, 5], dtype=float32, place=Place(cpu), stop_gradient=True, [[ 0. , -inf., -inf., -inf., -inf.], [ 0. , 0. , -inf., -inf., -inf.], [ 0. , 0. , 0. , -inf., -inf.], [ 0. , 0. , 0. , 0. , -inf.], [ 0. , 0. , 0. , 0. , 0. ]])