Keras实现autoencoder
Keras使我们搭建神经网络变得异常简单,之前我们使用了Sequential来搭建LSTM:keras实现LSTM。
我们要使用Keras的functional API搭建更加灵活的网络结构,比如说本文的autoencoder,关于autoencoder的介绍可以在这里找到:deep autoencoder。
现在我们就开始。
step 0 导入需要的包
import keras
from keras.layers import Dense, Input
from keras.datasets import mnist
from keras.models import Model
import numpy as np
step 1 数据预处理
这里需要说明一下,导入的原始数据shape为(60000,28,28),autoencoder使用(60000,28*28),而且autoencoder属于无监督学习,所以只需要导入x_train和x_test.
(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype('float32')/255.0
x_test = x_test.astype('float32')/255.0
#print(x_train.shape)
x_train = x_train.reshape(x_train.shape[0], -1)
x_test = x_test.reshape(x_test.shape[0], -1)
#print(x_train.shape)
step 2 向图片添加噪声
添加噪声是为了让autoencoder更robust,不容易出现过拟合。
#add random noise
x_train_nosiy = x_train + 0.3 * np.random.normal(loc=0., scale=1., size=x_train.shape)
x_test_nosiy = x_test + 0.3 * np.random.normal(loc=0, scale=1, size=x_test.shape)
x_train_nosiy = np.clip(x_train_nosiy, 0., 1.)
x_test_nosiy = np.clip(x_test_nosiy, 0, 1.)
print(x_train_nosiy.shape, x_test_nosiy.shape)
step 3 搭建网络结构
分别构建encoded和decoded,然后将它们链接起来构成整个autoencoder。使用Model建模。
#build autoencoder model
input_img = Input(shape=(28*28,))
encoded = Dense(500, activation='relu')(input_img)
decoded = Dense(784, activation='sigmoid')(encoded) autoencoder = Model(input=input_img, output=decoded)
step 4 compile
因为这里是让解压后的图片和原图片做比较, loss使用的是binary_crossentropy。
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')
autoencoder.summary()

step 5 train
指定epochs,batch_size,可以使用validation_data,keras训练的时候不会使用它,而是用来做模型评价。
autoencoder.fit(x_train_nosiy, x_train, epochs=20, batch_size=128, verbose=1, validation_data=(x_test, x_test))
step 6 对比一下解压缩后的图片和原图片
%matplotlib inline
import matplotlib.pyplot as plt #decoded test images
decoded_img = autoencoder.predict(x_test_nosiy) n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
#noisy data
ax = plt.subplot(3, n, i+1)
plt.imshow(x_test_nosiy[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
#predict
ax = plt.subplot(3, n, i+1+n)
plt.imshow(decoded_img[i].reshape(28, 28))
plt.gray()
ax.get_yaxis().set_visible(False)
ax.get_xaxis().set_visible(False)
#original
ax = plt.subplot(3, n, i+1+2*n)
plt.imshow(x_test[i].reshape(28, 28))
plt.gray()
ax.get_yaxis().set_visible(False)
ax.get_xaxis().set_visible(False)
plt.show()
这样的结果,你能分出哪个是压缩解压缩后的图片哪个是原图片吗?

reference:
https://keras.io/getting-started/functional-api-guide/
Keras实现autoencoder的更多相关文章
- keras使用AutoEncoder对mnist数据降维
import keras import matplotlib.pyplot as plt from keras.datasets import mnist (x_train, _), (x_test, ...
- tlflearn 编码解码器 ——数据降维用
# -*- coding: utf-8 -*- """ Auto Encoder Example. Using an auto encoder on MNIST hand ...
- Keras(六)Autoencoder 自编码 原理及实例 Save&reload 模型的保存和提取
Autoencoder 自编码 压缩与解压 原来有时神经网络要接受大量的输入信息, 比如输入信息是高清图片时, 输入信息量可能达到上千万, 让神经网络直接从上千万个信息源中学习是一件很吃力的工作. 所 ...
- 深度学习Keras框架笔记之AutoEncoder类
深度学习Keras框架笔记之AutoEncoder类使用笔记 keras.layers.core.AutoEncoder(encoder, decoder,output_reconstruction= ...
- 用Keras搭建神经网络 简单模版(六)——Autoencoder 自编码
import numpy as np np.random.seed(1337) from keras.datasets import mnist from keras.models import Mo ...
- CNN autoencoder 进行异常检测——TODO,使用keras进行测试
https://sefiks.com/2018/03/23/convolutional-autoencoder-clustering-images-with-neural-networks/ http ...
- 深度学习中的Data Augmentation方法(转)基于keras
在深度学习中,当数据量不够大时候,常常采用下面4中方法: 1. 人工增加训练集的大小. 通过平移, 翻转, 加噪声等方法从已有数据中创造出一批"新"的数据.也就是Data Augm ...
- 深度学习之自编码器AutoEncoder
原文地址:https://blog.csdn.net/marsjhao/article/details/73480859 一.什么是自编码器(Autoencoder) 自动编码器是一种数据的压缩算法, ...
- (zhuan) Variational Autoencoder: Intuition and Implementation
Agustinus Kristiadi's Blog TECH BLOG TRAVEL BLOG PORTFOLIO CONTACT ABOUT Variational Autoencoder: In ...
随机推荐
- spring 和 struts 整合遇到的问题(学习中)
一大早就报错 org.hibernate.TransactionException: Transaction not successfully started at org.hibernate.eng ...
- ubuntu16.04搭建jdk1.8运行环境
搭建环境:Ubuntu 16.04 ×64 JDK :jdk-8u171-linux-x64.tar.gz 首先下载linux对应的安装包下载地址:http://www.oracle.com/tech ...
- 巨蟒python全栈开发-第16天 核能来袭-初识面向对象
一.今日内容总览(上帝视角,大象自己进冰箱,控制时机) #转换思想(从面向过程到面向对象) 1.初识面向对象 面向过程: 一切以事物的发展流程为中心. 面向对象: 一切以对象为中心,一切皆为对象,具体 ...
- contos7 mongodb安装教程
通过yum安装mongodb 1.创建文件mongodb.repo文件, cd /etc/yum.repos.d/ vi mongodb.repo 复制如下代码: [mongodb-org-3.4] ...
- phpmail发送邮件
---恢复内容开始--- 首先.需要phpmailer的包. 地址:https://github.com/Synchro/PHPMailer 解开压缩包,将class.phpmailer.php,cl ...
- jd算法大赛 一个user_id只需映射到一个sku_id, 但是一个sku_id能否映射到多个user_id
0-购买预测 w 任务目标 提交user_id-->sku_id (1->1,但反向呢?) 实际操作: 0- 行为表中的type值加权为一个购买意向值0(性别.年龄,不干扰/或加权) 品类 ...
- Java基础语法 - 面向对象 - static 关键字
使用static关键字修饰的变量.常量和方法分别被称作静态变量.静态常量和静态方法,也被称作类的静态成员 静态变量 使用static修饰过的类变量称为静态变量 该变量需要使用类名.变量名进行调用,不能 ...
- SOE不能进入断点调试
一.前言 任何程序开发,如果不能进入断点调试,是非常的痛苦的. 如果有过SOE开发经验的人都知道,SOE开发过程中调试是非常麻烦的.任何在SOE开发模板中的修改都需要重新编译工程,重新生成.soe 文 ...
- linux里的CPU负载
昨天查看Nagios警报信息,发现其中一台服务器CPU负载过重,机器为CentOS系统.信息如下: 2011-2-15 (星期二) 17:50 WARNING - load average: 9.73 ...
- 用java求一个整数各位数字之和
/* * 用java求一个整数各位数字之和 */ public class Test02 { public static void main(String[] args) { System.out.p ...