autoencoder可以用于数据压缩、降维,预训练神经网络,生成数据等等

Auto-Encoder架构



需要完成的工作

需要完成Encoder和Decoder的训练

例如,Mnist的一张图片大小为784维,将图片放到Encoder中进行压缩,编码code使得维度小于784维度,之后可以将code放进Decoder中进行重建,可以产生同之前相似的图片。

Encoder和Decoder需要一起进行训练。



输入同样是一张图片,通过选择W,找到数据的主特征向量,压缩图片得到code,然后使用W的转置,恢复图片。

我们知道,PCA对数据的降维是线性的(linear),恢复数据会有一定程度的失真。上面通过PCA恢复的图片也是比较模糊的。

所以,我们也可以把PCA理解成为一个线性的autoencoder,W就是encode的作用,w的转置就是decode的作用,最后的目的是decode的结果和原始图片越接近越好。



现在来看真正意义上的Deep Auto-encoder的结构。通常encoder每层对应的W和decoder每层对应的W不需要对称(转置)



从上面可以看出,Auto-encoder产生的图片,比PCA还原的图片更加接近真实图片。

接下来我们就来实现这样的一个Auto-Encoder

实现

导入必要的第三方库,以及前期的处理

import os
import numpy as np
from PIL import Image
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Sequential,layers tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
assert tf.__version__.startswith('2.')

定义一个保存图片的方法,以便于将我们新生成的图片保存起来,为我们后面我们查看图片的效果带来持久化的数据

def save_images(imgs,name):
new_im=Image.new('L',(280,280))
index=0
for i in range(0,280,28):
for j in range(0,280,28):
im=imgs[index]
im=Image.fromarray(im,mode='L')
new_im.paste(im,(i,j))
index+=1
new_im.save(name)

这部分为数据集的加载和图片重建的预处理过程;我们这里将高的维度降为20,这个参数可以随意,读者也可以将其降为10也是可以的。同时这里我们不再使用label了

h_dim=20
batchsz=512
lr=1e-3
(x_train,y_train),(x_test,y_test)=keras.datasets.fashion_mnist.load_data()
x_train,x_test=x_train.astype(np.float32)/255.,x_test.astype(np.float32)/255.
train_data=tf.data.Dataset.from_tensor_slices(x_train)
train_data=train_data.shuffle(batchsz*5).batch(batchsz)
test_data=tf.data.Dataset.from_tensor_slices(x_test)
test_data=test_data.batch(batchsz)

接下来我们创建模型

这里我们使用keras的接口,再建立模型的时,我们需要继承Keras下的Model

我们先将网络结构搭建出来,这里有两个部分,一个是init的初始化方法;另一个是call前向传播的方法

class AE(keras.Model):
def __init__(self):
super(AE, self).__init__()
pass
def call(self,inputs,training=None):
pass

编写好上述后,我们完成init和call中的方法。

首先编写Encoder,这里Encoder将编辑为高维度、抽象的向量

 self.encoder=Sequential([
layers.Dense(256,activation=tf.nn.relu),
layers.Dense(128,activation=tf.nn.relu),
layers.Dense(h_dim)
])

我们再编写Decoders的方法,可以看到同Encoder是相反的过程

        self.decoder=Sequential([
layers.Dense(128,activation=tf.nn.relu),
layers.Dense(256,activation=tf.nn.relu),
layers.Dense(784)
])

完成了init的方法后,我们再来写call中的方法了,

首先使用encoder将输入的高维度图片置为低维的,然后再使用decoder还原,

笔者这里由于上述设置的h_dim为10,同时使用的是FashionMNIST数据集(维度是784),所以encoder将[b,784]-->[b,10],

decoder将[b,10]-->[b,784]

    def call(self, inputs, training=None):
# encoder-->decoder [b,784]-->[b,10]
h=self.encoder(inputs)
# [b,10]-->[b,784]
x_hat=self.decoder(h)
return x_hat

接下来我们可以建立model,再看看model是怎样的

model=AE()
model.build(input_shape=(None,784))
model.summary()
Model: "ae"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
sequential (Sequential) multiple 236436
_________________________________________________________________
sequential_1 (Sequential) multiple 237200
=================================================================
Total params: 473,636
Trainable params: 473,636
Non-trainable params: 0
_________________________________________________________________

定义优化器

这里我们就使用Adam优化器,读者也可以使用SGD,这个无所谓。、

optimizer=tf.optimizers.Adam(lr=lr)

训练

for epoch in range(200):
for step,x in enumerate(train_data):
x=tf.reshape(x,[-1,784])
with tf.GradientTape() as tape:
x_rec_logits =model(x)
rec_loss =tf.losses.binary_crossentropy(x,x_rec_logits,from_logits=True)
rec_loss =tf.reduce_mean(rec_loss)
grads=tape.gradient(rec_loss,model.trainable_variables)
optimizer.apply_gradients(zip(grads,model.trainable_variables))
if step%100==0:
print(epoch,step,float(rec_loss))

验证

这里需要注意一下,image是一个文件夹,再训练前,我们需要在代码所在路径下手动添加

        x=next(iter(test_data))
logits=model(tf.reshape(x,[-1,784])) # trans [0,1]
x_hat=tf.sigmoid(logits)
x_hat=tf.reshape(x_hat,[-1,28,28])
x_concat=tf.concat([x,x_hat],axis=0)
x_concat=x_concat.numpy()*255
x_concat=x_concat.astype(np.uint8)
save_images(x_concat,'image/epoch_%d.png'%epoch)

结果展示:



建议大家动手实践实践,共同进步。

笔者水平有限,如有表述不准确的地方还请谅解,有错误的地方欢迎大家批评指正。

Tensorflow2.0实战之Auto-Encoder的更多相关文章

  1. Google老师亲授 TensorFlow2.0实战: 入门到进阶

    Google老师亲授 TensorFlow2.0 入门到进阶 课程以Tensorflow2.0框架为主体,以图像分类.房价预测.文本分类等项目为依托,讲解Tensorflow框架的使用方法,同时学习到 ...

  2. Google工程师亲授 Tensorflow2.0-入门到进阶

    第1章 Tensorfow简介与环境搭建 本门课程的入门章节,简要介绍了tensorflow是什么,详细介绍了Tensorflow历史版本变迁以及tensorflow的架构和强大特性.并在Tensor ...

  3. Auto Encoder用于异常检测

    对基于深度神经网络的Auto Encoder用于异常检测的一些思考 from:https://my.oschina.net/u/1778239/blog/1861724 一.前言 现实中,大部分数据都 ...

  4. 基于tensorflow2.0 使用tf.keras实现Fashion MNIST

    本次使用的是2.0测试版,正式版估计会很快就上线了 tf2好像更新了蛮多东西 虽然教程不多 还是找了个试试 的确简单不少,但是还是比较喜欢现在这种写法 老样子先导入库 import tensorflo ...

  5. TensorFlow2.0(1):基本数据结构—张量

    1 引言 TensorFlow2.0版本已经发布,虽然不是正式版,但预览版都发布了,正式版还会远吗?相比于1.X,2.0版的TensorFlow修改的不是一点半点,这些修改极大的弥补了1.X版本的反人 ...

  6. 『TensorFlow2.0正式版教程』极简安装TF2.0正式版(CPU&GPU)教程

    0 前言 TensorFlow 2.0,今天凌晨,正式放出了2.0版本. 不少网友表示,TensorFlow 2.0比PyTorch更好用,已经准备全面转向这个新升级的深度学习框架了. ​ 本篇文章就 ...

  7. TensorFlow2.0(9):TensorBoard可视化

    .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...

  8. TensorFlow2.0(11):tf.keras建模三部曲

    .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...

  9. 一文上手Tensorflow2.0(四)

    系列文章目录: Tensorflow2.0 介绍 Tensorflow 常见基本概念 从1.x 到2.0 的变化 Tensorflow2.0 的架构 Tensorflow2.0 的安装(CPU和GPU ...

  10. 一文上手TensorFlow2.0(一)

    目录: Tensorflow2.0 介绍 Tensorflow 常见基本概念 从1.x 到2.0 的变化 Tensorflow2.0 的架构 Tensorflow2.0 的安装(CPU和GPU) Te ...

随机推荐

  1. 《SQL与数据库基础》19. 日志

    目录 日志 错误日志 二进制日志 日志格式 日志查看 日志删除 查询日志 慢查询日志 本文以 MySQL 为例 日志 错误日志 错误日志是 MySQL 中最重要的日志之一,它记录了当 mysql 启动 ...

  2. [Python3] 初识py, 一个简单练手的小玩意. 快递查询

    有图有真相 脚本代码 最近刚入门py, 准备写点小玩意练练手. 于是决定拿快递100开刀. 因为它的api很简单. # 快递100 API # 作者: 剑齿虎 # 邮箱: yuxiaobo64@gma ...

  3. KRPANO 最新官方文档中文版(持续更新)

    KRPano最新官方文档中文版分享,后续持续更新: http://docs.krpano.tech/ 本博文发表于:http://www.krpano.tech/archives/849 发布者:屠龙 ...

  4. iperf 工具使用总结

    转载请注明出处: iperf是一个用于测量网络带宽的工具,可以通过客户端和服务器之间的数据传输来评估网络性能.下面详细介绍iperf的使用方法.常用命令和参数以及注意事项,并提供一些示例说明.在ipe ...

  5. nodejs实现的一个简单粗暴的洗牌算法

    据说名字长别人不一定看得到 之前用python,自带shuffle用的还是超爽的: 去年6月份自己动手用nodejs写一个21点扑克游戏的后台时,就需要一个洗牌算法,于是简单粗暴的实现了一个. 贴出来 ...

  6. python-微信

    wxpy/itchat已禁用 自从微信禁止网页版登陆之后,itchat 库实现的功能也就都不能用了: itchat现在叫wxpy 1.安装库wxpy: PS D:\01VSCodeScript\Pyt ...

  7. Pandas 读取 Excel 斜着读

    读取 Excel 斜着读数据 import pandas as pd def read_sideling(direction, sheet_name, row_start, col_start, ga ...

  8. Go 函数的健壮性、panic异常处理、defer 机制

    Go 函数的健壮性.panic异常处理.defer 机制 目录 Go 函数的健壮性.panic异常处理.defer 机制 一.函数健壮性的"三不要"原则 1.1 原则一:不要相信任 ...

  9. mybtis-plus 出现 Wrong namespace

    今天进行项目整合,刚开始代码搬的还挺快乐的,但是到后面调试起来,头晕眼花的.记录一个基本的错误. Cause: org.apache.ibatis.builder.BuilderException:  ...

  10. Xmind思维导图工具2023最新专业版破解思路

    工具介绍 XMind 是一款最为流行的专业级思维_导图_制作与编辑软件,它现在在全球范围内都已极具名气,可谓是办公.学习.团队交流必备工具之一. 准备工作 1,官方Xmind软件 2,一个心意的编辑器 ...