Pytorch构建超分辨率模型——常用模块
Import required libraries:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
from torchvision.models import vgg19
from torchvision.datasets import ImageFolder
Define a simple convolutional block (Conv-BatchNorm-ReLU)
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
super(ConvBlock, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv(x)
Define a simple upscaling block using sub-pixel convolution
class UpscaleBlock(nn.Module):
def __init__(self, in_channels, scale_factor):
super(UpscaleBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, in_channels * (scale_factor ** 2), kernel_size=3, stride=1, padding=1)
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.relu(x)
return x
Define a custom super-resolution model (e.g., using ConvBlocks and UpscaleBlocks)
class SuperResolutionModel(nn.Module):
def __init__(self, upscale_factor):
super(SuperResolutionModel, self).__init__()
self.conv1 = ConvBlock(3, 64, kernel_size=9, stride=1, padding=4)
self.conv2 = ConvBlock(64, 32, kernel_size=1, stride=1, padding=0)
self.upscale = UpscaleBlock(32, upscale_factor)
self.conv3 = nn.Conv2d(32, 3, kernel_size=9, stride=1, padding=4)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.upscale(x)
x = self.conv3(x)
return x
Create a custom dataset for image super-resolution
class SuperResolutionDataset(torch.utils.data.Dataset):
def __init__(self, image_folder, input_transform, target_transform):
self.dataset = ImageFolder(image_folder)
self.input_transform = input_transform
self.target_transform = target_transform
def __getitem__(self, index):
img, _ = self.dataset[index]
target = self.target_transform(img)
input = self.input_transform(target)
return input, target
def __len__(self):
return len(self.dataset)
Instantiate the model, loss function, and optimizer
upscale_factor = 2
model = SuperResolutionModel(upscale_factor).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
Define input and target transformations for data preprocessing
input_transform = transforms.Compose([
transforms.Resize((256 // upscale_factor, 256 // upscale_factor), interpolation=TF.InterpolationMode.BICUBIC),
transforms.ToTensor()
])
target_transform = transforms.Compose([
transforms.Resize((256, 256), interpolation=TF.InterpolationMode.BICUBIC),
transforms.ToTensor()
])
Create DataLoader for training and validation data
train_dataset = SuperResolutionDataset("path/to/train_data", input_transform, target_transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
val_dataset = SuperResolutionDataset("path/to/val_data", input_transform, target_transform)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)
Training loop
model.eval()
val_loss = 0.0
with torch.no_grad():
for inputs, targets in val_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
val_loss += loss.item()
val_loss /= len(val_loader)
print(f"Validation Loss: {val_loss:.4f}")
Validation loop
model.eval()
val_loss = 0.0
with torch.no_grad():
for inputs, targets in val_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
val_loss += loss.item()
val_loss /= len(val_loader)
print(f"Validation Loss: {val_loss:.4f}")
Pytorch构建超分辨率模型——常用模块的更多相关文章
- 【超分辨率】—(ESRGAN)增强型超分辨率生成对抗网络-解读与实现
一.文献解读 我们知道GAN 在图像修复时更容易得到符合视觉上效果更好的图像,今天要介绍的这篇文章——ESRGAN: Enhanced Super-Resolution Generative Adve ...
- 小米造最强超分辨率算法 | Fast, Accurate and Lightweight Super-Resolution with Neural Architecture Search
本篇是基于 NAS 的图像超分辨率的文章,知名学术性自媒体 Paperweekly 在该文公布后迅速跟进,发表分析称「属于目前很火的 AutoML / Neural Architecture Sear ...
- 腾讯QQ空间超分辨率技术TSR
腾讯QQ空间超分辨率技术TSR:为用户节省3/4流量,处理效果和速度超谷歌RAISR 雷锋网AI科技评论: 随着移动端屏幕分辨率越来越高,甚至像iPhone更有所谓的“视网膜屏”,人们对高清图片的诉求 ...
- 使用深度学习的超分辨率介绍 An Introduction to Super Resolution using Deep Learning
使用深度学习的超分辨率介绍 关于使用深度学习进行超分辨率的各种组件,损失函数和度量的详细讨论. 介绍 超分辨率是从给定的低分辨率(LR)图像恢复高分辨率(HR)图像的过程.由于较小的空间分辨率(即尺寸 ...
- 超分辨率论文CVPR-Kai Zhang
深度学习与传统方法结合的超分辨率:Kai Zhang 1. (CVPR, 2019) Deep Plug-and-Play Super-Resolution for Arbitrary https:/ ...
- PyTorch如何构建深度学习模型?
简介 每过一段时间,就会有一个深度学习库被开发,这些深度学习库往往可以改变深度学习领域的景观.Pytorch就是这样一个库. 在过去的一段时间里,我研究了Pytorch,我惊叹于它的操作简易.Pyto ...
- 【超分辨率】- CVPR2019中SR论文导读与剖析
CVPR2019超分领域出现多篇更接近于真实世界原理的低分辨率和高分辨率图像对应的新思路.具体来说,以前论文训练数据主要使用的是人为的bicubic下采样得到的,网络倾向于学习bicubic下采样的逆 ...
- 『超分辨率重建』从SRCNN到WDSR
超分辨率重建技术(Super-Resolution)是指从观测到的低分辨率图像重建出相应的高分辨率图像.SR可分为两类: 1. 从多张低分辨率图像重建出高分辨率图像 2. 从单张低分辨率图 ...
- 使用PyTorch构建神经网络模型进行手写识别
使用PyTorch构建神经网络模型进行手写识别 PyTorch是一种基于Torch库的开源机器学习库,应用于计算机视觉和自然语言处理等应用,本章内容将从安装以及通过Torch构建基础的神经网络,计算梯 ...
- Tengine 常用模块使用介绍
Tengine 和 Nginx Tengine简介 从2011年12月开始:Tengine是由淘宝网发起的Web服务器项目.它在Nginx的基础上,针对大访问量网站的需求,添加了很多高级功能 和特性. ...
随机推荐
- prefetch_related() 一对多、多对多查询优化,反向查询
prefetch_related() 一对多.多对多查询优化,反向查询 Student.objects.filter(age__lt=30).prefetch_related('course') # ...
- windows server 2012 2019启动 开机自动启动 项设置
1.第一种方法:打开运行功能,运行shell:startup,打开管理员用户启动项目录.将想要设置成开机自启的程序快捷方式添加到其中即可,或者删除其中快捷方式即可取消开机自启.2.第二种方法:打开系统 ...
- 利用简单的IO操作实现M3U8文件之间的合并
先上代码: 1 @SneakyThrows //合并操作,最终文件不包含结束标识,方便多次合并 2 private static void mergeM3U8File(String source, S ...
- WPF 入门笔记 - 01 - 入门基础以及常用布局
本篇为学习博客园大佬圣殿骑士的<WPF基础到企业应用系列>以及部分DotNet菜园的<WPF入门教程系列>所作笔记,对应圣殿骑士<WPF基础到企业应用系列>第 1 ...
- 为什么 GPU 能够极大地提高仿真速度?
这里的提速主要是针对时域电磁算法的.因为时域算法的蛙跳推进模式仅对大量存放在固定 位置的数据进行完全相同的且是简单的操作(移位相加),这正是 GPU 这类众核 SIMD 架构所进行的运算,即 ALU ...
- Python基础 - 解释性语言和编译性语言
什么是机器语言 计算机是不能理解高级语言,当然也就不能直接执行高级语言了.计算机只能直接理解机器语言,所以任何语言,都必须将其翻译成机器语言,计算机才能运行高级语言编写的程序. 如何把我们写的代码 ...
- javascript5 定时器功能
定时器功能: 定时器功能是window对象方法,涉及到 定时器和延时器,具体 看代码 定时器 timer=setInterval(function (){},300) 清除定时器: clearInte ...
- ODOO配置属性
2字段的属性 2.1 隐藏字段 <field name='model_name' invisible="True"/> 2.2 条件下隐藏 <field name ...
- 《Just For Fun》:学习即游戏
<Just For Fun>:学习即游戏 最近读完了 Linus 的自传<Just For Fun>,一直想写点东西,但始终苦于工作繁忙,无暇思考该从何写起.技术上自然不用废话 ...
- 隐藏Tomcat中间件名称及版本号
目的 防止黑客利用Tomcat中间件及版本号有针对性发起攻击. 处理方法 输入命令方式 # 进入tomcat/lib目录 cd Tomcat目录/lib # 解决catalina.jar,备份Serv ...