VAEs最早由“Diederik P. Kingma and Max Welling, “Auto-Encoding Variational Bayes, arXiv (2013)”和“Danilo Jimenez Rezende, Shakir Mohamed, and Daan Wierstra, “Stochastic Backpropagation and Approximate Inference in Deep Generative Models,” arXiv (2014)”同时发现。

原理:

对自编码器来说,它只是将输入数据投影到隐空间中,这些数据在隐空间中的位置是离散的,因此在此空间中进行采样,解码后的输出很可能是毫无意义的。

而对VAEs来说,它将输入数据转换成2个分布,一个是平均值的分布,一个是方差的分布(这就像高斯混合型了),添加上一些噪音,组合后,再进行解码。

如图(网上找的,应该是论文里的,暂时没看论文)

为什么分为2个分布?

可以这么理解:假设均值和方差都有n个,那么编码部分相当于用n个高斯分布(每个输入是不同权重的n个分布的组合)去模拟输入。

再通过一系列变换,转化为隐空间的若干维度,其每个维度可能具有某种意义。比如下面代码使用2维隐空间,可以看作是均值和方差维度。

方差部分指数化,保证非负。添加噪音让隐空间更具有意义的连续性。

然后我们从隐空间采样,由于隐空间具有意义上的连续性,那么解码后的东东就可能类似输入。

损失loss如何定义?为什么?

loss由2部分构成,第一部分就是解码输出与原始输入的loss,可以定义为交叉熵或者均方误差等。

第二部分是约束项。如上图黄色框,m平方作为L2正则化项,前2项可以看做方差减去其泰勒展开,当σ趋近0时,方差也即e^σ为1。那么最小化前2项必然使得σ趋近0(求导即可知)。

由此,这第二部分,m平方项约束使得均值为0,前2项约束使得方差为1。这样约束使得隐空间具有连续性,且强制输入数据在隐空间中的表示范围收拢。

这样在隐空间中2个数据表示的中间,就有一种过渡区域。如果仅以第一部分约束,效果可能就和自编码器一样了,模型会过拟合。


下面进入代码部分

以MNIST数据集作为训练样本。

from keras import backend as K

from keras.models import Model

from keras.metrics import binary_crossentropy

import numpy as np

from keras.layers import Conv2D,Flatten,Dense,Input,Lambda,Reshape,Conv2DTranspose,Layer

from keras.datasets import mnist

from keras.callbacks import EarlyStopping

编码器使用卷积层,输出2个部分

img_shape=(28,28,1)
batch_size=16
latent_dim=2 input_img=Input(shape=img_shape)
x=Conv2D(32,3,padding='same',activation='relu')(input_img)# 28,28,32
x=Conv2D(64,3,padding='same',activation='relu',strides=(2,2))(x)# 14,14,64
x=Conv2D(64,3,padding='same',activation='relu')(x)#14,14,64
x=Conv2D(64,3,padding='same',activation='relu')(x)#14,14,64
# 保存Flatten之前的shape
shape_before_flattening=K.int_shape(x)
x=Flatten()(x)#14*14*64
x=Dense(32,activation='relu')(x)#
# 将输入图像拆分为2个向量
z_mean=Dense(latent_dim)(x)#
z_log_var=Dense(latent_dim)(x)

定义采样方法

def sampling(args):
z_mean,z_log_var=args
# 得到一个平均值为0,方差为1的正态分布,shape为(?,2)
epsilon=K.random_normal(shape=(K.shape(z_mean)[0],latent_dim),mean=0,stddev=1.)#K.shape返回仍是tensor
# tensor*tensor为elementwise操作
return z_mean+K.exp(z_log_var)*epsilon
z=Lambda(sampling)([z_mean,z_log_var])# 采样

解码

# 解码过程,逆操作
decode_input=Input(K.int_shape(z)[1:])
# np.prod表示对数组某个axis进行乘法操作,如果axis不指定,则将所有的元素乘积返回一个值
x=Dense(np.prod(shape_before_flattening[1:]),activation='relu')(decode_input)#14*14*64
# 逆Flatten操作
x=Reshape(shape_before_flattening[1:])(x)#14,14,64
# 反卷积,strides=2将14*14变为28*28,跟Conv2D相反
x=Conv2DTranspose(32,3,padding='same',activation='relu',strides=2)(x)#28,28,32
# 注意这里的激活函数
x=Conv2D(1,3,padding='same',activation='sigmoid')(x)#28,28,1
# 解码model
decoder=Model(decode_input,x)
# 解码后的图片数据
z_decoded=decoder(z)

定义loss,使用一个自定义layer实现

class CustomVariationalLayer(Layer):
def vae_loss(self,x,z_decoded):
x=K.flatten(x)
z_decoded=K.flatten(z_decoded)
# loss为原始输入和编码-解码后的输出比较
xent_loss=binary_crossentropy(x,z_decoded)
# 约束
# mean部分表示L2正则损失,K.exp(z_log_var)-(1+z_log_var)保证方差为1,如果不约束,网络可能偷懒
kl_loss=5e-4*K.mean(K.exp(z_log_var)-(1+z_log_var)+K.square(z_mean),axis=-1)
return K.mean(xent_loss+kl_loss) def call(self,inputs):
x=inputs[0]
z_decoded=inputs[1]
loss=self.vae_loss(x,z_decoded)
# 继承方法
self.add_loss(loss,inputs=inputs)#将根据inputs计算的损失loss加到本layer
return x #不用,但是需要返回点啥 y=CustomVariationalLayer()([input_img,z_decoded])

加载数据,定义、训练模型

(x_train,y_train),(x_test,y_test)=mnist.load_data()

x_train=x_train.astype('float32')/255.
# 表示添加一个通道维度,通道数为1(颜色只有一种模式)
x_train=x_train.reshape(x_train.shape+(1,))
x_test=x_test.astype('float32')/255.
x_test=x_test.reshape(x_test.shape+(1,))
vae=Model(input_img,y)
# 自定义层y里面已经包含了loss,这里不需要指定
vae.compile(optimizer='rmsprop',loss=None)
# 不需要标签,所以y为None,我们只需要知道一个图片的原始输入是否和编码-解码后的输出一致
vae.fit(x=x_train,y=None,shuffle=True,epochs=10,batch_size=batch_size,validation_data=(x_test,None),callbacks=[EarlyStopping(patience=2)],verbose=2)

测试

import matplotlib.pyplot as plt
from scipy.stats import norm # 潜空间中任意矢量可以解码成数字
n = 10
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
# norm.ppf([v1,v2...])表示正态分布积分值为vi时,对应的x轴坐标值xi
grid_x = norm.ppf(np.linspace(0.05, 0.95, n))#可以看作均值
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))#方差
for i, yi in enumerate(grid_x):
for j, xi in enumerate(grid_y):
z_sample = np.array([[xi, yi]])
# np.tile将数组重复n次,如[1,2]->[1,2,1,2]。然后reshape到输入格式
z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2)
x_decoded = decoder.predict(z_sample, batch_size=batch_size)
# 因为x_decoded为16个相同矢量得到的推导,取第一个就行,再将 28*28*1 reshape到 28*28
digit = x_decoded[0].reshape(digit_size, digit_size)
figure[i * digit_size: (i + 1) * digit_size,
j * digit_size: (j + 1) * digit_size] = digit
plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
plt.show()

结果如下,可以看到,图片是连续变化的。

VAEs(变分自编码)之keras实践的更多相关文章

  1. Keras实践:模型可视化

    Keras实践:模型可视化 安装Graphviz 官方网址为:http://www.graphviz.org/.我使用的是mac系统,所以我分享一下我使用时遇到的坑. Mac安装时在终端中执行: br ...

  2. Keras实践:实现非线性回归

    Keras实践:实现非线性回归 代码 import os os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" import ke ...

  3. GAN(生成对抗网络)之keras实践

    GAN由论文<Ian Goodfellow et al., “Generative Adversarial Networks,” arXiv (2014)>提出. GAN与VAEs的区别 ...

  4. 分享几个 PHP 编码的最佳实践

    对于初学者而言,可能很难理解为什么某些做法更安全. 但是,以下一些技巧可能超出了 PHP 的范围. 始终使用大括号 让我们看下面的代码: if (isset($condition) && ...

  5. 2.keras实现-->字符级或单词级的one-hot编码 VS 词嵌入

    1. one-hot编码 # 字符集的one-hot编码 import string samples = ['zzh is a pig','he loves himself very much','p ...

  6. ​结合异步模型,再次总结Netty多线程编码最佳实践

    更多技术分享可关注我 前言 本文重点总结Netty多线程的一些编码最佳实践和注意事项,并且顺便对Netty的线程调度模型,和异步模型做了一个汇总.原文:​​结合异步模型,再次总结Netty多线程编码最 ...

  7. 文本离散表示(二):新闻语料的one-hot编码

    上一篇博客介绍了文本离散表示的one-hot.TF-IDF和n-gram方法,在这篇文章里,我做了一个对新闻文本进行one-hot编码的小实践. 文本的one-hot相对而言比较简单,我用了两种方法, ...

  8. 通过keras例子理解LSTM 循环神经网络(RNN)

    博文的翻译和实践: Understanding Stateful LSTM Recurrent Neural Networks in Python with Keras 正文 一个强大而流行的循环神经 ...

  9. 算术编码Arithmetic Coding-高质量代码实现详解

    关于算术编码的具体讲解我不多细说,本文按照下述三个部分构成. 两个例子分别说明怎么用算数编码进行编码以及解码(来源:ARITHMETIC CODING FOR DATA COIUPRESSION): ...

随机推荐

  1. Java笔记(第五篇)

    抛出异常 使用throws声明抛出异常 Throws 通常用于方法声明,当方法中可能存在异常,却不想在方法中对异常进行处理时,就可以在声明方法时使用throws声明抛出的异常,然后再调用该方法的其他方 ...

  2. 包 ,模块(time、datetime、random、hashlib、typing、requests、re)

    目录 1. 包 1. 优先掌握 2. 了解 3. datetime模块 1. 优先掌握 4. random模块 1. 优先掌握 2. 了解 5. hashlib模块和hmac模块 6. typing模 ...

  3. Lighting Techinology of the Last Of Us (2013 SIGGRAPH)

    Lighting Techinology of the Last Of Us(2013 SIGGRAPH) or "Old Lightmaps - New Tricks" 原作:M ...

  4. Python程序设计《集美大学各省成绩分析》

    分析文件‘集美大学各省录取分数.xlsx’,完成以下功能: 1)集美大学2015-2018年间不同省份在本一批的平均分数,柱状图展示排名前10的省份, 2)分析福建省这3年各批次成绩情况,使用折线图展 ...

  5. JSONOjbect,对各种属性的处理

    import com.alibaba.fastjson.JSONObject; public class JsonTest { public static void main(String[] arg ...

  6. Easily use UUIDs in Laravel

    Easily use UUIDs in Laravel  Wilbur PoweryOct 29 '18 Updated on Oct 30, 2018 ・1 min read #php #larav ...

  7. 【csp模拟赛1】T1 心有灵犀

    [题目描述] 爱玩游戏的小 Z 最近又换了一个新的游戏.这个游戏有点特别,需要两位玩 家心有灵犀通力合作才能拿到高分. 游戏开始时,两位玩家会得到同一个数字 N,假设这个数字共有 t 位数码, 然后两 ...

  8. 【线性代数】4-1:四个正交子空间(Orthogonality of the Four Subspace)

    title: [线性代数]4-1:四个正交子空间(Orthogonality of the Four Subspace) categories: Mathematic Linear Algebra k ...

  9. 如何使用PHP排序key为字母+数字的数组

    你还在为如何使用PHP排序字母+数字的数组而烦恼吗? 今天有个小伙伴在群里问: 如何将一个key为字母+数字的数组按升序排序呢? 举个例子: $test = [ 'n1' => 22423, ' ...

  10. CDialog::DoModal()问题和_WIN32_WINNT

    1.从CDialogEx派生自己的CMyDialog,到DoModal()时总提示 error C2039: "DoModal": 不是"CMyDialog"的 ...