目录结构

dogsData.py

import json

import torch
import os, glob
import random, csv from PIL import Image
from torch.utils.data import Dataset, DataLoader from torchvision import transforms
from torchvision.transforms import InterpolationMode class Dogs(Dataset): def __init__(self, root, resize, mode):
super().__init__()
self.root = root
self.resize = resize
self.nameLable = {}
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root, name)):
continue
self.nameLable[name] = len(self.nameLable.keys()) if not os.path.exists(os.path.join(self.root, 'label.txt')):
with open(os.path.join(self.root, 'label.txt'), 'w', encoding='utf-8') as f:
f.write(json.dumps(self.nameLable, ensure_ascii=False)) # print(self.nameLable)
self.images, self.labels = self.load_csv('images.csv')
# print(self.labels) if mode == 'train':
self.images = self.images[:int(0.8*len(self.images))]
self.labels = self.labels[:int(0.8*len(self.labels))]
elif mode == 'val':
self.images = self.images[int(0.8*len(self.images)):int(0.9*len(self.images))]
self.labels = self.labels[int(0.8*len(self.labels)):int(0.9*len(self.labels))]
else:
self.images = self.images[int(0.9*len(self.images)):]
self.labels = self.labels[int(0.9*len(self.labels)):] def load_csv(self, filename): if not os.path.exists(os.path.join(self.root, filename)):
images = []
for name in self.nameLable.keys():
images += glob.glob(os.path.join(self.root, name, '*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
# print(len(images)) random.shuffle(images)
with open(os.path.join(self.root, filename), mode='w', newline='') as f:
writer = csv.writer(f)
for img in images:
name = img.split(os.sep)[-2]
label = self.nameLable[name]
writer.writerow([img, label])
print('csv write succesful') images, labels = [], []
with open(os.path.join(self.root, filename)) as f:
reader = csv.reader(f)
for row in reader:
img, label = row
label = int(label)
images.append(img)
labels.append(label) assert len(images) == len(labels) return images, labels def denormalize(self, x_hat):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# x_hot = (x-mean)/std
# x = x_hat * std = mean
# x : [c, w, h]
# mean [3] => [3, 1, 1]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1) x = x_hat * std + mean
return x def __len__(self):
return len(self.images) def __getitem__(self, idx):
# print(idx, len(self.images), len(self.labels))
img, label = self.images[idx], self.labels[idx] # 将字符串路径转换为tensor数据
# print(self.resize, type(self.resize))
tf = transforms.Compose([
lambda x: Image.open(x).convert('RGB'),
transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
transforms.RandomRotation(15),
transforms.CenterCrop(self.resize),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img = tf(img) label = torch.tensor(label) return img, label def main(): import visdom
import time viz = visdom.Visdom() # func1 通用
db = Dogs('Images_Data_Dog', 224, 'train')
# 取一张
# x,y = next(iter(db))
# print(x.shape, y)
# viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x')) # 取一个batch
loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)
print(len(loader))
print(db.nameLable)
# for x, y in loader:
# # print(x.shape, y)
# viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
# viz.text(str(y.numpy()), win='label', opts=dict(title='batch_y'))
# time.sleep(10) # # fun2
# import torchvision
# tf = transforms.Compose([
# transforms.Resize((64, 64)),
# transforms.RandomRotation(15),
# transforms.ToTensor(),
# ])
# db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)
# loader = DataLoader(db, batch_size=32, shuffle=True)
# print(len(loader))
# for x, y in loader:
# # print(x.shape, y)
# viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
# viz.text(str(y.numpy()), win='label', opts=dict(title='batch_y'))
# time.sleep(10) if __name__ == '__main__':
main()

utils.py

import torch
from torch import nn class Flatten(nn.Module):
def __init__(self):
super().__init__() def forward(self, x):
shape = torch.prod(torch.tensor(x.shape[1:])).item()
return x.view(-1, shape)

train.py

import os
import sys
base_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(base_path)
base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(base_path)
import torch
import visdom
from torch import optim, nn
import torchvision from torch.utils.data import DataLoader from dogs_train.utils import Flatten
from dogsData import Dogs from torchvision.models import resnet18 viz = visdom.Visdom() batchsz = 32
lr = 1e-3
epochs = 20 device = torch.device('cuda')
torch.manual_seed(1234) train_db = Dogs('Images_Data_Dog', 224, mode='train')
val_db = Dogs('Images_Data_Dog', 224, mode='val')
test_db = Dogs('Images_Data_Dog', 224, mode='test') train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=4)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2) def evalute(model, loader):
correct = 0
total = len(loader.dataset)
for x, y in loader:
x = x.to(device)
y = y.to(device)
with torch.no_grad():
logist = model(x)
pred = logist.argmax(dim=1)
correct += torch.eq(pred, y).sum().float().item()
return correct/total def main(): # model = ResNet18(5).to(device)
trained_model = resnet18(pretrained=True)
model = nn.Sequential(*list(trained_model.children())[:-1],
Flatten(), # [b, 512, 1, 1] => [b, 512]
nn.Linear(512, 27)
).to(device) x = torch.randn(2, 3, 224, 224).to(device)
print(model(x).shape) optimizer = optim.Adam(model.parameters(), lr=lr)
criteon = nn.CrossEntropyLoss() best_acc, best_epoch = 0, 0
global_step = 0
viz.line([0], [-1], win='loss', opts=dict(title='loss'))
viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
for epoch in range(epochs): for step, (x, y) in enumerate(train_loader):
x = x.to(device)
y = y.to(device) logits = model(x)
loss = criteon(logits, y) optimizer.zero_grad()
loss.backward()
optimizer.step()
viz.line([loss.item()], [global_step], win='loss', update='append')
global_step += 1
if epoch % 2 == 0:
val_acc = evalute(model, val_loader)
if val_acc > best_acc:
best_acc = val_acc
best_epoch = epoch
torch.save(model.state_dict(), 'best.mdl') viz.line([val_acc], [global_step], win='val_acc', update='append') print('best acc', best_acc, 'best epoch', best_epoch) model.load_state_dict(torch.load('best.mdl'))
print('loader from ckpt') test_acc = evalute(model, test_loader)
print(test_acc) if __name__ == '__main__':
main()

resnet18训练自定义数据集的更多相关文章

  1. MMDetection 快速开始,训练自定义数据集

    本文将快速引导使用 MMDetection ,记录了实践中需注意的一些问题. 环境准备 基础环境 Nvidia 显卡的主机 Ubuntu 18.04 系统安装,可见 制作 USB 启动盘,及系统安装 ...

  2. Scaled-YOLOv4 快速开始,训练自定义数据集

    代码: https://github.com/ikuokuo/start-scaled-yolov4 Scaled-YOLOv4 代码: https://github.com/WongKinYiu/S ...

  3. [炼丹术]YOLOv5训练自定义数据集

    YOLOv5训练自定义数据 一.开始之前的准备工作 克隆 repo 并在Python>=3.6.0环境中安装requirements.txt,包括PyTorch>=1.7.模型和数据集会从 ...

  4. yolov5训练自定义数据集

    yolov5训练自定义数据 step1:参考文献及代码 博客 https://blog.csdn.net/weixin_41868104/article/details/107339535 githu ...

  5. Tensorflow2 自定义数据集图片完成图片分类任务

    对于自定义数据集的图片任务,通用流程一般分为以下几个步骤: Load data Train-Val-Test Build model Transfer Learning 其中大部分精力会花在数据的准备 ...

  6. torch_13_自定义数据集实战

    1.将图片的路径和标签写入csv文件并实现读取 # 创建一个文件,包含image,存放方式:label pokemeon\\mew\\0001.jpg,0 def load_csv(self,file ...

  7. tensorflow从训练自定义CNN网络模型到Android端部署tflite

    网上有很多关于tensorflow lite在安卓端部署的教程,但是大多只讲如何把训练好的模型部署到安卓端,不讲如何训练,而实际上在部署的时候,需要知道训练模型时预处理的细节,这就导致了自己训练的模型 ...

  8. Yolo训练自定义目标检测

    Yolo训练自定义目标检测 参考darknet:https://pjreddie.com/darknet/yolo/ 1. 下载darknet 在 https://github.com/pjreddi ...

  9. pytorch加载语音类自定义数据集

    pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合 torch.u ...

  10. PyTorch 自定义数据集

    准备数据 准备 COCO128 数据集,其是 COCO train2017 前 128 个数据.按 YOLOv5 组织的目录: $ tree ~/datasets/coco128 -L 2 /home ...

随机推荐

  1. 软件工程作业:个人项目—wc项目

    软件工程作业:个人项目-WC项目 项目相关要求 wc.exe 是一个常见的工具,它能统计文本文件的字符数.单词数和行数.这个项目要求写一个命令行程序,模仿已有wc.exe 的功能,并加以扩充,给出某程 ...

  2. 面向对象的练习总结(java)

    三次作业总结博客 l  前言 第一次题目集是我刚刚接触java所做的第一套习题,本次题目难度不大,题量较多,涉及的知识点主要是基础的语法知识,出题人的意图是让我们尽快熟悉java的语法,由于事先有c语 ...

  3. django_静态文件

    **************************************************************************************************** ...

  4. 探测域名解析依赖关系(运行问题解决No module named 'DNS')

    探测域名解析依赖关系 最近很懒,今天晚上才开始这个任务,然后发现我原来能跑起来的程序跑不起来了. 一直报错 ModuleNotFoundError: No module named 'DNS' 这个应 ...

  5. Windows11右键改Win10

    Win11改Win10右键模式 1.以管理员身份运行CMD控制台 2.在控制台中输入下列代码后回车执行 reg add "HKCU\Software\Classes\CLSID\{86ca1 ...

  6. idea连接服务器发包配置插件AlibabaCloudExplorer

    添加配置信息: 启动项选择:Edit Configurations,添加插件选择插件Deploy to Host

  7. What is UDS Service 0x10 - Diagnostic Session Control ?

    Why need the UDS Service 0x10? ECU在正常工作时会处于某一个会话模式下,上电后会自动进入默认会话模式,所以ECU启动后我们不需要输入0x10 01来进入该会话模式.EC ...

  8. msfvenom的使用

    msfvenom也只metasploit中的一个很强的工具,msfvenom生成木马非常的迅速可以各种版本的木马 该工具由msfpaylod和msfencodes的组成 生成木马是需要在本地监听,等待 ...

  9. 快速部署LAMP黄金架构,搭建disuz论坛

    快速部署LAMP架构 [root@zhanghuan ~]# iptables -F[root@zhanghuan ~]# systemctl stop firewalld[root@zhanghua ...

  10. 设计师必备:免费素材管理工具Billfish v3.0更新了!

    ​​Billfish是专门为设计师打造的图片收藏管理工具,可以轻松管理您的各种素材文件.Billfish是一个免费的软件,支持对大量的图片素材进行管理,提供多种快速的检索筛选功能,如颜色,格式,方向, ...