moe_dispatch

paddle.incubate.nn.functional. moe_dispatch ( x: Tensor, gating_output: Tensor, moe_topk: int, group_moe: bool = False, topk_only_mode: bool = False ) tuple[Tensor, Tensor, Tensor, Tensor, Tensor] [source]

Dispatches tokens to experts based on gating probabilities.

This function routes each token to its top-k selected experts according to the gating output. It prepares the inputs for expert processing by reordering and scaling.

Parameters
  • x (Tensor) – The input tensor with shape [batch_size * seq_len, d_model].

  • gating_output (Tensor) – The gating output probabilities with shape [batch_size * seq_len, num_experts].

  • moe_topk (int) – The number of top experts to select for each token.

  • group_moe (bool, optional) – Whether to use group MoE. Default is False.Group_size is expert_nums // moe_topk.

  • topk_only_mode (bool, optional) – Whether to only use tok. Default is False.

Returns

  • permute_input (Tensor): The permuted input tensor ready for expert processing.

  • token_nums_per_expert (Tensor): The number of tokens assigned to each expert.

  • permute_indices_per_token (Tensor): The index mapping for scattering outputs back to the original order.

  • expert_scales_float (Tensor): The scaling factors for each expert’s outputs.

  • top_k_indices (Tensor): The indices of the selected experts for each token.

Return type

Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]

Examples

>>> 
>>> import paddle
>>> from paddle.incubate.nn.functional import moe_dispatch

>>> x = paddle.randn([1280, 768]) # 1280 = bs * 128
>>> gating_output = paddle.rand([1280, 48])
>>> group_moe = False
>>> topk_only_mode = True
>>> moe_topk = 6
>>> (
...     permute_input,
...     token_nums_per_expert,
...     permute_indices_per_token,
...     expert_scales_float,
...     top_k_indices
... ) = moe_dispatch(x, gating_output, moe_topk, group_moe, topk_only_mode)
>>> print(permute_input.shape)
[7680, 768]
>>> print(token_nums_per_expert.shape)
[48]