1 前言

keras是Google公司于2016年发布的以tensorflow为后端的用于深度学习网络训练的高阶API,因接口设计非常人性化,深受程序员的喜爱。

keras建模有3种实现方式——序列模型、函数模型、子类模型。本文以MNIST手写数字为例,用3种建模方式实现。

关于MNIST数据集的说明,见使用TensorFlow实现MNIST数据集分类

笔者工作空间如下:

2 序列模型

sequential.py

from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Sequential
from keras.models import load_model
from keras.layers import Dense #载入数据
def read_data(path):
mnist=input_data.read_data_sets(path,one_hot=True)
train_x,train_y=mnist.train.images,mnist.train.labels,
valid_x,valid_y=mnist.validation.images,mnist.validation.labels,
test_x,test_y=mnist.test.images,mnist.test.labels
return train_x,train_y,valid_x,valid_y,test_x,test_y #序列模型
def DNN(train_x,train_y,valid_x,valid_y):
#创建模型
model=Sequential()
model.add(Dense(64,input_dim=784,activation='relu'))
model.add(Dense(128,activation='relu'))
model.add(Dense(10,activation='softmax'))
#查看网络结构
model.summary()
#编译模型
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
#训练模型
model.fit(train_x,train_y,batch_size=500,nb_epoch=100,verbose=2,validation_data=(valid_x,valid_y))
#保存模型
model.save('sequential.h5') train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
DNN(train_x,train_y,valid_x,valid_y) model=load_model('sequential.h5') #下载模型
pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2) #评估模型
print('test_loss:',pre[0],'- test_acc:',pre[1])

运行结果

Epoch 98/100
- 1s - loss: 4.8694e-04 - acc: 1.0000 - val_loss: 0.1331 - val_acc: 0.9776
Epoch 99/100
- 1s - loss: 4.7432e-04 - acc: 1.0000 - val_loss: 0.1336 - val_acc: 0.9778
Epoch 100/100
- 1s - loss: 4.6462e-04 - acc: 1.0000 - val_loss: 0.1343 - val_acc: 0.9774
test_loss: 0.13972217990085484 - test_acc: 0.9768999993801117

3 函数模型

fun_model.py

from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Model
from keras.models import load_model
from keras.layers import Input,Dense #载入数据
def read_data(path):
mnist=input_data.read_data_sets(path,one_hot=True)
train_x,train_y=mnist.train.images,mnist.train.labels,
valid_x,valid_y=mnist.validation.images,mnist.validation.labels,
test_x,test_y=mnist.test.images,mnist.test.labels
return train_x,train_y,valid_x,valid_y,test_x,test_y #函数模型
def DNN(train_x,train_y,valid_x,valid_y):
#创建模型
inputs=Input(shape=(784,))
x=Dense(64,activation='relu')(inputs)
x=Dense(128,activation='relu')(x)
output=Dense(10,activation='softmax')(x)
model=Model(input=inputs,output=output)
#查看网络结构
model.summary()
#编译模型
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
#训练模型
model.fit(train_x,train_y,batch_size=500,nb_epoch=100,verbose=2,validation_data=(valid_x,valid_y))
#保存模型
model.save('fun_model.h5') train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
DNN(train_x,train_y,valid_x,valid_y) model=load_model('fun_model.h5') #下载模型
pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2) #评估模型
print('test_loss:',pre[0],'- test_acc:',pre[1])

4 子类模型

class_model.py

from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Model
from keras.layers import Dense #载入数据
def read_data(path):
mnist=input_data.read_data_sets(path,one_hot=True)
train_x,train_y=mnist.train.images,mnist.train.labels,
valid_x,valid_y=mnist.validation.images,mnist.validation.labels,
test_x,test_y=mnist.test.images,mnist.test.labels
return train_x,train_y,valid_x,valid_y,test_x,test_y #子类模型
class DNN(Model):
def __init__(self):
super(DNN,self).__init__()
#初始化网络结构
self.dense1=Dense(64,input_dim=784,activation='relu')
self.dense2=Dense(128,activation='relu')
self.dense3=Dense(10,activation='softmax') def call(self,inputs): #回调顺序
x=self.dense1(inputs)
x=self.dense2(x)
x=self.dense3(x)
return x train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
model=DNN()
#编译模型
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
#训练模型
model.fit(train_x,train_y,batch_size=500,nb_epoch=100,verbose=2,validation_data=(valid_x,valid_y))
#查看网络结构
model.summary() pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2) #评估模型
print('test_loss:',pre[0],'- test_acc:',pre[1])

5 注意事项

(1)只有序列模型函数模型能够保存模型,子类模型不能保存模型,即不能调用 model.save()

(2)子类模型中,model.summary() 得放在 model.fit() 之后,否则会报错

ValueError: This model has not yet been built. Build the model first by calling build() or calling fit() with some data. Or specify input_shape or batch_input_shape in the first layer for automatic build.

(3)若想自定义学习率,可以引入优化器对象,如下:

from keras.optimizers import Adam
....
model.compile(optimizer=Adam(lr=0.001),loss='categorical_crossentropy',metrics=['accuracy'])

(4)常用损失函数

mse  #均方差(回归)
mae #绝对误差(回归)
binary_crossentropy #二值交叉熵(二分类,逻辑回归)
categorical_crossentropy #交叉熵(多分类)

(5)model.fit( ) 和 model.evaluate( ) 中,属性 verbose 表示打印训练或评估信息是否详细

  • 0:不打印进度和结果
  • 1:打印进度和结果
Epoch 100/100
55000/55000 [==============================] - 1s 9us/step - loss: 6.0211e-05 - acc: 1.0000 - val_loss: 0.1405 - val_acc: 0.9766
10000/10000 [==============================] - 0s 26us/step
  • 2:只打印结果
Epoch 100/100
- 1s - loss: 4.6462e-04 - acc: 1.0000 - val_loss: 0.1343 - val_acc: 0.9774

​ 声明:本文转自keras建模的3种方式——序列模型、函数模型、子类模型

keras建模的3种方式——序列模型、函数模型、子类模型的更多相关文章

  1. python 零散记录(五) import的几种方式 序列解包 条件和循环 强调getattr内建函数

    用import关键字导入模块的几种方式: #python是自解释的,不必多说,代码本身就是人可读的 import xxx from xxx import xxx from xxx import xx1 ...

  2. 增加收入的 6 种方式(很多公司的模型是:一份时间卖多次。比如网易、腾讯。个人赚取收入的本质是:出售时间)good

    个人赚取收入的本质是:出售时间.从这个角度出发,下面的公式可以描述个人收入: 个人收入 = 每天可售时间数量 * 单位时间价格 * 单位时间出售次数 在这个公式里,有三个要素: 每天可出售的时间数量 ...

  3. Keras框架下的保存模型和加载模型

    在Keras框架下训练深度学习模型时,一般思路是在训练环境下训练出模型,然后拿训练好的模型(即保存模型相应信息的文件)到生产环境下去部署.在训练过程中我们可能会遇到以下情况: 需要运行很长时间的程序在 ...

  4. Android-创建启动线程的两种方式

    方式一:成为Thread的子类,然后在Thread的子类.start 缺点:存在耦合度(因为线程任务run方法里面的业务逻辑 和 线程启动耦合了) 缺点:Cat extends Thread {} 后 ...

  5. SpringBoot集成Mybatis实现多表查询的两种方式(基于xml)

     下面将在用户和账户进行一对一查询的基础上进行介绍SpringBoot集成Mybatis实现多表查询的基于xml的两种方式.   首先我们先创建两个数据库表,分别是user用户表和account账户表 ...

  6. 【Keras篇】---Keras初始,两种模型构造方法,利用keras实现手写数字体识别

    一.前述 Keras 适合快速体验 ,keras的设计是把大量内部运算都隐藏了,用户始终可以用theano或tensorflow的语句来写扩展功能并和keras结合使用. 二.安装 Pip insta ...

  7. keras embeding设置初始值的两种方式

    随机初始化Embedding from keras.models import Sequential from keras.layers import Embedding import numpy a ...

  8. Keras中间层输出的两种方式,即特征图可视化

    训练好的模型,想要输入中间层的特征图,有两种方式: 1. 通过model.get_layer的方式.创建新的模型,输出为你要的层的名字. 创建模型,debug状态可以看到模型中,base_model/ ...

  9. Windows10-UWP中设备序列显示不同XAML的三种方式[3]

    阅读目录: 概述 DeviceFamily-Type文件夹 DeviceFamily-Type扩展 InitializeComponent重载 结论 概述 Windows10-UWP(Universa ...

  10. 三种方式实现观察者模式 及 Spring中的事件编程模型

    观察者模式可以说是众多设计模式中,最容易理解的设计模式之一了,观察者模式在Spring中也随处可见,面试的时候,面试官可能会问,嘿,你既然读过Spring源码,那你说说Spring中运用的设计模式吧, ...

随机推荐

  1. [转帖]docker 最新版本升级

    文章目录 前言 一.卸载低版本docker 1.1 检查docker版本 1.2 删除docker 二.开始安装 2.1 安装所需依赖 2.2 设置docker yum源 2.3 查看所有可用版本 2 ...

  2. [转帖]OutOfMemory自动重启程序

    OutOfMemory以后程序已经假死,无法再提供服务,最好的做法是dump内存,发送警告,然后重启服务 我的方案:利用at命令延迟启动 但有一个问题,at最多支持分钟操作,也就是说要1分钟以后才能启 ...

  3. 微软Windows Sever系统也将强制要求TPM及CPU兼容

    https://www.cnbeta.com/articles/tech/1238647.htm 去年微软推出Win11系统时,TPM安全模块以及Intel 8代酷睿/AMD锐龙2代及以上的硬件要求引 ...

  4. Redis6.x 在Windows上面编译安装的过程

    背景说明 在github上面仅能够找到 redis3.2.100的Windows安装文件 比较新的版本比较难以找到, 同事经常出现这个版本的redis卡死的情况, 所以想尝试进行一下升级. 第一部分下 ...

  5. TypeScript类的修饰符 public private protected的详细讲解

    简单介绍一下public private protected public:当一个类的成员变量没有修饰的时候,默认的就是 public 进行修饰.外部是可以进行访问的. private属性只能够在父类 ...

  6. 你不知道的Promise状态变化机制

    1.Promise中PromiseStatus的三种状态 var p = new Promise((resolve, reject) => { // resolve 既是函数也是参数,它用于处理 ...

  7. 【解决一个小问题】proto文件中的enum,去掉长长的重复的enum名字

    在proto中定义的enum,通常类型名字都会带上enum的前缀,很丑陋,如何去掉呢? enum DataSourceType{ NotUse = 0; MySQL = 1; ElasticSearc ...

  8. Gorm 关联关系介绍与基本使用

    目录 一 Belongs To(一对一) 1.1 Belongs To 1.2 重写外键 1.3 重写引用(一般不用) 1.4 Belongs to 的 CRUD 1.5 预加载 1.6 外键约束 二 ...

  9. 复原docker中容器的启动命令

    复原 docker 容器的启动命令 前言 查看 docker 容器的启动命令 参考 复原 docker 容器的启动命令 前言 不规范的操作,在启动 docker 容器,没有留命令脚本,或者没有使用 d ...

  10. 设计模式学习-使用go实现解释器模式

    解释器模式 定义 优点 缺点 适用范围 代码实现 参考 解释器模式 定义 解释器模式(interpreter):给定一种语言,定义它的文法的一种表示,并定一个解释器,这个解释器使用该表示来解释语言中的 ...