本文将会介绍如何利用Keras来实现模型的保存、读取以及加载。

  本文使用的模型为解决IRIS数据集的多分类问题而设计的深度神经网络(DNN)模型,模型的结构示意图如下:

具体的模型参数可以参考文章:Keras入门(一)搭建深度神经网络(DNN)解决多分类问题

模型保存

  Keras使用HDF5文件系统来保存模型。模型保存的方法很容易,只需要使用save()方法即可。

  以Keras入门(一)搭建深度神经网络(DNN)解决多分类问题中的DNN模型为例,整个模型的变量为model,我们设置模型共训练10次,在原先的代码中加入Python代码即可保存模型:

    # save model
print("Saving model to disk \n")
mp = "E://logs/iris_model.h5"
model.save(mp)

保存的模型文件(iris_model.h5)如下:

模型读取

  保存后的iris_model.h5以HDF5文件系统的形式储存,在我们使用Python读取h5文件里面的数据之前,我们先用HDF5的可视化工具HDFView来查看里面的数据:

  我们感兴趣的是这个模型中的各个神经层之间的连接权重及偏重,也就是上图中的红色部分,model_weights里面包含了各个神经层之间的连接权重及偏重,分别位于dense_1,dense_2,dense_3中。蓝色部分为dense_3/dense_3/kernel:0的数据,即最后输出层的连接权重矩阵。

  有了对模型参数的直观认识,我们要做的下一步工作就是读取各个神经层之间的连接权重及偏重。我们使用Python的h5py这个模块来这个iris_model.h5这个文件。关于h5py的快速入门指南,可以参考文章:h5py快速入门指南

  使用以下Python代码可以读取各个神经层之间的连接权重及偏重数据:

import h5py

# 模型地址
MODEL_PATH = 'E://logs/iris_model.h5' # 获取每一层的连接权重及偏重
print("读取模型中...")
with h5py.File(MODEL_PATH, 'r') as f:
dense_1 = f['/model_weights/dense_1/dense_1']
dense_1_bias = dense_1['bias:0'][:]
dense_1_kernel = dense_1['kernel:0'][:] dense_2 = f['/model_weights/dense_2/dense_2']
dense_2_bias = dense_2['bias:0'][:]
dense_2_kernel = dense_2['kernel:0'][:] dense_3 = f['/model_weights/dense_3/dense_3']
dense_3_bias = dense_3['bias:0'][:]
dense_3_kernel = dense_3['kernel:0'][:] print("第一层的连接权重矩阵:\n%s\n"%dense_1_kernel)
print("第一层的连接偏重矩阵:\n%s\n"%dense_1_bias)
print("第二层的连接权重矩阵:\n%s\n"%dense_2_kernel)
print("第二层的连接偏重矩阵:\n%s\n"%dense_2_bias)
print("第三层的连接权重矩阵:\n%s\n"%dense_3_kernel)
print("第三层的连接偏重矩阵:\n%s\n"%dense_3_bias)

输出的结果如下:

读取模型中...
第一层的连接权重矩阵:
[[ 0.04141677 0.03080632 -0.02768146 0.14334357 0.06242227]
[-0.41209617 -0.77948487 0.5648218 -0.699587 -0.19246106]
[ 0.6856315 0.28241938 -0.91930366 -0.07989818 0.47165248]
[ 0.8655262 0.72175753 0.36529952 -0.53172135 0.26573092]] 第一层的连接偏重矩阵:
[-0.16441862 -0.02462054 -0.14060321 0. -0.14293939] 第二层的连接权重矩阵:
[[ 0.39296603 0.01864707 0.12538083 0.07935872 0.27940807 -0.4565802 ]
[-0.34312084 0.6446907 -0.92546445 -0.00538039 0.95466876 -0.32819661]
[-0.7593299 -0.07227057 0.20751365 0.40547106 0.35726753 0.8884158 ]
[-0.48096 0.11294878 -0.29462305 -0.410536 -0.23620337 -0.72703975]
[ 0.7666149 -0.41720924 0.29576775 -0.6328017 0.43118536 0.6589351 ]] 第二层的连接偏重矩阵:
[-0.1899569 0. -0.09710662 -0.12964155 -0.26443157 0.6050924 ] 第三层的连接权重矩阵:
[[-0.44450542 0.09977101 0.12196152]
[ 0.14334357 0.18546402 -0.23861367]
[-0.7284191 0.7859063 -0.878823 ]
[ 0.0876545 0.51531947 0.09671918]
[-0.7964963 -0.16435687 0.49531657]
[ 0.8645698 0.4439873 0.24599855]] 第三层的连接偏重矩阵:
[ 0.39192322 -0.1266532 -0.29631865]

值得注意的是,我们得到的这些矩阵的数据类型都是numpy.ndarray。

  OK,既然我们已经得到了各个神经层之间的连接权重及偏重的数据,那我们能做什么呢?当然是去做一些有趣的事啦,那就是用我们自己的方法来实现新数据的预测向量(softmax函数作用后的向量)。so, really?

  新的输入向量为[6.1, 3.1, 5.1, 1.1],使用以下Python代码即可输出新数据的预测向量:

import h5py
import numpy as np # 模型地址
MODEL_PATH = 'E://logs/iris_model.h5' # 获取每一层的连接权重及偏重
print("读取模型中...")
with h5py.File(MODEL_PATH, 'r') as f:
dense_1 = f['/model_weights/dense_1/dense_1']
dense_1_bias = dense_1['bias:0'][:]
dense_1_kernel = dense_1['kernel:0'][:] dense_2 = f['/model_weights/dense_2/dense_2']
dense_2_bias = dense_2['bias:0'][:]
dense_2_kernel = dense_2['kernel:0'][:] dense_3 = f['/model_weights/dense_3/dense_3']
dense_3_bias = dense_3['bias:0'][:]
dense_3_kernel = dense_3['kernel:0'][:] # 模拟每个神经层的计算,得到该层的输出
def layer_output(input, kernel, bias):
return np.dot(input, kernel) + bias # 实现ReLU函数
relu = np.vectorize(lambda x: x if x >=0 else 0) # 实现softmax函数
def softmax_func(arr):
exp_arr = np.exp(arr)
arr_sum = np.sum(exp_arr)
softmax_arr = exp_arr/arr_sum
return softmax_arr # 输入向量
unkown = np.array([[6.1, 3.1, 5.1, 1.1]], dtype=np.float32) # 第一层的输出
print("模型计算中...")
output_1 = layer_output(unkown, dense_1_kernel, dense_1_bias)
output_1 = relu(output_1) # 第二层的输出
output_2 = layer_output(output_1, dense_2_kernel, dense_2_bias)
output_2 = relu(output_2) # 第三层的输出
output_3 = layer_output(output_2, dense_3_kernel, dense_3_bias)
output_3 = softmax_func(output_3) # 最终的输出的softmax值
np.set_printoptions(precision=4)
print("最终的预测值向量为: %s"%output_3)

其输出的结果如下:

读取模型中...
模型计算中...
最终的预测值向量为: [[0.0242 0.6763 0.2995]]

  额,这个输出的预测值向量会是我们的DNN模型的预测值向量吗?这时候,我们就需要回过头来看看Keras入门(一)搭建深度神经网络(DNN)解决多分类问题中的代码了,注意,为了保证数值的可比较性,笔者已经将DNN模型的训练次数改为10次了。让我们来看看原来代码的输出结果吧:

Using model to predict species for features:
[[6.1 3.1 5.1 1.1]] Predicted softmax vector is:
[[0.0242 0.6763 0.2995]] Predicted species is:
Iris-versicolor

Yes,两者的预测值向量完全一致!因此,我们用自己的方法也实现了这个DNN模型的预测功能,棒!

模型加载

  当然,在实际的使用中,我们不需要再用自己的方法来实现模型的预测功能,只需使用Keras给我们提供好的模型导入功能(keras.models.load_model())即可。使用以下Python代码即可加载模型

    # 模型的加载及使用
from keras.models import load_model
print("Using loaded model to predict...")
load_model = load_model("E://logs/iris_model.h5")
np.set_printoptions(precision=4)
unknown = np.array([[6.1, 3.1, 5.1, 1.1]], dtype=np.float32)
predicted = load_model.predict(unknown)
print("Using model to predict species for features: ")
print(unknown)
print("\nPredicted softmax vector is: ")
print(predicted)
species_dict = {v: k for k, v in Class_dict.items()}
print("\nPredicted species is: ")
print(species_dict[np.argmax(predicted)])

输出结果如下:

Using loaded model to predict...
Using model to predict species for features:
[[6.1 3.1 5.1 1.1]] Predicted softmax vector is:
[[0.0242 0.6763 0.2995]] Predicted species is:
Iris-versicolor

总结

  本文主要介绍如何利用Keras来实现模型的保存、读取以及加载。

  本文将不再给出完整的Python代码,如需完整的代码,请参考Github地址:https://github.com/percent4/Keras_4_multiclass.

注意:本人现已开通微信公众号: Python爬虫与算法(微信号为:easy_web_scrape), 欢迎大家关注哦~~

Keras入门(二)模型的保存、读取及加载的更多相关文章

  1. keras模型的保存与重新加载

    # 模型保存JSON文件 model_json = model.to_json() with open('model.json', 'w') as file: file.write(model_jso ...

  2. pyspider 示例二 升级完整版绕过懒加载,直接读取图片

    pyspider 示例二 升级完整版绕过懒加载,直接读取图片,见[升级写法处] #!/usr/bin/env python # -*- encoding: utf-8 -*- # Created on ...

  3. esri-leaflet入门教程(5)- 动态要素加载

    esri-leaflet入门教程(5)- 动态要素加载 by 李远祥 在上一章节中已经说明了esr-leaflet是如何加载ArcGIS Server提供的各种服务,这些都是服务本身来决定的,API脚 ...

  4. Linux内核启动代码分析二之开发板相关驱动程序加载分析

    Linux内核启动代码分析二之开发板相关驱动程序加载分析 1 从linux开始启动的函数start_kernel开始分析,该函数位于linux-2.6.22/init/main.c  start_ke ...

  5. DB数据源之SpringBoot+MyBatis踏坑过程(二)手工配置数据源与加载Mapper.xml扫描

    DB数据源之SpringBoot+MyBatis踏坑过程(二)手工配置数据源与加载Mapper.xml扫描 liuyuhang原创,未经允许进制转载  吐槽之后应该有所改了,该方式可以作为一种过渡方式 ...

  6. spark SQL (四)数据源 Data Source----Parquet 文件的读取与加载

    spark SQL Parquet 文件的读取与加载 是由许多其他数据处理系统支持的柱状格式.Spark SQL支持阅读和编写自动保留原始数据模式的Parquet文件.在编写Parquet文件时,出于 ...

  7. 基于 Koa平台Node.js开发的KoaHub.js的控制器,模型,帮助方法自动加载

    koahub-loader koahub-loader是基于 Koa平台Node.js开发的KoaHub.js的koahub-loader控制器,模型,帮助方法自动加载 koahub loader I ...

  8. Unity3d-WWW实现图片资源显示以及保存和本地加载

    本文固定连接:http://blog.csdn.net/u013108312/article/details/52712844 WWW实现图片资源显示以及保存和本地加载 using UnityEngi ...

  9. tensorflow 模型保存后的加载路径问题

    import tensorflow as tf #保存模型 saver = tf.train.Saver() saver.save(sess, "e://code//python//test ...

随机推荐

  1. 浅谈如何检查Linux中开放端口列表

    给大家分享一篇关于如何检查Linux中的开放端口列表的详细介绍,首先如果你想检查远程Linux系统上的端口是否打开请点击链接浏览.如果你想检查多个远程Linux系统上的端口是否打开请点击链接浏览.如果 ...

  2. JavaScript基础学习笔记整理

    1.关于JS: (1)脚本语言——不需要编译的语言(常见有cmd,t-sql)----解释性语言; (2)动态类型的语言——1.代码只有执行到那个位置才知道那个变量中存储的是什么 2.对象中没有某个属 ...

  3. 关于Android 8.0java.lang.SecurityException: Permission Denial错误的解决方法

    背景 当我在Android 7.0及以下手机运行启动页,进行Activity跳转的时候,完美跳转到对应的目标Activity. 但当在Android 8.0及以上手机进行Activity跳转时,会爆如 ...

  4. Linux更新源汇总-18.9.7更新

    企业站 阿里云:https://opsx.alibaba.com/mirror 网易:http://mirrors.163.com/ 教育站 北京理工大学:http://mirror.bit.edu. ...

  5. 【CF429E】 Points and Segments(欧拉回路)

    传送门 CodeForces 洛谷 Solution 考虑欧拉回路有一个性质. 如果把点抽出来搞成一条直线,路径看成区间覆盖,那么一个点从左往右被覆盖的次数等于从右往左被覆盖的次数. 发现这个性质和本 ...

  6. Python学习笔记【第十篇】:Python面向对象进阶

    保护对象的属性 如果有一个对象,当需要对其进行修改属性时,有2种方法 对象名.属性名 = 数据 ---->直接修改 对象名.方法名() ---->间接修改 为了更好的保存属性安全,即不能随 ...

  7. Kali学习笔记17:OpenVAS安装部署

    正式介绍OpenVAS之前先说一些题外话 1.有一个网站记录了很多的漏洞: https://www.exploit-db.com/ 可以下载利用 2.如果觉得从网上寻找太麻烦,Kali自带工具:sea ...

  8. apollo入门(一)

    1. apollo入门(一) 1.1. 核心概念 1.1.1. 应用 注意:每个应用需要配置一个appid 1.1.2. 环境 dev 开发环境 fat 功能测试环境 uat 用户接受测试环境 pro ...

  9. SpringCache实战遇坑

    1. SpringCache实战遇坑 1.1. pom 主要是以下两个 <dependency> <groupId>org.springframework.boot</g ...

  10. 四、activiti工作流-第一个HelloWorld

    上一节已经把流程图画好,并且数据库也已经创建好了25张表,这节讲如何启动一个流程 先新建一个包,并新建一个类. /**然后定义一个成员属性,主要是因为每个方法都要用到这个引擎 * 获取默认流程引擎实例 ...