fused_rotary_position_embedding¶
- paddle.incubate.nn.functional. fused_rotary_position_embedding ( q, k=None, v=None, sin=None, cos=None, position_ids=None, use_neox_rotary_style=True ) [源代码] ¶
融合旋转位置编码。
参数¶
q (Tensor) - 输入张量。 数据类型可以是 bfloat16, float16, float32 或 float64。 q 的形状必须是 [batch_size, seq_len, num_heads, head_dim],并且 head_dim 必须是 2 的倍数。
k (Tensor, 可选) - 输入张量。 数据类型可以是 bfloat16, float16, float32 或 float64。 k 的形状必须是 [batch_size, seq_len, num_heads, head_dim],并且 head_dim 必须是 2 的倍数。
v (Tensor, 可选) - 输入张量。 数据类型可以是 bfloat16, float16, float32 或 float64。 v 的形状必须是 [batch_size, seq_len, num_heads, head_dim],并且 head_dim 必须是 2 的倍数。
sin (Tensor, 可选) - 输入张量。 数据类型可以是 bfloat16, float16, float32 或 float64。 sin 的形状必须是 [seq_len, head_dim] 或 [1, seq_len, 1, head_dim], 并且 head_dim 必须是 2 的倍数。
cos (Tensor, 可选) - 输入张量。 数据类型可以是 bfloat16, float16, float32 或 float64。 cos 的形状必须是 [seq_len, head_dim] 或 [1, seq_len, 1, head_dim], 并且 head_dim 必须是 2 的倍数。
position_ids (Tensor, 可选) - 输入张量。 数据类型为 int64. position_ids 的形状为[batch_size, seq_len]。
use_neox_rotary_style (可选|bool) - 当 use_neox_rotary_style 为 True, 每两个相邻的数字计算一次。 当 use_neox_rotary_style 为 False, 计算与前半段和后半段位置相对应的数字。 默认值为 True。
time_major (可选|bool) - 指定输入张量的时间维度是否为第一个维度。 如果为 True,则输入张量的形状应为 [seq_len, batch_size, num_heads, head_dim]。 如果为 False,则输入张量的形状应为 [batch_size, seq_len, num_heads, head_dim] 。 默认值为 False。
rotary_emb_base (可选|float) - 计算旋转角使用的底数。 默认值为 10000.0。
返回¶
out_q/out_k/out_v 表示融合旋转位置嵌入的张量,具有与 q 相同的形状和数据类型。
代码示例¶
>>> import paddle
>>> from paddle.incubate.nn.functional import fused_rotary_position_embedding
>>> paddle.set_device('gpu')
>>> # batch_size = 2
>>> # seq_len = 2
>>> # num_heads = 2
>>> # head_dim = 2
>>> paddle.seed(1204)
>>> # q, k, v: [batch_size, seq_len, num_heads, head_dim]
>>> q = paddle.randn([2, 2, 2, 2], dtype='float16')
>>> k = paddle.randn([2, 2, 2, 2], dtype='float16')
>>> v = paddle.randn([2, 2, 2, 2], dtype='float16')
>>> # sin, cos: [1, seq_len, 1, head_dim]
>>> x = paddle.randn([1, 2, 1, 2], dtype='float16')
>>> y = paddle.randn([1, 2, 1, 2], dtype='float16')
>>> sin = paddle.sin(x)
>>> cos = paddle.cos(y)
>>> # position_ids: [batch_size, seq_len]
>>> position_ids = paddle.randint(high=2, shape=[2, 2], dtype='int64')
>>> # out_q, out_k, out_v: [batch_size, seq_len, num_heads, head_dim]
>>> out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v, sin=sin, cos=cos, position_ids=position_ids, use_neox_rotary_style=False)
>>> print(out_q)
Tensor(shape=[2, 2, 2, 2], dtype=float16, place=Place(gpu:0), stop_gradient=True,
[[[[-0.54931641, 0.64990234],
[-1.08691406, 1.18261719]],
[[ 0.57812500, 0.11749268],
[-0.63281250, 0.15551758]]],
[[[-0.77050781, 0.07733154],
[-0.73730469, -0.16735840]],
[[ 0.07116699, -0.90966797],
[-0.03628540, -0.20202637]]]])