图像风格迁移

最后要生成的图片是怎样的是难以想象的,所以朴素的监督学习方法可能不会生效,

Content Loss

根据输入图片和输出图片的像素差别可以比较损失

\(l_{content} = \frac{1}{2}\sum (C_c-T_c)^2\)

Style Loss

从中间提取多个特征层来衡量损失。

利用\(Gram\) \(Matrix\)(格拉姆矩阵)可以衡量风格的相关性,对于一个实矩阵\(X\),矩阵\(XX^T\)是\(X\)的行向量的格拉姆矩阵

\(l_{style}=\sum wi(Ts-Ss)^2\)

总的损失函数

\(L_{total(S,C,T)}=\alpha l_{content}(C,T)+\beta L_{style}(S,T)\)


代码
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np import torch
import torch.optim as optim
from torchvision import transforms, models vgg = models.vgg19(pretrained=True).features #使用预训练的VGG19,features表示只提取不包括全连接层的部分 for i in vgg.parameters():
i.requires_grad_(False) #不要求训练VGG的参数

定义一个显示图片的函数

def load_img(path, max_size=400,shape=None):
img = Image.open(path).convert('RGB') if(max(img.size)) > max_size: #规定图像的最大尺寸
size = max_size
else:
size = max(img.size) if shape is not None:
size = shape
transform = transforms.Compose([
transforms.Resize(size),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))
])
'''删除alpha通道(jpg), 转为png,补足另一个维度-batch'''
img = transform(img)[:3,:,:].unsqueeze(0)
return img

载入图像

content  = load_img('./images/turtle.jpg')
style = load_img('./images/wave.jpg', shape=content.shape[-2:]) #让两张图尺寸一样 '''转换为plt可以画出来的形式'''
def im_convert(tensor):
img = tensor.clone().detach()
img = img.numpy().squeeze()
img = img.transpose(1,2,0)
img = img * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
img = img.clip(0,1)
return img

使用的图像为(左边为Content Image,右边为Style Image):

定义几个待会要用到的函数

def get_features(img, model, layers=None):
'''获取特征层'''
if layers is None:
layers = {
'0':'conv1_1',
'5':'conv2_1',
'10':'conv3_1',
'19':'conv4_1',
'21':'conv4_2', #content层
'28':'conv5_1'
} features = {}
x = img
for name, layer in model._modules.items():
x = layer(x)
if name in layers:
features[layers[name]] = x return features def gram_matrix(tensor):
'''计算Gram matrix'''
_, d, h, w = tensor.size() #第一个是batch_size tensor = tensor.view(d, h*w) gram = torch.mm(tensor, tensor.t()) return gram content_features = get_features(content, vgg)
style_features = get_features(style, vgg) style_grams = {layer:gram_matrix(style_features[layer]) for layer in style_features} target = content.clone().requires_grad_(True) '''定义不同层的权重'''
style_weights = {
'conv1_1': 1,
'conv2_1': 0.8,
'conv3_1': 0.5,
'conv4_1': 0.3,
'conv5_1': 0.1,
}
'''定义2种损失对应的权重'''
content_weight = 1
style_weight = 1e6

训练过程

show_every = 400
optimizer = optim.Adam([target], lr=0.003)
steps = 2000 for ii in range(steps):
target_features = get_features(target, vgg) content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)
style_loss = 0
'''加上每一层的gram_matrix矩阵的损失'''
for layer in style_weights:
target_feature = target_features[layer]
target_gram = gram_matrix(target_feature)
_, d, h, w = target_feature.shape
style_gram = style_grams[layer]
layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
style_loss += layer_style_loss/(d*h*w) #加到总的style_loss里,除以大小 total_loss = content_weight * content_loss + style_weight * style_loss optimizer.zero_grad()
total_loss.backward()
optimizer.step() if ii % show_every == 0 :
print('Total Loss:',total_loss.item())
plt.imshow(im_convert(target))
plt.show()

将输入的图像和最后得到的混合图作比较:

没有达到最好的效果,还有可以优化的空间√

参考:
  1. Image Style Transfer Using Convolutional Neural Networks论文
  2. Udacity——PyTorch Scholarship Challenge

图像风格迁移(Pytorch)的更多相关文章

  1. keras图像风格迁移

    风格迁移: 在内容上尽量与基准图像保持一致,在风格上尽量与风格图像保持一致. 1. 使用预训练的VGG19网络提取特征 2. 损失函数之一是"内容损失"(content loss) ...

  2. fast neural style transfer图像风格迁移基于tensorflow实现

    引自:深度学习实践:使用Tensorflow实现快速风格迁移 一.风格迁移简介 风格迁移(Style Transfer)是深度学习众多应用中非常有趣的一种,如图,我们可以使用这种方法把一张图片的风格“ ...

  3. Distill详述「可微图像参数化」:神经网络可视化和风格迁移利器!

    近日,期刊平台 Distill 发布了谷歌研究人员的一篇文章,介绍一个适用于神经网络可视化和风格迁移的强大工具:可微图像参数化.这篇文章从多个方面介绍了该工具. 图像分类神经网络拥有卓越的图像生成能力 ...

  4. 使用 PyTorch 进行 风格迁移(Neural-Transfer)

    1.简介 本教程主要讲解如何实现由 Leon A. Gatys,Alexander S. Ecker和Matthias Bethge提出的Neural-Style 算法.Neural-Style 或者 ...

  5. Gram格拉姆矩阵在风格迁移中的应用

    Gram定义 n维欧式空间中任意k个向量之间两两的内积所组成的矩阵,称为这k个向量的格拉姆矩阵(Gram matrix) 根据定义可以看到,每个Gram矩阵背后都有一组向量,Gram矩阵就是由这一组向 ...

  6. 『cs231n』通过代码理解风格迁移

    『cs231n』卷积神经网络的可视化应用 文件目录 vgg16.py import os import numpy as np import tensorflow as tf from downloa ...

  7. Keras实现风格迁移

    风格迁移 风格迁移算法经历多次定义和更新,现在应用在许多智能手机APP上. 风格迁移在保留目标图片内容的基础上,将图片风格引用在目标图片上. 风格本质上是指在各种空间尺度上图像中的纹理,颜色和视觉图案 ...

  8. ng-深度学习-课程笔记-14: 人脸识别和风格迁移(Week4)

    1 什么是人脸识别( what is face recognition ) 在相关文献中经常会提到人脸验证(verification)和人脸识别(recognition). verification就 ...

  9. [DeeplearningAI笔记]卷积神经网络4.6-4.10神经网络风格迁移

    4.4特殊应用:人脸识别和神经网络风格转换 觉得有用的话,欢迎一起讨论相互学习~Follow Me 4.6什么是神经网络风格转换neural style transfer 将原图片作为内容图片Cont ...

随机推荐

  1. .Net 委托 delegate 学习

    一.什么是委托: 委托是寻址方法的.NET版本,使用委托可以将方法作为参数进行传递.委托是一种特殊类型的对象,其特殊之处在于委托中包含的只是一个活多个方法的地址,而不是数据.   二.使用委托: 关键 ...

  2. vs2015安装编辑神器:resharper10.0

    在平时的开发工作中,作为一名程序员,难免会想办法找到适合自己的开发编辑器.这款插件来自JetBrains公司.接下来就来教大家如何对这款软件进行安装与破解. 1:首先下载与安装.如果没有找到适合的资源 ...

  3. .NET Framework框架介绍

    1.内容 .net framework c#和.net关系 掌握C#中命名空间2..net 就是微软提供的一个开发平台 版本: vs2008 3.5 vs2010 4.0 vs2012 2013 20 ...

  4. SQL Server 一列或多列重复数据的查询,删除(转载)

    转载来源:https://www.cnblogs.com/sunxi/p/4572332.html 业务需求 最近给公司做一个小工具,把某个数据库(数据源)的数据导进另一个数据(目标数据库).要求导入 ...

  5. 1.常用turtle功能函数

    #turtle常用命令汇总,括号中的参数仅仅作为举例使用,可根据需要修改 #设置画面背景色 turtle.bgcolor("black") #设置窗口大小和在屏幕上的坐标 turt ...

  6. .net 笔试面试总结(1)

    趁着在放假时候,给大家总结一点笔试面试上的东西,也刚好为年后跳槽做一点小积累. 下面的参考解答只是帮助大家理解,不用背,面试题.笔试题千变万化,不要梦想着把题覆盖了,下面的题是供大家查漏补缺用的,真正 ...

  7. JavaScript中的十个难点,你有必要知道。

    1. 立即执行函数 立即执行函数,即Immediately Invoked Function Expression (IIFE),正如它的名字,就是创建函数的同时立即执行.它没有绑定任何事件,也无需等 ...

  8. Odoo薪酬管理 公式配置

    薪酬计算的一般原理是:在基本工资的基础上,加上各种津贴,减去社保.公积金.个税等各种扣除项之后,得出最终的实发工资.此外,还要计算社保.公积金等公司应该承担的部分. 在同一公司中,针对不同的地区.不同 ...

  9. 碰到了通过Movie显示gif图片,有部分图片的duration为0导致gif只显示第一帧

    解决办法,改为使用android-gif-drawable.jar来显示gif图片(需要配合com.android.support:support-v4:18.0.0使用) GifImageView ...

  10. 【Linux】【MySQL】CentOS7、MySQL8.0.13 骚操作速查笔记——专治各种忘词水土不服

    1.前言 [Linux][MySQL]CentOS7安装最新版MySQL8.0.13(最新版MySQL从安装到运行) 专治各种忘词,各种水土不服. - -,就是一个健忘贵的速查表:(当然不包括SQL的 ...