目录

构建一个简单的模型

序贯(Sequential)模型

网络层的构造

模型训练和参数评价

模型训练

模型的训练

tf.data的数据集

模型评估和预测

基本模型的建立

网络层模型

模型子类函数构建

回调函数Callbacks

模型保存和载入

网络参数保存Weights only

配置参数保存Configuration only

完整模型保存


目前keras API 已经整合到 tensorflow最新版本1.9.0 中,在tensorflow中通过tf.keras就可以调用keras。

import tensorflow as tf
from tensorflow import keras

官方教程为:https://tensorflow.google.cn/guide/keras

tf.keras可以调用所有的keras编译代码,但是有两个限制:

  1. 版本问题,需要通过tf.keras.version确认版本。
  2. 模型保存问题,tf.keras默认使用 checkpoint format格式,而keras模型的保存格式HDF5需要借用函数save_format='h5'

构建一个简单的模型

序贯(Sequential)模型

序贯模型就是是多个网络层的线性堆叠,比如多层感知机,BP神经网络。

tf.keras构建一个简单的全连通网络(即多层感知器)代码如下:

#建立序贯模型
model = keras.Sequential()
#添加全连接层,节点数为64,激活函数为relu函数,dense表示标准的一维全连接层
model.add(keras.layers.Dense(64, activation='relu'))
#添加全连接层,节点数为64,激活函数为relu函数
model.add(keras.layers.Dense(64, activation='relu'))
#添加输出层,输出节点数为10
model.add(keras.layers.Dense(10, activation='softmax'))

其中激活函数详细信息见keras官方文档http://keras-cn.readthedocs.io/en/latest/other/activations/

网络层的构造

通常在tf.keras中,网络层的构造参数主要有以下几个:

  1. 激活函数activation function,默认是没有激活函数的。
  2. 参数初始化,默认通过正态分布初始化(Glorot uniform)
  3. 参数正则化,包括权值初始化和偏置的初始化。
#参数调整
#建立一个sigmoid层
layers.Dense(64, activation='sigmoid')
#或者
layers.Dense(64, activation=tf.sigmoid) #权重L1正则化
layers.Dense(64, kernel_regularizer=keras.regularizers.l1(0.01))
#偏置L2正则化
layers.Dense(64, bias_regularizer=keras.regularizers.l2(0.01)) #权重正交矩阵的随机数初始化
layers.Dense(64, kernel_initializer='orthogonal')
#偏置常数初始化
layers.Dense(64, bias_initializer=keras.initializers.constant(2.0))

模型训练和参数评价

模型训练

模型建立后,通过compile模块确定模型的训练参数(tf.keras.Model.compile)

tf.keras.Model.compile有三个主要参数:

  1. 优化器optimizer:通过tf.train模块调用优化器,可用的优化器类型见:http://keras-cn.readthedocs.io/en/latest/other/optimizers/
  2. 损失函数loss:通过tf.keras.losses模块调用损失函数,可用的损失函数类型见:http://keras-cn.readthedocs.io/en/latest/other/objectives/
  3. 模型评估方法metrics:通过tf.keras.metrics调用评估参数,可用的模型评估方法见:http://keras-cn.readthedocs.io/en/latest/other/metrics/

具体例子如下:

# 配置均方误差回归模型
model.compile(optimizer=tf.train.AdamOptimizer(0.01),
loss='mse', # 均方差
metrics=['mae']) # 平均绝对误差 # 配置分类模型
model.compile(optimizer=tf.train.RMSPropOptimizer(0.01),
loss=keras.losses.categorical_crossentropy, #多类的对数损失
metrics=[keras.metrics.categorical_accuracy]) #多分类问题,所有预测值上的平均正确率

模型的训练

对于小数据集,使用numpy数组,通过tf.keras.Model.fit模块来训练和评估模型。

import numpy as np
#输入数据(1000,32)
data = np.random.random((1000, 32))
#输入标签(1000,10)
labels = np.random.random((1000, 10))
#模型训练
model.fit(data, labels, epochs=10, batch_size=32)

tf.keras.Model.fit模块有三个重要的参数:

  1. 训练轮数epochs:epochs指的就是训练过程中数据将被训练多少轮,一个epoch指的是当一个完整的数据集通过了神经网络一次并且返回了一次。
  2. 批训练大小batch_size:基本上现在的梯度下降都是基于mini-batch的,即将一个完整数据分为batch_size个批次进行训练。详见http://keras-cn.readthedocs.io/en/latest/for_beginners/concepts/#epochs
  3. 验证集validation_data:通常一个模型训练,评估要有训练集,验证集和测试集。验证集就是模型调参时用来评估模型的数据集。

tf.data的数据集

对于大型数据集,常常通过tf.data模块来调用数据,详见https://tensorflow.google.cn/guide/datasets

# 数据实例化
dataset = tf.data.Dataset.from_tensor_slices((data, labels))
dataset = dataset.batch(32)
dataset = dataset.repeat() #模型训练,steps_per_epoch表示每次训练的数据大小类似与batch_size
model.fit(dataset, epochs=10, steps_per_epoch=30)

模型评估和预测

通过 tf.keras.Model.evaluate 和tf.keras.Model.predict可以实现模型的评估和预测。

model.evaluate(x, y, batch_size=32)
model.evaluate(dataset, steps=30) model.predict(x, batch_size=32)
model.predict(dataset, steps=30)

基本模型的建立

网络层模型

通过f.keras.Sequential 可以实现各种的复杂模型,如:

  1. 多输入模型;
  2. 多输出模型;
  3. 参数共享层模型;
  4. 残差网络模型。

具体例子如下:

#输入参数
inputs = keras.Input(shape=(32,)) #网络层的构建
x = keras.layers.Dense(64, activation='relu')(inputs)
x = keras.layers.Dense(64, activation='relu')(x)
#预测
predictions = keras.layers.Dense(10, activation='softmax')(x) #模型实例化
model = keras.Model(inputs=inputs, outputs=predictions) #模型构建
model.compile(optimizer=tf.train.RMSPropOptimizer(0.001),
loss='categorical_crossentropy',
metrics=['accuracy']) #模型训练
model.fit(data, labels, batch_size=32, epochs=5)

模型子类函数构建

通常通过tf.keras.Model构建模型结构, __init__方法初始化模型,call方法进行参数传递。如下所示:

class MyModel(keras.Model):
#模型结构确定
def __init__(self, num_classes=10):
super(MyModel, self).__init__(name='my_model')
self.num_classes = num_classes
#网络层的定义
self.dense_1 = keras.layers.Dense(32, activation='relu')
self.dense_2 = keras.layers.Dense(num_classes, activation='sigmoid')
#参数调用
def call(self, inputs):
#前向传播过程确定
x = self.dense_1(inputs)
return self.dense_2(x) def compute_output_shape(self, input_shape):
#输出参数确定
shape = tf.TensorShape(input_shape).as_list()
shape[-1] = self.num_classes
return tf.TensorShape(shape) #模型初始化
model = MyModel(num_classes=10) #模型构建
model.compile(optimizer=tf.train.RMSPropOptimizer(0.001),
loss='categorical_crossentropy',
metrics=['accuracy']) #模型训练
model.fit(data, labels, batch_size=32, epochs=5)

回调函数Callbacks

回调函数是一组在训练的特定阶段被调用的函数集,你可以使用回调函数来观察训练过程中网络内部的状态和统计信息。通过传递回调函数列表到模型fit()中,即可在给定的训练阶段调用该函数集中的函数。详见:http://keras-cn.readthedocs.io/en/latest/other/callbacks/。主要回调函数有:

  1. tf.keras.callbacks.ModelCheckpoint:模型保存
  2. tf.keras.callbacks.LearningRateScheduler:学习率调整
  3. tf.keras.callbacks.EarlyStopping:中断训练
  4. tf.keras.callbacks.TensorBoard:tensorboard的使用

模型保存和载入

tf.keras有两种模型保存方式

网络参数保存Weights only

#模型保存为tensorflow默认格式
model.save_weights('./my_model') #载入模型
model.load_weights('my_model') #模型保存为keras默认格式,包含其他优化参数
model.save_weights('my_model.h5', save_format='h5') #载入模型
model.load_weights('my_model.h5')

配置参数保存Configuration only

保存一个没有模型参数只有配置参数的模型, Keras支持 JSON和YAML序列化格式:

# 模型保存
json_string = model.to_json()
yaml_string = model.to_yaml()
#模型载入
fresh_model = keras.models.from_json(json_string)
fresh_model = keras.models.from_yaml(yaml_string)

完整模型保存

将原来模型所用信息进行保存:

#模型建立
model = keras.Sequential([
keras.layers.Dense(10, activation='softmax', input_shape=(32,)),
keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit(data, targets, batch_size=32, epochs=5) #保存为keras格式文件
model.save('my_model.h5') # 模型载入
model = keras.models.load_model('my_model.h5')

[深度学习] tf.keras入门1-基本函数介绍的更多相关文章

  1. [深度学习] tf.keras入门2-分类

    目录 Fashion MNIST数据库 分类模型的建立 模型预测 总体代码 主要介绍基于tf.keras的Fashion MNIST数据库分类, 官方文档地址为:https://tensorflow. ...

  2. [深度学习] tf.keras入门4-过拟合和欠拟合

    过拟合和欠拟合 简单来说过拟合就是模型训练集精度高,测试集训练精度低:欠拟合则是模型训练集和测试集训练精度都低. 官方文档地址为 https://tensorflow.google.cn/tutori ...

  3. [深度学习] tf.keras入门5-模型保存和载入

    目录 设置 基于checkpoints的模型保存 通过ModelCheckpoint模块来自动保存数据 手动保存权重 整个模型保存 总体代码 模型可以在训练中或者训练完成后保存.具体文档参考:http ...

  4. [深度学习] tf.keras入门3-回归

    目录 波士顿房价数据集 数据集 数据归一化 模型训练和预测 模型建立和训练 模型预测 总结 回归主要基于波士顿房价数据库进行建模,官方文档地址为:https://tensorflow.google.c ...

  5. 深度学习:Keras入门(一)之基础篇

    1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深度学习框架. Keras是一个高层神经网络API,支持快速实验,能够把你的idea迅速转换为结 ...

  6. 深度学习:Keras入门(一)之基础篇【转】

    本文转载自:http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorfl ...

  7. 深度学习:Keras入门(一)之基础篇(转)

    转自http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深 ...

  8. 深度学习:Keras入门(二)之卷积神经网络(CNN)

    说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么是卷积? 简单来说,卷积(或内积)就是一种先把对应位置相乘然后再把结果相加的运算.(具体含义或者数学公式 ...

  9. 深度学习:Keras入门(二)之卷积神经网络(CNN)【转】

    本文转载自:https://www.cnblogs.com/lc1217/p/7324935.html 说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么 ...

随机推荐

  1. 【博学谷学习记录】超强总结,用心分享|Linux修改文件权限方法总结

    一.介绍 linux中"一切皆文件".每个文件都设定了针对不同用户的访问权限. 文件权限主要针对以下三种对象: 属主:拥有者 属组:所属的组 其他人:不属于上述两类 二.文件权限 ...

  2. Python-D4-语法入门2

    目录 数据类型 数据类型之整型int 数据类型之浮点型float 数据类型之字符串str 数据类型之列表list 数据类型之字典dict 基本数据类型之布尔值bool 基本数据类型之元祖tuple 基 ...

  3. Resilience4J通过yml设置circuitBreaker

    介绍 Resilience4j是一个轻量级.易于使用的容错库,其灵感来自Netflix Hystrix,但专为Java 8和函数式编程设计. springcloud2020升级以后Hystrix被官方 ...

  4. JUC(5)BlockingQueue四组API

    1.读写锁ReadWriteLock package com.readlock; import java.util.HashMap; import java.util.Map; /** * ReadW ...

  5. 齐博x1标签之异步加载标签数据

    为什么要异步加载标签?他有什么好处 如果一个页面的标签太多,又或者是页面中某一个标签调用数据太慢的话,就会拖慢整个页面的打开,非常影响用户体验.这个时候,用异步加载的话,就可以一块一块的显示,用户体验 ...

  6. letcode刷题记录-day02-回文数

    回文数 题目描述 给定一个整数数组 nums 和一个整数目标值 target,请你在该数组中找出 和为目标值 target  的那 两个 整数,并返回它们的数组下标. 你可以假设每种输入只会对应一个答 ...

  7. jvm调优思路及调优案例

    jvm调优思路及调优案例 ​ 我们说jvm调优,其实就是不断测试调整jvm的运行参数,尽可能让对象都在新生代(Eden)里分配和回收,尽量别让太多对象频繁进入老年代,避免频繁对老年代进行垃圾回收,同时 ...

  8. 1B踩坑大王

    题目链接 题目大意: 人们常用的电子表格软件(比如: Excel)采用如下所述的坐标系统: 第一列被标为 A,第二列为 B,以此类推,第 262626 列为 Z.接下来为由两个字母构成的列号: 第 2 ...

  9. Centos7.6分区、格式化、自动挂载磁盘

    个人名片: 对人间的热爱与歌颂,可抵岁月冗长 Github‍:念舒_C.ying CSDN主页️:念舒_C.ying 个人博客 :念舒_C.ying 目录 1. 添加硬盘 2. 执行fdisk -l ...

  10. HDC.Cloud Day | 全国首场上海站告捷,聚开发者力量造梦、探梦、筑梦

    摘要:11月20日,首个华为云开发者日HDC.Cloud Day在上海成功举行. 本文分享自华为云社区<HDC.Cloud Day | 全国首场上海站告捷,聚开发者力量造梦.探梦.筑梦>, ...