SplitPoint

class paddle.distributed. SplitPoint [源代码]

用于流水线并行下切分位置的确认。目前支持 BEGINNINGEND

BEGINNING 表明在标识的 Layer 之前进行切分。

END 表明在标识的 Layer 之后进行切分。

代码示例

>>> import paddle
>>> import paddle.distributed as dist

>>> class MLP(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self.fc1 = paddle.nn.Linear(8, 8)
...         self.fc2 = paddle.nn.Linear(8, 8)
...
...     def forward(self, input):
...         return self.fc2(self.fc1(input))

>>> layer = MLP()
>>> pp_config = {
...     'fc1': dist.SplitPoint.END
... }