VAE变分自编码器Keras实现
变分自编码器(variational autoencoder, VAE)是一种生成模型,训练模型分为编码器和解码器两部分。
编码器将输入样本映射为某个低维分布,这个低维分布通常是不同维度之间相互独立的多元高斯分布,因此编码器的输出为这个高斯分布的均值与对数方差(因为方差总是大于0,为了将它映射到$(-\infty,\infty)$,所以加了对数)。在编码器的分布中抽样后,解码器做的事是将从这个低维抽样重新解码,生成与输入样本相似的数据。数据可以是图像、文字、音频等。
VAE模型的结构不难理解,关键在于它的损失函数的定义。我们要让解码器的输出与编码器的输入尽量相似,这个损失可以由这二者之间的二元交叉熵(binary crossentropy)来定义。但是仅由这个作为最终的目标函数是不够的。在这样的目标函数下,不断的梯度下降,会使编码器在不同输入下的输出均值之间的差别越来越大,输出方差则会不断地趋向于0,也就是对数方差趋向于负无穷。因为只有这样才会使从生成分布获取的抽样更加明确,从而让解码器能生成与输入数据更接近的数据,以使损失变得更小。但是这就与生成器的初衷有悖了,生成器的初衷实际上是为了生成更多“全新”的数据,而不是为了生成与输入数据“更像”的数据。所以,我们还要再给目标函数加上编码器生成分布的“正则化损失”:生成分布与标准正态分布之间的KL散度(相对熵)。让生成分布不至于“太极端、太确定”,从而让不同输入数据的生成分布之间有交叉 。于是解码器通过这些交叉的“缓冲带”上的抽样,能够生成“中间数据”,产生意想不到的效果。
详细的分析请看:变分自编码器VAE:原来是这么一回事 - 知乎
以下使用Keras实现VAE生成图像,数据集是MNIST。
代码实现
编码器
编码器将MNIST的数字图像转换为2维的正态分布均值与对数方差。简单堆叠卷积层与全连接层即可,代码如下:
#%%编码器
import numpy as np
import keras
from keras import layers,Model,models,utils
from keras import backend as K
from keras.datasets import mnist img_shape = (28,28,1)
latent_dim = 2 input_img = layers.Input(shape=img_shape)
x = layers.Conv2D(32,3,padding='same',activation='relu')(input_img)
x = layers.Conv2D(64,3,padding='same',activation='relu',strides=2)(x)
x = layers.Conv2D(64,3,padding='same',activation='relu')(x)
x = layers.Conv2D(64,3,padding='same',activation='relu')(x)
inter_shape = K.int_shape(x)
x = layers.Flatten()(x)
x = layers.Dense(32,activation='relu')(x) encode_mean = layers.Dense(2,name = 'encode_mean')(x) #分布均值
encode_log_var = layers.Dense(2,name = 'encode_logvar')(x) #分布对数方差 encoder = Model(input_img,[encode_mean,encode_log_var],name = 'encoder')
解码器
解码器接受2维向量,将这个向量“解码”为图像。同样也是简单的堆叠卷积层、逆卷积层与全连接层即可,代码如下:
#%%解码器
input_code = layers.Input(shape=[2])
x = layers.Dense(np.prod(inter_shape[1:]),activation='relu')(input_code)
x = layers.Reshape(target_shape=inter_shape[1:])(x)
x = layers.Conv2DTranspose(32,3,padding='same',activation='relu',strides=2)(x)
x = layers.Conv2D(1,3,padding='same',activation='sigmoid')(x) decoder = Model(input_code,x,name = 'decoder')
整体待训练模型
整个待训练模型包括编码器、抽样层、解码器。中间的抽样操作在获取编码器传出的均值与方差后,通过一个自定义的lambda层来实现。这个抽样是先从标准正态分布中抽样,再通过乘生成分布的标准差,加上均值来获得。因此这个操作并不会把反向传播中断,可以将编码器与解码器的张量流连接起来。
定义好模型后是损失的定义,如前面所说,最终损失(目标函数)是生成图像与原图像之间的二元交叉熵和生成分布的正则化的平均值。使用add_loss方法来添加模型的损失,具体的自定义损失方法看链接。
代码如下:
#%%整体待训练模型
def sampling(arg):
mean = arg[0]
logvar = arg[1]
epsilon = K.random_normal(shape=K.shape(mean),mean=0.,stddev=1.) #从标准正态分布中抽样
return mean + K.exp(0.5*logvar) * epsilon #获取生成分布的抽样
input_img = layers.Input(shape=img_shape,name = 'img_input')
code_mean, code_log_var = encoder(input_img) #获取生成分布的均值与方差
x = layers.Lambda(sampling,name = 'sampling')([code_mean, code_log_var])
x = decoder(x)
training_model = Model(input_img,x,name = 'training_model') decode_loss = keras.metrics.binary_crossentropy(K.flatten(input_img), K.flatten(x))
kl_loss = -5e-4*K.mean(1+code_log_var-K.square(code_mean)-K.exp(code_log_var))
training_model.add_loss(K.mean(decode_loss+kl_loss)) #新出的方法,方便得很
training_model.compile(optimizer='rmsprop')
训练
因为损失函数并没有定义真实数据与预测数据直接的损失,因此fit方法只需传入输入即可(不用输出)。代码如下:
#%%读取数据集训练
(x_train,y_train),(x_test,y_test) = mnist.load_data()
x_train = x_train.astype('float32')/255
x_train = x_train[:,:,:,np.newaxis] training_model.fit(
x_train,
batch_size=512,
epochs=100,
validation_data=(x_train[:2],None))
生成测试
使用scipy.stats中的norm.ppf方法在概率区间(0.01,0.99)内生成20*20个解码器输入,这个方法类似在标准正态分布中抽样,但并不是随机的,是正态分布下的等概率。生成的二维点分布如下图:
这样抽样而不均匀抽样为了和编码器的生成分布契合,因为编码器正则化后生成的分布是靠近标准正态分布的。然后用解码器生成图片,这一部分的代码如下:
#%%测试
from scipy.stats import norm
import numpy as np
import matplotlib.pyplot as plt
n = 20
x = y = norm.ppf(np.linspace(0.01,0.99,n)) #生成标准正态分布数
X,Y = np.meshgrid(x,y) #形成网格
X = X.reshape([-1,1]) #数组展平
Y = Y.reshape([-1,1])
input_points = np.concatenate([X,Y],axis=-1)#连接为输入
for i in input_points:
plt.scatter(i[0],i[1])
plt.show() img_size = 28
predict_img = decoder.predict(input_points)
pic = np.empty([img_size*n,img_size*n,1])
for i in range(n):
for j in range(n):
pic[img_size*i:img_size*(i+1), img_size*j:img_size*(j+1)] = predict_img[i*n+j]
plt.figure(figsize=(10,10))
plt.axis('off')
pic = np.squeeze(pic)
plt.imshow(pic,cmap='bone')
plt.show()
生成的400张图:
可以看出来,二维坐标系中某个方向的编码是可以使解码器的输出从一个数字变换到另一个数字的。
VAE变分自编码器Keras实现的更多相关文章
- VAE变分自编码器实现
变分自编码器(VAE)组合了神经网络和贝叶斯推理这两种最好的方法,是最酷的神经网络,已经成为无监督学习的流行方法之一. 变分自编码器是一个扭曲的自编码器.同自编码器的传统编码器和解码器网络一起,具有附 ...
- VAE变分自编码器
我在学习VAE的时候遇到了很多问题,很多博客写的不太好理解,因此将很多内容重新进行了整合. 我自己的学习路线是先学EM算法再看的变分推断,最后学VAE,自我感觉这个线路比较好理解. 一.首先我们来宏观 ...
- VAE变分自编码器公式推导
VAE变分推导依赖数学公式 (1)贝叶斯公式:\(p(z|x) = \frac{p(x|z)p(z)}{p(x)}\) (2)边缘概率公式:\(p(x) =\int{p(x,z)}dz\) (3)KL ...
- Variational Auto-encoder(VAE)变分自编码器-Pytorch
import os import torch import torch.nn as nn import torch.nn.functional as F import torchvision from ...
- 4.keras实现-->生成式深度学习之用变分自编码器VAE生成图像(mnist数据集和名人头像数据集)
变分自编码器(VAE,variatinal autoencoder) VS 生成式对抗网络(GAN,generative adversarial network) 两者不仅适用于图像,还可以 ...
- (转) 变分自编码器(Variational Autoencoder, VAE)通俗教程
变分自编码器(Variational Autoencoder, VAE)通俗教程 转载自: http://www.dengfanxin.cn/?p=334&sukey=72885186ae5c ...
- 变分自编码器(Variational Autoencoder, VAE)通俗教程
原文地址:http://www.dengfanxin.cn/?p=334 1. 神秘变量与数据集 现在有一个数据集DX(dataset, 也可以叫datapoints),每个数据也称为数据点.我们假定 ...
- 变分自编码器(Variational auto-encoder,VAE)
参考: https://www.cnblogs.com/huangshiyu13/p/6209016.html https://zhuanlan.zhihu.com/p/25401928 https: ...
- 基于变分自编码器(VAE)利用重建概率的异常检测
本文为博主翻译自:Jinwon的Variational Autoencoder based Anomaly Detection using Reconstruction Probability,如侵立 ...
- 变分推断到变分自编码器(VAE)
EM算法 EM算法是含隐变量图模型的常用参数估计方法,通过迭代的方法来最大化边际似然. 带隐变量的贝叶斯网络 给定N 个训练样本D={x(n)},其对数似然函数为: 通过最大化整个训练集的对数边际似然 ...
随机推荐
- failed to copy: httpReadSeeker: failed open: unexpected status code xxx 403
ack上pull镜像的时候,报的错 非运行脚本的问题,由负责ack相关设定的人员调整即可
- 【转】ElasticSearch报错FORBIDDEN/12/index read-only / allow delete (api) ,read_only_allow_delete 设置 windows
仅供自己记录使用,原文链接:ElasticSearch报错FORBIDDEN/12/index read-only / allow delete (api)_sinat_22387459的博客-CSD ...
- 2023.7.2-3-4Mssql xp_cmdshell提权
1.概念 Mssql和SQL sever的一个产品的不同名称.都属于微软公司旗下.而上述Mssql xp_cmdshell提权也属于数据库提权的一种. 主要依赖于sql server自带的存储过程. ...
- eclipse真的落后了嘛?这几点优势其他IDE比不上
序言 各位好啊,我是会编程的蜗牛,作为java开发者,我们每天都要和开发工具打交道.我以前一开始入门java开发的时候,就是用的eclipse,虽然感觉有点繁琐,但好在还能用.后来偶然间发现了IDEA ...
- 这才是批量update的正确姿势!
前言 最近我有位小伙伴问我,在实际工作中,批量更新的代码要怎么写. 这个问题挺有代表性的,今天拿出来给大家一起分享一下,希望对你会有所帮助. 1 案发现场 有一天上午,在我的知识星球群里,有位小伙伴问 ...
- 暑假集训CSP提高模拟8
一看见题目列表就吓晕了,还好我是体育生,后面忘了 唉这场比赛没啥好写的,要不就是太难要不就是太简单要不就是拉出去写在专题里了 A. 基础的生成函数练习题 考虑到只有奇偶性相同才能尝试加二,因此先用加一 ...
- USB gadget configfs
概述 USB Linux Gadget是一种具有UDC (USB设备控制器)的设备,可以连接到USB主机,以扩展其附加功能,如串口或大容量存储能力. 一个gadget被它的主机视为一组配置,每个配置都 ...
- 进程D 状态的产生及原因解释
在 Linux 系统中,进程的 D 状态表示进程处于不可中断的睡眠状态 (Uninterruptible Sleep).这种状态通常由进程等待某些资源或事件引起,这些资源或事件无法立即可用.以下是一些 ...
- kotlin更多语言结构——>作用域函数
作用域函数 Kotlin 标准库包含几个函数,它们的唯一目的是在对象的上下文中执行代码块.当对一个对象调用这样的函数 并提供一个 lambda 表达式时,它会形成一个临时作用域.在此作用域中,可以访问 ...
- 云原生周刊:CNCF 宣布 Falco 毕业|2024.3.4
开源项目推荐 ldap-operator 用于部署和管理 LDAP 目录的 Kubernetes Operator. Updatecli Updatecli 是一个用于应用文件更新策略的工具.每个应用 ...