wGAN如何解决GAN已有问题(附代码实现)

简介:

随着柯洁与AlphaGo的比赛结束以后,大家是不是对人工智能的底层奥秘越来越有兴趣?


深度学习已经在图像分类、检测等诸多领域取得了突破性的成绩。但是它也存在一些问题。


首先,它与传统的机器学习方法一样,通常假设训练数据与测试数据服从同样的分布,或者是在训练数据上的预测结果与在测试数据上的预测结果服从同样的分布,而实际上这两者存在一定的偏差。另一个问题是深度学习的模型(比如卷积神经网络)有时候并不能很好地学到训练数据中的一些特征。深度对抗学习(deep adversarial learning)就是为了解决上述问题而被提出的一种方法。


学习的过程可以看做是我们要得到一个模型,为了建模真实的数据分布,生成器学习生成实际的数据样本,而鉴别器学习确定这些样本是否是真实的。如果这个鉴别器的水平很高,而它无法分清它们之间的区别,那么就说明我们需要的模型具有很好的表达或者预测能力。


非监督学习是通往真正人工智能的方向,本文回顾了从传统机器学习,到wGAN的逻辑发展过程。GAN能自己生成特征、问题、评估函数,是近年来深度学习的一个突破。而wGAN解决了GAN已有的问题,“一个月内改变行业”,是深度学习的最新进展。本文让读者对wGAN的历史发展有个清晰的认识,并提供了wGAN的代码实现,是一篇很好的学习wGAN的入门材料。


gif;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAA


对抗学习是深度学习中最火的一个领域。网站arxiv-sanity的最近最流行的研究领域列表上,许多都是对抗学习,本文同样也是一篇讲对抗学习的文章。

 

在这篇文章中,我们主要学习以下三个方面的内容:


  • 为什么我们应该关注对抗学习

  • 生成对抗网络GANs(General Adversarial Networks) 和它面临的挑战

  • 能解决这些挑战的Wasserstein GAN和改进的稳定训练Wasserstein GAN的方法,还包括了代码实现。

 

从传统机器学习到深度学习


我在UIUC上“模拟信号与系统”课程的时候,教授在一开始就信誓旦旦地说:“这个课程将是你们上的最重要的课程,抽象是工程里面最重要的概念。” 


康奈尔大学的课程里面也有“解决复杂问题的方法就是抽象,也就是隐藏细节信息。抽象屏蔽掉无用的细节。为了设计一个复杂系统,你必须找出哪些是你想暴露给其他人的,哪些是你想隐藏起来的。暴露给其他人的部分,其他人可以进行设计。暴露的部分就是抽象。”

 

深度神经网络中的每层就是数据的抽象表示,层和层之间有依赖关系,最终形成一个层次结构。每一层都是上一层的一个更高级的抽象。给定一组原始数据和要解决的问题,然后定义一个目标函数来评估网络输出的答案,最终神经网络就能通过学习得到一个最优的解。

 

因此,特征是神经网络自己学习得来的。但是在传统的机器学习中,特征和算法都是人工定义的。

 

现在的数据的特征、结构、模式都是网络自我学习的,而不是像传统机器学习那样人工定义。所以以前无法实现的AI的算法现在可行了,并且在某些方面超过了人类。

 

从深度学习到深度对抗学习


很多年前,我学习过拳击。我的拳击教练不让新手问问题,说新手不知道问什么问题,连问的问题都是错误的,会得到没用的答案,会专注于错误的东西,越学越错。

 

Robert Half说过“会问问题和会解题一样,都需要一定的水平”

 

对抗学习的奇妙之处在于所有的东西都是从数据中学习得到的,包括要解决的问题,最终的答案以及评估答案的标准—目标函数。传统的深度学习中,是由人来决定要解决什么问题,人来决定用什么目标函数做评估。

 

Deep Mind公司用AlphaGo证明了深度对抗学习的厉害之处。在围棋比赛中,AlphaGo可以自己创造新的下法和招数。这开创了围棋的新纪元,突破了过去几千年的一个瓶颈,达到了新的高度。AlphaGo能做到这点是因为它能自己给自己打分,可以随时计算当前的局势的分数,而不用预先人工定义和预编程。这样,AlphaGo自己和自己下了几百万局的比赛。听起来很像对抗学习吧?


640?wx_fmt=jpeg&wxfrom=5&wx_lazy=1


AlphaGo不仅仅是暴力破解,而是真正掌握了围棋比赛,学到了围棋的招式。之所以这样,是因为它没有被人类束缚,既没有得到人类先验的输入,也不受我们对问题域理解的局限。无法想象,当我们把这些成果应用到实际生活中,AI会如何改造农业、医疗等等。但是这一定会发生。

 

生成对抗网络GAN


Richard Feynman说“如果要真正理解一个东西,我们必须要能够把它创造出来。”

 

正是这句话激励着我开始学习GANs。GANs的训练过程就是两个神经网络自己在作对抗,通过对抗不断的学习。当然学习是在原始数据的基础上学习。


640?wx_fmt=jpeg&wxfrom=5&wx_lazy=1


生成器通过对原始数据的分布进行建模,学习如何生成近似数据;而判别器用来判断数据是生成器生成的数据还是原始的真实的数据。这样生成器就能重新创造出原始数据的近似数据。我们相信为了能够理解一个东西,我们要能重新创造这个东西,所以GAN是非常有价值的,我们的努力也是值得的。


如果我们能成功使得GAN达到纳什均衡(完美的判别器也不能识别数据到底是真实数据还是生成数据),我们就能够把这个成果应用到几乎任何事情上,并且还能够有最好的性能。


存在的问题


GANs很难优化,并且训练过程不稳定。网络结构必须设计的非常好,生成器和判别器之间必须有个很好的协调,才能使得训练过程收敛。这些问题中,最显著的就是失去样本多样性(mode dropping, 即生成器只从很小一部分的数据集中学习)。还有由于GANs的学习曲线基本没什么意义,因此很难调试。

 

虽然如此,仍然通过GANs得到了最先进的一些成果。但是就是因为这些问题,GANs的应用被限制住了。


解决方法


Alex J. Champandard说“一个月内,传统的训练GANs的方法会被当做黑暗时代的方法”。

 

GANs的训练目标是生成数据和真实数据的分布的距离差的最小化。


最开始使用的是Jensen-Shannon散度。但是,Wasserstein GAN(wGAN)文章在理论和实际两个方面,都证明了最小化推土距离EMD(Earth Mover’s distance)才是解决上述问题的最优方法。当然在实际计算中,由于EMD的计算量过大,因此使用的是EMD的合理的近似值。


为了使得近似值有效,wGAN在判别器(在wGAN中使用了critic一词,和GAN中的discriminator是同一个意思)中使用了权重剪裁(weight clipping)。但是正是权重剪裁导致了上述的问题。

 

后来对wGAN的训练方法进行了改进,它通过在判别器引入梯度惩罚(gradient penalty)使得训练稳定。梯度惩罚只要简单的加到总损失函数中的Wasserstein距离就可以了。


640?wx_fmt=jpeg&wxfrom=5&wx_lazy=1


历史上第一次,终于可以训练GAN而几乎不用超参数调优了。其中包括了101层的残差网络和基于离散数据的语言模型。

 

Wasserstein距离的一个优势就是当判别器改进的时候,生成器能收到改进的梯度。但是在使用Jensen-Shannon散度的时候,当判别器改进的时候,产生的梯度消失,生成器无法学习改进。这个也是产生训练不稳定的主要原因。

 

如果想对这个理论有深入理解,我建议读一下下面两个文章:

  • Wasserstein GAN

  • Wasserstein GANs的改进的训练方法

 

随着新的目标函数的引入,我看待GANs的方式也发生了变化:

 

传统的GAN(Jensen-Shannon散度)下,生成器和判别器是竞争关系,如下图。


640?wx_fmt=jpeg&wxfrom=5&wx_lazy=1


在wGAN(Wasserstein距离)下,生成器和判别器是协作关系,如下图。


640?wx_fmt=jpeg&wxfrom=5&wx_lazy=1


代码实现

640?wx_fmt=jpeg&wxfrom=5&wx_lazy=1


结论


对抗学习的网络不受我们对问题域理解的任何限制,没有任何先验知识,网络就是从数据中学习。



原文发布时间为:2017-06-27

本文作者:Michael Dietz

本文来自云栖社区合作伙伴“数据派THU”,了解相关信息可以关注“数据派THU”微信公众号

相关文章
|
2月前
|
机器学习/深度学习 存储 算法
【复现】尝试使用numpy对卷积神经网络中各经典结构进行改写复现
【复现】尝试使用numpy对卷积神经网络中各经典结构进行改写复现
38 0
【复现】尝试使用numpy对卷积神经网络中各经典结构进行改写复现
|
2月前
|
机器学习/深度学习 编解码 计算机视觉
YOLOv8改进 | 主干篇 | SwinTransformer替换Backbone(附代码 + 详细修改步骤 +原理介绍)
YOLOv8改进 | 主干篇 | SwinTransformer替换Backbone(附代码 + 详细修改步骤 +原理介绍)
179 0
|
机器学习/深度学习 人工智能 自然语言处理
一文尽览 | 开放世界目标检测的近期工作及简析!(基于Captioning/CLIP/伪标签/Prompt)(上)
人类通过自然监督,即探索视觉世界和倾听他人描述情况,学会了毫不费力地识别和定位物体。我们人类对视觉模式的终身学习,并将其与口语词汇联系起来,从而形成了丰富的视觉和语义词汇,不仅可以用于检测物体,还可以用于其他任务,如描述物体和推理其属性和可见性。人类的这种学习模式为我们实现开放世界的目标检测提供了一个可以学习的角度。
一文尽览 | 开放世界目标检测的近期工作及简析!(基于Captioning/CLIP/伪标签/Prompt)(上)
|
2月前
|
前端开发 PyTorch 算法框架/工具
【基础实操】借用torch自带网络进行训练自己的图像数据
【基础实操】借用torch自带网络进行训练自己的图像数据
24 0
【基础实操】借用torch自带网络进行训练自己的图像数据
|
2月前
|
机器学习/深度学习 计算机视觉 网络架构
YOLOv8改进 | 2023主干篇 | 替换LSKNet遥感目标检测主干 (附代码+修改教程+结构讲解)
YOLOv8改进 | 2023主干篇 | 替换LSKNet遥感目标检测主干 (附代码+修改教程+结构讲解)
67 1
YOLOv8改进 | 2023主干篇 | 替换LSKNet遥感目标检测主干 (附代码+修改教程+结构讲解)
|
11月前
|
机器学习/深度学习 并行计算 固态存储
YOLO系列 | 一份YOLOX改进的实验报告,并提出更优秀的模型架构组合!
YOLO系列 | 一份YOLOX改进的实验报告,并提出更优秀的模型架构组合!
156 0
|
11月前
|
机器学习/深度学习 算法框架/工具 计算机视觉
又改ResNet | 重新思考ResNet:采用高阶方案的改进堆叠策略(附论文下载)(一)
又改ResNet | 重新思考ResNet:采用高阶方案的改进堆叠策略(附论文下载)(一)
137 0
|
11月前
|
计算机视觉
又改ResNet | 重新思考ResNet:采用高阶方案的改进堆叠策略(附论文下载)(二)
又改ResNet | 重新思考ResNet:采用高阶方案的改进堆叠策略(附论文下载)(二)
75 0
|
11月前
|
机器学习/深度学习 存储 缓存
随机YOLO|你用的YOLO在Dataset Shift时是否依旧鲁棒?这个策略可能是你想要的!!!
随机YOLO|你用的YOLO在Dataset Shift时是否依旧鲁棒?这个策略可能是你想要的!!!
140 0
|
11月前
|
机器学习/深度学习 编解码 运维
覆盖100余篇论文,这篇综述系统回顾了CV中的扩散模型
覆盖100余篇论文,这篇综述系统回顾了CV中的扩散模型
109 0