本文地址:https://www.cnblogs.com/tujia/p/13862357.html

系列文章:

【0】TensorFlow光速入门-序

【1】TensorFlow光速入门-tensorflow开发基本流程

【2】TensorFlow光速入门-数据预处理(得到数据集)

【3】TensorFlow光速入门-训练及评估

【4】TensorFlow光速入门-保存模型及加载模型并使用

【5】TensorFlow光速入门-图片分类完整代码

【6】TensorFlow光速入门-python模型转换为tfjs模型并使用

【7】TensorFlow光速入门-总结

一、导入需要的包

import tensorflow as tf
from tensorflow import keras
import numpy as np

二、初始化模型并配置神经网络层

model = keras.Sequential([
# 展平数据,输入类型要和数据集保持一致,我这里是100*100的灰图
keras.layers.Flatten(input_shape=(100, 100, 1)),
# 第二层是神经元
keras.layers.Dense(128, activation='relu'),
# 第三层的参数很重要,2表示分两类,如果要分5类就传5,10类就传10
keras.layers.Dense(2, activation='softmax')
])

注:如果是图片分类,这三层网络是固定搭配,需要注意的是,input_shape要和数据集数据保持一致,第三层分几类就传几;其他模型的层选择和配置,我们后面再慢慢了解

三、模型编译

model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])

注:同样,图片分类的优化器、损失函数及指标也是固定搭配,其他类型的模型我们后面再慢慢了解

四、训练

model.fit(ds, epochs=100, steps_per_epoch=10)

注1:ds 是上一节准备好的数据集;epochs 代表要训练多少次,steps_per_epoch 代表每一次分几步训练;因为我准备的数据比较少,所以设置的训练100次。数据多的话,不用训练那么多次。

注2:使用 ZipDataset 类型的数据集时,steps_per_epoch 参数为必填,其他情况,根据自己的情况可以不传。

五、评估(评估训练效果)

test_loss, test_acc = model.evaluate(ds, verbose=2, steps=10)

注1:正常情况下,训练要用训练集,评估要用测试集。因为偷懒的原故,这里我都用的同一个数据集。

注2:使用 ZipDataset 类型的数据集时,steps 参数为必填,其他情况,根据自己的情况可以不传。

六、预测

预测即使用的意思,评估通过的模型,可以直接使用了

predictions = model.predict(ds, steps=10)
label = np.argmax(predictions[0])
print(label_names[label])

注:这里批量预测,对整个数据集都进行预测,正式使用的时候,一般只预测一张图片就可以了,下一节会说。

重点 Api :

keras.Sequential        https://tensorflow.google.cn/api_docs/python/tf/keras/Sequential

model.compile            https://tensorflow.google.cn/api_docs/python/tf/keras/Sequential#compile

model.fit                       https://tensorflow.google.cn/api_docs/python/tf/keras/Model#fit

model.evaluate            https://tensorflow.google.cn/api_docs/python/tf/keras/Model#evaluate

model.predict               https://tensorflow.google.cn/api_docs/python/tf/keras/Model#predict

至此,我们的图片分类模型已经训练好了。可以使用了模型来做图片分类预测了。

下一节,让我们来说一下,怎么保存这个训练好的模型。以及如何加载已保存的模型并使用:

【4】TensorFlow光速入门-保存模型及加载模型并使用

本文链接:https://www.cnblogs.com/tujia/p/13862357.html


完。

【3】TensorFlow光速入门-训练及评估的更多相关文章

  1. 【6】TensorFlow光速入门-python模型转换为tfjs模型并使用

    本文地址:https://www.cnblogs.com/tujia/p/13862365.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...

  2. 【0】TensorFlow光速入门-序

    本文地址:https://www.cnblogs.com/tujia/p/13863181.html 序言: 对于我这么一个技术渣渣来说,想学习TensorFlow机器学习,实在是太难了: 百度&qu ...

  3. 【1】TensorFlow光速入门-tensorflow开发基本流程

    本文地址:https://www.cnblogs.com/tujia/p/13862339.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...

  4. 【2】TensorFlow光速入门-数据预处理(得到数据集)

    本文地址:https://www.cnblogs.com/tujia/p/13862351.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...

  5. 【4】TensorFlow光速入门-保存模型及加载模型并使用

    本文地址:https://www.cnblogs.com/tujia/p/13862360.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...

  6. 【5】TensorFlow光速入门-图片分类完整代码

    本文地址:https://www.cnblogs.com/tujia/p/13862364.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...

  7. Tensorflow高速入门2--实现手写数字识别

    Tensorflow高速入门2–实现手写数字识别 环境: 虚拟机ubuntun16.0.4 Tensorflow 版本号:0.12.0(仅使用cpu下) Tensorflow安装见: http://b ...

  8. TensorFlow学习——入门篇

    本文主要通过一个简单的 Demo 介绍 TensorFlow 初级 API 的使用方法,因为自己也是初学者,因此本文的目的主要是引导刚接触 TensorFlow 或者 机器学习的同学,能够从第一步开始 ...

  9. 音频标签化3:igor-8m项目的训练、评估与测试

    上一节介绍了youtube-8m项目,这个项目以youtube-8m dataset(简称8m-dataset)样本集为基础,进行训练.评估与测试.youtube-8m设计用于视频特征样本,但实际也适 ...

随机推荐

  1. 【MySQL】面试官:如何添加新数据库到MySQL主从复制环境?

    写在前面 今天,一名读者反馈说:自己出去面试,被面试官一顿虐啊!为什么呢?因为这名读者面试的是某大厂的研发工程师,偏技术型的.所以,在面试过程中,面试官比较偏向于问技术型的问题.不过,技术终归还是要服 ...

  2. Linux服务器内存监控—每小时检查&超出发送邮件&重启占用最高的Java程式

    简介与优点 使用该脚本能自行判断系统内存使用情况是否超出设定百分比 能在超出预警值时执行重启程式的操作 能记录重启过程,并将具体LOG邮件发送给指定收信人 可以设定Crontab排程,达成每隔一段时间 ...

  3. Kafka处理请求的全流程分析

    大家好,我是 yes. 这是我的第三篇Kafka源码分析文章,前两篇讲了日志段的读写和二分算法在kafka索引上的应用 今天来讲讲 Kafka Broker端处理请求的全流程,剖析下底层的网络通信是如 ...

  4. javascript内置对象的innerText、innerHTML、join方法的认识

    innerText语法规范:HTMLElement.innerText = string ;//后面的赋值是一个字符串形式 innerText是一个非标准形式,不识别HTML标签 返回值会去除空格和换 ...

  5. 082 01 Android 零基础入门 02 Java面向对象 01 Java面向对象基础 02 构造方法介绍 01 构造方法-无参构造方法

    082 01 Android 零基础入门 02 Java面向对象 01 Java面向对象基础 02 构造方法介绍 01 构造方法-无参构造方法 本文知识点:构造方法-无参构造方法 说明:因为时间紧张, ...

  6. Open CV leaning

    刚接触Open CV 几个比较好的介绍: OpenCV学习笔记:https://blog.csdn.net/yang_xian521/column/info/opencv-manual/3 OpenC ...

  7. 安装 Windows 10 系统时分区选择 MBR 还是 GUID?

    一.MBR 和 GUID 的概述 MBR 分区表 MBR:Master Boot Record,即硬盘主引导记录分区表,指支持容量在2.1TB以下的硬盘,超过2.1TB的硬盘只能管理2.1TB,最多只 ...

  8. How to install the NVIDIA drivers on Fedora 32

    https://linuxconfig.org/how-to-install-the-nvidia-drivers-on-fedora-32 The NVIDIA Driver is a progra ...

  9. 热力图 vue 项目中使用热力图插件 “heatmap.js”(保姆式教程)

    我现在写的这项目是用CDN引入 heatmap.js, 可根据自己项目情况使用哪种方式引入插件. 官网地址 "https://www.patrick-wied.at/static/heatm ...

  10. 这类注解都不知道,还好意思说会Spring Boot ?

    前言 不知道大家在使用Spring Boot开发的日常中有没有用过@Conditionalxxx注解,比如@ConditionalOnMissingBean.相信看过Spring Boot源码的朋友一 ...