TensorFlow自编码器(AutoEncoder)之MNIST实践
自编码器可以用于降维,添加噪音学习也可以获得去噪的效果。
以下使用单隐层训练mnist数据集,并且共享了对称的权重参数。
模型本身不难,调试的过程中有几个需要注意的地方:
- 模型对权重参数初始值敏感,所以这里对权重参数w做了一些限制
- 需要对数据标准化
- 学习率设置合理(Adam,0.001)
1,建立模型
import numpy as np
import tensorflow as tf class AutoEncoder(object):
'''
使用对称结构,解码器重用编码器的权重参数
'''
def __init__(self, input_shape, h1_size, lr):
tf.reset_default_graph()# 重置默认计算图,有时出错后内存还一团糟
with tf.variable_scope('auto_encoder', reuse=tf.AUTO_REUSE):
self.W1 = self.weights(shape=(input_shape, h1_size), name='h1')
self.b1 = self.bias(h1_size)
self.W2 = tf.transpose(tf.get_variable('h1')) # 共享参数,使用其转置
self.b2 = self.bias(input_shape)
self.lr = lr
self.input = tf.placeholder(shape=(None, input_shape),
dtype=tf.float32)
self.h1_out = tf.nn.softplus(tf.matmul(self.input, self.W1) + self.b1)# softplus,类relu
self.out = tf.matmul(self.h1_out, self.W2) + self.b2
self.optimizer = tf.train.AdamOptimizer(learning_rate=self.lr)
self.loss = 0.1 * tf.reduce_sum(
tf.pow(tf.subtract(self.input, self.out), 2))
self.train_op = self.optimizer.minimize(self.loss)
self.sess = tf.Session()
self.sess.run(tf.global_variables_initializer()) def fit(self, X, epoches=100, batch_size=128, epoches_to_display=10):
batchs_per_epoch = X.shape[0] // batch_size
for i in range(epoches):
epoch_loss = []
for j in range(batchs_per_epoch):
X_train = X[j * batch_size:(j + 1) * batch_size]
loss, _ = self.sess.run([self.loss, self.train_op],
feed_dict={self.input: X_train})
epoch_loss.append(loss)
if i % epoches_to_display == 0:
print('avg_loss at epoch %d :%f' % (i, np.mean(epoch_loss)))
# return self.sess.run(W1) # 权重初始化参考别人的,这个居然很重要!用自己设定的截断正态分布随机没有效果
def weights(self, shape, name, constant=1):
fan_in = shape[0]
fan_out = shape[1]
low = -constant * np.sqrt(6.0 / (fan_in + fan_out))
high = constant * np.sqrt(6.0 / (fan_in + fan_out))
init = tf.random_uniform_initializer(minval=low, maxval=high)
return tf.get_variable(name=name,
shape=shape,
initializer=init,
dtype=tf.float32) def bias(self, size):
return tf.Variable(tf.constant(0, dtype=tf.float32, shape=[size])) def encode(self, X):
return self.sess.run(self.h1_out, feed_dict={self.input: X}) def decode(self, h):
return self.sess.run(self.out, feed_dict={self.h1_out: h}) def reconstruct(self, X):
return self.sess.run(self.out, feed_dict={self.input: X})
2,加载数据及预处理
from keras.datasets import mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data() import random
X_train = X_train.reshape(-1, 784)
# 测试集里随机10个图片用做测试
test_idxs = random.sample(range(X_test.shape[0]), 10)
data_test = X_test[test_idxs].reshape(-1, 784)
# 标准化
import sklearn.preprocessing as prep
processer = prep.StandardScaler().fit(X_train) # 这里还是用全部数据好,这个也很关键!
X_train = processer.transform(X_train)
X_test = processer.transform(data_test) # 随机5000张图片用做训练
idxs = random.sample(range(X_train.shape[0]), 5000)
data_train = X_train[idxs]
3,训练
model = AutoEncoder(784, 200, 0.001) # 学习率对loss影响也有点大
model.fit(data_train, batch_size=128, epoches=200) # 200轮即可
4,测试,可视化对比图
decoded_test = model.reconstruct(X_test) import matplotlib.pyplot as plt
%matplotlib inline
shape = (28, 28)
fig, axes = plt.subplots(2,10,
figsize=(10, 2),
subplot_kw={
'xticks': [],
'yticks': []
},
gridspec_kw=dict(hspace=0.1, wspace=0.1))
for i in range(10):
axes[0][i].imshow(np.reshape(X_test[i], shape))
axes[1][i].imshow(np.reshape(decoded_test[i], shape))
plt.show()
结果如下:

以上,可以在输入中添加点高斯噪音,增加鲁棒性。
TensorFlow自编码器(AutoEncoder)之MNIST实践的更多相关文章
- 用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识
用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识 循环神经网络RNN相比传统的神经网络在处理序列化数据时更有优势,因为RNN能够将加入上(下)文信息进行考虑.一个简单的RNN如 ...
- 吴裕雄 PYTHON 神经网络——TENSORFLOW 双隐藏层自编码器设计处理MNIST手写数字数据集并使用TENSORBORD描绘神经网络数据2
import os import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data os.envi ...
- 吴裕雄 PYTHON 神经网络——TENSORFLOW 单隐藏层自编码器设计处理MNIST手写数字数据集并使用TensorBord描绘神经网络数据
import os import numpy as np import tensorflow as tf import matplotlib.pyplot as plt from tensorflow ...
- 学习笔记TF061:分布式TensorFlow,分布式原理、最佳实践
分布式TensorFlow由高性能gRPC库底层技术支持.Martin Abadi.Ashish Agarwal.Paul Barham论文<TensorFlow:Large-Scale Mac ...
- 深度学习之自编码器AutoEncoder
原文地址:https://blog.csdn.net/marsjhao/article/details/73480859 一.什么是自编码器(Autoencoder) 自动编码器是一种数据的压缩算法, ...
- Tesorflow-自动编码器(AutoEncoder)
直接附上代码: import numpy as np import sklearn.preprocessing as prep import tensorflow as tf from tensorf ...
- tensorflow学习笔记五:mnist实例--卷积神经网络(CNN)
mnist的卷积神经网络例子和上一篇博文中的神经网络例子大部分是相同的.但是CNN层数要多一些,网络模型需要自己来构建. 程序比较复杂,我就分成几个部分来叙述. 首先,下载并加载数据: import ...
- tensorflow学习笔记四:mnist实例--用简单的神经网络来训练和测试
刚开始学习tf时,我们从简单的地方开始.卷积神经网络(CNN)是由简单的神经网络(NN)发展而来的,因此,我们的第一个例子,就从神经网络开始. 神经网络没有卷积功能,只有简单的三层:输入层,隐藏层和输 ...
- Tensorflow学习笔记(对MNIST经典例程的)的代码注释与理解
1 #coding:utf-8 # 日期 2017年9月4日 环境 Python 3.5 TensorFlow 1.3 win10开发环境. import tensorflow as tf from ...
随机推荐
- Oracle之:Function :getdate()
create or replace function getdate(sp_date varchar) return date is Result date; begin if LENGTH(sp_d ...
- BZOJ 2809: [Apio2012]dispatching(可并堆 左偏树板题)
这道题只要读懂题目一切好说. 给出nnn个点的一棵树,每一个点有一个费用vvv和一个领导力aaa,给出费用上限mmm.求下面这个式子的最大值ax∗∣S∣ ( S⊂x的子树, ∑iv[i]≤m )\la ...
- Luogu P4331 [BOI2004]Sequence 数字序列 (左偏树论文题)
清晰明了%%% Fairycastle的博客 个人习惯把size什么的存在左偏树结点内,这样在外面好写,在里面就是模板(只用修改update). 可以对比一下代码(好像也差不多-) MY CODE # ...
- golang web实战之三(基于iris框架的 web小应用,数据库采用 sqlite3 )
一.效果:一个图片应用 1.可上传图片到uploads目录. 2.可浏览和评论图片(用富文本编辑器输入) 二.梳理一下相关知识: 1.iris框架(模板输出,session) 2.富文本编辑器.sql ...
- c++回溯法求组合问题(取数,选取问题)从n个元素中选出m个的回溯算法
假如现在有n个数,分别从里面选择m个出来,那么一共有多少种不同的组合呢,分别是哪些呢? 利用计算机的计算力,采用回溯算法很容易求解 程序源代码如下: #include<iostream># ...
- easyUI datagrid中checkbox选中事件以及行点击事件,翻页之后还可以选中
DataGrid其中与选择,勾选相关 DataGrid属性:singleSelect boolean 如果为true,则只允许选择一行. false ctrlSelect boolean 在启用多行选 ...
- ngx_http_auth_request自用
server { listen 80; server_name www.php12.cn php12.mama1314.com; root /var/www/shf; location / { ind ...
- 突破大文件上传 和内网ip的端口转发
php上传大于2M文件的解决方法 2016年12月11日 :: katelyn9 阅读数 php上传大于2M文件的解决方法 如上传一个文件大于2m往往是上传不成功的解决方法: php.ini里查找 查 ...
- [BZOJ4033]:[HAOI2015]树上染色(树上DP)
题目传送门 题目描述 有一棵点数为N的树,树边有边权.给你一个在0~N之内的正整数K,你要在这棵树中选择K个点,将其染成黑色,并将其他的N-K个点染成白色.将所有点染色后,你会获得黑点两两之间的距离加 ...
- Function和Object 应该知道的
javascript有5种基础的内建对象(Fundamental Objects),Object.Function.Error.Symbol.Boolean,而Object/Function尤为特殊, ...