利用keras实现MobileNet,并以mnist数据集作为一个小例子进行识别。使用的环境是:tensorflow-gpu 2.0,python=3.7 , GTX-2070的GPU

1.导入数据

  • 首先是导入两行魔法命令,可以多行显示.
%config InteractiveShell.ast_node_interactivity="all"
%pprint
  • 加载keras中自带的mnist数据
import tensorflow as tf
import keras tf.debugging.set_log_device_placement(True) mnist = keras.datasets.mnist (x_train,y_train),(x_test,y_test) = mnist.load_data()

上述tf.debugging.set_log_device_placement(True)的作用是将模型放在GPU上进行训练。

  • 数据的转换

    在mnist上下载的数据的分辨率是2828的,mobilenet用来训练的数据是ImageNet ,其图片的分辨率是224224,所以先将图片的维度调整为224*224.
from PIL import Image
import numpy as np
def convert_mnist_224pix(X):
img=Image.fromarray(X)
x=np.zeros((224,224))
img=np.array(img.resize((224,224)))
x[:,:]=img return x iteration = iter(x_train)
new_train =np.zeros((len(x_train),224,224),dtype=np.float32)
for i in range(len(x_train)):
data = next(iteration)
new_train[i]=convert_mnist_224pix(data) if i%5000==0:
print(i) new_train.shape

这里要注意一下,new_train中一定要注明dtype=np.float32,不然默认的是float64,这样数据就太大了,没有那么多存储空间装。最后输出的维度是(60000,224,224)

2.搭建模型

  • 导入所有需要的函数和库
from keras.layers import Conv2D,DepthwiseConv2D,Dense,AveragePooling2D,BatchNormalization,Input
from keras import Model
from keras import Sequential
from keras.layers.advanced_activations import ReLU
from keras.utils import to_categorical
  • 自己定义中间可以重复利用的层,将其放在一起,简化搭建网络的重复代码。
def depth_point_conv2d(x,s=[1,1,2,1],channel=[64,128]):
"""
s:the strides of the conv
channel: the depth of pointwiseconvolutions
""" dw1 = DepthwiseConv2D((3,3),strides=s[0],padding='same')(x)
bn1 = BatchNormalization()(dw1)
relu1 = ReLU()(bn1)
pw1 = Conv2D(channel[0],(1,1),strides=s[1],padding='same')(relu1)
bn2 = BatchNormalization()(pw1)
relu2 = ReLU()(bn2)
dw2 = DepthwiseConv2D((3,3),strides=s[2],padding='same')(relu2)
bn3 = BatchNormalization()(dw2)
relu3 = ReLU()(bn3)
pw2 = Conv2D(channel[1],(1,1),strides=s[3],padding='same')(relu3)
bn4 = BatchNormalization()(pw2)
relu4 = ReLU()(bn4) return relu4 def repeat_conv(x,s=[1,1],channel=512):
dw1 = DepthwiseConv2D((3,3),strides=s[0],padding='same')(x)
bn1 = BatchNormalization()(dw1)
relu1 = ReLU()(bn1)
pw1 = Conv2D(channel,(1,1),strides=s[1],padding='same')(relu1)
bn2 = BatchNormalization()(pw1)
relu2 = ReLU()(bn2) return relu2

根据mobilenet论文中的结构进行模型的搭建

在倒数第5行Conv/dw/s2中,我一直不理解如果strides=2,为什么最后生成图片尺寸没有变化,我感觉可能是笔误?,不过我这里将strides定义为1,因为这样才符合后面的整个输出。

  • 搭建网络
h0=Input(shape=(224,224,1))
h1=Conv2D(32,(3,3),strides = 2,padding="same")(h0)
h2= BatchNormalization()(h1)
h3=ReLU()(h2)
h4 = depth_point_conv2d(h3,s=[1,1,2,1],channel=[64,128])
h5 = depth_point_conv2d(h4,s=[1,1,2,1],channel=[128,256])
h6 = depth_point_conv2d(h5,s=[1,1,2,1],channel=[256,512])
h7 = repeat_conv(h6)
h8 = repeat_conv(h7)
h9 = repeat_conv(h8)
h10 = repeat_conv(h9)
h11 = depth_point_conv2d(h10,s=[1,1,2,1],channel=[512,1024])
h12 = repeat_conv(h11,channel=1024)
h13 = AveragePooling2D((7,7))(h12)
h14 = Dense(10,activation='softmax')(h13)
model =Model(input=h0,output =h14)
model.summary()
Model: "model_4"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_11 (InputLayer) (None, 224, 224, 1) 0
_________________________________________________________________
conv2d_63 (Conv2D) (None, 112, 112, 32) 320
_________________________________________________________________
batch_normalization_120 (Bat (None, 112, 112, 32) 128
_________________________________________________________________
re_lu_120 (ReLU) (None, 112, 112, 32) 0
_________________________________________________________________
depthwise_conv2d_58 (Depthwi (None, 112, 112, 32) 320
_________________________________________________________________
batch_normalization_121 (Bat (None, 112, 112, 32) 128
_________________________________________________________________
re_lu_121 (ReLU) (None, 112, 112, 32) 0
_________________________________________________________________
conv2d_64 (Conv2D) (None, 112, 112, 64) 2112
_________________________________________________________________
batch_normalization_122 (Bat (None, 112, 112, 64) 256
_________________________________________________________________
re_lu_122 (ReLU) (None, 112, 112, 64) 0
_________________________________________________________________
depthwise_conv2d_59 (Depthwi (None, 56, 56, 64) 640
_________________________________________________________________
batch_normalization_123 (Bat (None, 56, 56, 64) 256
_________________________________________________________________
re_lu_123 (ReLU) (None, 56, 56, 64) 0
_________________________________________________________________
conv2d_65 (Conv2D) (None, 56, 56, 128) 8320
_________________________________________________________________
batch_normalization_124 (Bat (None, 56, 56, 128) 512
_________________________________________________________________
re_lu_124 (ReLU) (None, 56, 56, 128) 0
_________________________________________________________________
depthwise_conv2d_60 (Depthwi (None, 56, 56, 128) 1280
_________________________________________________________________
batch_normalization_125 (Bat (None, 56, 56, 128) 512
_________________________________________________________________
re_lu_125 (ReLU) (None, 56, 56, 128) 0
_________________________________________________________________
conv2d_66 (Conv2D) (None, 56, 56, 128) 16512
_________________________________________________________________
batch_normalization_126 (Bat (None, 56, 56, 128) 512
_________________________________________________________________
re_lu_126 (ReLU) (None, 56, 56, 128) 0
_________________________________________________________________
depthwise_conv2d_61 (Depthwi (None, 28, 28, 128) 1280
_________________________________________________________________
batch_normalization_127 (Bat (None, 28, 28, 128) 512
_________________________________________________________________
re_lu_127 (ReLU) (None, 28, 28, 128) 0
_________________________________________________________________
conv2d_67 (Conv2D) (None, 28, 28, 256) 33024
_________________________________________________________________
batch_normalization_128 (Bat (None, 28, 28, 256) 1024
_________________________________________________________________
re_lu_128 (ReLU) (None, 28, 28, 256) 0
_________________________________________________________________
depthwise_conv2d_62 (Depthwi (None, 28, 28, 256) 2560
_________________________________________________________________
batch_normalization_129 (Bat (None, 28, 28, 256) 1024
_________________________________________________________________
re_lu_129 (ReLU) (None, 28, 28, 256) 0
_________________________________________________________________
conv2d_68 (Conv2D) (None, 28, 28, 256) 65792
_________________________________________________________________
batch_normalization_130 (Bat (None, 28, 28, 256) 1024
_________________________________________________________________
re_lu_130 (ReLU) (None, 28, 28, 256) 0
_________________________________________________________________
depthwise_conv2d_63 (Depthwi (None, 14, 14, 256) 2560
_________________________________________________________________
batch_normalization_131 (Bat (None, 14, 14, 256) 1024
_________________________________________________________________
re_lu_131 (ReLU) (None, 14, 14, 256) 0
_________________________________________________________________
conv2d_69 (Conv2D) (None, 14, 14, 512) 131584
_________________________________________________________________
batch_normalization_132 (Bat (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_132 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
depthwise_conv2d_64 (Depthwi (None, 14, 14, 512) 5120
_________________________________________________________________
batch_normalization_133 (Bat (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_133 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
conv2d_70 (Conv2D) (None, 14, 14, 512) 262656
_________________________________________________________________
batch_normalization_134 (Bat (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_134 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
depthwise_conv2d_65 (Depthwi (None, 14, 14, 512) 5120
_________________________________________________________________
batch_normalization_135 (Bat (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_135 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
conv2d_71 (Conv2D) (None, 14, 14, 512) 262656
_________________________________________________________________
batch_normalization_136 (Bat (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_136 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
depthwise_conv2d_66 (Depthwi (None, 14, 14, 512) 5120
_________________________________________________________________
batch_normalization_137 (Bat (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_137 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
conv2d_72 (Conv2D) (None, 14, 14, 512) 262656
_________________________________________________________________
batch_normalization_138 (Bat (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_138 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
depthwise_conv2d_67 (Depthwi (None, 14, 14, 512) 5120
_________________________________________________________________
batch_normalization_139 (Bat (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_139 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
conv2d_73 (Conv2D) (None, 14, 14, 512) 262656
_________________________________________________________________
batch_normalization_140 (Bat (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_140 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
depthwise_conv2d_68 (Depthwi (None, 14, 14, 512) 5120
_________________________________________________________________
batch_normalization_141 (Bat (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_141 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
conv2d_74 (Conv2D) (None, 14, 14, 512) 262656
_________________________________________________________________
batch_normalization_142 (Bat (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_142 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
depthwise_conv2d_69 (Depthwi (None, 7, 7, 512) 5120
_________________________________________________________________
batch_normalization_143 (Bat (None, 7, 7, 512) 2048
_________________________________________________________________
re_lu_143 (ReLU) (None, 7, 7, 512) 0
_________________________________________________________________
conv2d_75 (Conv2D) (None, 7, 7, 1024) 525312
_________________________________________________________________
batch_normalization_144 (Bat (None, 7, 7, 1024) 4096
_________________________________________________________________
re_lu_144 (ReLU) (None, 7, 7, 1024) 0
_________________________________________________________________
depthwise_conv2d_70 (Depthwi (None, 7, 7, 1024) 10240
_________________________________________________________________
batch_normalization_145 (Bat (None, 7, 7, 1024) 4096
_________________________________________________________________
re_lu_145 (ReLU) (None, 7, 7, 1024) 0
_________________________________________________________________
conv2d_76 (Conv2D) (None, 7, 7, 1024) 1049600
_________________________________________________________________
batch_normalization_146 (Bat (None, 7, 7, 1024) 4096
_________________________________________________________________
re_lu_146 (ReLU) (None, 7, 7, 1024) 0
_________________________________________________________________
average_pooling2d_5 (Average (None, 1, 1, 1024) 0
_________________________________________________________________
dense_4 (Dense) (None, 1, 1, 10) 10250
=================================================================
Total params: 3,249,482
Trainable params: 3,227,594
Non-trainable params: 21,888
_________________________________________________________________

因为这里的类别只有10类,所以最后的输出层只有10个神经元,原始的mobilenet要进行1000个类别分类,所以最后是1000个神经元。

model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])

上述代码定义优化算法和损失函数。

3、训练数据的整理与训练

将训练数据进行维度变换,标签进行one-hot编码并进行维度变换。

x_train = np.expand_dims(new_train,3)

y_train = to_categorical(y_train)

y=np.expand_dims(y_train,1)
y = np.expand_dims(y,1)
  • 定义数据生成函数
def data_generate(x_train,y_train,batch_size,epochs):
for i in range(epochs):
batch_num = len(x_train)//batch_size
shuffle_index = np.arange(batch_num)
np.random.shuffle(shuffle_index)
for j in shuffle_index:
begin = j*batch_size
end =begin+batch_size
x = x_train[begin:end]
y = y_train[begin:end] yield ({"input_11":x},{"dense_4":y})

上述命名和model中的第一层和最后一层名字一样,不然会报错。

  • 开始训练
model.fit_generator(data_generate(x_train,y,100,11),step_per_epoch=600,epochs=10)

训练过程图如下:

Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0
Epoch 1/10
Executing op __inference_keras_scratch_graph_22639 in device /job:localhost/replica:0/task:0/device:GPU:0
600/600 [==============================] - 411s 684ms/step - loss: 0.1469 - accuracy: 0.9529
Epoch 2/10
600/600 [==============================] - 398s 663ms/step - loss: 0.0375 - accuracy: 0.9884
Epoch 3/10
600/600 [==============================] - 401s 668ms/step - loss: 0.0283 - accuracy: 0.9909
Epoch 4/10
600/600 [==============================] - 399s 665ms/step - loss: 0.0211 - accuracy: 0.9936
Epoch 5/10
600/600 [==============================] - 400s 666ms/step - loss: 0.0216 - accuracy: 0.9932
Epoch 6/10
600/600 [==============================] - 401s 668ms/step - loss: 0.0208 - accuracy: 0.9935
Epoch 7/10
600/600 [==============================] - 401s 669ms/step - loss: 0.0174 - accuracy: 0.9945
Epoch 8/10
131/600 [=====>........................] - ETA: 5:13 - loss: 0.0091 - accuracy: 0.9973

模型卷积比较多,需要训练的时间有点长,参数不多,所以更新较快,收敛速度也很快。

keras实现MobileNet的更多相关文章

  1. keras中使用预训练模型进行图片分类

    keras中含有多个网络的预训练模型,可以很方便的拿来进行使用. 安装及使用主要参考官方教程:https://keras.io/zh/applications/   https://keras-cn. ...

  2. 我的Keras使用总结(4)——Application中五款预训练模型学习及其应用

    本节主要学习Keras的应用模块 Application提供的带有预训练权重的模型,这些模型可以用来进行预测,特征提取和 finetune,上一篇文章我们使用了VGG16进行特征提取和微调,下面尝试一 ...

  3. Keras读取保存的模型时, 产生错误[ValueError: Unknown activation function:relu6]

    Solution: from keras.utils.generic_utils import CustomObjectScope with CustomObjectScope({'relu6': k ...

  4. 卷积神经网络学习笔记——轻量化网络MobileNet系列(V1,V2,V3)

    完整代码及其数据,请移步小编的GitHub地址 传送门:请点击我 如果点击有误:https://github.com/LeBron-Jian/DeepLearningNote 这里结合网络的资料和Mo ...

  5. Keras RetinaNet github项目安装

    在存储库目录/keras-retinanet/中,执行pip install . --user 后,出现错误: D:\>cd D:\JupyterWorkSpace\keras-retinane ...

  6. Keras学习笔记(完结)

    使用Keras中文文档学习 基本概念 Keras的核心数据结构是模型,也就是一种组织网络层的方式,最主要的是序贯模型(Sequential).创建好一个模型后就可以用add()向里面添加层.模型搭建完 ...

  7. [Tensorflow] Object Detection API - retrain mobileNet

    前言 一.专注话题 重点话题 Retrain mobileNet (transfer learning). Train your own Object Detector. 这部分讲理论,下一篇讲实践. ...

  8. 使用keras导入densenet模型

    从keras的keras_applications的文件夹内可以找到内置模型的源代码 Kera的应用模块Application提供了带有预训练权重的Keras模型,这些模型可以用来进行预测.特征提取和 ...

  9. 【Keras学习】资源

    Keras项目github源码(python):keras-team/keras: Deep Learning for humans 里面的docs包含说明文档 中文文档:Keras中文文档 预训练模 ...

随机推荐

  1. JS删除微博

    昨天晚上找回了10年注册的微博,现在瞅瞅,转发过很多傻吊的微博,关注了一堆营销号,不忍直视,动手删吧!开玩笑的,怎么可能手动! 查看自己的所有微博,F12----->console,负责下面代码 ...

  2. (or type Control-D to continue):

    (or type Control-D to continue): 很多小伙伴学习使用Linux时可能经常遇到这个问题 (大部分原因是磁盘挂载等问题) 如下图: 具体解决方法 1.直接输入root用户的 ...

  3. Kubernetes 搭建 ES 集群(存储使用 local pv)

    一.集群规划 由于当前环境中没有分布式存储,所以只能使用本地 PV 的方式来实现数据持久化. ES 集群的 master 节点至少需要三个,防止脑裂. 由于 master 在配置过程中需要保证主机名固 ...

  4. ASP.NET Core Authentication系列(二)实现认证、登录和注销

    前言 在上一篇文章介绍ASP.NET Core Authentication的三个重要概念,分别是Claim, ClaimsIdentity, ClaimsPrincipal,以及claims-bas ...

  5. java查询elasticsearch聚合

    java查es多分组聚合: SearchRequestBuilder requestBuilderOfLastMonth = transportClient.prepareSearch(TYPE_NA ...

  6. [Luogu P2824] [HEOI2016/TJOI2016]排序 (线段树+二分答案)

    题面 传送门:https://www.luogu.org/problemnew/show/P2824 Solution 这题极其巧妙. 首先,如果直接做m次排序,显然会T得起飞. 注意一点:我们只需要 ...

  7. NOIP 2018 D1 解题报告(Day_1)

    总分   205分 T1 100分 T2  95分 T3  10分 T1: 题目描述 春春是一名道路工程师,负责铺设一条长度为 nn 的道路. 铺设道路的主要工作是填平下陷的地表.整段道路可以看作是  ...

  8. MySQL各版本connector net msi

    从其他博主那里扒来的! 链接:https://pan.baidu.com/s/1C1fYepBFKfxU0NJS0aRyJw 提取码:awsl

  9. NOIP 2012 P1081 开车旅行

    倍增 这道题最难的应该是预处理... 首先用$set$从后往前预处理出每一个点海拔差绝对值得最大值和次大值 因为当前城市的下标只能变大,对于点$i$,在$set$中二分找出与其值最接近的下标 然后再$ ...

  10. Kubernetes-17:Kubernets包管理工具—>Helm介绍与使用

    Kubernets包管理工具->Helm 什么是Helm? 我们都知道,Linux系统各发行版都有自己的包管理工具,比如Centos的YUM,再如Ubuntu的APT. Kubernetes也有 ...