昨天有群友,在交流群【群号:392784757】里提到了这个需求,进行实现一下

V10 官方代码结构相较于 V8 稍微复杂一些

yolov10 是基于 v8 的代码完成开发,yolov10 进行了继承来简化代码开发

因此 V10 的代码修改 基本和 V8 这篇一致

https://blog.csdn.net/csy1021/article/details/134406419

但存在一些不同,会在下面提到

版本环境

YOLOv10 2024.07.01 版本

修改

trainer.py

1 添加 save_metrics_per_class()

在 save_metrics 函数后面,添加下面的 save_metrics_per_class 函数

def save_metrics_per_class(self, box):

    """Saves training metrics per class to a CSV file."""

    # ap ap50 p r 提示作用
keys = ['ap', 'ap50', 'p', 'r']
n = 4 + 1 # number of cols for i in box.ap_class_index:
cur_class = self.model.names[box.ap_class_index[i]]
save_path = self.save_dir.joinpath("result_" + cur_class + ".csv")
vals = [box.ap[i], box.ap50[i], box.p[i], box.r[i]]
s = '' if save_path.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header with open(save_path, 'a') as f:
f.write(s + ('%23.5g,' * n % tuple([self.epoch] + vals)).rstrip(',') + '\n')

2 validate() 修改

def validate(self):
"""
Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
"""
# metrics = self.validator(self)
metrics,box = self.validator(self)
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
if not self.best_fitness or self.best_fitness < fitness:
self.best_fitness = fitness
# return metrics, fitness
return metrics, fitness,box

找到【这里比 v8 的判断要多】

if (self.args.val and (((epoch+1) % self.args.val_period == 0) or (self.epochs - epoch) <= 10)) \
or final_epoch or self.stopper.possible_stop or self.stop:
self.metrics, self.fitness = self.validate()

修改为

if (self.args.val and (((epoch+1) % self.args.val_period == 0) or (self.epochs - epoch) <= 10)) \
or final_epoch or self.stopper.possible_stop or self.stop:
# self.metrics, self.fitness = self.validate()
self.metrics, self.fitness,box = self.validate()

3 找到 self.save_metrics



self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})

后面添加调用

self.save_metrics_per_class(box)

validator.py

找到 stats = self.get_stats()

改为 stats,box = self.get_stats()

找到 return {k: round(float(v), 5) for k, v in results.items()}

改为 return {k: round(float(v), 5) for k, v in results.items()}, box

val.py

get_stats() 【注意与 v8 不同】

def get_stats(self):
"""Returns metrics statistics and results dictionary."""
stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy
# if len(stats) and stats["tp"].any():
# if len(stats) and stats[0].any():
if len(stats) :
self.metrics.process(**stats)
self.nt_per_class = np.bincount(
stats["target_cls"].astype(int), minlength=self.nc
) # number of targets per class
# return self.metrics.results_dict
return self.metrics.results_dict,self.metrics.box

save_metrics_per_class() 函数 【注意与 v8 不同】


可以看到支持的指标有 all_ap (可用来计算其他ap指标),map,map50,f1,p ap,r mr ...
我在函数中使用的是 ap,ap50,p,r,需要其他的可以再添加
==注意:添加指标,使用的是 . 而不是 ["xxxx"] 如 box.ap[i] 而不是 box['ap'][i]==

def save_metrics_per_class(self, box):

    """Saves training metrics per class to a CSV file."""

    # ap ap50 p r 提示作用
keys = ['ap', 'ap50', 'p', 'r']
n = 4 + 1 # number of cols for i in box.ap_class_index:
cur_class = self.model.names[box.ap_class_index[i]]
save_path = self.save_dir.joinpath("result_" + cur_class + ".csv")
vals = [box.ap[i], box.ap50[i], box.p[i], box.r[i]]
s = '' if save_path.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header with open(save_path, 'a') as f:
f.write(s + ('%23.5g,' * n % tuple([self.epoch] + vals)).rstrip(',') + '\n')

注意!不同点

def get_stats(self):
"""Returns metrics statistics and results dictionary."""
stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy
# if len(stats) and stats["tp"].any(): # v10
# if len(stats) and stats[0].any(): # v8
if len(stats) : # 修改后
self.metrics.process(**stats)
self.nt_per_class = np.bincount(
stats["target_cls"].astype(int), minlength=self.nc
) # number of targets per class
# return self.metrics.results_dict
return self.metrics.results_dict,self.metrics.box

v10



v8

如果不修改 这个判断条件

if len(stats) and stats["tp"].any(): # v10
# if len(stats) and stats[0].any(): # v8 仅作对比
if len(stats) : # 修改后

可能会出现 前几次 epoch 数据不记录的问题 【这里也可能是和我的数据集有关,我测试了几次,增加 batch-size 发现仍然 stats["tp"] 仍然全为 false 过不了,后面 epoch 会正常 】这里大家可以自行测试后决定,如果正常,就不需要改

其他

增加训练过程各类指标打印(可选,默认开启是有条件的)

val.py 找到 print_results() 函数 在

LOGGER.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results())) 后面

添加

for i, c in enumerate(self.metrics.ap_class_index):
LOGGER.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))

有问题,欢迎留言、进群讨论或私聊:【群号:392784757】

YOLOv10添加输出各类别训练过程指标的更多相关文章

  1. (转)理解YOLOv2训练过程中输出参数含义

    最近有人问起在YOLOv2训练过程中输出在终端的不同的参数分别代表什么含义,如何去理解这些参数?本篇文章中我将尝试着去回答这个有趣的问题. 刚好现在我正在训练一个YOLOv2模型,拿这个真实的例子来讨 ...

  2. 理解YOLOv2训练过程中输出参数含义

    原英文地址: https://timebutt.github.io/static/understanding-yolov2-training-output/ 最近有人问起在YOLOv2训练过程中输出在 ...

  3. TensorFlow从1到2(七)线性回归模型预测汽车油耗以及训练过程优化

    线性回归模型 "回归"这个词,既是Regression算法的名称,也代表了不同的计算结果.当然结果也是由算法决定的. 不同于前面讲过的多个分类算法或者逻辑回归,线性回归模型的结果是 ...

  4. 深度学习笔记之关于基本思想、浅层学习、Neural Network和训练过程(三)

    不多说,直接上干货! 五.Deep Learning的基本思想 假设我们有一个系统S,它有n层(S1,…Sn),它的输入是I,输出是O,形象地表示为: I =>S1=>S2=>….. ...

  5. 交叉熵代价函数——当我们用sigmoid函数作为神经元的激活函数时,最好使用交叉熵代价函数来替代方差代价函数,以避免训练过程太慢

    交叉熵代价函数 machine learning算法中用得很多的交叉熵代价函数. 1.从方差代价函数说起 代价函数经常用方差代价函数(即采用均方误差MSE),比如对于一个神经元(单输入单输出,sigm ...

  6. 吴裕雄 python 神经网络——TensorFlow 训练过程的可视化 TensorBoard的应用

    #训练过程的可视化 ,TensorBoard的应用 #导入模块并下载数据集 import tensorflow as tf from tensorflow.examples.tutorials.mni ...

  7. DL4J实战之六:图形化展示训练过程

    欢迎访问我的GitHub 这里分类和汇总了欣宸的全部原创(含配套源码):https://github.com/zq2599/blog_demos 本篇概览 本篇是<DL4J实战>系列的第六 ...

  8. 深度学习训练过程中的学习率衰减策略及pytorch实现

    学习率是深度学习中的一个重要超参数,选择合适的学习率能够帮助模型更好地收敛. 本文主要介绍深度学习训练过程中的6种学习率衰减策略以及相应的Pytorch实现. 1. StepLR 按固定的训练epoc ...

  9. 从零搭建Pytorch模型教程(四)编写训练过程--参数解析

    ​  前言 训练过程主要是指编写train.py文件,其中包括参数的解析.训练日志的配置.设置随机数种子.classdataset的初始化.网络的初始化.学习率的设置.损失函数的设置.优化方式的设置. ...

  10. (原)torch的训练过程

    转载请注明出处: http://www.cnblogs.com/darkknightzh/p/6221622.html 参考网址: http://ju.outofmemory.cn/entry/284 ...

随机推荐

  1. Zynq 7000的3种IO

    概念 MIO MIO:多功能IO接口(分配在 GPIO 的 Bank0 和Bank1),属于Zynq的PS部分,在芯片外部有54个引脚.这些引脚可以用在GPIO.SPI.UART.TIMER.Ethe ...

  2. 使用jsp+servlet+mysql用户管理系统之用户注册-----------使用简单三层结构分析页面显示层(view),业务逻辑层(service),数据持久层(dao)

    View层:jsp+servlet: jsp: <%@ page language="java" contentType="text/html; charset=U ...

  3. 小组合作实现的基于 jsp,servlet,mysql 编写的学校管理系统

    基本完成的页面--源代码在<文件>中可下载 文件地址:https://i.cnblogs.com/Files.aspx 学生管理模块各功能已实现 百度网盘下载地址: 链接:https:// ...

  4. Spark3 学习【基于Java】4. Spark-Sql数据源

    通过DF,Spark可以跟大量各型的数据源(文件/数据库/大数据)进行交互.前面我们已经看到DF可以生成视图,这就是一个非常使用的功能. 简单的读写流程如下: 通过read方法拿到DataFrameR ...

  5. TCP/UDP 协议和 HTTP/FTP/SMTP 协议之间的区别

    前言 我们经常会听到HTTP协议.TCP/IP协议.UDP协议.Socket.Socket长连接.Socket连接池等字眼,然而它们之间的关系.区别及原理并不是所有人都能理解清楚. 计算机网络体系结构 ...

  6. Netcode for Entities里如何对Ghost进行可见性筛选(1.2.3版本)

    一行代码省流:SystemAPI.GetSingleton() 当你需要按照区域.距离或者场景对Ghost进行筛选的时候,Netcode for Entities里并没有类似FishNet那样方便的过 ...

  7. mysql Using join buffer (Block Nested Loop) join连接查询优化

    最近在优化链表查询的时候发现就算链接的表里面不到1w的数据链接查询也需要10多秒,这个速度简直不能忍受 通过EXPLAIN发现,extra中有数据是Using join buffer (Block N ...

  8. [oeasy]python0093_电子游戏起源_视频游戏_达特茅斯_Basic_家酿俱乐部

    编码进化 回忆上次内容 Ed Robert 的 创业之路 从 售卖 diy 组装配件 到进军 计算器市场 最后 发布 牛郎星8800 intel 8080 的出现 让 人人都有 自己的 个人电脑 Bi ...

  9. 题解:AT_abc359_c [ABC359C] Tile Distance 2

    背景 去中考了,比赛没打,来补一下题. 分析 这道题让我想起了这道题(连题目名称都是连着的),不过显然要简单一些. 这道题显然要推一些式子.我们发现,和上面提到的那道题目一样,沿着对角线走台阶,纵坐标 ...

  10. 从零开始写 Docker(十九)---增加 cgroup v2 支持

    本文为从零开始写 Docker 系列第十九篇,添加对 cgroup v2 的支持. 完整代码见:https://github.com/lixd/mydocker 欢迎 Star 推荐阅读以下文章对 d ...