[深度学习] tf.keras入门1-基本函数介绍
目录
目前keras API 已经整合到 tensorflow最新版本1.9.0 中,在tensorflow中通过tf.keras就可以调用keras。
import tensorflow as tf
from tensorflow import keras
官方教程为:https://tensorflow.google.cn/guide/keras
tf.keras可以调用所有的keras编译代码,但是有两个限制:
- 版本问题,需要通过tf.keras.version确认版本。
- 模型保存问题,tf.keras默认使用 checkpoint format格式,而keras模型的保存格式HDF5需要借用函数save_format='h5'
构建一个简单的模型
序贯(Sequential)模型
序贯模型就是是多个网络层的线性堆叠,比如多层感知机,BP神经网络。
tf.keras构建一个简单的全连通网络(即多层感知器)代码如下:
#建立序贯模型
model = keras.Sequential()
#添加全连接层,节点数为64,激活函数为relu函数,dense表示标准的一维全连接层
model.add(keras.layers.Dense(64, activation='relu'))
#添加全连接层,节点数为64,激活函数为relu函数
model.add(keras.layers.Dense(64, activation='relu'))
#添加输出层,输出节点数为10
model.add(keras.layers.Dense(10, activation='softmax'))
其中激活函数详细信息见keras官方文档http://keras-cn.readthedocs.io/en/latest/other/activations/
网络层的构造
通常在tf.keras中,网络层的构造参数主要有以下几个:
- 激活函数activation function,默认是没有激活函数的。
- 参数初始化,默认通过正态分布初始化(Glorot uniform)
- 参数正则化,包括权值初始化和偏置的初始化。
#参数调整
#建立一个sigmoid层
layers.Dense(64, activation='sigmoid')
#或者
layers.Dense(64, activation=tf.sigmoid)
#权重L1正则化
layers.Dense(64, kernel_regularizer=keras.regularizers.l1(0.01))
#偏置L2正则化
layers.Dense(64, bias_regularizer=keras.regularizers.l2(0.01))
#权重正交矩阵的随机数初始化
layers.Dense(64, kernel_initializer='orthogonal')
#偏置常数初始化
layers.Dense(64, bias_initializer=keras.initializers.constant(2.0))
模型训练和参数评价
模型训练
模型建立后,通过compile模块确定模型的训练参数(tf.keras.Model.compile)
tf.keras.Model.compile有三个主要参数:
- 优化器optimizer:通过tf.train模块调用优化器,可用的优化器类型见:http://keras-cn.readthedocs.io/en/latest/other/optimizers/
- 损失函数loss:通过tf.keras.losses模块调用损失函数,可用的损失函数类型见:http://keras-cn.readthedocs.io/en/latest/other/objectives/
- 模型评估方法metrics:通过tf.keras.metrics调用评估参数,可用的模型评估方法见:http://keras-cn.readthedocs.io/en/latest/other/metrics/
具体例子如下:
# 配置均方误差回归模型
model.compile(optimizer=tf.train.AdamOptimizer(0.01),
loss='mse', # 均方差
metrics=['mae']) # 平均绝对误差
# 配置分类模型
model.compile(optimizer=tf.train.RMSPropOptimizer(0.01),
loss=keras.losses.categorical_crossentropy, #多类的对数损失
metrics=[keras.metrics.categorical_accuracy]) #多分类问题,所有预测值上的平均正确率
模型的训练
对于小数据集,使用numpy数组,通过tf.keras.Model.fit模块来训练和评估模型。
import numpy as np
#输入数据(1000,32)
data = np.random.random((1000, 32))
#输入标签(1000,10)
labels = np.random.random((1000, 10))
#模型训练
model.fit(data, labels, epochs=10, batch_size=32)
tf.keras.Model.fit模块有三个重要的参数:
- 训练轮数epochs:epochs指的就是训练过程中数据将被训练多少轮,一个epoch指的是当一个完整的数据集通过了神经网络一次并且返回了一次。
- 批训练大小batch_size:基本上现在的梯度下降都是基于mini-batch的,即将一个完整数据分为batch_size个批次进行训练。详见http://keras-cn.readthedocs.io/en/latest/for_beginners/concepts/#epochs。
- 验证集validation_data:通常一个模型训练,评估要有训练集,验证集和测试集。验证集就是模型调参时用来评估模型的数据集。
tf.data的数据集
对于大型数据集,常常通过tf.data模块来调用数据,详见https://tensorflow.google.cn/guide/datasets
# 数据实例化
dataset = tf.data.Dataset.from_tensor_slices((data, labels))
dataset = dataset.batch(32)
dataset = dataset.repeat()
#模型训练,steps_per_epoch表示每次训练的数据大小类似与batch_size
model.fit(dataset, epochs=10, steps_per_epoch=30)
模型评估和预测
通过 tf.keras.Model.evaluate 和tf.keras.Model.predict可以实现模型的评估和预测。
model.evaluate(x, y, batch_size=32)
model.evaluate(dataset, steps=30)
model.predict(x, batch_size=32)
model.predict(dataset, steps=30)
基本模型的建立
网络层模型
通过f.keras.Sequential 可以实现各种的复杂模型,如:
- 多输入模型;
- 多输出模型;
- 参数共享层模型;
- 残差网络模型。
具体例子如下:
#输入参数
inputs = keras.Input(shape=(32,))
#网络层的构建
x = keras.layers.Dense(64, activation='relu')(inputs)
x = keras.layers.Dense(64, activation='relu')(x)
#预测
predictions = keras.layers.Dense(10, activation='softmax')(x)
#模型实例化
model = keras.Model(inputs=inputs, outputs=predictions)
#模型构建
model.compile(optimizer=tf.train.RMSPropOptimizer(0.001),
loss='categorical_crossentropy',
metrics=['accuracy'])
#模型训练
model.fit(data, labels, batch_size=32, epochs=5)
模型子类函数构建
通常通过tf.keras.Model构建模型结构, __init__方法初始化模型,call方法进行参数传递。如下所示:
class MyModel(keras.Model):
#模型结构确定
def __init__(self, num_classes=10):
super(MyModel, self).__init__(name='my_model')
self.num_classes = num_classes
#网络层的定义
self.dense_1 = keras.layers.Dense(32, activation='relu')
self.dense_2 = keras.layers.Dense(num_classes, activation='sigmoid')
#参数调用
def call(self, inputs):
#前向传播过程确定
x = self.dense_1(inputs)
return self.dense_2(x)
def compute_output_shape(self, input_shape):
#输出参数确定
shape = tf.TensorShape(input_shape).as_list()
shape[-1] = self.num_classes
return tf.TensorShape(shape)
#模型初始化
model = MyModel(num_classes=10)
#模型构建
model.compile(optimizer=tf.train.RMSPropOptimizer(0.001),
loss='categorical_crossentropy',
metrics=['accuracy'])
#模型训练
model.fit(data, labels, batch_size=32, epochs=5)
回调函数Callbacks
回调函数是一组在训练的特定阶段被调用的函数集,你可以使用回调函数来观察训练过程中网络内部的状态和统计信息。通过传递回调函数列表到模型fit()中,即可在给定的训练阶段调用该函数集中的函数。详见:http://keras-cn.readthedocs.io/en/latest/other/callbacks/。主要回调函数有:
- tf.keras.callbacks.ModelCheckpoint:模型保存
- tf.keras.callbacks.LearningRateScheduler:学习率调整
- tf.keras.callbacks.EarlyStopping:中断训练
- tf.keras.callbacks.TensorBoard:tensorboard的使用
模型保存和载入
tf.keras有两种模型保存方式
网络参数保存Weights only
#模型保存为tensorflow默认格式
model.save_weights('./my_model')
#载入模型
model.load_weights('my_model')
#模型保存为keras默认格式,包含其他优化参数
model.save_weights('my_model.h5', save_format='h5')
#载入模型
model.load_weights('my_model.h5')
配置参数保存Configuration only
保存一个没有模型参数只有配置参数的模型, Keras支持 JSON和YAML序列化格式:
# 模型保存
json_string = model.to_json()
yaml_string = model.to_yaml()
#模型载入
fresh_model = keras.models.from_json(json_string)
fresh_model = keras.models.from_yaml(yaml_string)
完整模型保存
将原来模型所用信息进行保存:
#模型建立
model = keras.Sequential([
keras.layers.Dense(10, activation='softmax', input_shape=(32,)),
keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit(data, targets, batch_size=32, epochs=5)
#保存为keras格式文件
model.save('my_model.h5')
# 模型载入
model = keras.models.load_model('my_model.h5')
[深度学习] tf.keras入门1-基本函数介绍的更多相关文章
- [深度学习] tf.keras入门2-分类
目录 Fashion MNIST数据库 分类模型的建立 模型预测 总体代码 主要介绍基于tf.keras的Fashion MNIST数据库分类, 官方文档地址为:https://tensorflow. ...
- [深度学习] tf.keras入门4-过拟合和欠拟合
过拟合和欠拟合 简单来说过拟合就是模型训练集精度高,测试集训练精度低:欠拟合则是模型训练集和测试集训练精度都低. 官方文档地址为 https://tensorflow.google.cn/tutori ...
- [深度学习] tf.keras入门5-模型保存和载入
目录 设置 基于checkpoints的模型保存 通过ModelCheckpoint模块来自动保存数据 手动保存权重 整个模型保存 总体代码 模型可以在训练中或者训练完成后保存.具体文档参考:http ...
- [深度学习] tf.keras入门3-回归
目录 波士顿房价数据集 数据集 数据归一化 模型训练和预测 模型建立和训练 模型预测 总结 回归主要基于波士顿房价数据库进行建模,官方文档地址为:https://tensorflow.google.c ...
- 深度学习:Keras入门(一)之基础篇
1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深度学习框架. Keras是一个高层神经网络API,支持快速实验,能够把你的idea迅速转换为结 ...
- 深度学习:Keras入门(一)之基础篇【转】
本文转载自:http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorfl ...
- 深度学习:Keras入门(一)之基础篇(转)
转自http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深 ...
- 深度学习:Keras入门(二)之卷积神经网络(CNN)
说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么是卷积? 简单来说,卷积(或内积)就是一种先把对应位置相乘然后再把结果相加的运算.(具体含义或者数学公式 ...
- 深度学习:Keras入门(二)之卷积神经网络(CNN)【转】
本文转载自:https://www.cnblogs.com/lc1217/p/7324935.html 说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么 ...
随机推荐
- 【博学谷学习记录】超强总结,用心分享|Linux修改文件权限方法总结
一.介绍 linux中"一切皆文件".每个文件都设定了针对不同用户的访问权限. 文件权限主要针对以下三种对象: 属主:拥有者 属组:所属的组 其他人:不属于上述两类 二.文件权限 ...
- Python-D4-语法入门2
目录 数据类型 数据类型之整型int 数据类型之浮点型float 数据类型之字符串str 数据类型之列表list 数据类型之字典dict 基本数据类型之布尔值bool 基本数据类型之元祖tuple 基 ...
- Resilience4J通过yml设置circuitBreaker
介绍 Resilience4j是一个轻量级.易于使用的容错库,其灵感来自Netflix Hystrix,但专为Java 8和函数式编程设计. springcloud2020升级以后Hystrix被官方 ...
- JUC(5)BlockingQueue四组API
1.读写锁ReadWriteLock package com.readlock; import java.util.HashMap; import java.util.Map; /** * ReadW ...
- 齐博x1标签之异步加载标签数据
为什么要异步加载标签?他有什么好处 如果一个页面的标签太多,又或者是页面中某一个标签调用数据太慢的话,就会拖慢整个页面的打开,非常影响用户体验.这个时候,用异步加载的话,就可以一块一块的显示,用户体验 ...
- letcode刷题记录-day02-回文数
回文数 题目描述 给定一个整数数组 nums 和一个整数目标值 target,请你在该数组中找出 和为目标值 target 的那 两个 整数,并返回它们的数组下标. 你可以假设每种输入只会对应一个答 ...
- jvm调优思路及调优案例
jvm调优思路及调优案例 我们说jvm调优,其实就是不断测试调整jvm的运行参数,尽可能让对象都在新生代(Eden)里分配和回收,尽量别让太多对象频繁进入老年代,避免频繁对老年代进行垃圾回收,同时 ...
- 1B踩坑大王
题目链接 题目大意: 人们常用的电子表格软件(比如: Excel)采用如下所述的坐标系统: 第一列被标为 A,第二列为 B,以此类推,第 262626 列为 Z.接下来为由两个字母构成的列号: 第 2 ...
- Centos7.6分区、格式化、自动挂载磁盘
个人名片: 对人间的热爱与歌颂,可抵岁月冗长 Github:念舒_C.ying CSDN主页️:念舒_C.ying 个人博客 :念舒_C.ying 目录 1. 添加硬盘 2. 执行fdisk -l ...
- HDC.Cloud Day | 全国首场上海站告捷,聚开发者力量造梦、探梦、筑梦
摘要:11月20日,首个华为云开发者日HDC.Cloud Day在上海成功举行. 本文分享自华为云社区<HDC.Cloud Day | 全国首场上海站告捷,聚开发者力量造梦.探梦.筑梦>, ...