DistAttr

class paddle.distributed. DistAttr ( mesh, sharding_specs ) [源代码]

DistAttr 指定 Tensor 在 ProcessMesh 上的分布或切片方式。

参数

  • mesh (paddle.distributed.ProcessMesh) - 表示进程拓扑信息的 ProcessMesh 对象。

  • sharding_specs (list[str|None]) - 描述 Tensor 的切分规则。

代码示例

>>> import paddle
>>> import paddle.distributed as dist

>>> mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=['x', 'y'])
>>> dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y'])

>>> print(dist_attr)