ContentLoss

首先是要定义一个内容差异损失函数,这里直接调用functional.mse_loss(input,self.target)就可以计算出其内容差异损失。

注意这里一般是定义一个网络模型,输入和输出一直,这样才在后面方便直接求出 loss

class ContentLoss(nn.Module):

    def __init__(self, target,):
super(ContentLoss, self).__init__()
# we 'detach' the target content from the tree used
# to dynamically compute the gradient: this is a stated value,
# not a variable. Otherwise the forward method of the criterion
# will throw an error.
self.target = target.detach() def forward(self, input):
self.loss = F.mse_loss(input, self.target)
return input

detach是使他不能求梯度

StyleLoss

这里是定义风格差异损失函数,首先就先定一个一个gram函数来求 c*c c为层数

def gram_matrix(input):
a, b, c, d = input.size() # a=batch size(=1)
# b=number of feature maps
# (c,d)=dimensions of a f. map (N=c*d) features = input.view(a * b, c * d) # resise F_XL into \hat F_XL G = torch.mm(features, features.t()) # compute the gram product # we 'normalize' the values of the gram matrix
# by dividing by the number of element in each feature maps.
return G.div(a * b * c * d)

然后利用gram函数来求loss

class StyleLoss(nn.Module):
def __init__(self, target_feature):
super(StyleLoss, self).__init__()
self.target = gram_matrix(target_feature).detach() def forward(self, input):
G = gram_matrix(input)
self.loss = F.mse_loss(G, self.target)
return input

新定义网络

然后将 ContentLoss和 StyleLoss 加入新定义的迁移网络(normalization见下面)

cnn = models.vgg19(pretrained=True).features.to(device).eval()

content_layers_default = ['conv_4']
style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'] def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
style_img, content_img,
content_layers=content_layers_default,
style_layers=style_layers_default):
cnn = copy.deepcopy(cnn) # normalization module
normalization = Normalization(normalization_mean, normalization_std).to(device) # just in order to have an iterable access to or list of content/syle
# losses
content_losses = []
style_losses = [] # assuming that cnn is a nn.Sequential, so we make a new nn.Sequential
# to put in modules that are supposed to be activated sequentially
model = nn.Sequential(normalization)
# model = nn.Sequential()
i = 0 # increment every time we see a conv
for layer in cnn.children():
if isinstance(layer, nn.Conv2d):
i += 1
name = 'conv_{}'.format(i)
elif isinstance(layer, nn.ReLU):
name = 'relu_{}'.format(i)
# The in-place version doesn't play very nicely with the ContentLoss
# and StyleLoss we insert below. So we replace with out-of-place
# ones here. #********注意**************不能掉 vgg默认inplace=ture
layer = nn.ReLU(inplace=False) elif isinstance(layer, nn.MaxPool2d):
name = 'pool_{}'.format(i)
elif isinstance(layer, nn.BatchNorm2d):
name = 'bn_{}'.format(i)
else:
raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__)) model.add_module(name, layer) if name in content_layers:
# add content loss:
target = model(content_img).detach()
content_loss = ContentLoss(target)
model.add_module("content_loss_{}".format(i), content_loss)
content_losses.append(content_loss) if name in style_layers:
# add style loss:
target_feature = model(style_img).detach()
style_loss = StyleLoss(target_feature)
model.add_module("style_loss_{}".format(i), style_loss)
style_losses.append(style_loss) # now we trim off the layers after the last content and style losses
for i in range(len(model) - 1, -1, -1):
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
break model = model[:(i + 1)] return model, style_losses, content_losses

normalization模型

cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) # create a module to normalize input image so we can easily put it in a
# nn.Sequential class Normalization(nn.Module):
def __init__(self, mean, std):
super(Normalization, self).__init__()
# .view the mean and std to make them [C x 1 x 1] so that they can
# directly work with image Tensor of shape [B x C x H x W].
# B is batch size. C is number of channels. H is height and W is width.
self.mean = torch.tensor(mean).view(-1, 1, 1)
self.std = torch.tensor(std).view(-1, 1, 1) def forward(self, img):
# normalize img
return (img - self.mean) / self.std

输入参数定义

input_img = content_img.clone()
optimizer = optim.LBFGS([input_img.requires_grad_()])
两种方式
#input_param = nn.Parameter(input_img.data)
#optimizer = optim.LBFGS([input_param])

更新input_img

注意梯度共轭和LBFGS 方法更新梯度的时候 要用闭包的形式

def run_style_transfer(cnn, normalization_mean, normalization_std,
content_img, style_img, input_img, num_steps=300,
style_weight=500000, content_weight=1):
"""Run the style transfer."""
print('Building the style transfer model..')
model, style_losses, content_losses = get_style_model_and_losses(cnn,
normalization_mean, normalization_std, style_img, content_img)
# optimizer = get_input_optimizer(input_img) print('Optimizing..')
run = [0]
while run[0] <= num_steps: def closure():
# correct the values of updated input image
input_img.data.clamp_(0, 1) optimizer.zero_grad()
model(input_img)
style_score = 0
content_score = 0 for sl in style_losses:
style_score += sl.loss
for cl in content_losses:
content_score += cl.loss
style_score *= style_weight
content_score *= content_weight loss = style_score + content_score
loss.backward() run[0] += 1
if run[0] % 50 == 0:
print("run {}:".format(run))
print('Style Loss : {:4f} Content Loss: {:4f}'.format(
style_score.item(), content_score.item()))
print() return style_score + content_score
optimizer.step(closure) # a last correction...
input_img.data.clamp_(0, 1) return input_img

其余代码参考 https://pytorch.org/tutorials/advanced/neural_style_tutorial.html#sphx-glr-advanced-neural-style-tutorial-py

neural_transfer风格迁移的更多相关文章

  1. 图像风格迁移(Pytorch)

    图像风格迁移 最后要生成的图片是怎样的是难以想象的,所以朴素的监督学习方法可能不会生效, Content Loss 根据输入图片和输出图片的像素差别可以比较损失 \(l_{content} = \fr ...

  2. keras图像风格迁移

    风格迁移: 在内容上尽量与基准图像保持一致,在风格上尽量与风格图像保持一致. 1. 使用预训练的VGG19网络提取特征 2. 损失函数之一是"内容损失"(content loss) ...

  3. Gram格拉姆矩阵在风格迁移中的应用

    Gram定义 n维欧式空间中任意k个向量之间两两的内积所组成的矩阵,称为这k个向量的格拉姆矩阵(Gram matrix) 根据定义可以看到,每个Gram矩阵背后都有一组向量,Gram矩阵就是由这一组向 ...

  4. 『cs231n』通过代码理解风格迁移

    『cs231n』卷积神经网络的可视化应用 文件目录 vgg16.py import os import numpy as np import tensorflow as tf from downloa ...

  5. Keras实现风格迁移

    风格迁移 风格迁移算法经历多次定义和更新,现在应用在许多智能手机APP上. 风格迁移在保留目标图片内容的基础上,将图片风格引用在目标图片上. 风格本质上是指在各种空间尺度上图像中的纹理,颜色和视觉图案 ...

  6. fast neural style transfer图像风格迁移基于tensorflow实现

    引自:深度学习实践:使用Tensorflow实现快速风格迁移 一.风格迁移简介 风格迁移(Style Transfer)是深度学习众多应用中非常有趣的一种,如图,我们可以使用这种方法把一张图片的风格“ ...

  7. Distill详述「可微图像参数化」:神经网络可视化和风格迁移利器!

    近日,期刊平台 Distill 发布了谷歌研究人员的一篇文章,介绍一个适用于神经网络可视化和风格迁移的强大工具:可微图像参数化.这篇文章从多个方面介绍了该工具. 图像分类神经网络拥有卓越的图像生成能力 ...

  8. ng-深度学习-课程笔记-14: 人脸识别和风格迁移(Week4)

    1 什么是人脸识别( what is face recognition ) 在相关文献中经常会提到人脸验证(verification)和人脸识别(recognition). verification就 ...

  9. [DeeplearningAI笔记]卷积神经网络4.6-4.10神经网络风格迁移

    4.4特殊应用:人脸识别和神经网络风格转换 觉得有用的话,欢迎一起讨论相互学习~Follow Me 4.6什么是神经网络风格转换neural style transfer 将原图片作为内容图片Cont ...

随机推荐

  1. 个人永久性免费-Excel催化剂功能第89波-批量多图片转PDF

    前一篇展示了从PDF中提取到有用信息如图片.文本.表格等功能,部分人可能对自己手中的转PDF格式的保护性有所顾虑,此篇从反向角度,提供数据保护作用,让PDF文件的数据保护更彻底,让文本型的PDF文件彻 ...

  2. Jmeter(1):使用TCP取样器与socket接口进行简单通信

    一个小任务:服务器与客户端连接,每次发送50个随机生成的字符,两秒发送一次 失败过太多次,然后昨晚终于跑通了,心情激动,于是清均第一篇博客就诞生了. 之前不了解jmeter,想过单纯用java编写服务 ...

  3. css 图片裁剪显示

    用object-fit:cover object-fit属性详解 object-fit:CSS 属性指定替换元素的内容应该如何适应到其使用的高度和宽度确定的框. object-fit:fill 被替换 ...

  4. 一位 iOS 大牛的 BAT面试心得与经验总结,送给正在迷茫 的你!

    前言: 目前形势,参加到 iOS 队伍的人是越来越多,可以说是已经达到了供过于求的地步了. 今年,找过工作人可能会更深刻地体会到今年的就业形势不容乐观,之前实习的时候就想着写一篇面经,后来忙就给忘了, ...

  5. jmter快速安装

    一.简介 Apache JMeter是Apache组织开发的基于Java的压力测试工具,用于接口和压力测试,所以前提是一定更要安装jdk. 二.下载安装 下载:官网下载 下载完成后运行包里的jmete ...

  6. 基础篇-1.2Java世界的规章制度(下)

    1 Java运算符 Java世界中的运算其实就是数学运算,而运算符就是其中的媒介. 算术运算符 操作符 描述 + 加法,对符号两边的数值相加 - 减法,符号左边的数减去右边的数 * 乘法,符号两边的数 ...

  7. Redis(四)--- Redis的命令参考

    1.简述 数据类型也称数据对象,包含字符串对象(string).列表对象(list).哈希对象(hash).集合对象(set).有序集合对象(zset). 2.String数据类型命令 string  ...

  8. Python基础总结之第六天开始【认识List:列表】【认识Tuple:元组】【还有他们基本的操作】(新手可相互督促)

    早,在北京的周六,热到不行~~~ 今天更新笔记列表(List).元组(Tuple)以及它们的操作方法 在列表中会经常用到List列表,前面我们认识到的有字符串,字符串数据是不能修改当前字符串里面的任意 ...

  9. 剖析std::function接口与实现

    目录 前言 一.std::function的原理与接口 1.1 std::function是函数包装器 1.2 C++注重运行时效率 1.3 用函数指针实现多态 1.4 std::function的接 ...

  10. jsp数据交互(一).1

    一.jsp中java小脚本1.<% java代码段%>2.<% =java表达式%>不能有分号3.<%!成员变量和函数声明%>二.注释1.<!--html注释 ...