Tesorflow-自动编码器(AutoEncoder)
直接附上代码:
import numpy as np
import sklearn.preprocessing as prep
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data def xavier_init(fan_in,fan_out,constant=1):
low=-constant*np.sqrt(6.0/(fan_in+fan_out))
high=constant*np.sqrt(6.0/(fan_in+fan_out))
return tf.random_uniform((fan_in,fan_out),minval=low,maxval=high,dtype=tf.float32) class AdditiveGaussianNoiseAutoencoder(object):
def __init__(self,n_input,n_hidden,transfer_function=tf.nn.softplus,optimizer=tf.train.AdamOptimizer(),scale=0.1):
self.n_input=n_input
self.n_hidden=n_hidden
self.transfer=transfer_function
self.scale=tf.placeholder(tf.float32)
self.training_scale=scale
network_weights=self._initialize_weights()
self.weights=network_weights self.x=tf.placeholder(tf.float32,[None,self.n_input])
self.hidden=self.transfer(tf.add(tf.matmul(self.x+scale*tf.random_normal((n_input,)),self.weights['w1']),self.weights['b1']))
self.reconstruction=tf.add(tf.matmul(self.hidden,self.weights['w2']),self.weights['b2'])
self.cost=0.5*tf.reduce_sum(tf.pow(tf.sub(self.reconstruction,self.x),2.0))
self.optimizer=optimizer.minimize(self.cost) init=tf.initialize_all_variables()
self.sess=tf.Session()
self.sess.run(init) def _initialize_weights(self):
all_weights=dict()
all_weights['w1']=tf.Variable(xavier_init(self.n_input,self.n_hidden))
all_weights['b1']=tf.Variable(tf.zeros([self.n_hidden],dtype=tf.float32))
all_weights['w2']=tf.Variable(tf.zeros([self.n_hidden,self.n_input],dtype=tf.float32))
all_weights['b2']=tf.Variable(tf.zeros([self.n_input],dtype=tf.float32)) return all_weights def partial_fit(self,X): cost,opt=self.sess.run((self.cost,self.optimizer),feed_dict={self.x:X,self.scale:self.training_scale}) return cost def calc_total_cost(self,X):
return self.sess.run(self.cost,feed_dict={self.x:X,self.scale:self.training_scale}) def transform(self,X):
return self.sess.run(self.hidden,feed_dict={self.x:X,self.scale:self.training_scale}) def generate(self,hidden=None):
if hidden is None:
hidden=np.random.normal(size=self.weights["b1"])
return self.sess.run(self.reconstruction,feed_dict={self.hidden:hidden}) def reconstruct(self,X):
return self.sess.run(self.reconstruction,feed_dict={self.x:X,self.scale:self.training_scale}) def getWeights(self):
return self.sess.run(self.weights['w1']) def getBiases(self):
return self.sess.run(self.weights['b1']) mnist=input_data.read_data_sets('MNIST_data',one_hot=True) def standard_scale(X_train,X_test):
preprocessor=prep.StandardScaler().fit(X_train)
X_train=preprocessor.transform(X_train)
X_test=preprocessor.transform(X_test)
return X_train,X_test def get_random_block_from_data(data,batch_size):
start_index=np.random.randint(0,len(data)-batch_size)
return data[start_index:(start_index+batch_size)] X_train,X_test=standard_scale(mnist.train.images,mnist.test.images)
n_samples=int(mnist.train.num_examples)
training_epochs=20
batch_size=128
diaplay_step=1
autoencoder=AdditiveGaussianNoiseAutoencoder(n_input=784,n_hidden=200,transfer_function=tf.nn.softplus,optimizer=tf.train.AdamOptimizer(learning_rate=0.001),scale=0.01)
for epoch in range(training_epochs):
avg_cost=0
total_batch=int(n_samples/batch_size)
for i in range(total_batch):
batch_xs=get_random_block_from_data(X_train,batch_size) cost=autoencoder.partial_fit(batch_xs)
avg_cost+=cost/n_samples*batch_size if epoch%diaplay_step==0:
print("Epoch:",'%04d'%(epoch+1),"cost=","{:.9f}".format(avg_cost)) print("Total cost: "+str(autoencoder.calc_total_cost(X_test)))
Tesorflow-自动编码器(AutoEncoder)的更多相关文章
- VAE--就是AutoEncoder的编码输出服从正态分布
花式解释AutoEncoder与VAE 什么是自动编码器 自动编码器(AutoEncoder)最开始作为一种数据的压缩方法,其特点有: 1)跟数据相关程度很高,这意味着自动编码器只能压缩与训练数据相似 ...
- Machine Learning Algorithms Study Notes(6)—遗忘的数学知识
机器学习中遗忘的数学知识 最大似然估计( Maximum likelihood ) 最大似然估计,也称为最大概似估计,是一种统计方法,它用来求一个样本集的相关概率密度函数的参数.这个方法最早是遗传学家 ...
- Machine Learning Algorithms Study Notes(4)—无监督学习(unsupervised learning)
1 Unsupervised Learning 1.1 k-means clustering algorithm 1.1.1 算法思想 1.1.2 k-means的不足之处 1 ...
- Machine Learning Algorithms Study Notes(1)--Introduction
Machine Learning Algorithms Study Notes 高雪松 @雪松Cedro Microsoft MVP 目 录 1 Introduction 1 1.1 ...
- 【Todo】【转载】深度学习&神经网络 科普及八卦 学习笔记 & GPU & SIMD
上一篇文章提到了数据挖掘.机器学习.深度学习的区别:http://www.cnblogs.com/charlesblc/p/6159355.html 深度学习具体的内容可以看这里: 参考了这篇文章:h ...
- (zhuan) 深度学习全网最全学习资料汇总之模型介绍篇
This blog from : http://weibo.com/ttarticle/p/show?id=2309351000224077630868614681&u=5070353058& ...
- SIGAI深度学习第四集 深度学习简介
讲授机器学习面临的挑战.人工特征的局限性.为什么选择神经网络.深度学习的诞生和发展.典型的网络结构.深度学习在机器视觉.语音识别.自然语言处理.推荐系统中的应用 大纲: 机器学习面临的挑战 特征工程的 ...
- AI面试必备/深度学习100问1-50题答案解析
AI面试必备/深度学习100问1-50题答案解析 2018年09月04日 15:42:07 刀客123 阅读数 2020更多 分类专栏: 机器学习 转载:https://blog.csdn.net ...
- 栈式自动编码器(Stacked AutoEncoder)
起源:自动编码器 单自动编码器,充其量也就是个强化补丁版PCA,只用一次好不过瘾. 于是Bengio等人在2007年的 Greedy Layer-Wise Training of Deep Netw ...
- 降噪自动编码器(Denoising Autoencoder)
起源:PCA.特征提取.... 随着一些奇怪的高维数据出现,比如图像.语音,传统的统计学-机器学习方法遇到了前所未有的挑战. 数据维度过高,数据单调,噪声分布广,传统方法的“数值游戏”很难奏效.数据挖 ...
随机推荐
- 2.Books
Books示例说明了Qt中SQL类如何被Model/View框架使用,使用数据库中存储的信息,创建丰富的用户界面. 首先介绍使用SQL我们需要了解的类: 1.QSqlDatabase: QSqlDat ...
- orcad找不到dll
如果运行Capture.exe找不到cdn_sfl401as.dll,如果运行allegro.exe找不到cnlib.dll,(上面俩个库文件都在C:/Cadence/SPB_16.3/tools/b ...
- [转]Replace all UUIDs in an ATL COM DLL.
1. Introduction. 1.1 Recently, a friend asked me for advise on a very unusual requirement. 1.2 He ne ...
- Windows中与系统关联自己开发的程序(默认打开方式、图标、右击菜单等)
1. 默认打开方式 1.1. 代码支持 在Windows下,某个特定后缀名类型的文件,如果要双击时默认用某个程序(比如自己开发的WinForm程序)打开,代码中首先肯定要支持直接根据这个文件进行下一步 ...
- 201621123012 《Java程序设计》第7周学习总结
1. 本周学习总结 1.1 思维导图:Java图形界面总结 答: 1.2 可选:使用常规方法总结其他上课内容. 2.书面作业 1. GUI中的事件处理 1.1 写出事件处理模型中最重要的几个关键词. ...
- 分享一个利用HTML5制作的海浪效果代码
在前面简单讲述了一下HTML里的Canvas,这次根据Canvas完成了“海浪效果”(水波上升). (O(∩_∩)O哈哈哈~作者我能看这个动画看一下午) 上升水波.gif 动画分析构成:贝塞尔曲线画布 ...
- hbase安装 配置报错 zookeeper启动报错
zookeeper安装问题,使用独立安装的zookeeper export HBASE_MANAGES_ZK=false #如果使用独立安装的zookeeper这个地方就是false 创建zook ...
- Visual odometry and zed's IMU fusion on RTAB-Map
"When using /camera/odom, you don't need to use visual_odometry node. rtabmap should be subscri ...
- Linux环境下mysql安装并配置远程访问
环境:centOS 1.下载mysql安装文件 [root@localhost ~]# wget http://dev.mysql.com/get/mysql-community-release-el ...
- 【转】C#中静态方法和非静态方法的区别
源地址:https://www.cnblogs.com/amoshu/p/7477757.html 备注:静态方法不需要类的实例化就能调用,因为它是一直保存在内存中,不像非静态方法一样要放在实例化类时 ...