生成对抗网络(Generative Adversarial Network[1], 简称GAN) 是一种非监督学习的方式,通过让两个神经网络相互博弈的方法进行学习,该方法由lan Goodfellow等人在2014年提出。生成对抗网络由一个生成网络和一个判别网络组成,生成网络从潜在的空间(latent space)中随机采样作为输入,其输出结果需要尽量模仿训练集中的真实样本。判别网络的输入为真实样本或生成网络的输出,其目的是将生成网络的输出从真实样本中尽可能的分辨出来。而生成网络则尽可能的欺骗判别网络,两个网络相互对抗,不断调整参数。 生成对抗网络常用于生成以假乱真的图片。此外,该方法还被用于生成影片,三维物体模型等。[2]
本图像生成模型库包含CGAN[3], DCGAN[4], Pix2Pix[5], CycleGAN[6], StarGAN[7], AttGAN[8], STGAN[9]。
注意: 1. StarGAN,AttGAN和STGAN由于梯度惩罚所需的操作目前只支持GPU,需使用GPU训练。 2. GAN模型目前仅仅验证了单机单卡训练和预测结果。 3. CGAN和DCGAN两个模型训练使用的数据集为MNIST数据集;StarGAN,AttGAN和STGAN的数据集为CelebA数据集。Pix2Pix和CycleGAN支持的数据集可以参考download.py中的cycle_pix_dataset。 4. PaddlePaddle1.5.1及之前的版本不支持在AttGAN和STGAN模型里的判别器加上的instance norm。如果要在判别器中加上instance norm,请源码编译develop分支并安装。 5. 中间效果图保存在${output_dir}/test文件夹中。对于Pix2Pix来说,inputA 和inputB 代表输入的两种风格的图片,fakeB表示生成图片;对于CycleGAN来说,inputA表示输入图片,fakeB表示inputA根据生成的图片,cycA表示fakeB经过生成器重构出来的对应于inputA的重构图片;对于StarGAN,AttGAN和STGAN来说,第一行表示原图,之后的每一行都代表一种属性变换。
图像生成模型库库的目录结构如下:
├── download.py 下载数据
│
├── data_reader.py 数据预处理
│
├── train.py 模型的训练入口
│
├── infer.py 模型的预测入口
│
├── trainer 不同模型的训练脚本
│ ├── CGAN.py Conditional GAN的训练脚本
│ ├── ...
│ ├── STGAN.py STGAN的训练脚本
│
├── network 不同模型的网络结构
│ ├── base_network.py GAN模型需要的公共基础网络结构
│ ├── ...
│ ├── STGAN_network.py STGAN的网络结构
│
├── util 网络的基础配置和公共模块
│ ├── config.py 网络公用的基础配置
│ ├── utility.py 保存模型等网络公用的模块
│
├── scripts 多个模型的训练启动和测试启动示例
│ ├── run_....py 训练启动示例
│ ├── infer_....py 测试启动示例
│ ├── make_pair_data.py pix2pix GAN的数据list的生成脚本
│
├── data 下载的数据集存放的位置
│ ├── celeba
│ ├── ${image_dir} 存放实际图片
│ ├── list 文件
安装PaddlePaddle:
在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.5或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据安装文档中的说明来更新PaddlePaddle。
其他依赖包:
1. pip install imageio
或者 pip install -r requirements.txt
安装imageio包(保存图片代码中所依赖的包)
Pix2Pix和CycleGAN采用cityscapes[10]数据集进行风格转换。 StarGAN,AttGAN和STGAN采用celeba[11]数据集进行属性迁移。
模型库中提供了download.py数据下载脚本,该脚本支持下载MNIST数据集,CycleGAN和Pix2Pix所需要的数据集。使用以下命令下载数据: python download.py --dataset=mnist 通过指定dataset参数来下载相应的数据集。
StarGAN, AttGAN和STGAN所需要的Celeba数据集可以自行下载。
自定义数据集:
如果您要使用自定义的数据集,只要设置成对应的生成模型所需要的数据格式,并放在data文件夹下,然后把--dataset
参数设置成您自定义数据集的名称,data_reader.py文件就会自动去data文件夹中寻找数据。
注意: pix2pix模型数据集准备中的list文件需要通过scripts文件夹里的make_pair_data.py来生成,可以使用以下命令来生成:
python scripts/make_pair_data.py \
--direction=A2B
用户可以通过设置--direction
参数生成list文件,从而确保图像风格转变的方向。
开始训练: 数据准备完毕后,可以通过以下方式启动训练:
python train.py \
--model_net=$(name_of_model) \
--dataset=$(name_of_dataset) \
--data_dir=$(path_to_data) \
--train_list=$(path_to_train_data_list) \
--test_list=$(path_to_test_data_list) \
--batch_size=$(batch_size)
可选参数见:
python train.py --help
每个GAN都给出了一份运行示例,放在scripts文件夹内,用户可以直接运行训练脚本快速开始训练。
--model_net
参数来选择想要训练的模型,通过设置--dataset
参数来选择训练所需要的数据集。模型测试是利用训练完成的生成模型进行图像生成。infer.py是主要的执行程序,调用示例如下:
执行以下命令得到CyleGAN的预测结果:
python infer.py \
--model_net=CycleGAN \
--init_model=$(path_to_init_model) \
--image_size=256 \
--dataset_dir=$(path_to_data) \
--input_style=$(A_or_B) \
--net_G=$(generator_network) \
--g_base_dims=$(base_dim_of_generator)
执行以下命令得到Pix2Pix的预测结果:
python infer.py \
--model_net=Pix2pix \
--init_model=$(path_to_init_model) \
--image_size=256 \
--dataset_dir=$(path_to_data) \
--net_G=$(generator_network)
执行以下命令得到StarGAN,AttGAN或者STGAN的预测结果:
python infer.py \
--model_net=$(StarGAN_or_AttGAN_or_STGAN) \
--init_model=$(path_to_init_model)\
--dataset_dir=$(path_to_data)
Pix2Pix和CycleGAN的效果如图所示:
Pix2Pix和CycleGAN的效果图
StarGAN,AttGAN和STGAN的效果如图所示:
StarGAN的效果图(图片属性分别为:origial image, Black hair, Blond Hair, Brown Hair, Male, Young)
AttGAN的效果图(图片属性分别为:original image, Bald, Bangs, Black Hair, Blond Hair, Brown Hair, Bushy Eyebrows, Eyeglasses, Male, Mouth Slightly Open, Mustache, No Beard, Pale Skin, Young)
STGAN的效果图(图片属性分别为:original image, Bald, Bangs, Black Hair, Blond Hair, Brown Hair, Bushy Eyebrows, Eyeglasses, Male, Mouth Slightly Open, Mustache, No Beard, Pale Skin, Young)
下载预训练模型:
本示例提供以下预训练模型:
Model | Pretrained model |
---|---|
Pix2Pix | Pix2Pix的预训练模型 |
CycleGAN | CycleGAN的预训练模型 |
StarGAN | StarGAN的预训练模型 |
AttGAN | AttGAN的预训练模型 |
STGAN | STGAN的预训练模型 |
CGAN,条件生成对抗网络,一种带条件约束的GAN,使用额外信息对模型增加条件,可以指导数据生成过程。
DCGAN,深度卷积生成对抗网络,将GAN和卷积网络结合起来,以解决GAN训练不稳定的问题,利用卷积神经网络作为网络结构进行图像生成,可以得到更加丰富的层次表达。
Pix2Pix利用成对的图片进行图像翻译,即输入为同一张图片的两种不同风格,可用于进行风格迁移。
CycleGAN可以利用非成对的图片进行图像翻译,即输入为两种不同风格的不同图片,自动进行风格转换。
StarGAN多领域属性迁移,引入辅助分类帮助单个判别器判断多个属性,可用于人脸属性转换。
AttGAN利用分类损失和重构损失来保证改变特定的属性,可用于人脸特定属性转换。
STGAN只输入有变化的标签,引入GRU结构,更好的选择变化的属性,可用于人脸特定属性转换。
convolution-batch norm-ReLU
作为基础结构,解码部分的网络结构由transpose convolution-batch norm-ReLU
组成,判别网络基本是由convolution-norm-leaky_ReLU
作为基础结构,详细的网络结构可以查看network/Pix2pix_network.py
文件。生成网络提供两种可选的网络结构:Unet网络结构和普通的encoder-decoder网络结构。网络利用损失函数学习从输入图像到输出图像的映射,生成网络损失函数由GAN的损失函数和L1损失函数组成,判别网络损失函数由GAN的损失函数组成。生成器的网络结构如下图所示:
Pix2Pix生成网络结构图[5]
convolution-norm-ReLU
作为基础结构,解码部分的网络结构由transpose convolution-norm-ReLU
组成,判别网络基本是由convolution-norm-leaky_ReLU
作为基础结构,详细的网络结构可以查看network/CycleGAN_network.py
文件。生成网络提供两种可选的网络结构:Unet网络结构和普通的encoder-decoder网络结构。生成网络损失函数由LSGAN的损失函数,重构损失和自身损失组成,判别网络的损失函数由LSGAN的损失函数组成。
CycleGAN生成网络结构图[5]
convolution-instance norm-ReLU
组成,解码部分主要由transpose convolution-norm-ReLU
组成,判别网络主要由convolution-leaky_ReLU
组成,详细网络结构可以查看network/StarGAN_network.py
文件。生成网络的损失函数是由WGAN的损失函数,重构损失和分类损失组成,判别网络的损失函数由预测损失,分类损失和梯度惩罚损失组成。
StarGAN的生成网络结构[左]和判别网络结构[右] [7]
convolution-instance norm-ReLU
组成,解码部分由transpose convolution-norm-ReLU
组成,判别网络主要由convolution-leaky_ReLU
组成,详细网络结构可以查看network/AttGAN_network.py
文件。生成网络的损失函数是由WGAN的损失函数,重构损失和分类损失组成,判别网络的损失函数由预测损失,分类损失和梯度惩罚损失组成。
AttGAN的网络结构[8]
convolution-instance norm-ReLU
组成,解码网络主要由transpose convolution-norm-leaky_ReLU
组成,判别网络主要由convolution-leaky_ReLU
组成,详细网络结构可以查看network/STGAN_network.py
文件。生成网络的损失函数是由WGAN的损失函数,重构损失和分类损失组成,判别网络的损失函数由预测损失,分类损失和梯度惩罚损失组成。
STGAN的网络结构[9]
注意:网络结构中的norm指的是用户可以选用batch norm或者instance norm来搭建自己的网络。
Q: StarGAN/AttGAN/STGAN中属性没有变化,为什么?
A: 查看是否所有的标签都转换对了。
Q: 预测结果不正常,是怎么回事?
A: 某些GAN预测的时候batch_norm的设置需要和训练的时候行为一致,查看模型库中相应的GAN中预测时batch_norm的行为和自己模型中的预测时batch_norm的行为是否一致。
Q: 为什么STGAN和ATTGAN中变男性得到的预测结果是变女性呢?
A: 这是由于预测时标签的设置,目标标签是基于原本的标签进行改变,比如原本图片是男生,预测代码对标签进行转变的时候会自动变成相对立的标签,即女
性,所以得到的结果是女生。如果想要原本是男生,转变之后还是男生,保持要转变的标签不变即可。
Q: 如何使用自己的数据集进行训练?
A: 对于Pix2Pix来说,只要准备好类似于Cityscapes数据集的不同风格的成对的数据即可。对于CycleGAN来说,只要准备类似于Cityscapes数据集的不同风格的数据即可。对于StarGAN,AttGAN和STGAN来说,除了需要准备类似于CelebA数据集中图片,包含图片数量、名称和标签信息的list文件外,还需要把模型中的selected_attrs参数设置为想要改变的目标属性,c_dim参数设置为目标属性的个数。
Q: 如何从模型库中拿出单独的一个模型?
A: 由于trainer文件夹中的init.py文件默认导入了所有网络结构,所以需要删掉init.py文件中导入的当前模型之外的包,然后把trainer和network中不需要的模型文件删掉即可。
1Goodfellow, Ian J.; Pouget-Abadie, Jean; Mirza, Mehdi; Xu, Bing; Warde-Farley, David; Ozair, Sherjil; Courville, Aaron; Bengio, Yoshua. Generative Adversarial Networks. 2014. arXiv:1406.2661 stat.ML.](https://arxiv.org/abs/1406.2661)
2(https://zh.wikipedia.org/wiki/生成对抗网络)
3(https://arxiv.org/abs/1411.1784)
4(https://arxiv.org/abs/1511.06434)
5(https://arxiv.org/abs/1611.07004)
6(https://arxiv.org/abs/1703.10593)
7(https://arxiv.org/abs/1711.09020)
8(https://arxiv.org/abs/1711.10678)
9(https://arxiv.org/abs/1904.09709)
10(https://arxiv.org/abs/1604.01685)
11(https://arxiv.org/abs/1411.7766)
如果你可以修复某个issue或者增加一个新功能,欢迎给我们提交PR。如果对应的PR被接受了,我们将根据贡献的质量和难度进行打分(0-5分,越高越好)。如果你累计获得了10分,可以联系我们获得面试机会或者为你写推荐信。
Pix2Pix和CycleGAN的效果图
StarGAN的效果图(图片属性分别为:origial image, Black hair, Blond Hair, Brown Hair, Male, Young)
AttGAN的效果图(图片属性分别为:original image, Bald, Bangs, Black Hair, Blond Hair, Brown Hair, Bushy Eyebrows, Eyeglasses, Male, Mouth Slightly Open, Mustache, No Beard, Pale Skin, Young)
STGAN的效果图(图片属性分别为:original image, Bald, Bangs, Black Hair, Blond Hair, Brown Hair, Bushy Eyebrows, Eyeglasses, Male, Mouth Slightly Open, Mustache, No Beard, Pale Skin, Young)
Pix2Pix生成网络结构图[5]
CycleGAN生成网络结构图[5]
StarGAN的生成网络结构[左]和判别网络结构[右] [7]
AttGAN的网络结构[8]
STGAN的网络结构[9]