caffe添加python数据层
在caffe中添加自定义层时,必须要实现这四个函数,在C++中是(LayerSetUp,Reshape,Forward_cpu,Backward_cpu),在python 中是(setup,reshape,forward_cpu,backword_cpu)。
prototxt
layer {
name: "data"
type: "Python"
top: "data"
top: "label"
include {
phase: TRAIN
}
python_param {
module: "src.data_layer.rank_layer_live" # 不能代目录形式
layer: "DataLayer"
param_str: " {\'pascal_root\': \'data\' ,\'split\': \'live_train\', \'im_shape\': [224, 224],\'batch_size\': 32}"
}
}
layer {
name: "data"
type: "Python"
top: "data"
top: "label"
include {
phase: TEST
}
python_param {
module: "src.data_layer.rank_layer_live"
layer: "DataLayer"
#batch_size: 160
param_str: " {\'pascal_root\': \'data\' ,\'split\': \'live_test\', \'im_shape\': [224, 224],\'batch_size\': 32}"
}
}
- 数据定义层:
import cv2
import sys
sys.path.append("/home/rjw/caffe/python")
import caffe
import numpy as np
import multiprocessing as mtp
import pdb
import os.path as osp ## 理解参考:https://blog.csdn.net/auto1993/article/details/78951849 class DataLayer(caffe.Layer): def setup(self, bottom, top): self._name_to_top_map = {}
self._name_to_top_map['data'] = 0
self._name_to_top_map['label'] = 1
# === Read input parameters ===
self.workers= mtp.Pool(10)
# params is a python dictionary with layer parameters.
params = eval(self.param_str) # Check the paramameters for validity.
check_params(params) # store input as class variables
self.batch_size = params['batch_size']
self.pascal_root = params['pascal_root']
self.im_shape = params['im_shape']
# get list of image indexes.
list_file = params['split'] + '.txt'
filename = [line.rstrip('\n') for line in open(
osp.join(self.pascal_root, list_file))]
self._roidb = []
self.scores =[]
for i in filename:
self._roidb.append(i.split()[0])
self.scores.append(float(i.split()[1]))
self._perm = None
self._cur = 0
self.num =0 top[0].reshape(
self.batch_size, 3, params['im_shape'][0], params['im_shape'][1]) top[1].reshape(self.batch_size, 1) def _get_next_minibatch_inds(self):
"""Return the roidb indices for the next minibatch."""
db_inds = []
dis = 4 # total number of distortions in live dataset
batch = 2 # number of images for each distortion level
level = 4 # distortion levels for each mini_batch = level * dis_mini*batch
#shuff = np.random.permutation(range(dis))
Num = len(self.scores)/dis/level
for k in range(dis):
for i in range(level):
temp = self.num
for j in range(batch):
db_inds.append(len(self.scores)/dis*k+i*Num+temp)
temp = temp +1
self.num = self.num+batch
if Num-self.num<batch:
self.num=0
db_inds = np.asarray(db_inds)
return db_inds def get_minibatch(self,minibatch_db):
"""Given a roidb, construct a minibatch sampled from it."""
# Get the input image blob, formatted for caffe jobs =self.workers.map(preprocess,minibatch_db)
#print len(jobs)
index = 0
images_train = np.zeros([self.batch_size,3,224,224],np.float32)
#pdb.set_trace()
for index_job in range(len(jobs)):
images_train[index,:,:,:] = jobs[index_job]
index += 1 blobs = {'data': images_train}
return blobs def forward(self, bottom, top):
"""Get blobs and copy them into this layer's top blob vector.""" db_inds = self._get_next_minibatch_inds()
minibatch_db = []
for i in range(len(db_inds)):
minibatch_db.append(self._roidb[int(db_inds[i])])
#minibatch_db = [self._roidb[i] for i in db_inds]
#print minibatch_db
scores = []
for i in range(len(db_inds)):
scores.append(self.scores[int(db_inds[i])])
blobs = self.get_minibatch(minibatch_db)
blobs ['label'] =np.asarray(scores)
for blob_name, blob in blobs.iteritems():
top_ind = self._name_to_top_map[blob_name]
# Reshape net's input blobs
top[top_ind].reshape(*(blob.shape))
# Copy data into net's input blobs
top[top_ind].data[...] = blob.astype(np.float32, copy=False) def backward(self, top, propagate_down, bottom):
"""This layer does not propagate gradients."""
pass def reshape(self, bottom, top):
"""Reshaping happens during the call to forward."""
pass def preprocess(data): sp = 224
im = np.asarray(cv2.imread(data))
x = im.shape[0]
y = im.shape[1]
x_p = np.random.randint(x-sp,size=1)[0]
y_p = np.random.randint(y-sp,size=1)[0]
#print x_p,y_p
images = im[x_p:x_p+sp,y_p:y_p+sp,:].transpose([2,0,1])
#print images.shape
return images def check_params(params):
"""
A utility function to check the parameters for the data layers.
"""
assert 'split' in params.keys(
), 'Params must include split (train, val, or test).' required = ['batch_size', 'pascal_root', 'im_shape']
for r in required:
assert r in params.keys(), 'Params must include {}'.format(r)
caffe添加python数据层的更多相关文章
- caffe 中 python 数据层
caffe中大多数层用C++写成. 但是对于自己数据的输入要写对应的输入层,比如你要去图像中的一部分,不能用LMDB,或者你的label 需要特殊的标记. 这时候就需要用python 写一个输入层. ...
- 在Caffe添加Python layer详细步骤
本文主要讨论的是在caffe中添加python layer的一般流程,自己设计的test_python_layer.py层只是起到演示作用,没有实际的功能. 1) Python layer 在caff ...
- 【撸码caffe 五】数据层搭建
caffe.cpp中的train函数内声明了一个类型为Solver类的智能指针solver: // Train / Finetune a model. int train() { -- shared_ ...
- [开源]OSharpNS 步步为营系列 - 2. 添加业务数据层
什么是OSharp OSharpNS全称OSharp Framework with .NetStandard2.0,是一个基于.NetStandard2.0开发的一个.NetCore快速开发框架.这个 ...
- caffe添加自己的层
首先修改src/caffe/proto/下的caffe.proto,修改好后需要编译 然后修改include/caffe/layers/logwxl_layer.hpp 然后修改src/caffe/l ...
- Caffe实现多标签输入,添加数据层(data layer)
因为之前遇到了sequence learning问题(CRNN),里面涉及到一张图对应多个标签.Caffe源码本身是不支持多类标签数据的输入的. 如果之前习惯调用脚本create_imagenet.s ...
- caffe添加自己编写的Python层
由于Python的灵活性,我们在caffe中添加自己定义的层时使用python层会更加方便,开发速速也会比C++更快,现在我就在这儿简单说一下如何在caffe中添加自定义的python层(使用的原网络 ...
- 【转】caffe数据层及参数
原文: 要运行caffe,需要先创建一个模型(model),如比较常用的Lenet,Alex等, 而一个模型由多个层(layer)构成,每一层又由许多参数组成.所有的参数都定义在caffe.proto ...
- 【转】Caffe初试(四)数据层及参数
要运行caffe,需要先创建一个模型(model),如比较常用的Lenet,Alex等,而一个模型由多个层(layer)构成,每一层又由许多参数组成.所有的参数都定义在caffe.proto这个文件中 ...
随机推荐
- Orleans介绍
一.介绍 Orleans是一个框架,提供了一个直接的方法来构建分布式高规模计算应用程序 默认可扩展 -> Orleans处理构建分布式系统的复杂性,使您的应用程序能够扩展到数百台服务器.低延迟 ...
- Android activity之间数据传递和共享的方式之Application
1.基于消息的通信机制 Intent ---bundle ,extra 数据类型有限,比如遇到不可序列化的数据Bitmap,InputStream,或者LinkedList链表等等数据类型就不太好用 ...
- [USACO07JAN]Balanced Lineup
OJ题号:洛谷2880 思路1: 线段树维护区间最大最小值. #include<cstdio> #include<cctype> #include<utility> ...
- ZOJ 2975 Kinds of Fuwas
K - Kinds of Fuwas Time Limit:2000MS Memory Limit:65536KB 64bit IO Format:%lld & %llu De ...
- WAP2.0(XHTML MP)基础介绍
(一)XHTML MP 介绍XHTML MP(eXtensible HyperText Markup Language Mobile Profile)WAP2.0与WCSS(WAP CSS /WAP ...
- mysql故障
1.服务器上是的电不要随边乱断,一定要保存,然后断电,不要在服务器插座版上乱插其他电器,导致非法断电, 2.出现断电后,检查MYSQL数据库文件是否损坏,可以看WINDOWS 应用程序程序管理日志,提 ...
- C#编程(四十四)----------string和stringbuilder
System.String类 首先string类是静态的,System.String是最常用的字符串操作类,可以帮助开发者完成绝大部分的字符串操作功能,使用方便. 1.比较字符串 比较字符串是指按照字 ...
- Oracle 导入导出 dmp 文件
导入dmp文件,需要知道这个dmp文件创建的用户.因此需要先创建用户,并授权给它. (1)用户的创建 首先,以system用户登录Oracle SQL Developer 其次,在sql工作表(可以用 ...
- 【python】python安装步骤
1.官网下载python 官网地址:https://www.python.org/getit/ 2.下载完成后点击安装 勾选Add python to PATH 是可以自己去配置环境变量的 注意:这里 ...
- Python:日期和时间类型学习
背景 在非开发环境经常需要做一下日期计算,就准备使用Python,顺便记下来学习的痕迹. 代码 1 # coding = utf-8 2 3 from datetime import * 4 5 ## ...