dtensor_from_fn¶
- paddle.distributed. dtensor_from_fn ( fn, mesh, placements, *args, **kwargs ) [source]
-
Construct a Distributed Tensor from a function of arguments.
- Parameters
-
fn (callable) – A callable function that takes arguments of Distributed Tensor and returns tensor.
mesh (paddle.distributed.ProcessMesh) – The ProcessMesh object describes the Cartesian topology of the used processes.
placements (list[paddle.distributed.Placement]) – the placements describe how to place the tensor on ProcessMesh, it can be Shard, Replicate and Partial.
*args (tuple) – A tuple of arguments to be passed to the
fn
function.**kwargs (dict) – A dict of arguments to be passed to the
fn
function.
- Retruns:
-
Tensor: A Tensor constructed from
fn
with distributed attributes.
Examples
>>> import paddle >>> import paddle.distributed as dist >>> # Create a distributed attribute >>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) >>> # Call the function dtensor_from_fn with dist_attr parameter >>> d_tensor = dist.dtensor_from_fn(paddle.ones, mesh, [dist.Replicate()], shape=[1]) >>> print(d_tensor)