python caffe 在师兄的代码上修改成自己风格的代码
首先,感谢师兄的帮助。师兄的代码封装成类,流畅精美,容易调试。我的代码是堆积成的,被师兄嘲笑说写脚本。好吧!我的代码只有我懂,哈哈! 希望以后代码能写得工整点。现在还是让我先懂。这里,我做了一个简单的任务:0,1,2三个数字的分类。准确率:0.9806666666666667
(部分)代码分为:
1 train_net.py
#import some module
import time
import os
import numpy as np
import sys
import cv2
sys.path.append("/home/wang/Downloads/caffe-master/python")
import caffe
#from prepare_data import DataConfig
#from data_config import DataConfig #configure GPU mode
''' uncommend below line to use gpu '''
caffe.set_mode_gpu() # about dataset
##dataset = Dataset('/home/wang/Downloads/object/extract/')
##dataset = dataset.Split('train')
##data_config = DataConfig(dataset)
##data_config.SetBatchSize(256)
data_config='/home/wang/Downloads/caffe-master/examples/myFig_recognition/data/train/' #configure solve.prototxt
solver = caffe.SGDSolver('models/solver.prototxt') # load pretrain model
print('load pretrain model')
solver.net.copy_from('models/bvlc_reference_caffenet.caffemodel') solver.net.layers[0].SetDataConfig(data_config) for i in range(1, 10000):
# Make one SGD update
solver.step(5)
if i % 100 == 0:
solver.net.save('tmp.caffemodel')
''' TODO: test code '''
2 test_net.py
#import setup
import time
import os
import random
import sys
sys.path.append("/home/wang/Downloads/caffe-master/python")
import caffe
import cv2
import numpy as np
import random from utils import PrepareImage
#from dataset import Dataset
from test_data import test_data_pre test_num_once=10 ''' uncommend below line to use gpu '''
# caffe.set_mode_gpu() # dataset
#dataset = Dataset('/home/wang/Downloads/object/extract/')
#dataset = dataset.Split('test') # load net
net = caffe.Net('models/deploy.prototxt', caffe.TEST) # load train model
print('load pretrain model')
net.copy_from('tmp.caffemodel') #test all samples one by one
data_pre='/home/wang/Downloads/caffe-master/examples/myFig_recognition/data/test/'
#(imgPaths, gt_label) = dataset[int(random.random()*num_obj)]
(imgPaths, gt_label)=test_data_pre(data_pre)
num_img = len(imgPaths)
correct_num=0
for idx in range(num_img):
img = cv2.imread(imgPaths[idx])
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
tmp_img = img.copy() # for display
img = PrepareImage(img, (227, 227))
net.blobs['data'].reshape(test_num_once, 3, 227, 227)
net.blobs['data'].data[...] = img
#net.blobs['data'].data[i,:,:,:] = img
net.forward()
score = net.blobs['cls_prob'].data
if score.argmax()==gt_label[idx]:
correct_num=correct_num+1
if idx%100==0:
print("Please wait some minutes...")
correct_rate=correct_num*1.0/num_img
print('The correct rate is :',correct_rate)
3 test_data.py
import os
import numpy as np
from random import randint
import cv2
from utils import PrepareImage,CatImage
#class data:
#path should be /home/
def test_data_pre(path):
img_list=[]
image_num=len(os.listdir(path+'/0'))+len(os.listdir(path+'/1'))+len(os.listdir(path+'/2'))
label = np.zeros(image_num, dtype=np.float32) i=0
for idf in range(3):
idf_str=str(idf)
path1=path+idf_str
tmp_path=os.listdir(path1)
for idi in range(len(tmp_path)):
img_path=path1+'/'+tmp_path[idi]
img_list.append(img_path)
label[i]=idf
i=i+1
return ( img_list,label)
4 pre_data.py
import os
import numpy as np
from random import randint
import cv2
from utils import PrepareImage,CatImage
#class data:
#path should be /home/
def prepare_data(path,batchsize):
#tmp_path=os.listdir(path)
img_list=[]
label = np.zeros(batchsize, dtype=np.float32)
for i in range(batchsize):
#randomly select one file
idf=randint(0,2)
idf_str=str(idf)
path1=path+idf_str
tmp_path=os.listdir(path1) #randomly select one image
idi=randint(0,len(tmp_path)-1)
#img = cv2.imread(imgPaths[idx])
img_path=path1+'/'+tmp_path[idi]
img=cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
flip = randint(0, 1)>0
if flip > 0:
img = img[:, ::-1, :] # flip left to right img=PrepareImage(img, (227,227))
img_list.append(img)
label[i]=idf
imgData = CatImage(img_list)
return (imgData,label)
5 utils.py
import os
import cv2
import numpy as np def PrepareImage(im, size):
im = cv2.resize(im, (size[0], size[1]))
im = im.transpose(2, 0, 1)
im = im.astype(np.float32, copy=False)
return im def CatImage(im_list):
max_shape = np.array([im.shape for im in im_list]).max(axis=0)
blob = np.zeros((len(im_list), 3, max_shape[1], max_shape[2]), dtype=np.float32)
# set to mean value
blob[:, 0, :, :] = 102.9801
blob[:, 1, :, :] = 115.9465
blob[:, 2, :, :] = 122.7717
for i, im in enumerate(im_list):
blob[i, :, 0:im.shape[1], 0:im.shape[2]] = im
return blob
6 layer/data_layer.py
import caffe
import numpy as np #import data_config
#import prepare_data
from pre_data import prepare_data class DataLayer(caffe.Layer): def SetDataConfig(self, data_config):
self._data_config = data_config def GetDataConfig(self):
return self._data_config def setup(self, bottom, top):
# data blob
top[0].reshape(1, 3, 227, 227)
#top[0].reshape(1, 3, 34, 44)
# label type
top[1].reshape(1, 1) def reshape(self, bootom, top):
pass def forward(self, bottom, top):
#(imgs, label) = self._data_config.next()
path=self.GetDataConfig()
(imgs,label)=prepare_data(path,128)
(N, C, W, H) = imgs.shape
# image data
top[0].reshape(N, C, W, H)
top[0].data[...] = imgs
# object type label
top[1].reshape(N)
top[1].data[...] = label def backward(self, top, propagate_down, bottom):
pass
7 layer/__init__.py
import data_layer
还有一些caffe中经典的东西没放进来。
代码和数据:
python caffe 在师兄的代码上修改成自己风格的代码的更多相关文章
- 用Python给你的代码上个进度条吧 | 【代码也要面子的】
微信公众号:AI算法与图像处理如果你觉得对你有帮助,欢迎关注.转发以及点赞哦-( ̄▽ ̄-)~ 前言 最近在跑一些代码的时候,很烦...因为有时候不知道这段程序什么时候能执行完,现在执行哪里了,如果报错 ...
- Upsource——对已签入的代码进行分享、讨论和审查代码
Upsource 一.Upsource简介 Upsource ,这是一个专门为软件开发团队所设计的源代码协作工具.Upsource能够与多种版本控制工具进行集成,包括Git.Mercurial.Sub ...
- python之模块ftplib(实现ftp上传下载代码)
# -*- coding: utf-8 -*- #python 27 #xiaodeng #python之模块ftplib(实现ftp上传下载代码) #需求:实现ftp上传下载代码(不含错误处理) f ...
- 学习Git的一点心得以及如何把本地修改、删除的代码上传到github中
一:学习Github的资料如下:https://git.oschina.net/progit/ 这是一个学习Git的中文网站,如果诸位能够静下心来阅读,不要求阅读太多,只需要阅读前三章,就可以掌握Gi ...
- 基于Caffe的DeepID2实现(上)
小喵的唠叨话:小喵最近在做人脸识别的工作,打算将汤晓鸥前辈的DeepID,DeepID2等算法进行实验和复现.DeepID的方法最简单,而DeepID2的实现却略微复杂,并且互联网上也没有比较好的资源 ...
- 使用pycharm开发代码上传到GitLab和GitHub
使用pycharm开发代码上传到GitLab和GitHub 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 我这里主要是针对局域网的自减的GitLab服务器,python开发工程师如 ...
- python 全栈开发,Day86(上传文件,上传头像,CBV,python读写Excel,虚拟环境virtualenv)
一.上传文件 上传一个图片 使用input type="file",来上传一个文件.注意:form表单必须添加属性enctype="multipart/form-data ...
- 使用git工具将本地电脑上的代码上传至GitHub
本文教你如果使用git工具将本地电脑上的代码上传至GitHub 1.安装git工具 安装git链接 2.使用git工具上传自己的代码到GitHub中 安装完git工具之后,我们会得到两个命令行工具,一 ...
- Dynamics AX 2012 R2 窗体系列 - 在窗体上修改字段时所触发的方法及其顺序
在这个系列里,Reinhard将和大家一起探索在AX的窗体上执行操作时,都会触发窗体.窗体数据源和表上的哪些方法,并且是以怎样的顺序触发的. 这次,我们来看看在窗体上修改或录入数据的情 ...
随机推荐
- mysql数据库无法连接(JDBC)java.net.ConnectException: Connection timed out
数据库无法连接(JDBC) 用户名密码正确,但是一直报错:Connection timed out 后来知道了原因:我用的是BAE提供的云mysql数据库,对访问的IP有限制 ,所以在本机上无法连接. ...
- 关于Bonobo Git Server的安装
1.关于安装 参考官网:https://bonobogitserver.com/ 实际上就是在IIS上搭建一个MVC程序.安装教程:https://bonobogitserver.com/instal ...
- FASTQ format
FASTQ format 每个FASTQ文件中每个序列通常有四行信息: 1: 以 '@' 字符开头,后面紧接着的是序列标识符和可选字段的描述(类似FASTA title line). 2: 序列 3: ...
- c# iText 生成PDF 有文字,图片,表格,文字样式,对齐方式,页眉页脚,等等等,
#region 下载说明书PDF protected void lbtnDownPDF_Click(object sender, EventArgs e) { int pid = ConvertHel ...
- HDU5324 cqd分治
HDU5324 cqd分治 标签(空格分隔): 未分类 给你两个长度相同数列,求第一个不上升,第二个不下降的最长子序列长度. 这里要求的子序列对第一个和第二个来说是相同的.即如果你在第一个序列里选了第 ...
- 关于javascript以及jquery如何打开文件
其实很简单, <input type="file" id="file" mce_style="display:none"> 这个 ...
- vector的坑——C++primer练习6.33总结
说来惭愧,一道简单的对vector递归的题目写了一个多小时,最后还是请教了大神才改出来. 首先贴上原代码: void return_vector(vector<int>::iterator ...
- Swoft 快速上手小贴士
IDE一定要装注解插件PHP Annotations Request和Response里的with...开头的方法会clone $this, 而不是修改本实体, 所以设置Cookie之类的时候要$re ...
- 重新学习MySQL数据库8:MySQL的事务隔离级别实战
重新学习Mysql数据库8:MySQL的事务隔离级别实战 在Mysql中,事务主要有四种隔离级别,今天我们主要是通过示例来比较下,四种隔离级别实际在应用中,会出现什么样的对应现象. Read unco ...
- (2) iOS开发之UI处理-UILabel篇
我们经常要根据内容去动态计算控件的高度,比如一个UILabel控件,常常要显示多行内容,并且计算出总高度,如果每个UILabel要多行显示,都要写这么一段代码是非常痛苦的,看代码如下: 我想大 ...