Keras使我们搭建神经网络变得异常简单,之前我们使用了Sequential来搭建LSTM:keras实现LSTM

我们要使用Keras的functional API搭建更加灵活的网络结构,比如说本文的autoencoder,关于autoencoder的介绍可以在这里找到:deep autoencoder

现在我们就开始。

step 0 导入需要的包

 import keras
from keras.layers import Dense, Input
from keras.datasets import mnist
from keras.models import Model
import numpy as np

step 1 数据预处理

这里需要说明一下,导入的原始数据shape为(60000,28,28),autoencoder使用(60000,28*28),而且autoencoder属于无监督学习,所以只需要导入x_train和x_test.

 (x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype('float32')/255.0
x_test = x_test.astype('float32')/255.0
#print(x_train.shape)
x_train = x_train.reshape(x_train.shape[0], -1)
x_test = x_test.reshape(x_test.shape[0], -1)
#print(x_train.shape)

step 2 向图片添加噪声

添加噪声是为了让autoencoder更robust,不容易出现过拟合。

 #add random noise
x_train_nosiy = x_train + 0.3 * np.random.normal(loc=0., scale=1., size=x_train.shape)
x_test_nosiy = x_test + 0.3 * np.random.normal(loc=0, scale=1, size=x_test.shape)
x_train_nosiy = np.clip(x_train_nosiy, 0., 1.)
x_test_nosiy = np.clip(x_test_nosiy, 0, 1.)
print(x_train_nosiy.shape, x_test_nosiy.shape)

step 3 搭建网络结构

分别构建encoded和decoded,然后将它们链接起来构成整个autoencoder。使用Model建模。

 #build autoencoder model
input_img = Input(shape=(28*28,))
encoded = Dense(500, activation='relu')(input_img)
decoded = Dense(784, activation='sigmoid')(encoded) autoencoder = Model(input=input_img, output=decoded)

step 4 compile

因为这里是让解压后的图片和原图片做比较, loss使用的是binary_crossentropy。

 autoencoder.compile(optimizer='adam', loss='binary_crossentropy')
autoencoder.summary()

step 5 train

指定epochs,batch_size,可以使用validation_data,keras训练的时候不会使用它,而是用来做模型评价。

autoencoder.fit(x_train_nosiy, x_train, epochs=20, batch_size=128, verbose=1, validation_data=(x_test, x_test))

step 6 对比一下解压缩后的图片和原图片

 %matplotlib inline
import matplotlib.pyplot as plt #decoded test images
decoded_img = autoencoder.predict(x_test_nosiy) n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
#noisy data
ax = plt.subplot(3, n, i+1)
plt.imshow(x_test_nosiy[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
#predict
ax = plt.subplot(3, n, i+1+n)
plt.imshow(decoded_img[i].reshape(28, 28))
plt.gray()
ax.get_yaxis().set_visible(False)
ax.get_xaxis().set_visible(False)
#original
ax = plt.subplot(3, n, i+1+2*n)
plt.imshow(x_test[i].reshape(28, 28))
plt.gray()
ax.get_yaxis().set_visible(False)
ax.get_xaxis().set_visible(False)
plt.show()

这样的结果,你能分出哪个是压缩解压缩后的图片哪个是原图片吗?

reference:

https://keras.io/getting-started/functional-api-guide/

Keras实现autoencoder的更多相关文章

  1. keras使用AutoEncoder对mnist数据降维

    import keras import matplotlib.pyplot as plt from keras.datasets import mnist (x_train, _), (x_test, ...

  2. tlflearn 编码解码器 ——数据降维用

    # -*- coding: utf-8 -*- """ Auto Encoder Example. Using an auto encoder on MNIST hand ...

  3. Keras(六)Autoencoder 自编码 原理及实例 Save&reload 模型的保存和提取

    Autoencoder 自编码 压缩与解压 原来有时神经网络要接受大量的输入信息, 比如输入信息是高清图片时, 输入信息量可能达到上千万, 让神经网络直接从上千万个信息源中学习是一件很吃力的工作. 所 ...

  4. 深度学习Keras框架笔记之AutoEncoder类

    深度学习Keras框架笔记之AutoEncoder类使用笔记 keras.layers.core.AutoEncoder(encoder, decoder,output_reconstruction= ...

  5. 用Keras搭建神经网络 简单模版(六)——Autoencoder 自编码

    import numpy as np np.random.seed(1337) from keras.datasets import mnist from keras.models import Mo ...

  6. CNN autoencoder 进行异常检测——TODO,使用keras进行测试

    https://sefiks.com/2018/03/23/convolutional-autoencoder-clustering-images-with-neural-networks/ http ...

  7. 深度学习中的Data Augmentation方法(转)基于keras

    在深度学习中,当数据量不够大时候,常常采用下面4中方法: 1. 人工增加训练集的大小. 通过平移, 翻转, 加噪声等方法从已有数据中创造出一批"新"的数据.也就是Data Augm ...

  8. 深度学习之自编码器AutoEncoder

    原文地址:https://blog.csdn.net/marsjhao/article/details/73480859 一.什么是自编码器(Autoencoder) 自动编码器是一种数据的压缩算法, ...

  9. (zhuan) Variational Autoencoder: Intuition and Implementation

    Agustinus Kristiadi's Blog TECH BLOG TRAVEL BLOG PORTFOLIO CONTACT ABOUT Variational Autoencoder: In ...

随机推荐

  1. iOS --转载 NSRange 和 NSString 详解

    一.NSRange 1.NSRange的介绍 NSRange是Foundation框架中比较常用的结构体, 它的定义如下: typedef struct _NSRange { NSUInteger l ...

  2. 怎么查看mac系统是32位还是64位的操作系统

    如何查看mac系统是32位还是64位的操作系统 (一)点击工具栏左上角点击 (苹果Logo)标志,关于本机  -->  更多信息 --> 系统报告  -->(左侧栏中)软件 (二) ...

  3. libevent(1)

    很多时候,除了响应事件之外,应用还希望做一定的数据缓冲.比如说,写入数据的时候,通常的运行模式是: l 决定要向连接写入一些数据,把数据放入到缓冲区中 l 等待连接可以写入 l 写入尽量多的数据 l  ...

  4. android应用安全——数据安全

    数据安全包含数据库数据安全.SD卡数据(外部存储)安全.RAM数据(内部存储)安全. android中操作数据库可使用SQLiteOpenHelper或ContentProvider的方式.使用SQL ...

  5. 【BZOJ4421】[Cerc2015] Digit Division 动态规划

    [BZOJ4421][Cerc2015] Digit Division Description 给出一个数字串,现将其分成一个或多个子串,要求分出来的每个子串能Mod M等于0. 将方案数(mod 1 ...

  6. java的double类型如何精确到一位小数?

    java的double类型如何精确到一位小数? //分钟转小时vacationNum = (double)Math.round(vacationNum/60*10)/10.0;overTimeNum ...

  7. 【IDEA】启动项目报错:3 字节的 UTF-8 序列的字节 3 无效

    一.报错和原因: 项目起服务出错.具体报错就不贴了,报错主要是"3 字节的 UTF-8 序列的字节 3 无效". 分析:主要就是项目编码问题,IDEA中估计就是配置不对,没必要纠结 ...

  8. Phonetic Symbols:2个半元音:[w] ,[j]

    2个半元音音标发音技巧与单词举例 原文地址:http://www.hlyy.in/1243.html 2个半元音音标发音技巧与半元音单词举例 [w]  发音技巧: 嘴唇张开到刚好可以含住一根吸管的程度 ...

  9. MongoDB-5: 查询(游标操作、游标信息)

    一.简介 db.collection.find()可以实现根据条件查询和指定使用投影运算符返回的字段省略此参数返回匹配文档中的所有字段.并返回到匹配文档的游标,可以随意修改查询限制.跳跃.和排序顺序的 ...

  10. Python面向对象高级编程-@property

    使用@property 在绑定属性时,如果直接把属性暴露出去,虽然写起来简单,但是没法检查参数,导致可以把成绩随便改: >>> class Student(object): pass ...