caffe添加python数据层
在caffe中添加自定义层时,必须要实现这四个函数,在C++中是(LayerSetUp,Reshape,Forward_cpu,Backward_cpu),在python 中是(setup,reshape,forward_cpu,backword_cpu)。
prototxt
layer {
name: "data"
type: "Python"
top: "data"
top: "label"
include {
phase: TRAIN
}
python_param {
module: "src.data_layer.rank_layer_live" # 不能代目录形式
layer: "DataLayer"
param_str: " {\'pascal_root\': \'data\' ,\'split\': \'live_train\', \'im_shape\': [224, 224],\'batch_size\': 32}"
}
}
layer {
name: "data"
type: "Python"
top: "data"
top: "label"
include {
phase: TEST
}
python_param {
module: "src.data_layer.rank_layer_live"
layer: "DataLayer"
#batch_size: 160
param_str: " {\'pascal_root\': \'data\' ,\'split\': \'live_test\', \'im_shape\': [224, 224],\'batch_size\': 32}"
}
}
- 数据定义层:
import cv2
import sys
sys.path.append("/home/rjw/caffe/python")
import caffe
import numpy as np
import multiprocessing as mtp
import pdb
import os.path as osp ## 理解参考:https://blog.csdn.net/auto1993/article/details/78951849 class DataLayer(caffe.Layer): def setup(self, bottom, top): self._name_to_top_map = {}
self._name_to_top_map['data'] = 0
self._name_to_top_map['label'] = 1
# === Read input parameters ===
self.workers= mtp.Pool(10)
# params is a python dictionary with layer parameters.
params = eval(self.param_str) # Check the paramameters for validity.
check_params(params) # store input as class variables
self.batch_size = params['batch_size']
self.pascal_root = params['pascal_root']
self.im_shape = params['im_shape']
# get list of image indexes.
list_file = params['split'] + '.txt'
filename = [line.rstrip('\n') for line in open(
osp.join(self.pascal_root, list_file))]
self._roidb = []
self.scores =[]
for i in filename:
self._roidb.append(i.split()[0])
self.scores.append(float(i.split()[1]))
self._perm = None
self._cur = 0
self.num =0 top[0].reshape(
self.batch_size, 3, params['im_shape'][0], params['im_shape'][1]) top[1].reshape(self.batch_size, 1) def _get_next_minibatch_inds(self):
"""Return the roidb indices for the next minibatch."""
db_inds = []
dis = 4 # total number of distortions in live dataset
batch = 2 # number of images for each distortion level
level = 4 # distortion levels for each mini_batch = level * dis_mini*batch
#shuff = np.random.permutation(range(dis))
Num = len(self.scores)/dis/level
for k in range(dis):
for i in range(level):
temp = self.num
for j in range(batch):
db_inds.append(len(self.scores)/dis*k+i*Num+temp)
temp = temp +1
self.num = self.num+batch
if Num-self.num<batch:
self.num=0
db_inds = np.asarray(db_inds)
return db_inds def get_minibatch(self,minibatch_db):
"""Given a roidb, construct a minibatch sampled from it."""
# Get the input image blob, formatted for caffe jobs =self.workers.map(preprocess,minibatch_db)
#print len(jobs)
index = 0
images_train = np.zeros([self.batch_size,3,224,224],np.float32)
#pdb.set_trace()
for index_job in range(len(jobs)):
images_train[index,:,:,:] = jobs[index_job]
index += 1 blobs = {'data': images_train}
return blobs def forward(self, bottom, top):
"""Get blobs and copy them into this layer's top blob vector.""" db_inds = self._get_next_minibatch_inds()
minibatch_db = []
for i in range(len(db_inds)):
minibatch_db.append(self._roidb[int(db_inds[i])])
#minibatch_db = [self._roidb[i] for i in db_inds]
#print minibatch_db
scores = []
for i in range(len(db_inds)):
scores.append(self.scores[int(db_inds[i])])
blobs = self.get_minibatch(minibatch_db)
blobs ['label'] =np.asarray(scores)
for blob_name, blob in blobs.iteritems():
top_ind = self._name_to_top_map[blob_name]
# Reshape net's input blobs
top[top_ind].reshape(*(blob.shape))
# Copy data into net's input blobs
top[top_ind].data[...] = blob.astype(np.float32, copy=False) def backward(self, top, propagate_down, bottom):
"""This layer does not propagate gradients."""
pass def reshape(self, bottom, top):
"""Reshaping happens during the call to forward."""
pass def preprocess(data): sp = 224
im = np.asarray(cv2.imread(data))
x = im.shape[0]
y = im.shape[1]
x_p = np.random.randint(x-sp,size=1)[0]
y_p = np.random.randint(y-sp,size=1)[0]
#print x_p,y_p
images = im[x_p:x_p+sp,y_p:y_p+sp,:].transpose([2,0,1])
#print images.shape
return images def check_params(params):
"""
A utility function to check the parameters for the data layers.
"""
assert 'split' in params.keys(
), 'Params must include split (train, val, or test).' required = ['batch_size', 'pascal_root', 'im_shape']
for r in required:
assert r in params.keys(), 'Params must include {}'.format(r)
caffe添加python数据层的更多相关文章
- caffe 中 python 数据层
caffe中大多数层用C++写成. 但是对于自己数据的输入要写对应的输入层,比如你要去图像中的一部分,不能用LMDB,或者你的label 需要特殊的标记. 这时候就需要用python 写一个输入层. ...
- 在Caffe添加Python layer详细步骤
本文主要讨论的是在caffe中添加python layer的一般流程,自己设计的test_python_layer.py层只是起到演示作用,没有实际的功能. 1) Python layer 在caff ...
- 【撸码caffe 五】数据层搭建
caffe.cpp中的train函数内声明了一个类型为Solver类的智能指针solver: // Train / Finetune a model. int train() { -- shared_ ...
- [开源]OSharpNS 步步为营系列 - 2. 添加业务数据层
什么是OSharp OSharpNS全称OSharp Framework with .NetStandard2.0,是一个基于.NetStandard2.0开发的一个.NetCore快速开发框架.这个 ...
- caffe添加自己的层
首先修改src/caffe/proto/下的caffe.proto,修改好后需要编译 然后修改include/caffe/layers/logwxl_layer.hpp 然后修改src/caffe/l ...
- Caffe实现多标签输入,添加数据层(data layer)
因为之前遇到了sequence learning问题(CRNN),里面涉及到一张图对应多个标签.Caffe源码本身是不支持多类标签数据的输入的. 如果之前习惯调用脚本create_imagenet.s ...
- caffe添加自己编写的Python层
由于Python的灵活性,我们在caffe中添加自己定义的层时使用python层会更加方便,开发速速也会比C++更快,现在我就在这儿简单说一下如何在caffe中添加自定义的python层(使用的原网络 ...
- 【转】caffe数据层及参数
原文: 要运行caffe,需要先创建一个模型(model),如比较常用的Lenet,Alex等, 而一个模型由多个层(layer)构成,每一层又由许多参数组成.所有的参数都定义在caffe.proto ...
- 【转】Caffe初试(四)数据层及参数
要运行caffe,需要先创建一个模型(model),如比较常用的Lenet,Alex等,而一个模型由多个层(layer)构成,每一层又由许多参数组成.所有的参数都定义在caffe.proto这个文件中 ...
随机推荐
- JVM简介堆中新生代老年代浅析
一.JVM内存结构由程序计数器.堆.栈.本地方法栈.方法区等部分组成.1)程序计数器 几乎不占有内存.用于取下一条执行的指令.2)堆 所有通过new创建的对象的内存都在堆中分配,其大小可以通过-Xmx ...
- WatermarkMaker
using System; using System.Collections.Generic; using System.Linq; using System.Web; using System.Dr ...
- BZOJ.2007.[NOI2010]海拔(最小割 对偶图最短路)
题目链接 想一下能猜出,最优解中海拔只有0和1,且海拔相同的点都在且只在1个连通块中. 这就是个平面图最小割.也可以转必须转对偶图最短路,不然只能T到90分了..边的方向看着定就行. 不能忽略回去的边 ...
- 20172308《Java软件结构与数据结构》第四周学习总结
教材学习内容总结 第 6 章 列表 一. 列表集合 列表集合:一种概念性表示法,思想是使事物以线性列表的方式进行组织 特点: 列表集合没有内在的容量大小,它可以随着需要而增大 列表集合更具一般化,可以 ...
- 喵哈哈村的魔法考试 Round #7 (Div.2) 题解
喵哈哈村的魔法考试 Round #7 (Div.2) 注意!后四道题来自于周日的hihocoder offer收割赛第九场. 我建了个群:欢迎加入qscoj交流群,群号码:540667432 大概作为 ...
- webdings 和 wingdings 字体
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/ ...
- LPC43xx OTP
- delphi TOnFormVisibleChangeEvent 事件应用
TGQIFileMgrForm = class(TForm) 定义 property OnVisibleChange: TOnFormVisibleChangeEvent read FOnVisibl ...
- 对数据集“dsArea”执行查询失败。 (rsErrorExecutingCommand),Query execution failed for dataset 'dsArea'. (rsErrorExecutingCommand),Manually process the TFS data warehouse and analysis services cube
错误提示: 处理报表时出错. (rsProcessingAborted)对数据集“dsArea”执行查询失败. (rsErrorExecutingCommand)Team System 多维数据集或者 ...
- In-Place upgrade to Team Foundation Server (TFS) 2015 from TFS 2013Team Foundation Server TFS TFS 2015 TFS upgrade TFS with Sharepoint
This upgrade document gives detailed step by step procedure for the In-Place upgrade from TFS 2013 t ...