Import required libraries:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
from torchvision.models import vgg19
from torchvision.datasets import ImageFolder

Define a simple convolutional block (Conv-BatchNorm-ReLU)

class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
super(ConvBlock, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
) def forward(self, x):
return self.conv(x)

Define a simple upscaling block using sub-pixel convolution

class UpscaleBlock(nn.Module):
def __init__(self, in_channels, scale_factor):
super(UpscaleBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, in_channels * (scale_factor ** 2), kernel_size=3, stride=1, padding=1)
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
self.relu = nn.ReLU(inplace=True) def forward(self, x):
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.relu(x)
return x

Define a custom super-resolution model (e.g., using ConvBlocks and UpscaleBlocks)

class SuperResolutionModel(nn.Module):
def __init__(self, upscale_factor):
super(SuperResolutionModel, self).__init__()
self.conv1 = ConvBlock(3, 64, kernel_size=9, stride=1, padding=4)
self.conv2 = ConvBlock(64, 32, kernel_size=1, stride=1, padding=0)
self.upscale = UpscaleBlock(32, upscale_factor)
self.conv3 = nn.Conv2d(32, 3, kernel_size=9, stride=1, padding=4) def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.upscale(x)
x = self.conv3(x)
return x

Create a custom dataset for image super-resolution

class SuperResolutionDataset(torch.utils.data.Dataset):
def __init__(self, image_folder, input_transform, target_transform):
self.dataset = ImageFolder(image_folder)
self.input_transform = input_transform
self.target_transform = target_transform def __getitem__(self, index):
img, _ = self.dataset[index]
target = self.target_transform(img)
input = self.input_transform(target)
return input, target def __len__(self):
return len(self.dataset)

Instantiate the model, loss function, and optimizer

upscale_factor = 2
model = SuperResolutionModel(upscale_factor).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

Define input and target transformations for data preprocessing

input_transform = transforms.Compose([
transforms.Resize((256 // upscale_factor, 256 // upscale_factor), interpolation=TF.InterpolationMode.BICUBIC),
transforms.ToTensor()
]) target_transform = transforms.Compose([
transforms.Resize((256, 256), interpolation=TF.InterpolationMode.BICUBIC),
transforms.ToTensor()
])

Create DataLoader for training and validation data

train_dataset = SuperResolutionDataset("path/to/train_data", input_transform, target_transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4) val_dataset = SuperResolutionDataset("path/to/val_data", input_transform, target_transform)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)

Training loop

model.eval()
val_loss = 0.0 with torch.no_grad():
for inputs, targets in val_loader:
inputs, targets = inputs.to(device), targets.to(device) outputs = model(inputs)
loss = criterion(outputs, targets) val_loss += loss.item() val_loss /= len(val_loader)
print(f"Validation Loss: {val_loss:.4f}")

Validation loop

model.eval()
val_loss = 0.0 with torch.no_grad():
for inputs, targets in val_loader:
inputs, targets = inputs.to(device), targets.to(device) outputs = model(inputs)
loss = criterion(outputs, targets) val_loss += loss.item() val_loss /= len(val_loader)
print(f"Validation Loss: {val_loss:.4f}")

Pytorch构建超分辨率模型——常用模块的更多相关文章

  1. 【超分辨率】—(ESRGAN)增强型超分辨率生成对抗网络-解读与实现

    一.文献解读 我们知道GAN 在图像修复时更容易得到符合视觉上效果更好的图像,今天要介绍的这篇文章——ESRGAN: Enhanced Super-Resolution Generative Adve ...

  2. 小米造最强超分辨率算法 | Fast, Accurate and Lightweight Super-Resolution with Neural Architecture Search

    本篇是基于 NAS 的图像超分辨率的文章,知名学术性自媒体 Paperweekly 在该文公布后迅速跟进,发表分析称「属于目前很火的 AutoML / Neural Architecture Sear ...

  3. 腾讯QQ空间超分辨率技术TSR

    腾讯QQ空间超分辨率技术TSR:为用户节省3/4流量,处理效果和速度超谷歌RAISR 雷锋网AI科技评论: 随着移动端屏幕分辨率越来越高,甚至像iPhone更有所谓的“视网膜屏”,人们对高清图片的诉求 ...

  4. 使用深度学习的超分辨率介绍 An Introduction to Super Resolution using Deep Learning

    使用深度学习的超分辨率介绍 关于使用深度学习进行超分辨率的各种组件,损失函数和度量的详细讨论. 介绍 超分辨率是从给定的低分辨率(LR)图像恢复高分辨率(HR)图像的过程.由于较小的空间分辨率(即尺寸 ...

  5. 超分辨率论文CVPR-Kai Zhang

    深度学习与传统方法结合的超分辨率:Kai Zhang 1. (CVPR, 2019) Deep Plug-and-Play Super-Resolution for Arbitrary https:/ ...

  6. PyTorch如何构建深度学习模型?

    简介 每过一段时间,就会有一个深度学习库被开发,这些深度学习库往往可以改变深度学习领域的景观.Pytorch就是这样一个库. 在过去的一段时间里,我研究了Pytorch,我惊叹于它的操作简易.Pyto ...

  7. 【超分辨率】- CVPR2019中SR论文导读与剖析

    CVPR2019超分领域出现多篇更接近于真实世界原理的低分辨率和高分辨率图像对应的新思路.具体来说,以前论文训练数据主要使用的是人为的bicubic下采样得到的,网络倾向于学习bicubic下采样的逆 ...

  8. 『超分辨率重建』从SRCNN到WDSR

    超分辨率重建技术(Super-Resolution)是指从观测到的低分辨率图像重建出相应的高分辨率图像.SR可分为两类:    1. 从多张低分辨率图像重建出高分辨率图像    2. 从单张低分辨率图 ...

  9. 使用PyTorch构建神经网络模型进行手写识别

    使用PyTorch构建神经网络模型进行手写识别 PyTorch是一种基于Torch库的开源机器学习库,应用于计算机视觉和自然语言处理等应用,本章内容将从安装以及通过Torch构建基础的神经网络,计算梯 ...

  10. Tengine 常用模块使用介绍

    Tengine 和 Nginx Tengine简介 从2011年12月开始:Tengine是由淘宝网发起的Web服务器项目.它在Nginx的基础上,针对大访问量网站的需求,添加了很多高级功能 和特性. ...

随机推荐

  1. flask之数据模型flask-sqlalchemy

    一.安装数据库连接依赖包 pip install flask-sqlalchemy pip install pymysql 二.项目配置 app/__init__.py from flask_sqla ...

  2. js 关于 replace 取值、替换第几个匹配项

    〇.前言 在日常开发中,经常遇到针对字符串的替换.截取,知识点比较碎容易混淆,特此总结一下,仅供参考. 一.替换第一个匹配项 字符串替换 let strtest = "0123测试repla ...

  3. 设置nginx允许服务端跨域

    目前项目大多使用前后端分离的模式进行开发,跨域请求当然就是必不可少了,很多时候我们会使用在客户端的ajax 请求中设置跨域请求,也有的在服务端设置跨域.但是有时候会遇到不使用ajax也没有使用后端服务 ...

  4. 有管django使用orm 字段报错问题

    直接删除表,重新生成,首先删除:migrations 中,上传记录,然后django_migrations,

  5. 瞄准程序员招聘痛点,ShowMeBug让面试代码操作可“回放”

    程序员虽然是建设互联网的职业之一,但他们的招聘工作的线上化却有不少难题. 疫情加速了市场对远程办公.远程面试.远程教学等模式的接受程度,但程序员招聘涉及到代码能力测试,甚至不同企业有不同的产品代码基础 ...

  6. 24 式加速你的 Python

    一,分析代码运行时间 第1式,测算代码运行时间 平凡方法 快捷方法(jupyter环境) 第2式,测算代码多次运行平均时间 平凡方法 快捷方法(jupyter环境) 第3式,按调用函数分析代码运行时间 ...

  7. .Net全网最简RabbitMQ操作【强烈推荐】

    [前言] 本文自1年前的1.0版本推出以来,已被业界大量科技公司采用.同时也得到了.Net圈内多位大佬的关注+推荐,文章也被多家顶级.Net/C#公众号转载. 现在更新到了7.0版本,更好的服务各位. ...

  8. 关于Unity3D第一视角下镜头穿墙的问题解决方法

    昨天做室内模型的时候,遇到一个非常棘手的问题,那就是第一视角在室内运行的时候,会出现穿墙的效果.类似下图效果,在靠近墙壁的时候,出现了镜头看见了墙壁外的情况,很显然这是不符合逻辑的.我们要做的就是避免 ...

  9. Curl 中 关于PUT, POST, DELETE, UPDATE 的使用

    POST curl -H "Content-Type:application/json" -X POST --data '{"id":1, "text ...

  10. Java求数组元素的最大值和最小值

    代码如下: public static void main(String[] args) { int [] a = {1,2,3,88,2,90}; int max = a[0]; int min = ...