中国人工智能学会通讯——最优传输理论在机器学习中的应用 1.1 最优传输理论与 WGAN 模型

简介:

image

最优传输理论是连接几何和概率的桥梁, 它用几何的方法为概率分布的建模和衡量概 率分布之间的距离提供了强有力的工具。最 近,最优传输理论的概念和方法日益渗透进 机器学习领域,为机器学习原理的解释提供 了新的视角,为机器学习算法的改进提供了 新的指导方向。

本文介绍最优传输理论的基本概念和原 理,解释如何用最优传输理论的框架来表示 概率分布,度量概率分布间的距离,如何降 维逼近,并进一步解释这些手法在机器学习 中的应用,给出机器学习原理和特点的最优 传输理论阐释。

1.1 最优传输理论与 WGAN 模型

1. 生成对抗网络简介

训练模型生成对抗网络 (GAN, Generative Adversarial Networks)[1] 是一个“自 相矛盾”的系统,就是“以己之矛,攻己之盾”, 在矛盾中发展,使得矛更加锋利,盾更加强 韧。这里的矛被称为判别器(Descriminator), 这里的盾被称为生成器(Generator)。如图 1~3 所示。

生成器 G 一般是将一个随机变量(例如 高斯分布,或者均匀分布),通过参数化的 概率生成模型(通常是用一个深度神经网来 进行参数化),进行概率分布的逆变换采样, 从而得到一个生成的概率分布。如图 2 所示。 判别器 D 也通常采用深度卷积神经网。

image
image

我们的目的是要找出给定的真实数据内 部的统计规律,将其概率分布表示为 Pr。为 此制作了一个随机变量生成器 G,G 能够产生 随机变量,其概率分布是 Pg,我们用 Pg 来尽 量接近 Pr。为了区分真实概率分布 Pr 和生成 概率分布 Pg,又制作了一个判别器 D,D 用 来判别一个样本是来自真实数据,还是来自 G 生成的伪造数据。为了使 GAN 中的判别器尽 可能将真实样本判为正例,将生成样本判为负 例,Goodfellow 设计了如下的损失函数(loss function):

image

这里第一项不依赖于生成器 G。 此式也可用 于定义 GAN 中生成器的损失函数。

矛盾的交锋过程如下:在训练过程中, 判别器 D 和生成器 G 交替学习,最终达到纳 什均衡(零和游戏)。在均衡状态,判别器 无法区分真实样本和生成样本,此时的生成 概率分布 Pg,可以被视作是真实概率分布 Pr 的一个良好逼近。如图 1~3 所示。

GAN 具有非常重要的优越性:当真实 数据的概率分布 Pr 不可计算时,依赖于数 据内在解释的传统生成模型无法被直接应 用。但是 GAN 依然可以被使用,这是因 为 GAN 引入了内部对抗的训练机制,能 够逼近难以计算的概率分布。Yann LeCun 一直积极倡导 GAN,因为 GAN 为无监督 学习提供了一个强有力的算法框架,而无 监督学习被广泛认为是通往人工智能的重 要一环。

原始 GAN 形式具有致命缺陷:判别器 越好,生成器的梯度消失越严重。我们固定 生成器 G 来优化判别器 D。考察任意一个样 本 x,其对判别器损失函数的贡献是

image

在这种情况下(判别器最优),如果 Pr 和 Pg 的支撑集合 (support) 交集为零测度,则生成 器的损失函数恒为 0,梯度消失。

本质上,JS 散度给出了概率分布 Pr 、 Pg 之间的差异程度,亦即概率分布间的度 量。我们可以用其他的度量来替换 JS 散度。Wasserstein 距离就是一个好的选择,因为 即便 Pr 、Pg 的交集为零测度,它们之间的 Wasserstein 距离依然非零。这样我们就得到 了 Wasserstein GAN 的模式 [2-3]。Wasserstein 距离的好处在于,即便 Pr、 Pg 两个分布之间 没有重叠,Wasserstein 距离依然能够度量它们的远近。

为此,我们引入最优传输的几何理论 (Optimal Mass Transportation),这个理论可视 化了 W-GAN 的关键概念,例如概率分布、 概率生成模型(生成器)、Wasserstein 距离。 更为重要的,这套理论中所有的概念、原理 都是透明的。例如,对于概率生成模型,理 论上我们可以用最优传输的框架取代深度神 经网络来构造生成器,从而使得机器学习的 黑箱变得透明。

2. 最优传输理论梗概

image
image

蒙 日-安 培 方 程 解 的 存 在 性、 唯 一 性 等价于经典的凸几何中的亚历山大定理 (Alexandrov Theorem)。

image
image

3. W-GAN 中关键概念可视化

W-GAN 模型中,关键的概念包括概率分 布(概率测度)、概率测度间的最优传输映 射(生成器)、概率测度间的Wasserstein距离。 下面我们详细解释每个概念的含义、所对应 的构造方法和相应的几何意义。

概率分布 GAN 模型中有两个至关重要 的概率分布(probability measure),一个 是真实数据的概率分布 Pr;一个是生成数 据的概率分布 Pg。另外,生成器的输入随 机变量可以是任意标准概率分布,例如高 斯分布、均匀分布等。

概率测度可以看成是一种推广的面积(或 者体积)。我们可以用几何变换随意构造一 个概率测度。如图 5 所示,我们用三维扫描 仪获取一张人脸曲面,那么人脸曲面上的面 积就是一个概率测度。我们缩放变换人脸曲 面,使得总面积等于 π;然后,用保角变换 将人脸曲面映射到平面圆盘。如图 5 所示, 保角变换将人脸曲面上的无穷小圆映到平面 上的无穷小圆,但是,小圆的面积发生了变化。 每对小圆的面积比率定义了平面圆盘上的概 率密度函数。

image
image
image
image
image
image
image
image
image
image
image

4. 小结

image

在 W-GAN 模型中,通常生成器和判别 器是用深度神经网络来实现的。根据最优传 输理论,可以用 Briener 势函数来代替深度 神经网络这个黑箱,从而使得整个系统变得透明。在另一层面上,深度神经网络本质上 是在训练概率分布间的传输映射,因此有可 能隐含地在学习最优传输映射,或者等价地 Brenier 势能函数。对这些问题的深入了解, 将有助于我们看穿黑箱。和图6中的例子类似, 图 12 显示了用最优传输映射计算的曲面保面 积参数化。最优传输理论在任意维空间都成立, 图 13 显示了一个三维体的最优传输例子。

image

相关文章
|
18天前
|
机器学习/深度学习 数据采集 人工智能
构建高效机器学习模型的五大技巧
【4月更文挑战第7天】 在数据科学迅猛发展的今天,机器学习已成为解决复杂问题的重要工具。然而,构建一个既精确又高效的机器学习模型并非易事。本文将分享五种提升机器学习模型性能的有效技巧,包括数据预处理、特征工程、模型选择、超参数调优以及交叉验证。这些方法不仅能帮助初学者快速提高模型准确度,也为经验丰富的数据科学家提供了进一步提升模型性能的思路。
|
18天前
招募!阿里云x魔搭社区发起Create@AI创客松邀你探索下一代多维智能体应用
招募!阿里云x魔搭社区发起Create@AI创客松邀你探索下一代多维智能体应用
303 0
|
18天前
|
人工智能 自然语言处理 开发者
AIGC创作活动 | 跟着UP主秋葉一起部署AI视频生成应用!
本次AI创作活动由 B 站知名 AI Up 主“秋葉aaaki”带您学习在阿里云 模型在线服务(PAI-EAS)中零代码、一键部署基于ComfyUI和Stable Video Diffusion模型的AI视频生成Web应用,快速实现文本生成视频的AI生成解决方案,帮助您完成社交平台短视频内容生成、动画制作等任务。制作上传专属GIF视频,即有机会赢取乐歌M2S台式升降桌、天猫精灵、定制保温杯等好礼!
|
2天前
|
人工智能 监控 数据处理
【AI大模型应用开发】【LangSmith: 生产级AI应用维护平台】1. 快速上手数据集与测试评估过程
【AI大模型应用开发】【LangSmith: 生产级AI应用维护平台】1. 快速上手数据集与测试评估过程
17 0
|
2天前
|
人工智能 监控 数据可视化
【AI大模型应用开发】【LangSmith: 生产级AI应用维护平台】0. 一文全览Tracing功能,让你的程序运行过程一目了然
【AI大模型应用开发】【LangSmith: 生产级AI应用维护平台】0. 一文全览Tracing功能,让你的程序运行过程一目了然
7 0
|
2天前
|
人工智能 API 开发者
【AI大模型应用开发】0.2 智谱AI API接入详细步骤和简单应用
【AI大模型应用开发】0.2 智谱AI API接入详细步骤和简单应用
9 0
|
7天前
|
机器学习/深度学习 人工智能 算法
未来AI技术的发展与应用前景
随着人工智能(AI)技术的迅速发展,其在各个领域的应用前景备受关注。本文将探讨未来AI技术的发展趋势,以及其在医疗、交通、教育等领域的潜在应用,展望AI技术对未来社会的影响和改变。
15 1
|
10天前
|
索引 机器学习/深度学习 Python
fast.ai 机器学习笔记(二)(3)
fast.ai 机器学习笔记(二)
24 0
fast.ai 机器学习笔记(二)(3)
|
10天前
|
机器学习/深度学习 算法框架/工具 PyTorch
fast.ai 机器学习笔记(三)(2)
fast.ai 机器学习笔记(三)
37 0
fast.ai 机器学习笔记(三)(2)
|
机器学习/深度学习 算法 计算机视觉
fast.ai 机器学习笔记(四)(4)
fast.ai 机器学习笔记(四)
18 0
fast.ai 机器学习笔记(四)(4)