torchnet+VGG16计算patch之间相似度

torch
VGG16
similarity

本来打算使用VGG实现siamese CNN的,但是没想明白怎么使用torchnet对模型进行微调。。。所以只好把VGG的卷积层单独做一个数据预处理模块,后面跟一个网络,将两个VGG输出的结果输入该网络中,仅训练这个浅层网络。

数据:使用了MOTChallenge数据库MOT16-02中的pedestrian

代码:

  1. -- --------------------------------------------------------------------------------------- 

  2. -- 读取MOT16-02数据集的groundtruth,分成训练集和测试集 

  3. -- --------------------------------------------------------------------------------------- 

  4. require 'torch' 

  5. require 'cutorch' 

  6. torch.setdefaulttensortype('torch.FloatTensor') 

  7. data_type = 'torch.CudaTensor' -- 设置数据类型,不适用GPU可以设置为torch.FloatTensor 


  8. require 'image' 

  9. local datapath = '/home/zwzhou/programFiles/2DMOT2015/MOT16/train/MOT16-02/' 

  10. local tmp = image.load(datapath .. 'img1/000001.jpg',3,'byte') 

  11. local width = tmp:size(3) 

  12. local height = tmp:size(2) 

  13. local num = 600 

  14. local imgs = torch.Tensor(num,3,height,width) 


  15. local file,_ = io.open('imgs.t7') 

  16. if not file then 

  17. for i=1,num do -- 读取视频帧 

  18. imgs[i]=image.load(datapath .. 'img1/' .. string.format('%06d.jpg',i)) 

  19. end 

  20. torch.save('imgs.t7',imgs) 

  21. else 

  22. imgs = torch.load('imgs.t7') 

  23. end 


  24. require'sys' 

  25. local gt_path = datapath .. 'gt/gt.txt' 

  26. local gt_info={} 

  27. local i=0 

  28. for line in io.lines(gt_path) do -- pedestrians的patch信息 

  29. local v=sys.split(line,',') 

  30. if tonumber(v[7]) ==1 and tonumber(v[9]) > 0.8 then -- 筛选有效的patch,是pedestrian且可见度>0.8 

  31. table.insert(gt_info,{tonumber(v[1]),tonumber(v[2]),tonumber(v[3]),tonumber(v[4]),tonumber(v[5]),tonumber(v[6])}) 

  32. -- 对应的是frame index,track index, x, y, w, h 

  33. end 

  34. end 

  35. -- 构建样本对,这里主要是为了正负样本个数相同,每个pedestrian选取25个相同id的patch,25个不同id的patch 

  36. local pairwise={} 

  37. for i=1,#gt_info do 

  38. local count=0 

  39. local iter=0 

  40. repeat  

  41. local j=torch.ceil(torch.rand(1)*(#gt_info))[1] 

  42. if gt_info[i][2] == gt_info[j][2] then  

  43. count=count+1 

  44. table.insert(pairwise,{i,j}) 

  45. end 

  46. iter=iter+1 

  47. until(count >25 or iter>100) 

  48. repeat  

  49. local j=torch.ceil(torch.rand(1)*#gt_info)[1] 

  50. if gt_info[i][2] ~= gt_info[j][2] then  

  51. count=count-1 

  52. table.insert(pairwise,{i,j}) 

  53. end 

  54. until(count <0) 

  55. end 


  56. local function cast(x) return x:type(data_type) end -- 类型转换 


  57. -- 加载pretrained VGG16 model 

  58. require 'nn' 

  59. require 'loadcaffe' 

  60. local function getPretrainedModel() 

  61. local proto = '/home/zwzhou/modelZoo/VGG_ILSVRC_16_layers_deploy.prototxt' 

  62. local caffemodel = '/home/zwzhou/modelZoo/VGG_ILSVRC_16_layers.caffemodel' 

  63. local VGG16 = loadcaffe.load(proto,caffemodel,'nn') 

  64. for i = 1,3 do 

  65. VGG16.modules[#VGG16.modules]=nil 

  66. end 

  67. return VGG16 

  68. end 


  69. -- 为了能够使用VGG,需要定义一些预处理方法 

  70. local loadSize = {3,256,256} 

  71. local sampleSize={3,224,224} 


  72. local function adjustScale(input) -- VGG需要先将输入图片的最小边缩放到256,另一边保持纵横比 

  73. if input:size(3) < input:size(2) then 

  74. input = image.scale(input,loadSize[2],loadSize[3]*input:size(2)/input:size(3)) 

  75. else 

  76. input = image.scale(input,loadSize[2]*input:size(3)/input:size(2),loadSize[3]) 

  77. end 

  78. return input 

  79. end 


  80. local bgr_means = {103.939,116.779,123.68} -- VGG使用的均值,注意是BGR通道,image.load()获得的是rgb 

  81. local function vggProcessing(img) 

  82. local img2 = img:clone() -- 深度拷贝 

  83. img2[{{1}}] = img[{{3}}] 

  84. img2[{{3}}] = img[{{1}}] -- rgb -> bgr 

  85. img2=img2:mul(255) 

  86. for i=1,3 do 

  87. img2[i]:add(-bgr_means[i]) 

  88. end 

  89. return img2 

  90. end 


  91. local function centerCrop(input) -- 截取224*224大小 

  92. local oH = sampleSize[2] 

  93. local oW = sampleSize[3] 

  94. local iW = input:size(3) 

  95. local iH = input:size(2) 

  96. local w1 = math.ceil((iW-oW)/2) 

  97. local h1 = math.ceil((iH-oH)/2) 

  98. local out = image.crop(input,w1,h1,w1+oW,h1+oH) 

  99. return out 

  100. end 


  101. local file,_ = io.open('vgg_info.t7') 

  102. local vgg_info={} 

  103. if not file then 

  104. local VGG16_model = getPretrainedModel() 

  105. if data_type:match'torch.Cuda.*Tensor' then 

  106. require 'cudnn' 

  107. require 'cunn' 

  108. cudnn.convert(VGG16_model,cudnn):cuda() 

  109. cudnn.benchmark = true 

  110. end 

  111. cast(VGG16_model) 

  112. for i=1, #gt_info do 

  113. local idx=gt_info[i] 

  114. local img = imgs[idx[1]] 

  115. local x1 = math.max(idx[3],1) 

  116. local y1 = math.max(idx[4],1) 

  117. local x2 = math.min(idx[3]+idx[5],width) 

  118. local y2 = math.min(idx[4]+idx[6],height) 

  119. local patch = image.crop(img,x1,y1,x2,y2) 

  120. patch = adjustScale(patch) 

  121. patch = vggProcessing(patch) 

  122. patch = centerCrop(patch) 

  123. patch=cast(patch) 

  124. table.insert(vgg_info,VGG16_model:forward(patch):float()) 

  125. end 

  126. torch.save('vgg_info.t7',vgg_info) 

  127. else  

  128. vgg_info=torch.load('vgg_info.t7') 

  129. end 


  130. local function getPatchPair(tmp) -- 获得patch 对 

  131. local pp = {} 

  132. pp[1] = vgg_info[tmp[1]] 

  133. pp[2] = vgg_info[tmp[2]] 

  134. local t=torch.cat(pp[1],pp[2],1) 

  135. return t 

  136. end 


  137. -- 定义datasetiterator 

  138. local tnt=require'torchnet' 

  139. local function getIterator(mode) 

  140. -- 创建model 

  141. local fc = nn.Sequential() 

  142. fc:add(nn.View(-1,4096*2)) 

  143. fc:add(nn.Linear(4096*2,500)) 

  144. fc:add(nn.ReLU(true)) 

  145. fc:add(nn.Normalize(2)) 

  146. fc:add(nn.Linear(500,500)) 

  147. fc:add(nn.ReLU(true)) 

  148. fc:add(nn.Linear(500,1)) 


  149. -- print(fc:forward(torch.randn(2,4096*2))) 

  150. if data_type:match'torch.Cuda.*Tensor' then 

  151. require 'cudnn' 

  152. require 'cunn' 

  153. cudnn.convert(fc,cudnn):cuda() 

  154. cudnn.benchmark = true 

  155. end 

  156. cast(fc) 


  157. -- 构建训练引擎,使用OptimEngine 

  158. require 'optim' 

  159. local engine = tnt.OptimEngine() 

  160. local criterion = cast(nn.MarginCriterion()) 


  161. -- 创建一些评估值 

  162. local train_timer = torch.Timer() 

  163. local test_timer = torch.Timer() 

  164. local data_timer = torch.Timer() 


  165. local meter = tnt.AverageValueMeter() -- 用于统计评估函数的输出 

  166. local confusion = optim.ConfusionMatrix(2) -- 2类混淆矩阵 

  167. local data_time_meter = tnt.AverageValueMeter() 

  168. -- log 

  169. local logtext=require 'torchnet.log.view.text' 

  170. log = tnt.Log{ 

  171. keys = {'train_loss','train_acc','data_loading_time','epoch','test_acc','train_time','test_time'}, 

  172. onFlush={ 

  173. logtext{keys={'train_loss','train_acc','data_loading_time','epoch','test_acc','train_time','test_time'}} 






  174. local inputs = cast(torch.Tensor()) 

  175. local targets = cast(torch.Tensor()) 


  176. -- 填一些hook函数,以便观察训练过程 

  177. engine.hooks.onSample = function(state) 

  178. if state.training then 

  179. data_time_meter:add(data_timer:time().real) 

  180. end 

  181. inputs:resize(state.sample.input:size()):copy(state.sample.input) 

  182. targets:resize(state.sample.target:size()):copy(state.sample.target) 

  183. state.sample.input = inputs 

  184. state.sample.target = targets 

  185. end 


  186. engine.hooks.onForwardCriterion = function(state) 

  187. meter:add(state.criterion.output) 

  188. confusion:batchAdd(state.network.output:gt(0):add(1),state.sample.target:gt(0):add(1)) 

  189. end 


  190. local function test() -- 用于测试 

  191. engine:test{ 

  192. network = fc, 

  193. iterator = getIterator('test'), 

  194. criterion=criterion,  



  195. confusion:updateValids() 

  196. end 


  197. engine.hooks.onStartEpoch = function(state) 

  198. local epoch = state.epoch + 1 

  199. print('===>' .. ' online epoch # ' .. epoch .. '[batchsize = 256]') 

  200. meter:reset() 

  201. confusion:zero() 

  202. train_timer:reset() 

  203. data_time_meter:reset() 

  204. end 


  205. engine.hooks.onEndEpoch = function(state) 

  206. local train_loss = meter:value() 

  207. confusion:updateValids() 

  208. local train_acc = confusion.totalValid*100 

  209. local train_time = train_timer:time().real 

  210. meter:reset() 

  211. print(confusion) 

  212. confusion:zero() 

  213. test_timer:reset() 


  214. local cache = state.params:clone() -- 保存现场 

  215. --state.params:copy(state.optim.ax) 

  216. test() 

  217. --state.params:copy(cache) -- 恢复现场 


  218. log:set{ 

  219. train_loss = train_loss, 

  220. train_acc = train_acc, 

  221. data_loading_time = data_time_meter:value(), 

  222. epoch = state.epoch, 

  223. test_acc = confusion.totalValid*100, 

  224. train_time = train_time, 

  225. test_time = test_timer:time().real, 



  226. log:flush() 

  227. end 


  228. engine.hooks.onUpdate = function(state) 

  229. data_timer:reset() 

  230. end 


  231. engine:train{ 

  232. network = fc, 

  233. criterion = criterion, 

  234. iterator = getIterator('train'), 

  235. optimMethod = optim.sgd, 

  236. config = {learningRate = 0.05, 

  237. --weightDecay = 0.05, 

  238. momentum = 0.9, 

  239. --t0 = 1e+4, 

  240. --eta0 =0.1 

  241. }, 

  242. maxepoch = 30,  




  243. -- 保存模型 

  244. local modelpath = 'SiaVGG16_model.t7' 

  245. print('Saving to ' .. modelpath) 

  246. torch.save(modelpath,fc:float():clearState()) 

  247. --]] 

输出:

1493386765674.jpg

发现网络太容易过拟合,主要一方面是数据太少,另一方面是视频中就那么几个人,所以patch之间的相关性太大,对网络提供的信息太少。所以使用更多的数据测试结果应该会好许多。

这个代码主要是为了熟悉torchnet package,感受呢,

  1. 对于数据的预处理,确实方便多了

  2. 如果使用提供的Engine,虽然训练过程简单了但是也太模块化了,比如某些层的微调,比如每层设置不同的学习率

  3. 使用Iterator时,尤其要小心

torchnet+VGG16计算patch之间相似度的更多相关文章

  1. (转)c# math 计算两点之间的角度公式

    计算两点之间的角度公式是: 假设点一(X1,Y1),点二(X2,Y2) double angleOfLine = Math.Atan2((Y2 - Y1), (X2 - X2)) * 180 / Ma ...

  2. python-Levenshtein几个计算字串相似度的函数解析

    linux环境下,没有首先安装python_Levenshtein,用法如下: 重点介绍几个该包中的几个计算字串相似度的几个函数实现. 1. Levenshtein.hamming(str1, str ...

  3. sql server2008根据经纬度计算两点之间的距离

    --通过经纬度计算两点之间的距离 create FUNCTION [dbo].[fnGetDistanceNew] --LatBegin 开始经度 --LngBegin 开始维度 --29.49029 ...

  4. C#面向对象思想计算两点之间距离

    题目为计算两点之间距离. 面向过程的思维方式,两点的横坐标之差,纵坐标之差,平方求和,再开跟,得到两点之间距离. using System; using System.Collections.Gene ...

  5. 2D和3D空间中计算两点之间的距离

    自己在做游戏的忘记了Unity帮我们提供计算两点之间的距离,在百度搜索了下. 原来有一个公式自己就写了一个方法O(∩_∩)O~,到僵尸到达某一个点之后就向另一个奔跑过去 /// <summary ...

  6. Jquery计算时间戳之间的差值,可返回年,月,日,小时等

    /** * 计算时间戳之间的差值 * @param startTime 开始时间戳 * @param endTime 结束时间戳 * @param type 返回指定类型差值(year, month, ...

  7. Levenshtein Distance莱文斯坦距离算法来计算字符串的相似度

    Levenshtein Distance莱文斯坦距离定义: 数学上,两个字符串a.b之间的莱文斯坦距离表示为levab(|a|, |b|). levab(i, j) = max(i, j)  如果mi ...

  8. <tf-idf + 余弦相似度> 计算文章的相似度

    背景知识: (1)tf-idf 按照词TF-IDF值来衡量该词在该文档中的重要性的指导思想:如果某个词比较少见,但是它在这篇文章中多次出现,那么它很可能就反映了这篇文章的特性,正是我们所需要的关键词. ...

  9. numpy :: 计算特征之间的余弦距离

    余弦距离在计算相似度的应用中经常使用,比如: 文本相似度检索 人脸识别检索 相似图片检索 原理简述 下面是余弦相似度的计算公式(图来自wikipedia): 但是,余弦相似度和常用的欧式距离的有所区别 ...

随机推荐

  1. js 格式验证大全

    1.身份证号码验证: var Common = { //身份证号验证 IsIdCardNo: function (IdCard) { var reg = /^\d{15}(\d{2}[0-9X])?$ ...

  2. Block Towers---cf626c(二分)

    题目链接:http://www.codeforces.com/contest/626/problem/C 题意是有一群小朋友在堆房子,现在有n个小孩每次可以放两个积木,m个小孩,每次可以放3个积木,最 ...

  3. Day20 javaWeb监听器和国际化

    day20 JavaWeb监听器 三大组件: Servlet Listener Filter   Listener:监听器 初次相见:AWT 二次相见:SAX   监听器: 它是一个接口,内容由我们来 ...

  4. 第1章 1.3计算机网络概述--规划IP地址介绍MAC地址

    IP地址的作用是:指定发送数据者和接收数据者. MAC地址的作用:指定数据包的下一跳转设备.就是说明数据下一步向谁发. 路由器的作用:在不同的网段中转发数据.路由器本质就是有2个网卡的设备. 网卡:用 ...

  5. zip和tgz以及exe的区别

    在下载东西的时候总是碰见后缀是.tar.gz和.zip的问题,搞不清楚是怎么回事,不晓得下载哪个文件才是对自己有用的. 后来才知道,其实这两个压缩文件里面包含的内容是一样的,只是压缩格式不一样, ta ...

  6. 2.8 The Object Model -- Enumerables

    在Ember.js中,枚举是包含许多子对象的任何对象,并允许你使用Ember.Enumerable API和那些子对象一起工作.在大部分应用程序中最常见的可枚举是本地JS数组,Ember.js扩展到符 ...

  7. NodeJS学习笔记五

    Promise简介 所谓Promise,就是一个对象,用来传递异步操作的消息. Promise对象有以下两个特点. (1)对象的状态不受外界影响.Promise对象代表一个异步操作,有三种状态:Pen ...

  8. BUG克星:几款优秀的BUG跟踪管理软件

    Bug管理是指对开发,测试,设计等过程中一系列活动过程中出现的bug问题给予纪录.审查.跟踪.分配.修改.验证.关闭.整理.分析.汇总以及删除等一系列活动状态的管理.,最后出相应图表统计,email通 ...

  9. 【android】ViewPager 大量内容页的内存优化

    总结:使用FragmentStatePagerAdapter 代替 FragmentPagerAdapter作为大批量内容页的适配器. 详细: 最近App里有一个场景,类似猿题库做题那种:有很多个题目 ...

  10. 《Java入门第二季》第二章 封装

    什么是java中的封装1.封装的概念:隐藏信息.隐藏具体的实现细节. 2.封装的实现步骤: 1)修改属性的可见性,private.2)创建修改器方法和访问器方法,getXXX/setXXX.(未必一定 ...