之前,对SSD的论文进行了解读,可以回顾之前的博客:https://www.cnblogs.com/dengshunge/p/11665929.html

为了加深对SSD的理解,因此对SSD的源码进行了复现,主要参考的github项目是ssd.pytorch。同时,我自己对该项目增加了大量注释:https://github.com/Dengshunge/mySSD_pytorch

搭建SSD的项目,可以分成以下三个部分:

  1. 数据读取;
  2. 网络搭建
  3. 损失函数的构建
  4. 网络测试

接下来,本篇博客重点分析数据读取


一、整体框架

SSD的数据读取环节,同样适用于大部分目标检测的环节,具有通用性。为了方便理解,本项目以VOC2007+2012为例。因此,数据读取环节,通常是按照以下步骤展开进行:

  1. 函数入口;
  2. 图片的读取和xml文件的读取;
  3. 对GT框进行处理;
  4. 数据增强;
  5. 辅助函数。

二、具体实现细节

2.1 函数入口

数据读取的函数入口在train.py文件中:

if args.dataset == 'VOC':
train_dataset = VOCDetection(root=args.dataset_root)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, num_workers=4,
collate_fn=detection_collate, shuffle=True, pin_memory=True)

可以看到,首先通过函数 VOCDetection() 来对VOC数据集进行初始化,再使用函数 DataLoader() 来实现对数据集的读取。这一步与常见的分类网络形式相同,但不同的是,多了collate_fn这一参数,后续会对此进行说明。

2.2 图片与xml文件读取

首先,我们先看看函数VOCDetection() 的初始化函数__init__()。在__init__中包含了需要传入的几个参数,image_sets(表示VOC使用到的数据集),transform(数据增强的方式),target_transform(GT框的处理方式)。

class VOCDetection():
"""VOC Detection Dataset Object input is image, target is annotation Arguments:
root (string): filepath to VOCdevkit folder.
image_set (string): imageset to use (eg. 'train', 'val', 'test')
transform (callable, optional): transformation to perform on the input image
图片预处理的方式,这里使用了大量数据增强的方式
target_transform (callable, optional): transformation to perform on the
target `annotation`
(eg: take in caption string, return tensor of word indices)
真实框预处理的方式
""" def __init__(self, root,
image_sets=[('', 'trainval'), ('', 'trainval')],
transform=SSDAugmentation(size=config.voc['min_dim'], mean=config.MEANS),
target_transform=VOCAnnotationTransform()):
self.root = root
self.image_set = image_sets
self.transform = transform
self.target_transform = target_transform
self._annopath = os.path.join('%s', 'Annotations', '%s.xml')
self._imgpath = os.path.join('%s', 'JPEGImages', '%s.jpg')
self.ids = []
# 使用VOC2007和VOC2012的train作为训练集
for (year, name) in self.image_set:
rootpath = os.path.join(self.root, 'VOC' + year)
for line in open(os.path.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
self.ids.append([rootpath, line[:-1]])

首先,为什么image_sets是这样的形式呢?因为VOC具有固定的文件夹路径,利用这个参数和配合路径读取,可以读取到txt文件,该txt文件用于制定哪些图片用于训练。此外,还需要设置参数self.ids,这个list用于存储文件的路径,由两列组成,"VOC/2007"和图片名称。通过这两个参数,后续可以配合函数_annopath()和_imgpath()可以读取到对应图片的路径和xml文件。

在pytorch中,还需要相应的函数来对读取图片与返回结果,如下所示。其中,重点是pull_iterm函数。

    def __getitem__(self, index):
im, gt = self.pull_item(index)
return im, gt def __len__(self):
return len(self.ids) def pull_item(self, index):
img_id = tuple(self.ids[index])
# img_id里面有2个值
target = ET.parse(self._annopath % img_id).getroot() # 获得xml的内容,但这个是具有特殊格式的
img = cv2.imread(self._imgpath % img_id)
height, width, _ = img.shape if self.target_transform is not None:
# 真实框处理
target = self.target_transform(target, width, height) if self.transform is not None:
# 图像预处理,进行数据增强,只在训练进行数据增强,测试的时候不需要
target = np.array(target)
img, boxes, labels = self.transform(img, target[:, :4], target[:, 4])
# 转换格式
img = img[:, :, (2, 1, 0)] # to rbg
target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
return torch.from_numpy(img).permute(2, 0, 1), target

该函数pull_item(),首先读取图片和相应的xml文件;接着对使用类VOCAnnotationTransform来对GT框进行处理,即读取GT框坐标与将坐标归一化;然后通过函数SSDAugmentation()对图片进行数据增强;最后对对图片进行常规处理(交换通道等),返回图片与存有GT框的list。

2.3 对GT框进行处理

接着,需要讲一讲这个类VOCAnnotationTransform的作用,其定义如下。self.class_to_ind是一个map,其key是类别名称,value是编号,这个对象的作用是,读取xml时,能将对应的类别名称转换成label;在__call__()函数中,主要是xml读取的一些方式,值得一提的是,GT框的最表转换成了[0,1]之间,当图片尺寸变化了,GT框的坐标也能进行相应的变换。最后,res的每行由5个元素组成,分别是[x_min,y_min,x_max,y_max,label]。

class VOCAnnotationTransform():
'''
获取xml里面的坐标值和label,并将坐标值转换成0到1
''' def __init__(self, class_to_ind=None, keep_difficult=False):
# 将类别名字转换成数字label
self.class_to_ind = class_to_ind or dict(zip(VOC_CLASSES, range(len(VOC_CLASSES))))
# 在xml里面,有个difficult的参数,这个表示特别难识别的目标,一般是小目标或者遮挡严重的目标
# 因此,可以通过这个参数,忽略这些目标
self.keep_difficult = keep_difficult def __call__(self, target, width, height):
'''
将一张图里面包含若干个目标,获取这些目标的坐标值,并转换成0到1,并得到其label
:param target: xml格式
:return: 返回List,每个目标对应一行,每行包括5个参数[xmin, ymin, xmax, ymax, label_ind]
'''
res = []
for obj in target.iter('object'):
difficult = int(obj.find('difficult').text) == 1 # 判断该目标是否为难例
# 判断是否跳过难例
if not self.keep_difficult and difficult:
continue
name = obj.find('name').text.lower().strip() # text是获得目标的名称,lower将字符转换成小写,strip去除前后空格
bbox = obj.find('bndbox') # 获得真实框坐标 pts = ['xmin', 'ymin', 'xmax', 'ymax']
bndbox = []
for i, pt in enumerate(pts):
cur_pt = int(bbox.find(pt).text) - 1 # 获得坐标值
# 将坐标转换成[0,1],这样图片尺寸发生变化的时候,真实框也随之变化,即平移不变形
cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
bndbox.append(cur_pt)
label_idx = self.class_to_ind[name] # 获得名字对应的label
bndbox.append(label_idx)
res.append(bndbox) # [xmin, ymin, xmax, ymax, label_ind]
return res # [[xmin, ymin, xmax, ymax, label_ind], ... ]

2.4 数据增强

还有一个重要的函数,即函数SSDAugmentation(),该函数的作用是作数据增强。论文中也提及了,数据增强对最终的结果提升有着重大作用。博客1博客2具体讲述了数据增强的源码,讲得十分详细。在本项目中,SSDAugmentation()函数在data/augmentations.py中,如下所示。由于opencv读取读片的时候,取值范围是[0,255],是int类型,需要将其转换为float类型,计算其GT框的正式坐标。然后对图片进行光度变形,包含改变对比度,改变饱和度,改变色调、改变亮度和增加噪声等。接着有对图片进行扩张和裁剪等。在此操作中,会涉及到GT框坐标的变换。最后,当上述变化处理完后,再对GT框坐标归一化,和resize图片,减去均值等。具体细节,可以参考两篇博客进行解读。

class SSDAugmentation(object):
def __init__(self, size=300, mean=(104, 117, 123)):
self.mean = mean
self.size = size
self.augment = Compose([
ConvertFromInts(), # 将图片从int转换成float
ToAbsoluteCoords(), # 计算真实的锚点框坐标
PhotometricDistort(), # 光度变形
Expand(self.mean), # 随机扩张图片
RandomSampleCrop(), # 随机裁剪
RandomMirror(), # 随机镜像
ToPercentCoords(),
Resize(self.size),
SubtractMeans(self.mean)
]) def __call__(self, img, boxes, labels):
return self.augment(img, boxes, labels)

2.5 辅助函数

在一个batch中,每张图片的GT框数量是不等的,因此,需要定义一个函数来处理这种情况。函数detection_collate()就是用于处理这种情况,使得一张图片能对应一个list,这里list里面有所有GT框的信息组成。

def detection_collate(batch):
"""Custom collate fn for dealing with batches of images that have a different
number of associated object annotations (bounding boxes).
自定义处理在同一个batch,含有不同数量的目标框的情况 Arguments:
batch: (tuple) A tuple of tensor images and lists of annotations Return:
A tuple containing:
1) (tensor) batch of images stacked on their 0 dim
2) (list of tensors) annotations for a given image are stacked on
0 dim
"""
targets = []
imgs = []
for sample in batch:
imgs.append(sample[0])
targets.append(torch.FloatTensor(sample[1]))
return torch.stack(imgs, 0), targets

至此,已经将SSD的数据读取部分分析完。

SSD源码解读——数据读取的更多相关文章

  1. SSD源码解读——网络测试

    之前,对SSD的论文进行了解读,可以回顾之前的博客:https://www.cnblogs.com/dengshunge/p/11665929.html. 为了加深对SSD的理解,因此对SSD的源码进 ...

  2. SSD源码解读——损失函数的构建

    之前,对SSD的论文进行了解读,可以回顾之前的博客:https://www.cnblogs.com/dengshunge/p/11665929.html. 为了加深对SSD的理解,因此对SSD的源码进 ...

  3. SSD源码解读——网络搭建

    之前,对SSD的论文进行了解读,可以回顾之前的博客:https://www.cnblogs.com/dengshunge/p/11665929.html. 为了加深对SSD的理解,因此对SSD的源码进 ...

  4. jQuery源码解读 - 数据缓存系统:jQuery.data

    jQuery在1.2后引入jQuery.data(数据缓存系统),主要的作用是让一组自定义的数据可以DOM元素相关联——浅显的说:就是让一个对象和一组数据一对一的关联. 一组和Element相关的数据 ...

  5. Jfinal-Plugin源码解读

    PS:cnxieyang@163.com/xieyang@e6yun.com 本文就Jfinal-plugin的源码进行分析和解读 Plugin继承及实现关系类图如下,常用的是Iplugin的三个集成 ...

  6. 『TensorFlow』SSD源码学习_其五:TFR数据读取&数据预处理

    Fork版本项目地址:SSD 一.TFR数据读取 创建slim.dataset.Dataset对象 在train_ssd_network.py获取数据操作如下,首先需要slim.dataset.Dat ...

  7. Spark Streaming源码解读之流数据不断接收和全生命周期彻底研究和思考

    本节的主要内容: 一.数据接受架构和设计模式 二.接受数据的源码解读 Spark Streaming不断持续的接收数据,具有Receiver的Spark 应用程序的考虑. Receiver和Drive ...

  8. SDWebImage源码解读之SDWebImageDownloaderOperation

    第七篇 前言 本篇文章主要讲解下载操作的相关知识,SDWebImageDownloaderOperation的主要任务是把一张图片从服务器下载到内存中.下载数据并不难,如何对下载这一系列的任务进行设计 ...

  9. SDWebImage源码解读之SDWebImageCache(下)

    第六篇 前言 我们在SDWebImageCache(上)中了解了这个缓存类大概的功能是什么?那么接下来就要看看这些功能是如何实现的? 再次强调,不管是图片的缓存还是其他各种不同形式的缓存,在原理上都极 ...

随机推荐

  1. 关于 About

    关于我 我是 Ivy,目前武汉大学 GIS 专业在读硕士研究生,业余渣程序媛. 写了一些不起眼的代码(参看我的 GitHub),做了一些不起眼的小研究(参看我的 ResearchGate). 关于本站 ...

  2. pyQt点击事件和数据传输

    首先是PushButton点击事件,点击按钮之后发送textEdit框里输入的文字到后台. def retranslateUi(self, MainWindow): _translate = QtCo ...

  3. .NET制作滚动条

    今天,在工作的时候,刚好做到了滚动条,对这点不是很懂,所以,研究了一下,记录在这里,与大家分享. 对于前台页面,我们就只需要设置数据表的样式:style="overflow: auto; 即 ...

  4. office web apps安装部署,配置https,负载均衡(七)配置过程中遇到的问题详细解答

    该篇文章,是这个系列文章的最后一篇文章,该篇文章将详细解答owa在安装过程中常见的问题. 如果您没有搭建好office web apps,您可以查看前面的一系列文章,查看具体步骤: office we ...

  5. mysql数据库为什么要分表和分区?

    一般下载的源码都带了MySQL数据库的,做个真正意义上的网站没数据库肯定不行. 数据库主要存放用户信息(注册用户名密码,分组,等级等),配置信息(管理权限配置,模板配置等),内容链接(html ,图片 ...

  6. 【BZOJ4766】文艺计算姬

    让你求一个两边各有n和m个点的完全二分图有多少个生成树. 这是一道比较经典的利用prufer序列结论求解答案的计数题. 大致思路考虑一张二分图求解prufer序列,由于prufer序列求解时最后剩下的 ...

  7. NOIp2015D1T3 斗地主【暴搜】

    题目传送门 刚开始读到题目的时候,非常懵逼,非常崩溃,写着写着呢,也有点崩溃,细节有点多. 这个做法呢,只能过掉官方数据,洛谷上好像有加强版,只能得$86$分,就没有管了. 大概说一下思路: 暴力搜索 ...

  8. 系统的可用性用平均无故障时间( MTTF)

    计算机系统的可用性用平均无故障时间( MTTF)来度量,即计算机系统平均能够正常运行多长时间,才发生一次故障.系统的可用性越高,平均无故障时间越长. 可维护性用平均维修时间(MTTR)来度量,即系统发 ...

  9. PHP反序列化进阶寻找和构造

    POP链的构造 如果关键代码不在魔术方法中,而是在一个类的普通方法中. 这时候可以通过寻找相同的函数名将类的属性和敏感函数的属性联系起来 <?phpclass lemon {    protec ...

  10. Storm提交Topology报错:Found multiple defaults.yaml resources.

    Storm提交Topology运行方式分为本地和集群运行两种,其中集群运行需要将程序打包并把jar包复制到集群,通过以下方式执行: bin/storm jar /opt/run/storm-demo- ...