Overview of GAN

本文最后更新于:2024年1月15日 晚上

Overview of GAN

1 GAN的诞生

GAN的全称为Generative Adversarial Net,意思是对抗生成网络。

在这之前,生成式网络总会遇到很多难以实现的概率计算方法,这些问题在最大似然估计等策略中经常出现。

而GAN采用一种完全不同的方式来学习数据集的分布。

首先,我们定义一个生成器GG,生成器的输入是一个随机噪声zpz(z)z\sim p_z(z),生成器就是一个将噪声映射到图像数据空间的函数G(z;θG)G(z;\theta_G),其中θG\theta_G是生成器的参数。我们希望生成器最终学到数据的分布xp(x)x\sim p(x)

接下来,我们定义一个判别器D(x;θD)D(x;\theta_D),判别器的输入是一幅图像,输出是0,10,1,分别代表其判断为假图像(来自GG)和真图像(来自数据集),θD\theta_DDD的参数。从范式上讲,判别器就是一个二分类器,用来判断看见的图像是真是假。

graph LR

A[Generator]
B[Discriminator]
C[noise]
D[data]
E[Valid/Fake]

A --> B
C --> A
D --> B
B --> E

现在,我们的训练方式是

  • 让生成器根据噪声生成一张图片G(z)G(z)
  • 判别器根据图片判断是真图片还是假图片,输出为D(G(z))D(G(z))
  • 计算损失LG=BCELoss(D(G(z)),1)L_G=\mathrm{BCELoss}(D(G(z)),\boldsymbol{1}),或者说是log(1D(G(z)))\log(1-D(G(z)))。这样,GG就要尽可能让DD输出1来最小化LGL_G
  • 梯度回传并单独优化θG\theta_G
  • 计算损失LDvalid=BCELoss(D(x),1)L_D^{valid}=\mathrm{BCELoss}(D(x),\boldsymbol{1}),即判别器看到真图像时应该输出1。计算损失LDfake=BCELoss(D(G(z)),0)L_D^{fake}=\mathrm{BCELoss}(D(G(z)),\boldsymbol{0}),即判别器看到假图像时要输出0。总损失为LD=12(LDvalid+LDfake)L_D=\frac{1}{2}(L_D^{valid}+L_D^{fake})
  • 梯度回传并单独优化θD\theta_D

在这样的训练方式中,GG要不断提升自己的造假能力来骗过DD,而DD要不断提升自己的判别能力来判别是否是假图像,形成了一个对抗的范式。

在完成训练之后,只需要给GG输入噪声,就可以让GG生成类似于数据集的内容。

2 Conditional GAN

在最初的2014年的Conditional GAN中,作者的做法很简单,首先将条件用向量表示,例如分类的标签可以用one-hot向量表示。接着将条件向量与输入拼接一起输入网络即可。

graph LR

A[Noise]
B[Data]
C[Conditioanl Vec]
D[Generator]
E[Discriminator]
F[Valid/Fake]

A --> D
C --> D
D --> E
B --> E
E --> F
C --> E

由于每次输入时噪声是随机的,所以我们以某个条件做生成的时候,可以实现类似one to many的映射。在MNIST数据集中,作者用MNIST分类标签做为条件,成功利用同一个条件,例如数字0,生成了不同的手写体数字0的图片。

此外,作者尝试了使用图片-语言描述的数据集来做图片描述任务。图片描述是用几个字来描述一张图片,由于一张图片可以用不同的语言来描述,所以是一个one to many映射的问题,在这里作者把卷积网络提取的图片特征作为条件,让模型学数据集中语言描述的分布,最终实现了模型能够自动描述一张图的内容。


Overview of GAN
https://jesseprince.github.io/2024/01/15/convnets/generative/gan/
作者
林正
发布于
2024年1月15日
许可协议