学习笔记TF051:生成式对抗网络
生成式对抗网络(gennerative adversarial network,GAN),谷歌2014年提出网络模型。灵感自二人博弈的零和博弈,目前最火的非监督深度学习。GAN之父,Ian J.Goodfellow,公认人工智能顶级专家。
原理。
生成式对搞网络包含一个生成模型(generative model,G)和一个判别模型(discriminative model,D)。Ian J.Goodfellow、Jean Pouget-Abadie、Mehdi Mirza、Bing Xu、David Warde-Farley、Sherjil Ozair、Aaron Courville、Yoshua Bengio论文,《Generative Adversarial Network》,https://arxiv.org/abs/1406.2661 。
生成式对抗网络结构:
噪声数据->生成模型->假图片---|
                                                   |->判别模型->真/假
打乱训练数据->训练集->真图片-|
生成式对抗网络主要解决如何从训练样本中学习出新样本。生成模型负责训练出样本的分布,如果训练样本是图片就生成相似的图片,如果训练样本是文章名子就生成相似的文章名子。判别模型是一个二分类器,用来判断输入样本是真实数据还是训练生成的样本。
生成式对抗网络优化,是一个二元极小极大博弈(minimax two-player game)问题。使生成模型输出在输入给判别模型时,判断模型秀难判断是真实数据还是虚似数据。训练好的生成模型,能把一个噪声向量转化成和训练集类似的样本。Argustus Odena、Christopher Olah、Jonathon Shlens论文《Coditional Image Synthesis with Auxiliary Classifier GANs》。
辅助分类器生成式对抗网络(auxiliary classifier GAN,AC-GAN)实现。
生成式对抗网络应用。生成数字,生成人脸图像。
生成式对抗网络实现。https://github.com/fchollet/keras/blob/master/examples/mnist_acgan.py 。
Augustus Odena、Chistopher Olah和Jonathon Shlens 论文《Conditional Image Synthesis With Auxiliary Classifier GANs》。
通过噪声,让生成模型G生成虚假数据,和真实数据一起送到判别模型D,判别模型一方面输出数据真/假,一方面输出图片分类。
首先定义生成模型,目的是生成一对(z,L)数据,z是噪声向量,L是(1,28,28)的图像空间。
def build_generator(latent_size):
        cnn = Sequential()
        cnn.add(Dense(1024, input_dim=latent_size, activation='relu'))
        cnn.add(Dense(128 * 7 * 7, activation='relu'))
        cnn.add(Reshape((128, 7, 7)))
        #上采样,图你尺寸变为 14X14
        cnn.add(UpSampling2D(size=(2,2)))
        cnn.add(Convolution2D(256, 5, 5, border_mode='same', activation='relu', init='glorot_normal'))
        #上采样,图像尺寸变为28X28
        cnn.add(UpSampling2D(size=(2,2)))
        cnn.add(Convolution2D(128, 5, 5, border_mode='same', activation='relu', init='glorot_normal'))
        #规约到1个通道
        cnn.add(Convolution2D(1, 2, 2, border_mode='same', activation='tanh', init='glorot_normal'))
        #生成模型输入层,特征向量
        latent = Input(shape=(latent_size, ))
        #生成模型输入层,标记
        image_class = Input(shape=(1,), dtype='int32')
        cls = Flatten()(Embedding(10, latent_size, init='glorot_normal')(image_class))
        h = merge([latent, cls], mode='mul')
        fake_image = cnn(h) #输出虚假图片
        return Model(input=[latent, image_class], output=fake_image)
定义判别模型,输入(1,28,28)图片,输出两个值,一个是判别模型认为这张图片是否是虚假图片,另一个是判别模型认为这第图片所属分类。
def build_discriminator();
        #采用激活函数Leaky ReLU来替换标准的卷积神经网络中的激活函数
        cnn = Wequential()
        cnn.add(Convolution2D(32, 3, 3, border_mode='same', subsample=(2, 2), input_shape=(1, 28, 28)))
        cnn.add(LeakyReLU())
        cnn.add(Dropout(0.3))
        cnn.add(Convolution2D(64, 3, 3, border_mode='same', subsample=(1, 1)))
        cnn.add(LeakyReLU())
        cnn.add(Dropout(0.3))
        cnn.add(Convolution2D(128, 3, 3, border_mode='same', subsample=(1, 1)))
        cnn.add(LeakyReLU())
        cnn.add(Dropout(0.3))
        cnn.add(Convolution2D(256, 3, 3, border_mode='same', subsample=(1, 1)))
        cnn.add(LeakyReLU())
        cnn.add(Dropout(0.3))
        cnn.add(Flatten())
        image = Input(shape=(1, 28, 28))
        features = cnn(image)
        #有两个输出
        #输出真假值,范围在0~1
        fake = Dense(1, activation='sigmoid',name='generation')(features)
        #辅助分类器,输出图片分类
       aux = Dense(10, activation='softmax', name='auxiliary')(features)
        return Model(input=image, output=[fake, aux])
训练过程,50轮(epoch),把权重保存,每轮把虚假数据生成图处保存,观察虚假数据演化过程。
if __name__ =='__main__':
        #定义超参数
        nb_epochs = 50
        batch_size = 100
        latent_size = 100
        #优化器学习率
        adam_lr = 0.0002
        adam_beta_l = 0.5
        #构建判别网络
        discriminator = build_discriminator()
        discriminator.compile(optimizer=adam(lr=adam_lr, beta_l=adam_beta_l), loss='binary_crossentropy')
        latent = Input(shape=(lastent_size, ))
        image_class = Input(shape-(1, ), dtype='int32')
        #生成组合模型
        discriminator.trainable = False
        fake, aux = discriminator(fake)
        combined = Model(input=[latent, image_class], output=[fake, aux])
        combined.compile(optimizer=Adam(lr=adam_lr, beta_l=adam_beta_1), loss=['binary_crossentropy', 'sparse_categorical_crossentropy'])
        #将mnist数据转化为(...,1,28,28)维度,取值范围为[-1,1]
        (X_train,y_train),(X_test,y_test) = mnist.load_data()
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=1)
        X_test = (X_test.astype(np.float32) - 127.5) / 127.5
        X_test = np.expand_dims(X_test, axis=1)
        num_train, num_test = X_train.shape[0], X_test.shape[0]
        train_history = defaultdict(list)
        test_history = defaultdict(list)
        for epoch in range(epochs):
            print('Epoch {} of {}'.format(epoch + 1, epochs))
            num_batches = int(X_train.shape[0] / batch_size)
            progress_bar = Progbar(target=num_batches)
            epoch_gen_loss = []
            epoch_disc_loss = []
            for index in range(num_batches):
                progress_bar.update(index)
                #产生一个批次的噪声数据
                noise = np.random.uniform(-1, 1, (batch_size, latent_size))
                # 获取一个批次的真实数据
                image_batch = X_train[index * batch_size:(index + 1) * batch_size]
                label_batch = y_train[index * batch_size:(index + 1) * batch_size]
                # 生成一些噪声标记
                sampled_labels = np.random.randint(0, 10, batch_size)
                # 产生一个批次的虚假图片
                generated_images = generator.predict(
                [noise, sampled_labels.reshape((-1, 1))], verbose=0)
                X = np.concatenate((image_batch, generated_images))
                y = np.array([1] * batch_size + [0] * batch_size)
                aux_y = np.concatenate((label_batch, sampled_labels), axis=0)
                epoch_disc_loss.append(discriminator.train_on_batch(X, [y, aux_y]))
                # 产生两个批次噪声和标记
                noise = np.random.uniform(-1, 1, (2 * batch_size, latent_size))
                sampled_labels = np.random.randint(0, 10, 2 * batch_size)
                # 训练生成模型来欺骗判别模型,输出真/假都设为真
                trick = np.ones(2 * batch_size)
                epoch_gen_loss.append(combined.train_on_batch(
                   [noise, sampled_labels.reshape((-1, 1))],
                    [trick, sampled_labels]))
            print('\nTesting for epoch {}:'.format(epoch + 1))
            # 评估测试集,产生一个新批次噪声数据
            noise = np.random.uniform(-1, 1, (num_test, latent_size))
            sampled_labels = np.random.randint(0, 10, num_test)
            generated_images = generator.predict(
                [noise, sampled_labels.reshape((-1, 1))], verbose=False)
            X = np.concatenate((X_test, generated_images))
            y = np.array([1] * num_test + [0] * num_test)
            aux_y = np.concatenate((y_test, sampled_labels), axis=0)
            # 判别模型是否能判别
            discriminator_test_loss = discriminator.evaluate(
                X, [y, aux_y], verbose=False)
            discriminator_train_loss = np.mean(np.array(epoch_disc_loss), axis=0)
            # 创建两个批次新噪声数据
            noise = np.random.uniform(-1, 1, (2 * num_test, latent_size))
            sampled_labels = np.random.randint(0, 10, 2 * num_test)
            trick = np.ones(2 * num_test)
            generator_test_loss = combined.evaluate(
                [noise, sampled_labels.reshape((-1, 1))],
                [trick, sampled_labels], verbose=False)
            generator_train_loss = np.mean(np.array(epoch_gen_loss), axis=0)
            # 损失值等性能指标记录下来,并输出
            train_history['generator'].append(generator_train_loss)
            train_history['discriminator'].append(discriminator_train_loss)
            test_history['generator'].append(generator_test_loss)
            test_history['discriminator'].append(discriminator_test_loss)
            print('{0:<22s} | {1:4s} | {2:15s} | {3:5s}'.format(
                'component', *discriminator.metrics_names))
            print('-' * 65)
            ROW_FMT = '{0:<22s} | {1:<4.2f} | {2:<15.2f} | {3:<5.2f}'
            print(ROW_FMT.format('generator (train)',
                             *train_history['generator'][-1]))
            print(ROW_FMT.format('generator (test)',
                             *test_history['generator'][-1]))
            print(ROW_FMT.format('discriminator (train)',
                             *train_history['discriminator'][-1]))
            print(ROW_FMT.format('discriminator (test)',
                             *test_history['discriminator'][-1]))
            # 每个epoch保存一次权重
            generator.save_weights(
                'params_generator_epoch_{0:03d}.hdf5'.format(epoch), True)
            discriminator.save_weights(
                'params_discriminator_epoch_{0:03d}.hdf5'.format(epoch), True)
            # 生成一些可视化虚假数字看演化过程
            noise = np.random.uniform(-1, 1, (100, latent_size))
            sampled_labels = np.array([
                [i] * 10 for i in range(10)
            ]).reshape(-1, 1)
            generated_images = generator.predict(
                [noise, sampled_labels], verbose=0)
            # 整理到一个方格
            img = (np.concatenate([r.reshape(-1, 28)
                               for r in np.split(generated_images, 10)
                               ], axis=-1) * 127.5 + 127.5).astype(np.uint8)
            Image.fromarray(img).save(
                'plot_epoch_{0:03d}_generated.png'.format(epoch))
        pickle.dump({'train': train_history, 'test': test_history},
                    open('acgan-history.pkl', 'wb'))
训练结束,创建3类文件。params_discriminator_epoch_{{epoch_number}}.hdf5,判别模型权重参数。params_generator_epoch_{{epoch_number}}.hdf5,生成模型权重参数。plot_epoch_{{epoch_number}}_generated.png 。
生成式对抗网络改进。生成式对抗网络(generative adversarial network,GAN)在无监督学习非常有效。常规生成式对抗网络判别器使用Sigmoid交叉熵损失函数,学习过程梯度消失。Wasserstein生成式对抗网络(Wasserstein generative adversarial network,WGAN),使用Wasserstein距离度量,而不是Jensen-Shannon散度(Jensen-Shannon divergence,JSD)。使用最小二乘生成式对抗网络(least squares generative adversarial network,LSGAN),判别模型用最小平方损失小函数(least squares loss function)。Sebastian Nowozin、Botond Cseke、Ryota Tomioka论文《f-GAN: Training Generative Neural Samplers using Variational Divergence Minimization》。
参考资料:
《TensorFlow技术解析与实战》
欢迎付费咨询(150元每小时),我的微信:qingxingfengzi
学习笔记TF051:生成式对抗网络的更多相关文章
- 学习笔记GAN001:生成式对抗网络,只需10步,从零开始到调试
		生成式对抗网络(gennerative adversarial network,GAN),目前最火的非监督深度学习.一个生成网络无中生有,一个判别网络推动进化.学技术,不先着急看书看文章.先把Demo ... 
- 不要怂,就是GAN (生成式对抗网络) (一)
		前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ... 
- 不要怂,就是GAN (生成式对抗网络) (一): GAN 简介
		前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ... 
- 生成式对抗网络(GAN)学习笔记
		图像识别和自然语言处理是目前应用极为广泛的AI技术,这些技术不管是速度还是准确度都已经达到了相当的高度,具体应用例如智能手机的人脸解锁.内置的语音助手.这些技术的实现和发展都离不开神经网络,可是传统的 ... 
- 【神经网络与深度学习】生成式对抗网络GAN研究进展(五)——Deep Convolutional Generative Adversarial Nerworks,DCGAN
		[前言] 本文首先介绍生成式模型,然后着重梳理生成式模型(Generative Models)中生成对抗网络(Generative Adversarial Network)的研究与发展.作者 ... 
- (转)【重磅】无监督学习生成式对抗网络突破,OpenAI 5大项目落地
		[重磅]无监督学习生成式对抗网络突破,OpenAI 5大项目落地 [新智元导读]"生成对抗网络是切片面包发明以来最令人激动的事情!"LeCun前不久在Quroa答问时毫不加掩饰对生 ... 
- 生成式对抗网络(GAN)实战——书法字体生成练习赛
		https://www.tinymind.cn/competitions/ai 生成式对抗网络(GAN)是近年来大热的深度学习模型. 目前GAN最常使用的场景就是图像生成,作为一种优秀的生成式模型,G ... 
- 生成式对抗网络GAN 的研究进展与展望
		生成式对抗网络GAN的研究进展与展望.pdf 摘要: 生成式对抗网络GAN (Generative adversarial networks) 目前已经成为人工智能学界一个热门的研究方向. GAN的基 ... 
- 【CV论文阅读】生成式对抗网络GAN
		生成式对抗网络GAN 1. 基本GAN 在论文<Generative Adversarial Nets>提出的GAN是最原始的框架,可以看成极大极小博弈的过程,因此称为“对抗网络”.一般 ... 
随机推荐
- Java开源博客My-Blog之docker组件化修改
			前言 5月13号上线了自己的个人博客,<Docker+SpringBoot+Mybatis+thymeleaf的Java博客系统开源啦>,紧接着也在github上开源了博客的代码,到现在为 ... 
- c#编程-线程同步
			线程同步 上一篇介绍了如何开启线程,线程间相互传递参数,及线程中本地变量和全局共享变量区别. 本篇主要说明线程同步. 如果有多个线程同时访问共享数据的时候,就必须要用线程同步,防止共享数据被破坏.如果 ... 
- .NetCore~TagHelpers标签的使用
			回到目录 TagHelpers 能够让服务端代码参与创建和渲染 HTML 元素,让整个View视图看起来只有Html代码,也让前台开发人员开发的页面直接被后台开发人员使用,而不需要重复的书写代码,这种 ... 
- 理解Linux文件系统之inode
			很少转发别人的文章,但是这篇写的太好了. 理解inode 作者: 阮一峰 inode是一个重要概念,是理解Unix/Linux文件系统和硬盘储存的基础. 我觉得,理解inode,不仅有助于提高系统 ... 
- Log4Net(一):快速入门
			概览 Log4Net是Apache Log4J框架在.NET平台上的实现,它是一个帮助开发者将日志信息以多种方式(数据库.控制台.文件等)输出的开源工具. 为什么要使用日志记录 提供应用程序运行时状态 ... 
- UglifyJS-- 对你的js做了什么
			也不是闲着没事去看压缩代码,但今天调试自己代码的时候发现有点意思.因为是自己写的,虽然压缩了,格式化之后还是很好辨认.当然作为min的首要准则不是可读性,而是精简.那么它会尽量的缩短代码,尽量的保持一 ... 
- Spring源码情操陶冶-AbstractApplicationContext#onRefresh
			承接前文Spring源码情操陶冶-AbstractApplicationContext#initApplicationEventMulticaster 约定web.xml配置的contextClass ... 
- 6.RDD持久性
			RDD持久性 1 Why Apache Spark 2 关于Apache Spark 3 如何安装Apache Spark 4 Apache Spark的工作原理 5 spark弹性分布式数据集 6 ... 
- JS组件系列——在ABP中封装BootstrapTable
			前言:关于ABP框架,博主关注差不多有两年了吧,一直迟迟没有尝试.一方面博主觉得像这种复杂的开发框架肯定有它的过人之处,系统的稳定性和健壮性比一般的开源框架肯定强很多,可是另一方面每每想到它繁琐的封装 ... 
- 再起航,我的学习笔记之JavaScript设计模式01
			我的学习笔记是根据我的学习情况来定期更新的,预计2-3天更新一章,主要是给大家分享一下,我所学到的知识,如果有什么错误请在评论中指点出来,我一定虚心接受,那么废话不多说开始我们今天的学习分享吧! 在通 ... 
