Auto-Encoders实战
Outline
Auto-Encoder
Variational Auto-Encoders
Auto-Encoder
创建编解码器
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import Sequential, layers
from PIL import Image
from matplotlib import pyplot as plt
tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')
def save_images(imgs, name):
new_im = Image.new('L', (280, 280))
index = 0
for i in range(0, 280, 28):
for j in range(0, 280, 28):
im = imgs[index]
im = Image.fromarray(im, mode='L')
new_im.paste(im, (i, j))
index += 1
new_im.save(name)
h_dim = 20 # 784降维20维
batchsz = 512
lr = 1e-3
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(
np.float32) / 255.
# we do not need label
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batchsz * 5).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batchsz)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
class AE(keras.Model):
def __init__(self):
super(AE, self).__init__()
# Encoders
self.encoder = Sequential([
layers.Dense(256, activation=tf.nn.relu),
layers.Dense(128, activation=tf.nn.relu),
layers.Dense(h_dim)
])
# Decoders
self.decoder = Sequential([
layers.Dense(128, activation=tf.nn.relu),
layers.Dense(256, activation=tf.nn.relu),
layers.Dense(784)
])
def call(self, inputs, training=None):
# [b,784] ==> [b,19]
h = self.encoder(inputs)
# [b,10] ==> [b,784]
x_hat = self.decoder(h)
return x_hat
model = AE()
model.build(input_shape=(None, 784)) # tensorflow尽量用元组
model.summary()
(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)
Model: "ae"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
sequential (Sequential) multiple 236436
_________________________________________________________________
sequential_1 (Sequential) multiple 237200
=================================================================
Total params: 473,636
Trainable params: 473,636
Non-trainable params: 0
_________________________________________________________________
训练
optimizer = tf.optimizers.Adam(lr=lr)
for epoch in range(10):
for step, x in enumerate(train_db):
# [b,28,28]==>[b,784]
x = tf.reshape(x, [-1, 784])
with tf.GradientTape() as tape:
x_rec_logits = model(x)
rec_loss = tf.losses.binary_crossentropy(x,
x_rec_logits,
from_logits=True)
rec_loss = tf.reduce_min(rec_loss)
grads = tape.gradient(rec_loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
if step % 100 == 0:
print(epoch, step, float(rec_loss))
# evaluation
x = next(iter(test_db))
logits = model(tf.reshape(x, [-1, 784]))
x_hat = tf.sigmoid(logits)
# [b,784]==>[b,28,28]
x_hat = tf.reshape(x_hat, [-1, 28, 28])
# [b,28,28] ==> [2b,28,28]
x_concat = tf.concat([x, x_hat], axis=0)
# x_concat = x # 原始图片
x_concat = x_hat
x_concat = x_concat.numpy() * 255.
x_concat = x_concat.astype(np.uint8) # 保存为整型
if not os.path.exists('ae_images'):
os.mkdir('ae_images')
save_images(x_concat, 'ae_images/rec_epoch_%d.png' % epoch)
0 0 0.09717604517936707
0 100 0.12493347376585007
1 0 0.09747321903705597
1 100 0.12291513383388519
2 0 0.10048121958971024
2 100 0.12292417883872986
3 0 0.10093794018030167
3 100 0.12260882556438446
4 0 0.10006923228502274
4 100 0.12275046110153198
5 0 0.0993042066693306
5 100 0.12257824838161469
6 0 0.0967678651213646
6 100 0.12443818897008896
7 0 0.0965462476015091
7 100 0.12179268896579742
8 0 0.09197664260864258
8 100 0.12110235542058945
9 0 0.0913471132516861
9 100 0.12342415750026703
Auto-Encoders实战的更多相关文章
- [Python] 机器学习库资料汇总
声明:以下内容转载自平行宇宙. Python在科学计算领域,有两个重要的扩展模块:Numpy和Scipy.其中Numpy是一个用python实现的科学计算包.包括: 一个强大的N维数组对象Array: ...
- python数据挖掘领域工具包
原文:http://qxde01.blog.163.com/blog/static/67335744201368101922991/ Python在科学计算领域,有两个重要的扩展模块:Numpy和Sc ...
- Theano3.1-练习之初步介绍
来自 http://deeplearning.net/tutorial/,虽然比较老了,不过觉得想系统的学习theano,所以需要从python--numpy--theano的顺序学习.这里的资料都很 ...
- [resource]Python机器学习库
reference: http://qxde01.blog.163.com/blog/static/67335744201368101922991/ Python在科学计算领域,有两个重要的扩展模块: ...
- 机器学习——深度学习(Deep Learning)
Deep Learning是机器学习中一个非常接近AI的领域,其动机在于建立.模拟人脑进行分析学习的神经网络,近期研究了机器学习中一些深度学习的相关知识,本文给出一些非常实用的资料和心得. Key W ...
- Deep Learning Tutorial - Classifying MNIST digits using Logistic Regression
Deep Learning Tutorial 由 Montreal大学的LISA实验室所作,基于Theano的深度学习材料.Theano是一个python库,使得写深度模型更容易些,也可以在GPU上训 ...
- [转]Python机器学习工具箱
原文在这里 Python在科学计算领域,有两个重要的扩展模块:Numpy和Scipy.其中Numpy是一个用python实现的科学计算包.包括: 一个强大的N维数组对象Array: 比较成熟的(广播 ...
- 深度学习材料:从感知机到深度网络A Deep Learning Tutorial: From Perceptrons to Deep Networks
In recent years, there’s been a resurgence in the field of Artificial Intelligence. It’s spread beyo ...
- Deep Learning(4)
四.拓展学习推荐 Deep Learning 经典阅读材料: The monograph or review paper Learning Deep Architectures for AI (Fou ...
- 深度学习教程Deep Learning Tutorials
Deep Learning Tutorials Deep Learning is a new area of Machine Learning research, which has been int ...
随机推荐
- React实战之将数据库返回的时间转换为几分钟前、几小时前、几天前的形式。
React实战之将数据库返回的时间转换为几分钟前.几小时前.几天前的形式. 不知道大家的时间格式是什么样子的,我先展示下我这里数据库返回的时间格式 ‘2019-05-05T15:52:19Z’ 是这个 ...
- 组合数的几种球阀 By cellur925
先来了解几个概念:排列数,组合数. 一.定义及有用的性质 排列数:从n个不同元素中依次取出m个元素排成一列的方案数.P(n,m)=n!/(n-m)! 组合数:从n个不同元素中依次取出m个元素形成一个集 ...
- jQuery笔记之Easing Plugin
jQuery easing 使用方法首先,项目中如果需要使用特殊的动画效果,则需要在引入jQuery之后引入jquery.easing.1.3.js<script type="text ...
- shiro之SimpleAccountRealm
我使用的是maven构建的工程,junit测试 Shiro认证过程 创建SecurityManager--->主体提交认证--->SecurityManager认证--->Authe ...
- Hdu 3488 Tour (KM 有向环覆盖)
题目链接: Hdu 3488 Tour 题目描述: 有n个节点,m条有权单向路,要求用一个或者多个环覆盖所有的节点.每个节点只能出现在一个环中,每个环中至少有两个节点.问最小边权花费为多少? 解题思路 ...
- 454 4Sum II 四数相加 II
给定四个包含整数的数组列表 A , B , C , D ,计算有多少个元组 (i, j, k, l) ,使得 A[i] + B[j] + C[k] + D[l] = 0.为了使问题简单化,所有的 A, ...
- 会jQuery,该如何用AngularJS编程思想?
我可以熟练使用jQuery进行客户端应用的开发,但是现在我希望开始使用Angular.js.哪位能描述一下这个过程中必要的模式变化吗?希望您的答案能够围绕下面这些具体的问题: 1. 我如何对客户端we ...
- HTML标签,简单归纳
列表标签 有序列表: <ol><li></li></ol> 无序列表: <ul><li></li></ul&g ...
- Thinkphp删除缓存
控制器代码 public function delcache(){ //当找到有Runtime的文件夹时,进入if if(is_dir(RUNTIME_PATH)){ delDir(RUNTIME ...
- android中使用图文并茂的按钮
代码: <LinearLayout android:orientation="horizontal" android:layout_width="match_par ...