本文介绍如何使用keras作图片分类(2分类与多分类,其实就一个参数的区别。。。呵呵)

先来看看解决的问题:从一堆图片中分出是不是书本,也就是最终给图片标签上:“书本“、“非书本”,简单吧。

先来看看网络模型,用到了卷积和全连接层,最后套上SOFTMAX算出各自概率,输出ONE-HOT码,主要部件就是这些,下面的nb_classes就是用来控制分类数的,本文是2分类:

from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.optimizers import SGD def Net_model(nb_classes, lr=0.001,decay=1e-6,momentum=0.9):
model = Sequential()
model.add(Convolution2D(filters=10, kernel_size=(5,5),
padding='valid',
input_shape=(200, 200, 3)))
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Convolution2D(filters=20, kernel_size=(10,10)))
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25)) model.add(Flatten())
model.add(Dense(1000))
model.add(Activation('tanh'))
model.add(Dropout(0.5))
model.add(Dense(nb_classes))
model.add(Activation('softmax')) sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd) return model

上面的input_shape=(200, 200, 3)代表图片像素大小为宽高为200,200,并且包含RGB 3通道的图片,不是灰度图片(只要1个通道)

也就是说送入此网络的图片宽高必须200*200*3;如果不是这个shape就需要resize到这个shape

下面来看看训练程序,首先肯定是要收集些照片,书本、非书本的照片,我是分别放在了0文件夹和1文件夹下了,再带个验证用途的文件夹validate:

  

训练程序涉及到几个地方:照片文件的读取、模型加载训练与保存、可视化训练过程中的损失函数value

照片文件的读取

import cv2
import os
import numpy as np
import keras def loadImages():
imageList=[]
labelList=[] rootdir="d:\\books\\0"
list =os.listdir(rootdir)
for item in list:
path=os.path.join(rootdir,item)
if(os.path.isfile(path)):
f=cv2.imread(path)
f=cv2.resize(f, (200, 200))#resize到网络input的shape
imageList.append(f)
labelList.append(0)#类别0 rootdir="d:\\books\\1"
list =os.listdir(rootdir)
for item in list:
path=os.path.join(rootdir,item)
if(os.path.isfile(path)):
f=cv2.imread(path)
f=cv2.resize(f, (200, 200))#resize到网络input的shape
imageList.append(f)
labelList.append(1)#类别1 return np.asarray(imageList), keras.utils.to_categorical(labelList, 2)

关于(200,200)这个shape怎么得来的,只是几月前开始玩opencv时随便写了个数值,后来想利用那些图片,就适应到这个shape了

keras.utils.to_categorical函数类似numpy.onehot、tf.one_hot这些,只是one hot的keras封装

模型加载训练与保存

nb_classes = 2
nb_epoch = 30
nb_step = 6
batch_size = 3 x,y=loadImages() from keras.preprocessing.image import ImageDataGenerator
dataGenerator=ImageDataGenerator()
dataGenerator.fit(x)
data_generator=dataGenerator.flow(x, y, batch_size, True)#generator函数,用来生成批处理数据(从loadImages中) model=NetModule.Net_model(nb_classes=nb_classes, lr=0.0001) #加载网络模型 history=model.fit_generator(data_generator, epochs=nb_epoch, steps_per_epoch=nb_step, shuffle=True)#训练网络,并且返回每次epoch的损失value model.save_weights('D:\\Documents\\Visual Studio 2017\\Projects\\ConsoleApp9\\PythonApplication1\\书本识别\\trained_model_weights.h5')#保存权重
print("DONE, model saved in path-->D:\\Documents\\Visual Studio 2017\\Projects\\ConsoleApp9\\PythonApplication1\\书本识别\\trained_model_weights.h5")

ImageDataGenerator构造函数有很多参数,主要用来提升数据质量,比如要不要标准化数字

lr=0.001这个参数要看经验,大了会导致不收敛,训练的时候经常由于这个参数的问题导致重复训练,这在没有GPU的情况下很是痛苦。。痛苦。。。痛苦。。。

model.save_weights是保存权重,但是不保存网络模型 ,对应的是model.load_weights方法

model.save是保存网络+权重,只是。。。。此例中用save_weights保存的h5文件是125M,但用save方法保存后,h5文件就增大为280M了。。。

上面2个save方法都能finetune,只是灵活度不一样。

可视化训练过程中的损失函数value

import matplotlib.pyplot as plt

plt.plot(history.history['loss'])
plt.show()

  

貌似没啥好补充的。。。

AND。。。。看看预测部分吧,这部分加载图片、加载模型,似乎都和训练部分雷同:

def loadImages():
imageList=[] rootdir="d:\\books\\validate"
list =os.listdir(rootdir)
for item in list:
path=os.path.join(rootdir,item)
if(os.path.isfile(path)):
f=cv2.imread(path)
f=cv2.resize(f, (200, 200))
imageList.append(f) return np.asarray(imageList) x=loadImages() x=np.asarray(x) model=NetModule.Net_model(nb_classes=2, lr=0.0001)
model.load_weights('D:\\Documents\\Visual Studio 2017\\Projects\\ConsoleApp9\\PythonApplication1\\书本识别\\trained_model_weights.h5') print(model.predict(x))
print(model.predict_classes(x))
y=convert2label(model.predict_classes(x))
print(y)

predict的返回其实是softmax层返回的概率数值,是<=1的float

predict_classes返回的是经过one-hot处理后的数值,此时只有0、1两种数值(最大的value会被返回称为1,其他都为0)  

convert2label:

def convert2label(vector):
string_array=[]
for v in vector:
if v==1:
string_array.append('BOOK')
else:
string_array.append('NOT BOOK')
return string_array

这个函数是用来把0、1转换成文本的,小插曲:

本来这里是中文的“书本”、“非书本”,后来和女儿一起调试时发现都显示成了问号,应该是中文字符问题,就改成了英文显示,和女儿一起写代码是种乐趣啊!

本来只是显示文本,感觉太无聊了,因此加上了opencv显示图片+分类文本的代码段:

for i in range(len(x)):
cv2.putText(x[i], y[i], (50,50), cv2.FONT_HERSHEY_SIMPLEX, 1, 255, 2)
cv2.imshow('image'+str(i), x[i]) cv2.waitKey(-1)

  

OK, 2018年继续学习,继续科学信仰。

用keras作CNN卷积网络书本分类(书本、非书本)的更多相关文章

  1. 1. CNN卷积网络-初识

    1. CNN卷积网络-初识 2. CNN卷积网络-前向传播算法 3. CNN卷积网络-反向更新 1. 前言 卷积神经网络是一种特殊的深层的神经网络模型,它的特殊性体现在两个方面, 它的神经元间的连接是 ...

  2. 3. CNN卷积网络-反向更新

    1. CNN卷积网络-初识 2. CNN卷积网络-前向传播算法 3. CNN卷积网络-反向更新 1. 前言 如果读者详细的了解了DNN神经网络的反向更新,那对我们今天的学习会有很大的帮助.我们的CNN ...

  3. 2. CNN卷积网络-前向传播算法

    1. CNN卷积网络-初识 2. CNN卷积网络-前向传播算法 3. CNN卷积网络-反向更新 1. 前言 我们已经了解了CNN的结构,CNN主要结构有输入层,一些卷积层和池化层,后面是DNN全连接层 ...

  4. Deeplearning 两层cnn卷积网络详解

    https://blog.csdn.net/u013203733/article/details/79074452 转载地址: https://www.cnblogs.com/sunshineatno ...

  5. 用keras的cnn做人脸分类

    keras介绍 Keras是一个简约,高度模块化的神经网络库.采用Python / Theano开发. 使用Keras如果你需要一个深度学习库: 可以很容易和快速实现原型(通过总模块化,极简主义,和可 ...

  6. keras搭建密集连接网络/卷积网络/循环网络

    输入模式与网络架构间的对应关系: 向量数据:密集连接网络(Dense层) 图像数据:二维卷积神经网络 声音数据(比如波形):一维卷积神经网络(首选)或循环神经网络 文本数据:一维卷积神经网络(首选)或 ...

  7. 用Keras搭建神经网络 简单模版(三)—— CNN 卷积神经网络(手写数字图片识别)

    # -*- coding: utf-8 -*- import numpy as np np.random.seed(1337) #for reproducibility再现性 from keras.d ...

  8. Keras(四)CNN 卷积神经网络 RNN 循环神经网络 原理及实例

    CNN 卷积神经网络 卷积 池化 https://www.cnblogs.com/peng8098/p/nlp_16.html 中有介绍 以数据集MNIST构建一个卷积神经网路 from keras. ...

  9. 机器学习-计算机视觉和卷积网络CNN

    概述 对于计算机视觉的应用现在是非常广泛的,但是它背后的原理其实非常简单,就是将每一个像素的值pixel输入到一个DNN中,然后让这个神经网络去学习这个模型,最后去应用这个模型就可以了.听起来是不是很 ...

随机推荐

  1. 11. 配置ZooKeeper ensemble

    一个ZooKeeper集群或复制的ZooKeeper服务器集群应该优化配置,以避免出现脑裂(split-brain)等情况. 由于网络分割,同一ensemble的两个不同服务器可能构成领导者不一致,因 ...

  2. unity插件开发

    1.简单的svn集成: 查询svn的文档可以知道svn提供各种命令符操作.因此,原理非常简单,利用命令符操作调用svn即可.代码也非常简单: 更新:Process.Start("Tortoi ...

  3. 教你3分钟读懂HTML5语言的特点

    HTML5的跨平台技术 HTML5技术跨平台,适配多终端.传统移动终端上的Native App,开发者的研发工作必须针对不同的操作系统进行,成本相对较高.Native App对于用户还存在着管理成本. ...

  4. 数字三角形-poj

    题目要求: 7 3 8 8 1 0 2 7 4 4 4 5 2 6 5 在上面的数字三角形中寻找在上面的数字三角形中寻找一条从顶部到底边的路径,使得路径上所经过的数字之和最大.路径上的每一步都只能往左 ...

  5. 微服务配置内容《网上copy》=========》如何创建一个高可用的服务注册中心

    前言:首先要知道什么是一个高可用的服务注册中心,基于spring boot建成的服务注册中心是一个单节点的服务注册中心,这样一旦发生了故障,那么整个服务就会瘫痪,所以我们需要一个高可用的服务注册中心, ...

  6. 【正则表达式】--python(表示字符)

    [前修知识] match :匹配    span:范围 match 是从头往后开始匹配,search不按照顺序,直接获取自己想要的,有就显示,没有就None r 代表反转义,前面也提到过这个知识,如果 ...

  7. Process Doppelgänging

    Process Doppelgänging -- 新的代码注入技术,通杀windows系统的所有版本,并且能绕过绝大多数的安全软件. 介绍 今天(2017-12-07),在伦敦举行的2017年黑帽欧洲 ...

  8. 企业实战Nginx+Tomcat动静分离架构的技术分享

    Nginx动静分离简单来说就是把动态跟静态请求分开,不能理解成只是单纯的把动态页面和静态页面物理分离.严格意义上说应该是动态请求跟静态请求分开,可以理解成使用Nginx处理静态页面,Tomcat.Re ...

  9. 51Nod 2006 飞行员配对(二分图最大匹配)

    链接:http://www.51nod.com/onlineJudge/questionCode.html#!problemId=2006 思路: 二分匹配 注意n m的关系 代码: #include ...

  10. oracle触发器 调用 web接口

    最近要求开发当数据表发生变化的时候调用web接口的需求,上网找了好几篇文章看着都觉得不是很好,也根据别人的思路去实现了下,感觉都不太理想,最后使用URLConnection实现了调用.具体查看一下代码 ...