reshard

paddle.distributed. reshard ( dist_tensor, mesh, placements ) [源代码]

根据新的分布式信息 placements ,对一个带有分布式信息的 Tensor 进行 reshard 操作,重新进行 Tensor 的分布/切片,返回一个新的分布式 Tensor 。

dist_tensor 需要是一个具有分布式信息的 paddle.Tensor。

参数

  • dist_tensor (Tensor) - 具有分布式信息的 Tensor ,为 paddle.Tensor 类型。

  • mesh (paddle.distributed.ProcessMesh) - 表示进程拓扑信息的 ProcessMesh 对象。

  • placements (list(Placement)) - 分布式 Tensor 的切分表示列表,描述 Tensor 在 mesh 上如何切分。

返回

将输入的 dist_tensor 按照新的方式进行分布/切分的分布式 Tensor。

代码示例

>>> import paddle
>>> import paddle.distributed as dist

>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])

>>> # dense tensor
>>> a = paddle.ones([10, 20])

>>> # distributed tensor
>>> d_tensor = dist.shard_tensor(a, mesh, [dist.Partial()])

>>> out_d_tensor = dist.reshard(d_tensor, mesh, [dist.Replicate()])

>>> print(out_d_tensor)