flash_attention_with_sparse_mask¶
- paddle.nn.functional. flash_attention_with_sparse_mask ( query, key, value, attn_mask_start_row_indices, attn_mask_start_row=0, dropout_p=0.0, is_causal=False, return_softmax=False, return_softmax_lse=False, return_seed_offset=False, training=True, name=None ) [source]
-
The equation is:
\[result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V\]where :
Q
,K
, andV
represent the three input parameters of the attention module. The dimensions of the three parameters are the same.d
represents the size of the last dimension of the three parameters.Warning
This API only supports inputs with dtype float16 and bfloat16.
- Parameters
-
query (Tensor) – The query tensor in the Attention module. 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float61 or bfloat16.
key (Tensor) – The key tensor in the Attention module. 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float61 or bfloat16.
value (Tensor) – The value tensor in the Attention module. 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float61 or bfloat16.
attn_mask_start_row_indices (Tensor) – A sparse attention mask indices tensor, the shape is [batch_size, num_head, seq_len], The value of each element indicates the row index where the mask starts in score matrix. The dtype must be int32.
attn_mask_start_row (int,optional) – When attn_mask_start_row_indices is passed in and the minimum row number is known to be greater than 0, it can set attn_mask_start_row for performance improvement. The default value is 0.
dropout_p (float) – The dropout ratio.
is_causal (bool) – Whether enable causal mode.
training (bool) – Whether it is in the training phase.
name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name.
- Returns
-
- out(Tensor), The attention tensor.
-
4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float16 or bfloat16.
Examples
>>> >>> import paddle >>> import numpy as np >>> def generate_start_rows(bz, num_head, rows, cols, start_row): >>> assert rows == cols, f"rows {rows} must be equal to cols {cols}." >>> start_rows_list = [] >>> for bz_idx in range(bz): >>> for head_idx in range(num_head): >>> start_rows = np.array([rows+1] * cols) >>> mask_pos = np.random.choice(cols-1, cols - start_row, replace=False) >>> index = np.arange(start_row, rows) >>> mask_pos = np.concatenate([mask_pos[mask_pos < index - 1], mask_pos[mask_pos >= index - 1]]) >>> start_rows[mask_pos] = index >>> start_rows_list.append(start_rows) >>> start_rows_arr = np.array(start_rows_list).reshape([bz, num_head, rows]) >>> return start_rows_arr >>> q = paddle.rand((1, 128, 2, 16), dtype=paddle.bfloat16) >>> attn_mask_start_row = 48 >>> start_row_indices = generate_start_rows(1, 2, 128, 128, attn_mask_start_row) >>> attn_mask_start_row_indices = paddle.to_tensor(start_row_indices, dtype=paddle.int32) >>> out = paddle.nn.functional.flash_attention.flash_attention_with_sparse_mask( >>> q, q, q, >>> attn_mask_start_row_indices=attn_mask_start_row_indices, >>> attn_mask_start_row=attn_mask_start_row, >>> dropout_p=0.9, >>> is_causal=True, >>> ) >>> print(output) >>>