Auto-Encoders实战
Outline
Auto-Encoder
Variational Auto-Encoders
Auto-Encoder

创建编解码器
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import Sequential, layers
from PIL import Image
from matplotlib import pyplot as plt
tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')
def save_images(imgs, name):
new_im = Image.new('L', (280, 280))
index = 0
for i in range(0, 280, 28):
for j in range(0, 280, 28):
im = imgs[index]
im = Image.fromarray(im, mode='L')
new_im.paste(im, (i, j))
index += 1
new_im.save(name)
h_dim = 20 # 784降维20维
batchsz = 512
lr = 1e-3
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(
np.float32) / 255.
# we do not need label
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batchsz * 5).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batchsz)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
class AE(keras.Model):
def __init__(self):
super(AE, self).__init__()
# Encoders
self.encoder = Sequential([
layers.Dense(256, activation=tf.nn.relu),
layers.Dense(128, activation=tf.nn.relu),
layers.Dense(h_dim)
])
# Decoders
self.decoder = Sequential([
layers.Dense(128, activation=tf.nn.relu),
layers.Dense(256, activation=tf.nn.relu),
layers.Dense(784)
])
def call(self, inputs, training=None):
# [b,784] ==> [b,19]
h = self.encoder(inputs)
# [b,10] ==> [b,784]
x_hat = self.decoder(h)
return x_hat
model = AE()
model.build(input_shape=(None, 784)) # tensorflow尽量用元组
model.summary()
(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)
Model: "ae"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
sequential (Sequential) multiple 236436
_________________________________________________________________
sequential_1 (Sequential) multiple 237200
=================================================================
Total params: 473,636
Trainable params: 473,636
Non-trainable params: 0
_________________________________________________________________
训练
optimizer = tf.optimizers.Adam(lr=lr)
for epoch in range(10):
for step, x in enumerate(train_db):
# [b,28,28]==>[b,784]
x = tf.reshape(x, [-1, 784])
with tf.GradientTape() as tape:
x_rec_logits = model(x)
rec_loss = tf.losses.binary_crossentropy(x,
x_rec_logits,
from_logits=True)
rec_loss = tf.reduce_min(rec_loss)
grads = tape.gradient(rec_loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
if step % 100 == 0:
print(epoch, step, float(rec_loss))
# evaluation
x = next(iter(test_db))
logits = model(tf.reshape(x, [-1, 784]))
x_hat = tf.sigmoid(logits)
# [b,784]==>[b,28,28]
x_hat = tf.reshape(x_hat, [-1, 28, 28])
# [b,28,28] ==> [2b,28,28]
x_concat = tf.concat([x, x_hat], axis=0)
# x_concat = x # 原始图片
x_concat = x_hat
x_concat = x_concat.numpy() * 255.
x_concat = x_concat.astype(np.uint8) # 保存为整型
if not os.path.exists('ae_images'):
os.mkdir('ae_images')
save_images(x_concat, 'ae_images/rec_epoch_%d.png' % epoch)
0 0 0.09717604517936707
0 100 0.12493347376585007
1 0 0.09747321903705597
1 100 0.12291513383388519
2 0 0.10048121958971024
2 100 0.12292417883872986
3 0 0.10093794018030167
3 100 0.12260882556438446
4 0 0.10006923228502274
4 100 0.12275046110153198
5 0 0.0993042066693306
5 100 0.12257824838161469
6 0 0.0967678651213646
6 100 0.12443818897008896
7 0 0.0965462476015091
7 100 0.12179268896579742
8 0 0.09197664260864258
8 100 0.12110235542058945
9 0 0.0913471132516861
9 100 0.12342415750026703
Auto-Encoders实战的更多相关文章
- [Python] 机器学习库资料汇总
声明:以下内容转载自平行宇宙. Python在科学计算领域,有两个重要的扩展模块:Numpy和Scipy.其中Numpy是一个用python实现的科学计算包.包括: 一个强大的N维数组对象Array: ...
- python数据挖掘领域工具包
原文:http://qxde01.blog.163.com/blog/static/67335744201368101922991/ Python在科学计算领域,有两个重要的扩展模块:Numpy和Sc ...
- Theano3.1-练习之初步介绍
来自 http://deeplearning.net/tutorial/,虽然比较老了,不过觉得想系统的学习theano,所以需要从python--numpy--theano的顺序学习.这里的资料都很 ...
- [resource]Python机器学习库
reference: http://qxde01.blog.163.com/blog/static/67335744201368101922991/ Python在科学计算领域,有两个重要的扩展模块: ...
- 机器学习——深度学习(Deep Learning)
Deep Learning是机器学习中一个非常接近AI的领域,其动机在于建立.模拟人脑进行分析学习的神经网络,近期研究了机器学习中一些深度学习的相关知识,本文给出一些非常实用的资料和心得. Key W ...
- Deep Learning Tutorial - Classifying MNIST digits using Logistic Regression
Deep Learning Tutorial 由 Montreal大学的LISA实验室所作,基于Theano的深度学习材料.Theano是一个python库,使得写深度模型更容易些,也可以在GPU上训 ...
- [转]Python机器学习工具箱
原文在这里 Python在科学计算领域,有两个重要的扩展模块:Numpy和Scipy.其中Numpy是一个用python实现的科学计算包.包括: 一个强大的N维数组对象Array: 比较成熟的(广播 ...
- 深度学习材料:从感知机到深度网络A Deep Learning Tutorial: From Perceptrons to Deep Networks
In recent years, there’s been a resurgence in the field of Artificial Intelligence. It’s spread beyo ...
- Deep Learning(4)
四.拓展学习推荐 Deep Learning 经典阅读材料: The monograph or review paper Learning Deep Architectures for AI (Fou ...
- 深度学习教程Deep Learning Tutorials
Deep Learning Tutorials Deep Learning is a new area of Machine Learning research, which has been int ...
随机推荐
- bzoj 2326: [HNOI2011]数学作业【dp+矩阵快速幂】
矩阵乘法一般不满足交换律!!所以快速幂里需要注意乘的顺序!! 其实不难,设f[i]为i的答案,那么f[i]=(f[i-1]w[i]+i)%mod,w[i]是1e(i的位数),这个很容易写成矩阵的形式, ...
- 【原创】《从0开始学Elasticsearch》—集群健康和索引管理
内容目录 1.搭建Kibana2.集群健康3.索引操作 1.搭建Kibana 正如<Kibana 用户手册>中所介绍,Kibana 是一款开源的数据分析和可视化平台,因此我们可以借助 Ki ...
- centos 允许远程连接mysql
GRANT ALL PRIVILEGES ON *.* TO 'root'@'%' IDENTIFIED BY '123456' WITH GRANT OPTION;
- B - Burning Midnight Oil CodeForces - 165B
One day a highly important task was commissioned to Vasya — writing a program in a night. The progra ...
- Qt - 锁屏界面加虚拟小键盘
一.实现效果 鼠标点击"密码输入栏",弹出虚拟键盘,输入锁屏密码后,点击虚拟键盘外部区域,则会隐藏虚拟键盘,再点击登录,成功进入主界面. 二.虚拟键盘-程序设计 2.1 frmNu ...
- 使用Apache Commons IO组件读取大文件
Apache Commons IO读取文件代码如下: Files.readLines(new File(path), Charsets.UTF_8); FileUtils.readLines(new ...
- CentOS 6.5上安装GlassFish4.0 过程笔记
CentOS 6.5上安装GlassFish4.0 过程笔记 1.安装JDK, 注意操作系统的位数, 64 or 32: [root@linuxidc ~]# mkdir /usr/java [ro ...
- AJPFX:关于面向对象的封装
1.回顾 面向对象 -- 注重的是结果,强调的是具备功能的对象. 面向过程 -- 强调的是函数,注重的实现的过程. 函数:对功能的封装. ...
- 解决::processDebugResourcesERROR: In<declare-styleable> FontFamilyFont编译报错
cordova编译时报错 错误信息 :processDebugResourcesERROR: In <declare-styleable> FontFamilyFont, unable t ...
- STM32 (基础语法)笔记
指针遍历: char *myitoa(int value, char *string, int radix){ int i, d; int flag = 0; char *ptr = string; ...