\u200E
  • 开始使用
  • 特性
  • 文档
    • API
    • 使用指南
  • 工具平台
    • 工具
      • AutoDL
  • develop
  • 2.1
  • 2.0
  • 1.8
  • 1.7
  • 1.6
  • 1.5
  • 1.4
  • 1.3
  • 1.2
  • 1.1
  • 1.0
  • 0.15.0
  • 0.14.0
  • 0.13.0
  • 0.12.0
  • 0.11.0
  • 0.10.0
  • 中文(简)
  • English(En)
  • 安装指南
    • Pip安装
      • Linux下的PIP安装
      • MacOS下的PIP安装
      • Windows下的PIP安装
    • Conda安装
      • Linux下的Conda安装
      • MacOS下的Conda安装
      • Windows下的Conda安装
    • Docker安装
      • Linux下的Docker安装
      • MacOS下的Docker安装
    • 从源码编译
      • Linux下从源码编译
      • MacOS下从源码编译
      • Windows下从源码编译
      • 飞腾/鲲鹏下从源码编译
      • 申威下从源码编译
      • 兆芯下从源码编译
    • 昆仑XPU芯片安装及运行飞桨
    • 海光DCU芯片运行飞桨
    • 附录
  • 使用教程
    • 整体介绍
      • 基本概念
        • Tensor概念介绍
        • 广播 (broadcasting)
        • 自动微分机制介绍
        • 自动混合精度训练
      • 升级指南
      • 版本迁移工具
    • 模型开发
      • 10分钟快速上手飞桨(PaddlePaddle)
      • 数据集定义与加载
      • 数据预处理
      • 模型组网
      • 训练与预测
      • 单机多卡训练
      • 自定义指标
      • 模型保存与载入
      • 模型导出ONNX协议
    • VisualDL 工具
      • VisualDL 工具简介
      • VisualDL 使用指南
    • 动态图转静态图
      • 基本用法
      • 内部架构原理
      • 支持语法列表
      • InputSpec 功能介绍
      • 报错信息处理
      • 调试方法
    • 推理部署
      • 服务器部署 — Paddle Inference
      • 移动端/嵌入式部署 — Paddle Lite
      • 模型压缩 — PaddleSlim
    • 分布式训练
      • 分布式训练快速开始
      • 使用FleetAPI进行分布式训练
    • 自定义算子
      • 自定义原生算子
      • 原生算子开发注意事项
      • 自定义外部算子
      • 自定义Python算子
    • 算子映射
      • Paddle 1.8 与 Paddle 2.0 API映射表
      • PyTorch-PaddlePaddle API映射表
    • 硬件支持
      • 飞桨产品硬件支持表
      • 昆仑XPU芯片运行飞桨
        • 飞桨对昆仑XPU芯片的支持
        • 飞桨框架昆仑XPU版安装说明
        • 飞桨框架昆仑XPU版训练示例
        • 飞桨预测库昆仑XPU版安装及使用示例
      • 海光DCU芯片运行飞桨
        • 飞桨框架ROCm版支持模型
        • 飞桨框架ROCm版安装说明
        • 飞桨框架ROCm版训练示例
        • 飞桨框架ROCm版预测示例
      • 昇腾NPU芯片运行飞桨
        • 飞桨框架NPU版安装说明
        • 飞桨框架昇腾NPU版训练示例
    • 参与开发
      • 本地开发指南
      • 提交PR注意事项
      • FAQ
  • 应用实践
    • 快速上手
      • hello paddle: 从普通程序走向机器学习程序
      • 动态图
      • 飞桨高层API使用指南
      • 模型保存及加载
      • 使用线性回归预测波士顿房价
    • 计算机视觉
      • 使用LeNet在MNIST数据集实现图像分类
      • 使用卷积神经网络进行图像分类
      • 基于图片相似度的图片搜索
      • 基于U-Net卷积神经网络实现宠物图像分割
      • 通过OCR实现验证码识别
      • 通过Sub-Pixel实现图像超分辨率
    • 自然语言处理
      • 用N-Gram模型在莎士比亚文集中训练word embedding
      • IMDB 数据集使用BOW网络的文本分类
      • 使用注意力机制的LSTM的机器翻译
      • 使用序列到序列模型完成数字加法
    • 时序数据
      • 通过AutoEncoder实现时序数据异常检测
    • 强化学习
      • 强化学习——Actor Critic Method
      • 强化学习——Advantage Actor-Critic(A2C)
      • 强化学习——Deep Deterministic Policy Gradient (DDPG)
    • 推荐算法
      • 使用协同过滤实现电影推荐
  • API 文档
    • paddle
      • Overview
      • abs
      • acos
      • add
      • add_n
      • addmm
      • all
      • allclose
      • any
      • arange
      • argmax
      • argmin
      • argsort
      • asin
      • assign
      • atan
      • bernoulli
      • bmm
      • broadcast_shape
      • broadcast_to
      • cast
      • ceil
      • cholesky
      • chunk
      • clip
      • concat
      • conj
      • cos
      • cosh
      • CPUPlace
      • create_parameter
      • crop
      • cross
      • CUDAPinnedPlace
      • CUDAPlace
      • cumsum
      • DataParallel
      • diag
      • disable_static
      • dist
      • divide
      • dot
      • empty
      • empty_like
      • enable_static
      • equal
      • equal_all
      • erf
      • exp
      • expand
      • expand_as
      • eye
      • flatten
      • flip
      • floor
      • floor_divide
      • floor_mod
      • flops
      • full
      • full_like
      • gather
      • gather_nd
      • get_cuda_rng_state
      • get_default_dtype
      • grad
      • greater_equal
      • greater_than
      • histogram
      • imag
      • in_dynamic_mode
      • increment
      • index_sample
      • index_select
      • inverse
      • is_empty
      • is_tensor
      • isfinite
      • isinf
      • isnan
      • kron
      • less_equal
      • less_than
      • linspace
      • load
      • log
      • log10
      • log1p
      • log2
      • logical_and
      • logical_not
      • logical_or
      • logical_xor
      • logsumexp
      • masked_select
      • matmul
      • max
      • maximum
      • mean
      • median
      • meshgrid
      • min
      • minimum
      • mm
      • Model
      • multinomial
      • multiplex
      • multiply
      • mv
      • no_grad
      • nonzero
      • norm
      • normal
      • not_equal
      • NPUPlace
      • numel
      • ones
      • ones_like
      • ParamAttr
      • pow
      • prod
      • rand
      • randint
      • randn
      • randperm
      • rank
      • real
      • reciprocal
      • reshape
      • reshape_
      • roll
      • round
      • rsqrt
      • save
      • scale
      • scatter
      • scatter_
      • scatter_nd
      • scatter_nd_add
      • seed
      • set_cuda_rng_state
      • set_default_dtype
      • set_device
      • set_printoptions
      • shape
      • shard_index
      • sign
      • sin
      • sinh
      • slice
      • sort
      • split
      • sqrt
      • square
      • squeeze
      • squeeze_
      • stack
      • standard_normal
      • stanh
      • std
      • strided_slice
      • subtract
      • sum
      • summary
      • t
      • tan
      • tanh
      • tanh_
      • Tensor
      • tile
      • to_tensor
      • tolist
      • topk
      • trace
      • transpose
      • tril
      • triu
      • unbind
      • uniform
      • unique
      • unsqueeze
      • unsqueeze_
      • unstack
      • var
      • where
      • zeros
      • zeros_like
    • paddle.amp
      • Overview
      • auto_cast
      • GradScaler
    • paddle.autograd
      • backward
      • PyLayer
      • PyLayerContext
    • paddle.callbacks
      • Overview
      • Callback
      • EarlyStopping
      • LRScheduler
      • ModelCheckpoint
      • ProgBarLogger
      • ReduceLROnPlateau
      • VisualDL
    • paddle.compat
      • floor_division
      • get_exception_message
      • long_type
      • round
      • to_bytes
      • to_text
    • paddle.device
      • get_cudnn_version
      • get_device
      • is_compiled_with_cuda
      • is_compiled_with_npu
      • is_compiled_with_rocm
      • is_compiled_with_xpu
      • XPUPlace
    • paddle.distributed
      • Overview
      • all_gather
      • all_reduce
      • barrier
      • broadcast
      • fleet
        • DistributedStrategy
        • Fleet
        • PaddleCloudRoleMaker
        • UserDefinedRoleMaker
        • UtilBase
        • utils
          • HDFSClient
          • LocalFS
      • get_rank
      • get_world_size
      • init_parallel_env
      • InMemoryDataset
      • ParallelEnv
      • QueueDataset
      • recv
      • reduce
      • ReduceOp
      • scatter
      • send
      • spawn
      • split
    • paddle.distribution
      • Overview
      • Categorical
      • Distribution
      • Normal
      • Uniform
    • paddle.fluid
      • clip
        • ErrorClipByValue
        • set_gradient_clip
      • create_lod_tensor
      • create_random_int_lodtensor
      • cuda_pinned_places
      • data
      • DataFeedDesc
      • DataFeeder
      • dataset
        • DatasetFactory
        • InMemoryDataset
        • QueueDataset
      • dygraph
        • BilinearTensorProduct
        • Conv2D
        • Conv2DTranspose
        • Conv3D
        • Conv3DTranspose
        • Dropout
        • Embedding
        • enabled
        • GroupNorm
        • GRUCell
        • GRUUnit
        • LambdaDecay
        • LayerNorm
        • Linear
        • load_dygraph
        • LSTMCell
        • MultiStepDecay
        • NCE
        • Pool2D
        • PRelu
        • prepare_context
        • ReduceLROnPlateau
        • save_dygraph
        • StepDecay
        • TreeConv
      • evaluator
        • ChunkEvaluator
        • DetectionMAP
        • EditDistance
      • get_flags
      • initializer
        • Constant
        • ConstantInitializer
        • MSRA
        • MSRAInitializer
        • Normal
        • NumpyArrayInitializer
        • TruncatedNormal
        • Uniform
        • Xavier
      • io
        • get_program_parameter
        • get_program_persistable_vars
        • load_params
        • load_persistables
        • load_vars
        • PyReader
        • save_params
        • save_persistables
        • save_vars
        • shuffle
      • layers
        • adaptive_pool2d
        • adaptive_pool3d
        • add_position_encoding
        • affine_channel
        • affine_grid
        • anchor_generator
        • argmax
        • argmin
        • argsort
        • array_length
        • array_read
        • array_write
        • autoincreased_step_counter
        • BasicDecoder
        • beam_search
        • beam_search_decode
        • bipartite_match
        • box_clip
        • box_coder
        • box_decoder_and_assign
        • bpr_loss
        • brelu
        • Categorical
        • center_loss
        • collect_fpn_proposals
        • concat
        • continuous_value_model
        • cosine_decay
        • create_array
        • create_py_reader_by_data
        • create_tensor
        • crop
        • cross_entropy
        • ctc_greedy_decoder
        • cumsum
        • data
        • DecodeHelper
        • Decoder
        • deformable_conv
        • deformable_roi_pooling
        • density_prior_box
        • detection_output
        • diag
        • distribute_fpn_proposals
        • double_buffer
        • dropout
        • dynamic_gru
        • dynamic_lstm
        • dynamic_lstmp
        • DynamicRNN
        • edit_distance
        • elementwise_add
        • elementwise_div
        • elementwise_floordiv
        • elementwise_max
        • elementwise_min
        • elementwise_mod
        • elementwise_pow
        • elementwise_sub
        • elu
        • embedding
        • equal
        • expand
        • expand_as
        • exponential_decay
        • eye
        • fc
        • fill_constant
        • filter_by_instag
        • flatten
        • fsp_matrix
        • gather
        • gather_nd
        • gaussian_random
        • gelu
        • generate_mask_labels
        • generate_proposal_labels
        • generate_proposals
        • get_tensor_from_selected_rows
        • greater_equal
        • greater_than
        • GreedyEmbeddingHelper
        • grid_sampler
        • gru_unit
        • GRUCell
        • hard_shrink
        • hard_sigmoid
        • hard_swish
        • hash
        • hsigmoid
        • huber_loss
        • IfElse
        • im2sequence
        • image_resize
        • image_resize_short
        • inplace_abn
        • inverse_time_decay
        • iou_similarity
        • isfinite
        • kldiv_loss
        • l2_normalize
        • label_smooth
        • leaky_relu
        • less_equal
        • less_than
        • linear_chain_crf
        • linear_lr_warmup
        • locality_aware_nms
        • lod_append
        • lod_reset
        • lrn
        • lstm
        • lstm_unit
        • LSTMCell
        • margin_rank_loss
        • matmul
        • matrix_nms
        • maxout
        • mean
        • merge_selected_rows
        • mse_loss
        • mul
        • multiclass_nms
        • MultivariateNormalDiag
        • natural_exp_decay
        • noam_decay
        • Normal
        • not_equal
        • one_hot
        • ones
        • ones_like
        • pad
        • pad2d
        • pad_constant_like
        • piecewise_decay
        • pixel_shuffle
        • polygon_box_transform
        • polynomial_decay
        • pool2d
        • pool3d
        • prior_box
        • prroi_pool
        • psroi_pool
        • py_reader
        • random_crop
        • range
        • rank_loss
        • read_file
        • reduce_all
        • reduce_any
        • reduce_max
        • reduce_mean
        • reduce_min
        • reduce_prod
        • reduce_sum
        • relu
        • relu6
        • reorder_lod_tensor_by_rank
        • reshape
        • resize_bilinear
        • resize_nearest
        • resize_trilinear
        • retinanet_detection_output
        • retinanet_target_assign
        • reverse
        • rnn
        • RNNCell
        • roi_align
        • roi_perspective_transform
        • roi_pool
        • rpn_target_assign
        • SampleEmbeddingHelper
        • sampling_id
        • scatter
        • selu
        • shuffle_channel
        • sigmoid_cross_entropy_with_logits
        • sigmoid_focal_loss
        • sign
        • similarity_focus
        • size
        • smooth_l1
        • soft_relu
        • softmax
        • softshrink
        • space_to_depth
        • split
        • squeeze
        • ssd_loss
        • stack
        • StaticRNN
        • strided_slice
        • sum
        • sums
        • swish
        • Switch
        • target_assign
        • teacher_student_sigmoid_loss
        • tensor_array_to_tensor
        • thresholded_relu
        • topk
        • TrainingHelper
        • Uniform
        • uniform_random
        • unique
        • unique_with_counts
        • unsqueeze
        • warpctc
        • where
        • While
        • yolo_box
        • yolov3_loss
        • zeros
        • zeros_like
      • metrics
        • Accuracy
        • Auc
        • ChunkEvaluator
        • CompositeMetric
        • DetectionMAP
        • EditDistance
        • MetricBase
        • Precision
        • Recall
      • nets
        • glu
        • img_conv_group
        • scaled_dot_product_attention
        • sequence_conv_pool
        • simple_img_conv_pool
      • one_hot
      • optimizer
        • MomentumOptimizer
        • SGDOptimizer
      • reader
        • PyReader
      • regularizer
        • L1DecayRegularizer
        • L2DecayRegularizer
      • set_flags
      • transpiler
        • DistributeTranspiler
        • DistributeTranspilerConfig
        • HashName
        • memory_optimize
        • release_memory
    • paddle.hub
      • Overview
      • help
      • list
      • load
    • paddle.io
      • Overview
      • BatchSampler
      • ChainDataset
      • ComposeDataset
      • DataLoader
      • Dataset
      • DistributedBatchSampler
      • get_worker_info
      • IterableDataset
      • random_split
      • RandomSampler
      • Sampler
      • SequenceSampler
      • Subset
      • TensorDataset
      • WeightedRandomSampler
    • paddle.jit
      • Overview
      • load
      • ProgramTranslator
      • save
      • set_code_level
      • set_verbosity
      • to_static
      • TracedLayer
      • TranslatedLayer
    • paddle.metric
      • Overview
      • Accuracy
      • accuracy
      • Auc
      • Metric
      • Precision
      • Recall
    • paddle.nn
      • Overview
      • AdaptiveAvgPool1D
      • AdaptiveAvgPool2D
      • AdaptiveAvgPool3D
      • AdaptiveMaxPool1D
      • AdaptiveMaxPool2D
      • AdaptiveMaxPool3D
      • AlphaDropout
      • AvgPool1D
      • AvgPool2D
      • AvgPool3D
      • BatchNorm
      • BatchNorm1D
      • BatchNorm2D
      • BatchNorm3D
      • BCELoss
      • BCEWithLogitsLoss
      • BeamSearchDecoder
      • Bilinear
      • BiRNN
      • ClipGradByGlobalNorm
      • ClipGradByNorm
      • ClipGradByValue
      • Conv1D
      • Conv1DTranspose
      • Conv2D
      • Conv2DTranspose
      • Conv3D
      • Conv3DTranspose
      • CosineSimilarity
      • CrossEntropyLoss
      • CTCLoss
      • Dropout
      • Dropout2D
      • Dropout3D
      • dynamic_decode
      • ELU
      • Embedding
      • Flatten
      • functional
        • adaptive_avg_pool1d
        • adaptive_avg_pool2d
        • adaptive_avg_pool3d
        • adaptive_max_pool1d
        • adaptive_max_pool2d
        • adaptive_max_pool3d
        • affine_grid
        • alpha_dropout
        • avg_pool1d
        • avg_pool2d
        • avg_pool3d
        • batch_norm
        • bilinear
        • binary_cross_entropy
        • binary_cross_entropy_with_logits
        • conv1d
        • conv1d_transpose
        • conv2d
        • conv2d_transpose
        • conv3d
        • conv3d_transpose
        • cosine_similarity
        • cross_entropy
        • ctc_loss
        • diag_embed
        • dice_loss
        • dropout
        • dropout2d
        • dropout3d
        • elu
        • elu_
        • embedding
        • gather_tree
        • gelu
        • grid_sample
        • hardshrink
        • hardsigmoid
        • hardswish
        • hardtanh
        • hsigmoid_loss
        • instance_norm
        • interpolate
        • kl_div
        • l1_loss
        • label_smooth
        • layer_norm
        • leaky_relu
        • linear
        • local_response_norm
        • log_loss
        • log_sigmoid
        • log_softmax
        • margin_ranking_loss
        • max_pool1d
        • max_pool2d
        • max_pool3d
        • maxout
        • mse_loss
        • nll_loss
        • normalize
        • npair_loss
        • one_hot
        • pad
        • pixel_shuffle
        • prelu
        • relu
        • relu6
        • relu_
        • selu
        • sequence_mask
        • sigmoid
        • sigmoid_focal_loss
        • silu
        • smooth_l1_loss
        • softmax
        • softmax_
        • softmax_with_cross_entropy
        • softplus
        • softshrink
        • softsign
        • square_error_cost
        • swish
        • tanhshrink
        • temporal_shift
        • thresholded_relu
        • unfold
        • upsample
      • GELU
      • GroupNorm
      • GRU
      • GRUCell
      • Hardshrink
      • Hardsigmoid
      • Hardswish
      • Hardtanh
      • HSigmoidLoss
      • initializer
        • Assign
        • Bilinear
        • Constant
        • KaimingNormal
        • KaimingUniform
        • Normal
        • set_global_initializer
        • TruncatedNormal
        • Uniform
        • XavierNormal
        • XavierUniform
      • InstanceNorm1D
      • InstanceNorm2D
      • InstanceNorm3D
      • KLDivLoss
      • L1Loss
      • Layer
      • LayerDict
      • LayerList
      • LayerNorm
      • LeakyReLU
      • Linear
      • LocalResponseNorm
      • LogSigmoid
      • LogSoftmax
      • LSTM
      • LSTMCell
      • MarginRankingLoss
      • Maxout
      • MaxPool1D
      • MaxPool2D
      • MaxPool3D
      • MSELoss
      • MultiHeadAttention
      • NLLLoss
      • Pad1D
      • Pad2D
      • Pad3D
      • PairwiseDistance
      • ParameterList
      • PixelShuffle
      • PReLU
      • ReLU
      • ReLU6
      • RNN
      • RNNCellBase
      • SELU
      • Sequential
      • Sigmoid
      • Silu
      • SimpleRNN
      • SimpleRNNCell
      • SmoothL1Loss
      • Softmax
      • Softplus
      • Softshrink
      • Softsign
      • SpectralNorm
      • Swish
      • SyncBatchNorm
      • Tanh
      • Tanhshrink
      • ThresholdedReLU
      • Transformer
      • TransformerDecoder
      • TransformerDecoderLayer
      • TransformerEncoder
      • TransformerEncoderLayer
      • Unfold
      • Upsample
      • UpsamplingBilinear2D
      • UpsamplingNearest2D
      • utils
        • remove_weight_norm
        • spectral_norm
        • weight_norm
    • paddle.onnx
      • export
    • paddle.optimizer
      • Overview
      • Adadelta
      • Adagrad
      • Adam
      • Adamax
      • AdamW
      • Lamb
      • lr
        • CosineAnnealingDecay
        • ExponentialDecay
        • InverseTimeDecay
        • LambdaDecay
        • LinearWarmup
        • LRScheduler
        • MultiStepDecay
        • NaturalExpDecay
        • NoamDecay
        • PiecewiseDecay
        • PolynomialDecay
        • ReduceOnPlateau
        • StepDecay
      • Momentum
      • Optimizer
      • RMSProp
      • SGD
    • paddle.regularizer
      • L1Decay
      • L2Decay
    • paddle.static
      • Overview
      • accuracy
      • append_backward
      • auc
      • BuildStrategy
      • CompiledProgram
      • cpu_places
      • create_global_var
      • cuda_places
      • data
      • default_main_program
      • default_startup_program
      • deserialize_persistables
      • deserialize_program
      • device_guard
      • ExecutionStrategy
      • Executor
      • global_scope
      • gradients
      • InputSpec
      • load
      • load_from_file
      • load_inference_model
      • load_program_state
      • name_scope
      • nn
        • batch_norm
        • bilinear_tensor_product
        • case
        • conv2d
        • conv2d_transpose
        • conv3d
        • conv3d_transpose
        • crf_decoding
        • data_norm
        • deform_conv2d
        • embedding
        • fc
        • group_norm
        • instance_norm
        • layer_norm
        • multi_box_head
        • nce
        • prelu
        • row_conv
        • sequence_concat
        • sequence_conv
        • sequence_enumerate
        • sequence_expand
        • sequence_expand_as
        • sequence_first_step
        • sequence_last_step
        • sequence_pad
        • sequence_pool
        • sequence_reshape
        • sequence_reverse
        • sequence_scatter
        • sequence_slice
        • sequence_softmax
        • sequence_unpad
        • spectral_norm
        • switch_case
        • while_loop
      • normalize_program
      • ParallelExecutor
      • Print
      • Program
      • program_guard
      • py_func
      • save
      • save_inference_model
      • save_to_file
      • scope_guard
      • serialize_persistables
      • serialize_program
      • set_program_state
      • Variable
      • WeightNormParamAttr
      • xpu_places
    • paddle.sysconfig
      • get_include
      • get_lib
    • paddle.text
      • Overview
      • Conll05st
      • Imdb
      • Imikolov
      • Movielens
      • UCIHousing
      • WMT14
      • WMT16
    • paddle.utils
      • Overview
      • cpp_extension
        • CppExtension
        • CUDAExtension
        • get_build_directory
        • load
        • setup
      • deprecated
      • download
        • get_weights_path_from_url
      • profiler
        • cuda_profiler
        • profiler
        • reset_profiler
        • start_profiler
        • stop_profiler
      • require_version
      • run_check
      • unique_name
        • generate
        • guard
        • switch
    • paddle.vision
      • Overview
      • datasets
        • Cifar10
        • Cifar100
        • DatasetFolder
        • FashionMNIST
        • Flowers
        • ImageFolder
        • MNIST
        • VOC2012
      • get_image_backend
      • image_load
      • models
        • LeNet
        • mobilenet_v1
        • mobilenet_v2
        • MobileNetV1
        • MobileNetV2
        • ResNet
        • resnet101
        • resnet152
        • resnet18
        • resnet34
        • resnet50
        • VGG
        • vgg11
        • vgg13
        • vgg16
        • vgg19
      • ops
        • deform_conv2d
        • DeformConv2D
        • yolo_box
        • yolo_loss
      • set_image_backend
      • transforms
        • adjust_brightness
        • adjust_contrast
        • adjust_hue
        • BaseTransform
        • BrightnessTransform
        • center_crop
        • CenterCrop
        • ColorJitter
        • Compose
        • ContrastTransform
        • crop
        • Grayscale
        • hflip
        • HueTransform
        • normalize
        • Normalize
        • pad
        • Pad
        • RandomCrop
        • RandomHorizontalFlip
        • RandomResizedCrop
        • RandomRotation
        • RandomVerticalFlip
        • resize
        • Resize
        • rotate
        • SaturationTransform
        • to_grayscale
        • to_tensor
        • ToTensor
        • Transpose
        • vflip
  • 常见问题与解答
    • 2.0 升级常见问题
    • 安装常见问题
    • 数据及其加载常见问题
    • 组网、训练、评估常见问题
    • 模型保存常见问题
    • 参数调整常见问题
    • 分布式训练常见问题
    • 其他常见问题
  • Release Note
  • beam_search
  • »
  • beam_search
  • 在 GitHub 上修改

beam_search¶

paddle.fluid.layers. beam_search ( pre_ids, pre_scores, ids, scores, beam_size, end_id, level=0, is_accumulated=True, name=None, return_parent_idx=False ) [源代码] ¶

束搜索(Beam search)是在机器翻译等生成任务中选择候选词的一种经典算法

更多细节参考 Beam Search

该OP仅支持LoDTensor,在计算产生得分之后使用,完成单个时间步内的束搜索。具体而言,在计算部分产生 ids 和 scores 后,对于每个源句(样本)该OP从 ids 中根据其对应的 scores 选择当前时间步 top-K (K 是 beam_size)的候选词id。而 pre_id 和 pre_scores 是上一时间步 beam_search 的输出,加入输入用于特殊处理到达结束的翻译候选。

注意,如果 is_accumulated 为 True,传入的 scores 应该是累积分数。反之,scores 是单步得分,会在该OP内被转化为log值并累积到 pre_scores 作为最终得分。如需使用长度惩罚,应在计算累积分数前使用其他OP完成。

束搜索的完整用法请参阅以下示例:

fluid/tests/book/test_machine_translation.py

参数:
  • pre_ids (Variable) - LoD level为2的LodTensor,表示前一时间步选择的候选id,是前一时间步 beam_search 的输出。第一步时,其形状应为为 \([batch\_size,1]\) , lod应为 \([[0,1,...,batch\_size],[0,1,...,batch\_size]]\) 。数据类型为int64。

  • pre_scores (Variable) - 维度和LoD均与 pre_ids 相同的LodTensor,表示前一时间步所选id对应的累积得分,是前一时间步 beam_search 的输出。数据类型为float32。

  • ids (None|Variable) - 包含候选id的LodTensor。LoD应与 pre_ids 相同,形状为 \([batch\_size \times beam\_size,K]\) ,其中第一维大小与 pre_ids 相同且``batch_size`` 会随样本到达结束而自动减小, K 应该大于 beam_size 。数据类型为int64。可为空,为空时使用 scores 上的索引作为id。

  • scores (Variable) - 表示 ids 对应的累积分数的LodTensor变量, 维度和LoD均与 ids 相同。

  • beam_size (int) - 指明束搜索中的束宽度。

  • end_id (int) - 指明标识序列结束的id。

  • level (int,可选) - 可忽略,当前不能更改 。知道LoD level为2即可,两层lod的意义如下: 第一级表示每个源句(样本)包含的beam大小,若满足结束条件(达到 beam_size 个结束)则变为0;第二级是表示每个beam被选择的次数。

  • is_accumulated (bool,可选) - 指明输入分数 scores 是否为累积分数,默认为True。

  • name (str,可选) – 具体用法请参见 Name ,一般无需设置,默认值为None。

  • return_parent_idx (bool,可选) - 指明是否返回一个额外的Tensor,该Tensor保存了选择的id的父节点(beam)在 pre_id 中索引,可用于通过gather OP更新其他Tensor的内容。默认为False。

返回:Variable的二元组或三元组。二元组中包含了当前时间步选择的id和对应的累积得分两个LodTensor,形状相同且均为 \([batch\_size×beam\_size,1]\) ,LoD相同且level均为2,数据类型分别为int64和float32;若 return_parent_idx 为True时为三元组,多返回一个保存了父节点在 pre_id 中索引的Tensor,形状为 \([batch\_size \times beam\_size]\) ,数据类型为int64。

返回类型:tuple

代码示例

import paddle.fluid as fluid

# 假设 `probs` 包含计算神经元所得的预测结果
# `pre_ids` 和 `pre_scores` 为beam_search之前时间步的输出
beam_size = 4
end_id = 1
pre_ids = fluid.layers.data(
    name='pre_id', shape=[1], lod_level=2, dtype='int64')
pre_scores = fluid.layers.data(
    name='pre_scores', shape=[1], lod_level=2, dtype='float32')
probs = fluid.layers.data(
    name='probs', shape=[10000], dtype='float32')
topk_scores, topk_indices = fluid.layers.topk(probs, k=beam_size)
accu_scores = fluid.layers.elementwise_add(
                                      x=fluid.layers.log(x=topk_scores),
                                      y=fluid.layers.reshape(
                                          pre_scores, shape=[-1]),
                                      axis=0)
selected_ids, selected_scores = fluid.layers.beam_search(
                                      pre_ids=pre_ids,
                                      pre_scores=pre_scores,
                                      ids=topk_indices,
                                      scores=accu_scores,
                                      beam_size=beam_size,
                                      end_id=end_id)