首先,感谢师兄的帮助。师兄的代码封装成类,流畅精美,容易调试。我的代码是堆积成的,被师兄嘲笑说写脚本。好吧!我的代码只有我懂,哈哈! 希望以后代码能写得工整点。现在还是让我先懂。这里,我做了一个简单的任务:0,1,2三个数字的分类。准确率:0.9806666666666667

(部分)代码分为:

1 train_net.py

 #import some module
import time
import os
import numpy as np
import sys
import cv2
sys.path.append("/home/wang/Downloads/caffe-master/python")
import caffe
#from prepare_data import DataConfig
#from data_config import DataConfig #configure GPU mode
''' uncommend below line to use gpu '''
caffe.set_mode_gpu() # about dataset
##dataset = Dataset('/home/wang/Downloads/object/extract/')
##dataset = dataset.Split('train')
##data_config = DataConfig(dataset)
##data_config.SetBatchSize(256)
data_config='/home/wang/Downloads/caffe-master/examples/myFig_recognition/data/train/' #configure solve.prototxt
solver = caffe.SGDSolver('models/solver.prototxt') # load pretrain model
print('load pretrain model')
solver.net.copy_from('models/bvlc_reference_caffenet.caffemodel') solver.net.layers[0].SetDataConfig(data_config) for i in range(1, 10000):
# Make one SGD update
solver.step(5)
if i % 100 == 0:
solver.net.save('tmp.caffemodel')
''' TODO: test code '''

2 test_net.py

 #import setup
import time
import os
import random
import sys
sys.path.append("/home/wang/Downloads/caffe-master/python")
import caffe
import cv2
import numpy as np
import random from utils import PrepareImage
#from dataset import Dataset
from test_data import test_data_pre test_num_once=10 ''' uncommend below line to use gpu '''
# caffe.set_mode_gpu() # dataset
#dataset = Dataset('/home/wang/Downloads/object/extract/')
#dataset = dataset.Split('test') # load net
net = caffe.Net('models/deploy.prototxt', caffe.TEST) # load train model
print('load pretrain model')
net.copy_from('tmp.caffemodel') #test all samples one by one
data_pre='/home/wang/Downloads/caffe-master/examples/myFig_recognition/data/test/'
#(imgPaths, gt_label) = dataset[int(random.random()*num_obj)]
(imgPaths, gt_label)=test_data_pre(data_pre)
num_img = len(imgPaths)
correct_num=0
for idx in range(num_img):
img = cv2.imread(imgPaths[idx])
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
tmp_img = img.copy() # for display
img = PrepareImage(img, (227, 227))
net.blobs['data'].reshape(test_num_once, 3, 227, 227)
net.blobs['data'].data[...] = img
#net.blobs['data'].data[i,:,:,:] = img
net.forward()
score = net.blobs['cls_prob'].data
if score.argmax()==gt_label[idx]:
correct_num=correct_num+1
if idx%100==0:
print("Please wait some minutes...")
correct_rate=correct_num*1.0/num_img
print('The correct rate is :',correct_rate)

3 test_data.py

 import os
import numpy as np
from random import randint
import cv2
from utils import PrepareImage,CatImage
#class data:
#path should be /home/
def test_data_pre(path):
img_list=[]
image_num=len(os.listdir(path+'/0'))+len(os.listdir(path+'/1'))+len(os.listdir(path+'/2'))
label = np.zeros(image_num, dtype=np.float32) i=0
for idf in range(3):
idf_str=str(idf)
path1=path+idf_str
tmp_path=os.listdir(path1)
for idi in range(len(tmp_path)):
img_path=path1+'/'+tmp_path[idi]
img_list.append(img_path)
label[i]=idf
i=i+1
return ( img_list,label)

4 pre_data.py

 import os
import numpy as np
from random import randint
import cv2
from utils import PrepareImage,CatImage
#class data:
#path should be /home/
def prepare_data(path,batchsize):
#tmp_path=os.listdir(path)
img_list=[]
label = np.zeros(batchsize, dtype=np.float32)
for i in range(batchsize):
#randomly select one file
idf=randint(0,2)
idf_str=str(idf)
path1=path+idf_str
tmp_path=os.listdir(path1) #randomly select one image
idi=randint(0,len(tmp_path)-1)
#img = cv2.imread(imgPaths[idx])
img_path=path1+'/'+tmp_path[idi]
img=cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
flip = randint(0, 1)>0
if flip > 0:
img = img[:, ::-1, :] # flip left to right img=PrepareImage(img, (227,227))
img_list.append(img)
label[i]=idf
imgData = CatImage(img_list)
return (imgData,label)

5 utils.py

 import os
import cv2
import numpy as np def PrepareImage(im, size):
im = cv2.resize(im, (size[0], size[1]))
im = im.transpose(2, 0, 1)
im = im.astype(np.float32, copy=False)
return im def CatImage(im_list):
max_shape = np.array([im.shape for im in im_list]).max(axis=0)
blob = np.zeros((len(im_list), 3, max_shape[1], max_shape[2]), dtype=np.float32)
# set to mean value
blob[:, 0, :, :] = 102.9801
blob[:, 1, :, :] = 115.9465
blob[:, 2, :, :] = 122.7717
for i, im in enumerate(im_list):
blob[i, :, 0:im.shape[1], 0:im.shape[2]] = im
return blob

6 layer/data_layer.py

 import caffe
import numpy as np #import data_config
#import prepare_data
from pre_data import prepare_data class DataLayer(caffe.Layer): def SetDataConfig(self, data_config):
self._data_config = data_config def GetDataConfig(self):
return self._data_config def setup(self, bottom, top):
# data blob
top[0].reshape(1, 3, 227, 227)
#top[0].reshape(1, 3, 34, 44)
# label type
top[1].reshape(1, 1) def reshape(self, bootom, top):
pass def forward(self, bottom, top):
#(imgs, label) = self._data_config.next()
path=self.GetDataConfig()
(imgs,label)=prepare_data(path,128)
(N, C, W, H) = imgs.shape
# image data
top[0].reshape(N, C, W, H)
top[0].data[...] = imgs
# object type label
top[1].reshape(N)
top[1].data[...] = label def backward(self, top, propagate_down, bottom):
pass

7 layer/__init__.py

import data_layer

还有一些caffe中经典的东西没放进来。

代码和数据:

python caffe 在师兄的代码上修改成自己风格的代码的更多相关文章

  1. 用Python给你的代码上个进度条吧 | 【代码也要面子的】

    微信公众号:AI算法与图像处理如果你觉得对你有帮助,欢迎关注.转发以及点赞哦-( ̄▽ ̄-)~ 前言 最近在跑一些代码的时候,很烦...因为有时候不知道这段程序什么时候能执行完,现在执行哪里了,如果报错 ...

  2. Upsource——对已签入的代码进行分享、讨论和审查代码

    Upsource 一.Upsource简介 Upsource ,这是一个专门为软件开发团队所设计的源代码协作工具.Upsource能够与多种版本控制工具进行集成,包括Git.Mercurial.Sub ...

  3. python之模块ftplib(实现ftp上传下载代码)

    # -*- coding: utf-8 -*- #python 27 #xiaodeng #python之模块ftplib(实现ftp上传下载代码) #需求:实现ftp上传下载代码(不含错误处理) f ...

  4. 学习Git的一点心得以及如何把本地修改、删除的代码上传到github中

    一:学习Github的资料如下:https://git.oschina.net/progit/ 这是一个学习Git的中文网站,如果诸位能够静下心来阅读,不要求阅读太多,只需要阅读前三章,就可以掌握Gi ...

  5. 基于Caffe的DeepID2实现(上)

    小喵的唠叨话:小喵最近在做人脸识别的工作,打算将汤晓鸥前辈的DeepID,DeepID2等算法进行实验和复现.DeepID的方法最简单,而DeepID2的实现却略微复杂,并且互联网上也没有比较好的资源 ...

  6. 使用pycharm开发代码上传到GitLab和GitHub

    使用pycharm开发代码上传到GitLab和GitHub 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 我这里主要是针对局域网的自减的GitLab服务器,python开发工程师如 ...

  7. python 全栈开发,Day86(上传文件,上传头像,CBV,python读写Excel,虚拟环境virtualenv)

    一.上传文件 上传一个图片 使用input type="file",来上传一个文件.注意:form表单必须添加属性enctype="multipart/form-data ...

  8. 使用git工具将本地电脑上的代码上传至GitHub

    本文教你如果使用git工具将本地电脑上的代码上传至GitHub 1.安装git工具 安装git链接 2.使用git工具上传自己的代码到GitHub中 安装完git工具之后,我们会得到两个命令行工具,一 ...

  9. Dynamics AX 2012 R2 窗体系列 - 在窗体上修改字段时所触发的方法及其顺序

        在这个系列里,Reinhard将和大家一起探索在AX的窗体上执行操作时,都会触发窗体.窗体数据源和表上的哪些方法,并且是以怎样的顺序触发的.     这次,我们来看看在窗体上修改或录入数据的情 ...

随机推荐

  1. CommonLang3 --StringUtils使用指南

    转载自(http://blog.csdn.net/xuxiaoxie/article/details/52095930)public static boolean isEmpty(CharSequen ...

  2. pt-table-checksum校验mysql主从数据一致性

    主从数据的一致性校验是个头疼的问题,偶尔被业务投诉主从数据不一致,或者几个从库之间的数据不一致,这会令人沮丧.通常我们仅有一种办法,热备主库,然后替换掉所有的从库.这不仅代价非常大,而且类似治标不治本 ...

  3. C# 往string [] arr 数组插入元素

    string [] arr ; List<string> _list = new List<string>(arr ); for(int i ;i<10;i++) { _ ...

  4. [源码解读] ResNet源码解读(pytorch)

    自己看读完pytorch封装的源码后,自己又重新写了一边(模仿其书写格式), 一些问题在代码中说明. import torch import torchvision import argparse i ...

  5. 自学Jav测试代码三 Math类 & Date & GregorianCalendar类

    2017-08-23 20:30:08 writer: pprp package test; import java.util.Date; import java.util.*; public cla ...

  6. Eclipse安卓项目导入android.support.design报错的解决办法

    导入android.support.design出错:1.项目除了需要依赖appcompat_v7包外还要design包2.design包就是在安卓sdk下Extras中的android.suppor ...

  7. Java 类的构造器的调用顺序

    规则如下: 对于一个复杂的对象,构建器的调用遵照下面的顺序: (1) 调用父类构建器.这个步骤会不断重复下去,首先得到构建的是分级结构的根部,然后是下一个子类,等等.直到抵达最深一层的子类. (2) ...

  8. atom的初次尝试,activate-power-mode 插件和做gif

    编辑器是github 和sublime 的综合,作为一个经常逛github的人,还很喜欢sublime的开发,还有什么好不尝试的理由呢. 好吧,我承认,编辑器有很多,但是像它那样炫酷的很少,作为喜欢一 ...

  9. VS2017 MVC项目,新建控制器提示未能加载文件或程序集“Dapper.Contrib解决方法

    VS2017中MVC项目中,右键新建控制器时,提示 未能加载文件或程序集“Dapper.Contrib, Version=1.50.0.0, Culture=neutral, PublicKeyTok ...

  10. tinyxml解析xml

    基于tinyxml做的简单的xml解析. 1.创建xml bool CreateXmlFile(string& szFileName) {//创建xml文件,szFilePath为文件保存的路 ...