TensorFlow搭建模型方式总结
引言
TensorFlow提供了多种API,使得入门者和专家可以根据自己的需求选择不同的API搭建模型。
基于Keras Sequential API搭建模型
Sequential适用于线性堆叠的方式搭建模型,即每层只有一个输入和输出。
import tensorflow as tf # 导入手写数字数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data() # 数据标准化
x_train, x_test = x_train/255, x_test/255 # 使用Sequential搭建模型
# 方式一
model = tf.keras.models.Sequential([ # 加入CNN层(2D), 使用了3个卷积核, 卷积核的尺寸为3X3, 步长为1, 输入图像的维度为28X28X1
tf.keras.layers.Conv2D(3, kernel_size=3, strides=1, input_shape=(28, 28, 1)), # 加入激活函数
tf.keras.layers.Activation('relu'), # 加入2X2池化层, 步长为2
tf.keras.layers.MaxPool2D(pool_size=2, strides=2), # 把图像数据平铺
tf.keras.layers.Flatten(), # 加入全连接层, 设置神经元为128个, 设置relu激活函数
tf.keras.layers.Dense(128, activation='relu'), # 加入全连接层(输出层), 设置输出数量为10, 设置softmax激活函数
tf.keras.layers.Dense(10, activation='softmax')
]) # 方式二
model2 = tf.keras.models.Sequential()
model2.add(tf.keras.layers.Conv2D(3, kernel_size=3, strides=1, input_shape=(28, 28, 1)))
model2.add(tf.keras.layers.Activation('relu'))
model2.add(tf.keras.layers.MaxPool2D(pool_size=2, strides=2))
model2.add(tf.keras.layers.Flatten())
model2.add(tf.keras.layers.Dense(128, activation='relu'))
model2.add(tf.keras.layers.Dense(10, activation='softmax')) # 模型概览
model.summary() """
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 26, 26, 3) 30 activation (Activation) (None, 26, 26, 3) 0 max_pooling2d (MaxPooling2D (None, 13, 13, 3) 0
) flatten (Flatten) (None, 507) 0 dense (Dense) (None, 128) 65024 dense_1 (Dense) (None, 10) 1290 =================================================================
Total params: 66,344
Trainable params: 66,344
""" # 编译 为模型加入优化器, 损失函数, 评估指标
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
) # 训练模型, 2个epoch, batch size为100
model.fit(x_train, y_train, epochs=2, batch_size=100)
基于Keras 函数API搭建模型
由于Sequential是线性堆叠的,只有一个输入和输出,但是当我们需要搭建多输入模型时,如输入图片、文本描述等,这几类信息可能需要分别使用CNN,RNN模型提取信息,然后汇总信息到最后的神经网络中预测输出。或者是多输出任务,如根据音乐预测音乐类型和发行时间。亦或是一些非线性的拓扑网络结构模型,如使用残差链接、Inception等。上述这些情况的网络都不是线性搭建,要搭建如此复杂的网络,需要使用函数API来搭建。
简单实例
import tensorflow as tf # 导入手写数字数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data() # 数据标准化
x_train, x_test = x_train/255, x_test/255 input_tensor = tf.keras.layers.Input(shape=(28, 28, 1)) # CNN层(2D), 使用了3个卷积核, 卷积核的尺寸为3X3, 步长为1, 输入图像的维度为28X28X1
x = tf.keras.layers.Conv2D(3, kernel_size=3, strides=1)(input_tensor) # 激活函数
x = tf.keras.layers.Activation('relu')(x) # 2X2池化层, 步长为2
x = tf.keras.layers.MaxPool2D(pool_size=2, strides=2)(x) # 把图像数据平铺
x = tf.keras.layers.Flatten()(x) # 全连接层, 设置神经元为128个, 设置relu激活函数
x = tf.keras.layers.Dense(128, activation='relu')(x) # 全连接层(输出层), 设置输出数量为10, 设置softmax激活函数
output = tf.keras.layers.Dense(10, activation='softmax')(x) model = tf.keras.models.Model(input_tensor, output) # 模型概览
model.summary() """
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 28, 28, 1)] 0 conv2d (Conv2D) (None, 26, 26, 3) 30 activation (Activation) (None, 26, 26, 3) 0 max_pooling2d (MaxPooling2D (None, 13, 13, 3) 0
) flatten (Flatten) (None, 507) 0 dense (Dense) (None, 128) 65024 dense_1 (Dense) (None, 10) 1290 =================================================================
Total params: 66,344
Trainable params: 66,344
Non-trainable params: 0
_________________________________________________________________ """ # 编译 为模型加入优化器, 损失函数, 评估指标
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
) # 训练模型, 2个epoch, batch size为100
model.fit(x_train, y_train, epochs=2, batch_size=100)
多输入实例
import tensorflow as tf # 输入1
input_tensor1 = tf.keras.layers.Input(shape=(28,))
x1 = tf.keras.layers.Dense(16, activation='relu')(input_tensor1)
output1 = tf.keras.layers.Dense(32, activation='relu')(x1) # 输入2
input_tensor2 = tf.keras.layers.Input(shape=(28,))
x2 = tf.keras.layers.Dense(16, activation='relu')(input_tensor2)
output2 = tf.keras.layers.Dense(32, activation='relu')(x2) # 合并输入1和输入2
concat = tf.keras.layers.concatenate([output1, output2]) # 顶层分类模型
output = tf.keras.layers.Dense(10, activation='relu')(concat) model = tf.keras.models.Model([input_tensor1, input_tensor2], output) # 编译
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
多输出实例
import tensorflow as tf # 输入
input_tensor = tf.keras.layers.Input(shape=(28,))
x = tf.keras.layers.Dense(16, activation='relu')(input_tensor)
output = tf.keras.layers.Dense(32, activation='relu')(x) # 多个输出
output1 = tf.keras.layers.Dense(10, activation='relu')(output)
output2 = tf.keras.layers.Dense(1, activation='sigmoid')(output) model = tf.keras.models.Model(input_tensor, [output1, output2]) # 编译
model.compile(
optimizer='adam',
loss=['sparse_categorical_crossentropy', 'binary_crossentropy'],
metrics=['accuracy']
)
子类化API
相较于上述使用高阶API,使用子类化API的方式来搭建模型,可以根据需求对模型中的任何一部分进行修改。
import tensorflow as tf # 导入手写数字数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data() # 数据标准化
x_train, x_test = x_train / 255, x_test / 255 train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(buffer_size=10).batch(32)
test_data = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32) class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.flatten = tf.keras.layers.Flatten()
self.hidden_layer1 = tf.keras.layers.Dense(16, activation='relu')
self.hidden_layer2 = tf.keras.layers.Dense(10, activation='softmax') # 定义模型
def call(self, x):
h = self.flatten(x)
h = self.hidden_layer1(h)
y = self.hidden_layer2(h)
return y model = MyModel() # 损失函数 和 优化器
loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam() # 评估指标
train_loss = tf.keras.metrics.Mean() # 一个epoch的loss
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy() # 一个epoch的准确率 test_loss = tf.keras.metrics.Mean()
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy() @tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
y_pre = model(x)
loss = loss_function(y, y_pre)
grad = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grad, model.trainable_variables)) train_loss(loss)
train_accuracy(y, y_pre) @tf.function
def test_step(x, y):
y_pre = model(x)
te_loss = loss_function(y, y_pre) test_loss(te_loss)
test_accuracy(y, y_pre) epoch = 2 for i in range(epoch): # 重置评估指标
train_loss.reset_states()
train_accuracy.reset_states() # 按照batch size 进行训练
for x, y in train_data:
train_step(x, y) print(f'epoch {i+1} train loss {train_loss.result()} train accuracy {train_accuracy.result()}')
参考
TensorFlow搭建模型方式总结的更多相关文章
- 用TensorFlow搭建一个万能的神经网络框架(持续更新)
我一直觉得TensorFlow的深度神经网络代码非常困难且繁琐,对TensorFlow搭建模型也十分困惑,所以我近期阅读了大量的神经网络代码,终于找到了搭建神经网络的规律,各位要是觉得我的文章对你有帮 ...
- (转)一文学会用 Tensorflow 搭建神经网络
一文学会用 Tensorflow 搭建神经网络 本文转自:http://www.jianshu.com/p/e112012a4b2d 字数2259 阅读3168 评论8 喜欢11 cs224d-Day ...
- tensorflow机器学习模型的跨平台上线
在用PMML实现机器学习模型的跨平台上线中,我们讨论了使用PMML文件来实现跨平台模型上线的方法,这个方法当然也适用于tensorflow生成的模型,但是由于tensorflow模型往往较大,使用无法 ...
- 一文学会用 Tensorflow 搭建神经网络
http://www.jianshu.com/p/e112012a4b2d 本文是学习这个视频课程系列的笔记,课程链接是 youtube 上的,讲的很好,浅显易懂,入门首选, 而且在github有代码 ...
- [DL学习笔记]从人工神经网络到卷积神经网络_3_使用tensorflow搭建CNN来分类not_MNIST数据(有一些问题)
3:用tensorflow搭个神经网络出来 为什么用tensorflow呢,应为谷歌是亲爹啊,虽然有些人说caffe更适合图像啊mxnet效率更高等等,但爸爸就是爸爸,Android都能那么火,一个道 ...
- TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人
简介 TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人. 文章包括一下几个部分: 1.为什么要尝试做这个项目? 2.为 ...
- 用Tensorflow搭建神经网络的一般步骤
用Tensorflow搭建神经网络的一般步骤如下: ① 导入模块 ② 创建模型变量和占位符 ③ 建立模型 ④ 定义loss函数 ⑤ 定义优化器(optimizer), 使 loss 达到最小 ⑥ 引入 ...
- 『TensorFlow』模型保存和载入方法汇总
『TensorFlow』第七弹_保存&载入会话_霸王回马 一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 ...
- Python之TensorFlow的模型训练保存与加载-3
一.TensorFlow的模型保存和加载,使我们在训练和使用时的一种常用方式.我们把训练好的模型通过二次加载训练,或者独立加载模型训练.这基本上都是比较常用的方式. 二.模型的保存与加载类型有2种 1 ...
随机推荐
- innerHTML 和 innertext 以及 outerHTML
今天在制作firefox下支持复制的js代码的时候,用到了innerText,测试发现原来firefox支持innerHTML但不支持innerText. test.innerHTML: 也就是从对象 ...
- Python3.7+Tornado5.1.1+Celery3.1+Rabbitmq3.7.16实现异步队列任务
原文转载自「刘悦的技术博客」https://v3u.cn/a_id_99 在之前的一篇文章中提到了用Django+Celery+Redis实现了异步任务队列,只不过消息中间件使用了redis,redi ...
- MySQL主从复制之GTID模式介绍
GreatSQL社区原创内容未经授权不得随意使用,转载请联系小编并注明来源. GTID概述 MySQL5.6 在原有主从复制的基础上增加了一个新的复制方式,即基于GTID的复制方式,它由UUID和事务 ...
- MySQL入门笔记一
MySQL应用笔记 一MySQL关系型数据库.开源,中小型公司常用类型的数据库Oracle 大型公司常用数据库 MySQL基本的命令一. 创建.删除.查看数据库(database)创建库creat ...
- NOI2022游记,Au
前言 8.19: 说实话,我在这里说几句话还不如水群, 新番把我心态搞炸了,我现在急需快乐 所以像游记这种吹水+回忆的文章让我现在非常痛苦. Day -1 (8.19) 上午是信心赛,太好辣,坐等D3 ...
- 3-14 Python处理XML文件
xml文件处理 什么是xml文件? xml即可扩展标记语言,它可以用来标记数据.定义数据类型,是一种允许用户对自己的标记语言进行定义的源语言. 从结构上,很像HTML超文本标记语言.但他们被设计的目的 ...
- PostgreSQL 与 Oracle 访问分区表执行计划差异
熟悉Oracle 的DBA都知道,Oracle 访问分区表时,对于没有提供分区条件的,也就是在无法使用分区剪枝情况下,优化器会根据全局的统计信息制定执行计划,该执行计划针对所有分区适用.在分析利弊之前 ...
- IIS 实现http重定向https(亲测有效:解决URL重写模块配置https重定向不生效的问题)
前言 以前部署网站的时候,都是通过代码来实现http重定向https,最近在部署个人网站的时候,突发奇想可不可通过IIS来实现无代码的重定向呢? 在一番操作猛如虎的搜索引擎操作后,发现只有google ...
- 采云端&采云链:从订单协同到采购供应链,让采购供应链互联互通
采购供应链安全从来没有像现在这样显得如此重要和紧迫,也从来没有像现在这样复杂和敏感,对企业的经营产生决定性的影响.尤其在疫情期间,采购供应链更加牵一发而动全身,成为"运筹帷幄,决胜于千里之外 ...
- Python数据科学手册-Numpy数组的计算:比较、掩码和布尔逻辑,花哨的索引
Numpy的通用函数可以用来替代循环, 快速实现数组的逐元素的 运算 同样,使用其他通用函数实现数组的逐元素的 比较 < > 这些运算结果 是一个布尔数据类型的数组. 有6种标准的比较操作 ...