Caffe-python interface 学习|网络训练、部署、測试
继续python接口的学习。剩下还有solver、deploy文件的生成和模型的測试。
网络训练
solver文件生成
事实上我认为用python生成solver并不如直接写个配置文件,它不像net配置一样有非常多反复的东西。
对于一下的solver配置文件:
base_lr: 0.001
display: 782
gamma: 0.1
lr_policy: “step”
max_iter: 78200 #训练样本迭代次数=max_iter/782(训练完一次所有样本的迭代数)
momentum: 0.9
snapshot: 7820
snapshot_prefix: "snapshot"
solver_mode: GPU
solver_type: SGD
stepsize: 26067
test_interval: 782 #test_interval=训练样本数(50000)/batch_size(train:64)
test_iter: 313 #test_iter=測试样本数(10000)/batch_size(test:32)
test_net: "/home/xxx/data/val.prototxt"
train_net: "/home/xxx/data/proto/train.prototxt"
weight_decay: 0.0005
能够用以下方式实现生成:
from caffe.proto import caffe_pb2
s = caffe_pb2.SolverParameter()
path='/home/xxx/data/'
solver_file=path+'solver1.prototxt'
s.train_net = path+'train.prototxt'
s.test_net.append(path+'val.prototxt')
s.test_interval = 782
s.test_iter.append(313) #这里用的是append,码风不太一样
s.max_iter = 78200
s.base_lr = 0.001
s.momentum = 0.9
s.weight_decay = 5e-4
s.lr_policy = 'step'
s.stepsize=26067
s.gamma = 0.1
s.display = 782
s.snapshot = 7820
s.snapshot_prefix = 'shapshot'
s.type = “SGD”
s.solver_mode = caffe_pb2.SolverParameter.GPU
with open(solver_file, 'w') as f:
f.write(str(s))
并没有简单多少。
须要注意的是有些參数须要计算得到:
- test_interval:
假设我们有50000个训练样本。batch_size为64。即每批次处理64个样本,那么须要迭代50000/64=782次才处理完一次所有的样本。我们把处理完一次所有的样本,称之为一代,即epoch。所以。这里的test_interval设置为782,即处理完一次所有的训练数据后。才去进行測试。假设我们想训练100代。则须要设置max_iter为78200.
- test_iter:
同理,假设有10000个測试样本,batch_size设为32,那么须要迭代10000/32=313次才完整地測试完一次。所以设置test_iter为313. - lr_rate:
学习率变化规律我们设置为随着迭代次数的添加,慢慢变低。总共迭代78200次,我们将变化lr_rate三次。所以stepsize设置为78200/3=26067。即每迭代26067次,我们就减少一次学习率。
模型训练
完整依照定义的网络和solver去训练,就像命令行一样:
solver = caffe.SGDSolver('/home/xxx/solver.prototxt')
solver.solve()
只是也能够分得更细一些,比方先载入模型:
solver = caffe.get_solver('models/bvlc_reference_caffenet/solver.prototxt')
这里用的是.get_solver。默认依照SGD方法求解。
向前传播一次网络。即从输入层到loss层,计算net.blobs[k].data。
solver.net.forward() # train net
反向传播一次网络,即从loss层到输入层,计算net.blobs[k].diff and net.params[k][j].diff。
solver.net.backward()
假设须要一次完整的计算,正向、反向、更新权重(net.params[k][j].data)。能够使用
solver.step(1)
改变数字进行多次计算。
网络部署
部署即生成一个deploy文件,用于以下的模型測试。
这里既能够用python,也能够直接改动net文件。
from caffe import layers as L,params as P,to_proto
root='/home/xxx/'
deploy=root+'mnist/deploy.prototxt' #文件保存路径
def create_deploy():
#少了第一层。data层
conv1=L.Convolution(bottom='data', kernel_size=5, stride=1,num_output=20, pad=0,weight_filler=dict(type='xavier'))
pool1=L.Pooling(conv1, pool=P.Pooling.MAX, kernel_size=2, stride=2)
conv2=L.Convolution(pool1, kernel_size=5, stride=1,num_output=50, pad=0,weight_filler=dict(type='xavier'))
pool2=L.Pooling(conv2, pool=P.Pooling.MAX, kernel_size=2, stride=2)
fc3=L.InnerProduct(pool2, num_output=500,weight_filler=dict(type='xavier'))
relu3=L.ReLU(fc3, in_place=True)
fc4 = L.InnerProduct(relu3, num_output=10,weight_filler=dict(type='xavier'))
#最后没有accuracy层,但有一个Softmax层
prob=L.Softmax(fc4)
return to_proto(prob)
def write_deploy():
with open(deploy, 'w') as f:
f.write('name:"Lenet"\n')
f.write('input:"data"\n')
f.write('input_dim:1\n')
f.write('input_dim:3\n')
f.write('input_dim:28\n')
f.write('input_dim:28\n')
f.write(str(create_deploy()))
if __name__ == '__main__':
write_deploy()
假设自己改动net。须要改动数据输入:
layer {
name: "data"
type: "Input"
top: "data"
input_param { shape: { dim: 1 dim: 3 dim: 100 dim: 100 } }
}
而且添加一个softmax。对于原来的softmaxwithloss直接换掉即可。
网络測试
训练好之后得到模型。实际使用是须要用模型进行预測。
这时须要用到deploy文件和caffemodel。
#coding=utf-8
import caffe
import numpy as np
root='/home/xxx/' #根文件夹
deploy=root + 'mnist/deploy.prototxt' #deploy文件
caffe_model=root + 'mnist/lenet_iter_9380.caffemodel' #训练好的 caffemodel
img=root+'mnist/test/5/00008.png' #随机找的一张待測图片
labels_filename = root + 'mnist/test/labels.txt' #类别名称文件,将数字标签转换回类别名称
net = caffe.Net(deploy,caffe_model,caffe.TEST) #载入model和network
#图片预处理设置
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape}) #设定图片的shape格式(1,3,28,28)
transformer.set_transpose('data', (2,0,1)) #改变维度的顺序,由原始图片(28,28,3)变为(3,28,28)
#transformer.set_mean('data', np.load(mean_file).mean(1).mean(1)) #减去均值。前面训练模型时没有减均值,这儿就不用
transformer.set_raw_scale('data', 255) # 缩放到【0。255】之间
transformer.set_channel_swap('data', (2,1,0)) #交换通道,将图片由RGB变为BGR
im=caffe.io.load_image(img) #载入图片
net.blobs['data'].data[...] = transformer.preprocess('data',im) #运行上面设置的图片预处理操作,并将图片载入到blob中
#运行測试
out = net.forward()
labels = np.loadtxt(labels_filename, str, delimiter='\t') #读取类别名称文件
prob= net.blobs['Softmax1'].data[0].flatten() #取出最后一层(Softmax)属于某个类别的概率值,并打印
print prob
order=prob.argsort()[-1] #将概率值排序,取出最大值所在的序号
print 'the class is:',labels[order] #将该序号转换成相应的类别名称,并打印
总结
利用python接口,对网络的详细參数能够有更全面的认识和理解。只是也有几点须要注意:
- 数据格式的转换
caffe的数据blob shape是N*C*H*W。通道数在前。而python图像处理时shape是H*W*C。通道数在后。因此须要转换一下。
- 图片显示与保存
因为没有图形界面,非常方便的jupyter notebook不能使用,仅仅好保存图片查看。
caffe的python接口学习(2):生成solver文件
caffe的python接口学习(5):生成deploy文件
caffe的python接口学习(6):用训练好的模型(caffemodel)来分类新的图片
Deep learning tutorial on Caffe technology : basic commands, Python and C++ code.
Multilabel classification on PASCAL using python data-layers
Caffe-python interface 学习|网络训练、部署、測试的更多相关文章
- OpenGL学习脚印:深度測试(depth testing)
写在前面 上一节我们使用AssImp载入了3d模型,效果已经令人激动了.可是绘制效率和场景真实感还存在不足,接下来我们还是要保持耐心,继续学习一些高级主题,等学完后面的高级主题,我们再次来改进我们载入 ...
- Python 基于学习 网络小爬虫
<span style="font-size:18px;"># # 百度贴吧图片网络小爬虫 # import re import urllib def getHtml( ...
- caffe Python API 之Model训练
# 训练设置 # 使用GPU caffe.set_device(gpu_id) # 若不设置,默认为0 caffe.set_mode_gpu() # 使用CPU caffe.set_mode_cpu( ...
- Python入门学习:网络刷博器爬虫
1.比较有趣,可以不断刷新指定的网址 2.源码: #!/usr/bin/env python3 # -*- coding: utf-8 -*- import webbrowser as web imp ...
- [Python]threading local 线程局部变量小測试
概念 有个概念叫做线程局部变量.一般我们对多线程中的全局变量都会加锁处理,这样的变量是共享变量,每一个线程都能够读写变量,为了保持同步我们会做枷锁处理.可是有些变量初始化以后.我们仅仅想让他们在每一个 ...
- Selenium2 Python 自己主动化測试实战学习笔记(五)
7.1 自己主动化測试用例 无论是功能測试.性能測试和自己主动化測试时都须要编写測试用例,測试用例的好坏能准确的体现了測试人员的经验.能力以及对项目的深度理解. 7.1.1 手工測试用例与自己主动化測 ...
- 【目录】Python模块学习系列
目录:Python模块学习笔记 1.Python模块学习 - Paramiko - 主机管理 2.Python模块学习 - Fileinput - 读取文件 3.Python模块学习 - Confi ...
- Maven实现Web应用集成測试自己主动化 -- 部署自己主动化(WebTest Maven Plugin)
上篇:Maven实现Web应用集成測试自己主动化 -- 測试自己主动化(WebTest Maven Plugin) 之前介绍了怎样在maven中使用webtest插件实现web的集成測试,这里有个遗留 ...
- 移动App測试实战:顶级互联网企业软件測试和质量提升最佳实践
这篇是计算机类的优质预售推荐>>>><移动App測试实战:顶级互联网企业软件測试和质量提升最佳实践> 国内顶级互联网公司測试实战经验总结.阿里.腾讯.京东.携程.百 ...
随机推荐
- JS实现缓存运动
JS ...
- 【codeforces 505D】Mr. Kitayuta's Technology
[题目链接]:http://codeforces.com/problemset/problem/505/D [题意] 让你构造一张有向图; n个点; 以及所要求的m对联通关系(xi,yi) 即要求这张 ...
- shell脚本学习之ubuntu删除多余内核
#!/bin/bash #定期删除内核 #存储命令输出cmd_output=`commands` uname_output=$(uname -r) kernel_output=`dpkg --list ...
- spark源代码action系列-foreach与foreachPartition
RDD.foreachPartition/foreach的操作 在这个action的操作中: 这两个action主要用于对每一个partition中的iterator时行迭代的处理.通过用户传入的fu ...
- [Maven实战](5)Archetype生成项目骨架
Hello World项目中有一些Maven的约定:在项目根文件夹中放置pom.xml,在src/main/java文件夹下放置项目的主代码,在sc/test/java中放置项目的測试代码.之所以一步 ...
- Android BLE与终端通信(三)——client与服务端通信过程以及实现数据通信
Android BLE与终端通信(三)--client与服务端通信过程以及实现数据通信 前面的终究仅仅是小知识点.上不了台面,也仅仅能算是起到一个科普的作用.而同步到实际的开发上去,今天就来延续前两篇 ...
- Ansible@一个高效的配置管理工具--Ansible configure management--翻译(七)
如无书面授权,请勿转载 Larger Projects Until now, we have been looking at single plays in one playbook file. Th ...
- FireEye APT检测——APT业务占比过重,缺乏其他安全系统的查杀和修复功能
摘自:https://zhidao.baidu.com/question/1694626564301467468.html火眼,APT威胁下快速成长 FireEye的兴起开始于2012年,这时段正好迎 ...
- ORA-01658无法为表空间中的段创建INITIAL区
导出空表设置时,提示错误是: ORA-01658无法为表空间中的段创建INITIAL区 查找解决方案为 表空间已满 设置表空间自动增长 即可 例: alter database datafil ...
- rman备份工具简介
RMAN工具简介: 备份的文件: 数据文件 归档日志 控制文件(当前控制文件) spfile 自动管理备份相关元数据 文件名称 完成备份的scn 以数据块为单位,只备份使用过的数据块(物理层面判断是否 ...