blha_get_max_len

paddle.incubate.nn.functional. blha_get_max_len ( seq_lens_encoder, seq_lens_decoder, batch_size ) [source]

Apply Fused BlhaGetMaxLen kernel. Typically used before the block_multihead_attention operator.

Parameters
  • seq_lens_encoder (Tensor) – Sentence length of the encoder.

  • seq_lens_decoder (Tensor) – Sentence length of the decoder.

  • batch_size (Tensor) – the batch size.

Returns

Tensor|(max_enc_len_this_time, max_dec_len_this_time)

Examples

>>> 
>>> import paddle
>>> paddle.device.set_device('gpu')

>>> seq_lens_encoder = paddle.cast(paddle.randn(shape=[10]), dtype=paddle.int32)
>>> seq_lens_decoder = paddle.cast(paddle.randn(shape=[10]), dtype=paddle.int32)
>>> bsz = 10
>>> batch_size = paddle.ones(shape=[bsz])
>>> max_enc_len_this_time, max_dec_len_this_time = paddle.incubate.nn.functional.blha_get_max_len(seq_lens_encoder, seq_lens_decoder, batch_size)