前言

本文与前文对手写数字识别分类基本类似的,同样图像作为输入,类别作为输出。这里不同的是,不仅仅是使用简单的卷积神经网络加上全连接层的模型。卷积神经网络大火以来,发展出来许多经典的卷积神经网络模型,包括VGG、ResNet、AlexNet等等。下面将针对CIFAR-10数据集,对图像进行分类。

1、CIFAR-10数据集、Reader创建

CIFAR-10数据集分为5个batch的训练集和1个batch的测试集,每个batch包含10,000张图片。每张图像尺寸为32*32的RGB图像,且包含有标签。一共有10个标签:airplane、automobile、bird、cat、deer、dog、frog、horse、ship、truck十个类别。

我在CIFAR-10网站中下载的是[CIFAR-10 python version](http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz)。数据集完成后,解压得到上述六个文件。上述六个文件都是字典文件,使用cPickle模块即可读入。字典中‘data’需要重新定义维度为1000*32*32*3,维度分别代表[N H W C],即10,000张32*32尺寸的三通道(RGB)图像,再经过转换成为paddlepaddle读取的[N C H W ]维度形式;而字典‘labels’为10000个标签。如此一来,可以建立读取CIFAR-10的reader(与官方例程不同),如下:

def reader_creator(ROOT,istrain=True,cycle=False):
def load_CIFAR_batch(filename):
""" load single batch of cifar """
with open(filename,'rb') as f:
datadict = Pickle.load(f)
X = datadict['data']
Y = datadict['labels']
""" (N C H W) transpose to (N H W C) """
X = X.reshape(10000,3,32,32).transpose(0,2,3,1).astype('float')
Y = np.array(Y)
return X,Y
def reader():
while True:
if istrain:
for b in range(1,6):
f = os.path.join(ROOT,'data_batch_%d'%(b))
X,Y = load_CIFAR_batch(f)
length = X.shape[0]
for i in range(length):
yield X[i],Y[i]
if not cycle:
break
else:
f = os.path.join(ROOT,'test_batch')
X,Y = load_CIFAR_batch(f)
length = X.shape[0]
for i in range(length):
yield X[i],Y[i]
if not cycle:
break
return reader

2、VGG网络

VGG网络采用“减小卷积核大小,增加卷积核数量”的思想改造而成,这里直接采用paddlepaddle例程中的VGG网络了,值得提醒的是paddlepaddle中直接有函数img_conv_group提供卷积、池化、dropout一组操作,所以根据VGG的模型,前面卷积层可以划分为5组,然后再经过3层的全连接层得到结果。

PaddlePaddle例程中根据上图D网络,加入dorpout:

def vgg_bn_drop(input):
def conv_block(ipt, num_filter, groups, dropouts):
return fluid.nets.img_conv_group(
input=ipt,
#一组的卷积层的卷积核总数,组成list[num_filter num_filter ...]
conv_num_filter=[num_filter] * groups,
conv_filter_size=3,
conv_act='relu',
conv_with_batchnorm=True,
#每组卷积层各层的droput概率
conv_batchnorm_drop_rate=dropouts,
pool_size=2,
pool_stride=2,
pool_type='max') conv1 = conv_block(input, 64, 2, [0.3, 0]) #[0.3 0]即为第一组两层的dorpout概率,下同
conv2 = conv_block(conv1, 128, 2, [0.4, 0])
conv3 = conv_block(conv2, 256, 3, [0.4, 0.4, 0])
conv4 = conv_block(conv3, 512, 3, [0.4, 0.4, 0])
conv5 = conv_block(conv4, 512, 3, [0.4, 0.4, 0]) drop = fluid.layers.dropout(x=conv5, dropout_prob=0.5)
fc1 = fluid.layers.fc(input=drop, size=512, act=None) bn = fluid.layers.batch_norm(input=fc1, act='relu') drop2 = fluid.layers.dropout(x=bn, dropout_prob=0.5)
fc2 = fluid.layers.fc(input=drop2, size=512, act=None) predict = fluid.layers.fc(input=fc2, size=10, act='softmax')
return predict

3、训练

训练程序与上一节例程一样,同样是选取交叉熵作为损失函数,不多累赘讲述。

def train_network():
predict = inference_network()
label = fluid.layers.data(name='label',shape=[1],dtype='int64')
cost = fluid.layers.cross_entropy(input=predict,label=label)
avg_cost = fluid.layers.mean(cost)
accuracy = fluid.layers.accuracy(input=predict,label=label)
return [avg_cost,accuracy] def optimizer_program():
return fluid.optimizer.Adam(learning_rate=0.001) def train(data_path,save_path):
BATCH_SIZE = 128
EPOCH_NUM = 2
train_reader = paddle.batch(
paddle.reader.shuffle(reader_creator(data_path),buf_size=50000),
batch_size = BATCH_SIZE)
test_reader = paddle.batch(
reader_creator(data_path,False),
batch_size=BATCH_SIZE)
def event_handler(event):
if isinstance(event, fluid.EndStepEvent):
if event.step % 100 == 0:
print("\nPass %d, Epoch %d, Cost %f, Acc %f" %
(event.step, event.epoch, event.metrics[0],
event.metrics[1]))
else:
sys.stdout.write('.')
sys.stdout.flush()
if isinstance(event, fluid.EndEpochEvent):
avg_cost, accuracy = trainer.test(
reader=test_reader, feed_order=['image', 'label'])
print('\nTest with Pass {0}, Loss {1:2.2}, Acc {2:2.2}'.format(
event.epoch, avg_cost, accuracy))
if save_path is not None:
trainer.save_params(save_path)
place = fluid.CUDAPlace(0)
trainer = fluid.Trainer(
train_func=train_network, optimizer_func=optimizer_program, place=place)
trainer.train(
reader=train_reader,
num_epochs=EPOCH_NUM,
event_handler=event_handler,
feed_order=['image', 'label'])

4、测试接口

测试接口也类似,需要特别注意的是图像维度要改为[N C H W]的顺序!

def infer(params_dir):
place = fluid.CUDAPlace(0)
inferencer = fluid.Inferencer(
infer_func=inference_network, param_path=params_dir, place=place)
# Prepare testing data.
from PIL import Image
import numpy as np
import os def load_image(file):
im = Image.open(file)
im = im.resize((32, 32), Image.ANTIALIAS)
im = np.array(im).astype(np.float32)
"""transpose [H W C] to [C H W]"""
im = im.transpose((2, 0, 1))
im = im / 255.0 # Add one dimension, [N C H W] N=1
im = np.expand_dims(im, axis=0)
return im
cur_dir = os.path.dirname(os.path.realpath(__file__))
img = load_image(cur_dir + '/dog.png')
# inference
results = inferencer.infer({'image': img})
print(results)
lab = np.argsort(results) # probs and lab are the results of one batch data
print("infer results: ", cifar_classes[lab[0][0][-1]])

5、运行结果

由于笔者没有GPU服务器,所以只迭代了50次,已经用了8个多小时,但是准确率只有15.6%,测试集方面准确率有17%,效果不理想,用于验证的结果也是错的!

Pass , Epoch , Cost 2.261115, Acc 0.156250
.........................................................................................
Test with Pass , Loss 2.2, Acc 0.17 Classify the cifar10 images...
[array([[0.05997971, 0.13485196, 0.096842 , 0.09973737, 0.11053724,
0.08180068, 0.13847008, 0.08627985, 0.06851784, 0.12298328]],
dtype=float32)]
infer results: frog

结语

网络比较深,且数据集比较大,训练时间比较长,普通笔记本上面的GT840M聊以胜无吧。

本文代码:02_cifar

参考:book/03.image_classification/

【PaddlePaddle系列】CIFAR-10图像分类的更多相关文章

  1. 【深度学习系列】用PaddlePaddle和Tensorflow进行图像分类

    上个月发布了四篇文章,主要讲了深度学习中的"hello world"----mnist图像识别,以及卷积神经网络的原理详解,包括基本原理.自己手写CNN和paddlepaddle的 ...

  2. ABP(现代ASP.NET样板开发框架)系列之10、ABP领域层——实体

    点这里进入ABP系列文章总目录 基于DDD的现代ASP.NET开发框架--ABP系列之10.ABP领域层——实体 ABP是“ASP.NET Boilerplate Project (ASP.NET样板 ...

  3. JVM基础系列第10讲:垃圾回收的几种类型

    我们经常会听到许多垃圾回收的术语,例如:Minor GC.Major GC.Young GC.Old GC.Full GC.Stop-The-World 等.但这些 GC 术语到底指的是什么,它们之间 ...

  4. Mysql高手系列 - 第10篇:常用的几十个函数详解,收藏慢慢看

    这是Mysql系列第10篇. 环境:mysql5.7.25,cmd命令中进行演示. MySQL 数值型函数 函数名称 作 用 abs 求绝对值 sqrt 求二次方根 mod 求余数 ceil 和 ce ...

  5. java高并发系列 - 第10天:线程安全和synchronized关键字

    这是并发系列第10篇文章. 什么是线程安全? 当多个线程去访问同一个类(对象或方法)的时候,该类都能表现出正常的行为(与自己预想的结果一致),那我们就可以所这个类是线程安全的. 看一段代码: pack ...

  6. 【翻译】TensorFlow卷积神经网络识别CIFAR 10Convolutional Neural Network (CNN)| CIFAR 10 TensorFlow

    原网址:https://data-flair.training/blogs/cnn-tensorflow-cifar-10/ by DataFlair Team · Published May 21, ...

  7. ShoneSharp语言(S#)的设计和使用介绍系列(10)— 富家子弟“语句“不炫富

    ShoneSharp语言(S#)的设计和使用介绍 系列(10)— 富家子弟“语句“不炫富 作者:Shone 声明:原创文章欢迎转载,但请注明出处,https://www.cnblogs.com/Sho ...

  8. RabbitMQ 入门系列:10、扩展内容:延时队列:延时队列插件及其有限的适用场景(系列大结局)。

    系列目录 RabbitMQ 入门系列:1.MQ的应用场景的选择与RabbitMQ安装. RabbitMQ 入门系列:2.基础含义:链接.通道.队列.交换机. RabbitMQ 入门系列:3.基础含义: ...

  9. 深度学习与计算机视觉系列(2)_图像分类与KNN

    作者: 寒小阳 &&龙心尘 时间:2015年11月. 出处: http://blog.csdn.net/han_xiaoyang/article/details/49949535 ht ...

随机推荐

  1. Devexpress VCL Build v2014 vol 14.2.4 发布

    What's New in 14.2.4 (VCL Product Line)   New Major Features in 14.2 What's New in VCL Products 14.2 ...

  2. hdu-1147(跨立实验)

    题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=1147 思路:判断每条线段,如果将要输入的线段和已经有的线段相交,则这条线段不算. 参考文章:https ...

  3. Linux下编译与调试

    gcc/g++编译器 对于.c格式的C文件,可以采用gcc或g++编译 对于 .cc..cpp格式的C++文件,应该采用g++进行编译 常用的选项: -c  表示编译源文件 -o  表示输出目标文件 ...

  4. gj10 python socket编程

    10.1 HTTP.Socket.TCP这几个概念 五层网络模型 socket 不属于任何协议,是一个API,通过socket 可以和传输层的打交道,然后在之上可以实现自己的功能和协议 10.2 cl ...

  5. 使用Volley上传文件

    使用浏览器上传文件,然后通过Wireshark抓包分析,发现发送的数据大概是这个样子. MIME Multipart Media Encapsulation, Type: multipart/form ...

  6. 20155320 2016-2017-2 《Java程序设计》第五周学习总结

    20155320 2016-2017-2 <Java程序设计>第五周学习总结 教材学习内容总结 错误处理 java中所有错误都会被打包为对象,可以通过try catch 代表错误的对象后做 ...

  7. 移动端与PC端的viewport

    第一种解析: 设备像素,就是我们直觉上觉得"靠谱"的像素,这些像素为所使用的各种设备提供了正规的分辨率,并且其值可以通过(通常情况下)从screen.width/height属性中 ...

  8. 【笔记】virtualbox+arch+kde5安装流水账

    正常安装就是RTFD就行了,不行辅助这几个链接也行: 我先把整个脚本[1]放这里: loadkeys us parted mkfs.ext4 /dev/sda1mkfs.ext4 /dev/sda3 ...

  9. OSLab课堂作业2

      日期:2019/3/23 内容: 实现内容 要求 mysys.c 实现函数mysys,用于执行一个系统命令. mysys的功能与系统函数system相同,要求用进程管理相关系统调用自己实现一遍 使 ...

  10. Restframework 认证authentication 组件实例-1

    1. 创建用户表和 token表 class User(models.Model): user =models.CharField(max_length=) pwd =models.CharField ...