目录结构

dogsData.py

import json

import torch
import os, glob
import random, csv from PIL import Image
from torch.utils.data import Dataset, DataLoader from torchvision import transforms
from torchvision.transforms import InterpolationMode class Dogs(Dataset): def __init__(self, root, resize, mode):
super().__init__()
self.root = root
self.resize = resize
self.nameLable = {}
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root, name)):
continue
self.nameLable[name] = len(self.nameLable.keys()) if not os.path.exists(os.path.join(self.root, 'label.txt')):
with open(os.path.join(self.root, 'label.txt'), 'w', encoding='utf-8') as f:
f.write(json.dumps(self.nameLable, ensure_ascii=False)) # print(self.nameLable)
self.images, self.labels = self.load_csv('images.csv')
# print(self.labels) if mode == 'train':
self.images = self.images[:int(0.8*len(self.images))]
self.labels = self.labels[:int(0.8*len(self.labels))]
elif mode == 'val':
self.images = self.images[int(0.8*len(self.images)):int(0.9*len(self.images))]
self.labels = self.labels[int(0.8*len(self.labels)):int(0.9*len(self.labels))]
else:
self.images = self.images[int(0.9*len(self.images)):]
self.labels = self.labels[int(0.9*len(self.labels)):] def load_csv(self, filename): if not os.path.exists(os.path.join(self.root, filename)):
images = []
for name in self.nameLable.keys():
images += glob.glob(os.path.join(self.root, name, '*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
# print(len(images)) random.shuffle(images)
with open(os.path.join(self.root, filename), mode='w', newline='') as f:
writer = csv.writer(f)
for img in images:
name = img.split(os.sep)[-2]
label = self.nameLable[name]
writer.writerow([img, label])
print('csv write succesful') images, labels = [], []
with open(os.path.join(self.root, filename)) as f:
reader = csv.reader(f)
for row in reader:
img, label = row
label = int(label)
images.append(img)
labels.append(label) assert len(images) == len(labels) return images, labels def denormalize(self, x_hat):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# x_hot = (x-mean)/std
# x = x_hat * std = mean
# x : [c, w, h]
# mean [3] => [3, 1, 1]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1) x = x_hat * std + mean
return x def __len__(self):
return len(self.images) def __getitem__(self, idx):
# print(idx, len(self.images), len(self.labels))
img, label = self.images[idx], self.labels[idx] # 将字符串路径转换为tensor数据
# print(self.resize, type(self.resize))
tf = transforms.Compose([
lambda x: Image.open(x).convert('RGB'),
transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
transforms.RandomRotation(15),
transforms.CenterCrop(self.resize),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img = tf(img) label = torch.tensor(label) return img, label def main(): import visdom
import time viz = visdom.Visdom() # func1 通用
db = Dogs('Images_Data_Dog', 224, 'train')
# 取一张
# x,y = next(iter(db))
# print(x.shape, y)
# viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x')) # 取一个batch
loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)
print(len(loader))
print(db.nameLable)
# for x, y in loader:
# # print(x.shape, y)
# viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
# viz.text(str(y.numpy()), win='label', opts=dict(title='batch_y'))
# time.sleep(10) # # fun2
# import torchvision
# tf = transforms.Compose([
# transforms.Resize((64, 64)),
# transforms.RandomRotation(15),
# transforms.ToTensor(),
# ])
# db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)
# loader = DataLoader(db, batch_size=32, shuffle=True)
# print(len(loader))
# for x, y in loader:
# # print(x.shape, y)
# viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
# viz.text(str(y.numpy()), win='label', opts=dict(title='batch_y'))
# time.sleep(10) if __name__ == '__main__':
main()

utils.py

import torch
from torch import nn class Flatten(nn.Module):
def __init__(self):
super().__init__() def forward(self, x):
shape = torch.prod(torch.tensor(x.shape[1:])).item()
return x.view(-1, shape)

train.py

import os
import sys
base_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(base_path)
base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(base_path)
import torch
import visdom
from torch import optim, nn
import torchvision from torch.utils.data import DataLoader from dogs_train.utils import Flatten
from dogsData import Dogs from torchvision.models import resnet18 viz = visdom.Visdom() batchsz = 32
lr = 1e-3
epochs = 20 device = torch.device('cuda')
torch.manual_seed(1234) train_db = Dogs('Images_Data_Dog', 224, mode='train')
val_db = Dogs('Images_Data_Dog', 224, mode='val')
test_db = Dogs('Images_Data_Dog', 224, mode='test') train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=4)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2) def evalute(model, loader):
correct = 0
total = len(loader.dataset)
for x, y in loader:
x = x.to(device)
y = y.to(device)
with torch.no_grad():
logist = model(x)
pred = logist.argmax(dim=1)
correct += torch.eq(pred, y).sum().float().item()
return correct/total def main(): # model = ResNet18(5).to(device)
trained_model = resnet18(pretrained=True)
model = nn.Sequential(*list(trained_model.children())[:-1],
Flatten(), # [b, 512, 1, 1] => [b, 512]
nn.Linear(512, 27)
).to(device) x = torch.randn(2, 3, 224, 224).to(device)
print(model(x).shape) optimizer = optim.Adam(model.parameters(), lr=lr)
criteon = nn.CrossEntropyLoss() best_acc, best_epoch = 0, 0
global_step = 0
viz.line([0], [-1], win='loss', opts=dict(title='loss'))
viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
for epoch in range(epochs): for step, (x, y) in enumerate(train_loader):
x = x.to(device)
y = y.to(device) logits = model(x)
loss = criteon(logits, y) optimizer.zero_grad()
loss.backward()
optimizer.step()
viz.line([loss.item()], [global_step], win='loss', update='append')
global_step += 1
if epoch % 2 == 0:
val_acc = evalute(model, val_loader)
if val_acc > best_acc:
best_acc = val_acc
best_epoch = epoch
torch.save(model.state_dict(), 'best.mdl') viz.line([val_acc], [global_step], win='val_acc', update='append') print('best acc', best_acc, 'best epoch', best_epoch) model.load_state_dict(torch.load('best.mdl'))
print('loader from ckpt') test_acc = evalute(model, test_loader)
print(test_acc) if __name__ == '__main__':
main()

resnet18训练自定义数据集的更多相关文章

  1. MMDetection 快速开始,训练自定义数据集

    本文将快速引导使用 MMDetection ,记录了实践中需注意的一些问题. 环境准备 基础环境 Nvidia 显卡的主机 Ubuntu 18.04 系统安装,可见 制作 USB 启动盘,及系统安装 ...

  2. Scaled-YOLOv4 快速开始,训练自定义数据集

    代码: https://github.com/ikuokuo/start-scaled-yolov4 Scaled-YOLOv4 代码: https://github.com/WongKinYiu/S ...

  3. [炼丹术]YOLOv5训练自定义数据集

    YOLOv5训练自定义数据 一.开始之前的准备工作 克隆 repo 并在Python>=3.6.0环境中安装requirements.txt,包括PyTorch>=1.7.模型和数据集会从 ...

  4. yolov5训练自定义数据集

    yolov5训练自定义数据 step1:参考文献及代码 博客 https://blog.csdn.net/weixin_41868104/article/details/107339535 githu ...

  5. Tensorflow2 自定义数据集图片完成图片分类任务

    对于自定义数据集的图片任务,通用流程一般分为以下几个步骤: Load data Train-Val-Test Build model Transfer Learning 其中大部分精力会花在数据的准备 ...

  6. torch_13_自定义数据集实战

    1.将图片的路径和标签写入csv文件并实现读取 # 创建一个文件,包含image,存放方式:label pokemeon\\mew\\0001.jpg,0 def load_csv(self,file ...

  7. tensorflow从训练自定义CNN网络模型到Android端部署tflite

    网上有很多关于tensorflow lite在安卓端部署的教程,但是大多只讲如何把训练好的模型部署到安卓端,不讲如何训练,而实际上在部署的时候,需要知道训练模型时预处理的细节,这就导致了自己训练的模型 ...

  8. Yolo训练自定义目标检测

    Yolo训练自定义目标检测 参考darknet:https://pjreddie.com/darknet/yolo/ 1. 下载darknet 在 https://github.com/pjreddi ...

  9. pytorch加载语音类自定义数据集

    pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合 torch.u ...

  10. PyTorch 自定义数据集

    准备数据 准备 COCO128 数据集,其是 COCO train2017 前 128 个数据.按 YOLOv5 组织的目录: $ tree ~/datasets/coco128 -L 2 /home ...

随机推荐

  1. 关于FPGA学习的疑惑

    记录心路历程--第一次真正意义上的找工作. 最近在学习小梅哥的FPGA的过程中,总是学习着前面的东西忘着后面的,进行了过一次复习,我准备是在每个章节学完之后,再复习之前的章节,可这样是真的很耗费时间. ...

  2. vue 组件之间事件触发($emit)与event Bus($on)的用法说明

    组件之间事件触发 新增按钮组件: 操作按钮组合组件: 此时有个需求就是,无论是哪个按钮,如果改变了列表中的数据,列表需要实时更新数据. 此时就需要用到组件间的事件触发. 父子组件之间事件触发可以使用$ ...

  3. jq的用法

    选择页面中的元素,得到jQuery实例对象 ID选择器$("#save") 类选择器$(".class") 标签选择器$("div") 复合 ...

  4. win 子系统导入centos7

    之前在应用商店安装过ubuntu的,有钱的建议从商店购买 window配置 , 准备一个centos系统,我是从已有系统导出的,导出命令 tar -cvf ./centos.tar ./ --excl ...

  5. 跨域获取iframe页面的url

    一:跨域获取iframe页面的url 1.在使用iframe页面的js添加以下内容 <script> var host = window.location.href; var histor ...

  6. Hbase操作与编程使用

    1.任务: 列出HBase所有的表的相关信息,例如表名: 3. 编程完成以下指定功能(教材P92下): (1)createTable(String tableName, String[] fields ...

  7. char值转换为int怎么才能不是ASCII值

    直接将char类型的变量强制转换为int类型是不行的,那样只会传递变量所对应的ASCII码 怎么才能将char类型转换为int类型呢?String类型的可以通过方法转换为int类型.那是不是可以将ch ...

  8. manjaro安装指导

    本文"指导"二字口气有点大,是说给自己听的,指导我下次的安装. 正文: 1.安装系统:在清华大学开源站上下载KDE版(本机适用19版54内核无驱动问题),用rufus烧制启动盘,以 ...

  9. Python学习笔记--数据可视化的开头

    JSON数据格式的转换 示例: 若是有中文数据,可以在data后面加上ensure_ascii=False pyecharts模块 网站:https://gallery.pyecharts.org(有 ...

  10. 报错的大怨种来了--java.sql.SQLSyntaxErrorException: You have an error in your SQL syntax; check the manual that corresponds to your MySQL server version for the right syntax to use near '' at line 1

    频繁爆出这样的错误:java.sql.SQLSyntaxErrorException: You have an error in your SQL syntax; check the manual t ...