Outline

  • Auto-Encoder

  • Variational Auto-Encoders

Auto-Encoder

创建编解码器

import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import Sequential, layers
from PIL import Image
from matplotlib import pyplot as plt tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.') def save_images(imgs, name):
new_im = Image.new('L', (280, 280)) index = 0
for i in range(0, 280, 28):
for j in range(0, 280, 28):
im = imgs[index]
im = Image.fromarray(im, mode='L')
new_im.paste(im, (i, j))
index += 1 new_im.save(name) h_dim = 20 # 784降维20维
batchsz = 512
lr = 1e-3 (x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(
np.float32) / 255.
# we do not need label
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batchsz * 5).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batchsz) print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape) class AE(keras.Model):
def __init__(self):
super(AE, self).__init__() # Encoders
self.encoder = Sequential([
layers.Dense(256, activation=tf.nn.relu),
layers.Dense(128, activation=tf.nn.relu),
layers.Dense(h_dim)
]) # Decoders
self.decoder = Sequential([
layers.Dense(128, activation=tf.nn.relu),
layers.Dense(256, activation=tf.nn.relu),
layers.Dense(784)
]) def call(self, inputs, training=None):
# [b,784] ==> [b,19]
h = self.encoder(inputs) # [b,10] ==> [b,784]
x_hat = self.decoder(h) return x_hat model = AE()
model.build(input_shape=(None, 784)) # tensorflow尽量用元组
model.summary()
(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)
Model: "ae"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
sequential (Sequential) multiple 236436
_________________________________________________________________
sequential_1 (Sequential) multiple 237200
=================================================================
Total params: 473,636
Trainable params: 473,636
Non-trainable params: 0
_________________________________________________________________

训练

optimizer = tf.optimizers.Adam(lr=lr)

for epoch in range(10):

    for step, x in enumerate(train_db):

        # [b,28,28]==>[b,784]
x = tf.reshape(x, [-1, 784]) with tf.GradientTape() as tape:
x_rec_logits = model(x) rec_loss = tf.losses.binary_crossentropy(x,
x_rec_logits,
from_logits=True)
rec_loss = tf.reduce_min(rec_loss) grads = tape.gradient(rec_loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables)) if step % 100 == 0:
print(epoch, step, float(rec_loss)) # evaluation x = next(iter(test_db))
logits = model(tf.reshape(x, [-1, 784]))
x_hat = tf.sigmoid(logits)
# [b,784]==>[b,28,28]
x_hat = tf.reshape(x_hat, [-1, 28, 28]) # [b,28,28] ==> [2b,28,28]
x_concat = tf.concat([x, x_hat], axis=0)
# x_concat = x # 原始图片
x_concat = x_hat
x_concat = x_concat.numpy() * 255.
x_concat = x_concat.astype(np.uint8) # 保存为整型
if not os.path.exists('ae_images'):
os.mkdir('ae_images')
save_images(x_concat, 'ae_images/rec_epoch_%d.png' % epoch)
0 0 0.09717604517936707
0 100 0.12493347376585007
1 0 0.09747321903705597
1 100 0.12291513383388519
2 0 0.10048121958971024
2 100 0.12292417883872986
3 0 0.10093794018030167
3 100 0.12260882556438446
4 0 0.10006923228502274
4 100 0.12275046110153198
5 0 0.0993042066693306
5 100 0.12257824838161469
6 0 0.0967678651213646
6 100 0.12443818897008896
7 0 0.0965462476015091
7 100 0.12179268896579742
8 0 0.09197664260864258
8 100 0.12110235542058945
9 0 0.0913471132516861
9 100 0.12342415750026703

Auto-Encoders实战的更多相关文章

  1. [Python] 机器学习库资料汇总

    声明:以下内容转载自平行宇宙. Python在科学计算领域,有两个重要的扩展模块:Numpy和Scipy.其中Numpy是一个用python实现的科学计算包.包括: 一个强大的N维数组对象Array: ...

  2. python数据挖掘领域工具包

    原文:http://qxde01.blog.163.com/blog/static/67335744201368101922991/ Python在科学计算领域,有两个重要的扩展模块:Numpy和Sc ...

  3. Theano3.1-练习之初步介绍

    来自 http://deeplearning.net/tutorial/,虽然比较老了,不过觉得想系统的学习theano,所以需要从python--numpy--theano的顺序学习.这里的资料都很 ...

  4. [resource]Python机器学习库

    reference: http://qxde01.blog.163.com/blog/static/67335744201368101922991/ Python在科学计算领域,有两个重要的扩展模块: ...

  5. 机器学习——深度学习(Deep Learning)

    Deep Learning是机器学习中一个非常接近AI的领域,其动机在于建立.模拟人脑进行分析学习的神经网络,近期研究了机器学习中一些深度学习的相关知识,本文给出一些非常实用的资料和心得. Key W ...

  6. Deep Learning Tutorial - Classifying MNIST digits using Logistic Regression

    Deep Learning Tutorial 由 Montreal大学的LISA实验室所作,基于Theano的深度学习材料.Theano是一个python库,使得写深度模型更容易些,也可以在GPU上训 ...

  7. [转]Python机器学习工具箱

    原文在这里  Python在科学计算领域,有两个重要的扩展模块:Numpy和Scipy.其中Numpy是一个用python实现的科学计算包.包括: 一个强大的N维数组对象Array: 比较成熟的(广播 ...

  8. 深度学习材料:从感知机到深度网络A Deep Learning Tutorial: From Perceptrons to Deep Networks

    In recent years, there’s been a resurgence in the field of Artificial Intelligence. It’s spread beyo ...

  9. Deep Learning(4)

    四.拓展学习推荐 Deep Learning 经典阅读材料: The monograph or review paper Learning Deep Architectures for AI (Fou ...

  10. 深度学习教程Deep Learning Tutorials

    Deep Learning Tutorials Deep Learning is a new area of Machine Learning research, which has been int ...

随机推荐

  1. iOS端样式错位

    在iOS端上点击的时候触发点会在当前元素上方,原因是在最外层使用了fixed定位,换成绝对或相对定位解决问题

  2. ssh 公钥登录远程主机 并禁止密码登录

    https://www.digitalocean.com/community/tutorials/how-to-set-up-ssh-keys-on-centos7 如果在新的机器上,得先用密码登录一 ...

  3. Android Dialogs(2)最好用DialogFragment创建Dialog

    Creating a Dialog Fragment You can accomplish a wide variety of dialog designs—including custom layo ...

  4. 转-关于UIView的autoresizingMask属性的研究

    在 UIView 中有一个autoresizingMask的属性,它对应的是一个枚举的值(如下),属性的意思就是自动调整子控件与父控件中间的位置,宽高. 1 2 3 4 5 6 7 8 9 enum  ...

  5. Kali linux 2016.2(Rolling)里的枚举服务

    前言 枚举是一类程序,它允许用户从一个网络中收集某一类的所有相关服务.

  6. 移动web开发填坑(一)

    上周开始接触移动web开发,默默的掉进了很多坑里面.本文主要总结本周遇到的坑以及如何填坑. 1.px与rem换算. 设计稿的宽度一般是640px,而iphone是320px,所以测量设计稿的结果首先要 ...

  7. [ USACO 2013 OPEN ] Photo

    \(\\\) Description 有一个长度为 \(n\) 的奶牛队列,奶牛颜色为黑或白. 现给出 \(m\) 个区间 \([L_i,R_i]\) ,要求:每个区间里 有且只有一只黑牛 . 问满足 ...

  8. hihocoder offer收割编程练习赛11 B 物品价值

    思路: 状态压缩 + dp. 实现: #include <iostream> #include <cstdio> #include <cstring> #inclu ...

  9. 微信小程序button授权页面,用户拒绝后仍可再次授权

    微信小程序授权页面,进入小程序如果没授权跳转到授权页面,授权后跳转到首页,如果用户点拒绝下次进入小程序还是能跳转到授权页面,授权页面如下 app.js  中的 onLaunch或onShow中加如下代 ...

  10. NavigationView的使用

    代码已经分享至github:https://github.com/YanYoJun/NavigationDemo 转载请注明原文链接:http://www.cnblogs.com/yanyojun/p ...