PrepareLayerOutput¶
使用用户提供的函数,对标记 Layer 的输出进行处理。
参数¶
fn (callable,可选) - 用来处理标记 Layer 输出的函数,该函数需要接受并且仅接受一个参数 process_mesh ,并返回真正用来处理输出的函数。默认为 None。
代码示例
>>> import paddle
>>> import paddle.distributed as dist
>>> class MLP(paddle.nn.Layer):
... def __init__(self):
... super().__init__()
... self.fc1 = paddle.nn.Linear(8, 8)
... self.fc2 = paddle.nn.Linear(8, 8)
...
... def forward(self, input):
... return self.fc2(self.fc1(input))
>>> def layer_output_hook(process_mesh):
... def hook(layer, input, output):
... return output
... return hook
>>> layer = MLP()
>>> mp_config = {
... 'fc1': dist.PrepareLayerOutput(layer_output_hook)
... }