resnet18训练自定义数据集
目录结构

dogsData.py
import json import torch
import os, glob
import random, csv from PIL import Image
from torch.utils.data import Dataset, DataLoader from torchvision import transforms
from torchvision.transforms import InterpolationMode class Dogs(Dataset): def __init__(self, root, resize, mode):
super().__init__()
self.root = root
self.resize = resize
self.nameLable = {}
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root, name)):
continue
self.nameLable[name] = len(self.nameLable.keys()) if not os.path.exists(os.path.join(self.root, 'label.txt')):
with open(os.path.join(self.root, 'label.txt'), 'w', encoding='utf-8') as f:
f.write(json.dumps(self.nameLable, ensure_ascii=False)) # print(self.nameLable)
self.images, self.labels = self.load_csv('images.csv')
# print(self.labels) if mode == 'train':
self.images = self.images[:int(0.8*len(self.images))]
self.labels = self.labels[:int(0.8*len(self.labels))]
elif mode == 'val':
self.images = self.images[int(0.8*len(self.images)):int(0.9*len(self.images))]
self.labels = self.labels[int(0.8*len(self.labels)):int(0.9*len(self.labels))]
else:
self.images = self.images[int(0.9*len(self.images)):]
self.labels = self.labels[int(0.9*len(self.labels)):] def load_csv(self, filename): if not os.path.exists(os.path.join(self.root, filename)):
images = []
for name in self.nameLable.keys():
images += glob.glob(os.path.join(self.root, name, '*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
# print(len(images)) random.shuffle(images)
with open(os.path.join(self.root, filename), mode='w', newline='') as f:
writer = csv.writer(f)
for img in images:
name = img.split(os.sep)[-2]
label = self.nameLable[name]
writer.writerow([img, label])
print('csv write succesful') images, labels = [], []
with open(os.path.join(self.root, filename)) as f:
reader = csv.reader(f)
for row in reader:
img, label = row
label = int(label)
images.append(img)
labels.append(label) assert len(images) == len(labels) return images, labels def denormalize(self, x_hat):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# x_hot = (x-mean)/std
# x = x_hat * std = mean
# x : [c, w, h]
# mean [3] => [3, 1, 1]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1) x = x_hat * std + mean
return x def __len__(self):
return len(self.images) def __getitem__(self, idx):
# print(idx, len(self.images), len(self.labels))
img, label = self.images[idx], self.labels[idx] # 将字符串路径转换为tensor数据
# print(self.resize, type(self.resize))
tf = transforms.Compose([
lambda x: Image.open(x).convert('RGB'),
transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
transforms.RandomRotation(15),
transforms.CenterCrop(self.resize),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img = tf(img) label = torch.tensor(label) return img, label def main(): import visdom
import time viz = visdom.Visdom() # func1 通用
db = Dogs('Images_Data_Dog', 224, 'train')
# 取一张
# x,y = next(iter(db))
# print(x.shape, y)
# viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x')) # 取一个batch
loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)
print(len(loader))
print(db.nameLable)
# for x, y in loader:
# # print(x.shape, y)
# viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
# viz.text(str(y.numpy()), win='label', opts=dict(title='batch_y'))
# time.sleep(10) # # fun2
# import torchvision
# tf = transforms.Compose([
# transforms.Resize((64, 64)),
# transforms.RandomRotation(15),
# transforms.ToTensor(),
# ])
# db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)
# loader = DataLoader(db, batch_size=32, shuffle=True)
# print(len(loader))
# for x, y in loader:
# # print(x.shape, y)
# viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
# viz.text(str(y.numpy()), win='label', opts=dict(title='batch_y'))
# time.sleep(10) if __name__ == '__main__':
main()
utils.py
import torch
from torch import nn class Flatten(nn.Module):
def __init__(self):
super().__init__() def forward(self, x):
shape = torch.prod(torch.tensor(x.shape[1:])).item()
return x.view(-1, shape)
train.py
import os
import sys
base_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(base_path)
base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(base_path)
import torch
import visdom
from torch import optim, nn
import torchvision from torch.utils.data import DataLoader from dogs_train.utils import Flatten
from dogsData import Dogs from torchvision.models import resnet18 viz = visdom.Visdom() batchsz = 32
lr = 1e-3
epochs = 20 device = torch.device('cuda')
torch.manual_seed(1234) train_db = Dogs('Images_Data_Dog', 224, mode='train')
val_db = Dogs('Images_Data_Dog', 224, mode='val')
test_db = Dogs('Images_Data_Dog', 224, mode='test') train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=4)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2) def evalute(model, loader):
correct = 0
total = len(loader.dataset)
for x, y in loader:
x = x.to(device)
y = y.to(device)
with torch.no_grad():
logist = model(x)
pred = logist.argmax(dim=1)
correct += torch.eq(pred, y).sum().float().item()
return correct/total def main(): # model = ResNet18(5).to(device)
trained_model = resnet18(pretrained=True)
model = nn.Sequential(*list(trained_model.children())[:-1],
Flatten(), # [b, 512, 1, 1] => [b, 512]
nn.Linear(512, 27)
).to(device) x = torch.randn(2, 3, 224, 224).to(device)
print(model(x).shape) optimizer = optim.Adam(model.parameters(), lr=lr)
criteon = nn.CrossEntropyLoss() best_acc, best_epoch = 0, 0
global_step = 0
viz.line([0], [-1], win='loss', opts=dict(title='loss'))
viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
for epoch in range(epochs): for step, (x, y) in enumerate(train_loader):
x = x.to(device)
y = y.to(device) logits = model(x)
loss = criteon(logits, y) optimizer.zero_grad()
loss.backward()
optimizer.step()
viz.line([loss.item()], [global_step], win='loss', update='append')
global_step += 1
if epoch % 2 == 0:
val_acc = evalute(model, val_loader)
if val_acc > best_acc:
best_acc = val_acc
best_epoch = epoch
torch.save(model.state_dict(), 'best.mdl') viz.line([val_acc], [global_step], win='val_acc', update='append') print('best acc', best_acc, 'best epoch', best_epoch) model.load_state_dict(torch.load('best.mdl'))
print('loader from ckpt') test_acc = evalute(model, test_loader)
print(test_acc) if __name__ == '__main__':
main()
resnet18训练自定义数据集的更多相关文章
- MMDetection 快速开始,训练自定义数据集
本文将快速引导使用 MMDetection ,记录了实践中需注意的一些问题. 环境准备 基础环境 Nvidia 显卡的主机 Ubuntu 18.04 系统安装,可见 制作 USB 启动盘,及系统安装 ...
- Scaled-YOLOv4 快速开始,训练自定义数据集
代码: https://github.com/ikuokuo/start-scaled-yolov4 Scaled-YOLOv4 代码: https://github.com/WongKinYiu/S ...
- [炼丹术]YOLOv5训练自定义数据集
YOLOv5训练自定义数据 一.开始之前的准备工作 克隆 repo 并在Python>=3.6.0环境中安装requirements.txt,包括PyTorch>=1.7.模型和数据集会从 ...
- yolov5训练自定义数据集
yolov5训练自定义数据 step1:参考文献及代码 博客 https://blog.csdn.net/weixin_41868104/article/details/107339535 githu ...
- Tensorflow2 自定义数据集图片完成图片分类任务
对于自定义数据集的图片任务,通用流程一般分为以下几个步骤: Load data Train-Val-Test Build model Transfer Learning 其中大部分精力会花在数据的准备 ...
- torch_13_自定义数据集实战
1.将图片的路径和标签写入csv文件并实现读取 # 创建一个文件,包含image,存放方式:label pokemeon\\mew\\0001.jpg,0 def load_csv(self,file ...
- tensorflow从训练自定义CNN网络模型到Android端部署tflite
网上有很多关于tensorflow lite在安卓端部署的教程,但是大多只讲如何把训练好的模型部署到安卓端,不讲如何训练,而实际上在部署的时候,需要知道训练模型时预处理的细节,这就导致了自己训练的模型 ...
- Yolo训练自定义目标检测
Yolo训练自定义目标检测 参考darknet:https://pjreddie.com/darknet/yolo/ 1. 下载darknet 在 https://github.com/pjreddi ...
- pytorch加载语音类自定义数据集
pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合 torch.u ...
- PyTorch 自定义数据集
准备数据 准备 COCO128 数据集,其是 COCO train2017 前 128 个数据.按 YOLOv5 组织的目录: $ tree ~/datasets/coco128 -L 2 /home ...
随机推荐
- 【原创】windows环境下Java串口编程
由于工作中遇到需要读取SBG Ellipse N系列的惯导模块数据,为了方便操作,我选择在Windows下进行串口开发.串口使用RS232. Ellipse-N RS232的引脚定义 开始我尝试使用的 ...
- 点击事件触发count自增
1.vue3合成API :(即为什么要用setup() :为了数据更加关联) vue3 引入setup()合成API语法,它可以将某数据关联的内容整合到一个部分,即使setup里的内容越来越多也会围绕 ...
- Github高效搜索方式
Github高效搜索方式 文章目录 Github高效搜索方式 0.写在前面 1.常用的搜索功能 1.1 直接搜索 1.2 寻找指定用户|大小的仓库 1.3 搜索仓库 1.4 查找特定star范围的仓库 ...
- CVE-2016-2183(SSL/TLS)漏洞的办法
运行gpedit.msc,打开"本地组策略编辑器" 启用"SSL密码套件顺序" TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256_ ...
- 集群与iptables
Iptables 五链四表执行关系如图所示,容器环境最常用的就是filter和nat表 加上各种自定义的链插入到各个环节,拦截流量做各种控制 filter表:匹配数据包以进行过滤 nat表:修改数据包 ...
- 后端008_配置Security登录授权过滤器
------------恢复内容开始------------ 现在我们就可以去进行springscurity的配置了.首先我们新建一个配置类.然后该类需要添加@Configuration注解,然后还要 ...
- Qt非主线程显示窗口的简易思路
首发于我的个人博客:xie-kang.com 博客内有更多文章,欢迎大家访问 原文地址 Qt 在非主线程是无法显示UI窗口的,如果我们在其它线程中处理完数据,需要使用窗口展示,或者在其它线程需要使用到 ...
- CanvasScaler的三种适配模式——缩放模式(Scale with Screen Size)
一.含义 根据屏幕尺寸进行缩放,随着屏幕尺寸进行放大缩小 二.参数介绍 第一个参数一般是美术人员根据游戏主要面向的手机市场,比如安卓市场,用市场上最常用的分辨率作为制作UI图片的标准.这里填的数就是美 ...
- IP rDNS(PTR)信息从理解到情报挖掘
什么是IP的rdns信息? 过去很多人,将IP的rDNS信息理解为解析到IP的反查域名信息.IP的rDNS信息和IP反查域名信息完全是两个不同的信息.IP的rdns信息被称之为反向DNS解析(rDNS ...
- Javaweb学习笔记第六弹
本章节的存在意义是:学到PreparedStatement反应较慢,理解不透彻,来做个比较,加深印象 详细讲述PrepareStatement 与 Statement 连接数据库的部分区别 在我学习的 ...