Outline

  • Auto-Encoder

  • Variational Auto-Encoders

Auto-Encoder

创建编解码器

import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import Sequential, layers
from PIL import Image
from matplotlib import pyplot as plt tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.') def save_images(imgs, name):
new_im = Image.new('L', (280, 280)) index = 0
for i in range(0, 280, 28):
for j in range(0, 280, 28):
im = imgs[index]
im = Image.fromarray(im, mode='L')
new_im.paste(im, (i, j))
index += 1 new_im.save(name) h_dim = 20 # 784降维20维
batchsz = 512
lr = 1e-3 (x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(
np.float32) / 255.
# we do not need label
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batchsz * 5).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batchsz) print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape) class AE(keras.Model):
def __init__(self):
super(AE, self).__init__() # Encoders
self.encoder = Sequential([
layers.Dense(256, activation=tf.nn.relu),
layers.Dense(128, activation=tf.nn.relu),
layers.Dense(h_dim)
]) # Decoders
self.decoder = Sequential([
layers.Dense(128, activation=tf.nn.relu),
layers.Dense(256, activation=tf.nn.relu),
layers.Dense(784)
]) def call(self, inputs, training=None):
# [b,784] ==> [b,19]
h = self.encoder(inputs) # [b,10] ==> [b,784]
x_hat = self.decoder(h) return x_hat model = AE()
model.build(input_shape=(None, 784)) # tensorflow尽量用元组
model.summary()
(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)
Model: "ae"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
sequential (Sequential) multiple 236436
_________________________________________________________________
sequential_1 (Sequential) multiple 237200
=================================================================
Total params: 473,636
Trainable params: 473,636
Non-trainable params: 0
_________________________________________________________________

训练

optimizer = tf.optimizers.Adam(lr=lr)

for epoch in range(10):

    for step, x in enumerate(train_db):

        # [b,28,28]==>[b,784]
x = tf.reshape(x, [-1, 784]) with tf.GradientTape() as tape:
x_rec_logits = model(x) rec_loss = tf.losses.binary_crossentropy(x,
x_rec_logits,
from_logits=True)
rec_loss = tf.reduce_min(rec_loss) grads = tape.gradient(rec_loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables)) if step % 100 == 0:
print(epoch, step, float(rec_loss)) # evaluation x = next(iter(test_db))
logits = model(tf.reshape(x, [-1, 784]))
x_hat = tf.sigmoid(logits)
# [b,784]==>[b,28,28]
x_hat = tf.reshape(x_hat, [-1, 28, 28]) # [b,28,28] ==> [2b,28,28]
x_concat = tf.concat([x, x_hat], axis=0)
# x_concat = x # 原始图片
x_concat = x_hat
x_concat = x_concat.numpy() * 255.
x_concat = x_concat.astype(np.uint8) # 保存为整型
if not os.path.exists('ae_images'):
os.mkdir('ae_images')
save_images(x_concat, 'ae_images/rec_epoch_%d.png' % epoch)
0 0 0.09717604517936707
0 100 0.12493347376585007
1 0 0.09747321903705597
1 100 0.12291513383388519
2 0 0.10048121958971024
2 100 0.12292417883872986
3 0 0.10093794018030167
3 100 0.12260882556438446
4 0 0.10006923228502274
4 100 0.12275046110153198
5 0 0.0993042066693306
5 100 0.12257824838161469
6 0 0.0967678651213646
6 100 0.12443818897008896
7 0 0.0965462476015091
7 100 0.12179268896579742
8 0 0.09197664260864258
8 100 0.12110235542058945
9 0 0.0913471132516861
9 100 0.12342415750026703

Auto-Encoders实战的更多相关文章

  1. [Python] 机器学习库资料汇总

    声明:以下内容转载自平行宇宙. Python在科学计算领域,有两个重要的扩展模块:Numpy和Scipy.其中Numpy是一个用python实现的科学计算包.包括: 一个强大的N维数组对象Array: ...

  2. python数据挖掘领域工具包

    原文:http://qxde01.blog.163.com/blog/static/67335744201368101922991/ Python在科学计算领域,有两个重要的扩展模块:Numpy和Sc ...

  3. Theano3.1-练习之初步介绍

    来自 http://deeplearning.net/tutorial/,虽然比较老了,不过觉得想系统的学习theano,所以需要从python--numpy--theano的顺序学习.这里的资料都很 ...

  4. [resource]Python机器学习库

    reference: http://qxde01.blog.163.com/blog/static/67335744201368101922991/ Python在科学计算领域,有两个重要的扩展模块: ...

  5. 机器学习——深度学习(Deep Learning)

    Deep Learning是机器学习中一个非常接近AI的领域,其动机在于建立.模拟人脑进行分析学习的神经网络,近期研究了机器学习中一些深度学习的相关知识,本文给出一些非常实用的资料和心得. Key W ...

  6. Deep Learning Tutorial - Classifying MNIST digits using Logistic Regression

    Deep Learning Tutorial 由 Montreal大学的LISA实验室所作,基于Theano的深度学习材料.Theano是一个python库,使得写深度模型更容易些,也可以在GPU上训 ...

  7. [转]Python机器学习工具箱

    原文在这里  Python在科学计算领域,有两个重要的扩展模块:Numpy和Scipy.其中Numpy是一个用python实现的科学计算包.包括: 一个强大的N维数组对象Array: 比较成熟的(广播 ...

  8. 深度学习材料:从感知机到深度网络A Deep Learning Tutorial: From Perceptrons to Deep Networks

    In recent years, there’s been a resurgence in the field of Artificial Intelligence. It’s spread beyo ...

  9. Deep Learning(4)

    四.拓展学习推荐 Deep Learning 经典阅读材料: The monograph or review paper Learning Deep Architectures for AI (Fou ...

  10. 深度学习教程Deep Learning Tutorials

    Deep Learning Tutorials Deep Learning is a new area of Machine Learning research, which has been int ...

随机推荐

  1. bzoj 3894 文理分科【最小割+dinic】

    谁说这道和2127是双倍经验的来着完全不一样啊? 数组开小会TLE!数组开小会TLE!数组开小会TLE! 首先sum统计所有收益 对于当前点\( (i,j) \)考虑,设\( x=(i-1)*m+j ...

  2. 用Python解析HTML,BeautifulSoup使用简介

    Beautiful Soup,字面意思是美好的汤,是一个用于解析HTML文件的Python库.主页在http://www.crummy.com/software/BeautifulSoup/ , 下载 ...

  3. 通过IDEA制作包含Java应程序的Docker镜像

    IDEA官网在IDEA中把Java App制作成Docker镜像并启动一个容器运行 在idea上使用docker作为java的开发环境[][] ubuntu+docker+docker-compose ...

  4. ACM复习专项

    资料整理 ACM训练营 邝斌的ACM模板 牛客网哈理工ACM教学视频 视频网盘资料(密码:kntr) 1. 训练阶段 第一阶段:练习经典常用算法 (本周任务) 1. 最短路(Floyd.Dijstra ...

  5. Matlab调用C程序 分类: Matlab c/c++ 2015-01-06 19:18 464人阅读 评论(0) 收藏

    Matlab是矩阵语言,如果运算可以用矩阵实现,其运算速度非常快.但若运算中涉及到大量循环,Matlab的速度令人难以忍受的.当必须使用for循环且找不到对应的矩阵运算来等效时,可以将耗时长的函数用C ...

  6. rhel7安装oracle 11gR2,所需的依赖包

    binutils-2.23.52.0.1-30.el7.x86_64 compat-libstdc++-33-3.2.3-61.x86_64compat-libstdc++-33-3.2.3-61.i ...

  7. PHP使用Session遇到的一个Permission denied Notice解决办法

    搜索 session.save_path 在这里你有两个选择,一个是像我一样用; 把这一行注释掉,另一个选择就是修改一个 nobody 用户可以操作的目录,也就是说有读写权限的目录,我也查了下这个默认 ...

  8. 多个文本框点击复制 zClip (ZeroClipboard)有关问题

    <script type="text/javascript" src="js/jquery.min.js"$amp;>amp;$lt;/script ...

  9. 移动web开发基础(二)——viewport

    本文主要研究为什么移动web开发需要设置viewport,且一般设置为<meta name="viewport" content="width=device-wid ...

  10. Bing图片下载器(Python实现)

    分享一个Python实现的Bing图片下载器.下载首页图片并保存到到当前目录.其中用到了正则库re以及Request库. 大致流程如下: 1.Request抓取首页数据 2.re正则匹配首页图片URL ...