深度学习之 GAN 进行 mnist 图片的生成

mport numpy as np
import os
import codecs
import torch
from PIL import Image
import PIL def get_int(b):
return int(codecs.encode(b, 'hex'), 16) def extract_image(path, extract_path):
with open(path, 'rb') as f:
data = f.read()
assert get_int(data[:4]) == 2051
length = get_int(data[4:8])
num_rows = get_int(data[8:12])
num_cols = get_int(data[12:16])
images = []
parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
parsed = parsed.reshape(length, num_rows, num_cols) for image_i, image in enumerate(parsed):
Image.fromarray(image, 'L').save(os.path.join(extract_path, 'image_{}.jpg'.format(image_i))) image_path = './mnist/t10k-images.idx3-ubyte'
extract_path = './mnist/data/image' import math def images_square_grid(images, mode):
save_size = math.floor(np.sqrt(images.shape[0])) # Scale to 0-255
images = (((images - images.min()) * 255) / (images.max() - images.min())).astype(np.uint8) # Put images in a square arrangement
images_in_square = np.reshape(
images[:save_size*save_size],
(save_size, save_size, images.shape[1], images.shape[2], images.shape[3]))
if mode == 'L':
images_in_square = np.squeeze(images_in_square, 4) # Combine images to grid image
new_im = Image.new(mode, (images.shape[1] * save_size, images.shape[2] * save_size))
for col_i, col_images in enumerate(images_in_square):
for image_i, image in enumerate(col_images):
im = Image.fromarray(image, mode)
new_im.paste(im, (col_i * images.shape[1], image_i * images.shape[2])) return new_im def get_image(image_path, width, height, mode): image = Image.open(image_path) if image.size != (width, height):
face_width = face_width = 108
j = (image.size[0] - face_width) // 2
i = (image.size[1] - face_height) // 2 image = image.crop([j, i, j + face_width, i + face_height])
image = image.resize([width, height], Image.BILINEAR) return np.array(image.convert(mode)) def get_batch(image_files, width, height, mode):
data_batch = np.array([get_image(sample_file, width, height, mode) for sample_file in image_files]).astype(np.float32) if len(data_batch.shape) < 4:
data_batch = data_batch.reshape(data_batch.shape + (1,)) return data_batch %matplotlib inline
import os
from glob import glob
from matplotlib import pyplot data_dir = './mnist/data'
show_n_images = 25 mnist_images = get_batch(glob(os.path.join(data_dir, 'image/*.jpg'))[:show_n_images], 28, 28, 'L') pyplot.imshow(images_square_grid(mnist_images, 'L'), cmap='gray') from torch.utils import data
import torchvision as tv batch_size = 50 transforms = tv.transforms.Compose([
tv.transforms.Resize(96),
PIL.ImageOps.grayscale,
tv.transforms.ToTensor()
]) root="d:\\work\\yoho\\dl\\dl-study\\chapter8\\mnist\\data" dataset = tv.datasets.ImageFolder(root, transform=transforms)
dataloader = data.DataLoader(dataset, batch_size, shuffle=True, num_workers=1, drop_last=True) import torch.nn as nn
import torch.optim as optim
from torch.nn.modules import loss
from torch.autograd import Variable as V class GNet(nn.Module):
def __init__(self, opt):
super(GNet, self).__init__() ngf = opt["ngf"]
target = opt["target"] or 3 self.main = nn.Sequential(
nn.ConvTranspose2d( opt["nz"], ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True), nn.ConvTranspose2d( ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True), nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True), nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True), nn.ConvTranspose2d( ngf, target, 5, 3, 1, bias=False),
nn.Tanh()
) def forward(self, input):
return self.main(input) class DNet(nn.Module):
def __init__(self, opt):
super(DNet, self).__init__() ndf = opt["ndf"]
input = opt["input"] or 3 self.main = nn.Sequential(
nn.Conv2d(input, ndf, 5, 3, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.3, inplace=True), nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
) def forward(self, input):
return self.main(input).view(-1) lr_g = 0.01
lr_d = 0.01
ngf = 64
ndf = 64
raw_f = 1
nz = 100
d_every = 1
g_every = 5 net_g = GNet({"target": raw_f, "ngf": ngf, 'nz': nz})
net_d = DNet({"input": raw_f, "ndf": ndf}) opt_g = optim.Adam(net_g.parameters(), lr_g, betas=(0.5, 0.999))
opt_d = optim.Adam(net_d.parameters(), lr_g, betas=(0.5, 0.999)) criterion = torch.nn.BCELoss() true_labels = V(torch.ones(batch_size))
fake_labels = V(torch.zeros(batch_size))
fix_noises = V(torch.randn(batch_size, nz, 1, 1))
noises = V(torch.randn(batch_size, nz, 1, 1)) def train():
for ii, (img, _) in enumerate(dataloader):
real_img = V(img) if (ii + 1) % d_every == 0:
opt_d.zero_grad()
output = net_d(real_img)
loss_d = criterion(output, true_labels)
loss_d.backward() noises.data.copy_(torch.randn(batch_size, nz, 1, 1)) fake_img = net_g(noises) fake_img = fake_img.detach()
fake_output = net_d(fake_img)
loss_fake_d = criterion(fake_output, fake_labels)
loss_fake_d.backward() opt_d.step() if (ii + 1) % g_every == 0:
opt_g.zero_grad()
noises.data.copy_(torch.randn(batch_size, nz, 1, 1))
fake_image = net_g(noises) fake_output = net_d(fake_img) loss_g = criterion(fake_output, true_labels) loss_g.backward()
opt_g.step() def print_image():
fix_fake_imgs = net_g(fix_noises)
fix_fake_imgs = fix_fake_imgs.data.view(batch_size, 96, 96, 1).numpy()
pyplot.imshow(images_square_grid(fix_fake_imgs, 'L'), cmap='gray') epochs = 20
def main():
for i in range(epochs):
print("epoch {}".format(i))
train() if i % 2 == 0:
print_image()
main()

注意 GAN 很慢,要使用 GPU来工作

深度学习之 GAN 进行 mnist 图片的生成的更多相关文章

  1. 4.keras实现-->生成式深度学习之用变分自编码器VAE生成图像(mnist数据集和名人头像数据集)

    变分自编码器(VAE,variatinal autoencoder)   VS    生成式对抗网络(GAN,generative adversarial network) 两者不仅适用于图像,还可以 ...

  2. 【深度学习】--GAN从入门到初始

    一.前述 GAN,生成对抗网络,在2016年基本火爆深度学习,所有有必要学习一下.生成对抗网络直观的应用可以帮我们生成数据,图片. 二.具体 1.生活案例 比如假设真钱 r 坏人定义为G  我们通过 ...

  3. 深度学习-Wasserstein GAN论文理解笔记

    GAN存在问题 训练困难,G和D多次尝试没有稳定性,Loss无法知道能否优化,生成样本单一,改进方案靠暴力尝试 WGAN GAN的Loss函数选择不合适,使模型容易面临梯度消失,梯度不稳定,优化目标不 ...

  4. 机器学习 —— 深度学习 —— 基于DAGNN的MNIST NET

    DAGNN 是Directed acyclic graph neural network 缩写,也就有向图非循环神经网络.我使用的是由MatConvNet 提供的DAGNN API.选择这套API作为 ...

  5. 深度学习之GAN对抗神经网络

    1.结构图 2.知识点 生成器(G):将噪音数据生成一个想要的数据 判别器(D):将生成器的结果进行判别, 3.代码及案例 # coding: utf-8 # ## 对抗生成网络案例 ## # # # ...

  6. 【Python开发】【神经网络与深度学习】网络爬虫之图片自动下载器

    python爬虫实战--图片自动下载器 之前介绍了那么多基本知识[Python爬虫]入门知识(没看的赶紧去看)大家也估计手痒了.想要实际做个小东西来看看,毕竟: talk is cheap show ...

  7. 为什么要用深度学习来做个性化推荐 CTR 预估

    欢迎大家前往腾讯云技术社区,获取更多腾讯海量技术实践干货哦~ 作者:苏博览 深度学习应该这一两年计算机圈子里最热的一个词了.基于深度学习,工程师们在图像,语音,NLP等领域都取得了令人振奋的进展.而深 ...

  8. NLP+VS︱深度学习数据集标注工具、方法摘录,欢迎补充~~

    ~~因为不太会使用opencv.matlab工具,所以在找一些比较简单的工具. . . 一.NLP标注工具BRAT BRAT是一个基于web的文本标注工具,主要用于对文本的结构化标注,用BRAT生成的 ...

  9. 好书推荐计划:Keras之父作品《Python 深度学习》

    大家好,我禅师的助理兼人工智能排版住手助手条子.可能非常多人都不知道我.由于我真的难得露面一次,天天给禅师做底层工作. wx_fmt=jpeg" alt="640? wx_fmt= ...

随机推荐

  1. JBox使用详解

    插件说明 - jBox 是一款基于 jQuery 的多功能对话框插件,能够实现网站的整体风格效果,给用户一个新的视觉享受. 运行环境 - 兼容 IE6+.Firefox.Chrome.Safari.O ...

  2. Bond UVA - 11354(LCA应用题)

    Once again, James Bond is on his way to saving the world. Bond's latest mission requires him to trav ...

  3. git pull error

    在图形界面中,执行拉取操作时,出现下面的错误. You asked to pull from the remote 'origin', but did not specifya branch. Bec ...

  4. 使用localtunne一分钟搞定微信公众号接入

      记得15年那个刚刚进入工作的时候,公司有个微信公众号的项目,那个时候微信官方没有什么调试工具,也没有什么比较好的本地调试工具.当时有个功能需要调用微信JSSDK里面的扫一扫的功能.由于本地不能调试 ...

  5. 元素化设计原理及规则v1.0

    一.元素设计架构 元素设计架构展示在基于元素化设计的思想下,系统各元素之间如何相互协作,并完成整个系统搭建. 架构中以Entity(数据)为中心,由Entity产生数据库表结构,并且Entity作为业 ...

  6. Clion下jni配置

    Clion非常适合写C++程序 首先把C:\Program Files\Java\jdk1.7.0_79\include和C:\Program Files\Java\jdk1.7.0_79\inclu ...

  7. 以太坊挖矿源码:ethash算法

    本文具体分析以太坊的共识算法之一:实现了POW的以太坊共识引擎ethash. 关键字:ethash,共识算法,pow,Dagger Hashimoto,ASIC,struct{},nonce,FNV ...

  8. 笔记:XML-解析文档

    要处理XML文档,就要先解析(parse)他,解析器时这样一个程序,读入一个文件,确认整个文件具有正确的格式,然后将其分解成各种元素,使得程序员能够访问这些元素,Java库提供了两种XML解析器: 像 ...

  9. 【Saltstack】Saltstack简单说明

    [Saltstack] Saltstack是一个服务器集中管理中心平台,可以帮助管理员轻松的对若干台服务器进行统一操作.类似的工具还有Ansible,Puppet,func等等.相比于这些工具,sal ...

  10. 【Python】 SQLAlchemy的初步使用

    SQLAlchemy 在很多Python的web框架中都整合进了SQLAlchemy这个主要发挥ORM作用的模块.所谓ORM,就是把复杂的SQL语句给包装成更加面向对象,易于理解的样子.在操作数据库的 ...