深度学习之 GAN 进行 mnist 图片的生成

mport numpy as np
import os
import codecs
import torch
from PIL import Image
import PIL def get_int(b):
return int(codecs.encode(b, 'hex'), 16) def extract_image(path, extract_path):
with open(path, 'rb') as f:
data = f.read()
assert get_int(data[:4]) == 2051
length = get_int(data[4:8])
num_rows = get_int(data[8:12])
num_cols = get_int(data[12:16])
images = []
parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
parsed = parsed.reshape(length, num_rows, num_cols) for image_i, image in enumerate(parsed):
Image.fromarray(image, 'L').save(os.path.join(extract_path, 'image_{}.jpg'.format(image_i))) image_path = './mnist/t10k-images.idx3-ubyte'
extract_path = './mnist/data/image' import math def images_square_grid(images, mode):
save_size = math.floor(np.sqrt(images.shape[0])) # Scale to 0-255
images = (((images - images.min()) * 255) / (images.max() - images.min())).astype(np.uint8) # Put images in a square arrangement
images_in_square = np.reshape(
images[:save_size*save_size],
(save_size, save_size, images.shape[1], images.shape[2], images.shape[3]))
if mode == 'L':
images_in_square = np.squeeze(images_in_square, 4) # Combine images to grid image
new_im = Image.new(mode, (images.shape[1] * save_size, images.shape[2] * save_size))
for col_i, col_images in enumerate(images_in_square):
for image_i, image in enumerate(col_images):
im = Image.fromarray(image, mode)
new_im.paste(im, (col_i * images.shape[1], image_i * images.shape[2])) return new_im def get_image(image_path, width, height, mode): image = Image.open(image_path) if image.size != (width, height):
face_width = face_width = 108
j = (image.size[0] - face_width) // 2
i = (image.size[1] - face_height) // 2 image = image.crop([j, i, j + face_width, i + face_height])
image = image.resize([width, height], Image.BILINEAR) return np.array(image.convert(mode)) def get_batch(image_files, width, height, mode):
data_batch = np.array([get_image(sample_file, width, height, mode) for sample_file in image_files]).astype(np.float32) if len(data_batch.shape) < 4:
data_batch = data_batch.reshape(data_batch.shape + (1,)) return data_batch %matplotlib inline
import os
from glob import glob
from matplotlib import pyplot data_dir = './mnist/data'
show_n_images = 25 mnist_images = get_batch(glob(os.path.join(data_dir, 'image/*.jpg'))[:show_n_images], 28, 28, 'L') pyplot.imshow(images_square_grid(mnist_images, 'L'), cmap='gray') from torch.utils import data
import torchvision as tv batch_size = 50 transforms = tv.transforms.Compose([
tv.transforms.Resize(96),
PIL.ImageOps.grayscale,
tv.transforms.ToTensor()
]) root="d:\\work\\yoho\\dl\\dl-study\\chapter8\\mnist\\data" dataset = tv.datasets.ImageFolder(root, transform=transforms)
dataloader = data.DataLoader(dataset, batch_size, shuffle=True, num_workers=1, drop_last=True) import torch.nn as nn
import torch.optim as optim
from torch.nn.modules import loss
from torch.autograd import Variable as V class GNet(nn.Module):
def __init__(self, opt):
super(GNet, self).__init__() ngf = opt["ngf"]
target = opt["target"] or 3 self.main = nn.Sequential(
nn.ConvTranspose2d( opt["nz"], ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True), nn.ConvTranspose2d( ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True), nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True), nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True), nn.ConvTranspose2d( ngf, target, 5, 3, 1, bias=False),
nn.Tanh()
) def forward(self, input):
return self.main(input) class DNet(nn.Module):
def __init__(self, opt):
super(DNet, self).__init__() ndf = opt["ndf"]
input = opt["input"] or 3 self.main = nn.Sequential(
nn.Conv2d(input, ndf, 5, 3, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.3, inplace=True), nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
) def forward(self, input):
return self.main(input).view(-1) lr_g = 0.01
lr_d = 0.01
ngf = 64
ndf = 64
raw_f = 1
nz = 100
d_every = 1
g_every = 5 net_g = GNet({"target": raw_f, "ngf": ngf, 'nz': nz})
net_d = DNet({"input": raw_f, "ndf": ndf}) opt_g = optim.Adam(net_g.parameters(), lr_g, betas=(0.5, 0.999))
opt_d = optim.Adam(net_d.parameters(), lr_g, betas=(0.5, 0.999)) criterion = torch.nn.BCELoss() true_labels = V(torch.ones(batch_size))
fake_labels = V(torch.zeros(batch_size))
fix_noises = V(torch.randn(batch_size, nz, 1, 1))
noises = V(torch.randn(batch_size, nz, 1, 1)) def train():
for ii, (img, _) in enumerate(dataloader):
real_img = V(img) if (ii + 1) % d_every == 0:
opt_d.zero_grad()
output = net_d(real_img)
loss_d = criterion(output, true_labels)
loss_d.backward() noises.data.copy_(torch.randn(batch_size, nz, 1, 1)) fake_img = net_g(noises) fake_img = fake_img.detach()
fake_output = net_d(fake_img)
loss_fake_d = criterion(fake_output, fake_labels)
loss_fake_d.backward() opt_d.step() if (ii + 1) % g_every == 0:
opt_g.zero_grad()
noises.data.copy_(torch.randn(batch_size, nz, 1, 1))
fake_image = net_g(noises) fake_output = net_d(fake_img) loss_g = criterion(fake_output, true_labels) loss_g.backward()
opt_g.step() def print_image():
fix_fake_imgs = net_g(fix_noises)
fix_fake_imgs = fix_fake_imgs.data.view(batch_size, 96, 96, 1).numpy()
pyplot.imshow(images_square_grid(fix_fake_imgs, 'L'), cmap='gray') epochs = 20
def main():
for i in range(epochs):
print("epoch {}".format(i))
train() if i % 2 == 0:
print_image()
main()

注意 GAN 很慢,要使用 GPU来工作

深度学习之 GAN 进行 mnist 图片的生成的更多相关文章

  1. 4.keras实现-->生成式深度学习之用变分自编码器VAE生成图像(mnist数据集和名人头像数据集)

    变分自编码器(VAE,variatinal autoencoder)   VS    生成式对抗网络(GAN,generative adversarial network) 两者不仅适用于图像,还可以 ...

  2. 【深度学习】--GAN从入门到初始

    一.前述 GAN,生成对抗网络,在2016年基本火爆深度学习,所有有必要学习一下.生成对抗网络直观的应用可以帮我们生成数据,图片. 二.具体 1.生活案例 比如假设真钱 r 坏人定义为G  我们通过 ...

  3. 深度学习-Wasserstein GAN论文理解笔记

    GAN存在问题 训练困难,G和D多次尝试没有稳定性,Loss无法知道能否优化,生成样本单一,改进方案靠暴力尝试 WGAN GAN的Loss函数选择不合适,使模型容易面临梯度消失,梯度不稳定,优化目标不 ...

  4. 机器学习 —— 深度学习 —— 基于DAGNN的MNIST NET

    DAGNN 是Directed acyclic graph neural network 缩写,也就有向图非循环神经网络.我使用的是由MatConvNet 提供的DAGNN API.选择这套API作为 ...

  5. 深度学习之GAN对抗神经网络

    1.结构图 2.知识点 生成器(G):将噪音数据生成一个想要的数据 判别器(D):将生成器的结果进行判别, 3.代码及案例 # coding: utf-8 # ## 对抗生成网络案例 ## # # # ...

  6. 【Python开发】【神经网络与深度学习】网络爬虫之图片自动下载器

    python爬虫实战--图片自动下载器 之前介绍了那么多基本知识[Python爬虫]入门知识(没看的赶紧去看)大家也估计手痒了.想要实际做个小东西来看看,毕竟: talk is cheap show ...

  7. 为什么要用深度学习来做个性化推荐 CTR 预估

    欢迎大家前往腾讯云技术社区,获取更多腾讯海量技术实践干货哦~ 作者:苏博览 深度学习应该这一两年计算机圈子里最热的一个词了.基于深度学习,工程师们在图像,语音,NLP等领域都取得了令人振奋的进展.而深 ...

  8. NLP+VS︱深度学习数据集标注工具、方法摘录,欢迎补充~~

    ~~因为不太会使用opencv.matlab工具,所以在找一些比较简单的工具. . . 一.NLP标注工具BRAT BRAT是一个基于web的文本标注工具,主要用于对文本的结构化标注,用BRAT生成的 ...

  9. 好书推荐计划:Keras之父作品《Python 深度学习》

    大家好,我禅师的助理兼人工智能排版住手助手条子.可能非常多人都不知道我.由于我真的难得露面一次,天天给禅师做底层工作. wx_fmt=jpeg" alt="640? wx_fmt= ...

随机推荐

  1. NancyFX 附录: Nuget程序包

    Nancy.Authentication.Forms 该程序包向Nancy提供标准的基于ASP.NET/IIS的表单身份验证服务. 采用这个模块启用身份验证后,可以获得标准ASP.NET表单验证方式. ...

  2. SQL Server The target database ('db') is in an availability group and currently does not allow read only connections. For more information about application intent, see SQL Server Books Online.

    一.问题概述 在错误日志中看到非常多的alwayson群集只读连接错误,错误信息的描述为“目标数据库位于可用性组,当前不允许通过read only连接”.错误日志如下: 当前的业务系统使用监听ip对数 ...

  3. grub4dos和winsetupfromusb1.4

    其实grub4dos也是一个多系统启动盘制作软件,GRUB4DOS 最大的成功之处就是既学习了windows的方便易用,又引入linux的强大功能.http://baike.baidu.com/lin ...

  4. Asp.Net MVC 实现将Easy-UI展示数据下载为Excel 文件

    在一个项目中,需要做一个将Easy-UI界面展示数据下载为Excel文件的功能,经过一段时间努力,完成了一个小Demo.界面如下: 但按下导出Excel后,Excel文件将会下载到本地,在office ...

  5. Linux Centos 下安装软件 三种方式(转)

    Linux学习的路还很远呢,各位码农,新年快乐哈! 1)一种是软件的源代码,您需要自己动手编译它.这种软件安装包通常是用gzip压缩过的tar包(后缀为.tar.gz). 2)另一种是软件的可执行程序 ...

  6. Cesium剖面分析

  7. mysql存储过程(查询数据库内表 游标循环 if判断 插入别的表内)

    BEGIN declare f_age int;DECLARE incode1 VARCHAR(100);DECLARE incode2 VARCHAR(100);DECLARE incode3 VA ...

  8. Vue解析三之过滤器

    export function formatDate(date, fmt) { if (/(y+)/.test(fmt)) { fmt = fmt.replace(RegExp.$1, (date.g ...

  9. poj-1056-IMMEDIATE DECODABILITY(字典)

    Description An encoding of a set of symbols is said to be immediately decodable if no code for one s ...

  10. 数据库 --> sqlite3总结

    Sqlite3总结 SQLite,是一款轻型的数据库,是遵守ACID的关系型数据库管理系统,它包含在一个相对小的C库中. sqlite语句 #sqlite3 test.db //设置宽度为2sqlit ...