PrepareLayerOutput

class paddle.distributed. PrepareLayerOutput ( fn=None ) [源代码]

使用用户提供的函数,对标记 Layer 的输出进行处理。

参数

  • fn (callable,可选) - 用来处理标记 Layer 输出的函数,该函数需要接受并且仅接受一个参数 process_mesh ,并返回真正用来处理输出的函数。默认为 None。

代码示例

>>> import paddle
>>> import paddle.distributed as dist

>>> class MLP(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self.fc1 = paddle.nn.Linear(8, 8)
...         self.fc2 = paddle.nn.Linear(8, 8)
...
...     def forward(self, input):
...         return self.fc2(self.fc1(input))

>>> def layer_output_hook(process_mesh):
...     def hook(layer, input, output):
...         return output
...     return hook

>>> layer = MLP()
>>> mp_config = {
...     'fc1': dist.PrepareLayerOutput(layer_output_hook)
... }