unshard_dtensor

unshard_dtensor ( dist_tensor )

将带有分布式信息的分布式 Tensor 转换为普通 Tensor。

参数

  • dist_tensor (paddle.Tensor) - 带有分布式信息的分布式 Tensor。

返回

paddle.Tensor: 不带分布式信息的普通 Tensor,包含 dist_tensor 的全局数据。

代码示例

>>> import paddle
>>> import paddle.distributed as dist
>>> from paddle.distributed import Replicate, Shard

>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
>>> original_tensor = paddle.rand([4, 1024, 512])
>>> dist_tensor = dist.shard_tensor(original_tensor, mesh, [Shard(0)])
>>> # dense_tensor's shape is the same as original_tensor
>>> dense_tensor = dist.unshard_dtensor(dist_tensor)