shard_layer

paddle.distributed. shard_layer ( layer, process_mesh, shard_fn=None, input_fn=None, output_fn=None )

根据参数 shard_fn 将传入的 paddle.nn.Layer 所有的参数转换为带有分布式切分信息的 Tensor。同时也支持指定 input_fnoutput_fn 用于控制输入和输出 Tensor 的转换。(具体指的是,将输入转换为带有分布式切分信息的 Tensor,将输出转回不带分布式切分信息的 Tensor。)

shard_fn 的函数签名为:def shard_fn(layer_name, layer, process_mesh) -> None。

input_fn 的函数签名为:def input_fn(inputs, process_mesh) -> list(paddle.Tensor),一般地,input_fn 返回值的类型为带有分布式切分信息的 Tensor

output_fn 的函数签名为:def output_fn(outputs, process_mesh) -> list(paddle.Tensor),一般地,output_fn 返回值的类型为不带分布式切分信息的 Tensor

参数

  • layer (paddle.nn.Layer) - 需要被切分的 Layer 对象。

  • process_mesh (paddle.distributed.ProcessMesh) - 执行当前 LayerProcessMesh 信息。

  • shard_fn (Callable) - 用于切分当前 Layer 参数的函数。如果没有指定,默认地我们将在当前 ProcessMesh 上复制所有的参数。

  • input_fn (Callable) - 指定如何切分 Layer 的输入。input_fn 函数将被注册为 Layer 的一个 forward pre-hook。默认我们将不会切分 Layer 的输入。

  • output_fn (Callable) - 指定如何切分 Layer 的输出,或者将 Layer 的输出转回不带分布式切分信息的 Tensoroutput_fn 函数将被注册为 Layer 的一个 forward post-hook。默认我们将不会切分或者转换 Layer 的输出。

返回

Layer:一个参数全部为带有分布式切分信息 TensorLayer 对象。

代码示例

COPY-FROM: paddle.distributed.shard_layer