caffe添加python数据层
在caffe中添加自定义层时,必须要实现这四个函数,在C++中是(LayerSetUp,Reshape,Forward_cpu,Backward_cpu),在python 中是(setup,reshape,forward_cpu,backword_cpu)。
prototxt
layer {
name: "data"
type: "Python"
top: "data"
top: "label"
include {
phase: TRAIN
}
python_param {
module: "src.data_layer.rank_layer_live" # 不能代目录形式
layer: "DataLayer"
param_str: " {\'pascal_root\': \'data\' ,\'split\': \'live_train\', \'im_shape\': [224, 224],\'batch_size\': 32}"
}
}
layer {
name: "data"
type: "Python"
top: "data"
top: "label"
include {
phase: TEST
}
python_param {
module: "src.data_layer.rank_layer_live"
layer: "DataLayer"
#batch_size: 160
param_str: " {\'pascal_root\': \'data\' ,\'split\': \'live_test\', \'im_shape\': [224, 224],\'batch_size\': 32}"
}
}
- 数据定义层:
import cv2
import sys
sys.path.append("/home/rjw/caffe/python")
import caffe
import numpy as np
import multiprocessing as mtp
import pdb
import os.path as osp ## 理解参考:https://blog.csdn.net/auto1993/article/details/78951849 class DataLayer(caffe.Layer): def setup(self, bottom, top): self._name_to_top_map = {}
self._name_to_top_map['data'] = 0
self._name_to_top_map['label'] = 1
# === Read input parameters ===
self.workers= mtp.Pool(10)
# params is a python dictionary with layer parameters.
params = eval(self.param_str) # Check the paramameters for validity.
check_params(params) # store input as class variables
self.batch_size = params['batch_size']
self.pascal_root = params['pascal_root']
self.im_shape = params['im_shape']
# get list of image indexes.
list_file = params['split'] + '.txt'
filename = [line.rstrip('\n') for line in open(
osp.join(self.pascal_root, list_file))]
self._roidb = []
self.scores =[]
for i in filename:
self._roidb.append(i.split()[0])
self.scores.append(float(i.split()[1]))
self._perm = None
self._cur = 0
self.num =0 top[0].reshape(
self.batch_size, 3, params['im_shape'][0], params['im_shape'][1]) top[1].reshape(self.batch_size, 1) def _get_next_minibatch_inds(self):
"""Return the roidb indices for the next minibatch."""
db_inds = []
dis = 4 # total number of distortions in live dataset
batch = 2 # number of images for each distortion level
level = 4 # distortion levels for each mini_batch = level * dis_mini*batch
#shuff = np.random.permutation(range(dis))
Num = len(self.scores)/dis/level
for k in range(dis):
for i in range(level):
temp = self.num
for j in range(batch):
db_inds.append(len(self.scores)/dis*k+i*Num+temp)
temp = temp +1
self.num = self.num+batch
if Num-self.num<batch:
self.num=0
db_inds = np.asarray(db_inds)
return db_inds def get_minibatch(self,minibatch_db):
"""Given a roidb, construct a minibatch sampled from it."""
# Get the input image blob, formatted for caffe jobs =self.workers.map(preprocess,minibatch_db)
#print len(jobs)
index = 0
images_train = np.zeros([self.batch_size,3,224,224],np.float32)
#pdb.set_trace()
for index_job in range(len(jobs)):
images_train[index,:,:,:] = jobs[index_job]
index += 1 blobs = {'data': images_train}
return blobs def forward(self, bottom, top):
"""Get blobs and copy them into this layer's top blob vector.""" db_inds = self._get_next_minibatch_inds()
minibatch_db = []
for i in range(len(db_inds)):
minibatch_db.append(self._roidb[int(db_inds[i])])
#minibatch_db = [self._roidb[i] for i in db_inds]
#print minibatch_db
scores = []
for i in range(len(db_inds)):
scores.append(self.scores[int(db_inds[i])])
blobs = self.get_minibatch(minibatch_db)
blobs ['label'] =np.asarray(scores)
for blob_name, blob in blobs.iteritems():
top_ind = self._name_to_top_map[blob_name]
# Reshape net's input blobs
top[top_ind].reshape(*(blob.shape))
# Copy data into net's input blobs
top[top_ind].data[...] = blob.astype(np.float32, copy=False) def backward(self, top, propagate_down, bottom):
"""This layer does not propagate gradients."""
pass def reshape(self, bottom, top):
"""Reshaping happens during the call to forward."""
pass def preprocess(data): sp = 224
im = np.asarray(cv2.imread(data))
x = im.shape[0]
y = im.shape[1]
x_p = np.random.randint(x-sp,size=1)[0]
y_p = np.random.randint(y-sp,size=1)[0]
#print x_p,y_p
images = im[x_p:x_p+sp,y_p:y_p+sp,:].transpose([2,0,1])
#print images.shape
return images def check_params(params):
"""
A utility function to check the parameters for the data layers.
"""
assert 'split' in params.keys(
), 'Params must include split (train, val, or test).' required = ['batch_size', 'pascal_root', 'im_shape']
for r in required:
assert r in params.keys(), 'Params must include {}'.format(r)
caffe添加python数据层的更多相关文章
- caffe 中 python 数据层
caffe中大多数层用C++写成. 但是对于自己数据的输入要写对应的输入层,比如你要去图像中的一部分,不能用LMDB,或者你的label 需要特殊的标记. 这时候就需要用python 写一个输入层. ...
- 在Caffe添加Python layer详细步骤
本文主要讨论的是在caffe中添加python layer的一般流程,自己设计的test_python_layer.py层只是起到演示作用,没有实际的功能. 1) Python layer 在caff ...
- 【撸码caffe 五】数据层搭建
caffe.cpp中的train函数内声明了一个类型为Solver类的智能指针solver: // Train / Finetune a model. int train() { -- shared_ ...
- [开源]OSharpNS 步步为营系列 - 2. 添加业务数据层
什么是OSharp OSharpNS全称OSharp Framework with .NetStandard2.0,是一个基于.NetStandard2.0开发的一个.NetCore快速开发框架.这个 ...
- caffe添加自己的层
首先修改src/caffe/proto/下的caffe.proto,修改好后需要编译 然后修改include/caffe/layers/logwxl_layer.hpp 然后修改src/caffe/l ...
- Caffe实现多标签输入,添加数据层(data layer)
因为之前遇到了sequence learning问题(CRNN),里面涉及到一张图对应多个标签.Caffe源码本身是不支持多类标签数据的输入的. 如果之前习惯调用脚本create_imagenet.s ...
- caffe添加自己编写的Python层
由于Python的灵活性,我们在caffe中添加自己定义的层时使用python层会更加方便,开发速速也会比C++更快,现在我就在这儿简单说一下如何在caffe中添加自定义的python层(使用的原网络 ...
- 【转】caffe数据层及参数
原文: 要运行caffe,需要先创建一个模型(model),如比较常用的Lenet,Alex等, 而一个模型由多个层(layer)构成,每一层又由许多参数组成.所有的参数都定义在caffe.proto ...
- 【转】Caffe初试(四)数据层及参数
要运行caffe,需要先创建一个模型(model),如比较常用的Lenet,Alex等,而一个模型由多个层(layer)构成,每一层又由许多参数组成.所有的参数都定义在caffe.proto这个文件中 ...
随机推荐
- [代码审计]云ec电商系统代码审计
0x00 前言 看了一下博客内最新的文章,竟然是3月28号的,一个多月没写文章了,博客都长草了. 主要是临近毕业,事情繁多,也没有啥时间和心情静下来写.. 不过现在的话,毕业的东西告一段落了,几乎没啥 ...
- BZOJ.1023.[SHOI2008]cactus仙人掌图(DP)
题目链接 类似求树的直径,可以用(类似)树形DP求每个点其子树(在仙人掌上就是诱导子图)最长链.次长链,用每个点子节点不同子树的 max{最长链}+max{次长链} 更新答案.(不需要存次长链,求解过 ...
- git 撤销本地修改
git checkout file 例如:git checkout app/views/carts/_index_m.html.erb 可以先用 git status 查看差异 然后 git chec ...
- 注册Goole 账户 成功注册
注册谷歌邮箱等Google帐号时提示:此电话号码无法用于进行验证怎么办? 相信很多网友在注册Google帐号的时候需要手机号码验证,比如在遇到过注册Google Gmail到最后一步“验证您的帐号” ...
- Cocos2d-x 3.0游戏开发之虚拟机IOS环境:匹配才是好,莫要随便升级软件
尊重开发人员的劳动成果.转载的时候请务必注明出处:http://blog.csdn.net/haomengzhu/article/details/34110449 做为一个买不起MAC的Coder,仅 ...
- BTrace housemd TProfiler
http://blog.csdn.net/y461517142/article/details/26269529 http://calvin1978.blogcn.com/articles/btrac ...
- AI 实验--v_JULY_v
http://blog.csdn.net/v_JULY_v http://www.julyedu.com/
- WCF消息传递
通过了解了WCF的一些基本概念并创建和编写WCF应用中的相应方法,实现了WCF服务和客户端之间的调用,就能够理解WCF应用是如何进行通信的.了解了一些基本的WCF概念后,还需要深入了解WCF消息的概念 ...
- android:activity活动的生命周期
掌握活动的生命周期对任何 Android 开发者来说都非常重要,当你深入理解活动的生命 周期之后,就可以写出更加连贯流畅的程序,并在如何合理管理应用资源方面,你会发挥的 游刃有余.你的应用程序将会拥有 ...
- Flume 1.5.0简单部署试用
================================================================================ 一.Flume简介 ========= ...