fused_moe¶
- paddle.incubate.nn.functional. fused_moe ( x: Tensor, gate_weight: Tensor, ffn1_weight: Tensor, ffn2_weight: Tensor, ffn1_bias: Tensor | None = None, ffn1_scale: Tensor | None = None, ffn2_bias: Tensor | None = None, ffn2_scale: Tensor | None = None, quant_method: str = 'None', moe_topk: int = 2, norm_topk_prob: bool = True, group_moe: bool = False ) Tensor [source]
-
Applies fused moe kernel. This method requires SM_ARCH in sm75, sm80, sm86.
- Parameters
-
x (Tensor) – the input Tensor. Its shape is [bsz, seq_len, d_model].
gate_weight (Tensor) – the gate Tensor to choose expert. Its shape is [bsz, seq_len, num_experts].
ffn1_weight (Tensor) – the first batch matrix matmul weight. Its shape is [num_experts, d_model, d_feed_forward*2].
ffn2_weight (Tensor) – the second batch matrix matmul weight. Its shape is [num_experts, d_feed_forward, d_model].
ffn1_bias (Tensor, optional) – the first batch matrix matmul bias. Its shape is [num_experts, 1, d_feed_forward*2].
ffn1_scale (Tensor, optional) – the input scale Tensor Provided to weight for dequantization. Its shape is [num_experts, d_feed_forward*2].
ffn2_bias (Tensor, optional) – the second batch matrix matmul bias. Its shape is [num_experts, 1, d_model].
ffn2_scale (Tensor, optional) – the input scale Tensor Provided to weight for dequantization. Its shape is [num_experts, d_model].
quant_method (string) – Currently not supported.
moe_topk (int) – Select the top k experts for each token.
norm_topk_prob (bool) – Whether to normalize the moe_topk probabilities.
- Returns
-
the output Tensor.
- Return type
-
Tensor
Examples
>>> >>> import paddle >>> from paddle.incubate.nn.functional import fused_moe >>> paddle.set_device('gpu') >>> paddle.set_default_dtype("float16") >>> x = paddle.randn([10, 128, 1024]) >>> gate_weight = paddle.randn([10, 128, 8], dtype=paddle.float32) >>> ffn1_weight = paddle.randn([8, 1024, 4096]) >>> ffn1_bias = paddle.randn([8, 1, 4096]) >>> ffn2_weight = paddle.randn([8, 2048, 1024]) >>> ffn2_bias = paddle.randn([8, 1, 1024]) >>> moe_topk = 2 >>> out = fused_moe(x, gate_weight, ffn1_weight, ffn2_weight, ffn1_bias, None, ffn2_bias, None, "None", moe_topk, True) >>> print(out.shape) [10, 128, 1024]