python caffe 在师兄的代码上修改成自己风格的代码
首先,感谢师兄的帮助。师兄的代码封装成类,流畅精美,容易调试。我的代码是堆积成的,被师兄嘲笑说写脚本。好吧!我的代码只有我懂,哈哈! 希望以后代码能写得工整点。现在还是让我先懂。这里,我做了一个简单的任务:0,1,2三个数字的分类。准确率:0.9806666666666667
(部分)代码分为:
1 train_net.py
#import some module
import time
import os
import numpy as np
import sys
import cv2
sys.path.append("/home/wang/Downloads/caffe-master/python")
import caffe
#from prepare_data import DataConfig
#from data_config import DataConfig #configure GPU mode
''' uncommend below line to use gpu '''
caffe.set_mode_gpu() # about dataset
##dataset = Dataset('/home/wang/Downloads/object/extract/')
##dataset = dataset.Split('train')
##data_config = DataConfig(dataset)
##data_config.SetBatchSize(256)
data_config='/home/wang/Downloads/caffe-master/examples/myFig_recognition/data/train/' #configure solve.prototxt
solver = caffe.SGDSolver('models/solver.prototxt') # load pretrain model
print('load pretrain model')
solver.net.copy_from('models/bvlc_reference_caffenet.caffemodel') solver.net.layers[0].SetDataConfig(data_config) for i in range(1, 10000):
# Make one SGD update
solver.step(5)
if i % 100 == 0:
solver.net.save('tmp.caffemodel')
''' TODO: test code '''
2 test_net.py
#import setup
import time
import os
import random
import sys
sys.path.append("/home/wang/Downloads/caffe-master/python")
import caffe
import cv2
import numpy as np
import random from utils import PrepareImage
#from dataset import Dataset
from test_data import test_data_pre test_num_once=10 ''' uncommend below line to use gpu '''
# caffe.set_mode_gpu() # dataset
#dataset = Dataset('/home/wang/Downloads/object/extract/')
#dataset = dataset.Split('test') # load net
net = caffe.Net('models/deploy.prototxt', caffe.TEST) # load train model
print('load pretrain model')
net.copy_from('tmp.caffemodel') #test all samples one by one
data_pre='/home/wang/Downloads/caffe-master/examples/myFig_recognition/data/test/'
#(imgPaths, gt_label) = dataset[int(random.random()*num_obj)]
(imgPaths, gt_label)=test_data_pre(data_pre)
num_img = len(imgPaths)
correct_num=0
for idx in range(num_img):
img = cv2.imread(imgPaths[idx])
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
tmp_img = img.copy() # for display
img = PrepareImage(img, (227, 227))
net.blobs['data'].reshape(test_num_once, 3, 227, 227)
net.blobs['data'].data[...] = img
#net.blobs['data'].data[i,:,:,:] = img
net.forward()
score = net.blobs['cls_prob'].data
if score.argmax()==gt_label[idx]:
correct_num=correct_num+1
if idx%100==0:
print("Please wait some minutes...")
correct_rate=correct_num*1.0/num_img
print('The correct rate is :',correct_rate)
3 test_data.py
import os
import numpy as np
from random import randint
import cv2
from utils import PrepareImage,CatImage
#class data:
#path should be /home/
def test_data_pre(path):
img_list=[]
image_num=len(os.listdir(path+'/0'))+len(os.listdir(path+'/1'))+len(os.listdir(path+'/2'))
label = np.zeros(image_num, dtype=np.float32) i=0
for idf in range(3):
idf_str=str(idf)
path1=path+idf_str
tmp_path=os.listdir(path1)
for idi in range(len(tmp_path)):
img_path=path1+'/'+tmp_path[idi]
img_list.append(img_path)
label[i]=idf
i=i+1
return ( img_list,label)
4 pre_data.py
import os
import numpy as np
from random import randint
import cv2
from utils import PrepareImage,CatImage
#class data:
#path should be /home/
def prepare_data(path,batchsize):
#tmp_path=os.listdir(path)
img_list=[]
label = np.zeros(batchsize, dtype=np.float32)
for i in range(batchsize):
#randomly select one file
idf=randint(0,2)
idf_str=str(idf)
path1=path+idf_str
tmp_path=os.listdir(path1) #randomly select one image
idi=randint(0,len(tmp_path)-1)
#img = cv2.imread(imgPaths[idx])
img_path=path1+'/'+tmp_path[idi]
img=cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
flip = randint(0, 1)>0
if flip > 0:
img = img[:, ::-1, :] # flip left to right img=PrepareImage(img, (227,227))
img_list.append(img)
label[i]=idf
imgData = CatImage(img_list)
return (imgData,label)
5 utils.py
import os
import cv2
import numpy as np def PrepareImage(im, size):
im = cv2.resize(im, (size[0], size[1]))
im = im.transpose(2, 0, 1)
im = im.astype(np.float32, copy=False)
return im def CatImage(im_list):
max_shape = np.array([im.shape for im in im_list]).max(axis=0)
blob = np.zeros((len(im_list), 3, max_shape[1], max_shape[2]), dtype=np.float32)
# set to mean value
blob[:, 0, :, :] = 102.9801
blob[:, 1, :, :] = 115.9465
blob[:, 2, :, :] = 122.7717
for i, im in enumerate(im_list):
blob[i, :, 0:im.shape[1], 0:im.shape[2]] = im
return blob
6 layer/data_layer.py
import caffe
import numpy as np #import data_config
#import prepare_data
from pre_data import prepare_data class DataLayer(caffe.Layer): def SetDataConfig(self, data_config):
self._data_config = data_config def GetDataConfig(self):
return self._data_config def setup(self, bottom, top):
# data blob
top[0].reshape(1, 3, 227, 227)
#top[0].reshape(1, 3, 34, 44)
# label type
top[1].reshape(1, 1) def reshape(self, bootom, top):
pass def forward(self, bottom, top):
#(imgs, label) = self._data_config.next()
path=self.GetDataConfig()
(imgs,label)=prepare_data(path,128)
(N, C, W, H) = imgs.shape
# image data
top[0].reshape(N, C, W, H)
top[0].data[...] = imgs
# object type label
top[1].reshape(N)
top[1].data[...] = label def backward(self, top, propagate_down, bottom):
pass
7 layer/__init__.py
import data_layer
还有一些caffe中经典的东西没放进来。
代码和数据:
python caffe 在师兄的代码上修改成自己风格的代码的更多相关文章
- 用Python给你的代码上个进度条吧 | 【代码也要面子的】
微信公众号:AI算法与图像处理如果你觉得对你有帮助,欢迎关注.转发以及点赞哦-( ̄▽ ̄-)~ 前言 最近在跑一些代码的时候,很烦...因为有时候不知道这段程序什么时候能执行完,现在执行哪里了,如果报错 ...
- Upsource——对已签入的代码进行分享、讨论和审查代码
Upsource 一.Upsource简介 Upsource ,这是一个专门为软件开发团队所设计的源代码协作工具.Upsource能够与多种版本控制工具进行集成,包括Git.Mercurial.Sub ...
- python之模块ftplib(实现ftp上传下载代码)
# -*- coding: utf-8 -*- #python 27 #xiaodeng #python之模块ftplib(实现ftp上传下载代码) #需求:实现ftp上传下载代码(不含错误处理) f ...
- 学习Git的一点心得以及如何把本地修改、删除的代码上传到github中
一:学习Github的资料如下:https://git.oschina.net/progit/ 这是一个学习Git的中文网站,如果诸位能够静下心来阅读,不要求阅读太多,只需要阅读前三章,就可以掌握Gi ...
- 基于Caffe的DeepID2实现(上)
小喵的唠叨话:小喵最近在做人脸识别的工作,打算将汤晓鸥前辈的DeepID,DeepID2等算法进行实验和复现.DeepID的方法最简单,而DeepID2的实现却略微复杂,并且互联网上也没有比较好的资源 ...
- 使用pycharm开发代码上传到GitLab和GitHub
使用pycharm开发代码上传到GitLab和GitHub 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 我这里主要是针对局域网的自减的GitLab服务器,python开发工程师如 ...
- python 全栈开发,Day86(上传文件,上传头像,CBV,python读写Excel,虚拟环境virtualenv)
一.上传文件 上传一个图片 使用input type="file",来上传一个文件.注意:form表单必须添加属性enctype="multipart/form-data ...
- 使用git工具将本地电脑上的代码上传至GitHub
本文教你如果使用git工具将本地电脑上的代码上传至GitHub 1.安装git工具 安装git链接 2.使用git工具上传自己的代码到GitHub中 安装完git工具之后,我们会得到两个命令行工具,一 ...
- Dynamics AX 2012 R2 窗体系列 - 在窗体上修改字段时所触发的方法及其顺序
在这个系列里,Reinhard将和大家一起探索在AX的窗体上执行操作时,都会触发窗体.窗体数据源和表上的哪些方法,并且是以怎样的顺序触发的. 这次,我们来看看在窗体上修改或录入数据的情 ...
随机推荐
- Spring Boot CRUD+分页(基于JPA规范)
步骤一:JPA概念 JPA(Java Persistence API)是Sun官方提出的Java持久化规范,用来方便大家操作数据库. 真正干活的可能是Hibernate,TopLink等等实现了JPA ...
- Maven 一段时间知识小结
二种打包命令生成后的jar包比较 1.clean install -P dev 2.clean package -Dmaven.test.skip=true -P dev //clean packa ...
- Java并发之synchronized深入
一句话总结synchronized: JVM会自动通过使用monitor来加锁和解锁,保证了同时只有一个线程可以执行指定代码,从而保证了线程安全,同时具有可重入和不可中断的性质. 一.synchron ...
- 【平台中间件】Nginx安装配置,实现版本更新不影响服务访问
为什么要做负载均衡? 当你网站是一个企业站.个人博客的时候,或者访问量比较小的时候,一台服务器完全应付的了,那就完全没必要做负载均衡.但是,如果你的网站是平台级别,用户达到十万百万级别了,一台服务器明 ...
- windows系统下,安装多个版本的jdk,java -version
问题描述: 开始安装了 jdk8 后来装了jdk9,可以为项目配置不同的jdk,相安无事: 今天发现软件需要jdk8的环境,结果我的java -version始终是jdk9.0.1: 解决办法:使ja ...
- 安装 python 数据分析插件 pandas
一上午试验了各种方法,发现利用pycharm是最快的.可以抛弃版本,命令和兼容问题的烦恼.纯粹傻瓜式 方法是 pycharm, 直接在settings里面,搜索pandas,添加即可,他会把所有之前需 ...
- HDU - 5988The 2016 ACM-ICPC Asia Qingdao Regional ContestG - Coding Contest 最小费用流
很巧妙的建边方式 题意:有n个区域,每个区域有一些人数si和食物bi,区域之间有m条定向路径,每条路径有人数通过上限ci.路径之间铺了电线,每当有人通过路径时有pi的概率会触碰到电线,但是第一个通过的 ...
- 在Intellij Idea中使用Maven创建Spring&SpringMVC项目
环境及版本 Jetbrains Intellij Idea 15.0.6 Spring 4.1.6 JDK 1.8.0_20 Tomcat 8 Windows 10 从 Maven archetype ...
- 三十二 Python分布式爬虫打造搜索引擎Scrapy精讲—scrapy的暂停与重启
scrapy的每一个爬虫,暂停时可以记录暂停状态以及爬取了哪些url,重启时可以从暂停状态开始爬取过的URL不在爬取 实现暂停与重启记录状态 1.首先cd进入到scrapy项目里 2.在scrapy项 ...
- Linux命令 ls -l 输出内容含义详解
Linux命令 ls -l s输出内容含义详解 1. ls 只显示文件名或者文件目录 2. ls -l(这个参数是字母L的小写,不是数字1) 用来查看详细的文件资料 在某个目录下键入ls -l可 ...