Pytorch构建超分辨率模型——常用模块
Import required libraries:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
from torchvision.models import vgg19
from torchvision.datasets import ImageFolder
Define a simple convolutional block (Conv-BatchNorm-ReLU)
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
super(ConvBlock, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv(x)
Define a simple upscaling block using sub-pixel convolution
class UpscaleBlock(nn.Module):
def __init__(self, in_channels, scale_factor):
super(UpscaleBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, in_channels * (scale_factor ** 2), kernel_size=3, stride=1, padding=1)
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.relu(x)
return x
Define a custom super-resolution model (e.g., using ConvBlocks and UpscaleBlocks)
class SuperResolutionModel(nn.Module):
def __init__(self, upscale_factor):
super(SuperResolutionModel, self).__init__()
self.conv1 = ConvBlock(3, 64, kernel_size=9, stride=1, padding=4)
self.conv2 = ConvBlock(64, 32, kernel_size=1, stride=1, padding=0)
self.upscale = UpscaleBlock(32, upscale_factor)
self.conv3 = nn.Conv2d(32, 3, kernel_size=9, stride=1, padding=4)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.upscale(x)
x = self.conv3(x)
return x
Create a custom dataset for image super-resolution
class SuperResolutionDataset(torch.utils.data.Dataset):
def __init__(self, image_folder, input_transform, target_transform):
self.dataset = ImageFolder(image_folder)
self.input_transform = input_transform
self.target_transform = target_transform
def __getitem__(self, index):
img, _ = self.dataset[index]
target = self.target_transform(img)
input = self.input_transform(target)
return input, target
def __len__(self):
return len(self.dataset)
Instantiate the model, loss function, and optimizer
upscale_factor = 2
model = SuperResolutionModel(upscale_factor).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
Define input and target transformations for data preprocessing
input_transform = transforms.Compose([
transforms.Resize((256 // upscale_factor, 256 // upscale_factor), interpolation=TF.InterpolationMode.BICUBIC),
transforms.ToTensor()
])
target_transform = transforms.Compose([
transforms.Resize((256, 256), interpolation=TF.InterpolationMode.BICUBIC),
transforms.ToTensor()
])
Create DataLoader for training and validation data
train_dataset = SuperResolutionDataset("path/to/train_data", input_transform, target_transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
val_dataset = SuperResolutionDataset("path/to/val_data", input_transform, target_transform)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)
Training loop
model.eval()
val_loss = 0.0
with torch.no_grad():
for inputs, targets in val_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
val_loss += loss.item()
val_loss /= len(val_loader)
print(f"Validation Loss: {val_loss:.4f}")
Validation loop
model.eval()
val_loss = 0.0
with torch.no_grad():
for inputs, targets in val_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
val_loss += loss.item()
val_loss /= len(val_loader)
print(f"Validation Loss: {val_loss:.4f}")
Pytorch构建超分辨率模型——常用模块的更多相关文章
- 【超分辨率】—(ESRGAN)增强型超分辨率生成对抗网络-解读与实现
一.文献解读 我们知道GAN 在图像修复时更容易得到符合视觉上效果更好的图像,今天要介绍的这篇文章——ESRGAN: Enhanced Super-Resolution Generative Adve ...
- 小米造最强超分辨率算法 | Fast, Accurate and Lightweight Super-Resolution with Neural Architecture Search
本篇是基于 NAS 的图像超分辨率的文章,知名学术性自媒体 Paperweekly 在该文公布后迅速跟进,发表分析称「属于目前很火的 AutoML / Neural Architecture Sear ...
- 腾讯QQ空间超分辨率技术TSR
腾讯QQ空间超分辨率技术TSR:为用户节省3/4流量,处理效果和速度超谷歌RAISR 雷锋网AI科技评论: 随着移动端屏幕分辨率越来越高,甚至像iPhone更有所谓的“视网膜屏”,人们对高清图片的诉求 ...
- 使用深度学习的超分辨率介绍 An Introduction to Super Resolution using Deep Learning
使用深度学习的超分辨率介绍 关于使用深度学习进行超分辨率的各种组件,损失函数和度量的详细讨论. 介绍 超分辨率是从给定的低分辨率(LR)图像恢复高分辨率(HR)图像的过程.由于较小的空间分辨率(即尺寸 ...
- 超分辨率论文CVPR-Kai Zhang
深度学习与传统方法结合的超分辨率:Kai Zhang 1. (CVPR, 2019) Deep Plug-and-Play Super-Resolution for Arbitrary https:/ ...
- PyTorch如何构建深度学习模型?
简介 每过一段时间,就会有一个深度学习库被开发,这些深度学习库往往可以改变深度学习领域的景观.Pytorch就是这样一个库. 在过去的一段时间里,我研究了Pytorch,我惊叹于它的操作简易.Pyto ...
- 【超分辨率】- CVPR2019中SR论文导读与剖析
CVPR2019超分领域出现多篇更接近于真实世界原理的低分辨率和高分辨率图像对应的新思路.具体来说,以前论文训练数据主要使用的是人为的bicubic下采样得到的,网络倾向于学习bicubic下采样的逆 ...
- 『超分辨率重建』从SRCNN到WDSR
超分辨率重建技术(Super-Resolution)是指从观测到的低分辨率图像重建出相应的高分辨率图像.SR可分为两类: 1. 从多张低分辨率图像重建出高分辨率图像 2. 从单张低分辨率图 ...
- 使用PyTorch构建神经网络模型进行手写识别
使用PyTorch构建神经网络模型进行手写识别 PyTorch是一种基于Torch库的开源机器学习库,应用于计算机视觉和自然语言处理等应用,本章内容将从安装以及通过Torch构建基础的神经网络,计算梯 ...
- Tengine 常用模块使用介绍
Tengine 和 Nginx Tengine简介 从2011年12月开始:Tengine是由淘宝网发起的Web服务器项目.它在Nginx的基础上,针对大访问量网站的需求,添加了很多高级功能 和特性. ...
随机推荐
- cnpm : 无法加载文件 D:\nodejs\node_global\cnpm.ps1,因为在此系统上禁止运行 脚本。
错误信息: cnpm : 无法加载文件 D:\nodejs\node_global\cnpm.ps1,因为在此系统上禁止运行 脚本.有关详细信息,请参阅 https:/go.microsoft.com ...
- Linux 创建 Python 虚拟环境
Linux 创建 Python 虚拟环境 0. 前言 网上教程太杂太乱,要么排版不好看,要么讲半天讲不到重点,故做此篇,精简干练. 1. 安装virtualenv 先用pip安装virtualenv第 ...
- 【汇编】DOS系统功能调用(INT 21H)
前言 最近又听了听汇编的课程,发现代码里的MOV xxxxx INT 21H,老师都是一句话带过,而不讲讲其中的原因(也可能前面讲了我没有听QAQ). 顺便夸一下老师,老师懒省事录的视频画质已经成功从 ...
- 南洋才女,德艺双馨,孙燕姿本尊回应AI孙燕姿(基于Sadtalker/Python3.10)
孙燕姿果然不愧是孙燕姿,不愧为南洋理工大学的高材生,近日她在个人官方媒体博客上写了一篇英文版的长文,正式回应现在满城风雨的"AI孙燕姿"现象,流行天后展示了超人一等的智识水平,行文 ...
- 2023-06-09:什么是Redis事务?原理是什么?
2023-06-09:什么是Redis事务?原理是什么? 答案2023-06-09: Redis中的事务是以一组命令的形式出现的,这些命令被认为是最小的执行单位.事务可以保证在一个单独独立的隔离操作中 ...
- 国标GB28181协议客户端开发(二)程序架构和注册
国标GB28181协议客户端开发(二)程序架构和注册 本系列文章旨在探讨国标GB28181协议设备端的开发过程.本文将聚焦于架构设计和设备注册,并详细介绍了设备端的程序架构设计.exosip库介绍和接 ...
- PHP curl提交参数到某个网址,然后获取数据
<?php $data = '你的每个参数'; $url = 'https://www.bz80.vip/'; //举例 $html = post_data($url,$data); echo ...
- k3s 基础 —— 配置 loki
官方文档 核心组件 3 个 chart: promtail 这是一个 agent 代理客户端,用于收集日志,将日志传送给 loki loki 核心组件,主要功能是日志数据的写入与分析.包含 gatew ...
- selenium元素定位---ElementNotInteractableException(元素不可交互异常)解决方法
方法一: 增加强制等待时间 方法二: 使用js点击 element = self.browser.find_element(By.XPATH, "//td[@class='el-table_ ...
- pod setup 慢 的问题
由于更换了硬盘,重装了系统,需要重新配置环境,发现现在安装cocapods比之前坑更深了, 装环境时遇到pod setup才几kb的下载速度(即使用梯子也是巨慢),实在是没法用在网上尝试了各种方法,但 ...