global_gather¶
- paddle.distributed.utils. global_gather ( x, local_count, global_count, group=None, use_calc_stream=True ) [source]
-
Gather data in x to n_expert * world_size exeperts according to local_count and receive tensors from n_expert * world_size experts according to global_count.
- Parameters
-
x (Tensor) – Tensor. Every element in the list must be a Tensor whose data type should be float16, float32, float64, int32 or int64.
local_count (Tensor) – Tensor which have n_expert * world_size elements that indicates how many data needed to be received. Every element in the list must be a Tensor whose data type should be int64.
global_count (Tensor) – Tensor which have n_expert * world_size elements that indicates how many data needed to be sent. Every element in the list must be a Tensor whose data type should be int64.
group (Group, optional) – The group instance return by new_group or None for global default group. Default: None.
use_calc_stream (bool, optional) – Wether to use calculation stream (True) or communication stream. Default: True.
- Returns
-
None.
Examples
# required: distributed import numpy as np import paddle from paddle.distributed import init_parallel_env init_parallel_env() n_expert = 2 world_size = 2 d_model = 2 in_feat = d_model local_input_buf = np.array([[1, 2],[3, 4],[5, 6],[7, 8],[9, 10]], dtype=np.float32) if paddle.distributed.ParallelEnv().local_rank == 0: local_count = np.array([2, 1, 1, 1]) global_count = np.array([2, 1, 1, 1]) else: local_count = np.array([1, 1, 2, 1]) global_count = np.array([1, 1, 2, 1]) local_input_buf = paddle.to_tensor(local_input_buf, dtype="float32", stop_gradient=False) local_count = paddle.to_tensor(local_count, dtype="int64") global_count = paddle.to_tensor(global_count, dtype="int64") a = paddle.distributed.utils.global_gather(local_input_buf, local_count, global_count) print(a) # out for rank 0: [[1, 2], [3, 4], [7, 8], [1, 2], [7, 8]] # out for rank 1: [[5, 6], [9, 10], [3, 4], [5, 6], [9, 10]] a.stop_gradient = False c = a * a c.backward() print("local_input_buf.grad", local_input_buf.grad) # out for rank 0: [[2, 4], [6, 8], [10, 12], [14, 16], [18, 20]] # out for rank 1: [[2, 4], [6, 8], [10, 12], [14, 16], [18, 20]]