我们以MNIST手写数字识别为例

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD # 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10) # 创建模型,输入784个神经元,输出10个神经元
model = Sequential([
Dense(units=10,input_dim=784,bias_initializer='one',activation='softmax')
]) # 定义优化器
sgd = SGD(lr=0.2) # 定义优化器,loss function,训练过程中计算准确率
model.compile(
optimizer = sgd,
loss = 'mse',
metrics=['accuracy'],
) # 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=5) # 评估模型
loss,accuracy = model.evaluate(x_test,y_test) print('\ntest loss',loss)
print('accuracy',accuracy) # 保存模型
model.save('model.h5') # HDF5文件,pip install h5py

载入初次训练的模型,再训练

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
from keras.models import load_model
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10) # 载入模型
model = load_model('model.h5') # 评估模型
loss,accuracy = model.evaluate(x_test,y_test) print('\ntest loss',loss)
print('accuracy',accuracy) # 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=2) # 评估模型
loss,accuracy = model.evaluate(x_test,y_test) print('\ntest loss',loss)
print('accuracy',accuracy) # 保存参数,载入参数
model.save_weights('my_model_weights.h5')
model.load_weights('my_model_weights.h5')
# 保存网络结构,载入网络结构
from keras.models import model_from_json
json_string = model.to_json()
model = model_from_json(json_string) print(json_string)

关于compile和load_model()的使用顺序

这一段落主要是为了解决我们fit、evaluate、predict之前还是之后使用compile。想要弄明白,首先我们要清楚compile在程序中是做什么的?都做了什么?

compile做什么?

compile定义了loss function损失函数、optimizer优化器和metrics度量。它与权重无关,也就是说compile并不会影响权重,不会影响之前训练的问题。

如果我们要训练模型或者评估模型evaluate,则需要compile,因为训练要使用损失函数和优化器,评估要使用度量方法;如果我们要预测,则没有必要compile模型。

是否需要多次编译?

除非我们要更改其中之一:损失函数、优化器 / 学习率、度量

又或者我们加载了尚未编译的模型。或者您的加载/保存方法没有考虑以前的编译。

再次compile的后果?

如果再次编译模型,将会丢失优化器状态.

这意味着您的训练在开始时会受到一点影响,直到调整学习率,动量等为止。但是绝对不会对重量造成损害(除非您的初始学习率如此之大,以至于第一次训练步骤疯狂地更改微调的权重)。

Keras保存模型并载入模型继续训练的更多相关文章

  1. keras 保存模型

    转自:https://blog.csdn.net/u010159842/article/details/54407745,感谢分享! 我们不推荐使用pickle或cPickle来保存Keras模型 你 ...

  2. TensorFlow保存和载入模型

    首先定义一个tf.train.Saver类: saver = tf.train.Saver(max_to_keep=1) 其中,max_to_keep参数设定只保存最后一个参数,默认值是5,即保存最后 ...

  3. (原+译)pytorch中保存和载入模型

    转载请注明出处: http://www.cnblogs.com/darkknightzh/p/8108466.html 参考网址: http://pytorch.org/docs/master/not ...

  4. Keras框架下的保存模型和加载模型

    在Keras框架下训练深度学习模型时,一般思路是在训练环境下训练出模型,然后拿训练好的模型(即保存模型相应信息的文件)到生产环境下去部署.在训练过程中我们可能会遇到以下情况: 需要运行很长时间的程序在 ...

  5. Keras入门(六)模型训练实时可视化

      在北京做某个项目的时候,客户要求能够对数据进行训练.预测,同时能导出模型,还有在页面上显示训练的进度.前面的几个要求都不难实现,但在页面上显示训练进度当时笔者并没有实现.   本文将会分享如何在K ...

  6. TensorFlow从1到2(七)线性回归模型预测汽车油耗以及训练过程优化

    线性回归模型 "回归"这个词,既是Regression算法的名称,也代表了不同的计算结果.当然结果也是由算法决定的. 不同于前面讲过的多个分类算法或者逻辑回归,线性回归模型的结果是 ...

  7. keras系列︱Sequential与Model模型、keras基本结构功能(一)

    引自:http://blog.csdn.net/sinat_26917383/article/details/72857454 中文文档:http://keras-cn.readthedocs.io/ ...

  8. 三分钟快速上手TensorFlow 2.0 (下)——模型的部署 、大规模训练、加速

    前文:三分钟快速上手TensorFlow 2.0 (中)——常用模块和模型的部署 TensorFlow 模型导出 使用 SavedModel 完整导出模型 不仅包含参数的权值,还包含计算的流程(即计算 ...

  9. 【Keras篇】---Keras初始,两种模型构造方法,利用keras实现手写数字体识别

    一.前述 Keras 适合快速体验 ,keras的设计是把大量内部运算都隐藏了,用户始终可以用theano或tensorflow的语句来写扩展功能并和keras结合使用. 二.安装 Pip insta ...

随机推荐

  1. 基于Python的SQL Server数据库对象同步轻量级实现

    缘由 日常工作中经常遇到类似的问题:把某个服务器上的某些指定的表同步到另外一台服务器.类似需求用SSIS或者其他ETL工作很容易实现,比如用SSIS的话就可以,但会存在相当一部分反复的手工操作.建源的 ...

  2. MySQL数据库~~~~~创建用户和授权、备份和还原

    一 MySQL创建用户和授权 1.1 对新用户增删改 1.创建用户: # 指定ip:192.118.1.1的chao用户登录 create user 'chao'@'192.118.1.1' iden ...

  3. 计算机网络基础笔记 运输层协议UDP/TCP

    目录 UDP 首部结构 主要特点 TCP 首部结构 主要特点 TCP 可靠性实现 停止等待ARQ协议 连续ARQ协议&滑动窗口协议 拥塞控制 TCP 运输连接管理 连接建立:三次握手 连接释放 ...

  4. openstack Glance安装与配置

    一.实验目的: 1.理解glance镜像服务在OpenStack框架中的作用 2.掌握glance服务安装的基本方法 3.掌握glance的配置基本方法 二.实验步骤: 1.在controller节点 ...

  5. Github(第一次尝试)

    重要提示:项目中的文件最好最好不要出现中文,尤其是复杂的中文文件名. 前提:本地已经用git 管理 一个测试项目(项目一),分支为master. 1.注册 github: http://git.osc ...

  6. JavaScript-----12.对象

    1. 对象 万物皆对象,但是对象必须是一个具体的事物.例如:"明星"不是对象,"周星驰"是对象:"苹果"不是对象"这个苹果&quo ...

  7. 在浏览器地址栏输入www.baidu.com到打开百度首页这期间到底发生了什么?

    刚才无意间看到这么一个面试题,觉得有点意思,我想从五层网络模型的角度说说我的看法. 1.首先通过DNS域名系统向域名服务器发送域名解析请求来得到百度的IP地址39.156.69.79:2.系统通过AR ...

  8. dotnet core 调用electron来开发UI的探索

    先上仓库地址 https://github.com/lightszero/webwindow.netcore dotnet core 很喜欢,问题dotnet core 不包含GUI,经过一些尝试,觉 ...

  9. 解决tail命令提示“tail: inotify 资源耗尽,无法使用 inotify 机制,回归为 polling 机制”

    报错的原因是 inotify 跟踪的文件数量超出了系统设置的上限值,要是这个问题不经常出现可以使用临时解决方法,或者写入配置文件来永久解决. 临时解决方法: # 查看 inotify 的相关配置 $ ...

  10. PlayJava Day012

    今日所学: /* 2019.08.19开始学习,此为补档. */ JPanel和JFrame 1.JFrame是最底层,JPanel是置于其面上,同一个界面只有一个JFrame,一个JFrame可以放 ...