条件生成对抗模型生成数字图片

简介: 这次我们选用条件生成对抗模型(Conditional Generative Adversarial Networks)来生成数字图片

在上个数字识别的例子中,我们使用了一个简单的3层神经网络来识别给定图片的中的数字。

这次我们在上次的例子中在提升一下,这次我们选用条件生成对抗模型(Conditional Generative Adversarial Networks)来生成数字图片。

下面就让我们开始吧!

第一步:import 我们需要的数据库

%matplotlib inline

from __future__ import print_function, division

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.cm as cm

import seaborn as sns
sns.set_style('white')

from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam, SGD

第二步:数据预处理

在上个例子中,我们使用的是28*28的二值图像,也就是说像素只有0和1,0表示黑色,1表示白色。

在上个例子中,我们使用28*28的灰度图像,每个像素的值都是从0~255的数值,值越大,越接近白色。

2.1 数据加载函数

首先定义一个数据加载函数 load_data 用来加载数据。

不同于上一个例子,我们的数据存放在 npz 文件中,numpy 提供了 load 接口可以直接读取。

通过函数的输出我们就可以看到,npz文件里的内容是 x_traln , y_traln , x_test , y_ test 。

这几个内容标签分别对应训练图片数据,训练图片数据的 label,测试图片数据,测试图片数据的 label。

def load_data():
    data = np.load('mnist.npz')
    print(data.files)
    x_train = data['x_train']
    y_train = data['y_train']
    x_test = data['x_test']
    y_test = data['y_test']

    x_train = (x_train.astype(np.float32) - 127.5) / 127.5
    x_train = np.expand_dims(x_train, axis=3)
    y_train = y_train.reshape(-1, 1)

    return (x_train, y_train), (x_test, y_test)

(x_train, y_train), (x_test, y_test)=load_data()

2.2 数据查看

在任何模型建立之前,常规的操作是先查看数据的情况,比如数据集的大小,训练集和测试集的数据数量,标签的数据数量分布等等。

2.2.1 查看原始数据的纬度

训练集有60000条数据,测试集有100000条数据,并且每一条数据有28*28的图片像素数据。

print(x_train.shape)

print(x_test.shape)

2.2.2 查看标签的数量

通过查看训练标签跟测试标签的数量,我们可以观察到,训练和测试的数据集跟训练和测试的标签在数量上是一一对应的。这也是我么想要的结果,表示我们的数据集是完整的。

print(y_train.shape)

print(y_test.shape)

2.2.3 查看所有的标签种类

可以看出标签表示了从0-9的数字,没有其他的错误数据。

np.unique(y_train)

np.unique(y_test)

2.3数据可视化

接下来我随机的选取一些我们已经转换好的图片数据,用 matplot 来查看下,标签和图片是否一致。

plt.figure(figsize=(15, 9))
for i in range(50):
    random_selection = np.random.randint(0, 500)
    plt.subplot(5, 10, 1+i)
    plt.title(y_train[random_selection])
    plt.imshow(x_train[random_selection][:,:,0], cmap=cm.gray)

2.4 查看数据是否平衡

分类的设计都是基于类分布大致平衡这一假设,通常假定用于训练的数据集是平衡的,即各类所含的样本数大致相当。

均匀的数据分布,将会提高模型的精度。如果数据不均匀,我么就要考虑进行平衡处理,常用的处理方式包括采样、加权、数据合成等。

我们看下标签的分布情况,看下每个标签种类的数据量是否分布均匀。

在 MNIST 数据集中,我们的数据是比较均匀分布的。

sns.distplot(y_train, kde=False, bins=10)

第三步:构建模型

接下来让我们定义模型:

我们选用的是条件生成对抗模型(Conditional Generative Adversarial Networks)

首先先让我们来认识下基本的生成对抗模型(Generative Adversarial Networks)的架构

3.1 GAN(Generative Adversarial Networks)的模型示意图

从模型的示意图中我们可以看到,GAN的模型分成两个模型,一个是生成模型(Generator Network), 还有一个是判别模型(Discriminator Network)

我们的输入数据分成两个,一个是真实的图片,一个是噪声图片。

首先,噪声图片输入到生成模型中,通过生成模型输出一张假的图片,然后我们同时将得到的假的图片跟真实的图片输入到判别模型中,通过判别模型,我们输出一个预测的标签。

这个是最基本的GAN的模型流程。

3.2 条件生成对抗模型(Conditional Generative Adversarial Networks)

从基本的生成对抗模型(Generative Adversarial Networks)模型中我们看到,输入的只是一张随机的噪声图片,并没有指定这个噪声图片对应的标签的任何信息。

那么在我们的这个例子里,我们希望输入的噪声图片,是指定的一个数字的标签,并且在通过GAN模型以后,能够输出对于我们输入标签的数字图片。

因此我们需要在他的基础上做些修改,这个模型就是我们这次使用的模型,叫做条件生成对抗模型(Conditional Generative Adversarial Networks)。

模型示意图

可以看到我们做了如下修改:

我们在生成网络的输入数据中加入了我们的随机噪声图片所对应的标签,

我们在判别网络中加入了,真实图片所对应的标签。

3.3 定义网络

下面让我们来定义我们需要的模型

3.3.1先定义一些常量

# 输入图片数据的维度
img_shape = (28, 28, 1)

# 图片通道数
channels = 1

# 标签数目
num_classes = 10

# 噪声图片的输入维度
latent_dim = 100

3.3.2 定义优化器

这里我们使用的优化器是Adam(Adaptive Moment Estimation)

Adam 是一种可以替代传统随机梯度下降(SGD)过程的一阶优化算法,它能基于训练数据迭代地更新神经网络权重。

Adam最开始是由OpenAI的Diederik Kingma和多伦多大学的Jimmy Ba在提交到2015年ICLR论文(Adam: A Method for Stochastic Optimization)中提出的。

Adam优化器有以下特点:

1.实现简单,计算高效,对内存需求少

2.参数的更新不受梯度的伸缩变换影响

3.超参数具有很好的解释性,且通常无需调整或仅需很少的微调

4.更新的步长能够被限制在大致的范围内(初始学习率)

5.能自然地实现步长退火过程(自动调整学习率)

6.很适合应用于大规模的数据及参数的场景

7.适用于不稳定目标函数

8.适用于梯度稀疏或梯度存在很大噪声的问题

optimizer = Adam(0.0002, 0.5)

3.3.3 定义生成模型

def build_generator():
    model = Sequential()

    model.add(Dense(256, input_dim=latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(np.prod(img_shape), activation='tanh'))
    model.add(Reshape(img_shape))

    model.summary()

    noise = Input(shape=(latent_dim,))
    label = Input(shape=(1,), dtype='int32')
    label_embedding = Flatten()(Embedding(num_classes, latent_dim)(label))

    model_input = multiply([noise, label_embedding])
    img = model(model_input)

    return Model([noise, label], img)

3.3.4 定义判别模型

def build_discriminator():
    model = Sequential()

    model.add(Dense(512, input_dim=np.prod(img_shape)))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.4))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.4))
    model.add(Dense(1, activation='sigmoid'))
    model.summary()

    img = Input(shape=img_shape)
    label = Input(shape=(1,), dtype='int32')

    label_embedding = Flatten()(Embedding(num_classes, np.prod(img_shape))(label))
    flat_img = Flatten()(img)

    model_input = multiply([flat_img, label_embedding])

    validity = model(model_input)

    return Model([img, label], validity)

在上面的生成模型跟判别模型中,我们使用了几个新的网络, LeakyReLU, Dropout, BatchNormalization。

下面我们对这些层次进行一些简单的说明跟介绍。

3.4 带泄露修正线性单元(Leaky ReLU)

在上一个数字识别的例子中, 我们使用了线性整流函数(Rectified Linear Unit)就是我们常说的 ReLU 来作为激活函数。

我们也同时介绍了它的优缺点,其中一个重要的缺点就是前向传播过程中,在x<0时,神经元保持非激活状态。

这样会导致权重无法得到更新,也就是网络无法学习,为了解决 Relu 函数这个缺点,在 Relu 函数的负半区间引入一个泄露(Leaky)值, 使得ReLU在这个区间不为零。因此 Leaky ReLU 的图像如下, 通过参数a来控制函数负半区的值。

3.5 Dropout

在机器学习的模型中,如果模型的参数太多,而训练样本又太少,训练出来的模型很容易产生过拟合的现象, 具体表现在模型在训练数据上损失函数较小,预测准确率较高。

但是在测试数据上损失函数比较大,预测准确率较低。为了解决过拟合问题,Hinton在其论文《Improving neural networks by preventing co-adaptation of feature detectors》中提出了 Dropout 。

Dropout 的工作原理是我们在前向传播的时候,让某个神经元的激活值以一定的概率停止工作,这样可以使模型泛化性更强,因为它不会太依赖某些局部的特征。

它的工作的可视化表示如下图所示:

Dropout 可以有效的防止模型过拟合。

3.6 Batch Normallzatlon

机器学习领域有个很重要的假设:IID独立同分布假设,就是假设训练数据和测试数据是满足相同分布的,这是通过训练数据获得的模型能够在测试集获得好的效果的一个基本保障。

Batch Normalization就是在深度神经网络训练过程中使得每一层神经网络的输入保持相同分布。

基本思想其实相当直观:因为深层神经网络在做非线性变换前的激活输入值随着网络深度加深或者在训练过程中,其分布逐渐发生偏移或者变动,之所以训练收敛慢,一般是整体分布逐渐往非线性函数的取值区间的上下限两端靠近(参考Sigmoid函数),所以这导致反向传播时低层神经网络的梯度消失,这是训练深层神经网络收敛越来越慢的本质原因。

而 Batch Normalization 就是通过一定的规范化手段,把每层神经网络任意神经元这个输入值的分布强行拉回到均值为0方差为1的标准正态分布,其实就是把越来越偏的分布强制拉回比较标准的分布。

这样使得激活输入值落在非线性函数对输入比较敏感的区域,这样输入的小变化就会导致损失函数较大的变化,意思是这样让梯度变大,避免梯度消失问题产生,而且梯度变大意味着学习收敛速度快,能大大加快训练速度

Batch Normalization有如下几个有特点:

1.使得网络中每层输入数据的分布相对稳定,加速模型学习速度

2.使得模型对网络中的参数不那么敏感,简化调参过程,使得网络学习更加稳定

3.允许网络使用饱和性激活函数(例如sigmoid,tanh等),缓解梯度消失问题

4.具有一定的正则化效果

3.7 定义中间结果显示函数

这个函数主要用于在训练的时候,显示当前模型的预测情况,其中epoch这个参数表示当前第几个 epoch,generator 表示生成函数模型。

我们将查看当前 epoch 时,生成模型对于0-9这个几个数字的生成情况。

函数中,我们使用 numpy 产生一个0-9的标签数组,并且对每个0-9的数字产生一个噪声图片的数组,然后我们将噪声图片,以及对应的标签交给生成模型预测。将产生的结果,用 matplot 分两行绘制在一张图片内。

其中上面是0,1,2,3,4,下面是5,6,7,8,9, 并且将这张图片保存在images文件夹中, 文件名为当前 epoch,然后我们用 matplot 将这个图像显示jupyter上,方便查看。

def sample_images(epoch, generator):
    print("第%d个epoch的预测结果" % epoch)

    # 2行,5列
    r, c = 2, 5

    # 噪声图片
    noise = np.random.normal(0, 1, (r * c, 100))

    # 噪声图片对应的标签
    sampled_labels = np.arange(0, 10).reshape(-1, 1)

    # 用生成模型预测结果
    gen_imgs = generator.predict([noise, sampled_labels])
    gen_imgs = 0.5 * gen_imgs + 0.5

    # 将结果绘制在一张图片上并保存
    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            # 我们绘制的时灰度图片
            axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
            # 将结果图片的标签页绘制在结果图片的上方
            axs[i,j].set_title("%d" % sampled_labels[cnt])
            # 关闭坐标轴
            axs[i,j].axis('off')
            cnt += 1
    #保存图片
    fig.savefig("images/epoch_%d.png" % epoch)
    plt.close()

    # 读取刚才的图片,并显示在jupyter上
    img = mpimg.imread('images/epoch_%d.png' %epoch)
    plt.imshow(img)
    plt.axis('off')
    plt.show()

3.8 定义条件生成对抗模型(Conditional Generative Adversarial Networks)

class ConditionalGAN():
    def __init__(self):

        # loss值记录列表,用于最后显示Loss值的趋势,查看训练效果
        self.g_loss = []
        # epoch的记录列表
        self.epoch_range = []

        # ---------------------
        # 判别模型部分
        # ---------------------
        self.discriminator = build_discriminator()
        self.discriminator.compile(loss=['binary_crossentropy'],
                                   optimizer=optimizer,
                                   metrics=['accuracy'])

        # ---------------------
        # 生成模型部分
        # ---------------------
        self.generator = build_generator()

        # ---------------------
        # 合并模型部分
        # ---------------------
        noise = Input(shape=(latent_dim,))
        label = Input(shape=(1,))
        img = self.generator([noise, label])
        # 在合并模型中,我们只训练生成模型
        self.discriminator.trainable = False
        valid = self.discriminator([img, label])
        # 结合判别模型跟生成模型
        self.combined = Model([noise, label], valid)
        self.combined.compile(loss=['binary_crossentropy'],
                              optimizer=optimizer)

    # 训练函数
    def train(self, epochs, batch_size=128, sample_interval=50):

        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs+1):
            # ---------------------
            #  训练判别模型部分
            # ---------------------
            idx = np.random.randint(0, x_train.shape[0], batch_size)
            imgs, labels = x_train[idx], y_train[idx]
            # 生成噪声图片
            noise = np.random.normal(0, 1, (batch_size, 100))
            # 生成模型通过噪声图片跟标签,生成相应的图片
            gen_imgs = self.generator.predict([noise, labels])
            # 训练判别模型
            d_loss_real = self.discriminator.train_on_batch([imgs, labels], valid)
            d_loss_fake = self.discriminator.train_on_batch([gen_imgs, labels], fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  训练生成模型部分
            # ---------------------
            sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)
            # 训练生成模型模型
            g_loss = self.combined.train_on_batch([noise, labels], valid)

            # 记录训练结果的值
            self.g_loss.append(g_loss)
            self.epoch_range.append(epoch)

            # 每200个epoch输出一次结果,查看效果
            if epoch % sample_interval == 0:
                sample_images(epoch, self.generator)

创建条件生成对抗模型(Conditional Generative Adversarial Networks)对象

gan = ConditionalGAN()

训练

这次我们训练20000个epochs,设置Batch Size大小为32, 同时每200个epoch,我们输出一次预测结果,看下0-9这几个数字在当前模型下的生成情况。

从中间的每200个epoch的结果来看,我们的模型从最开始的随机图像,先慢慢的产生出黑色的背景图,然后在每个图像的中间慢慢的产生出内容,随着epoch迭代的增加,中间输出的图像的内容也慢慢的变得更加有意义,直到迭代结束。我们输出的结果图像,基本可以用肉眼看到这个是什么数字。

gan.train(epochs=20000, batch_size=32, sample_interval=200)

查看训练过程中条件生成对抗模型(Conditional Generative Adversarial Networks)的损失值(loss)情况:

图表中显示了生成模型跟对抗模型的损失值(loss)的趋势,

查看图表,我们可以看到损失值(loss)从一开始的非常大的数值,下降到了一个稳定的值。

这个表明我们的模型在不断的迭代的过程中,产生的结果的误差是在逐步逐步的减小,最后趋于一个稳定的值,说明我们的模型一直在收敛。

plt.plot(gan.epoch_range, gan.g_loss, '-r', label= "Generator Loss")
plt.legend()
plt.xlabel('epoch')
plt.ylabel('loss')

脚本地址:https://github.com/matpool/mnist_gan

矩池云现已经把脚本镜像以上线,有感兴趣的用户可以在矩池云中体验。

目录
相关文章
|
1月前
|
机器学习/深度学习 人工智能 数据可视化
什么是条件生成对抗性网络?
什么是条件生成对抗性网络?
|
4月前
|
机器学习/深度学习 vr&ar
【深度强化学习】值函数逼近的详解(图文解释)
【深度强化学习】值函数逼近的详解(图文解释)
33 0
|
移动开发 文字识别 算法
论文推荐|[PR 2019]SegLink++:基于实例感知与组件组合的任意形状密集场景文本检测方法
本文简要介绍Pattern Recognition 2019论文“SegLink++: Detecting Dense and Arbitrary-shaped Scene Text by Instance-aware Component Grouping”的主要工作。该论文提出一种对文字实例敏感的自下而上的文字检测方法,解决了自然场景中密集文本和不规则文本的检测问题。
1874 0
论文推荐|[PR 2019]SegLink++:基于实例感知与组件组合的任意形状密集场景文本检测方法
|
3天前
|
测试技术
Vript:最为详细的视频文本数据集,每个视频片段平均超过140词标注 | 多模态大模型,文生视频
[Vript](https://github.com/mutonix/Vript) 是一个大规模的细粒度视频文本数据集,包含12K个高分辨率视频和400k+片段,以视频脚本形式进行密集注释,每个场景平均有145个单词的标题。除了视觉信息,还转录了画外音,提供额外背景。新发布的Vript-Bench基准包括三个挑战性任务:Vript-CAP(详细视频描述)、Vript-RR(视频推理)和Vript-ERO(事件时序推理),旨在推动视频理解的发展。
14 1
Vript:最为详细的视频文本数据集,每个视频片段平均超过140词标注 | 多模态大模型,文生视频
|
8月前
|
机器学习/深度学习 自然语言处理 算法
解读未知:文本识别算法的突破与实际应用
解读未知:文本识别算法的突破与实际应用
解读未知:文本识别算法的突破与实际应用
|
8月前
|
JSON 算法 数据格式
优化cv2.findContours()函数提取的目标边界点,使语义分割进行远监督辅助标注
可以看到cv2.findContours()函数可以将目标的所有边界点都进行导出来,但是他的点存在一个问题,太过密集,如果我们想将语义分割的结果重新导出成labelme格式的json文件进行修正时,这就会存在点太密集没有办法进行修改,这里展示一个示例:没有对导出的结果进行修正,在labelme中的效果图。
82 0
|
9月前
|
机器学习/深度学习 数据可视化 数据挖掘
字符级CNN分类模型的实现
字符级CNN分类模型的实现
|
11月前
|
机器学习/深度学习 编解码 自动驾驶
联合训练2D-3D多任务学习 | 深度估计、检测、分割、3D检测通吃
联合训练2D-3D多任务学习 | 深度估计、检测、分割、3D检测通吃
248 0
|
11月前
|
算法 固态存储
分别使用SAD匹配,NCC匹配,SSD匹配三种算法提取双目图像的深度信息
分别使用SAD匹配,NCC匹配,SSD匹配三种算法提取双目图像的深度信息
111 0
分别使用SAD匹配,NCC匹配,SSD匹配三种算法提取双目图像的深度信息
|
机器学习/深度学习 算法 数据挖掘
K近邻算法(KNN)(包含手写体识别、约会类型识别的代码)
是有监督学习、属于判别模型 、支持多分类以及回归、非线性、有预测函数、无优化目标、无优化求解算法。(算法地图) 对应每个训练数据xi有对应的标签yi--监督学习;
135 0
K近邻算法(KNN)(包含手写体识别、约会类型识别的代码)