使用Tensorflow搭建自编码器(Autoencoder)
自编码器是一种数据压缩算法,其中数据的压缩和解压缩函数是数据相关的、从样本中训练而来的。大部分自编码器中,压缩和解压缩的函数是通过神经网络实现的。
1. 使用卷积神经网络搭建自编码器
- 导入MNIST数据集(灰度图,像素范围0~1)
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', validation_size=0) - 搭建网络
inputs_ = tf.placeholder(tf.float32, (None, 28, 28, 1), name='inputs')
targets_ = tf.placeholder(tf.float32, (None, 28, 28, 1), name='targets')
### Encoder
conv1 = tf.layers.conv2d(inputs_, 16, (3,3), padding='same', activation=tf.nn.relu) # 28x28x16
maxpool1 = tf.layers.max_pooling2d(conv1, (2,2), (2,2), padding='same') # 14x14x16
conv2 = tf.layers.conv2d(maxpool1, 8, (3,3), padding='same', activation=tf.nn.relu) # 14x14x8
maxpool2 = tf.layers.max_pooling2d(conv2, (2,2), (2,2), padding='same') # 7x7x8
conv3 = tf.layers.conv2d(maxpool2, 8, (3,3), padding='same', activation=tf.nn.relu) # 7x7x8
encoded = tf.layers.max_pooling2d(conv3, (2,2), (2,2), padding='same') # 4x4x8
### Decoder
upsample1 = tf.image.resize_nearest_neighbor(encoded, (7,7)) # 7x7x8
conv4 = tf.layers.conv2d(upsample1, 8, (3,3), padding='same', activation=tf.nn.relu) # 7x7x8
upsample2 = tf.image.resize_nearest_neighbor(conv4, (14,14)) # 14x14x8
conv5 = tf.layers.conv2d(upsample2, 8, (3,3), padding='same', activation=tf.nn.relu) # 14x14x8
upsample3 = tf.image.resize_nearest_neighbor(conv5, (28,28)) # 28x28x8
conv6 = tf.layers.conv2d(upsample3, 16, (3,3), padding='same', activation=tf.nn.relu) # 28x28x16
logits = tf.layers.conv2d(conv6, 1, (3,3), padding='same', activation=None) # 28x28x1
decoded = tf.nn.sigmoid(logits, name='decoded') # 28x28x1
### Loss and Optimization:
loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=targets_, logits=logits)
cost = tf.reduce_mean(loss)
opt = tf.train.AdamOptimizer(0.001).minimize(cost)模型在解码部分使用的是upsample+convolution而不是transposed convolution(参考文献)
- 训练网络
sess = tf.Session()
epochs = 20
batch_size = 200
sess.run(tf.global_variables_initializer())
for e in range(epochs):
for ii in range(mnist.train.num_examples//batch_size):
batch = mnist.train.next_batch(batch_size)
imgs = batch[0].reshape((-1, 28, 28, 1))
batch_cost, _ = sess.run([cost, opt], feed_dict={inputs_: imgs, targets_: imgs})
print("Epoch: {}/{}...".format(e+1, epochs), "Training loss: {:.4f}".format(batch_cost)) - 检验网络
fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(20,4))
in_imgs = mnist.test.images[:10]
reconstructed, compressed = sess.run([decoded, encoded], feed_dict={inputs_: in_imgs.reshape((10, 28, 28, 1))})
# plot
for images, row in zip([in_imgs, reconstructed], axes):
for img, ax in zip(images, row):
ax.imshow(img.reshape((28, 28)), cmap='Greys_r')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
fig.tight_layout(pad=0.1)
sess.close()
2. 使用自编码器降噪
- 搭建网络(同上但feature map的个数由16-8-8-8-8-16变为32-32-16-16-32-32)
- 训练网络
sess = tf.Session()
epochs = 100
batch_size = 200
# Set's how much noise we're adding to the MNIST images
noise_factor = 0.5
sess.run(tf.global_variables_initializer())
for e in range(epochs):
for ii in range(mnist.train.num_examples//batch_size):
batch = mnist.train.next_batch(batch_size)
# Get images from the batch
imgs = batch[0].reshape((-1, 28, 28, 1))
# Add random noise to the input images
noisy_imgs = imgs + noise_factor * np.random.randn(*imgs.shape)
# Clip the images to be between 0 and 1
noisy_imgs = np.clip(noisy_imgs, 0., 1.)
# Noisy images as inputs, original images as targets
batch_cost, _ = sess.run([cost, opt], feed_dict={inputs_: noisy_imgs, targets_: imgs})
print("Epoch: {}/{}...".format(e+1, epochs), "Training loss: {:.4f}".format(batch_cost)) - 检验网络
fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(20,4))
in_imgs = mnist.test.images[:10]
noisy_imgs = in_imgs + noise_factor * np.random.randn(*in_imgs.shape)
noisy_imgs = np.clip(noisy_imgs, 0., 1.)
reconstructed = sess.run(decoded, feed_dict={inputs_: noisy_imgs.reshape((10, 28, 28, 1))})
for images, row in zip([noisy_imgs, reconstructed], axes):
for img, ax in zip(images, row):
ax.imshow(img.reshape((28, 28)), cmap='Greys_r')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
fig.tight_layout(pad=0.1)
sess.close()
使用Tensorflow搭建自编码器(Autoencoder)的更多相关文章
- TensorFlow实现自编码器及多层感知机
1 自动编码机简介 传统机器学习任务在很大程度上依赖于好的特征工程,比如对数值型,日期时间型,种类型等特征的提取.特征工程往往是非常耗时耗力的,在图像,语音和视频中提取到有效的特征就更难 ...
- (转)一文学会用 Tensorflow 搭建神经网络
一文学会用 Tensorflow 搭建神经网络 本文转自:http://www.jianshu.com/p/e112012a4b2d 字数2259 阅读3168 评论8 喜欢11 cs224d-Day ...
- [DL学习笔记]从人工神经网络到卷积神经网络_3_使用tensorflow搭建CNN来分类not_MNIST数据(有一些问题)
3:用tensorflow搭个神经网络出来 为什么用tensorflow呢,应为谷歌是亲爹啊,虽然有些人说caffe更适合图像啊mxnet效率更高等等,但爸爸就是爸爸,Android都能那么火,一个道 ...
- 用Tensorflow搭建神经网络的一般步骤
用Tensorflow搭建神经网络的一般步骤如下: ① 导入模块 ② 创建模型变量和占位符 ③ 建立模型 ④ 定义loss函数 ⑤ 定义优化器(optimizer), 使 loss 达到最小 ⑥ 引入 ...
- 一文学会用 Tensorflow 搭建神经网络
http://www.jianshu.com/p/e112012a4b2d 本文是学习这个视频课程系列的笔记,课程链接是 youtube 上的,讲的很好,浅显易懂,入门首选, 而且在github有代码 ...
- 使用Tensorflow搭建回归预测模型之一:环境安装
方法1:快速包安装 一.安装Anaconda 1.官网地址:https://www.anaconda.com/distribution/,选择其中一个版本下载即可,最好安装3.7版本,因为2.7版本2 ...
- 使用Tensorflow搭建回归预测模型之二:数据准备与预处理
前言: 在前一篇中,已经搭建好了Tensorflow环境,本文将介绍如何准备数据与预处理数据. 正文: 在机器学习中,数据是非常关键的一个环节,在模型训练前对数据进行准备也预处理是非常必要的. 一.数 ...
- 用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识
用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识 循环神经网络RNN相比传统的神经网络在处理序列化数据时更有优势,因为RNN能够将加入上(下)文信息进行考虑.一个简单的RNN如 ...
- 用TensorFlow搭建一个万能的神经网络框架(持续更新)
我一直觉得TensorFlow的深度神经网络代码非常困难且繁琐,对TensorFlow搭建模型也十分困惑,所以我近期阅读了大量的神经网络代码,终于找到了搭建神经网络的规律,各位要是觉得我的文章对你有帮 ...
随机推荐
- tomcat内容总结
tomcat的安装以及配置环境变量 1.tomcat的官网下载地址:http://tomcat.apache.org/ tomcat有很多版本,有解压版 和 安装版,还分windows (还分为32位 ...
- 事件循环 event loop 究竟是什么
事件循环 event loop 究竟是什么 一些概念 浏览器运行时是多进程,从任务管理器或者活动监视器上可以验证. 打开新标签页和增加一个插件都会增加一个进程,如下图:  浏览器渲染进程是多线程,包 ...
- Newbe.Claptrap 框架入门,第二步 —— 简单业务,清空购物车
接上一篇 Newbe.Claptrap 框架入门,第一步 —— 创建项目,实现简易购物车 ,我们继续要了解一下如何使用 Newbe.Claptrap 框架开发业务.通过本篇阅读,您便可以开始尝试使用 ...
- mysql 5.7.13 安装配置方法
linux环境Mysql 5.7.13安装教程分享给大家,供大家参考,具体内容如下: 1系统约定 安装文件下载目录:/data/software Mysql目录安装位置:/usr/local/mysq ...
- .Net Core缓存组件(MemoryCache)【缓存篇(二)】
一.前言 .Net Core缓存源码 1.上篇.NET Core ResponseCache[缓存篇(一)]中我们提到了使用客户端缓存.和服务端缓存.本文我们介绍MemoryCache缓存组件,说到服 ...
- PWN头秃之旅 - 4.Retrun-into-libc(攻防世界-level1)
Retrun-into-libc,也写作Retrun2libc.libc是Linux下的ANSI C的函数库,包含了C语言最基本的库函数. Retrun2libc的前提是NX开启,但ASLR关闭,NX ...
- Android应用内部实现多语言,一键切换语言,国际化适配
1.首先提供多语言对应的string值 如en对应英语, fr对应法语 两个文件中包含同样的key, 对应不同的语言的value 2.java代码相应用户切换语言动作 private static v ...
- 推荐IT经理/产品经理,常用工具和网站
一. 常用必备工具 1)文档工具 石墨文档,在线协作文档工具 https://shimo.im/ 2) 表格工具 麦客,在线问卷调查工具 http://www.mikecrm.com/ 3)脑图工具 ...
- RACTF-web C0llide?(js弱类型)
源码: const bodyParser = require("body-parser") const express = require("express") ...
- Nginx安全优化与性能调优
目录 Nginx基本安全优化 隐藏Nginx软件版本号信息 更改源码隐藏Nginx软件名及版本号 修改Nginx服务的默认用户 修改参数优化Nginx服务性能 优化Nginx服务的worker进程数 ...