1 前言

keras是Google公司于2016年发布的以tensorflow为后端的用于深度学习网络训练的高阶API,因接口设计非常人性化,深受程序员的喜爱。

keras建模有3种实现方式——序列模型、函数模型、子类模型。本文以MNIST手写数字为例,用3种建模方式实现。

关于MNIST数据集的说明,见使用TensorFlow实现MNIST数据集分类

笔者工作空间如下:

2 序列模型

sequential.py

from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Sequential
from keras.models import load_model
from keras.layers import Dense #载入数据
def read_data(path):
mnist=input_data.read_data_sets(path,one_hot=True)
train_x,train_y=mnist.train.images,mnist.train.labels,
valid_x,valid_y=mnist.validation.images,mnist.validation.labels,
test_x,test_y=mnist.test.images,mnist.test.labels
return train_x,train_y,valid_x,valid_y,test_x,test_y #序列模型
def DNN(train_x,train_y,valid_x,valid_y):
#创建模型
model=Sequential()
model.add(Dense(64,input_dim=784,activation='relu'))
model.add(Dense(128,activation='relu'))
model.add(Dense(10,activation='softmax'))
#查看网络结构
model.summary()
#编译模型
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
#训练模型
model.fit(train_x,train_y,batch_size=500,nb_epoch=100,verbose=2,validation_data=(valid_x,valid_y))
#保存模型
model.save('sequential.h5') train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
DNN(train_x,train_y,valid_x,valid_y) model=load_model('sequential.h5') #下载模型
pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2) #评估模型
print('test_loss:',pre[0],'- test_acc:',pre[1])

运行结果

Epoch 98/100
- 1s - loss: 4.8694e-04 - acc: 1.0000 - val_loss: 0.1331 - val_acc: 0.9776
Epoch 99/100
- 1s - loss: 4.7432e-04 - acc: 1.0000 - val_loss: 0.1336 - val_acc: 0.9778
Epoch 100/100
- 1s - loss: 4.6462e-04 - acc: 1.0000 - val_loss: 0.1343 - val_acc: 0.9774
test_loss: 0.13972217990085484 - test_acc: 0.9768999993801117

3 函数模型

fun_model.py

from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Model
from keras.models import load_model
from keras.layers import Input,Dense #载入数据
def read_data(path):
mnist=input_data.read_data_sets(path,one_hot=True)
train_x,train_y=mnist.train.images,mnist.train.labels,
valid_x,valid_y=mnist.validation.images,mnist.validation.labels,
test_x,test_y=mnist.test.images,mnist.test.labels
return train_x,train_y,valid_x,valid_y,test_x,test_y #函数模型
def DNN(train_x,train_y,valid_x,valid_y):
#创建模型
inputs=Input(shape=(784,))
x=Dense(64,activation='relu')(inputs)
x=Dense(128,activation='relu')(x)
output=Dense(10,activation='softmax')(x)
model=Model(input=inputs,output=output)
#查看网络结构
model.summary()
#编译模型
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
#训练模型
model.fit(train_x,train_y,batch_size=500,nb_epoch=100,verbose=2,validation_data=(valid_x,valid_y))
#保存模型
model.save('fun_model.h5') train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
DNN(train_x,train_y,valid_x,valid_y) model=load_model('fun_model.h5') #下载模型
pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2) #评估模型
print('test_loss:',pre[0],'- test_acc:',pre[1])

4 子类模型

class_model.py

from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Model
from keras.layers import Dense #载入数据
def read_data(path):
mnist=input_data.read_data_sets(path,one_hot=True)
train_x,train_y=mnist.train.images,mnist.train.labels,
valid_x,valid_y=mnist.validation.images,mnist.validation.labels,
test_x,test_y=mnist.test.images,mnist.test.labels
return train_x,train_y,valid_x,valid_y,test_x,test_y #子类模型
class DNN(Model):
def __init__(self):
super(DNN,self).__init__()
#初始化网络结构
self.dense1=Dense(64,input_dim=784,activation='relu')
self.dense2=Dense(128,activation='relu')
self.dense3=Dense(10,activation='softmax') def call(self,inputs): #回调顺序
x=self.dense1(inputs)
x=self.dense2(x)
x=self.dense3(x)
return x train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
model=DNN()
#编译模型
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
#训练模型
model.fit(train_x,train_y,batch_size=500,nb_epoch=100,verbose=2,validation_data=(valid_x,valid_y))
#查看网络结构
model.summary() pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2) #评估模型
print('test_loss:',pre[0],'- test_acc:',pre[1])

5 注意事项

(1)只有序列模型函数模型能够保存模型,子类模型不能保存模型,即不能调用 model.save()

(2)子类模型中,model.summary() 得放在 model.fit() 之后,否则会报错

ValueError: This model has not yet been built. Build the model first by calling build() or calling fit() with some data. Or specify input_shape or batch_input_shape in the first layer for automatic build.

(3)若想自定义学习率,可以引入优化器对象,如下:

from keras.optimizers import Adam
....
model.compile(optimizer=Adam(lr=0.001),loss='categorical_crossentropy',metrics=['accuracy'])

(4)常用损失函数

mse  #均方差(回归)
mae #绝对误差(回归)
binary_crossentropy #二值交叉熵(二分类,逻辑回归)
categorical_crossentropy #交叉熵(多分类)

(5)model.fit( ) 和 model.evaluate( ) 中,属性 verbose 表示打印训练或评估信息是否详细

  • 0:不打印进度和结果
  • 1:打印进度和结果
Epoch 100/100
55000/55000 [==============================] - 1s 9us/step - loss: 6.0211e-05 - acc: 1.0000 - val_loss: 0.1405 - val_acc: 0.9766
10000/10000 [==============================] - 0s 26us/step
  • 2:只打印结果
Epoch 100/100
- 1s - loss: 4.6462e-04 - acc: 1.0000 - val_loss: 0.1343 - val_acc: 0.9774

​ 声明:本文转自keras建模的3种方式——序列模型、函数模型、子类模型

keras建模的3种方式——序列模型、函数模型、子类模型的更多相关文章

  1. python 零散记录(五) import的几种方式 序列解包 条件和循环 强调getattr内建函数

    用import关键字导入模块的几种方式: #python是自解释的,不必多说,代码本身就是人可读的 import xxx from xxx import xxx from xxx import xx1 ...

  2. 增加收入的 6 种方式(很多公司的模型是:一份时间卖多次。比如网易、腾讯。个人赚取收入的本质是:出售时间)good

    个人赚取收入的本质是:出售时间.从这个角度出发,下面的公式可以描述个人收入: 个人收入 = 每天可售时间数量 * 单位时间价格 * 单位时间出售次数 在这个公式里,有三个要素: 每天可出售的时间数量 ...

  3. Keras框架下的保存模型和加载模型

    在Keras框架下训练深度学习模型时,一般思路是在训练环境下训练出模型,然后拿训练好的模型(即保存模型相应信息的文件)到生产环境下去部署.在训练过程中我们可能会遇到以下情况: 需要运行很长时间的程序在 ...

  4. Android-创建启动线程的两种方式

    方式一:成为Thread的子类,然后在Thread的子类.start 缺点:存在耦合度(因为线程任务run方法里面的业务逻辑 和 线程启动耦合了) 缺点:Cat extends Thread {} 后 ...

  5. SpringBoot集成Mybatis实现多表查询的两种方式(基于xml)

     下面将在用户和账户进行一对一查询的基础上进行介绍SpringBoot集成Mybatis实现多表查询的基于xml的两种方式.   首先我们先创建两个数据库表,分别是user用户表和account账户表 ...

  6. 【Keras篇】---Keras初始,两种模型构造方法,利用keras实现手写数字体识别

    一.前述 Keras 适合快速体验 ,keras的设计是把大量内部运算都隐藏了,用户始终可以用theano或tensorflow的语句来写扩展功能并和keras结合使用. 二.安装 Pip insta ...

  7. keras embeding设置初始值的两种方式

    随机初始化Embedding from keras.models import Sequential from keras.layers import Embedding import numpy a ...

  8. Keras中间层输出的两种方式,即特征图可视化

    训练好的模型,想要输入中间层的特征图,有两种方式: 1. 通过model.get_layer的方式.创建新的模型,输出为你要的层的名字. 创建模型,debug状态可以看到模型中,base_model/ ...

  9. Windows10-UWP中设备序列显示不同XAML的三种方式[3]

    阅读目录: 概述 DeviceFamily-Type文件夹 DeviceFamily-Type扩展 InitializeComponent重载 结论 概述 Windows10-UWP(Universa ...

  10. 三种方式实现观察者模式 及 Spring中的事件编程模型

    观察者模式可以说是众多设计模式中,最容易理解的设计模式之一了,观察者模式在Spring中也随处可见,面试的时候,面试官可能会问,嘿,你既然读过Spring源码,那你说说Spring中运用的设计模式吧, ...

随机推荐

  1. 类外static函数定义要不要加static关键字?

    类外static函数定义要不要加static关键字? 先说答案:不需要. 错误代码: #include<iostream> #include<memory> using nam ...

  2. [转帖]PostgreSQL数据库的版本历史及关键变化

    https://cloud.tencent.com/developer/article/2311843 举报 PostgreSQL是一个强大的开源关系型数据库,它的发展历程充满了创新和卓越的设计.让我 ...

  3. goofys 鲲鹏上面编译挂载与性能测试

    goofys 鲲鹏上面编译挂载与性能测试 介质 使用go进行编译. 官网上面有 amd64的介质,但是没有aarch64的介质 需要自行编译 前几天一直编译失败. 周天在家自己测试了一把,根据gith ...

  4. 【转帖】Linux 调优篇 :虚拟化调优(irqbalance 网卡中断绑定)* 贰

    一.网络流量上不去二.中断绑定2.1 关闭中断平衡守护进程2.2 脱离中断平衡守护进程2.3 手动设置中断的CPU亲和性三. 总结 一.网络流量上不去 在Linux的网络调优方面,如果你发现网络流量上 ...

  5. [转帖]关于 AREX

    https://arextest.github.io/website/zh-Hans/docs/intro/ AREX 介绍​ 背景​ 对于一个初上线的简单服务,只需通过常规的自动化测试加上人工即可解 ...

  6. HTTPS下tomcat与nginx的前端性能比较

    HTTPS下tomcat与nginx的前端性能比较 摘要 之前比较http的web服务器的性能. 发现nginx 比 tomcat 要好 50% 然后想到, https的情况下不知道两者有什么区别 所 ...

  7. Grafana 监控 PG数据库的操作过程

    Grafana 监控 PG数据库的操作过程 容器化运行 postgres-exporter 进行处理 1. 镜像运行 exporter docker run -p 9187:9187 -e DATA_ ...

  8. Nginx 大并发 调优设置

    为了性能测试,放弃部分功能,保证绝对性能. 注意可能不能用于生产环境. 下面开始简单讲解. 1. worker_processes 工作线程数. 发现不用太多 一定不能多于操作系统的CPU核数. 2. ...

  9. [Python] 基于RapidFuzz库实现字符串模糊匹配

    RapidFuzz是一个用于快速字符串模糊匹配的Python库,它能够快速计算两个字符串之间的相似度,并提供与Fuzzywuzzy(已停用)和TheFuzz(Fuzzywuzzy的升级版)类似的接口. ...

  10. 文盘Rust -- 安全连接 TiDB/Mysql

    作者:京东科技 贾世闻 最近在折腾rust与数据库集成,为了偷懒,选了Tidb Cloud Serverless Tier 作为数据源.Tidb 无疑是近五年来最优秀的国产开源分布式数据库,Tidb ...