flash_attention_with_sparse_mask

paddle.nn.functional. flash_attention_with_sparse_mask ( query, key, value, attn_mask_start_row_indices, attn_mask_start_row=0, dropout_p=0.0, is_causal=False, return_softmax=False, return_softmax_lse=False, return_seed_offset=False, training=True, name=None ) [source]

The equation is:

\[result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V\]

where : Q, K, and V represent the three input parameters of the attention module. The dimensions of the three parameters are the same. d represents the size of the last dimension of the three parameters.

Warning

This API only supports inputs with dtype float16 and bfloat16.

Parameters
  • query (Tensor) – The query tensor in the Attention module. 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float61 or bfloat16.

  • key (Tensor) – The key tensor in the Attention module. 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float61 or bfloat16.

  • value (Tensor) – The value tensor in the Attention module. 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float61 or bfloat16.

  • attn_mask_start_row_indices (Tensor) – A sparse attention mask indices tensor, the shape is [batch_size, num_head, seq_len], The value of each element indicates the row index where the mask starts in score matrix. The dtype must be int32.

  • attn_mask_start_row (int,optional) – When attn_mask_start_row_indices is passed in and the minimum row number is known to be greater than 0, it can set attn_mask_start_row for performance improvement. The default value is 0.

  • dropout_p (float) – The dropout ratio.

  • is_causal (bool) – Whether enable causal mode.

  • training (bool) – Whether it is in the training phase.

  • name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name.

Returns

out(Tensor), The attention tensor.

4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float16 or bfloat16.

Examples

>>> 
>>> import paddle
>>> import numpy as np
>>> def generate_start_rows(bz, num_head, rows, cols, start_row):
>>>     assert rows == cols, f"rows {rows} must be equal to cols {cols}."
>>>     start_rows_list = []
>>>     for bz_idx in range(bz):
>>>         for head_idx in range(num_head):
>>>             start_rows = np.array([rows+1] * cols)
>>>             mask_pos = np.random.choice(cols-1, cols - start_row, replace=False)
>>>             index = np.arange(start_row, rows)
>>>             mask_pos = np.concatenate([mask_pos[mask_pos < index - 1], mask_pos[mask_pos >= index - 1]])
>>>             start_rows[mask_pos] = index
>>>             start_rows_list.append(start_rows)
>>>     start_rows_arr = np.array(start_rows_list).reshape([bz, num_head, rows])
>>>     return start_rows_arr
>>> q = paddle.rand((1, 128, 2, 16), dtype=paddle.bfloat16)
>>> attn_mask_start_row = 48
>>> start_row_indices = generate_start_rows(1, 2, 128, 128, attn_mask_start_row)
>>> attn_mask_start_row_indices = paddle.to_tensor(start_row_indices, dtype=paddle.int32)
>>> out = paddle.nn.functional.flash_attention.flash_attention_with_sparse_mask(
>>>     q, q, q,
>>>     attn_mask_start_row_indices=attn_mask_start_row_indices,
>>>     attn_mask_start_row=attn_mask_start_row,
>>>     dropout_p=0.9,
>>>     is_causal=True,
>>> )
>>> print(output)
>>>