keras建模的3种方式——序列模型、函数模型、子类模型
1 前言
keras是Google公司于2016年发布的以tensorflow为后端的用于深度学习网络训练的高阶API,因接口设计非常人性化,深受程序员的喜爱。
keras建模有3种实现方式——序列模型、函数模型、子类模型。本文以MNIST手写数字为例,用3种建模方式实现。
关于MNIST数据集的说明,见使用TensorFlow实现MNIST数据集分类
笔者工作空间如下:

2 序列模型
sequential.py
from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Sequential
from keras.models import load_model
from keras.layers import Dense
#载入数据
def read_data(path):
mnist=input_data.read_data_sets(path,one_hot=True)
train_x,train_y=mnist.train.images,mnist.train.labels,
valid_x,valid_y=mnist.validation.images,mnist.validation.labels,
test_x,test_y=mnist.test.images,mnist.test.labels
return train_x,train_y,valid_x,valid_y,test_x,test_y
#序列模型
def DNN(train_x,train_y,valid_x,valid_y):
#创建模型
model=Sequential()
model.add(Dense(64,input_dim=784,activation='relu'))
model.add(Dense(128,activation='relu'))
model.add(Dense(10,activation='softmax'))
#查看网络结构
model.summary()
#编译模型
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
#训练模型
model.fit(train_x,train_y,batch_size=500,nb_epoch=100,verbose=2,validation_data=(valid_x,valid_y))
#保存模型
model.save('sequential.h5')
train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
DNN(train_x,train_y,valid_x,valid_y)
model=load_model('sequential.h5') #下载模型
pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2) #评估模型
print('test_loss:',pre[0],'- test_acc:',pre[1])
运行结果
Epoch 98/100
- 1s - loss: 4.8694e-04 - acc: 1.0000 - val_loss: 0.1331 - val_acc: 0.9776
Epoch 99/100
- 1s - loss: 4.7432e-04 - acc: 1.0000 - val_loss: 0.1336 - val_acc: 0.9778
Epoch 100/100
- 1s - loss: 4.6462e-04 - acc: 1.0000 - val_loss: 0.1343 - val_acc: 0.9774
test_loss: 0.13972217990085484 - test_acc: 0.9768999993801117
3 函数模型
fun_model.py
from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Model
from keras.models import load_model
from keras.layers import Input,Dense
#载入数据
def read_data(path):
mnist=input_data.read_data_sets(path,one_hot=True)
train_x,train_y=mnist.train.images,mnist.train.labels,
valid_x,valid_y=mnist.validation.images,mnist.validation.labels,
test_x,test_y=mnist.test.images,mnist.test.labels
return train_x,train_y,valid_x,valid_y,test_x,test_y
#函数模型
def DNN(train_x,train_y,valid_x,valid_y):
#创建模型
inputs=Input(shape=(784,))
x=Dense(64,activation='relu')(inputs)
x=Dense(128,activation='relu')(x)
output=Dense(10,activation='softmax')(x)
model=Model(input=inputs,output=output)
#查看网络结构
model.summary()
#编译模型
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
#训练模型
model.fit(train_x,train_y,batch_size=500,nb_epoch=100,verbose=2,validation_data=(valid_x,valid_y))
#保存模型
model.save('fun_model.h5')
train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
DNN(train_x,train_y,valid_x,valid_y)
model=load_model('fun_model.h5') #下载模型
pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2) #评估模型
print('test_loss:',pre[0],'- test_acc:',pre[1])
4 子类模型
class_model.py
from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Model
from keras.layers import Dense
#载入数据
def read_data(path):
mnist=input_data.read_data_sets(path,one_hot=True)
train_x,train_y=mnist.train.images,mnist.train.labels,
valid_x,valid_y=mnist.validation.images,mnist.validation.labels,
test_x,test_y=mnist.test.images,mnist.test.labels
return train_x,train_y,valid_x,valid_y,test_x,test_y
#子类模型
class DNN(Model):
def __init__(self):
super(DNN,self).__init__()
#初始化网络结构
self.dense1=Dense(64,input_dim=784,activation='relu')
self.dense2=Dense(128,activation='relu')
self.dense3=Dense(10,activation='softmax')
def call(self,inputs): #回调顺序
x=self.dense1(inputs)
x=self.dense2(x)
x=self.dense3(x)
return x
train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
model=DNN()
#编译模型
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
#训练模型
model.fit(train_x,train_y,batch_size=500,nb_epoch=100,verbose=2,validation_data=(valid_x,valid_y))
#查看网络结构
model.summary()
pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2) #评估模型
print('test_loss:',pre[0],'- test_acc:',pre[1])
5 注意事项
(1)只有序列模型和函数模型能够保存模型,子类模型不能保存模型,即不能调用 model.save()
(2)子类模型中,model.summary() 得放在 model.fit() 之后,否则会报错
ValueError: This model has not yet been built. Build the model first by calling build() or calling fit() with some data. Or specify input_shape or batch_input_shape in the first layer for automatic build.
(3)若想自定义学习率,可以引入优化器对象,如下:
from keras.optimizers import Adam
....
model.compile(optimizer=Adam(lr=0.001),loss='categorical_crossentropy',metrics=['accuracy'])
(4)常用损失函数
mse #均方差(回归)
mae #绝对误差(回归)
binary_crossentropy #二值交叉熵(二分类,逻辑回归)
categorical_crossentropy #交叉熵(多分类)
(5)model.fit( ) 和 model.evaluate( ) 中,属性 verbose 表示打印训练或评估信息是否详细
- 0:不打印进度和结果
- 1:打印进度和结果
Epoch 100/100
55000/55000 [==============================] - 1s 9us/step - loss: 6.0211e-05 - acc: 1.0000 - val_loss: 0.1405 - val_acc: 0.9766
10000/10000 [==============================] - 0s 26us/step
- 2:只打印结果
Epoch 100/100
- 1s - loss: 4.6462e-04 - acc: 1.0000 - val_loss: 0.1343 - val_acc: 0.9774
声明:本文转自keras建模的3种方式——序列模型、函数模型、子类模型
keras建模的3种方式——序列模型、函数模型、子类模型的更多相关文章
- python 零散记录(五) import的几种方式 序列解包 条件和循环 强调getattr内建函数
用import关键字导入模块的几种方式: #python是自解释的,不必多说,代码本身就是人可读的 import xxx from xxx import xxx from xxx import xx1 ...
- 增加收入的 6 种方式(很多公司的模型是:一份时间卖多次。比如网易、腾讯。个人赚取收入的本质是:出售时间)good
个人赚取收入的本质是:出售时间.从这个角度出发,下面的公式可以描述个人收入: 个人收入 = 每天可售时间数量 * 单位时间价格 * 单位时间出售次数 在这个公式里,有三个要素: 每天可出售的时间数量 ...
- Keras框架下的保存模型和加载模型
在Keras框架下训练深度学习模型时,一般思路是在训练环境下训练出模型,然后拿训练好的模型(即保存模型相应信息的文件)到生产环境下去部署.在训练过程中我们可能会遇到以下情况: 需要运行很长时间的程序在 ...
- Android-创建启动线程的两种方式
方式一:成为Thread的子类,然后在Thread的子类.start 缺点:存在耦合度(因为线程任务run方法里面的业务逻辑 和 线程启动耦合了) 缺点:Cat extends Thread {} 后 ...
- SpringBoot集成Mybatis实现多表查询的两种方式(基于xml)
下面将在用户和账户进行一对一查询的基础上进行介绍SpringBoot集成Mybatis实现多表查询的基于xml的两种方式. 首先我们先创建两个数据库表,分别是user用户表和account账户表 ...
- 【Keras篇】---Keras初始,两种模型构造方法,利用keras实现手写数字体识别
一.前述 Keras 适合快速体验 ,keras的设计是把大量内部运算都隐藏了,用户始终可以用theano或tensorflow的语句来写扩展功能并和keras结合使用. 二.安装 Pip insta ...
- keras embeding设置初始值的两种方式
随机初始化Embedding from keras.models import Sequential from keras.layers import Embedding import numpy a ...
- Keras中间层输出的两种方式,即特征图可视化
训练好的模型,想要输入中间层的特征图,有两种方式: 1. 通过model.get_layer的方式.创建新的模型,输出为你要的层的名字. 创建模型,debug状态可以看到模型中,base_model/ ...
- Windows10-UWP中设备序列显示不同XAML的三种方式[3]
阅读目录: 概述 DeviceFamily-Type文件夹 DeviceFamily-Type扩展 InitializeComponent重载 结论 概述 Windows10-UWP(Universa ...
- 三种方式实现观察者模式 及 Spring中的事件编程模型
观察者模式可以说是众多设计模式中,最容易理解的设计模式之一了,观察者模式在Spring中也随处可见,面试的时候,面试官可能会问,嘿,你既然读过Spring源码,那你说说Spring中运用的设计模式吧, ...
随机推荐
- 【VSCode】秒下vscode
有时从vscode官网下载速度奇慢甚至失败,介绍一种方法可以秒下 进入官网选择要下载的版本 像我的电脑,下载网址根本打不开 修改下载网址,替换下载地址中红框字符串:vscode.cdn.azure.c ...
- TiKV 服务部署的注意事项
TiKV 服务部署的注意事项 背景 最近发现tikv总是会掉线 不知道是哪里触发了啥样子的bug. 所以想着使用systemd 管理一下, 至少在tikv宕机的时候能够拉起来服务. 二进制文件 pd- ...
- [转帖]Redis命令详解:Keys
https://jackeyzhe.github.io/2018/09/22/Redis%E5%91%BD%E4%BB%A4%E8%AF%A6%E8%A7%A3%EF%BC%9AKeys/ 介绍完Re ...
- [转帖]适用于 Azure VM 的 TCP/IP 性能优化
https://learn.microsoft.com/zh-cn/azure/virtual-network/virtual-network-tcpip-performance-tuning?con ...
- 基于Spring Cache实现Caffeine、jimDB多级缓存实战
作者: 京东零售 王震 背景 在早期参与涅槃氛围标签中台项目中,前台要求接口性能999要求50ms以下,通过设计Caffeine.ehcache堆外缓存.jimDB三级缓存,利用内存.堆外.jimDB ...
- export default 和 export 这两种方式导出的区别
参考地址 https://blog.csdn.net/sleepwalker_1992/article/details/81461543 使用export向外暴露的成员,只能使用{ }的形式来接收,这 ...
- Fabric网络升级(一)
原文来自这里. 本章节主要介绍如何从之前的版本或其他长期支持版本升级至最新版. 从2.1升级到2.2 Fabric v2.1和v2.2都是稳定版,以bug修复和其它形式的代码加固位置.因此,升级不需要 ...
- C# 使用字典将枚举转换为String
枚举 public enum ColorType { Red = 10, Blue = 20, Green = 30, Yellow = 40, } String var A1 = "AAA ...
- go中channel源码剖析
channel 前言 设计的原理 共享内存 csp channel channel的定义 源码剖析 环形队列 创建 写入数据 读取数据 channel的关闭 优雅的关闭 M个receivers,一个s ...
- 基于知识图谱的《红楼梦》人物关系可视化及问答系统(含码源):命名实体识别、关系识别、LTP简单教学
基于知识图谱的<红楼梦>人物关系可视化及问答系统(含码源):命名实体识别.关系识别.LTP简单教学 文件树: app.py是整个系统的主入口 templates文件夹是HTML的页面 |- ...