shard_tensor

paddle.distributed. shard_tensor ( data, mesh, placements, dtype=None, place=None, stop_gradient=True ) [源代码]

通过已知的 data 来创建一个带有分布式信息的 Tensor,Tensor 类型为 paddle.Tensordata 可以是 scalar,tuple,list,numpy.ndarray,paddle.Tensor。

如果 data 已经是一个 Tensor,将其转换为一个分布式 Tensor。

参数

  • data (scalar|tuple|list|ndarray|Tensor) - 初始化 Tensor 的数据,可以是 scalar,list,tuple,numpy.ndarray,paddle.Tensor 类型。

  • mesh (paddle.distributed.ProcessMesh) - 表示进程拓扑信息的 ProcessMesh 对象。

  • placements (list(Placement)) - 分布式 Tensor 的切分表示列表,描述 Tensor 在 mesh 上如何切分。

  • dtype (str,可选) - 创建 Tensor 的数据类型,可以是 bool、float16、float32、float64、int8、int16、int32、int64、uint8、complex64、complex128。 默认值为 None,如果 data 为 python 浮点类型,则从 get_default_dtype 获取类型,如果 data 为其他类型,则会自动推导类型。

  • place (CPUPlace|CUDAPinnedPlace|CUDAPlace,可选) - 创建 tensor 的设备位置,可以是 CPUPlace、CUDAPinnedPlace、CUDAPlace。默认值为 None,使用全局的 place。

  • stop_gradient (bool,可选) - 是否阻断 Autograd 的梯度传导。默认值为 True,此时不进行梯度传传导。

返回

通过 data 创建的带有分布式信息的 Tensor。

代码示例

>>> import paddle
>>> import paddle.distributed as dist

>>> mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=['x', 'y'])

>>> # dense tensor
>>> a = paddle.to_tensor([[1,2,3],
...                       [5,6,7]])

>>> # distributed tensor
>>> d_tensor = dist.shard_tensor(a, mesh, [dist.Shard(0), dist.Shard(1)])

>>> print(d_tensor)