深度学习的第一个实例一般都是mnist,只要这个例子完全弄懂了,其它的就是举一反三的事了。由于篇幅原因,本文不具体介绍配置文件里面每个参数的具体函义,如果想弄明白的,请参看我以前的博文:

数据层及参数

视觉层及参数

solver配置文件及参数

一、数据准备

官网提供的mnist数据并不是图片,但我们以后做的实际项目可能是图片。因此有些人并不知道该怎么办。在此我将mnist数据进行了转化,变成了一张张的图片,我们练习就从图片开始。mnist图片数据我放在了百度云盘。

mnist图片数据下载:http://pan.baidu.com/s/1pLMV4Kz

数据分成了训练集(60000张共10类)和测试集(共10000张10类),每个类别放在一个单独的文件夹里。并且将所有的图片,都生成了txt列表清单(train.txt和test.txt)。大家下载下来后,直接解压到当前用户根目录下就可以了。由于我是在windows下压缩的,因此是winrar文件。如果大家要在linux下解压缩,需要安装rar的linux版本,也是十分简单

sudo apt-get install rar

二、导入caffe库,并设定文件路径

我是将mnist直接放在根目录下的,所以代码如下:

# -*- coding: utf-8 -*-

import caffe
from caffe import layers as L,params as P,proto,to_proto
#设定文件的保存路径
root='/home/xxx/' #根目录
train_list=root+'mnist/train/train.txt' #训练图片列表
test_list=root+'mnist/test/test.txt' #测试图片列表
train_proto=root+'mnist/train.prototxt' #训练配置文件
test_proto=root+'mnist/test.prototxt' #测试配置文件
solver_proto=root+'mnist/solver.prototxt' #参数文件

其中train.txt 和test.txt文件已经有了,其它三个文件,我们需要自己编写。

此处注意:一般caffe程序都是先将图片转换成lmdb文件,但这样做有点麻烦。因此我就不转换了,我直接用原始图片进行操作,所不同的就是直接用图片操作,均值很难计算,因此可以不减均值。

二、生成配置文件

配置文件实际上就是一些txt文档,只是后缀名是prototxt,我们可以直接到编辑器里编写,也可以用代码生成。此处,我用python来生成。

#编写一个函数,生成配置文件prototxt
def Lenet(img_list,batch_size,include_acc=False):
#第一层,数据输入层,以ImageData格式输入
data, label = L.ImageData(source=img_list, batch_size=batch_size, ntop=2,root_folder=root,
transform_param=dict(scale= 0.00390625))
#第二层:卷积层
conv1=L.Convolution(data, kernel_size=5, stride=1,num_output=20, pad=0,weight_filler=dict(type='xavier'))
#池化层
pool1=L.Pooling(conv1, pool=P.Pooling.MAX, kernel_size=2, stride=2)
#卷积层
conv2=L.Convolution(pool1, kernel_size=5, stride=1,num_output=50, pad=0,weight_filler=dict(type='xavier'))
#池化层
pool2=L.Pooling(conv2, pool=P.Pooling.MAX, kernel_size=2, stride=2)
#全连接层
fc3=L.InnerProduct(pool2, num_output=500,weight_filler=dict(type='xavier'))
#激活函数层
relu3=L.ReLU(fc3, in_place=True)
#全连接层
fc4 = L.InnerProduct(relu3, num_output=10,weight_filler=dict(type='xavier'))
#softmax层
loss = L.SoftmaxWithLoss(fc4, label) if include_acc: # test阶段需要有accuracy层
acc = L.Accuracy(fc4, label)
return to_proto(loss, acc)
else:
return to_proto(loss) def write_net():
#写入train.prototxt
with open(train_proto, 'w') as f:
f.write(str(Lenet(train_list,batch_size=64))) #写入test.prototxt
with open(test_proto, 'w') as f:
f.write(str(Lenet(test_list,batch_size=100, include_acc=True)))

配置文件里面存放的,就是我们所说的network。我这里生成的network,可能和原始的Lenet不太一样,不过影响不大。

三、生成参数文件solver

同样,可以在编辑器里面直接书写,也可以用代码生成。

#编写一个函数,生成参数文件
def gen_solver(solver_file,train_net,test_net):
s=proto.caffe_pb2.SolverParameter()
s.train_net =train_net
s.test_net.append(test_net)
s.test_interval = 938 #60000/64,测试间隔参数:训练完一次所有的图片,进行一次测试
s.test_iter.append(100) #10000/100 测试迭代次数,需要迭代100次,才完成一次所有数据的测试
s.max_iter = 9380 #10 epochs , 938*10,最大训练次数
s.base_lr = 0.01 #基础学习率
s.momentum = 0.9 #动量
s.weight_decay = 5e-4 #权值衰减项
s.lr_policy = 'step' #学习率变化规则
s.stepsize=3000 #学习率变化频率
s.gamma = 0.1 #学习率变化指数
s.display = 20 #屏幕显示间隔
s.snapshot = 938 #保存caffemodel的间隔
s.snapshot_prefix =root+'mnist/lenet' #caffemodel前缀
s.type ='SGD' #优化算法
s.solver_mode = proto.caffe_pb2.SolverParameter.GPU #加速
#写入solver.prototxt
with open(solver_file, 'w') as f:
f.write(str(s))

四、开始训练模型

训练过程中,也在不停的测试。

#开始训练
def training(solver_proto):
caffe.set_device(0)
caffe.set_mode_gpu()
solver = caffe.SGDSolver(solver_proto)
solver.solve()

最后,调用以上的函数就可以了。

if __name__ == '__main__':
write_net()
gen_solver(solver_proto,train_proto,test_proto)
training(solver_proto)

五、完成的python文件

mnist.py

# -*- coding: utf-8 -*-

import caffe
from caffe import layers as L,params as P,proto,to_proto
#设定文件的保存路径
root='/home/xxx/' #根目录
train_list=root+'mnist/train/train.txt' #训练图片列表
test_list=root+'mnist/test/test.txt' #测试图片列表
train_proto=root+'mnist/train.prototxt' #训练配置文件
test_proto=root+'mnist/test.prototxt' #测试配置文件
solver_proto=root+'mnist/solver.prototxt' #参数文件 #编写一个函数,生成配置文件prototxt
def Lenet(img_list,batch_size,include_acc=False):
#第一层,数据输入层,以ImageData格式输入
data, label = L.ImageData(source=img_list, batch_size=batch_size, ntop=2,root_folder=root,
transform_param=dict(scale= 0.00390625))
#第二层:卷积层
conv1=L.Convolution(data, kernel_size=5, stride=1,num_output=20, pad=0,weight_filler=dict(type='xavier'))
#池化层
pool1=L.Pooling(conv1, pool=P.Pooling.MAX, kernel_size=2, stride=2)
#卷积层
conv2=L.Convolution(pool1, kernel_size=5, stride=1,num_output=50, pad=0,weight_filler=dict(type='xavier'))
#池化层
pool2=L.Pooling(conv2, pool=P.Pooling.MAX, kernel_size=2, stride=2)
#全连接层
fc3=L.InnerProduct(pool2, num_output=500,weight_filler=dict(type='xavier'))
#激活函数层
relu3=L.ReLU(fc3, in_place=True)
#全连接层
fc4 = L.InnerProduct(relu3, num_output=10,weight_filler=dict(type='xavier'))
#softmax层
loss = L.SoftmaxWithLoss(fc4, label) if include_acc: # test阶段需要有accuracy层
acc = L.Accuracy(fc4, label)
return to_proto(loss, acc)
else:
return to_proto(loss) def write_net():
#写入train.prototxt
with open(train_proto, 'w') as f:
f.write(str(Lenet(train_list,batch_size=64))) #写入test.prototxt
with open(test_proto, 'w') as f:
f.write(str(Lenet(test_list,batch_size=100, include_acc=True))) #编写一个函数,生成参数文件
def gen_solver(solver_file,train_net,test_net):
s=proto.caffe_pb2.SolverParameter()
s.train_net =train_net
s.test_net.append(test_net)
s.test_interval = 938 #60000/64,测试间隔参数:训练完一次所有的图片,进行一次测试
s.test_iter.append(500) #50000/100 测试迭代次数,需要迭代500次,才完成一次所有数据的测试
s.max_iter = 9380 #10 epochs , 938*10,最大训练次数
s.base_lr = 0.01 #基础学习率
s.momentum = 0.9 #动量
s.weight_decay = 5e-4 #权值衰减项
s.lr_policy = 'step' #学习率变化规则
s.stepsize=3000 #学习率变化频率
s.gamma = 0.1 #学习率变化指数
s.display = 20 #屏幕显示间隔
s.snapshot = 938 #保存caffemodel的间隔
s.snapshot_prefix = root+'mnist/lenet' #caffemodel前缀
s.type ='SGD' #优化算法
s.solver_mode = proto.caffe_pb2.SolverParameter.GPU #加速
#写入solver.prototxt
with open(solver_file, 'w') as f:
f.write(str(s)) #开始训练
def training(solver_proto):
caffe.set_device(0)
caffe.set_mode_gpu()
solver = caffe.SGDSolver(solver_proto)
solver.solve()
#
if __name__ == '__main__':
write_net()
gen_solver(solver_proto,train_proto,test_proto)
training(solver_proto)

我将此文件放在根目录下的mnist文件夹下,因此可用以下代码执行

sudo python mnist/mnist.py

在训练过程中,会保存一些caffemodel。多久保存一次,保存多少次,都可以在solver参数文件里进行设置。

我设置为训练10 epoch,9000多次,测试精度可以达到99%

caffe的python接口学习(4):mnist实例---手写数字识别的更多相关文章

  1. caffe的python接口学习(4)mnist实例手写数字识别

    以下主要是摘抄denny博文的内容,更多内容大家去看原作者吧 一 数据准备 准备训练集和测试集图片的列表清单; 二 导入caffe库,设定文件路径 # -*- coding: utf-8 -*- im ...

  2. keras实现mnist数据集手写数字识别

    一. Tensorflow环境的安装 这里我们只讲CPU版本,使用 Anaconda 进行安装 a.首先我们要安装 Anaconda 链接:https://pan.baidu.com/s/1AxdGi ...

  3. NN:利用深度学习之神经网络实现手写数字识别(数据集50000张图片)—Jason niu

    import mnist_loader import network training_data, validation_data, test_data = mnist_loader.load_dat ...

  4. 分类-MNIST(手写数字识别)

    这是学习<Hands-On Machine Learning with Scikit-Learn and TensorFlow>的笔记,如果此笔记对该书有侵权内容,请联系我,将其删除. 这 ...

  5. CNN完成mnist数据集手写数字识别

    # coding: utf-8 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data d ...

  6. mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)

    前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...

  7. [Python]基于CNN的MNIST手写数字识别

    目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...

  8. 深度学习之 mnist 手写数字识别

    深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...

  9. 用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别

    用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别 http://phunter.farbox.com/post/mxnet-tutorial1 用MXnet实战深度学 ...

随机推荐

  1. Docker化运维方式讲解

    应用迁移需求 应用运维需要考虑的一个重要问题就是迁移, 在不同机器.机房.环境间迁移.迁移的原因有很多, 比如硬件过保(硬件故障), 机房迁移, 应用扩缩容等. 应用迁移的核心需求是: 简单.迁移操作 ...

  2. Asp.net mvc5开源项目"超级冷笑话"

    业务时间做了个小网站,超级冷笑话,地址:http://www.superjokes.cn/ 开发技术: asp.net mvc5 +SQLServer2012 ORM:NPoco 用了简单的三层结构 ...

  3. 使用AngularJS实现简单:全选和取消全选功能

    这里用到AngularJS四大特性之二----双向数据绑定 注意:没写一行DOM代码!这就是ng的优点,bootstrap.css为了布局,JS代码也只是简单创建ng模块和ng控制器 效果: < ...

  4. android gradle NDK简介

    本章介绍在Android开发中,关于NDK,gradle相关的知识点. 1.NDK简介 (1)NDK是一系列工具的集合 NDK提供了一系列的工具,帮助开发者快速开发C(或C++)的动态库,并能自动将s ...

  5. 深入理解Linux修改hostname

    当我觉得对Linux系统下修改hostname已经非常熟悉的时候,今天碰到了几个个问题,这几个问题给我好好上了一课,很多知识点,当你觉得你已经掌握的时候,其实你了解的还只是皮毛.技术活,切勿浅尝则止! ...

  6. .NET并行编程实践(一:.NET并行计算基本介绍、并行循环使用模式)

    阅读目录: 1.开篇介绍 2.NET并行计算基本介绍 3.并行循环使用模式 3.1并行For循环 3.2并行ForEach循环 3.3并行LINQ(PLINQ) 1]开篇介绍 最近这几天在捣鼓并行计算 ...

  7. SQL Server 2008 R2——根据数据查找表名和字段名 根据脏数据定位表和字段

    =================================版权声明================================= 版权声明:原创文章 谢绝转载  请通过右侧公告中的“联系邮 ...

  8. springMVC基础controller类

    此文章是基于 搭建SpringMVC+Spring+Hibernate平台 功能:设置请求.响应对象:session.cookie操作:ajax访问返回json数据: 创建springMVC基础con ...

  9. php时间

    date_default_timezone_set('PRC'); //默认时区 //当前的时间增加5天 $date1 = "2014-11-11"; echo date('Y-m ...

  10. mysqld设置密码

    用root 进入mysql后 mysql>set password =password('你的密码'); mysql>flush privileges;   登录: mysql -u ro ...