keras实现MobileNet
利用keras实现MobileNet,并以mnist数据集作为一个小例子进行识别。使用的环境是:tensorflow-gpu 2.0,python=3.7 , GTX-2070的GPU
1.导入数据
- 首先是导入两行魔法命令,可以多行显示.
%config InteractiveShell.ast_node_interactivity="all"
%pprint
- 加载keras中自带的mnist数据
import tensorflow as tf
import keras
tf.debugging.set_log_device_placement(True)
mnist = keras.datasets.mnist
(x_train,y_train),(x_test,y_test) = mnist.load_data()
上述tf.debugging.set_log_device_placement(True)的作用是将模型放在GPU上进行训练。
- 数据的转换
在mnist上下载的数据的分辨率是2828的,mobilenet用来训练的数据是ImageNet ,其图片的分辨率是224224,所以先将图片的维度调整为224*224.
from PIL import Image
import numpy as np
def convert_mnist_224pix(X):
img=Image.fromarray(X)
x=np.zeros((224,224))
img=np.array(img.resize((224,224)))
x[:,:]=img
return x
iteration = iter(x_train)
new_train =np.zeros((len(x_train),224,224),dtype=np.float32)
for i in range(len(x_train)):
data = next(iteration)
new_train[i]=convert_mnist_224pix(data)
if i%5000==0:
print(i)
new_train.shape
这里要注意一下,new_train中一定要注明dtype=np.float32,不然默认的是float64,这样数据就太大了,没有那么多存储空间装。最后输出的维度是(60000,224,224)
2.搭建模型
- 导入所有需要的函数和库
from keras.layers import Conv2D,DepthwiseConv2D,Dense,AveragePooling2D,BatchNormalization,Input
from keras import Model
from keras import Sequential
from keras.layers.advanced_activations import ReLU
from keras.utils import to_categorical
- 自己定义中间可以重复利用的层,将其放在一起,简化搭建网络的重复代码。
def depth_point_conv2d(x,s=[1,1,2,1],channel=[64,128]):
"""
s:the strides of the conv
channel: the depth of pointwiseconvolutions
"""
dw1 = DepthwiseConv2D((3,3),strides=s[0],padding='same')(x)
bn1 = BatchNormalization()(dw1)
relu1 = ReLU()(bn1)
pw1 = Conv2D(channel[0],(1,1),strides=s[1],padding='same')(relu1)
bn2 = BatchNormalization()(pw1)
relu2 = ReLU()(bn2)
dw2 = DepthwiseConv2D((3,3),strides=s[2],padding='same')(relu2)
bn3 = BatchNormalization()(dw2)
relu3 = ReLU()(bn3)
pw2 = Conv2D(channel[1],(1,1),strides=s[3],padding='same')(relu3)
bn4 = BatchNormalization()(pw2)
relu4 = ReLU()(bn4)
return relu4
def repeat_conv(x,s=[1,1],channel=512):
dw1 = DepthwiseConv2D((3,3),strides=s[0],padding='same')(x)
bn1 = BatchNormalization()(dw1)
relu1 = ReLU()(bn1)
pw1 = Conv2D(channel,(1,1),strides=s[1],padding='same')(relu1)
bn2 = BatchNormalization()(pw1)
relu2 = ReLU()(bn2)
return relu2
根据mobilenet论文中的结构进行模型的搭建
在倒数第5行Conv/dw/s2中,我一直不理解如果strides=2,为什么最后生成图片尺寸没有变化,我感觉可能是笔误?,不过我这里将strides定义为1,因为这样才符合后面的整个输出。
- 搭建网络
h0=Input(shape=(224,224,1))
h1=Conv2D(32,(3,3),strides = 2,padding="same")(h0)
h2= BatchNormalization()(h1)
h3=ReLU()(h2)
h4 = depth_point_conv2d(h3,s=[1,1,2,1],channel=[64,128])
h5 = depth_point_conv2d(h4,s=[1,1,2,1],channel=[128,256])
h6 = depth_point_conv2d(h5,s=[1,1,2,1],channel=[256,512])
h7 = repeat_conv(h6)
h8 = repeat_conv(h7)
h9 = repeat_conv(h8)
h10 = repeat_conv(h9)
h11 = depth_point_conv2d(h10,s=[1,1,2,1],channel=[512,1024])
h12 = repeat_conv(h11,channel=1024)
h13 = AveragePooling2D((7,7))(h12)
h14 = Dense(10,activation='softmax')(h13)
model =Model(input=h0,output =h14)
model.summary()
Model: "model_4"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_11 (InputLayer) (None, 224, 224, 1) 0
_________________________________________________________________
conv2d_63 (Conv2D) (None, 112, 112, 32) 320
_________________________________________________________________
batch_normalization_120 (Bat (None, 112, 112, 32) 128
_________________________________________________________________
re_lu_120 (ReLU) (None, 112, 112, 32) 0
_________________________________________________________________
depthwise_conv2d_58 (Depthwi (None, 112, 112, 32) 320
_________________________________________________________________
batch_normalization_121 (Bat (None, 112, 112, 32) 128
_________________________________________________________________
re_lu_121 (ReLU) (None, 112, 112, 32) 0
_________________________________________________________________
conv2d_64 (Conv2D) (None, 112, 112, 64) 2112
_________________________________________________________________
batch_normalization_122 (Bat (None, 112, 112, 64) 256
_________________________________________________________________
re_lu_122 (ReLU) (None, 112, 112, 64) 0
_________________________________________________________________
depthwise_conv2d_59 (Depthwi (None, 56, 56, 64) 640
_________________________________________________________________
batch_normalization_123 (Bat (None, 56, 56, 64) 256
_________________________________________________________________
re_lu_123 (ReLU) (None, 56, 56, 64) 0
_________________________________________________________________
conv2d_65 (Conv2D) (None, 56, 56, 128) 8320
_________________________________________________________________
batch_normalization_124 (Bat (None, 56, 56, 128) 512
_________________________________________________________________
re_lu_124 (ReLU) (None, 56, 56, 128) 0
_________________________________________________________________
depthwise_conv2d_60 (Depthwi (None, 56, 56, 128) 1280
_________________________________________________________________
batch_normalization_125 (Bat (None, 56, 56, 128) 512
_________________________________________________________________
re_lu_125 (ReLU) (None, 56, 56, 128) 0
_________________________________________________________________
conv2d_66 (Conv2D) (None, 56, 56, 128) 16512
_________________________________________________________________
batch_normalization_126 (Bat (None, 56, 56, 128) 512
_________________________________________________________________
re_lu_126 (ReLU) (None, 56, 56, 128) 0
_________________________________________________________________
depthwise_conv2d_61 (Depthwi (None, 28, 28, 128) 1280
_________________________________________________________________
batch_normalization_127 (Bat (None, 28, 28, 128) 512
_________________________________________________________________
re_lu_127 (ReLU) (None, 28, 28, 128) 0
_________________________________________________________________
conv2d_67 (Conv2D) (None, 28, 28, 256) 33024
_________________________________________________________________
batch_normalization_128 (Bat (None, 28, 28, 256) 1024
_________________________________________________________________
re_lu_128 (ReLU) (None, 28, 28, 256) 0
_________________________________________________________________
depthwise_conv2d_62 (Depthwi (None, 28, 28, 256) 2560
_________________________________________________________________
batch_normalization_129 (Bat (None, 28, 28, 256) 1024
_________________________________________________________________
re_lu_129 (ReLU) (None, 28, 28, 256) 0
_________________________________________________________________
conv2d_68 (Conv2D) (None, 28, 28, 256) 65792
_________________________________________________________________
batch_normalization_130 (Bat (None, 28, 28, 256) 1024
_________________________________________________________________
re_lu_130 (ReLU) (None, 28, 28, 256) 0
_________________________________________________________________
depthwise_conv2d_63 (Depthwi (None, 14, 14, 256) 2560
_________________________________________________________________
batch_normalization_131 (Bat (None, 14, 14, 256) 1024
_________________________________________________________________
re_lu_131 (ReLU) (None, 14, 14, 256) 0
_________________________________________________________________
conv2d_69 (Conv2D) (None, 14, 14, 512) 131584
_________________________________________________________________
batch_normalization_132 (Bat (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_132 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
depthwise_conv2d_64 (Depthwi (None, 14, 14, 512) 5120
_________________________________________________________________
batch_normalization_133 (Bat (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_133 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
conv2d_70 (Conv2D) (None, 14, 14, 512) 262656
_________________________________________________________________
batch_normalization_134 (Bat (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_134 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
depthwise_conv2d_65 (Depthwi (None, 14, 14, 512) 5120
_________________________________________________________________
batch_normalization_135 (Bat (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_135 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
conv2d_71 (Conv2D) (None, 14, 14, 512) 262656
_________________________________________________________________
batch_normalization_136 (Bat (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_136 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
depthwise_conv2d_66 (Depthwi (None, 14, 14, 512) 5120
_________________________________________________________________
batch_normalization_137 (Bat (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_137 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
conv2d_72 (Conv2D) (None, 14, 14, 512) 262656
_________________________________________________________________
batch_normalization_138 (Bat (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_138 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
depthwise_conv2d_67 (Depthwi (None, 14, 14, 512) 5120
_________________________________________________________________
batch_normalization_139 (Bat (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_139 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
conv2d_73 (Conv2D) (None, 14, 14, 512) 262656
_________________________________________________________________
batch_normalization_140 (Bat (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_140 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
depthwise_conv2d_68 (Depthwi (None, 14, 14, 512) 5120
_________________________________________________________________
batch_normalization_141 (Bat (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_141 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
conv2d_74 (Conv2D) (None, 14, 14, 512) 262656
_________________________________________________________________
batch_normalization_142 (Bat (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_142 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
depthwise_conv2d_69 (Depthwi (None, 7, 7, 512) 5120
_________________________________________________________________
batch_normalization_143 (Bat (None, 7, 7, 512) 2048
_________________________________________________________________
re_lu_143 (ReLU) (None, 7, 7, 512) 0
_________________________________________________________________
conv2d_75 (Conv2D) (None, 7, 7, 1024) 525312
_________________________________________________________________
batch_normalization_144 (Bat (None, 7, 7, 1024) 4096
_________________________________________________________________
re_lu_144 (ReLU) (None, 7, 7, 1024) 0
_________________________________________________________________
depthwise_conv2d_70 (Depthwi (None, 7, 7, 1024) 10240
_________________________________________________________________
batch_normalization_145 (Bat (None, 7, 7, 1024) 4096
_________________________________________________________________
re_lu_145 (ReLU) (None, 7, 7, 1024) 0
_________________________________________________________________
conv2d_76 (Conv2D) (None, 7, 7, 1024) 1049600
_________________________________________________________________
batch_normalization_146 (Bat (None, 7, 7, 1024) 4096
_________________________________________________________________
re_lu_146 (ReLU) (None, 7, 7, 1024) 0
_________________________________________________________________
average_pooling2d_5 (Average (None, 1, 1, 1024) 0
_________________________________________________________________
dense_4 (Dense) (None, 1, 1, 10) 10250
=================================================================
Total params: 3,249,482
Trainable params: 3,227,594
Non-trainable params: 21,888
_________________________________________________________________
因为这里的类别只有10类,所以最后的输出层只有10个神经元,原始的mobilenet要进行1000个类别分类,所以最后是1000个神经元。
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
上述代码定义优化算法和损失函数。
3、训练数据的整理与训练
将训练数据进行维度变换,标签进行one-hot编码并进行维度变换。
x_train = np.expand_dims(new_train,3)
y_train = to_categorical(y_train)
y=np.expand_dims(y_train,1)
y = np.expand_dims(y,1)
- 定义数据生成函数
def data_generate(x_train,y_train,batch_size,epochs):
for i in range(epochs):
batch_num = len(x_train)//batch_size
shuffle_index = np.arange(batch_num)
np.random.shuffle(shuffle_index)
for j in shuffle_index:
begin = j*batch_size
end =begin+batch_size
x = x_train[begin:end]
y = y_train[begin:end]
yield ({"input_11":x},{"dense_4":y})
上述命名和model中的第一层和最后一层名字一样,不然会报错。
- 开始训练
model.fit_generator(data_generate(x_train,y,100,11),step_per_epoch=600,epochs=10)
训练过程图如下:
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0
Epoch 1/10
Executing op __inference_keras_scratch_graph_22639 in device /job:localhost/replica:0/task:0/device:GPU:0
600/600 [==============================] - 411s 684ms/step - loss: 0.1469 - accuracy: 0.9529
Epoch 2/10
600/600 [==============================] - 398s 663ms/step - loss: 0.0375 - accuracy: 0.9884
Epoch 3/10
600/600 [==============================] - 401s 668ms/step - loss: 0.0283 - accuracy: 0.9909
Epoch 4/10
600/600 [==============================] - 399s 665ms/step - loss: 0.0211 - accuracy: 0.9936
Epoch 5/10
600/600 [==============================] - 400s 666ms/step - loss: 0.0216 - accuracy: 0.9932
Epoch 6/10
600/600 [==============================] - 401s 668ms/step - loss: 0.0208 - accuracy: 0.9935
Epoch 7/10
600/600 [==============================] - 401s 669ms/step - loss: 0.0174 - accuracy: 0.9945
Epoch 8/10
131/600 [=====>........................] - ETA: 5:13 - loss: 0.0091 - accuracy: 0.9973
模型卷积比较多,需要训练的时间有点长,参数不多,所以更新较快,收敛速度也很快。
keras实现MobileNet的更多相关文章
- keras中使用预训练模型进行图片分类
keras中含有多个网络的预训练模型,可以很方便的拿来进行使用. 安装及使用主要参考官方教程:https://keras.io/zh/applications/ https://keras-cn. ...
- 我的Keras使用总结(4)——Application中五款预训练模型学习及其应用
本节主要学习Keras的应用模块 Application提供的带有预训练权重的模型,这些模型可以用来进行预测,特征提取和 finetune,上一篇文章我们使用了VGG16进行特征提取和微调,下面尝试一 ...
- Keras读取保存的模型时, 产生错误[ValueError: Unknown activation function:relu6]
Solution: from keras.utils.generic_utils import CustomObjectScope with CustomObjectScope({'relu6': k ...
- 卷积神经网络学习笔记——轻量化网络MobileNet系列(V1,V2,V3)
完整代码及其数据,请移步小编的GitHub地址 传送门:请点击我 如果点击有误:https://github.com/LeBron-Jian/DeepLearningNote 这里结合网络的资料和Mo ...
- Keras RetinaNet github项目安装
在存储库目录/keras-retinanet/中,执行pip install . --user 后,出现错误: D:\>cd D:\JupyterWorkSpace\keras-retinane ...
- Keras学习笔记(完结)
使用Keras中文文档学习 基本概念 Keras的核心数据结构是模型,也就是一种组织网络层的方式,最主要的是序贯模型(Sequential).创建好一个模型后就可以用add()向里面添加层.模型搭建完 ...
- [Tensorflow] Object Detection API - retrain mobileNet
前言 一.专注话题 重点话题 Retrain mobileNet (transfer learning). Train your own Object Detector. 这部分讲理论,下一篇讲实践. ...
- 使用keras导入densenet模型
从keras的keras_applications的文件夹内可以找到内置模型的源代码 Kera的应用模块Application提供了带有预训练权重的Keras模型,这些模型可以用来进行预测.特征提取和 ...
- 【Keras学习】资源
Keras项目github源码(python):keras-team/keras: Deep Learning for humans 里面的docs包含说明文档 中文文档:Keras中文文档 预训练模 ...
随机推荐
- JS删除微博
昨天晚上找回了10年注册的微博,现在瞅瞅,转发过很多傻吊的微博,关注了一堆营销号,不忍直视,动手删吧!开玩笑的,怎么可能手动! 查看自己的所有微博,F12----->console,负责下面代码 ...
- (or type Control-D to continue):
(or type Control-D to continue): 很多小伙伴学习使用Linux时可能经常遇到这个问题 (大部分原因是磁盘挂载等问题) 如下图: 具体解决方法 1.直接输入root用户的 ...
- Kubernetes 搭建 ES 集群(存储使用 local pv)
一.集群规划 由于当前环境中没有分布式存储,所以只能使用本地 PV 的方式来实现数据持久化. ES 集群的 master 节点至少需要三个,防止脑裂. 由于 master 在配置过程中需要保证主机名固 ...
- ASP.NET Core Authentication系列(二)实现认证、登录和注销
前言 在上一篇文章介绍ASP.NET Core Authentication的三个重要概念,分别是Claim, ClaimsIdentity, ClaimsPrincipal,以及claims-bas ...
- java查询elasticsearch聚合
java查es多分组聚合: SearchRequestBuilder requestBuilderOfLastMonth = transportClient.prepareSearch(TYPE_NA ...
- [Luogu P2824] [HEOI2016/TJOI2016]排序 (线段树+二分答案)
题面 传送门:https://www.luogu.org/problemnew/show/P2824 Solution 这题极其巧妙. 首先,如果直接做m次排序,显然会T得起飞. 注意一点:我们只需要 ...
- NOIP 2018 D1 解题报告(Day_1)
总分 205分 T1 100分 T2 95分 T3 10分 T1: 题目描述 春春是一名道路工程师,负责铺设一条长度为 nn 的道路. 铺设道路的主要工作是填平下陷的地表.整段道路可以看作是 ...
- MySQL各版本connector net msi
从其他博主那里扒来的! 链接:https://pan.baidu.com/s/1C1fYepBFKfxU0NJS0aRyJw 提取码:awsl
- NOIP 2012 P1081 开车旅行
倍增 这道题最难的应该是预处理... 首先用$set$从后往前预处理出每一个点海拔差绝对值得最大值和次大值 因为当前城市的下标只能变大,对于点$i$,在$set$中二分找出与其值最接近的下标 然后再$ ...
- Kubernetes-17:Kubernets包管理工具—>Helm介绍与使用
Kubernets包管理工具->Helm 什么是Helm? 我们都知道,Linux系统各发行版都有自己的包管理工具,比如Centos的YUM,再如Ubuntu的APT. Kubernetes也有 ...