moe_dispatch¶
- paddle.incubate.nn.functional. moe_dispatch ( x: Tensor, gating_output: Tensor, moe_topk: int, group_moe: bool = False, topk_only_mode: bool = False ) tuple[Tensor, Tensor, Tensor, Tensor, Tensor] [source]
-
Dispatches tokens to experts based on gating probabilities.
This function routes each token to its top-k selected experts according to the gating output. It prepares the inputs for expert processing by reordering and scaling.
- Parameters
-
x (Tensor) – The input tensor with shape [batch_size * seq_len, d_model].
gating_output (Tensor) – The gating output probabilities with shape [batch_size * seq_len, num_experts].
moe_topk (int) – The number of top experts to select for each token.
group_moe (bool, optional) – Whether to use group MoE. Default is False.Group_size is expert_nums // moe_topk.
topk_only_mode (bool, optional) – Whether to only use tok. Default is False.
- Returns
-
permute_input (Tensor): The permuted input tensor ready for expert processing.
token_nums_per_expert (Tensor): The number of tokens assigned to each expert.
permute_indices_per_token (Tensor): The index mapping for scattering outputs back to the original order.
expert_scales_float (Tensor): The scaling factors for each expert’s outputs.
top_k_indices (Tensor): The indices of the selected experts for each token.
- Return type
-
Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]
Examples
>>> >>> import paddle >>> from paddle.incubate.nn.functional import moe_dispatch >>> x = paddle.randn([1280, 768]) # 1280 = bs * 128 >>> gating_output = paddle.rand([1280, 48]) >>> group_moe = False >>> topk_only_mode = True >>> moe_topk = 6 >>> ( ... permute_input, ... token_nums_per_expert, ... permute_indices_per_token, ... expert_scales_float, ... top_k_indices ... ) = moe_dispatch(x, gating_output, moe_topk, group_moe, topk_only_mode) >>> print(permute_input.shape) [7680, 768] >>> print(token_nums_per_expert.shape) [48]