SplitPoint¶
用于流水线并行下切分位置的确认。目前支持 BEGINNING
和 END
。
BEGINNING
表明在标识的 Layer 之前进行切分。
END
表明在标识的 Layer 之后进行切分。
代码示例
>>> import paddle
>>> import paddle.distributed as dist
>>> class MLP(paddle.nn.Layer):
... def __init__(self):
... super().__init__()
... self.fc1 = paddle.nn.Linear(8, 8)
... self.fc2 = paddle.nn.Linear(8, 8)
...
... def forward(self, input):
... return self.fc2(self.fc1(input))
>>> layer = MLP()
>>> pp_config = {
... 'fc1': dist.SplitPoint.END
... }