【深度学习基础】:VGG实战篇(图像风格迁移)
前言
本篇来带大家看看VGG的实战篇,这次来带大家看看计算机视觉中一个有趣的小任务,图像风格迁移。
可运行代码位于: Style_transfer (可直接下载运行)
我们所进行的图像风格迁移是基于VGG实现的哦,如果对VGG网络还不太了解的,可以先去看看这边文章:【深度学习基础】:VGG原理篇 - carpell - 博客园
style transfer原理
原理解析
图像风格迁移是一种通过将一张图像的风格应用到另一张图像的内容上,从而生成一张新图像的技术。这种技术通常基于深度学习模型,特别是卷积神经网络(CNN)。
实现步骤:
- 首先,我们需要两张图像:内容图像和风格图像。内容图像是我们想要保持内容的图像,而风格图像是我们想要应用到内容图像上的风格的图像。接下来,我们通过使用预训练的卷积神经网络(如VGG19)来提取图像的特征。
- 在提取特征之后,我们计算内容损失和风格损失。内容损失度量生成图像与内容图像在内容上的差异,通常通过比较两个图像在浅层特征上的差异来计算。风格损失度量生成图像与风格图像在风格上的差异,风格通常通过计算图像特征之间的相关性来表示,即风格图像的特征图之间的gram矩阵。(这个等会会详细讲述)
其实图像风格迁移的核心是一个优化问题,目标是生成一张图像,使得内容损失和风格损失都最小化。我们初始化一张随机生成的图像。然后,我们使用卷积神经网络提取这张图像的特征,并计算内容损失和风格损失。接着,我们计算总损失,并进行反向传播和优化,以更新生成图像的像素值。这个过程会重复进行多次,直到生成图像在内容和风格上都接近目标图像。
损失函数
好,其实大概的原理基本上都清楚了,但是我们来看,我们是该如何去保证图片的内容是我们所选的内容图片,但是风格确实我们所选的风格图片呢?
首先来看内容函数的损失函数,其是这个还是很好理解的,我们所使用的就是平方差损失函数,我们希望我们的生成图片能够和内容图片具有相同的内容,所以最小化两者的差距即可。但是这里有个细节,我们该选择那个层的输出作为我们的标准呢?其实看最上面的图也能看出,我们的网络越深,提取到的内容图片损失的细节就会越多,所以我们会更多的选择浅层的内容图片特征作为标准。
def calculate_content_loss(original_feat, generated_feat) -> torch.Tensor:
"""计算内容损失,即生成特征图与标准特征图的规范化误差平方和"""
b, c, h, w = original_feat.shape
x = 2. * c * h * w # 规范化系数
return torch.sum((generated_feat - original_feat) ** 2) / x
然后我们在来看风格损失函数,其实风格损失函数的关键在于gram矩阵。那么我们回到最初,我们为什么能够提取出来图像的风格特征呢?其实就是依靠Gram矩阵了,其用于捕捉图像中不同通道之间的相关性,这些相关性可以反映出图像的纹理、颜色分布等风格特征。通俗说就是出现尖尖的形状的时候边上应该出现些什么。
然后风格损失函数如下,同样的我们会选择深层的风格图像特征作为我们的标准,因为深层网络提取出来的更多的是抽象的语义信息,能够代表其本质的特征,抽象的特征。
def gram_matrix(feature):
"""
计算特征图的格莱姆矩阵,作为风格特征表示
:param feature: 输入特征图
:return: 格莱姆矩阵
"""
b, c, h, w = feature.size()
feature = feature.view(b, c, h * w) # 拉平空间维度
G = torch.bmm(feature, feature.transpose(1, 2))
return G
def calculate_style_loss(style_feat, generated_feat) -> torch.Tensor:
"""计算风格损失,即生成特征图与标准特征图的格拉姆矩阵的规范化误差平方和"""
b, c, h, w = style_feat.shape
G = gram_matrix(generated_feat)
A = gram_matrix(style_feat)
x = 4. * ((h * w) ** 2) * (c ** 2) # 规范化系数
return torch.sum((G - A) ** 2) / x
所以最终的损失函数就是风格损失和内容损失的和,其中会乘上不同的系数。
style transfer代码
这里只展现了部分的核心代码,所有代码位于Style_transfer ,下载可直接使用。
我们通过处理目标图片,其一开始就是一张白噪声图,然后通过网络损失不断修改其像素,使得目标图片内容与内容图片相近,风格与风格图片相近。
import argparse
import os
import torch
import torch.optim as optim
from tqdm import tqdm
from models import VGG
from utils import make_transform, load_image, save_image, calculate_style_loss, calculate_content_loss
def parse_args():
parser = argparse.ArgumentParser(description='Style Transfer=')
parser.add_argument('--content_image', type=str, default='./data/content1.jpg', help='Path to the content image')
parser.add_argument('--style_image', type=str, default='./data/style2.jpg', help='Path to the style image')
parser.add_argument('--output_dir', type=str, default='./output/iterative_style_transfer', help='Output directory')
parser.add_argument('--image_size', type=int, nargs=2, default=[300, 450], help='Image size (height, width)')
parser.add_argument('--content_weight', type=float, default=1, help='Content weight')
parser.add_argument('--style_weight', type=float, default=15, help='Style weight')
parser.add_argument('--epochs', type=int, default=20, help='Number of training epochs')
parser.add_argument('--steps_per_epoch', type=int, default=100, help='Number of steps per epoch')
parser.add_argument('--learning_rate', type=float, default=0.03, help='Learning rate')
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
# 检查文件路径
assert os.path.exists(args.content_image), f"content image is not exist: {args.content_image}"
assert os.path.exists(args.style_image), f"style image is not exist: {args.style_image}"
# 内容特征层及loss加权系数
content_layers = {'5': 0.5, '10': 0.5}
# 风格特征层及loss加权系数
style_layers = {'0': 0.2, '5': 0.2, '10': 0.2, '19': 0.2, '28': 0.2}
# ----------------训练即推理过程----------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = make_transform(args.image_size, normalize=True)
content_img = load_image(args.content_image, transform).to(device)
style_img = load_image(args.style_image, transform).to(device)
# 用小随机噪声初始化图像(避免标准正态分布过大震荡)
generated_img = torch.rand_like(content_img).mul(0.1).requires_grad_().to(device)
save_image(generated_img, args.output_dir, 'noise_init.jpg')
vgg_model = VGG(content_layers, style_layers).to(device).eval()
# 关闭梯度追踪,节省显存
with torch.no_grad():
content_features, _ = vgg_model(content_img)
_, style_features = vgg_model(style_img)
optimizer = optim.Adam([generated_img], lr=args.learning_rate)
for epoch in range(args.epochs):
p_bar = tqdm(range(args.steps_per_epoch), desc=f'Epoch {epoch + 1}/{args.epochs}')
for step in p_bar:
generated_content, generated_style = vgg_model(generated_img)
content_loss = sum(
args.content_weight * content_layers[name] * calculate_content_loss(content_features[name], gen_content)
for name, gen_content in generated_content.items()
)
style_loss = sum(
args.style_weight * style_layers[name] * calculate_style_loss(style_features[name], gen_style)
for name, gen_style in generated_style.items()
)
total_loss = style_loss + content_loss
optimizer.zero_grad()
total_loss.backward()
# 梯度裁剪(可选但稳健)
torch.nn.utils.clip_grad_norm_([generated_img], max_norm=1.0)
optimizer.step()
p_bar.set_postfix(style_loss=style_loss.item(), content_loss=content_loss.item(), total_loss=total_loss.item())
# 保存中间结果
save_image(generated_img, args.output_dir, f'generated_epoch_{epoch + 1}.jpg')
效果图
fast style transfer 代码
刚才我们看了基于迭代的图像风格迁移,那么能不能有个方法,我们训练一个专用于一种风格的图像迁移的网络,其网络训练好之后,我们就不再需要老是输入风格图像,这个训练好的网络可以处理所有输入的图像输出其已经训练好的风格。所以fast style transfer 来了。其不仅可以处理图片,同样可以处理视频哦!
这里只展现了部分的核心代码,所有代码位于Style_transfer ,下载可直接使用。
import argparse
import os
from datetime import datetime, timedelta
from typing import List, Iterable
from PIL import Image
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
import cv2
import numpy as np
from models import VGG, TransNet
from datasets import ImageDataset
from utils import load_image, save_image, make_transform, save_model, calculate_style_loss, calculate_content_loss, \
denormalize
def parse_args():
parser = argparse.ArgumentParser(description='Fast Style Transfer')
parser.add_argument('--mode', type=str, choices=['train', 'image', 'video'], default='train',
help='Operation mode: train, image, video')
parser.add_argument('--image_size', type=int, nargs=2, default=[300, 450], help='Image size (height, width)')
parser.add_argument('--style_image', type=str, default='./data/style3.jpg', help='Style image path (training mode only)')
# The directory data/train2017/default_class contains several images. Due to the format required by ImageFolder, we need to use a class folder to contain the images, even though the class label is not used.
parser.add_argument('--content_dataset', type=str, default='data/train2014', help='Content image dataset path (training mode only)')
parser.add_argument('--content_weight', type=float, default=1., help='Content weight (training mode only)')
parser.add_argument('--style_weight', type=float, default=15., help='Style weight (training mode only)')
parser.add_argument('--model_save_path', type=str, default=f'./checkpoint/style2.pth', help='Model save path (training mode)')
parser.add_argument('--pretrained_model_path', type=str, help='Pretrained model load path (training mode only)')
parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs (training mode)')
parser.add_argument('--save_interval', type=int, default=120, help='Save interval (seconds) (training mode)')
parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (training mode)')
parser.add_argument('--output_dir', type=str, default='./output/realtime_transfer', help='Output directory (training mode)')
parser.add_argument('--model_path', type=str, default='./checkpoint/style1.pth', help='Model load path (image and video mode)')
parser.add_argument('--input_images_dir', type=str, help='Input images root directory (image mode only)')
parser.add_argument('--output_images_dir', type=str, default='./output/fast_style_transfer/image_generated.jpg',
help='Output images root directory (image mode only)')
parser.add_argument('--video_input', type=str, default='data/maigua.mp4', help='Input video path (video mode only)')
parser.add_argument('--video_output', type=str, default='output/videos/maigua.mp4',
help='Output video path (video mode only)')
parser.add_argument('--batch_size', type=int, default=4, help='Batch size')
return parser.parse_args()
def train(model,
vgg,
lr, epochs, batch_size, style_weight, content_weight, style_layers, content_layers,
device, transform,
image_style,
content_dataset_root,
save_path, output_dir,
save_interval=timedelta(seconds=120)):
optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 对目标网络进行优化
dataset = ImageDataset(content_dataset_root, transform=transform)
dataloader = DataLoader(dataset, batch_size, shuffle=True)
_, style_features = vgg(image_style)
p_bar = tqdm(range(epochs))
last_save_time = datetime.now() - save_interval
for epoch in p_bar:
running_content_loss, running_style_loss = 0.0, 0.0
for i, content_img in enumerate(dataloader):
content_img = content_img.to(device)
image_generated = model(content_img) # 只使用内容图像进行风格迁移
generated_content, generated_style = vgg(image_generated)
style_loss = sum(
style_weight * style_layers[name] * calculate_style_loss(style_features[name], gen_style) for
name, gen_style in generated_style.items())
content_features, _ = vgg(content_img) # 计算内容图的内容特征
content_loss = sum(
content_weight * content_layers[name] * calculate_content_loss(content_features[name], gen_content) for
name, gen_content in generated_content.items())
total_loss = style_loss + content_loss
optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) # 梯度裁剪
optimizer.step()
running_content_loss += content_loss.item()
running_style_loss += style_loss.item()
p_bar.set_postfix(progress=f'{(i + 1) / len(dataloader) * 100:.3f}%',
style_loss=f"{style_loss.item():.3f}",
content_loss=f"{content_loss.item():.3f}",
last_save_time=last_save_time)
if datetime.now() - last_save_time > save_interval:
last_save_time = datetime.now()
#writer.add_images('image_generated', denormalize(image_generated), epoch * len(dataloader) + i)
save_model(model, save_path) # 'fast_style_transfer.pth'
save_image(torch.cat((image_generated, content_img), 3), output_dir, f'{epoch}_{i}.jpg')
def process_images(images: Iterable[Image.Image], transform, model, device) -> List[Image.Image]:
images = torch.stack([transform(image) for image in images]).to(device)
model.to(device)
batch_generated = model(images)
batch_generated = denormalize(batch_generated).detach().cpu()
batch_generated = [transforms.ToPILImage()(image) for image in batch_generated]
return batch_generated
def process_video(video_path, output_path, transform, model, device, batch_size=4):
# 打开视频文件
cap = cv2.VideoCapture(video_path)
output_dir, filename = os.path.split(output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
# 获取视频属性
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# 定义视频编码器和创建 VideoWriter 对象
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
# 初始化 tqdm 进度条
pbar = tqdm(total=total_frames, desc="Processing Video")
# 读取视频并批量处理
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
if len(frames) == batch_size:
batch_generated = process_images(frames, transform, model, device)
for gen_frame in batch_generated:
gen = cv2.cvtColor(np.array(gen_frame), cv2.COLOR_RGB2BGR)
cv2.imshow("output", gen)
gen = cv2.resize(gen, (frame_width, frame_height))
out.write(gen)
if cv2.waitKey(1) & 0xFF == ord('s'):
break
frames.clear()
pbar.update(batch_size)
if frames:
batch_generated = process_images(frames, transform, model, device)
for gen_frame in batch_generated:
out.write(cv2.cvtColor(np.array(gen_frame), cv2.COLOR_RGB2BGR))
pbar.update(len(frames))
print(f'video successfully saved to: {output_path}')
pbar.close()
cap.release()
out.release()
if __name__ == '__main__':
args = parse_args()
# print(args)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# ----------------路径参数----------------
# 内容特征层及loss加权系数
content_layers = {'5': 0.5, '10': 0.5} # 使用vgg的较浅层特征作为内容特征,保证生成图片内容结构相似性
# 风格特征层及loss加权系数
style_layers = {'0': 0.2, '5': 0.2, '10': 0.2, '19': 0.2, '28': 0.2} # 使用vgg不同深度的风格特征,生成风格更加层次丰富
transform = make_transform(size=args.image_size, normalize=True) # 图像变换
image_style = load_image(args.style_image, transform=transform).to(device) # 风格图像
vgg = VGG(content_layers, style_layers).to(device) # 特征提取网络,只用来提取特征,不进行训练
model = TransNet(input_size=args.image_size).to(device) # 内容生成网络,用于生成风格图片,进行训练
if args.mode != 'train' and getattr(args, 'model_path'):
if not os.path.exists(args.model_path):
raise FileNotFoundError(f'{args.model_path}不存在!')
model.load_state_dict(torch.load(args.model_path))
elif args.mode == 'train' and getattr(args, 'pretrained_model_path') and os.path.exists(args.pretrained_model_path):
if not os.path.exists(args.pretrained_model_path):
raise FileNotFoundError(f'{args.pretrained_model_path}不存在!')
model.load_state_dict(torch.load(args.pretrained_model_path))
if args.mode == 'train':
# 训练模式
# 使用大规模内容图像数据训练快速图像风格迁移网络,比如COCO2017数据集
train(model, vgg, args.learning_rate, args.epochs, args.batch_size, args.style_weight, args.content_weight,
style_layers,
content_layers, device, transform, image_style, args.content_dataset, args.model_save_path,
args.output_dir,
save_interval=timedelta(seconds=args.save_interval))
elif args.mode == 'image':
# 使用训练好的风格迁移模型演示批量处理图片
if not os.path.exists(args.output_images_dir):
os.makedirs(args.output_images_dir)
for filename in tqdm(os.listdir(args.input_images_dir), desc='Processing Images'):
try:
filepath = os.path.join(args.input_images_dir, filename)
images_generated = process_images([Image.open(filepath)], transform, model, device)
images_generated[0].save(os.path.join(args.output_images_dir, filename))
except Exception as e:
pass
elif args.mode == 'video':
# 视频处理模式
process_video(args.video_input, args.video_output, transform, model, device,
batch_size=args.batch_size)
else:
raise ValueError("未知的运行模式")
效果图
【深度学习基础】:VGG实战篇(图像风格迁移)的更多相关文章
- 【神经网络与深度学习】neural-style、chainer-fast-neuralstyle图像风格转换使用
neural-style 官方地址:这个是使用torch7实现的;torch7安装比较麻烦.我这里使用的是大神使用TensorFlow实现的https://github.com/anishathaly ...
- 深度学习之PyTorch实战(4)——迁移学习
(这篇博客其实很早之前就写过了,就是自己对当前学习pytorch的一个教程学习做了一个学习笔记,一直未发现,今天整理一下,发出来与前面基础形成连载,方便初学者看,但是可能部分pytorch和torch ...
- 深度学习之PyTorch实战(1)——基础学习及搭建环境
最近在学习PyTorch框架,买了一本<深度学习之PyTorch实战计算机视觉>,从学习开始,小编会整理学习笔记,并博客记录,希望自己好好学完这本书,最后能熟练应用此框架. PyTorch ...
- TensorFlow深度学习基础与应用实战高清视频教程
TensorFlow深度学习基础与应用实战高清视频教程,适合Python C++ C#视觉应用开发者,基于TensorFlow深度学习框架,讲解TensorFlow基础.图像分类.目标检测训练与测试以 ...
- 腾讯QQ会员技术团队:人人都可以做深度学习应用:入门篇(下)
四.经典入门demo:识别手写数字(MNIST) 常规的编程入门有"Hello world"程序,而深度学习的入门程序则是MNIST,一个识别28*28像素的图片中的手写数字的程序 ...
- 人工智能深度学习框架MXNet实战:深度神经网络的交通标志识别训练
人工智能深度学习框架MXNet实战:深度神经网络的交通标志识别训练 MXNet 是一个轻量级.可移植.灵活的分布式深度学习框架,2017 年 1 月 23 日,该项目进入 Apache 基金会,成为 ...
- 算法工程师<深度学习基础>
<深度学习基础> 卷积神经网络,循环神经网络,LSTM与GRU,梯度消失与梯度爆炸,激活函数,防止过拟合的方法,dropout,batch normalization,各类经典的网络结构, ...
- 深度学习基础系列(九)| Dropout VS Batch Normalization? 是时候放弃Dropout了
Dropout是过去几年非常流行的正则化技术,可有效防止过拟合的发生.但从深度学习的发展趋势看,Batch Normalizaton(简称BN)正在逐步取代Dropout技术,特别是在卷积层.本文将首 ...
- 参考《深度学习之PyTorch实战计算机视觉》PDF
计算机视觉.自然语言处理和语音识别是目前深度学习领域很热门的三大应用方向. 计算机视觉学习,推荐阅读<深度学习之PyTorch实战计算机视觉>.学到人工智能的基础概念及Python 编程技 ...
- 机器学习python*(深度学习)核心技术实战
Python实战及机器学习(深度学习)技术 一,时间地点:2020年01月08日-11日 北京(机房上课,每人一台电脑进行实际案例操作,赠送 U盘拷贝资料及课件和软件)二.课程目标:1.python基 ...
随机推荐
- Rec out
"为你守夜" 在我这里蓝和灰是一种颜色 深蓝等于星灰. 当时看到星灰这个说法的时候,我说 星星怎么会是灰色的呢 在刻板印象里,星星和月亮一样是黄的 月亮有时候确实是黄的 但它颜色很 ...
- Flink-cdc同步mysql到iceberg丢失数据排查
一.获取任务信息 任务id:i01f51582-d8be-4262-aefa-000000 任务名称:ods_test1234 丢失的数据时间:2024-09-16 09:28:47 二.数据同步查看 ...
- 让我们从零开始使用PyTorch构建一个轻量级的词嵌入模型
图片来源:Phil Hearing(Unsplash) 在我之前写的一篇文章中,我们学习了如何使用PyTorch的nn.Embedding层将单词转换为稠密向量.但由于该嵌入层是未经训练的,这些向量并 ...
- KUKA库卡机器人KR120维修故障参考方案
随着智能制造的飞速发展,KUKA库卡机器人KR120以其稳定的特点,在自动化生产线上扮演着举足轻重的角色.然而,任何机械设备在长期运行过程中都难免会遇到故障.本文将针对KUKA库卡机器人KR120维修 ...
- Azure - [01] 订阅管理
题记部分 001 || 核心功能 (1)访问控制 Azure订阅通过基于角色的访问控制(RBAC)系统,允许管理员精细管理用户.组和应用程序对资源的访问权限.RBAC系统通过将权限分配给角色,再将 ...
- Hbase - hbase hbck介绍
原文地址:https://bbs.huaweicloud.com/blogs/353332 HBaseFsck(hbck)是一种命令行工具,可检查hbase集群的region一致性和表完整性的问题,同 ...
- try catch异常捕获工具类
异常捕获工具类 package com.example.multiThreadTransaction_demo.utils; import lombok.extern.slf4j.Slf4j; imp ...
- swoole(5)信号监听、热重启
一:信号监听 信号:由用户.系统或者进程发给目标进程的信息,以通知目标进程某个状态的改变或系统异常 信号查看:kill -l SIGHUP 终止进程 终端线路挂断 SIGINT ...
- 关于jsp的MySQL数据库连接问题
<%@ page language="java" contentType="text/html; charset=utf-8" pageEncoding= ...
- Web前端入门第 12 问:HTML 常用属性一览
HELLO,这里是大熊学习前端开发的入门笔记. 本系列笔记基于 windows 系统. HTML 常用属性大约 70 个,是否又头大了?脸上笑嘻嘻,心里嘛...嘿嘿... 温馨提示:属性不用死记硬背, ...