首先,感谢师兄的帮助。师兄的代码封装成类,流畅精美,容易调试。我的代码是堆积成的,被师兄嘲笑说写脚本。好吧!我的代码只有我懂,哈哈! 希望以后代码能写得工整点。现在还是让我先懂。这里,我做了一个简单的任务:0,1,2三个数字的分类。准确率:0.9806666666666667

(部分)代码分为:

1 train_net.py

 #import some module
import time
import os
import numpy as np
import sys
import cv2
sys.path.append("/home/wang/Downloads/caffe-master/python")
import caffe
#from prepare_data import DataConfig
#from data_config import DataConfig #configure GPU mode
''' uncommend below line to use gpu '''
caffe.set_mode_gpu() # about dataset
##dataset = Dataset('/home/wang/Downloads/object/extract/')
##dataset = dataset.Split('train')
##data_config = DataConfig(dataset)
##data_config.SetBatchSize(256)
data_config='/home/wang/Downloads/caffe-master/examples/myFig_recognition/data/train/' #configure solve.prototxt
solver = caffe.SGDSolver('models/solver.prototxt') # load pretrain model
print('load pretrain model')
solver.net.copy_from('models/bvlc_reference_caffenet.caffemodel') solver.net.layers[0].SetDataConfig(data_config) for i in range(1, 10000):
# Make one SGD update
solver.step(5)
if i % 100 == 0:
solver.net.save('tmp.caffemodel')
''' TODO: test code '''

2 test_net.py

 #import setup
import time
import os
import random
import sys
sys.path.append("/home/wang/Downloads/caffe-master/python")
import caffe
import cv2
import numpy as np
import random from utils import PrepareImage
#from dataset import Dataset
from test_data import test_data_pre test_num_once=10 ''' uncommend below line to use gpu '''
# caffe.set_mode_gpu() # dataset
#dataset = Dataset('/home/wang/Downloads/object/extract/')
#dataset = dataset.Split('test') # load net
net = caffe.Net('models/deploy.prototxt', caffe.TEST) # load train model
print('load pretrain model')
net.copy_from('tmp.caffemodel') #test all samples one by one
data_pre='/home/wang/Downloads/caffe-master/examples/myFig_recognition/data/test/'
#(imgPaths, gt_label) = dataset[int(random.random()*num_obj)]
(imgPaths, gt_label)=test_data_pre(data_pre)
num_img = len(imgPaths)
correct_num=0
for idx in range(num_img):
img = cv2.imread(imgPaths[idx])
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
tmp_img = img.copy() # for display
img = PrepareImage(img, (227, 227))
net.blobs['data'].reshape(test_num_once, 3, 227, 227)
net.blobs['data'].data[...] = img
#net.blobs['data'].data[i,:,:,:] = img
net.forward()
score = net.blobs['cls_prob'].data
if score.argmax()==gt_label[idx]:
correct_num=correct_num+1
if idx%100==0:
print("Please wait some minutes...")
correct_rate=correct_num*1.0/num_img
print('The correct rate is :',correct_rate)

3 test_data.py

 import os
import numpy as np
from random import randint
import cv2
from utils import PrepareImage,CatImage
#class data:
#path should be /home/
def test_data_pre(path):
img_list=[]
image_num=len(os.listdir(path+'/0'))+len(os.listdir(path+'/1'))+len(os.listdir(path+'/2'))
label = np.zeros(image_num, dtype=np.float32) i=0
for idf in range(3):
idf_str=str(idf)
path1=path+idf_str
tmp_path=os.listdir(path1)
for idi in range(len(tmp_path)):
img_path=path1+'/'+tmp_path[idi]
img_list.append(img_path)
label[i]=idf
i=i+1
return ( img_list,label)

4 pre_data.py

 import os
import numpy as np
from random import randint
import cv2
from utils import PrepareImage,CatImage
#class data:
#path should be /home/
def prepare_data(path,batchsize):
#tmp_path=os.listdir(path)
img_list=[]
label = np.zeros(batchsize, dtype=np.float32)
for i in range(batchsize):
#randomly select one file
idf=randint(0,2)
idf_str=str(idf)
path1=path+idf_str
tmp_path=os.listdir(path1) #randomly select one image
idi=randint(0,len(tmp_path)-1)
#img = cv2.imread(imgPaths[idx])
img_path=path1+'/'+tmp_path[idi]
img=cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
flip = randint(0, 1)>0
if flip > 0:
img = img[:, ::-1, :] # flip left to right img=PrepareImage(img, (227,227))
img_list.append(img)
label[i]=idf
imgData = CatImage(img_list)
return (imgData,label)

5 utils.py

 import os
import cv2
import numpy as np def PrepareImage(im, size):
im = cv2.resize(im, (size[0], size[1]))
im = im.transpose(2, 0, 1)
im = im.astype(np.float32, copy=False)
return im def CatImage(im_list):
max_shape = np.array([im.shape for im in im_list]).max(axis=0)
blob = np.zeros((len(im_list), 3, max_shape[1], max_shape[2]), dtype=np.float32)
# set to mean value
blob[:, 0, :, :] = 102.9801
blob[:, 1, :, :] = 115.9465
blob[:, 2, :, :] = 122.7717
for i, im in enumerate(im_list):
blob[i, :, 0:im.shape[1], 0:im.shape[2]] = im
return blob

6 layer/data_layer.py

 import caffe
import numpy as np #import data_config
#import prepare_data
from pre_data import prepare_data class DataLayer(caffe.Layer): def SetDataConfig(self, data_config):
self._data_config = data_config def GetDataConfig(self):
return self._data_config def setup(self, bottom, top):
# data blob
top[0].reshape(1, 3, 227, 227)
#top[0].reshape(1, 3, 34, 44)
# label type
top[1].reshape(1, 1) def reshape(self, bootom, top):
pass def forward(self, bottom, top):
#(imgs, label) = self._data_config.next()
path=self.GetDataConfig()
(imgs,label)=prepare_data(path,128)
(N, C, W, H) = imgs.shape
# image data
top[0].reshape(N, C, W, H)
top[0].data[...] = imgs
# object type label
top[1].reshape(N)
top[1].data[...] = label def backward(self, top, propagate_down, bottom):
pass

7 layer/__init__.py

import data_layer

还有一些caffe中经典的东西没放进来。

代码和数据:

python caffe 在师兄的代码上修改成自己风格的代码的更多相关文章

  1. 用Python给你的代码上个进度条吧 | 【代码也要面子的】

    微信公众号:AI算法与图像处理如果你觉得对你有帮助,欢迎关注.转发以及点赞哦-( ̄▽ ̄-)~ 前言 最近在跑一些代码的时候,很烦...因为有时候不知道这段程序什么时候能执行完,现在执行哪里了,如果报错 ...

  2. Upsource——对已签入的代码进行分享、讨论和审查代码

    Upsource 一.Upsource简介 Upsource ,这是一个专门为软件开发团队所设计的源代码协作工具.Upsource能够与多种版本控制工具进行集成,包括Git.Mercurial.Sub ...

  3. python之模块ftplib(实现ftp上传下载代码)

    # -*- coding: utf-8 -*- #python 27 #xiaodeng #python之模块ftplib(实现ftp上传下载代码) #需求:实现ftp上传下载代码(不含错误处理) f ...

  4. 学习Git的一点心得以及如何把本地修改、删除的代码上传到github中

    一:学习Github的资料如下:https://git.oschina.net/progit/ 这是一个学习Git的中文网站,如果诸位能够静下心来阅读,不要求阅读太多,只需要阅读前三章,就可以掌握Gi ...

  5. 基于Caffe的DeepID2实现(上)

    小喵的唠叨话:小喵最近在做人脸识别的工作,打算将汤晓鸥前辈的DeepID,DeepID2等算法进行实验和复现.DeepID的方法最简单,而DeepID2的实现却略微复杂,并且互联网上也没有比较好的资源 ...

  6. 使用pycharm开发代码上传到GitLab和GitHub

    使用pycharm开发代码上传到GitLab和GitHub 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 我这里主要是针对局域网的自减的GitLab服务器,python开发工程师如 ...

  7. python 全栈开发,Day86(上传文件,上传头像,CBV,python读写Excel,虚拟环境virtualenv)

    一.上传文件 上传一个图片 使用input type="file",来上传一个文件.注意:form表单必须添加属性enctype="multipart/form-data ...

  8. 使用git工具将本地电脑上的代码上传至GitHub

    本文教你如果使用git工具将本地电脑上的代码上传至GitHub 1.安装git工具 安装git链接 2.使用git工具上传自己的代码到GitHub中 安装完git工具之后,我们会得到两个命令行工具,一 ...

  9. Dynamics AX 2012 R2 窗体系列 - 在窗体上修改字段时所触发的方法及其顺序

        在这个系列里,Reinhard将和大家一起探索在AX的窗体上执行操作时,都会触发窗体.窗体数据源和表上的哪些方法,并且是以怎样的顺序触发的.     这次,我们来看看在窗体上修改或录入数据的情 ...

随机推荐

  1. Docker-Compose 一键部署Ningx+.Net Core+Redis集群

    在看该文章前,你需要对Docker有所了解. 1.创建WebApp应用程序 我使用的是.Net Core 1.0.1版本,创建一个MVC应用程序,并添加对Redis的引用.因为这些很基础,也很简单,这 ...

  2. CentOS 7 Firewalld 常用操作

    1.简介 Zone 级别 drop: 丢弃所有进入的包,而不给出任何响应block: 拒绝所有外部发起的连接,允许内部发起的连接public: 允许指定的进入连接external: 同上,对伪装的进入 ...

  3. C# winfrom listview 多窗口调用

    Form1 private void button1_Click(object sender, EventArgs e) { Form f = new Form2(ref listView1); f. ...

  4. mac配置jenkins遇到的问题及解决办法

    写这篇博客的时候,我暂时放弃了mac配置jenkins,先记着遇到的坑吧.虽然无数次想砸电脑,但是回头想想,对于经常用windows系统和接触过linux的测试的我来说,这也是个熟悉mac系统的机会. ...

  5. JavaScript encodeURIComponent()

    ■ 把字符串作为 URI 组件进行编码.JavaScript中有三个可以对字符串编码的函数,分别是: escape,encodeURI,encodeURIComponent,相应3个解码函数:unes ...

  6. 为什么U盘在拔出之前需要“安全弹出”?

    前言 我们不知道从什么时候开始有一个观念:U盘一定要点击“安全弹出”才能拔.那么是不是在任何情况下都必须要这样呢? 介绍 U盘的传输策略有两种: 写入缓存:这种策略在windows中称为“更好的性能” ...

  7. Resource——资源的总结

    在xaml中,对于Style.DataTemplate.ControlTemplate.StoryBord等资源,可以放在UserControl.Resource.Windows.Resource.C ...

  8. hdu 1211 逆元

    RSA Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 65536/32768 K (Java/Others)Total Submiss ...

  9. Oracle 反应太后知后觉了.

    很久已经提过一个SR,关于BES一个用户可以用两个密码登陆EBS系统的问题,但是SR解决太慢,而且一致强调你们的版本太低,需要升级到最新的版本,考虑客户化的问题,我们的版本没有升级(R2.1.1),无 ...

  10. leetcode 849. Maximize Distance to Closest Person

    In a row of seats, 1 represents a person sitting in that seat, and 0 represents that the seat is emp ...