直接附上代码:

 import numpy as np
import sklearn.preprocessing as prep
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data def xavier_init(fan_in,fan_out,constant=1):
low=-constant*np.sqrt(6.0/(fan_in+fan_out))
high=constant*np.sqrt(6.0/(fan_in+fan_out))
return tf.random_uniform((fan_in,fan_out),minval=low,maxval=high,dtype=tf.float32) class AdditiveGaussianNoiseAutoencoder(object):
def __init__(self,n_input,n_hidden,transfer_function=tf.nn.softplus,optimizer=tf.train.AdamOptimizer(),scale=0.1):
self.n_input=n_input
self.n_hidden=n_hidden
self.transfer=transfer_function
self.scale=tf.placeholder(tf.float32)
self.training_scale=scale
network_weights=self._initialize_weights()
self.weights=network_weights self.x=tf.placeholder(tf.float32,[None,self.n_input])
self.hidden=self.transfer(tf.add(tf.matmul(self.x+scale*tf.random_normal((n_input,)),self.weights['w1']),self.weights['b1']))
self.reconstruction=tf.add(tf.matmul(self.hidden,self.weights['w2']),self.weights['b2'])
self.cost=0.5*tf.reduce_sum(tf.pow(tf.sub(self.reconstruction,self.x),2.0))
self.optimizer=optimizer.minimize(self.cost) init=tf.initialize_all_variables()
self.sess=tf.Session()
self.sess.run(init) def _initialize_weights(self):
all_weights=dict()
all_weights['w1']=tf.Variable(xavier_init(self.n_input,self.n_hidden))
all_weights['b1']=tf.Variable(tf.zeros([self.n_hidden],dtype=tf.float32))
all_weights['w2']=tf.Variable(tf.zeros([self.n_hidden,self.n_input],dtype=tf.float32))
all_weights['b2']=tf.Variable(tf.zeros([self.n_input],dtype=tf.float32)) return all_weights def partial_fit(self,X): cost,opt=self.sess.run((self.cost,self.optimizer),feed_dict={self.x:X,self.scale:self.training_scale}) return cost def calc_total_cost(self,X):
return self.sess.run(self.cost,feed_dict={self.x:X,self.scale:self.training_scale}) def transform(self,X):
return self.sess.run(self.hidden,feed_dict={self.x:X,self.scale:self.training_scale}) def generate(self,hidden=None):
if hidden is None:
hidden=np.random.normal(size=self.weights["b1"])
return self.sess.run(self.reconstruction,feed_dict={self.hidden:hidden}) def reconstruct(self,X):
return self.sess.run(self.reconstruction,feed_dict={self.x:X,self.scale:self.training_scale}) def getWeights(self):
return self.sess.run(self.weights['w1']) def getBiases(self):
return self.sess.run(self.weights['b1']) mnist=input_data.read_data_sets('MNIST_data',one_hot=True) def standard_scale(X_train,X_test):
preprocessor=prep.StandardScaler().fit(X_train)
X_train=preprocessor.transform(X_train)
X_test=preprocessor.transform(X_test)
return X_train,X_test def get_random_block_from_data(data,batch_size):
start_index=np.random.randint(0,len(data)-batch_size)
return data[start_index:(start_index+batch_size)] X_train,X_test=standard_scale(mnist.train.images,mnist.test.images)
n_samples=int(mnist.train.num_examples)
training_epochs=20
batch_size=128
diaplay_step=1
autoencoder=AdditiveGaussianNoiseAutoencoder(n_input=784,n_hidden=200,transfer_function=tf.nn.softplus,optimizer=tf.train.AdamOptimizer(learning_rate=0.001),scale=0.01)
for epoch in range(training_epochs):
avg_cost=0
total_batch=int(n_samples/batch_size)
for i in range(total_batch):
batch_xs=get_random_block_from_data(X_train,batch_size) cost=autoencoder.partial_fit(batch_xs)
avg_cost+=cost/n_samples*batch_size if epoch%diaplay_step==0:
print("Epoch:",'%04d'%(epoch+1),"cost=","{:.9f}".format(avg_cost)) print("Total cost: "+str(autoencoder.calc_total_cost(X_test)))

Tesorflow-自动编码器(AutoEncoder)的更多相关文章

  1. VAE--就是AutoEncoder的编码输出服从正态分布

    花式解释AutoEncoder与VAE 什么是自动编码器 自动编码器(AutoEncoder)最开始作为一种数据的压缩方法,其特点有: 1)跟数据相关程度很高,这意味着自动编码器只能压缩与训练数据相似 ...

  2. Machine Learning Algorithms Study Notes(6)—遗忘的数学知识

    机器学习中遗忘的数学知识 最大似然估计( Maximum likelihood ) 最大似然估计,也称为最大概似估计,是一种统计方法,它用来求一个样本集的相关概率密度函数的参数.这个方法最早是遗传学家 ...

  3. Machine Learning Algorithms Study Notes(4)—无监督学习(unsupervised learning)

    1    Unsupervised Learning 1.1    k-means clustering algorithm 1.1.1    算法思想 1.1.2    k-means的不足之处 1 ...

  4. Machine Learning Algorithms Study Notes(1)--Introduction

    Machine Learning Algorithms Study Notes 高雪松 @雪松Cedro Microsoft MVP 目 录 1    Introduction    1 1.1    ...

  5. 【Todo】【转载】深度学习&神经网络 科普及八卦 学习笔记 & GPU & SIMD

    上一篇文章提到了数据挖掘.机器学习.深度学习的区别:http://www.cnblogs.com/charlesblc/p/6159355.html 深度学习具体的内容可以看这里: 参考了这篇文章:h ...

  6. (zhuan) 深度学习全网最全学习资料汇总之模型介绍篇

    This blog from : http://weibo.com/ttarticle/p/show?id=2309351000224077630868614681&u=5070353058& ...

  7. SIGAI深度学习第四集 深度学习简介

    讲授机器学习面临的挑战.人工特征的局限性.为什么选择神经网络.深度学习的诞生和发展.典型的网络结构.深度学习在机器视觉.语音识别.自然语言处理.推荐系统中的应用 大纲: 机器学习面临的挑战 特征工程的 ...

  8. AI面试必备/深度学习100问1-50题答案解析

    AI面试必备/深度学习100问1-50题答案解析 2018年09月04日 15:42:07 刀客123 阅读数 2020更多 分类专栏: 机器学习   转载:https://blog.csdn.net ...

  9. 栈式自动编码器(Stacked AutoEncoder)

    起源:自动编码器 单自动编码器,充其量也就是个强化补丁版PCA,只用一次好不过瘾. 于是Bengio等人在2007年的  Greedy Layer-Wise Training of Deep Netw ...

  10. 降噪自动编码器(Denoising Autoencoder)

    起源:PCA.特征提取.... 随着一些奇怪的高维数据出现,比如图像.语音,传统的统计学-机器学习方法遇到了前所未有的挑战. 数据维度过高,数据单调,噪声分布广,传统方法的“数值游戏”很难奏效.数据挖 ...

随机推荐

  1. 如何恢复VS2013代码实时校验功能

      VS2013在某一天突然无法进行实时代码校验了,只有在编译的时候,错误列表才显示语法错误 怎么来解决这个问题呢?试试环境重置吧. 首先:打开工具菜单,选择“导入和导出设置”. 其次:可以先导出选定 ...

  2. CSVHelper 导出CSV 格式

    public class CSVHelper { System.Windows.Forms.SaveFileDialog saveFileDialog1;//保存 private string hea ...

  3. Arduino I2C + 数字式环境光传感器BH1750FVI

    BH1750FVI是日本罗姆(ROHM)半导体生产的数字式环境光传感IC.其主要特性有: I2C数字接口,支持速率最大400Kbps 输出量为光照度(Illuminance) 测量范围1~65535 ...

  4. 【模板】二维凸包 / [USACO5.1]圈奶牛Fencing the Cows

    Problem surface 戳我 Meaning 坐标系内有若干个点,问把这些点都圈起来的最小凸包周长. 这道题就是一道凸包的模板题啊,只要求出凸包后在计算就好了,给出几个注意点 记得检查是否有吧 ...

  5. NSNotification 消息通知的3种方式

    1.Notification Center的概念: 它是一个单例对象,允许当事件发生时通知一些对象,让对象做出相应反应. 它允许我们在低程度耦合的情况下,满足控制器与一个任意的对象进行通信的目的. 这 ...

  6. nginx负载均衡tomcat和配置ssl

    目录 tomcat 组件功能 engine host context connector service server valve logger realm UserDatabaseRealm 工作流 ...

  7. Resurrectio-capserjs的自动化脚本录制工具

    [根据github上的文档说明整理] Phantom下的任何操作都可以录制 Resurrectio是一个Chrome插件,他可以记录浏览器的操作,并转化成对应的casperjs脚本 Resurrect ...

  8. 决策树--Python

    决策树 实验集数据: #coding:utf8 #关键词:决策树(desision tree).特征选择.信息增益(information gain).香农熵.熵(entropy).经验熵(H(D)) ...

  9. win10+anaconda环境下pyqt5+qt tools+eric6.18安装及汉化过程

    最近需要用python编写一个小程序的界面,选择了pyqt5+eric6的配套组合,安装过程中遇到一些坑,特此记录.参考书籍是电子工业出版社的<PyQt5快速开发与实战>. 因为我使用an ...

  10. wpa_supplicant

    一 函数接口介绍 wpa_ctrl_open接口用来打开wpa_supplicant的控制接口,在UNIX系统里使用UNIX domain sockets,而在Windows里则是使用UDP sock ...