shard_optimizer¶
将单卡视角的优化器转变为分布式视角。可以通过指定 shard_fn 来定制化优化器状态的切分方式,否则会将参数的分布式信息传递给对应的优化器状态。
shard_fn 的函数签名为:def shard_fn(accumulator_name, param, accumulator) -> sharded_accumulator。
参数¶
optimizer (paddle.optimizer.Optimizer) - 单卡视角的优化器。
shard_fn (Callable,可选) - 用于切分优化器状态函数。如果没有指定,默认地我们将参数的分布式信息传递给对应的优化器状态。
返回¶
Optimizer:一个具有分布式视角的 Optimizer 对象。
代码示例¶
>>> import paddle
>>> import paddle.distributed as dist
>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
>>> class MLP(paddle.nn.Layer):
... def __init__(self):
... super().__init__()
... self.fc1 = paddle.nn.Linear(8, 8)
... self.fc2 = paddle.nn.Linear(8, 8)
...
... def forward(self, input):
... return self.fc2(self.fc1(input))
>>> layer = MLP()
>>> batch = paddle.rand(shape=[8, 8])
>>> opt = paddle.optimizer.AdamW(parameters=layer.parameters())
>>> opt = dist.shard_optimizer(opt)
>>> for _ in range(5):
>>> loss = layer(batch)
>>> loss.backward()
>>> opt.step()
>>> opt.clear_grad()
>>> # This case need to be executed in multi-card environment
>>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py