城市街景分割数据集下使用对抗网络Pix2Pix根据掩码生成街景

作者信息:Liyulingyue

更新日期:2022 年 11 月 18 日

1. 简要介绍

关于城市街景分割数据集下使用对抗网络Pix2Pix根据掩码生成街景

本项目的目标是,根据分割的伪色彩图生成对应的街道场景。以下图为例,左边是城市街道拍摄图片,右边是分割图片,本项目的目标是根据右图生成左图。

https://ai-studio-static-online.cdn.bcebos.com/0312d38398f5453f928b12b14bf794d9513a2bd554874c57890cda0cc3caa2e9

关于GAN

GAN(Generative Adversarial Networks, 对抗生成网络)是Ian J. Goodfellow等在2014年提出的一种建立生成模型的框架。在这个框架中,用于刻画目标数据的生成模型(Generative model, 下文简称G)和用于鉴别数据是否真实的鉴别模型(Discriminative model, 下文简称D)被交替训练,最终使得G能够完美刻画目标数据的分布,而D的误判率趋近于1/2。

GAN的训练框架可以描述如下:

1. 抽取一组真实数据
2. 根据随机数,通过G生成假数据
3. 训练鉴别器D
4. 将生成数据输入D,将D的输出和正标签求loss
5. 根据第四步的loss更新G
6. 重复第一到第五步直到收敛

关于Pix2Pix

相比于纯粹的生成,Pix2Pix更多的聚焦于风格迁移,简单来说,就是保持数据的整体形状等信息不变。以绘画为例,Pix2Pix做到的事情类似于一个画手,用素描的方式临摹了梵高的油画。因此,Pix2Pix使用的网络为UNet,也可以替换为其他的基于像素的网络。

关于Paddle的梯度计算

Paddle动态图模式梯度计算是自动累加模式。因此,在运行GAN(对抗生成网络)的时候,可以多个batch/多个模块的loss单独进行backward,但也需要需要注意清除梯度和阻断方向传播。

关于梯度累加和阻断方向传播的介绍如下:

  1. 梯度累加:梯度累加是指在模型训练过程中,训练一个 batch 的数据得到梯度后,不立即用该梯度更新模型参数,而是继续下一个 batch 数据的训练,得到梯度后继续循环,多次循环后梯度不断累加,直至达到一定次数后,用累加的梯度更新参数,这样可以起到变相扩大 batch_size 的作用。受限于显存大小,可能无法开到更大的 batch_size,使用梯度累加可以实现增大 batch_size 的作用。动态图模式天然支持梯度累加,即只要不调用梯度清零 clear_grad 方法,动态图的梯度会一直累积。

  2. 阻断反向传播:在一些任务中,只希望拿到正向预测的值,但是不希望更新参数,或者在反向的时候剪枝,减少计算量,阻断反向的传播,即可使用paddle.Tensor.detach()产生一个新的、和当前计算图分离的,但是拥有当前变量内容的临时变量。

2. 环境设置

导入包,主要包括paddle和一些画图辅助,如plt

import os
import random
import paddle
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import cv2

# 输出Paddle版本
print(paddle.__version__)
2.3.2

3. 数据集

首先,准备数据集,这里直接下载PaddleSeg准备好的数据集。

! wget https://paddleseg.bj.bcebos.com/dataset/cityscapes.tar
! tar -xf cityscapes.tar

查看图片内容

plt.subplot(1, 2, 1)
img = cv2.imread(
    "cityscapes/leftImg8bit/train/aachen/aachen_000001_000019_leftImg8bit.png"
)
plt.imshow(img)
plt.axis("off")
plt.subplot(1, 2, 2)
img = cv2.imread(
    "cityscapes/gtFine/train/aachen/aachen_000001_000019_gtFine_color.png"
)
plt.imshow(img)
plt.axis("off")
plt.xticks([])
plt.yticks([])
plt.subplots_adjust(wspace=0.1, hspace=0.1)

png

依赖于cityscapes/train.list,构造一个数据读取器

import paddle
import os


class MyDateset(paddle.io.Dataset):
    def __init__(self, root_dir="cityscapes", txt_dir="cityscapes/train.list"):
        super(MyDateset, self).__init__()

        self.root_dir = root_dir
        with open(txt_dir, "r") as f:
            self.file_list = f.readlines()

    def __getitem__(self, index):
        file_dir = self.file_list[index][:-1]
        img_dir, mask_dir = file_dir.split(" ")
        mask_dir = mask_dir.replace("_labelTrainIds.png", "_color.png")

        img = cv2.imread(os.path.join(self.root_dir, img_dir))
        img = cv2.resize(img, (512, 256))  # (宽,高)
        img = img / 255
        img = img.transpose([2, 0, 1])
        img = paddle.to_tensor(img).astype("float32")

        mask = cv2.imread(os.path.join(self.root_dir, mask_dir))
        mask = cv2.resize(mask, (512, 256))  # (宽,高)
        mask = mask / 255
        mask = mask.transpose([2, 0, 1])
        mask = paddle.to_tensor(mask).astype("float32")

        return img, mask

    def __len__(self):
        return len(self.file_list)


if 1:
    dataset = MyDateset()

    dataloader = paddle.io.DataLoader(
        dataset, batch_size=16, shuffle=True, drop_last=False
    )

    for step, data in enumerate(dataloader):
        img, mask = data
        print(step, img.shape, mask.shape)
        break
0 [16, 3, 256, 512] [16, 3, 256, 512]

4. 模型组网

Pix2Pix需要使用UNet形式的构造生成器。

UNet网络的具体介绍请参考论文U-Net: Convolutional Networks for Biomedical Image Segmentation

鉴别器由一组卷积模块构成,用于判断输入的两张图片是否分别对应mask(分割后的图片)和街道场景。

# Generator Code
class UnetGenerator(paddle.nn.Layer):
    def __init__(self, input_nc=3, output_nc=3, ngf=64):
        super(UnetGenerator, self).__init__()

        self.down1 = paddle.nn.Conv2D(
            input_nc, ngf, kernel_size=4, stride=2, padding=1
        )
        self.down2 = Downsample(ngf, ngf * 2)
        self.down3 = Downsample(ngf * 2, ngf * 4)
        self.down4 = Downsample(ngf * 4, ngf * 8)
        self.down5 = Downsample(ngf * 8, ngf * 8)
        self.down6 = Downsample(ngf * 8, ngf * 8)
        self.down7 = Downsample(ngf * 8, ngf * 8)

        self.center = Downsample(ngf * 8, ngf * 8)

        self.up7 = Upsample(ngf * 8, ngf * 8, use_dropout=True)
        self.up6 = Upsample(ngf * 8 * 2, ngf * 8, use_dropout=True)
        self.up5 = Upsample(ngf * 8 * 2, ngf * 8, use_dropout=True)
        self.up4 = Upsample(ngf * 8 * 2, ngf * 8)
        self.up3 = Upsample(ngf * 8 * 2, ngf * 4)
        self.up2 = Upsample(ngf * 4 * 2, ngf * 2)
        self.up1 = Upsample(ngf * 2 * 2, ngf)

        self.output_block = paddle.nn.Sequential(
            paddle.nn.ReLU(),
            paddle.nn.Conv2DTranspose(
                ngf * 2, output_nc, kernel_size=4, stride=2, padding=1
            ),
            paddle.nn.Tanh(),
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)

        c = self.center(d7)

        x = self.up7(c, d7)
        x = self.up6(x, d6)
        x = self.up5(x, d5)
        x = self.up4(x, d4)
        x = self.up3(x, d3)
        x = self.up2(x, d2)
        x = self.up1(x, d1)

        x = self.output_block(x)
        return x


class Downsample(paddle.nn.Layer):
    # LeakyReLU => conv => batch norm
    def __init__(self, in_dim, out_dim, kernel_size=4, stride=2, padding=1):
        super(Downsample, self).__init__()

        self.layers = paddle.nn.Sequential(
            paddle.nn.LeakyReLU(0.2),
            paddle.nn.Conv2D(
                in_dim, out_dim, kernel_size, stride, padding, bias_attr=False
            ),
            paddle.nn.BatchNorm2D(out_dim),
        )

    def forward(self, x):
        x = self.layers(x)
        return x


class Upsample(paddle.nn.Layer):
    # ReLU => deconv => batch norm => dropout
    def __init__(
        self,
        in_dim,
        out_dim,
        kernel_size=4,
        stride=2,
        padding=1,
        use_dropout=False,
    ):
        super(Upsample, self).__init__()

        sequence = [
            paddle.nn.ReLU(),
            paddle.nn.Conv2DTranspose(
                in_dim, out_dim, kernel_size, stride, padding, bias_attr=False
            ),
            paddle.nn.BatchNorm2D(out_dim),
        ]

        if use_dropout:
            sequence.append(paddle.nn.Dropout(p=0.5))

        self.layers = paddle.nn.Sequential(*sequence)

    def forward(self, x, skip):
        x = self.layers(x)
        x = paddle.concat([x, skip], axis=1)
        return x


# 通过paddle.summary可以查看一个指定形状的数据在网络中各个模块中的传递
paddle.summary(UnetGenerator(), (8, 3, 256, 512))
-----------------------------------------------------------------------------------------------
  Layer (type)                 Input Shape                   Output Shape         Param #    
===============================================================================================
    Conv2D-1                [[8, 3, 256, 512]]            [8, 64, 128, 256]        3,136     
   LeakyReLU-1             [[8, 64, 128, 256]]            [8, 64, 128, 256]          0       
    Conv2D-2               [[8, 64, 128, 256]]            [8, 128, 64, 128]       131,072    
  BatchNorm2D-1            [[8, 128, 64, 128]]            [8, 128, 64, 128]         512      
  Downsample-1             [[8, 64, 128, 256]]            [8, 128, 64, 128]          0       
   LeakyReLU-2             [[8, 128, 64, 128]]            [8, 128, 64, 128]          0       
    Conv2D-3               [[8, 128, 64, 128]]             [8, 256, 32, 64]       524,288    
  BatchNorm2D-2             [[8, 256, 32, 64]]             [8, 256, 32, 64]        1,024     
  Downsample-2             [[8, 128, 64, 128]]             [8, 256, 32, 64]          0       
   LeakyReLU-3              [[8, 256, 32, 64]]             [8, 256, 32, 64]          0       
    Conv2D-4                [[8, 256, 32, 64]]             [8, 512, 16, 32]      2,097,152   
  BatchNorm2D-3             [[8, 512, 16, 32]]             [8, 512, 16, 32]        2,048     
  Downsample-3              [[8, 256, 32, 64]]             [8, 512, 16, 32]          0       
   LeakyReLU-4              [[8, 512, 16, 32]]             [8, 512, 16, 32]          0       
    Conv2D-5                [[8, 512, 16, 32]]             [8, 512, 8, 16]       4,194,304   
  BatchNorm2D-4             [[8, 512, 8, 16]]              [8, 512, 8, 16]         2,048     
  Downsample-4              [[8, 512, 16, 32]]             [8, 512, 8, 16]           0       
   LeakyReLU-5              [[8, 512, 8, 16]]              [8, 512, 8, 16]           0       
    Conv2D-6                [[8, 512, 8, 16]]               [8, 512, 4, 8]       4,194,304   
  BatchNorm2D-5              [[8, 512, 4, 8]]               [8, 512, 4, 8]         2,048     
  Downsample-5              [[8, 512, 8, 16]]               [8, 512, 4, 8]           0       
   LeakyReLU-6               [[8, 512, 4, 8]]               [8, 512, 4, 8]           0       
    Conv2D-7                 [[8, 512, 4, 8]]               [8, 512, 2, 4]       4,194,304   
  BatchNorm2D-6              [[8, 512, 2, 4]]               [8, 512, 2, 4]         2,048     
  Downsample-6               [[8, 512, 4, 8]]               [8, 512, 2, 4]           0       
   LeakyReLU-7               [[8, 512, 2, 4]]               [8, 512, 2, 4]           0       
    Conv2D-8                 [[8, 512, 2, 4]]               [8, 512, 1, 2]       4,194,304   
  BatchNorm2D-7              [[8, 512, 1, 2]]               [8, 512, 1, 2]         2,048     
  Downsample-7               [[8, 512, 2, 4]]               [8, 512, 1, 2]           0       
     ReLU-1                  [[8, 512, 1, 2]]               [8, 512, 1, 2]           0       
Conv2DTranspose-1            [[8, 512, 1, 2]]               [8, 512, 2, 4]       4,194,304   
  BatchNorm2D-8              [[8, 512, 2, 4]]               [8, 512, 2, 4]         2,048     
    Dropout-1                [[8, 512, 2, 4]]               [8, 512, 2, 4]           0       
   Upsample-1        [[8, 512, 1, 2], [8, 512, 2, 4]]      [8, 1024, 2, 4]           0       
     ReLU-2                 [[8, 1024, 2, 4]]              [8, 1024, 2, 4]           0       
Conv2DTranspose-2           [[8, 1024, 2, 4]]               [8, 512, 4, 8]       8,388,608   
  BatchNorm2D-9              [[8, 512, 4, 8]]               [8, 512, 4, 8]         2,048     
    Dropout-2                [[8, 512, 4, 8]]               [8, 512, 4, 8]           0       
   Upsample-2       [[8, 1024, 2, 4], [8, 512, 4, 8]]      [8, 1024, 4, 8]           0       
     ReLU-3                 [[8, 1024, 4, 8]]              [8, 1024, 4, 8]           0       
Conv2DTranspose-3           [[8, 1024, 4, 8]]              [8, 512, 8, 16]       8,388,608   
 BatchNorm2D-10             [[8, 512, 8, 16]]              [8, 512, 8, 16]         2,048     
    Dropout-3               [[8, 512, 8, 16]]              [8, 512, 8, 16]           0       
   Upsample-3       [[8, 1024, 4, 8], [8, 512, 8, 16]]     [8, 1024, 8, 16]          0       
     ReLU-4                 [[8, 1024, 8, 16]]             [8, 1024, 8, 16]          0       
Conv2DTranspose-4           [[8, 1024, 8, 16]]             [8, 512, 16, 32]      8,388,608   
 BatchNorm2D-11             [[8, 512, 16, 32]]             [8, 512, 16, 32]        2,048     
   Upsample-4      [[8, 1024, 8, 16], [8, 512, 16, 32]]   [8, 1024, 16, 32]          0       
     ReLU-5                [[8, 1024, 16, 32]]            [8, 1024, 16, 32]          0       
Conv2DTranspose-5          [[8, 1024, 16, 32]]             [8, 256, 32, 64]      4,194,304   
 BatchNorm2D-12             [[8, 256, 32, 64]]             [8, 256, 32, 64]        1,024     
   Upsample-5     [[8, 1024, 16, 32], [8, 256, 32, 64]]    [8, 512, 32, 64]          0       
     ReLU-6                 [[8, 512, 32, 64]]             [8, 512, 32, 64]          0       
Conv2DTranspose-6           [[8, 512, 32, 64]]            [8, 128, 64, 128]      1,048,576   
 BatchNorm2D-13            [[8, 128, 64, 128]]            [8, 128, 64, 128]         512      
   Upsample-6     [[8, 512, 32, 64], [8, 128, 64, 128]]   [8, 256, 64, 128]          0       
     ReLU-7                [[8, 256, 64, 128]]            [8, 256, 64, 128]          0       
Conv2DTranspose-7          [[8, 256, 64, 128]]            [8, 64, 128, 256]       262,144    
 BatchNorm2D-14            [[8, 64, 128, 256]]            [8, 64, 128, 256]         256      
   Upsample-7     [[8, 256, 64, 128], [8, 64, 128, 256]]  [8, 128, 128, 256]         0       
     ReLU-8                [[8, 128, 128, 256]]           [8, 128, 128, 256]         0       
Conv2DTranspose-8          [[8, 128, 128, 256]]            [8, 3, 256, 512]        6,147     
     Tanh-1                 [[8, 3, 256, 512]]             [8, 3, 256, 512]          0       
===============================================================================================
Total params: 54,425,923
Trainable params: 54,404,163
Non-trainable params: 21,760
-----------------------------------------------------------------------------------------------
Input size (MB): 12.00
Forward/backward pass size (MB): 2250.00
Params size (MB): 207.62
Estimated Total Size (MB): 2469.62
-----------------------------------------------------------------------------------------------






{'total_params': 54425923, 'trainable_params': 54404163}
# Discriminator Code
class NLayerDiscriminator(paddle.nn.Layer):
    def __init__(self, input_nc=6, ndf=64):
        super(NLayerDiscriminator, self).__init__()

        self.layers = paddle.nn.Sequential(
            paddle.nn.Conv2D(input_nc, ndf, kernel_size=4, stride=2, padding=1),
            paddle.nn.LeakyReLU(0.2),
            ConvBlock(ndf, ndf * 2),
            ConvBlock(ndf * 2, ndf * 4),
            ConvBlock(ndf * 4, ndf * 8, stride=1),
            paddle.nn.Conv2D(ndf * 8, 1, kernel_size=4, stride=1, padding=1),
            paddle.nn.Sigmoid(),
        )

    def forward(self, input):
        return self.layers(input)


class ConvBlock(paddle.nn.Layer):
    # conv => batch norm => LeakyReLU
    def __init__(self, in_dim, out_dim, kernel_size=4, stride=2, padding=1):
        super(ConvBlock, self).__init__()

        self.layers = paddle.nn.Sequential(
            paddle.nn.Conv2D(
                in_dim, out_dim, kernel_size, stride, padding, bias_attr=False
            ),
            paddle.nn.BatchNorm2D(out_dim),
            paddle.nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        x = self.layers(x)
        return x


# 通过paddle.summary可以查看一个指定形状的数据在网络中各个模块中的传递
paddle.summary(NLayerDiscriminator(), (16, 6, 256, 512))
---------------------------------------------------------------------------
 Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================
   Conv2D-9     [[16, 6, 256, 512]]   [16, 64, 128, 256]       6,208     
  LeakyReLU-8   [[16, 64, 128, 256]]  [16, 64, 128, 256]         0       
   Conv2D-10    [[16, 64, 128, 256]]  [16, 128, 64, 128]      131,072    
BatchNorm2D-15  [[16, 128, 64, 128]]  [16, 128, 64, 128]        512      
  LeakyReLU-9   [[16, 128, 64, 128]]  [16, 128, 64, 128]         0       
  ConvBlock-1   [[16, 64, 128, 256]]  [16, 128, 64, 128]         0       
   Conv2D-11    [[16, 128, 64, 128]]  [16, 256, 32, 64]       524,288    
BatchNorm2D-16  [[16, 256, 32, 64]]   [16, 256, 32, 64]        1,024     
 LeakyReLU-10   [[16, 256, 32, 64]]   [16, 256, 32, 64]          0       
  ConvBlock-2   [[16, 128, 64, 128]]  [16, 256, 32, 64]          0       
   Conv2D-12    [[16, 256, 32, 64]]   [16, 512, 31, 63]      2,097,152   
BatchNorm2D-17  [[16, 512, 31, 63]]   [16, 512, 31, 63]        2,048     
 LeakyReLU-11   [[16, 512, 31, 63]]   [16, 512, 31, 63]          0       
  ConvBlock-3   [[16, 256, 32, 64]]   [16, 512, 31, 63]          0       
   Conv2D-13    [[16, 512, 31, 63]]    [16, 1, 30, 62]         8,193     
   Sigmoid-1     [[16, 1, 30, 62]]     [16, 1, 30, 62]           0       
===========================================================================
Total params: 2,770,497
Trainable params: 2,766,913
Non-trainable params: 3,584
---------------------------------------------------------------------------
Input size (MB): 48.00
Forward/backward pass size (MB): 1768.70
Params size (MB): 10.57
Estimated Total Size (MB): 1827.27
---------------------------------------------------------------------------






{'total_params': 2770497, 'trainable_params': 2766913}

5. 模型训练

本项目的训练参数配置如下:

  • 迭代次数:30

  • 损失函数:BCELossL1Loss

  • 优化器:Adam优化器,出自Adam论文的第二节,能够利用梯度的一阶矩估计和二阶矩估计动态调整每个参数的学习率。

  • 学习率:0.0002

下述程序大约需要运行六个小时。

# create Net
netG = UnetGenerator()
netD = NLayerDiscriminator()

# 如果想要接着之前训练的模型训练,将if 0修改为if 1即可
if 1:
    try:
        mydict = paddle.load("generator.params")
        netG.set_dict(mydict)
        mydict = paddle.load("discriminator.params")
        netD.set_dict(mydict)
    except:
        print("fail to load model")

netG.train()
netD.train()

optimizerD = paddle.optimizer.Adam(
    parameters=netD.parameters(), learning_rate=0.00002, beta1=0.5, beta2=0.999
)
optimizerG = paddle.optimizer.Adam(
    parameters=netG.parameters(), learning_rate=0.00002, beta1=0.5, beta2=0.999
)

bce_loss = paddle.nn.BCELoss()
l1_loss = paddle.nn.L1Loss()

# 最大迭代epoch
max_epoch = 30

now_step = 0
for epoch in range(max_epoch):
    for step, (img, mask) in enumerate(dataloader):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################

        # 清除D的梯度
        optimizerD.clear_grad()

        # 传入正样本,并更新梯度
        pos_img = paddle.concat((img, mask), 1)
        # label = paddle.full([pos_img.shape[0], 1], 1, dtype='float32')
        pre = netD(pos_img)
        loss_D_1 = bce_loss(pre, paddle.ones_like(pre))
        loss_D_1.backward()

        # 通过randn构造随机数,制造负样本,并传入D,更新梯度
        fake_img = netG(mask).detach()
        neg_img = paddle.concat((fake_img, mask), 1)
        # label = paddle.full([pos_img.shape[0], 1], 0, dtype='float32')
        pre = netD(
            neg_img.detach()
        )  # 通过detach阻断网络梯度传播,不影响G的梯度计算
        loss_D_2 = bce_loss(pre, paddle.zeros_like(pre))
        loss_D_2.backward()

        # 更新D网络参数
        optimizerD.step()
        optimizerD.clear_grad()

        loss_D = loss_D_1 + loss_D_2

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################

        # 清除D的梯度
        optimizerG.clear_grad()

        fake_img = netG(mask)
        fake = paddle.concat((fake_img, mask), 1)
        # label = paddle.full((pos_img.shape[0], 1), 1, dtype=np.float32,)
        output = netD(fake)
        loss_G_1 = l1_loss(fake_img, img) * 100.0
        loss_G_2 = bce_loss(output, paddle.ones_like(pre))
        loss_G = loss_G_1 + loss_G_2
        loss_G.backward()

        # 更新G网络参数
        optimizerG.step()
        optimizerG.clear_grad()

        now_step += 1

        print("\r now_step is:", now_step, end="")
        ###########################
        # 可视化
        ###########################
        if now_step % 100 == 0:
            print()
            plt.figure(figsize=(15, 15))
            try:
                for i in range(10):
                    # image = generated_image[i].transpose()
                    image = fake_img[i]
                    image = np.where(image > 0, image, 0)
                    image = image.transpose((1, 2, 0))
                    plt.subplot(10, 10, i + 1)

                    plt.imshow(image)
                    plt.axis("off")
                    plt.xticks([])
                    plt.yticks([])
                    plt.subplots_adjust(wspace=0.1, hspace=0.1)
                msg = "Epoch ID={0} Batch ID={1} \n\n D-Loss={2} G-Loss={3}".format(
                    epoch, now_step, loss_D.numpy()[0], loss_G.numpy()[0]
                )
                print(msg)
                plt.suptitle(msg, fontsize=20)
                plt.draw()
                # 保存在work文件夹下
                plt.savefig(
                    "{}/{:04d}_{:04d}.png".format("work", epoch, now_step),
                    bbox_inches="tight",
                )
                plt.pause(0.01)
                break
            except IOError:
                print(IOError)
paddle.save(netG.state_dict(), "generator.params")
paddle.save(netD.state_dict(), "discriminator.params")
 now_step is: 100
Epoch ID=0 Batch ID=100 

 D-Loss=1.0732134580612183 G-Loss=6.412326812744141

png

 now_step is: 200
Epoch ID=1 Batch ID=200 

 D-Loss=1.018144130706787 G-Loss=6.176774501800537

png

 now_step is: 300
Epoch ID=2 Batch ID=300 

 D-Loss=0.9204654097557068 G-Loss=6.499708652496338

png

 now_step is: 400
Epoch ID=3 Batch ID=400 

 D-Loss=0.9853896498680115 G-Loss=6.3135786056518555

png

 now_step is: 500
Epoch ID=4 Batch ID=500 

 D-Loss=1.1680676937103271 G-Loss=6.198908805847168

png

6. 模型预测

通过下述代码即可调用刚刚训练好的生成器。

# 读取模型
netG = UnetGenerator()
mydict = paddle.load("generator.params")
netG.set_dict(mydict)
netG.eval()

# 读取分割图片
mask = cv2.imread("test.png")
h, w, c = mask.shape

# 绘制
plt.subplot(1, 2, 1)
plt.imshow(mask)
plt.axis("off")
plt.xticks([])
plt.yticks([])

# 预处理
mask = cv2.resize(mask, (512, 256))  # (宽,高)
mask = mask / 255
mask = mask.transpose([2, 0, 1])
mask = paddle.to_tensor(mask).astype("float32")
mask = mask.reshape([1] + mask.shape)

# 将推理结果转化为图片
img = netG(mask)
img = img.numpy()[0]
img[img < 0] = 0
img = (img * 255).astype("uint8")
img = img.transpose((1, 2, 0))

# 绘制
plt.subplot(1, 2, 2)
plt.imshow(img)
plt.axis("off")
plt.xticks([])
plt.yticks([])

# 保存
cv2.imwrite("result.jpg", img)

img = cv2.resize(img, (w, h))  # (宽,高)
cv2.imwrite("result_samesize.jpg", img)
True

png

7. 结语

本文简单跑通了Pix2Pix的流程,如果对网络有更高的性能需求可以参考致谢处的两个参考链接,内容类似,但是会更为详实。

8. 致谢

本文参考以下项目

请点击此处查看本环境基本用法.
Please click here for more detailed instructions.