本文地址:https://www.cnblogs.com/tujia/p/13862357.html

系列文章:

【0】TensorFlow光速入门-序

【1】TensorFlow光速入门-tensorflow开发基本流程

【2】TensorFlow光速入门-数据预处理(得到数据集)

【3】TensorFlow光速入门-训练及评估

【4】TensorFlow光速入门-保存模型及加载模型并使用

【5】TensorFlow光速入门-图片分类完整代码

【6】TensorFlow光速入门-python模型转换为tfjs模型并使用

【7】TensorFlow光速入门-总结

一、导入需要的包

import tensorflow as tf
from tensorflow import keras
import numpy as np

二、初始化模型并配置神经网络层

model = keras.Sequential([
# 展平数据,输入类型要和数据集保持一致,我这里是100*100的灰图
keras.layers.Flatten(input_shape=(100, 100, 1)),
# 第二层是神经元
keras.layers.Dense(128, activation='relu'),
# 第三层的参数很重要,2表示分两类,如果要分5类就传5,10类就传10
keras.layers.Dense(2, activation='softmax')
])

注:如果是图片分类,这三层网络是固定搭配,需要注意的是,input_shape要和数据集数据保持一致,第三层分几类就传几;其他模型的层选择和配置,我们后面再慢慢了解

三、模型编译

model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])

注:同样,图片分类的优化器、损失函数及指标也是固定搭配,其他类型的模型我们后面再慢慢了解

四、训练

model.fit(ds, epochs=100, steps_per_epoch=10)

注1:ds 是上一节准备好的数据集;epochs 代表要训练多少次,steps_per_epoch 代表每一次分几步训练;因为我准备的数据比较少,所以设置的训练100次。数据多的话,不用训练那么多次。

注2:使用 ZipDataset 类型的数据集时,steps_per_epoch 参数为必填,其他情况,根据自己的情况可以不传。

五、评估(评估训练效果)

test_loss, test_acc = model.evaluate(ds, verbose=2, steps=10)

注1:正常情况下,训练要用训练集,评估要用测试集。因为偷懒的原故,这里我都用的同一个数据集。

注2:使用 ZipDataset 类型的数据集时,steps 参数为必填,其他情况,根据自己的情况可以不传。

六、预测

预测即使用的意思,评估通过的模型,可以直接使用了

predictions = model.predict(ds, steps=10)
label = np.argmax(predictions[0])
print(label_names[label])

注:这里批量预测,对整个数据集都进行预测,正式使用的时候,一般只预测一张图片就可以了,下一节会说。

重点 Api :

keras.Sequential        https://tensorflow.google.cn/api_docs/python/tf/keras/Sequential

model.compile            https://tensorflow.google.cn/api_docs/python/tf/keras/Sequential#compile

model.fit                       https://tensorflow.google.cn/api_docs/python/tf/keras/Model#fit

model.evaluate            https://tensorflow.google.cn/api_docs/python/tf/keras/Model#evaluate

model.predict               https://tensorflow.google.cn/api_docs/python/tf/keras/Model#predict

至此,我们的图片分类模型已经训练好了。可以使用了模型来做图片分类预测了。

下一节,让我们来说一下,怎么保存这个训练好的模型。以及如何加载已保存的模型并使用:

【4】TensorFlow光速入门-保存模型及加载模型并使用

本文链接:https://www.cnblogs.com/tujia/p/13862357.html


完。

【3】TensorFlow光速入门-训练及评估的更多相关文章

  1. 【6】TensorFlow光速入门-python模型转换为tfjs模型并使用

    本文地址:https://www.cnblogs.com/tujia/p/13862365.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...

  2. 【0】TensorFlow光速入门-序

    本文地址:https://www.cnblogs.com/tujia/p/13863181.html 序言: 对于我这么一个技术渣渣来说,想学习TensorFlow机器学习,实在是太难了: 百度&qu ...

  3. 【1】TensorFlow光速入门-tensorflow开发基本流程

    本文地址:https://www.cnblogs.com/tujia/p/13862339.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...

  4. 【2】TensorFlow光速入门-数据预处理(得到数据集)

    本文地址:https://www.cnblogs.com/tujia/p/13862351.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...

  5. 【4】TensorFlow光速入门-保存模型及加载模型并使用

    本文地址:https://www.cnblogs.com/tujia/p/13862360.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...

  6. 【5】TensorFlow光速入门-图片分类完整代码

    本文地址:https://www.cnblogs.com/tujia/p/13862364.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...

  7. Tensorflow高速入门2--实现手写数字识别

    Tensorflow高速入门2–实现手写数字识别 环境: 虚拟机ubuntun16.0.4 Tensorflow 版本号:0.12.0(仅使用cpu下) Tensorflow安装见: http://b ...

  8. TensorFlow学习——入门篇

    本文主要通过一个简单的 Demo 介绍 TensorFlow 初级 API 的使用方法,因为自己也是初学者,因此本文的目的主要是引导刚接触 TensorFlow 或者 机器学习的同学,能够从第一步开始 ...

  9. 音频标签化3:igor-8m项目的训练、评估与测试

    上一节介绍了youtube-8m项目,这个项目以youtube-8m dataset(简称8m-dataset)样本集为基础,进行训练.评估与测试.youtube-8m设计用于视频特征样本,但实际也适 ...

随机推荐

  1. 刷题[BJDCTF 2nd]简单注入

    解题思路 打开发现登陆框,随机输入一些,发现有waf,然后回显都是同样的字符串.fuzz一波,发现禁了挺多东西的. select union 等 这里猜测是布尔盲注,错误的话显示的是:You konw ...

  2. 渗透测试之信息收集(Web安全攻防渗透测试实战指南第1章)

    收集域名信息 获得对象域名之后,需要收集域名的注册信息,包括该域名的DNS服务器信息和注册人的联系方式等. whois查询 对于中小型站点而言,域名所有人往往就是管理员,因此得到注册人的姓名和邮箱信息 ...

  3. 详细分析 Java 中启动线程的正确和错误方式

    目录 启动线程的正确和错误方式 前文回顾 start 方法和 run 方法的比较 start 方法分析 start 方法的含义以及注意事项 start 方法源码分析 源码 源码中的流程 run 方法分 ...

  4. 手写“SpringBoot”近况:IoC模块已经完成

    jsoncat:https://github.com/Snailclimb/jsoncat (About 仿 Spring Boot 但不同于 Spring Boot 的一个轻量级的 HTTP 框架) ...

  5. 题目:写出一条SQL语句,查询工资高于10000,且与他所在部门的经理年龄相同的职工姓名。

    create table Emp( eid char(20) primary key, ename char(20), age integer check (age > 0), did char ...

  6. 《To C产品经理进阶》

    我所说的,都是错的. To C产品设计和To B产品设计对一个优秀的产品经理的洞察能力.架构能力有共通的要求. 实际产品设计过程中,To C产品往往是从商业思维思考,侧重用户研究,思考用户心智,由产品 ...

  7. 为自己的网页博客添加L2Dwidget.js看板娘

    如果是博客园,直接在设置-->页脚 HTML 代码,加上下面代码: 1 <!-- L2Dwidget.js L2D网页动画人物 --> 2 <script src=" ...

  8. CentOS 7安装docker和常用指令

    1.安装 yum -y install docker 2.启动 systemctl start docker // 启动 docker -v //查看版本号 systemctl stop docker ...

  9. JavaScript封装函数:获取下一个/上一个兄弟元素节点

    要求: 获得下一个/上一个兄弟元素节点,不包括文本节点等 解决IE兼容性问题 代码实现: 获得下一个兄弟元素节点: function getNextElement(element) { var el ...

  10. [WC 2011]最大Xor和路径

    题目大意: 给你一张n个点,m条边的无向图,每条边都有一个权值,求:1到n的路径权值和的最大值. 题解: 任意一条路径都能够由一条简单路径(任意一条),在接上若干个环构成(如果不与这条简单路径相连就走 ...