global_scatter

paddle.distributed.utils. global_scatter ( x, local_count, global_count, group=None, use_calc_stream=True ) [source]

Scatter data in x which has been put together belong to one expert to n_expert * world_size exeperts according to local_count and receive tensors from n_expert * world_size experts according to global_count.

Parameters
  • x (Tensor) – Tensor. Every element in the list must be a Tensor whose data type should be float16, float32, float64, int32 or int64.

  • local_count (Tensor) – Tensor which have n_expert * world_size elements that indicates how many data needed to be sent. Every element in the list must be a Tensor whose data type should be int64.

  • global_count (Tensor) – Tensor which have n_expert * world_size elements that indicates how many data needed to be received. Every element in the list must be a Tensor whose data type should be int64.

  • group (Group, optional) – The group instance return by new_group or None for global default group. Default: None.

  • use_calc_stream (bool, optional) – Wether to use calculation stream (True) or communication stream. Default: True.

Returns

The data received from all experts.

Return type

out (Tensor)

Examples

# required: distributed
import numpy as np
import paddle
from paddle.distributed import init_parallel_env
init_parallel_env()
n_expert = 2
world_size = 2
d_model = 2
in_feat = d_model
local_input_buf = np.array([[1, 2],[3, 4],[5, 6],[7, 8],[9, 10]],             dtype=np.float32)
if paddle.distributed.ParallelEnv().local_rank == 0:
    local_count = np.array([2, 1, 1, 1])
    global_count = np.array([2, 1, 1, 1])
else:
    local_count = np.array([1, 1, 2, 1])
    global_count = np.array([1, 1, 2, 1])
local_input_buf = paddle.to_tensor(local_input_buf, dtype="float32", stop_gradient=False)
local_count = paddle.to_tensor(local_count, dtype="int64")
global_count = paddle.to_tensor(global_count, dtype="int64")
a = paddle.distributed.utils.global_scatter(local_input_buf,             local_count, global_count)
a.stop_gradient = False
print(a)
# out for rank 0: [[1, 2], [3, 4], [1, 2], [5, 6], [3, 4]]
# out for rank 1: [[7, 8], [5, 6], [7, 8], [9, 10], [9, 10]]
# backward test
c = a * a
c.backward()
print("local_input_buf.grad: ", local_input_buf.grad)
# out for rank 0: [[2, 4], [6, 8], [10, 12], [14, 16], [18, 20]]
# out for rank 1: [[2, 4], [6, 8], [10, 12], [14, 16], [18, 20]]