fused_moe

paddle.incubate.nn.functional. fused_moe ( x: Tensor, gate_weight: Tensor, ffn1_weight: Tensor, ffn2_weight: Tensor, ffn1_bias: Tensor | None = None, ffn1_scale: Tensor | None = None, ffn2_bias: Tensor | None = None, ffn2_scale: Tensor | None = None, quant_method: str = 'None', moe_topk: int = 2, norm_topk_prob: bool = True, group_moe: bool = False ) Tensor [source]

Applies fused moe kernel. This method requires SM_ARCH in sm75, sm80, sm86.

Parameters
  • x (Tensor) – the input Tensor. Its shape is [bsz, seq_len, d_model].

  • gate_weight (Tensor) – the gate Tensor to choose expert. Its shape is [bsz, seq_len, num_experts].

  • ffn1_weight (Tensor) – the first batch matrix matmul weight. Its shape is [num_experts, d_model, d_feed_forward*2].

  • ffn2_weight (Tensor) – the second batch matrix matmul weight. Its shape is [num_experts, d_feed_forward, d_model].

  • ffn1_bias (Tensor, optional) – the first batch matrix matmul bias. Its shape is [num_experts, 1, d_feed_forward*2].

  • ffn1_scale (Tensor, optional) – the input scale Tensor Provided to weight for dequantization. Its shape is [num_experts, d_feed_forward*2].

  • ffn2_bias (Tensor, optional) – the second batch matrix matmul bias. Its shape is [num_experts, 1, d_model].

  • ffn2_scale (Tensor, optional) – the input scale Tensor Provided to weight for dequantization. Its shape is [num_experts, d_model].

  • quant_method (string) – Currently not supported.

  • moe_topk (int) – Select the top k experts for each token.

  • norm_topk_prob (bool) – Whether to normalize the moe_topk probabilities.

Returns

the output Tensor.

Return type

Tensor

Examples

>>> 
>>> import paddle
>>> from paddle.incubate.nn.functional import fused_moe

>>> paddle.set_device('gpu')
>>> paddle.set_default_dtype("float16")
>>> x = paddle.randn([10, 128, 1024])
>>> gate_weight = paddle.randn([10, 128, 8], dtype=paddle.float32)
>>> ffn1_weight = paddle.randn([8, 1024, 4096])
>>> ffn1_bias = paddle.randn([8, 1, 4096])
>>> ffn2_weight = paddle.randn([8, 2048, 1024])
>>> ffn2_bias = paddle.randn([8, 1, 1024])
>>> moe_topk = 2
>>> out = fused_moe(x, gate_weight, ffn1_weight, ffn2_weight, ffn1_bias, None, ffn2_bias, None, "None", moe_topk, True)
>>> print(out.shape)
[10, 128, 1024]