变分自动编码器生成图片

从隐图像空间进行采样以创建全新的图像或编辑现有图像是目前创作AI最受欢迎和最成功的应用方式。

图像隐空间取样

图像生成的关键思想是开发表示的低维潜在空间(自然是矢量空间),其中任何点都可以映射到逼真的图像上。 能够实现该映射的模块,将潜在点作为输入并输出图像(像素网格),被称为生成器(在GAN的情况下)或解码器(在VAE的情况下)。一旦开发出这样的潜在空间,可以有意或无意地从中采样点,并通过将它们映射到图像空间,生成以前从未见过的图像。

GAN和VAE是用于学习图像表示的潜在空间的两种不同策略,每种都具有其自身的特征。VAE非常适合学习结构良好的潜在空间,其中特定方向编码数据能产生有意义的变化轴。GAN生成的图像可能非常逼真,但它们来自潜在的空间可能没有那么多的结构和连续性

图片编辑的概念向量

给定潜在的表示空间或嵌入空间,空间中的某些方向可以编码原始数据中有趣的变化轴。例如,在面部图像的潜在空间中,可能存在微笑矢量s,使得如果潜在点z是某个面部的嵌入表示,则潜在点z+s是同一面部的嵌入表示,面带微笑。一旦确定了这样的矢量,就可以通过将图像投影到潜在空间中来编辑图像,以有意义的方式移动它们的表示,然后将它们解码回图像空间。

变分自动编码器

变分自动编码器,是一种生成模型,特别适用于通过概念向量进行图像编辑的任务。它们是自动编码器的现代版本 - 一种旨在将输入编码到低维潜在空间然后将其解码回来的网络 - 将来自深度学习的想法与贝叶斯推理混合在一起.

经典图像自动编码器通过编码器模块拍摄图像,将其映射到潜在的矢量空间,然后通过解码器模块将其解码回与原始图像具有相同尺寸的输出。然后通过使用与输入图像相同的图像作为目标数据来训练,这意味着自动编码器学习重建原始输入。通过对代码(编码器的输出)施加各种约束,可以使自动编码器学习或多或少有趣的数据潜在表示。最常见的是,将限制代码为低维和稀疏(大多数为零),在这种情况下,编码器可以将输入数据压缩为更少的信息位。

在实践中,这种经典的自动编码器不会导致特别有用或结构良好的潜在空间,也不太擅长数据压缩。由于这些原因,他们已经基本上不再流行。然而,VAE用统计方法增强了自动编码器,迫使他们学习连续的,高度结构化的潜在空间。它们已成为图像生成的强大工具

VAE不是将其输入图像压缩为潜在空间中的固定代码,而是将图像转换为统计分布的参数:均值和方差。从本质上讲,这意味着假设输入图像是由统计过程生成的,并且此过程的随机性应在编码和解码期间用于计算。然后,VAE使用均值和方差参数随机采样分布的一个元素,并将该元素解码回原始输入。该过程的随机性提高了鲁棒性并迫使潜在空间在任何地方编码有意义的表示:在潜在空间中采样的每个点被解码为有效输出

数学描述,VAE工作过程:

  1. 编码器模块将输入样本input_img转换为表示的隐空间的z_mean和z_log_variance两个参数;
  2. 通过z=z_mean + exp(z_log_variance)*epsilon 从假定生成输入图像的潜在正态分布中随机采样点z,其中epsilon是小值的随机张量;
  3. 解码器模块将隐空间中的z点映射回原始输入图像。

因为epsilon是随机的,所以该过程确保接近编码input_img(z-mean)的潜在位置的每个点都可以被解码为类似于input_img的东西,从而迫使潜在空间持续有意义。潜在空间中的任何两个闭合点将解码为高度相似的图像。连续性与潜在空间的低维度相结合,迫使潜在空间中的每个方向编码有意义的数据变化轴,使得潜在空间非常结构化,因此非常适合通过概念向量进行操纵

VAE的参数通过两个损失函数进行训练:强制解码样本与初始输入匹配的重建损失函数,以及有助于学习良好的隐空间并减少过度拟合训练数据的正则化损失函数。让我们快速了解一下VAE的Keras实现。原理上,它看起来像这样:

z_mean,z_log_variance = encoder(input_img)#输入编码成均值、方法参数

z = z_mean + exp(z_log_variance)*epsilon#隐空间通过epsilon取样

reconstructed_img = decoder(z)#取样点生成新图片

model = Model(input_img,reconstructed_img)#实例化模型:输入图片映射到新建图片上,之后训练

模型定义后,使用重建损失函数和正则损失训练模型。

使用一个简单的convnet将输入图片映射到隐空间的概率分布上,得到两个向量z_mean,z_log_var。

VAE Encoder网络

import keras
from keras import layers
from keras import backend as K
from keras.models import Model
import numpy as np img_shape = (28,28,1)
batch_size = 16
latent_dim = 2 input_img = keras.Input(shape=img_shape)
x = layers.Conv2D(32,3,padding='same',activation='relu')(input_img)
x = layers.Conv2D(64,3,padding='same',activation='relu',stride=(2,2))(x)
x = layers.Conv2D(64,3,padding='same',activation='relu')(x)
x = layers.Conv2D(64,3,padding='same',activation='relu')(x)
shape_before_flattening = K.int_shape(x) x = layers.Flatten()(x)
x = layers.Dense(32,activation='relu')(x)
#输入图片最终 编码成 两个参数
z_mean = layers.Dense(latent_dim)(x)
z_log_var = layers.Dense(latent_dim)(x)

之后使用输入图片的假设空间分布特征z_mean和z_log_var得到隐空间取样点z。在这里,将一些任意代码(构建在Keras后端基元之上)包装到Lambda层中。在Keras中,一切都需要是一个层,因此不属于内置层的代码应该包装在Lambda(或自定义层)中.

隐空间取样函数

def sampling(args):
z_mean,z_log_var = args
epsilon=K.random_normal(shape=(K.shape(z_mean)[0],
latent_dim),mean=0.,stddev=1.)
return z_mean + K.exp(z_log_var)*epsilon z = layers.Lambda(sampling)([z_mean,z_log_var])

解码器部分实现。将向量z reshape到图片尺寸,最后经过几个卷积层得到最终的图片输出。

VAE decoder网络:隐变量空间到图片

decoder_input = layers.Input(K.int_shape(z)[1:])#输入z向量

x = layers.Dense(np.prod(shape_before_flattening[1:]),activation='relu')(decoder_input)
x = layers.Reshape(shape_before_flattening[1:])(x)
x = layers.Conv2DTranspose(32,3,padding='same',activation='relu',strides=(2,2))(x)
x = layers.Conv2D(1,3,padding='same',activation='sigmoid')(x) decoder = Model(decoder_input, x)#实例化模型,模型将输入decoder_input转换成图片
z_decoded = decoder(z)#输入z,得到最终转换后的输出图片

VAE的双重损失函数不符合传统形式损失函数(输入,目标)的预期。因此,将通过编写内部使用内置add_loss图层方法来创建任意损失的自定义图层来设置损失函数。

定义图层计算损失函数

class CustomVariationalLayer(keras.layers.Layer):
def vae_loss(self, x, z_decoded):
x = K.flatten(x)
z_decoded = K.flatten(z_decoded)
xent_loss = keras.metrics.binary_crossentropy(x, z_decoded)#重构损失
kl_loss = -5e-4 * K.mean(1+z_log_var-K.square(z_mean)-
K.exp(z_log_var), axis=-1)#encoder损失
return K.mean(xent_loss + kl_loss) def call(self,inputs):
x = inputs[0]
z_decoded = inputs[1]
loss = self.vae_loss(x,z_decoded)
self.add_loss(loss,inputs=inputs)
return x y = CustomVariationalLayer()([input_img, z_decoded])

最后,实例化模型并训练。由于损失函数是在自定义层中处理的,因此不会在编译时指定外部损失(loss=None),这反过来意味着不会在训练期间传递目标数据(如所见,只能将x_train传递给模型在fit函数中)。

VAE训练

from keras.datasets import mnist

vae = Model(input_img,y)#通过定义输入和输出 Model模型
vae.compile(optimizer='rmsprop', loss=None)
vae.summary()
(x_train, _), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_train = x_train.reshape(x_train.shape + (1,))
x_test = x_test.astype('float32') / 255.
x_test = x_test.reshape(x_test.shape + (1,))
vae.fit(x=x_train, y=None,shuffle=True,epochs=10,batch_size=batch_size,
validation_data=(x_test, None))

模型训练完成后,可以使用decoder模块将任意隐变量空间点转换生成图片。

2D隐变量空间点取样,生成图片

import matplotlib.pyplot as plt
from scipy.stats import norm n = 15#15*15 225个数字图片
digit_size = 28
figure = np.zeros((digit_size*n,digit_sie*n))#最终图片
grid_x = norm.ppf(np.linspace(0.05,0.95,n))#假设隐变量空间符合高斯分布
grid_y = norm.ppf(np.linspace(0.05,0.95,n))#ppf随机取样 for i,yi in enumerate(grid_x):
for j, xi in enumerate(grid_y):
z_sample = np.array([[xi, yi]])
#重复z_sample多次,形成一个完整的batch
z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2)
x_decoded = decoder.predict(z_sample, batch_size=batch_size)
digit=x_decoded[0].reshape(digit_size, digit_size)#28*28*1->28*28
figure[i*digit_size:(i+1)*digit_size,j*digit_size:(j+1)*digit_size] = digit plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
plt.show()

小结

  • 深度学习的图像生成是通过学习捕获有关图像数据集的统计信息的潜在空间来完成的。通过对潜在空间中的点进行采样和解码,可以生成前所未见的图像。有两个主要工具:VAE和GAN
  • VAE导致高度结构化,连续的潜在表征。出于这个原因,它们适用于在潜在空间中进行各种图像编辑:面部交换,将皱眉脸变成笑脸,等等。它们也可以很好地用于基于潜在空间的动画,例如沿着潜在空间的横截面动画制作动画,显示起始图像以连续的方式慢慢变形为不同的图像。
  • GAN可以生成逼真的单帧图像,但可能不会引入具有坚固结构和高连续性的潜在空间。

[图片生成]使用VAEs生成新图片的更多相关文章

  1. C#图片裁切,生成新图片

    /// 图片裁剪,生成新图,保存在同一目录下,名字加_new,格式1.png 新图1_new.png /// </summary> /// <param name="pic ...

  2. C# 图片的裁剪,两个图片合成一个图片

    图片的裁剪,两个图片合成一个图片(这是从网上摘的) /// <summary>         /// 图片裁剪,生成新图,保存在同一目录下,名字加_new,格式1.png  新图1_ne ...

  3. IOS 截取图片 部分 并生成新图片

    /** * 从图片中按指定的位置大小截取图片的一部分 * * @param image UIImage image 原始的图片 * @param rect CGRect rect 要截取的区域 * * ...

  4. Highcharts结合PhantomJS在服务端生成高质量的图表图片

    项目背景 最近忙着给部门开发一套交互式的报表系统,来替换原有的静态报表系统. 老系统是基于dotnetCHARTING开发的,dotnetCHARTING的优势是图表类型丰富,接口调用简单,使用时只需 ...

  5. Org mode无法生成LaTeX公式预览图片

    最近需要在Cygwin平台下的Emacs Org mode中生成LaTeX数学公式的预览图片,从而得到图文并貌的笔记与任务管理文档.但当我执行org-toggle-latex-fragment命令后却 ...

  6. java图片裁剪和java生成缩略图

    一.缩略图 在浏览相冊的时候.可能须要生成相应的缩略图. 直接上代码: public class ImageUtil { private Logger log = LoggerFactory.getL ...

  7. thinkphp图片上传+validate表单验证+图片木马检测+缩略图生成

    目录 1.案例 1.1图片上传  1.2进行图片木马检测   1.3缩略图生成   1.4控制器中调用缩略图生成方法 1.案例 前言:在thinkphp框架的Thinkphp/Library/Thin ...

  8. 图片url地址的生成获取方法

    在写博客插入图片时,许多时候需要提供图片的url地址.作为菜鸡的我,自然是一脸懵逼.那么什么是所谓的url地址呢?又该如何获取图片的url地址呢? 首先来看一下度娘对url地址的解释:url是统一资源 ...

  9. 复制图片链接和标题生成Markdown文本

    写Markdown的时候常常会需要复制图片链接和标题以插入图片,不借助其他工具的话,一般需要先在Markdown文件中输入插入图片的格式,然后在浏览器中复制图片链接和标题将其依次粘贴到Markdown ...

随机推荐

  1. MySQL——通过EXPLAIN分析SQL的执行计划

    在MySQL中,我们可以通过EXPLAIN命令获取MySQL如何执行SELECT语句的信息,包括在SELECT语句执行过程中表如何连接和连接的顺序. 下面分别对EXPLAIN命令结果的每一列进行说明: ...

  2. Kafka 处理器客户端介绍

    [编者按]本文作者为 Bill Bejeck,主要介绍如何有效利用新的 Apache Kafka 客户端来满足数据处理需求.文章系国内 ITOM 管理平台 OneAPM 编译呈现,以下为正文. 如果你 ...

  3. NFS 系统的搭建

    问题: 由于工作,需要,不断得进行挂在硬盘重装系统,NFS 系统给了我一个很好的解决方案.于是决定写一篇博客,防止以后再次使用的时候,能够很快得重新建立NFS 文件系统. 调研: NFS(Networ ...

  4. 解决iPhone滑动时滑到另一个层级导致卡顿问题

    问题概览: 两个div都可以滑动时,会造成滑动顶层div时,底层div也会跟着滑动.如图示. 解决方法: 添加CSS即可. 代码如下 * { -webkit-overflow-scrolling: t ...

  5. 【转】Java学习---Java Web基础面试题整理

    [原文]https://www.toutiao.com/i6592359948632457731/ 1.什么是Servlet? 可以从两个方面去看Servlet: a.API:有一个接口servlet ...

  6. 【转】MaBatis学习---源码分析MyBatis缓存原理

    [原文]https://www.toutiao.com/i6594029178964673027/ 源码分析MyBatis缓存原理 1.简介 在 Web 应用中,缓存是必不可少的组件.通常我们都会用 ...

  7. kettle 启动spoon一闪而过

    Kettle是Pentaho的一个组件,主要用于数据库间的数据迁移(ETL). Kettle有三个主要组件:Spoon,Kitchen,Pan.其中Spoon是一个图形化的界面. 一.安装kettle ...

  8. 4星|《行为设计学:掌控关键决策》:影响决策质量的四大思维陷阱及WRAP应对法

    行为设计学:掌控关键决策 两位作者认为,有四大思维陷阱让人做出错误的决策:思维狭隘.证实倾向.短期情绪.过度自信.两位作者提出WRAP决策流程来应对:Widen your options(拓宽选择空间 ...

  9. PyQt5--QProgressBar

    # -*- coding:utf-8 -*- ''' Created on Sep 20, 2018 @author: SaShuangYiBing Comment: ''' import sys f ...

  10. Solr建立索引时,过滤HTML标签

    原文地址  http://www.joyphper.net/article/201306/188.html 1.在数据库的读取文件data-config.xml 中的entity 标记里边添加 tra ...