shard_optimizer

paddle.distributed. shard_optimizer ( optimizer, shard_fn=None ) [源代码]

将单卡视角的优化器转变为分布式视角。可以通过指定 shard_fn 来定制化优化器状态的切分方式,否则会将参数的分布式信息传递给对应的优化器状态。

shard_fn 的函数签名为:def shard_fn(accumulator_name, param, accumulator) -> sharded_accumulator。

参数

  • optimizer (paddle.optimizer.Optimizer) - 单卡视角的优化器。

  • shard_fn (Callable,可选) - 用于切分优化器状态函数。如果没有指定,默认地我们将参数的分布式信息传递给对应的优化器状态。

返回

Optimizer:一个具有分布式视角的 Optimizer 对象。

代码示例

>>> import paddle
>>> import paddle.distributed as dist
>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
>>> class MLP(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self.fc1 = paddle.nn.Linear(8, 8)
...         self.fc2 = paddle.nn.Linear(8, 8)
...
...     def forward(self, input):
...         return self.fc2(self.fc1(input))
>>> layer = MLP()
>>> batch = paddle.rand(shape=[8, 8])
>>> opt = paddle.optimizer.AdamW(parameters=layer.parameters())
>>> opt = dist.shard_optimizer(opt)
>>> for _ in range(5):
>>>     loss = layer(batch)
>>>     loss.backward()
>>>     opt.step()
>>>     opt.clear_grad()
>>> # This case need to be executed in multi-card environment
>>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py