Tensorflow2.0实战之Auto-Encoder
autoencoder可以用于数据压缩、降维,预训练神经网络,生成数据等等
Auto-Encoder架构

需要完成的工作
需要完成Encoder和Decoder的训练
例如,Mnist的一张图片大小为784维,将图片放到Encoder中进行压缩,编码code使得维度小于784维度,之后可以将code放进Decoder中进行重建,可以产生同之前相似的图片。
Encoder和Decoder需要一起进行训练。

输入同样是一张图片,通过选择W,找到数据的主特征向量,压缩图片得到code,然后使用W的转置,恢复图片。
我们知道,PCA对数据的降维是线性的(linear),恢复数据会有一定程度的失真。上面通过PCA恢复的图片也是比较模糊的。
所以,我们也可以把PCA理解成为一个线性的autoencoder,W就是encode的作用,w的转置就是decode的作用,最后的目的是decode的结果和原始图片越接近越好。

现在来看真正意义上的Deep Auto-encoder的结构。通常encoder每层对应的W和decoder每层对应的W不需要对称(转置)

从上面可以看出,Auto-encoder产生的图片,比PCA还原的图片更加接近真实图片。
接下来我们就来实现这样的一个Auto-Encoder
实现
导入必要的第三方库,以及前期的处理
import os
import numpy as np
from PIL import Image
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Sequential,layers
tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
assert tf.__version__.startswith('2.')
定义一个保存图片的方法,以便于将我们新生成的图片保存起来,为我们后面我们查看图片的效果带来持久化的数据
def save_images(imgs,name):
new_im=Image.new('L',(280,280))
index=0
for i in range(0,280,28):
for j in range(0,280,28):
im=imgs[index]
im=Image.fromarray(im,mode='L')
new_im.paste(im,(i,j))
index+=1
new_im.save(name)
这部分为数据集的加载和图片重建的预处理过程;我们这里将高的维度降为20,这个参数可以随意,读者也可以将其降为10也是可以的。同时这里我们不再使用label了
h_dim=20
batchsz=512
lr=1e-3
(x_train,y_train),(x_test,y_test)=keras.datasets.fashion_mnist.load_data()
x_train,x_test=x_train.astype(np.float32)/255.,x_test.astype(np.float32)/255.
train_data=tf.data.Dataset.from_tensor_slices(x_train)
train_data=train_data.shuffle(batchsz*5).batch(batchsz)
test_data=tf.data.Dataset.from_tensor_slices(x_test)
test_data=test_data.batch(batchsz)
接下来我们创建模型
这里我们使用keras的接口,再建立模型的时,我们需要继承Keras下的Model
我们先将网络结构搭建出来,这里有两个部分,一个是init的初始化方法;另一个是call前向传播的方法
class AE(keras.Model):
def __init__(self):
super(AE, self).__init__()
pass
def call(self,inputs,training=None):
pass
编写好上述后,我们完成init和call中的方法。
首先编写Encoder,这里Encoder将编辑为高维度、抽象的向量
self.encoder=Sequential([
layers.Dense(256,activation=tf.nn.relu),
layers.Dense(128,activation=tf.nn.relu),
layers.Dense(h_dim)
])
我们再编写Decoders的方法,可以看到同Encoder是相反的过程
self.decoder=Sequential([
layers.Dense(128,activation=tf.nn.relu),
layers.Dense(256,activation=tf.nn.relu),
layers.Dense(784)
])
完成了init的方法后,我们再来写call中的方法了,
首先使用encoder将输入的高维度图片置为低维的,然后再使用decoder还原,
笔者这里由于上述设置的h_dim为10,同时使用的是FashionMNIST数据集(维度是784),所以encoder将[b,784]-->[b,10],
decoder将[b,10]-->[b,784]
def call(self, inputs, training=None):
# encoder-->decoder [b,784]-->[b,10]
h=self.encoder(inputs)
# [b,10]-->[b,784]
x_hat=self.decoder(h)
return x_hat
接下来我们可以建立model,再看看model是怎样的
model=AE()
model.build(input_shape=(None,784))
model.summary()
Model: "ae"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
sequential (Sequential) multiple 236436
_________________________________________________________________
sequential_1 (Sequential) multiple 237200
=================================================================
Total params: 473,636
Trainable params: 473,636
Non-trainable params: 0
_________________________________________________________________
定义优化器
这里我们就使用Adam优化器,读者也可以使用SGD,这个无所谓。、
optimizer=tf.optimizers.Adam(lr=lr)
训练
for epoch in range(200):
for step,x in enumerate(train_data):
x=tf.reshape(x,[-1,784])
with tf.GradientTape() as tape:
x_rec_logits =model(x)
rec_loss =tf.losses.binary_crossentropy(x,x_rec_logits,from_logits=True)
rec_loss =tf.reduce_mean(rec_loss)
grads=tape.gradient(rec_loss,model.trainable_variables)
optimizer.apply_gradients(zip(grads,model.trainable_variables))
if step%100==0:
print(epoch,step,float(rec_loss))
验证
这里需要注意一下,image是一个文件夹,再训练前,我们需要在代码所在路径下手动添加
x=next(iter(test_data))
logits=model(tf.reshape(x,[-1,784])) # trans [0,1]
x_hat=tf.sigmoid(logits)
x_hat=tf.reshape(x_hat,[-1,28,28])
x_concat=tf.concat([x,x_hat],axis=0)
x_concat=x_concat.numpy()*255
x_concat=x_concat.astype(np.uint8)
save_images(x_concat,'image/epoch_%d.png'%epoch)
结果展示:






建议大家动手实践实践,共同进步。
笔者水平有限,如有表述不准确的地方还请谅解,有错误的地方欢迎大家批评指正。
Tensorflow2.0实战之Auto-Encoder的更多相关文章
- Google老师亲授 TensorFlow2.0实战: 入门到进阶
Google老师亲授 TensorFlow2.0 入门到进阶 课程以Tensorflow2.0框架为主体,以图像分类.房价预测.文本分类等项目为依托,讲解Tensorflow框架的使用方法,同时学习到 ...
- Google工程师亲授 Tensorflow2.0-入门到进阶
第1章 Tensorfow简介与环境搭建 本门课程的入门章节,简要介绍了tensorflow是什么,详细介绍了Tensorflow历史版本变迁以及tensorflow的架构和强大特性.并在Tensor ...
- Auto Encoder用于异常检测
对基于深度神经网络的Auto Encoder用于异常检测的一些思考 from:https://my.oschina.net/u/1778239/blog/1861724 一.前言 现实中,大部分数据都 ...
- 基于tensorflow2.0 使用tf.keras实现Fashion MNIST
本次使用的是2.0测试版,正式版估计会很快就上线了 tf2好像更新了蛮多东西 虽然教程不多 还是找了个试试 的确简单不少,但是还是比较喜欢现在这种写法 老样子先导入库 import tensorflo ...
- TensorFlow2.0(1):基本数据结构—张量
1 引言 TensorFlow2.0版本已经发布,虽然不是正式版,但预览版都发布了,正式版还会远吗?相比于1.X,2.0版的TensorFlow修改的不是一点半点,这些修改极大的弥补了1.X版本的反人 ...
- 『TensorFlow2.0正式版教程』极简安装TF2.0正式版(CPU&GPU)教程
0 前言 TensorFlow 2.0,今天凌晨,正式放出了2.0版本. 不少网友表示,TensorFlow 2.0比PyTorch更好用,已经准备全面转向这个新升级的深度学习框架了. 本篇文章就 ...
- TensorFlow2.0(9):TensorBoard可视化
.caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...
- TensorFlow2.0(11):tf.keras建模三部曲
.caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...
- 一文上手Tensorflow2.0(四)
系列文章目录: Tensorflow2.0 介绍 Tensorflow 常见基本概念 从1.x 到2.0 的变化 Tensorflow2.0 的架构 Tensorflow2.0 的安装(CPU和GPU ...
- 一文上手TensorFlow2.0(一)
目录: Tensorflow2.0 介绍 Tensorflow 常见基本概念 从1.x 到2.0 的变化 Tensorflow2.0 的架构 Tensorflow2.0 的安装(CPU和GPU) Te ...
随机推荐
- 《Kali渗透基础》01. 介绍
@ 目录 1:渗透测试 1.1:安全问题的根源 1.2:安全目标 1.3:渗透测试 1.4:标准 2:Kali 2.1:介绍 2.2:策略 2.3:安装 3:Kali 初步设置 3.1:远程连接 3. ...
- 《SQLi-Labs》03. Less 11~15
@ 目录 索引 Less-11 题解 原理 Less-12 题解 Less-13 题解 Less-14 题解 Less-15 题解 原理 sqli.开启新坑. 索引 Less-11:POST 回显注入 ...
- 《SQL与数据库基础》04. SQL-DQL
目录 DQL 基础查询 条件查询 分组聚合 聚合函数 分组查询 结果排序 分页限制 总结 本文以 MySQL 为例 DQL 语法结构: SELECT 字段列表 FROM 表名列表 WHERE 条件列表 ...
- Java 设计模式实战系列—单例模式
本文首发公众号:小码A梦 单例模式是设计模式中最简单一个设计模式,该模式属于创建型模式,它提供了一种创建实例的最佳方式. 单例模式的定义也比较简单:一个类只能允许创建一个对象或者实例,那么这个类就是单 ...
- lattice crosslink开发板mipi核心板csi测试dsi屏lif md6000 fpga
1. 概述 CrossLink开发板,是用Lattice的芯片CrossLink 家族系列的,LIF-MD6000-6JM80I.该芯片用于桥接视频接口功能,自带2路MIPI硬核的功能,4 LANE ...
- 从软件工程师角度聊聊 Kubernetes
作为软件工程师,我们应该熟悉 K8s,尽管它有点像 DevOps,但它能让我们更好地了解幕后发生的事情,让我们与部署工作更密切相关,更有责任感.本文将从软件工程师的角度探讨 Kubernetes (K ...
- 文心一言 VS 讯飞星火 VS chatgpt (90)-- 算法导论8.3 3题
三.用go语言,利用归纳法来证明基数排序是正确的.在你所给出的证明中,在哪里需要假设所用的底层排序算法是稳定的? 文心一言: 基数排序是一种非比较型整数排序算法,其通过在每一位上进行比较来排序.基数排 ...
- HiAI Foundation助力端侧音视频AI能力,高性能低功耗释放云侧成本
过去三年是端侧AI高速发展的几年,华为在2020年预言了端侧AI的发展潮流,2021年通过提供端云协同的方式使我们的HiAI Foundation应用性更进一个台阶,2022年提供视频超分端到端的解决 ...
- 当你使用Taro时,你需要了解的一些事儿
2017 年 1 月 9 日凌晨,万众期待的微信小程序正式上线,前有跳一跳等爆圈小游戏的带动,后有特殊时期下各类健康码小程序的加持,小程序成为了国内技术圈独树一帜的存在.但随着小程序的迅猛发展,其实在 ...
- 探析ElasticSearch Kibana在测试工作中的实践应用
一. 为什么使用ES Kibana 离线数据测试中最重要的就是数据验证,一部分需要测试es存储数据的正确性,另一部分就需要验证接口从es取值逻辑的正确性.而为了验证es取值逻辑的正确性,就需要用到Ki ...