相比于之前写的ResNet18,下面的ResNet50写得更加工程化一点,这还适用与其他分类。

我的代码文件结构

  

1. 数据处理

  首先已经对数据做好了分类

  

  文件夹结构是这样

  开始划分数据集

  split_data.py

import os
import random
import shutil def move_file(target_path, save_train_path, save_val_pathm, scale=0.1): file_list = os.listdir(target_path)
random.shuffle(file_list) number = int(len(file_list) * scale)
train_list = file_list[number:]
val_list = file_list[:number] for file in train_list:
target_file_path = os.path.join(target_path, file)
save_file_path = os.path.join(save_train_path, file)
shutil.copyfile(target_file_path, save_file_path)
for file in val_list:
target_file_path = os.path.join(target_path, file)
save_file_path = os.path.join(save_val_pathm, file)
shutil.copyfile(target_file_path, save_file_path) def split_classify_data(base_path, save_path, scale=0.1):
folder_list = os.listdir(base_path)
for folder in folder_list:
target_path = os.path.join(base_path, folder)
save_train_path = os.path.join(save_path, 'train', folder)
save_val_path = os.path.join(save_path, 'val', folder)
if not os.path.exists(save_train_path):
os.makedirs(save_train_path)
if not os.path.exists(save_val_path):
os.makedirs(save_val_path)
move_file(target_path, save_train_path, save_val_path, scale)
print(folder, 'finish!') if __name__ == '__main__':
base_path = r'C:\Users\Administrator.DESKTOP-161KJQD\Desktop\save_dir'
save_path = r'C:\Users\Administrator.DESKTOP-161KJQD\Desktop\dog_cat'
# 验证集比例
scale = 0.1
split_classify_data(base_path, save_path, scale)

  运行完以上代码的到的文件夹结构

    

  一个训练集数据,一个验证集数据

  

2.数据集的导入

  我这个文件写了一个数据集的导入和一个学习率更新的函数。数据导入是通用的

  tools.py

import os
import time import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.autograd.variable import Variable
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ExponentialLR, LambdaLR
from torchvision.models import ResNet50_Weights
from tqdm import tqdm
from classify_cfg import * mean = MEAN
std = STD def get_dataset(base_dir='', input_size=160):
dateset = dict()
transform_train = transforms.Compose([
# 分辨率重置为input_size
transforms.Resize(input_size),
transforms.RandomRotation(15),
# 对加载的图像作归一化处理, 并裁剪为[input_sizexinput_sizex3]大小的图像(因为这图片像素不一致直接统一)
transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
]) transform_val = transforms.Compose([
transforms.Resize(input_size),
transforms.RandomRotation(15),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
base_dir_train = os.path.join(base_dir, 'train')
train_dataset = datasets.ImageFolder(root=base_dir_train, transform=transform_train)
# print("train_dataset=" + repr(train_dataset[1][0].size()))
# print("train_dataset.class_to_idx=" + repr(train_dataset.class_to_idx))
# print(train_dataset.classes)
classes = train_dataset.classes
# classes = train_dataset.class_to_idx
classes_num = len(train_dataset.classes) base_dir_val = os.path.join(base_dir, 'val')
val_dataset = datasets.ImageFolder(root=base_dir_val, transform=transform_val) dateset['train'] = train_dataset
dateset['val'] = val_dataset return dateset, classes, classes_num def update_lr(epoch, epochs):
"""
假设开始的学习率lr是0.001,训练次数epochs是100
当epoch<33时是lr * 1
当33<=epoch<=66 时是lr * 0.5
当66<epoch时是lr * 0.1
"""
if epoch == 0 or epochs // 3 > epoch:
return 1
elif (epochs // 3 * 2 >= epoch) and (epochs // 3 <= epoch):
return 0.5
else:
return 0.1

3.训练模型

  数据集导入好了以后,选择模型,选择优化器等等,然后开始训练。

  mytrain.py

import os
import time import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.autograd.variable import Variable
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ExponentialLR, LambdaLR
from torchvision.models import ResNet50_Weights
# from tqdm import tqdm
from classify_cfg import *
from tools import get_dataset, update_lr def train(model, dateset, epochs, batch_size, device, optimizer, scheduler, criterion, save_path):
train_loader = DataLoader(dateset.get('train'), batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dateset.get('val'), batch_size=batch_size, shuffle=True) # 保存为tensorboard文件
write = SummaryWriter(save_path)
# 训练过程写入txt
f = open(os.path.join(save_path, 'log.txt'), 'w', encoding='utf-8') best_acc = 0
for epoch in range(epochs):
train_correct = 0.0
model.train()
sum_loss = 0.0
accuracy = -1
total_num = len(train_loader.dataset)
# print(total_num, len(train_loader))
# loop = tqdm(enumerate(train_loader), total=len(train_loader))
batch_count = 0
for batch_idx, (data, target) in enumerate(train_loader):
start_time = time.time()
data, target = Variable(data).to(device), Variable(target).to(device)
output = model(data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step() print_loss = loss.data.item()
sum_loss += print_loss
train_predict = torch.max(output.data, 1)[1]
if torch.cuda.is_available():
train_correct += (train_predict.cuda() == target.cuda()).sum()
else:
train_correct += (train_predict == target).sum()
accuracy = (train_correct / total_num) * 100
# loop.set_description(f'Epoch [{epoch+1}/{epochs}]')
# loop.set_postfix(loss=loss.item(), acc='{:.3f}'.format(accuracy))
batch_count += len(data)
end_time = time.time()
s = f'Epoch:[{epoch+1}/{epochs}] Batch:[{batch_count}/{total_num}] train_acc: {"{:.2f}".format(accuracy)} ' \
f'train_loss: {"{:.3f}".format(loss.item())} time: {int((end_time-start_time)*1000)} ms'
# print(f'Epoch:[{epoch+1}/{epochs}]', f'Batch:[{batch_count}/{total_num}]',
# 'train_acc:', '{:.2f}'.format(accuracy), 'train_loss:', '{:.3f}'.format(loss.item()),
# 'time:', f'{int((end_time-start_time)*1000)} ms')
print(s)
f.write(s+'\n') write.add_scalar('train_acc', accuracy, epoch)
write.add_scalar('train_loss', loss.item(), epoch)
# print(optimizer.param_groups[0]['lr'])
scheduler.step()
if best_acc < accuracy:
best_acc = accuracy
torch.save(model, os.path.join(save_path, 'best.pt')) if epoch+1 == epochs:
torch.save(model, os.path.join(save_path, 'last.pt')) # 预测验证集
# if (epoch+1) % 5 == 0 or epoch+1 == epochs:
model.eval()
test_loss = 0.0
correct = 0.0
total_num = len(val_loader.dataset)
# print(total_num, len(val_loader))
with torch.no_grad():
for data, target in val_loader:
data, target = Variable(data).to(device), Variable(target).to(device)
output = model(data)
loss = criterion(output, target)
_, pred = torch.max(output.data, 1)
if torch.cuda.is_available():
correct += torch.sum(pred.cuda() == target.cuda())
else:
correct += torch.sum(pred == target)
print_loss = loss.data.item()
test_loss += print_loss
acc = correct / total_num * 100
avg_loss = test_loss / len(val_loader)
s = f"val acc: {'{:.2f}'.format(acc)} val loss: {'{:.3f}'.format(avg_loss)}"
# print('val acc: ', '{:.2f}'.format(acc), 'val loss: ', '{:.3f}'.format(avg_loss))
print(s)
f.write(s+'\n')
write.add_scalar('val_acc', acc, epoch)
write.add_scalar('val_loss', avg_loss, epoch)
# loop.set_postfix(val_loss='{:.3f}'.format(avg_loss), val_acc='{:.3f}'.format(acc)) f.close() if __name__ == '__main__':
device = DEVICE
epochs = EPOCHS
batch_size = BATCH_SIZE
input_size = INPUT_SIZE
lr = LR
# ---------------------------训练-------------------------------------
# 图片的路径
base_dir = r'C:\Users\Administrator.DESKTOP-161KJQD\Desktop\dog_cat'
# 保存的路径
save_path = r'C:\Users\Administrator.DESKTOP-161KJQD\Desktop\dog_cat_save'
dateset, classes, classes_num = get_dataset(base_dir, input_size=input_size)
# model = torchvision.models.resnet50(pretrained=True)
model = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, classes_num)
model.to(DEVICE)
# # 损失函数,交叉熵损失函数
criteon = nn.CrossEntropyLoss()
# 选择优化器
optimizer = optim.SGD(model.parameters(), lr=lr)
# 学习率更新
# scheduler = ExponentialLR(optimizer, gamma=0.9)
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: update_lr(epoch, epochs))
# 开始训练
train(model, dateset, epochs, batch_size, device, optimizer, scheduler, criteon, save_path)
# 将label保存起来
with open(os.path.join(save_path, 'labels.txt'), 'w', encoding='utf-8') as f:
f.write(f'{classes_num} {classes}')

  训练结束以后,在保存路径下会得到下面的文件

  

  最好的模型,最后一次的模型,标签的列表,训练的记录和tensorboard记录

  在该路径下执行 tensorboard --logdir=.

  

  然后在浏览器打开给出的地址,即可看到数据训练过程的绘图

4.对图片进行预测

  考虑对于用户来说,用户是在网页或者手机上上传一张图片进行预测,所以这边是采用二进制数据。

  mypredict.py

  

import cv2
import numpy as np
import torch from classify_cfg import * def img_process(img_betys, img_size, device): img_arry = np.asarray(bytearray(img_betys), dtype='uint8')
# im0 = cv2.imread(img_betys)
im0 = cv2.imdecode(img_arry, cv2.IMREAD_COLOR)
image = cv2.resize(im0, (img_size, img_size))
image = np.float32(image) / 255.0
image[:, :, ] -= np.float32(mean)
image[:, :, ] /= np.float32(std)
image = image.transpose((2, 0, 1))
im = torch.from_numpy(image).unsqueeze(0)
im = im.to(device)
return im def predict(model_path, img, device):
model = torch.load(model_path)
model.to(device)
model.eval()
predicts = model(img)
# print(predicts)
_, preds = torch.max(predicts, 1)
pred = torch.squeeze(preds)
# print(pred)
return pred if __name__ == '__main__':
mean = MEAN
std = STD
device = DEVICE
classes = ['狗', '猫']
# # 预测
model_path = r'C:\Users\Administrator.DESKTOP-161KJQD\Desktop\dog_cat_save\best.pt'
img_path = r'C:\Users\Administrator.DESKTOP-161KJQD\Desktop\save_dir\狗\000000.jpg'
with open(img_path, 'rb') as f:
img_betys = f.read()
img =img_process(img_betys, 160, device)
# print(img.shape)
# print(img)
pred = predict(model_path, img, device)
print(classes[int(pred)])

ResNet50的猫狗分类训练及预测的更多相关文章

  1. paddlepaddle实现猫狗分类

    目录 1.预备工作 1.1 数据集准备 1.2 数据预处理 2.训练 2.1 模型 2.2 定义训练 2.3 训练 3.预测 4.参考文献 声明:这是我的个人学习笔记,大佬可以点评,指导,不喜勿喷.实 ...

  2. 人工智能——CNN卷积神经网络项目之猫狗分类

    首先先导入所需要的库 import sys from matplotlib import pyplot from tensorflow.keras.utils import to_categorica ...

  3. 用tensorflow迁移学习猫狗分类

    笔者这几天在跟着莫烦学习TensorFlow,正好到迁移学习(至于什么是迁移学习,看这篇),莫烦老师做的是预测猫和老虎尺寸大小的学习.作为一个有为的学生,笔者当然不能再预测猫啊狗啊的大小啦,正好之前正 ...

  4. 猫狗分类--Tensorflow实现

    贴一张自己画的思维导图  数据集准备 kaggle猫狗大战数据集(训练),微软的不需要FQ 12500张cat 12500张dog 生成图片路径和标签的List step1:获取D:/Study/Py ...

  5. 1.keras实现-->自己训练卷积模型实现猫狗二分类(CNN)

    原数据集:包含 25000张猫狗图像,两个类别各有12500 新数据集:猫.狗 (照片大小不一样) 训练集:各1000个样本 验证集:各500个样本 测试集:各500个样本 1= 狗,0= 猫 # 将 ...

  6. 使用pytorch完成kaggle猫狗图像识别

    kaggle是一个为开发商和数据科学家提供举办机器学习竞赛.托管数据库.编写和分享代码的平台,在这上面有非常多的好项目.好资源可供机器学习.深度学习爱好者学习之用.碰巧最近入门了一门非常的深度学习框架 ...

  7. Kaggle系列1:手把手教你用tensorflow建立卷积神经网络实现猫狗图像分类

    去年研一的时候想做kaggle上的一道题目:猫狗分类,但是苦于对卷积神经网络一直没有很好的认识,现在把这篇文章的内容补上去.(部分代码参考网上的,我改变了卷积神经网络的网络结构,其实主要部分我加了一层 ...

  8. pytorch实现kaggle猫狗识别

    参考:https://blog.csdn.net/weixin_37813036/article/details/90718310 kaggle是一个为开发商和数据科学家提供举办机器学习竞赛.托管数据 ...

  9. 深度学习原理与框架-猫狗图像识别-卷积神经网络(代码) 1.cv2.resize(图片压缩) 2..get_shape()[1:4].num_elements(获得最后三维度之和) 3.saver.save(训练参数的保存) 4.tf.train.import_meta_graph(加载模型结构) 5.saver.restore(训练参数载入)

    1.cv2.resize(image, (image_size, image_size), 0, 0, cv2.INTER_LINEAR) 参数说明:image表示输入图片,image_size表示变 ...

  10. 【猫狗数据集】pytorch训练猫狗数据集之创建数据集

    猫狗数据集的分为训练集25000张,在训练集中猫和狗的图像是混在一起的,pytorch读取数据集有两种方式,第一种方式是将不同类别的图片放于其对应的类文件夹中,另一种是实现读取数据集类,该类继承tor ...

随机推荐

  1. css穿透

    https://www.cnblogs.com/linjiangxian/p/13183412.html

  2. RKO组——冲刺随笔(1)

    这个作业属于哪个课程 至诚软工实践F班 这个作业要求在哪里 第五次团队作业:项目冲刺 这个作业的目标 记录冲刺计划.要求包括当天会议照片.会议内容以及项目燃尽图(项目进度) 1.昨日进展 小组成员讨论 ...

  3. vim入门与快捷键使用

    1.移动 上下左右 jkhl 2.模式选择 命令模式 插入模式 字符选择模式 3.剪切复制 粘贴:p 复制 y 选择 v 进入选择模式 4. 撤销恢复 撤销 u 恢复 ctrl + r 5. 删除 d ...

  4. 【基础知识】C++算法基础(快速排序)

    快速排序: 1.执行流程(一趟快排): 2.一趟快排的结果:获得一个枢纽,在此左边皆小于此数,在此右边皆大于此数,因此可以继续使用递归获得最终的序列.

  5. Shell脚本实现模拟并发及并发数控制

    #!/bin/bash #by inmoonlight@163.com #下面的代码控制并发数.其实是利用令牌原理实现 #一个线程要运行,首先要拿到令牌在该代码中即read一行数据,读取不到就会暂停, ...

  6. Oracle查询优化经验

    1.ORACLE采用自下而上的顺序解析WHERE子句,根据这个原理,表之间的连接必须写在其他WHERE条件之前, 那些可以过滤掉最大数量记录的条件必须写在WHERE子句的末尾. (低效,执行时间156 ...

  7. 115、商城业务---分布式事务---使用Springboot提供的Seata解决分布式事务

    https://seata.io/zh-cn/ seata使用Seata AT模式控制分布式事务的步骤: 1.每一个想控制分布式事务的服务对应的数据库都需要创建一个UNDO_LOG 表 CREATE ...

  8. JSP基础语法笔记一

    JSP是一种脚本语言. 代码片段,方法内容: <% 代码片段 %> <jsp:scriptlet> 代码片段 </jsp:scriptlet> 设置编码格式,正常显 ...

  9. 痞子衡嵌入式:在i.MXRT1060-EVK上利用memtester程序给SDRAM做压力测试

    大家好,我是痞子衡,是正经搞技术的痞子.今天痞子衡给大家介绍的是在i.MXRT1060-EVK上利用memtester程序给SDRAM做压力测试. 我们知道恩智浦i.MXRT1xxx系列是高性能MCU ...

  10. MySQL学习(十一)为什么不推荐使用uuid和雪花id

    参考博客:https://www.cnblogs.com/wyq178/p/12548864.html 自增的主键的值是顺序的,所以Innodb把每一条记录都存储在一条记录的后面.当达到页面的最大填充 ...