yolov5训练自定义数据

step1:参考文献及代码

  1. 博客 https://blog.csdn.net/weixin_41868104/article/details/107339535
  2. github代码 https://github.com/DataXujing/YOLO-v5
  3. 官方代码 https://github.com/ultralytics/yolov5
  4. 官方教程 https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data

step2:准备数据集

  • --yolov5需要的数据集格式为txt格式的(即一个图片对应一个txt文件)
  • 参考文献1利用其将xml格式的代码转换成txt格式的代码

+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

  • 更新:2021/2/6 0:42
  • 找到了跟好的转换数据集的github库---->可应用与yolov3和yolov5的训练
  • github地址:https://github.com/pprp/voc2007_for_yolo_torch
  • 如果自己的图片格式不是.jpg需要修改tools/make_for_yolov3_torch.py里面的代码

+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

训练

  • 利用参考文献而将参考文献1中的labels中的txt数据集和images中的图片放入到参考文献二中

附录:

用于抽取训练集和测试集

  1. 抽取图片:抽取.py
import os
import random
import shutil # source_file:源路径, target_ir:目标路径
def cover_files(source_dir, target_ir):
for file in os.listdir(source_dir):
source_file = os.path.join(source_dir, file) if os.path.isfile(source_file):
shutil.copy(source_file, target_ir) def ensure_dir_exists(dir_name):
"""Makes sure the folder exists on disk.
Args:
dir_name: Path string to the folder we want to create.
"""
if not os.path.exists(dir_name):
os.makedirs(dir_name) def moveFile(file_dir, save_dir):
ensure_dir_exists(save_dir)
path_dir = os.listdir(file_dir)
filenumber = len(path_dir)
rate = 0.1 # 自定义抽取图片的比例,比方说100张抽10张,那就是0.1
picknumber = int(filenumber * rate) # 按照rate比例从文件夹中取一定数量图片
sample = random.sample(path_dir, picknumber) # 随机选取picknumber数量的样本图片
# print (sample)
for name in sample:
shutil.move(file_dir + name, save_dir + name) #切记win10路径D:你的路径\\,最后一定要有\\才能进入目标文件
if __name__ == '__main__':
file_dir = 'G:\\ECANet-master\\train\\0\\' # 源图片文件夹路径
save_dir = 'G:\\ECANet-master\\train\\00\\' # 移动到目标文件夹路径
moveFile(file_dir, save_dir)

json2xml:(json格式转换成xml格式)

  • 将下面三个文件放入到json_to_xml文件夹下
  1. create_xml_anno.py
# -*- coding: utf-8 -*-
from xml.dom.minidom import Document class CreateAnno:
def __init__(self,):
self.doc = Document() # 创建DOM文档对象
self.anno = self.doc.createElement('annotation') # 创建根元素
self.doc.appendChild(self.anno) self.add_folder()
self.add_path()
self.add_source()
self.add_segmented() # self.add_filename()
# self.add_pic_size(width_text_str=str(width), height_text_str=str(height), depth_text_str=str(depth)) def add_folder(self, floder_text_str='JPEGImages'):
floder = self.doc.createElement('floder') ##建立自己的开头
floder_text = self.doc.createTextNode(floder_text_str) ##建立自己的文本信息
floder.appendChild(floder_text) ##自己的内容
self.anno.appendChild(floder) def add_filename(self, filename_text_str='00000.jpg'):
filename = self.doc.createElement('filename')
filename_text = self.doc.createTextNode(filename_text_str)
filename.appendChild(filename_text)
self.anno.appendChild(filename) def add_path(self, path_text_str="None"):
path = self.doc.createElement('path')
path_text = self.doc.createTextNode(path_text_str)
path.appendChild(path_text)
self.anno.appendChild(path) def add_source(self, database_text_str="Unknow"):
source = self.doc.createElement('source')
database = self.doc.createElement('database')
database_text = self.doc.createTextNode(database_text_str) # 元素内容写入
database.appendChild(database_text)
source.appendChild(database)
self.anno.appendChild(source) def add_pic_size(self, width_text_str="0", height_text_str="0", depth_text_str="3"):
size = self.doc.createElement('size')
width = self.doc.createElement('width')
width_text = self.doc.createTextNode(width_text_str) # 元素内容写入
width.appendChild(width_text)
size.appendChild(width) height = self.doc.createElement('height')
height_text = self.doc.createTextNode(height_text_str)
height.appendChild(height_text)
size.appendChild(height) depth = self.doc.createElement('depth')
depth_text = self.doc.createTextNode(depth_text_str)
depth.appendChild(depth_text)
size.appendChild(depth) self.anno.appendChild(size) def add_segmented(self, segmented_text_str="0"):
segmented = self.doc.createElement('segmented')
segmented_text = self.doc.createTextNode(segmented_text_str)
segmented.appendChild(segmented_text)
self.anno.appendChild(segmented) def add_object(self,
name_text_str="None",
xmin_text_str="0",
ymin_text_str="0",
xmax_text_str="0",
ymax_text_str="0",
pose_text_str="Unspecified",
truncated_text_str="0",
difficult_text_str="0"):
object = self.doc.createElement('object')
name = self.doc.createElement('name')
name_text = self.doc.createTextNode(name_text_str)
name.appendChild(name_text)
object.appendChild(name) pose = self.doc.createElement('pose')
pose_text = self.doc.createTextNode(pose_text_str)
pose.appendChild(pose_text)
object.appendChild(pose) truncated = self.doc.createElement('truncated')
truncated_text = self.doc.createTextNode(truncated_text_str)
truncated.appendChild(truncated_text)
object.appendChild(truncated) difficult = self.doc.createElement('difficult')
difficult_text = self.doc.createTextNode(difficult_text_str)
difficult.appendChild(difficult_text)
object.appendChild(difficult) bndbox = self.doc.createElement('bndbox')
xmin = self.doc.createElement('xmin')
xmin_text = self.doc.createTextNode(xmin_text_str)
xmin.appendChild(xmin_text)
bndbox.appendChild(xmin) ymin = self.doc.createElement('ymin')
ymin_text = self.doc.createTextNode(ymin_text_str)
ymin.appendChild(ymin_text)
bndbox.appendChild(ymin) xmax = self.doc.createElement('xmax')
xmax_text = self.doc.createTextNode(xmax_text_str)
xmax.appendChild(xmax_text)
bndbox.appendChild(xmax) ymax = self.doc.createElement('ymax')
ymax_text = self.doc.createTextNode(ymax_text_str)
ymax.appendChild(ymax_text)
bndbox.appendChild(ymax)
object.appendChild(bndbox) self.anno.appendChild(object) def get_anno(self):
return self.anno def get_doc(self):
return self.doc def save_doc(self, save_path):
with open(save_path, "w") as f:
self.doc.writexml(f, indent='\t', newl='\n', addindent='\t', encoding='utf-8')
  1. main.py
import os
from tqdm import tqdm from read_json import ReadAnno
from create_xml_anno import CreateAnno def json_transform_xml(json_path, xml_path, process_mode="polygon"):
json_path = json_path
json_anno = ReadAnno(json_path, process_mode=process_mode)
width, height = json_anno.get_width_height()
filename = json_anno.get_filename()
coordis = json_anno.get_coordis() xml_anno = CreateAnno()
xml_anno.add_filename(filename)
xml_anno.add_pic_size(width_text_str=str(width), height_text_str=str(height), depth_text_str=str(3))
for xmin,ymin,xmax,ymax,label in coordis:
xml_anno.add_object(name_text_str=str(label),
xmin_text_str=str(int(xmin)),
ymin_text_str=str(int(ymin)),
xmax_text_str=str(int(xmax)),
ymax_text_str=str(int(ymax))) xml_anno.save_doc(xml_path) if __name__ == "__main__":
root_json_dir = r"/home/aibc/ouyang/temp_dataset/jjson"
root_save_xml_dir = r"/home/aibc/ouyang/temp_dataset/xml"
for json_filename in tqdm(os.listdir(root_json_dir)):
json_path = os.path.join(root_json_dir, json_filename)
save_xml_path = os.path.join(root_save_xml_dir, json_filename.replace(".json", ".xml"))
json_transform_xml(json_path, save_xml_path, process_mode="polygon")
  1. read_json.py
# -*- coding: utf-8 -*-
import numpy as np
import json class ReadAnno:
def __init__(self, json_path, process_mode="rectangle"):
self.json_data = json.load(open(json_path))
self.filename = self.json_data['imagePath']
self.width = self.json_data['imageWidth']
self.height = self.json_data['imageHeight'] self.coordis = []
assert process_mode in ["rectangle", "polygon"]
if process_mode == "rectangle":
self.process_polygon_shapes()
elif process_mode == "polygon":
self.process_polygon_shapes() def process_rectangle_shapes(self):
for single_shape in self.json_data['shapes']:
bbox_class = single_shape['label']
xmin = single_shape['points'][0][0]
ymin = single_shape['points'][0][1]
xmax = single_shape['points'][1][0]
ymax = single_shape['points'][1][1]
self.coordis.append([xmin,ymin,xmax,ymax,bbox_class]) def process_polygon_shapes(self):
for single_shape in self.json_data['shapes']:
bbox_class = single_shape['label']
temp_points = []
for couple_point in single_shape['points']:
x = float(couple_point[0])
y = float(couple_point[1])
temp_points.append([x,y])
temp_points = np.array(temp_points)
xmin, ymin = temp_points.min(axis=0)
xmax, ymax = temp_points.max(axis=0)
self.coordis.append([xmin,ymin,xmax,ymax,bbox_class]) def get_width_height(self):
return self.width, self.height def get_filename(self):
return self.filename def get_coordis(self):
return self.coordis

yolov5训练自定义数据集的更多相关文章

  1. [炼丹术]YOLOv5训练自定义数据集

    YOLOv5训练自定义数据 一.开始之前的准备工作 克隆 repo 并在Python>=3.6.0环境中安装requirements.txt,包括PyTorch>=1.7.模型和数据集会从 ...

  2. Scaled-YOLOv4 快速开始,训练自定义数据集

    代码: https://github.com/ikuokuo/start-scaled-yolov4 Scaled-YOLOv4 代码: https://github.com/WongKinYiu/S ...

  3. MMDetection 快速开始,训练自定义数据集

    本文将快速引导使用 MMDetection ,记录了实践中需注意的一些问题. 环境准备 基础环境 Nvidia 显卡的主机 Ubuntu 18.04 系统安装,可见 制作 USB 启动盘,及系统安装 ...

  4. PyTorch 自定义数据集

    准备数据 准备 COCO128 数据集,其是 COCO train2017 前 128 个数据.按 YOLOv5 组织的目录: $ tree ~/datasets/coco128 -L 2 /home ...

  5. torch_13_自定义数据集实战

    1.将图片的路径和标签写入csv文件并实现读取 # 创建一个文件,包含image,存放方式:label pokemeon\\mew\\0001.jpg,0 def load_csv(self,file ...

  6. tensorflow从训练自定义CNN网络模型到Android端部署tflite

    网上有很多关于tensorflow lite在安卓端部署的教程,但是大多只讲如何把训练好的模型部署到安卓端,不讲如何训练,而实际上在部署的时候,需要知道训练模型时预处理的细节,这就导致了自己训练的模型 ...

  7. Tensorflow2 自定义数据集图片完成图片分类任务

    对于自定义数据集的图片任务,通用流程一般分为以下几个步骤: Load data Train-Val-Test Build model Transfer Learning 其中大部分精力会花在数据的准备 ...

  8. Yolo训练自定义目标检测

    Yolo训练自定义目标检测 参考darknet:https://pjreddie.com/darknet/yolo/ 1. 下载darknet 在 https://github.com/pjreddi ...

  9. pytorch加载语音类自定义数据集

    pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合 torch.u ...

随机推荐

  1. 简单才是美! SpringBoot+JPA

    SpringBoot 急速构建项目,真的是用了才知道,搭配JPA作为持久层,一简到底!下面记录项目的搭建,后续会添加NOSQL redis,搜索引擎elasticSearch,等等,什么不过时就加什么 ...

  2. python根据日期判断星期几(超简洁)

    代码: from datetime import datetime def getWeek(week): print(date + "是星期" + str(week + 1)) d ...

  3. hibernate 联合主键 composite-id

    如果表使用联合主键(一个表有两个以上的主键),你可以映射类的多个属性为标识符属性.如:<composite-id>元素接受<key-property> 属性映射(单表映射)和& ...

  4. EMS查看及修改邮箱发送和接受邮件大小的方法

    默认情况下,新建用户邮箱没有进行单独设置,故用户邮箱默认值为"Unlimited"(未限制),即遵从全局设置(继承邮箱数据库策略).通过EMS查看用户邮箱发送和接受邮件大小的默认值 ...

  5. tracert命令简述

    1. 路由跟踪在线Tracert工具说明 Tracert(跟踪路由)是路由跟踪实用程序,用于确定 IP 数据报访问目标所采取的路径.Tracert 命令用 IP 生存时间 (TTL) 字段和 ICMP ...

  6. Windows中Nginx配置nginx.conf不生效解决方法(路径映射)

    Windows中Nginx配置nginx.conf不生效解决方法 今天在做Nginx项目的时候,要处理一个路径映射问题, location /evaluate/ { proxy_pass http:/ ...

  7. 机器学习系列:LightGBM 可视化调参

    大家好,在100天搞定机器学习|Day63 彻底掌握 LightGBM一文中,我介绍了LightGBM 的模型原理和一个极简实例.最近我发现Huggingface与Streamlit好像更配,所以就开 ...

  8. Shiro+springboot+mybatis(md5+salt+散列)认证与授权-01

    这个小项目包含了注册与登录,使用了springboot+mybatis+shiro的技术栈:当用户在浏览器登录时发起请求时,首先这一系列的请求会被拦截器进行拦截(ShiroFilter),然后拦截器根 ...

  9. 不用关闭重启cad及不用更改快捷方式或者版本号c#调试cad插件

    c#开发的cad插件需要重启cad才能进行调试,然而高版本的cad启动比较慢特别是一些古董电脑,而且cad有重启次数限制.针对不用重启cad调试已经有成熟的方案了,但是需要调试一次修改一次快捷方式或者 ...

  10. 安卓记账本开发学习day6之进度

    完成了基本的收入与支出添加,支持输入备注 以及备注的输入和金额的遮挡显示切换