to_distributed

paddle.distributed. to_distributed ( model, optimizer, dataloader, device_num, node_num=1, config=None ) [源代码]

能够自动地将没有包含任何分布式代码的神经网络、优化器、数据加载器 转化为适合分布式运行的 神经网络、优化器、数据加载器 并确保正确性,同时转化过程中会根据机器数和每台机器的设备数自动选择最优的分布式策略以尽可能提升性能。

注解

此接口处于原型试用阶段,支持部分模型结构在单机多卡运行,后续会扩大支持的模型范围及支持多机多卡运行。

参数

  • model (paddle.nn.Layer) - 单卡视角的模型,没有包含任何分布式代码。

  • optimizer (paddle.optimizer.Optimizer) - 单卡视角的优化器,通过常规优化器 API 构造,如 paddle.optimizer.Adam

  • dataloader (paddle.io.DataLoader) - 单卡视角的数据加载器,通过常规方式沟通,如 paddle.io.Datasetpaddle.io.Sampler, 无需使用 paddle.io.DistributedBatchSampler

  • config (ToDistributedConfig,可选) - 可以用来配置 输入数据信息 和 是否使用序列并行。配置时使用数据类 paddle.distributed.auto_parallel.high_level_api.ToDistributedConfig 来完成。

    配置 输入数据信息,是提供模型训练时最有可能输入数据的 shape、dtype 和 stop_gradient 信息,便于更快更准地自动选择最优的分布式策略。

    配置 是否使用序列并行,可以指定如果最优的分布式策略中包含模型并行时,是否要使用序列并行。

返回

Model:一个具有分布式信息的 paddle.nn.Layer 对象,根据自动选择的最优分布式策略,可能包含分布式化的权重参数。

Optimizer:一个 Optimizer 对象,根据自动选择的最优分布式策略,可能包含分布式化的优化器状态。

DataLoader:一个 ShardDataloader 对象。能够给后续的分布式训练提供输入数据。

代码示例

>>> import math
>>> import numpy as np
>>> import paddle
>>> import paddle.nn.functional as F
>>> from paddle import nn
>>> from paddle.distributed import to_distributed
>>> from paddle.distributed.auto_parallel.high_level_api import ToDistributedConfig

>>> EPOCHS = 1
>>> VOCAB_SIZE = 8000
>>> BATCH_NUM = 2
>>> BATCH_SIZE = 4
>>> HIDDEN_SIZE = 2048
>>> INTERMEDIATE_SIZE = 4096
>>> SEQ_LENGTH = 1024
>>> N_HEAD = 32
>>> NUM_HIDDEN_LAYERS = 4
>>> class RandomDataset(paddle.io.Dataset):
...     def __init__(self, inputs, labels, num_samples):
...         self.inputs = inputs
...         self.labels = labels
...         self.num_samples = num_samples
...     def __getitem__(self, idx):
...         return self.inputs[idx], self.labels[idx]
...     def __len__(self):
...         return self.num_samples

>>> class RotaryEmbedding(nn.Layer):
...     def __init__(self, dim, max_position_embeddings=2048, base=10000):
...         super().__init__()
...         self.dim = dim
...         self.max_position_embeddings = max_position_embeddings
...         self.base = base
...         self.inv_freq = 1.0 / (
...             self.base ** (
...                 paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32")
...                 / self.dim
...             )
...         )
...         self._set_cos_sin_cache(seq_len=max_position_embeddings)

...     def _set_cos_sin_cache(self, seq_len):
...         self.max_seq_len_cached = seq_len
...         t = paddle.arange(seq_len, dtype="float32")
...         freqs = paddle.einsum("i,j->ij", t, self.inv_freq)
...         emb = paddle.concat([freqs, freqs], axis=-1)
...         self.cos_cached = emb.cos()[None, :, None, :]
...         self.sin_cached = emb.sin()[None, :, None, :]

...     def forward(self, x, seq_len=None):
...         cos = self.cos_cached[:, :seq_len, :, :]
...         sin = self.sin_cached[:, :seq_len, :, :]
...         return (
...             cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
...             sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
...         )

>>> def rotate_half(x):
...     x1 = x[..., : x.shape[-1] // 2]
...     x2 = x[..., x.shape[-1] // 2 :]
...     return paddle.concat([-x2, x1], axis=-1)

>>> def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
...     if position_ids is None:
...         cos = cos[:, : q.shape[1], :, :]
...         sin = sin[:, : q.shape[1], :, :]
...     else:
...         cos = cos.squeeze(axis=[0, 2])
...         sin = sin.squeeze(axis=[0, 2])
...         cos = cos[position_ids].unsqueeze(2)
...         sin = sin[position_ids].unsqueeze(2)
...     q_embed = (q * cos) + (rotate_half(q) * sin)
...     k_embed = (k * cos) + (rotate_half(k) * sin)
...     return q_embed, k_embed

>>> def scaled_dot_product_attention(
...     query_states,
...     key_states,
...     value_states,
...     attention_mask,
... ):
...     bsz, q_len, num_heads, head_dim = query_states.shape
...     _, kv_seq_len, _, _ = value_states.shape
...     query_states = paddle.transpose(query_states, [0, 2, 1, 3])
...     key_states = paddle.transpose(key_states, [0, 2, 1, 3])
...     value_states = paddle.transpose(value_states, [0, 2, 1, 3])
...     attn_weights = paddle.matmul(
...         query_states / math.sqrt(head_dim), key_states.transpose([0, 1, 3, 2])
...     )
...     attention_mask = attention_mask.reshape([bsz, 1, q_len, kv_seq_len])
...     attn_weights = attn_weights + attention_mask
...     if not paddle.in_dynamic_mode():
...         attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(
...             query_states.dtype
...         )
...     else:
...         with paddle.amp.auto_cast(False):
...             attn_weights = F.softmax(
...                 attn_weights, axis=-1, dtype="float32"
...             ).astype(query_states.dtype)
...     attn_output = paddle.matmul(attn_weights, value_states)
...     attn_output = attn_output.transpose([0, 2, 1, 3])
...     attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
...     return attn_output

>>> class Attention(nn.Layer):
...     def __init__(self, hidden_size=HIDDEN_SIZE, n_head=N_HEAD):
...         super().__init__()
...         self.hidden_size = hidden_size
...         self.num_heads = n_head
...         self.head_dim = hidden_size // n_head
...         self.q_proj = nn.Linear(
...             hidden_size, hidden_size, bias_attr=False
...         )
...         self.k_proj = nn.Linear(
...             hidden_size, hidden_size, bias_attr=False
...         )
...         self.v_proj = nn.Linear(
...             hidden_size, hidden_size, bias_attr=False
...         )
...         self.o_proj = nn.Linear(
...             hidden_size, hidden_size, bias_attr=False
...         )
...         self.rotary_emb = RotaryEmbedding(
...             self.head_dim, max_position_embeddings=SEQ_LENGTH, base=10000
...         )

...     def forward(
...         self,
...         hidden_states,
...         position_ids=None,
...         attention_mask=None,
...     ):
...         query_states = self.q_proj(hidden_states)
...         key_states = self.k_proj(hidden_states)
...         value_states = self.v_proj(hidden_states)
...         target_query_shape = [0, 0, self.num_heads, self.head_dim]
...         target_key_value_shape = [0, 0, self.num_heads, self.head_dim]
...         query_states = query_states.reshape(shape=target_query_shape)
...         key_states = key_states.reshape(shape=target_key_value_shape)
...         value_states = value_states.reshape(shape=target_key_value_shape)
...         kv_seq_len = key_states.shape[-3]
...         cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
...         query_states, key_states = apply_rotary_pos_emb(
...             query_states, key_states, cos, sin, position_ids
...         )
...         output = scaled_dot_product_attention(
...             query_states,
...             key_states,
...             value_states,
...             attention_mask,
...         )
...         attn_output = output
...         attn_output = self.o_proj(attn_output)
...         return attn_output

>>> class Mlp(nn.Layer):
...     def __init__(
...         self,
...         hidden_size=HIDDEN_SIZE,
...         intermediate_size=INTERMEDIATE_SIZE,
...     ):
...         super().__init__()
...         self.hidden_size = hidden_size
...         self.intermediate_size = intermediate_size
...         self.gate_proj = nn.Linear(
...             hidden_size, intermediate_size, bias_attr=False
...         )
...         self.up_proj = nn.Linear(
...             hidden_size, intermediate_size, bias_attr=False
...         )
...         self.down_proj = nn.Linear(
...             intermediate_size, hidden_size, bias_attr=False
...         )

...     def forward(self, x):
...         x = paddle.incubate.nn.functional.swiglu(
...             self.gate_proj(x), self.up_proj(x)
...         )
...         out = self.down_proj(x)
...         return out

>>> class RMSNorm(nn.Layer):
...     def __init__(self, hidden_size=HIDDEN_SIZE):
...         super().__init__()
...         self.hidden_size = hidden_size
...         self.weight = paddle.create_parameter(
...             shape=[self.hidden_size],
...             dtype=paddle.get_default_dtype(),
...             default_initializer=nn.initializer.Constant(1.0),
...         )
...         self.variance_epsilon = 1.0

...     def forward(self, hidden_states):
...         with paddle.amp.auto_cast(False):
...             hidden_states = hidden_states.astype("float32")
...             variance = hidden_states.pow(2).mean(-1, keepdim=True)
...             hidden_states = (
...                 paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
...             )
...         if self.weight.dtype in [paddle.float16, paddle.bfloat16]:
...             hidden_states = paddle.cast(hidden_states, self.weight.dtype)
...         return hidden_states * self.weight

>>> class DecoderLayer(nn.Layer):
...     def __init__(
...         self,
...         hidden_size=HIDDEN_SIZE,
...         intermediate_size=INTERMEDIATE_SIZE,
...     ):
...         super().__init__()
...         self.hidden_size = hidden_size
...         self.intermediate_size = intermediate_size
...         self.self_attn = Attention(hidden_size)
...         self.mlp = Mlp()
...         self.input_layernorm = RMSNorm(hidden_size)
...         self.post_attn_layernorm = RMSNorm(hidden_size)

...     def forward(
...         self,
...         hidden_states,
...         position_ids=None,
...         attention_mask=None,
...     ):
...         residual = hidden_states
...         hidden_states = self.input_layernorm(hidden_states)
...         hidden_states = self.self_attn(
...             hidden_states, position_ids, attention_mask
...         )
...         hidden_states = residual + hidden_states
...         residual = hidden_states
...         hidden_states = self.post_attn_layernorm(hidden_states)
...         hidden_states = self.mlp(hidden_states)
...         hidden_states = residual + hidden_states
...         return hidden_states

>>> def _prepare_decoder_attention_mask(
...     attention_mask, input_shape, dtype
... ):
...     batch_size, src_length = attention_mask.shape[0], attention_mask.shape[-1]
...     batch_size, target_length = input_shape
...     attention_mask = attention_mask[:, None, None, :].astype("bool")
...     attention_mask.stop_gradient = True
...     expanded_attn_mask = attention_mask.expand([batch_size, 1, target_length, src_length])
...     mask = paddle.tril(paddle.ones((target_length, target_length), dtype="bool"))
...     combined_attention_mask = mask[None, None, :, :].expand(
...         [batch_size, 1, target_length, target_length]
...     )
...     expanded_attn_mask = (expanded_attn_mask & combined_attention_mask)
...     expanded_attn_mask = paddle.where(
...         expanded_attn_mask, 0.0, paddle.finfo(dtype).min
...     ).astype(dtype)
...     return expanded_attn_mask

>>> class Model(nn.Layer):
...     def __init__(
...         self,
...         vocab_size=VOCAB_SIZE,
...         hidden_size=HIDDEN_SIZE,
...         intermediate_size=INTERMEDIATE_SIZE,
...     ):
...         super().__init__()
...         self.vocab_size = vocab_size
...         self.hidden_size = hidden_size
...         self.intermediate_size = intermediate_size
...         self.embed_tokens = nn.Embedding(
...             vocab_size,
...             hidden_size,
...         )
...         self.layers = nn.LayerList(
...             [
...                 DecoderLayer()
...                 for i in range(NUM_HIDDEN_LAYERS)
...             ]
...         )
...         self.norm = RMSNorm(hidden_size)
...         self.weight = self.create_parameter(
...             shape=[hidden_size, vocab_size],
...             dtype=paddle.get_default_dtype(),
...         )
...         self.ignore_index = -100
...         self.loss_func = paddle.nn.CrossEntropyLoss(
...             reduction="none", ignore_index=self.ignore_index
...         )

...     def forward(
...         self,
...         input_ids=None,
...         position_ids=None,
...         attention_mask=None,
...         labels=None,
...     ):
...         batch_size, seq_length = input_ids.shape
...         inputs_embeds = self.embed_tokens(input_ids)
...         attention_mask = paddle.ones(
...             (batch_size, seq_length), dtype=paddle.bool
...         )
...         if position_ids is None:
...             position_ids = paddle.arange(seq_length, dtype="int64").expand(
...                 (batch_size, seq_length)
...             )
...         attention_mask = _prepare_decoder_attention_mask(
...             attention_mask,
...             (batch_size, seq_length),
...             inputs_embeds.dtype,
...         )
...         hidden_states = inputs_embeds
...         for idx, (decoder_layer) in enumerate(self.layers):
...             layer_outputs = decoder_layer(
...                 hidden_states,
...                 position_ids,
...                 attention_mask,
...             )
...             hidden_states = layer_outputs
...         hidden_states = self.norm(hidden_states)
...         logits = paddle.matmul(hidden_states, self.weight)
...         loss = None
...         if labels is not None:
...             masked_lm_loss = self.loss_func(
...                 logits.astype("float32"),
...                 labels.unsqueeze(2),
...             )
...             binary_sequence = paddle.where(
...                 masked_lm_loss > 0,
...                 paddle.ones_like(masked_lm_loss),
...                 paddle.zeros_like(masked_lm_loss),
...             )
...             count = paddle.sum(binary_sequence)
...             if count == 0:
...                 loss = paddle.sum(masked_lm_loss * binary_sequence)
...             else:
...                 loss = paddle.sum(masked_lm_loss * binary_sequence) / count
...         return (loss, logits)

>>> model = Model() # There is no distributed code or markup in Model
>>> input_seqs = np.random.randint(
...     low=0, high=1024, size=(BATCH_SIZE * BATCH_NUM, SEQ_LENGTH)
... ).astype("int64")
>>> labels = np.random.randint(
...     low=0, high=1024, size=(BATCH_SIZE * BATCH_NUM, SEQ_LENGTH)
... ).astype("int64")
>>> dataset = RandomDataset(
...     input_seqs, labels, BATCH_SIZE * BATCH_NUM
... )
>>> sampler = paddle.io.BatchSampler(
...     dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True
... )
>>> loader = paddle.io.DataLoader(
...     dataset, batch_sampler=sampler
... )
>>> opt = paddle.optimizer.SGD(
...     learning_rate=0.1, parameters=model.parameters()
... )
>>> input_seq_spec = paddle.static.InputSpec(
...     [BATCH_SIZE, SEQ_LENGTH], 'float32', 'input_seq', True
... )
>>> dist_config = ToDistributedConfig()
>>> dist_config.sequence_parallel = True

>>> # wrap model, opt, dataloader by using **to_distributed**
>>> dist_model, dist_opt, dist_loader = to_distributed(
...     model,
...     opt,
...     loader,
...     device_num=8,
...     node_num=1,
...     config=dist_config,
... )

>>> for epoch in range(EPOCHS):
...     dist_model.train()
...     for i, data in enumerate(dist_loader()):
...         inputs, labels = data
...         loss, _ = dist_model(inputs, labels=labels)
...         print(f"epoch {epoch}, step {i}: loss {loss}")
...         loss.backward()
...         dist_opt.step()
...         dist_opt.clear_grad()
>>> # This case need to be executed in multi-card environment
>>> # python -m paddle.distributed.launch --gpus=0,1,2,3,4,5,6,7 {test_case}.py