blha_get_max_len¶
- paddle.incubate.nn.functional. blha_get_max_len ( seq_lens_encoder, seq_lens_decoder, batch_size ) [source]
-
Apply Fused BlhaGetMaxLen kernel. Typically used before the block_multihead_attention operator.
- Parameters
-
seq_lens_encoder (Tensor) – Sentence length of the encoder.
seq_lens_decoder (Tensor) – Sentence length of the decoder.
batch_size (Tensor) – the batch size.
- Returns
-
Tensor|(max_enc_len_this_time, max_dec_len_this_time)
Examples
>>> >>> import paddle >>> paddle.device.set_device('gpu') >>> seq_lens_encoder = paddle.cast(paddle.randn(shape=[10]), dtype=paddle.int32) >>> seq_lens_decoder = paddle.cast(paddle.randn(shape=[10]), dtype=paddle.int32) >>> bsz = 10 >>> batch_size = paddle.ones(shape=[bsz]) >>> max_enc_len_this_time, max_dec_len_this_time = paddle.incubate.nn.functional.blha_get_max_len(seq_lens_encoder, seq_lens_decoder, batch_size)