GAN-生成手写数字-Keras
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers.core import Activation
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling2D
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers.core import Flatten
from keras.optimizers import SGD
from keras.datasets import mnist
import numpy as np
from PIL import Image
import argparse
import math
一、首先要定义一个生成器G,该生成器需要将输入的随机噪声变换为图像。
1. 该模型首先输入有100个元素的向量,该向量随机生成于某分布。
2. 随后利用两个全连接层接连将该输入向量扩展到1024维和128 * 7 * 7
3. 后面就开始将全连接层所产生的一维张量重新塑造成二维张量,即MNIST中的灰度图
4. 由全连接传递的数据会经过几个上采样层和卷积层,注意到最后一个卷积层所采用的卷积核为1,所以经过最后卷积层所生成的图像是一张二维灰度图
def generator_model():
# 下面搭建生成器的架构,首先导入序贯模型(sequential),即多个网络层的线性堆叠
model = Sequential()
# 添加一个全连接层,输入为100维向量,输出1024维
model.add(Dense(input_dim=100, output_dim=1024))
# 添加一个激活函数tanh
model.add(Activation('tanh'))
# 添加一个全连接层,输出为 128 * 7 * 7维度
model.add(Dense(128*7*7))
# 添加一个批量归一化层,该层在每个batch上将前一层的激活值重新规范化,即使得其输出数据的均值接近0,其标准差接近1
model.add(BatchNormalization())
model.add(Activation('tanh')) # Reshape层用来将输入shape转换为特定的shape,将含有 128*7*7 个元素的向量转换为 7*7*128 张量
model.add(Reshape((7, 7, 128), input_shape=(128*7*7,)))
# 2维上采样层,即将数据的行和列分别重复2次
model.add(UpSampling2D(size=(2, 2)))
# 添加一个2维卷积层,卷积核大小为5X5,激活函数为tanh,共64个卷积核,并采用padding以保持图像尺寸不变
model.add(Conv2D(64, (5, 5), padding='same'))
model.add(Activation('tanh'))
model.add(UpSampling2D(size=(2, 2))) # 卷积核设为1即输出图像的维度
model.add(Conv2D(1, (5, 5), padding='same'))
model.add(Activation('tanh'))
return model
二、判别模型
判别模型就是比较传统的图像识别模型,可以按照经典的方法采用几个卷积层与最大池化层,而后再展开为一维张量并采用几个全连接层作为架构
def discrimiator_model():
# 下面搭建判别器架构,同样采用序贯模型
model = Sequential() # 添加一个2维卷积层,卷积核大小为5X5,激活函数为tanh,输入shape在 'channel_first' 模式下为 (samples, channels, rows, cols)
# 在 ‘channel_last’模式下为 (samples, rows, cols, channels),输出为64维。 元素的顺序发生了一定的改变
model.add(Conv2D(64, (5, 5),
padding='same',
input_shape=(28, 28, 1))
)
model.add(Activation('tanh')) # 为空域信号施加最大值池化,pool_size 取(2, 2)代表使图片在两个维度均变为原长的一半
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(128, (5, 5)))
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(2, 2))) # Flatten层把多维输入一维化,常用在卷积层到全连接层的过渡
model.add(Flatten())
model.add(Dense(1024))
model.add(Activation('tanh')) # 一个结点进行二值分类,并采用sigmoid函数的输出作为概念
model.add(Dense(1))
model.add(Activation('sigmoid'))
return model
三、模型拼接
我们在训练生成模型时,需要固定判别模型D以极小化价值函数而寻求更好的生成模型,这就意味着我们需要将生成模型与判别模型拼接在一起,并固定D的权重以训练G的权重。因此训练这个组合模型才能真正更新生成模型的参数。
def generator_containing_discriminator(g, d):
# 将前面定义的生成器架构和判别器架构拼接成一个大的神经网络,用于判别生成的图片
model = Sequential()
# 先添加生成器架构,再令d不可训练,即固定d
# 因此在给定d的情况下训练生成器,即通过将生成的结果投入到判别器进行辨别而优化生成器
model.add(g)
d.trainable = False
model.add(d)
return model
四、生成图片拼接
# 生成图片拼接
def combine_images(generated_images):
num = generated_images.shape[0]
width = int(math.sqrt(num))
height = int(math.ceil(float(num)/width))
shape = generated_images.shape[1:3]
image = np.zeros((height*shape[0], width*shape[1]),
dtype=generated_images.dtype) for index, img in enumerate(generated_images):
i = int(index / width)
j = index % width
image[i * shape[0] : (i + 1) * shape[0], j * shape[1] : (j + 1) * shape[1]] = img[:, :, 0]
return image
五、训练
1. 加载MNIST数据
2. 将数据分割为训练与测试集,并赋值给变量
3. 设置训练模型的超参数
4. 编译模型的训练过程
5. 在每一次迭代内,抽取生成图像与真实图像,并打上标注
6. 随后将数据投入到判别模型中,并进行训练与计算损失
7. 固定判别模型,训练生成模型并计算损失,结束这一次迭代
def train(BATCH_SIZE):
# 加载数据,将数据集下载到本地‘/.keras/datasers/’
# 下载地址:https://s3.amazonaws.com/img-datasets/mnist.npz
(X_train, y_train), (X_test, y_test) = mnist.load_data(r'C:/Users/Administrator/.keras/datasets/mnist.npz')
# image_data_format选择‘channels_last’或‘channels_first’,该选项指定了Keras将要将要使用的维度顺序
# ‘channels_first’假定2D数据的维度顺序为(channels, rows, cols), 3D数据的维度顺序为(channels, conv_dim1, conv_dim2, conv_dim3) # 转换字段类型,并将数据导入变量中
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = X_train[:, :, :, None]
X_test = X_test[:, :, :, None] # 将定义好的模型架构赋值给特定的变量
d = discrimiator_model()
g = generator_model()
d_on_g = generator_containing_discriminator(g, d) # 定义生成器模型、判别器模型,更新所使用的优化算法及超参数
d_optim = SGD(lr=0.001, momentum=0.9, nesterov=True)
g_optim = SGD(lr=0.001, momentum=0.9, nesterov=True) # 编译三个神经网络并设置损失函数和优化算法,其中损失函数都是用二元分类交叉熵函数。编译是用来配置模型学习过程的
g.compile(loss='binary_crossentropy', optimizer='SGD')
d_on_g.compile(loss='binary_crossentropy', optimizer=g_optim) # 前一个架构在固定判别器的情况下训练了生成器,所以在训练判别器之前先要设定其为可训练
d.trainable = True
d.compile(loss='binary_crossentropy', optimizer=d_optim) # 下面在满足epoch条件下进行训练
for epoch in range(10):
print("Epoch is", epoch) # 计算一个epoch所需要的迭代数量,即训练样本数除批量大小数的值取整,其中shape[0]就是读取矩阵第一维度的长度
print("Number of batches", int(X_train.shape[0] / BATCH_SIZE)) # 在一个epoch内进行迭代训练
for index in range(int(X_train.shape[0] / BATCH_SIZE)):
# 随机生成的噪声服从均匀分布,且采样下届为-1,采样上届为1, 输出BATCH_SIZE * 100个样本,即抽取一个批量的随机样本
noise = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100)) # 抽取一个批量的真实图片
image_batch = X_train[index * BATCH_SIZE : (index + 1) * BATCH_SIZE] # 生成的图片使用生成器对随机噪声进行推断,verbose为日志显示
# 0为不在标准输出流输出日志信息,1为输出进度条记录
generated_images = g.predict(noise, verbose=0) # 每经过100次迭代输出一张生成的图片
if index % 100 == 0:
image = combine_images(generated_images)
image = image * 127.5 + 127.5
Image.fromarray(image.astype(np.uint8)).save("C:/Users/Administrator/GAN/" + str(epoch) + "_" + str(index) + ".png") # 将真实图片和生成图片以多维数组的形式拼接在一起,真实图片在上,生成图片在下
X = np.concatenate((image_batch, generated_images)) # 生成图片真假标签,即一个包含两倍批量大小的列表
# 前一个批量大小都是1,代表真实图片,后一个批量大小都是0,代表伪造图片
y = [1] * BATCH_SIZE + [0] * BATCH_SIZE # 判别器的损失,在一个batch的数据上进行一次参数更新
d_loss = d.train_on_batch(X, y)
print("batch %d d_loss : %f" % (index, d_loss)) # 随机生成的噪声服从均匀分布
noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100)) # 固定判别器
d.trainable = False # 计算生成器损失,在一个batch的数据上进行一次参数更新
g_loss = d_on_g.train_on_batch(noise, [1] * BATCH_SIZE) # 令判别器可训练
d.trainable = True
print("batch %d g_loss : %f" % (index, g_loss)) # 每100次迭代保存一次生成器和判别器的权重
if index % 100 == 0:
g.save_weights('generator', True)
d.save_weights('discrimiator', True)
train(32)
六、运行生成好的模型生成图片
# 训练完模型后,可以运行该函数生成图片
def generate(BATCH_SIZE, nice = False):
g = generator_model()
g.compile(loss='binary_crossentropy', optimizer='SGD')
g.load_weights('generator') if nice:
d = discrimiator_model()
d.compile(loss='binary_crossentropy', optimizer='SGD')
d.load_weights('discrimiator')
noise = np.random.uniform(-1, 1, (BATCH_SIZE * 20, 100))
generated_images = g.predict(noise, verbose=1)
d_pret = d.predict(generated_images, verbose=1)
index = np.arange(0, BATCH_SIZE * 20)
index.resize((BATCH_SIZE * 20, 1))
pre_with_index = list(np.append(d_pret, index, axis=1))
pre_with_index.sort(key=lambda x : x[0], reverse=True)
nice_images = np.zeros((BATCH_SIZE, ) + generated_images.shape[1:3], dtype=np.float32)
nice_image = nice_images[:, :, :, None] for i in range(BATCH_SIZE):
idx = int(pre_with_index[i][1])
nice_images[i, :, :, 0] = generated_images[idx, :, :, 0]
image = combine_images(nice_images)
else:
noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))
generated_images = g.predict(noise, verbose=0)
image = combine_images(generated_images)
image = image * 127.5 + 127.5
Image.fromarray(image.astype(np.uint8)).save("C:/Users/Administrator/GAN/generated_image.png")
generate(32)
由于只迭代了10个epoch,效果不是很好,不过已经能看出手写数字了。最后生成的图片如下:

GAN-生成手写数字-Keras的更多相关文章
- GAN——生成手写数字
<Generative Adversarial Nets>是 GAN 系列的鼻祖.在这里通过 PyTorch 实现 GAN ,并且用于手写数字生成. 摘要: 我们提出了一个新的框架,通过对 ...
- GAN实战笔记——第三章第一个GAN模型:生成手写数字
第一个GAN模型-生成手写数字 一.GAN的基础:对抗训练 形式上,生成器和判别器由可微函数表示如神经网络,他们都有自己的代价函数.这两个网络是利用判别器的损失记性反向传播训练.判别器努力使真实样本输 ...
- 卷积生成对抗网络(DCGAN)---生成手写数字
深度卷积生成对抗网络(DCGAN) ---- 生成 MNIST 手写图片 1.基本原理 生成对抗网络(GAN)由2个重要的部分构成: 生成器(Generator):通过机器生成数据(大部分情况下是图像 ...
- Tensorflow:DCGAN生成手写数字
参考地址:https://blog.csdn.net/miracle_ma/article/details/78305991 使用DCGAN(deep convolutional GAN):深度卷积G ...
- keras和tensorflow搭建DNN、CNN、RNN手写数字识别
MNIST手写数字集 MNIST是一个由美国由美国邮政系统开发的手写数字识别数据集.手写内容是0~9,一共有60000个图片样本,我们可以到MNIST官网免费下载,总共4个.gz后缀的压缩文件,该文件 ...
- 手写数字识别——利用keras高层API快速搭建并优化网络模型
在<手写数字识别——手动搭建全连接层>一文中,我们通过机器学习的基本公式构建出了一个网络模型,其实现过程毫无疑问是过于复杂了——不得不考虑诸如数据类型匹配.梯度计算.准确度的统计等问题,但 ...
- keras框架的MLP手写数字识别MNIST,梳理?
keras框架的MLP手写数字识别MNIST 代码: # coding: utf-8 # In[1]: import numpy as np import pandas as pd from kera ...
- 【问题解决方案】Keras手写数字识别-ConnectionResetError: [WinError 10054] 远程主机强迫关闭了一个现有的连接
参考:台大李宏毅老师视频课程-Keras-Demo 在载入数据阶段报错: ConnectionResetError: [WinError 10054] 远程主机强迫关闭了一个现有的连接 Google之 ...
- 用Keras搭建神经网络 简单模版(三)—— CNN 卷积神经网络(手写数字图片识别)
# -*- coding: utf-8 -*- import numpy as np np.random.seed(1337) #for reproducibility再现性 from keras.d ...
随机推荐
- python基础-----变量和简单数据类型
初识变量 1.变量命名规则: a.字母 b.数字(不能开头) c.下划线 ps.硬性规定,命名必须是字母,数字,下划线,且不能以数字开头. 软性规则,以下划线分割 2.实例: a.写法: age_of ...
- Eclipse同时显示两个编辑窗口
同时打开两个编辑窗口,点住一个窗口,拖到编辑窗口的最下面时或者最右面,会出现两个两个编辑窗口的轮廓,松开即可!
- Table Compression
https://docs.oracle.com/cd/E11882_01/server.112/e40540/tablecls.htm#CNCPT608
- 出现警告“user1 不在 sudoers 文件中。此事将被报告。”
linux中不是每个用户都有sudo权限. 在/etc/下有个文件sudoers 由此文件可知只有用户为sudo这个组的成员之后才能执行sudo命令 此时,我们查看用户user1的属性: 由此看出us ...
- VS2013 VS2015 VS2017调试出现无法启动iis express web服务器
最近老是遇到这个问题,天天如此,烦死人,网上答案繁多,但是都解决不了,也是由于各种环境不同导致的,这里把几种解决方法都记录下 一.其他项目都可以,就这么一个不行 因为其他项目都可以,就这么一个不行,所 ...
- Spring HibernateTemplate与HibernateDaoSupport对比
HibernateTemplate与HibernateDaoSupport两者都是spring整合hibernate提供的模板技术. 对于保存一个对象,HibernateTemplate需要先配置 配 ...
- js 测试性能
console.time('querySelector');for(var i=0; i<1000; i++){document.querySelector('body');}console.t ...
- Linux/Ubuntu安装搜狗输入法
零.你首先需要安装fcitx小企鹅输入法,相信绝大部分用linux的中国人都用这个输入法,安装fcitx后同时还能解决Sublime Text的中文输入问题. 安装fcitx输入法前首先要安装fcit ...
- kerberos简单介绍
重要术语 1. KDC 全称:key distributed center 作用:整个安全认证过程的票据生成管理服务,其中包含两个服务,AS和TGS 2. AS 全称:authentication s ...
- MySQL NULL处理
-- 首先在用户表中插入数据如下 TRUNCATE TABLE UserInfo ; INSERT INTO `userinfo`(`ID`,`UserName`,`UserLogin`,`User ...