基本分类

官网示例:https://www.tensorflow.org/tutorials/keras/basic_classification
主要步骤:

  1. 加载Fashion MNIST数据集
  2. 探索数据:了解数据集格式
  3. 预处理数据
  4. 构建模型:设置层、编译模型
  5. 训练模型
  6. 评估准确率
  7. 做出预测:可视化

Fashion MNIST数据集

tf.keras

  • Keras是一个用于构建和训练深度学习模型的高级API
  • TensorFlow中的tf.keras是Keras API规范的TensorFlow实现,可以运行任何与Keras兼容的代码,保留了一些细微的差别
  • 最新版TensorFlow中的tf.keras版本可能与PyPI中的最新Keras版本不同
  • https://www.tensorflow.org/api_docs/python/tf/keras/

过拟合

如果机器学习模型在新数据上的表现不如在训练数据上的表现,就表示出现过拟合

示例

脚本内容

GitHub:https://github.com/anliven/Hello-AI/blob/master/Google-Learn-and-use-ML/1_basic_classification.py

 # coding=utf-8
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = ''
print("TensorFlow version: {} - tf.keras version: {}".format(tf.VERSION, tf.keras.__version__)) # 查看版本 # ### 加载数据集
# 网络畅通的情况下,可以从 TensorFlow 直接访问 Fashion MNIST,只需导入和加载数据即可
# 或者手工下载文件,并存放在“~/.keras/datasets”下的fashion-mnist目录
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
# 训练集:train_images 和 train_labels 数组,用于学习的数据
# 测试集:test_images 和 test_labels 数组,用于测试模型
# 图像images为28x28的NumPy数组,像素值介于0到255之间
# 标签labels是整数数组,介于0到9之间,对应于图像代表的服饰所属的类别,每张图像都映射到一个标签 class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] # 类别名称 # ### 探索数据:了解数据格式
print("train_images.shape: {}".format(train_images.shape)) # 训练集中有60000张图像,每张图像都为28x28像素
print("train_labels len: {}".format(len(train_labels))) # 训练集中有60000个标签
print("train_labels: {}".format(train_labels)) # 每个标签都是一个介于 0 到 9 之间的整数
print("test_images.shape: {}".format(test_images.shape)) # 测试集中有10000张图像,每张图像都为28x28像素
print("test_labels len: {}".format(len(test_labels))) # 测试集中有10000个标签
print("test_labels: {}".format(test_labels)) # ### 预处理数据
# 必须先对数据进行预处理,然后再训练网络
plt.figure(num=1) # 创建图形窗口,参数num是图像编号
plt.imshow(train_images[0]) # 绘制图片
plt.colorbar() # 渐变色度条
plt.grid(False) # 显示网格
plt.savefig("./outputs/sample-1-figure-1.png", dpi=200, format='png') # 保存文件,必须在plt.show()前使用,否则将是空白内容
plt.show() # 显示
plt.close() # 关闭figure实例,如果要创建多个figure实例,必须显示调用close方法来释放不再使用的figure实例 # 值缩小为0到1之间的浮点数
train_images = train_images / 255.0
test_images = test_images / 255.0 # 显示训练集中的前25张图像,并在每张图像下显示类别名称
plt.figure(num=2, figsize=(10, 10)) # 参数figsize指定宽和高,单位为英寸
for i in range(25): # 前25张图像
plt.subplot(5, 5, i + 1)
plt.xticks([]) # x坐标轴刻度
plt.yticks([]) # y坐标轴刻度
plt.grid(False)
plt.imshow(train_images[i], cmap=plt.cm.binary)
plt.xlabel(class_names[train_labels[i]]) # x坐标轴名称
plt.savefig("./outputs/sample-1-figure-2.png", dpi=200, format='png')
plt.show()
plt.close() # ### 构建模型
# 构建神经网络需要先配置模型的层,然后再编译模型
# 设置层
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)), # 将图像格式从二维数组(28x28像素)转换成一维数组(784 像素)
keras.layers.Dense(128, activation=tf.nn.relu), # 全连接神经层,具有128个节点(或神经元)
keras.layers.Dense(10, activation=tf.nn.softmax)]) # 全连接神经层,具有10个节点的softmax层
# 编译模型
model.compile(optimizer=tf.train.AdamOptimizer(), # 优化器:根据模型看到的数据及其损失函数更新模型的方式
loss='sparse_categorical_crossentropy', # 损失函数:衡量模型在训练期间的准确率。
metrics=['accuracy']) # 指标:用于监控训练和测试步骤;这里使用准确率(图像被正确分类的比例) # ### 训练模型
# 将训练数据馈送到模型中,模型学习将图像与标签相关联
model.fit(train_images, # 训练数据
train_labels, # 训练数据
epochs=5, # 训练周期(训练模型迭代轮次)
verbose=2 # 日志显示模式:0为安静模式, 1为进度条(默认), 2为每轮一行
) # 调用model.fit 方法开始训练,使模型与训练数据“拟合 # ### 评估准确率
# 比较模型在测试数据集上的表现
test_loss, test_acc = model.evaluate(test_images, test_labels)
print('Test loss: {} - Test accuracy: {}'.format(test_loss, test_acc)) # ### 做出预测
predictions = model.predict(test_images) # 使用predict()方法进行预测
print("The first prediction: {}".format(predictions[0])) # 查看第一个预测结果(包含10个数字的数组,分别对应10种服饰的“置信度”
label_number = np.argmax(predictions[0]) # 置信度值最大的标签
print("label: {} - class name: {}".format(label_number, class_names[label_number]))
print("Result true or false: {}".format(test_labels[0] == label_number)) # 对比测试标签,查看该预测是否正确 # 可视化:将该预测绘制成图来查看全部10个通道
def plot_image(m, predictions_array, true_label, img):
predictions_array, true_label, img = predictions_array[m], true_label[m], img[m]
plt.grid(False)
plt.xticks([])
plt.yticks([])
plt.imshow(img, cmap=plt.cm.binary)
predicted_label = np.argmax(predictions_array)
if predicted_label == true_label:
color = 'blue' # 正确的预测标签为蓝色
else:
color = 'red' # 错误的预测标签为红色
plt.xlabel("{} {:2.0f}% ({})".format(class_names[predicted_label],
100 * np.max(predictions_array),
class_names[true_label]),
color=color) def plot_value_array(n, predictions_array, true_label):
predictions_array, true_label = predictions_array[n], true_label[n]
plt.grid(False)
plt.xticks([])
plt.yticks([])
thisplot = plt.bar(range(10), predictions_array, color="#777777")
plt.ylim([0, 1])
predicted_label = np.argmax(predictions_array)
thisplot[predicted_label].set_color('red')
thisplot[true_label].set_color('blue') # 查看第0张图像、预测和预测数组
i = 0
plt.figure(num=3, figsize=(8, 5))
plt.subplot(1, 2, 1)
plot_image(i, predictions, test_labels, test_images)
plt.subplot(1, 2, 2)
plot_value_array(i, predictions, test_labels)
plt.xticks(range(10), class_names, rotation=45) # x坐标轴刻度,参数rotation表示label旋转显示角度
plt.savefig("./outputs/sample-1-figure-3.png", dpi=200, format='png')
plt.show()
plt.close() # 查看第12张图像、预测和预测数组
i = 12
plt.figure(num=4, figsize=(8, 5))
plt.subplot(1, 2, 1)
plot_image(i, predictions, test_labels, test_images)
plt.subplot(1, 2, 2)
plot_value_array(i, predictions, test_labels)
plt.xticks(range(10), class_names, rotation=45) # range(10)作为x轴的刻度,class_names作为对应的标签
plt.savefig("./outputs/sample-1-figure-4.png", dpi=200, format='png')
plt.show()
plt.close() # 绘制图像:正确的预测标签为蓝色,错误的预测标签为红色,数字表示预测标签的百分比(总计为 100)
num_rows = 5
num_cols = 3
num_images = num_rows * num_cols
plt.figure(num=5, figsize=(2 * 2 * num_cols, 2 * num_rows))
for i in range(num_images):
plt.subplot(num_rows, 2 * num_cols, 2 * i + 1)
plot_image(i, predictions, test_labels, test_images)
plt.subplot(num_rows, 2 * num_cols, 2 * i + 2)
plot_value_array(i, predictions, test_labels)
plt.xticks(range(10), class_names, rotation=45)
plt.savefig("./outputs/sample-1-figure-5.png", dpi=200, format='png')
plt.show()
plt.close() # 使用经过训练的模型对单个图像进行预测
image = test_images[0] # 从测试数据集获得一个图像
print("img shape: {}".format(image.shape)) # 图像的shape信息
image = (np.expand_dims(image, 0)) # 添加到列表中
print("img shape: {}".format(image.shape))
predictions_single = model.predict(image) # model.predict返回一组列表,每个列表对应批次数据中的每张图像
print("prediction_single: {}".format(predictions_single)) # 查看预测,预测结果是一个具有10个数字的数组,分别对应10种不同服饰的“置信度” plt.figure(num=6)
plot_value_array(0, predictions_single, test_labels)
plt.xticks(range(10), class_names, rotation=45)
plt.savefig("./outputs/sample-1-figure-6.png", dpi=200, format='png')
plt.show()
plt.close() prediction_result = np.argmax(predictions_single[0]) # 获取批次数据中相应图像的预测结果(置信度值最大的标签)
print("prediction_result: {}".format(prediction_result))

运行结果

common line

C:\Users\anliven\AppData\Local\conda\conda\envs\mlcc\python.exe D:/Anliven/Anliven-Code/PycharmProjects/TempTest/TempTest.py
TensorFlow version: 1.12.
train_images.shape: (, , )
train_labels len:
train_labels: [ ... ]
test_images.shape: (, , )
test_labels len:
test_labels: [ ... ]
Epoch /
- 3s - loss: 0.5077 - acc: 0.8211
Epoch /
- 3s - loss: 0.3790 - acc: 0.8632
Epoch /
- 3s - loss: 0.3377 - acc: 0.8755
Epoch /
- 3s - loss: 0.3120 - acc: 0.8855
Epoch /
- 3s - loss: 0.2953 - acc: 0.8914 / [..............................] - ETA: 15s
/ [=====>........................] - ETA: 0s
/ [============>.................] - ETA: 0s
/ [====================>.........] - ETA: 0s
/ [===========================>..] - ETA: 0s
/ [==============================] - 0s 30us/step
Test loss: 0.3584352566242218 - Test accuracy: 0.8711
The first prediction: [4.9706377e-06 2.2675355e-09 1.3649772e-07 3.6149192e-08 4.7982059e-08
8.5262489e-03 1.5245891e-05 3.2628113e-03 1.6874857e-05 9.8817366e-01]
label: - class name: Ankle boot
Result true or false: True
img shape: (, )
img shape: (, , )
prediction_single: [[4.9706327e-06 2.2675313e-09 1.3649785e-07 3.6149192e-08 4.7982059e-08
8.5262526e-03 1.5245891e-05 3.2628146e-03 1.6874827e-05 9.8817366e-01]]
prediction_result: Process finished with exit code

Figure1

Figure2

Figure3

Figure4

Figure5

Figure6

问题处理

问题1:执行fashion_mnist.load_data()失败

错误提示
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
......
Exception: URL fetch failure on https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz: None -- [WinError 10060] A connection attempt failed because the connected party did not properly respond after a period of time, or established connection failed because connected host has failed to respond

处理方法1

选择一个链接,

手工下载下面四个文件,并存放在“~/.keras/datasets”下的fashion-mnist目录。

  • train-labels-idx1-ubyte.gz
  • train-images-idx3-ubyte.gz
  • t10k-labels-idx1-ubyte.gz
  • t10k-images-idx3-ubyte.gz
guowli@5CG450158J MINGW64 ~/.keras/datasets
$ pwd
/c/Users/guowli/.keras/datasets guowli@5CG450158J MINGW64 ~/.keras/datasets
$ ls -l
total
drwxr-xr-x guowli Mar : fashion-mnist/ guowli@5CG450158J MINGW64 ~/.keras/datasets
$ ls -l fashion-mnist/
total
-rw-r--r-- guowli Mar : t10k-images-idx3-ubyte.gz
-rw-r--r-- guowli Mar : t10k-labels-idx1-ubyte.gz
-rw-r--r-- guowli Mar : train-images-idx3-ubyte.gz
-rw-r--r-- guowli Mar : train-labels-idx1-ubyte.gz

处理方法2

手工下载文件,存放在指定目录。
改写“tensorflow\python\keras\datasets\fashion_mnist.py”定义的load_data()函数。

from tensorflow.python.keras.utils import get_file
import numpy as np
import pathlib
import gzip def load_data(): # 改写“tensorflow\python\keras\datasets\fashion_mnist.py”定义的load_data()函数
base = "file:///" + str(pathlib.Path.cwd()) + "\\" # 当前目录 files = [
'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'
] paths = []
for fname in files:
paths.append(get_file(fname, origin=base + fname)) with gzip.open(paths[0], 'rb') as lbpath:
y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8) with gzip.open(paths[1], 'rb') as imgpath:
x_train = np.frombuffer(
imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28) with gzip.open(paths[2], 'rb') as lbpath:
y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8) with gzip.open(paths[3], 'rb') as imgpath:
x_test = np.frombuffer(
imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28) return (x_train, y_train), (x_test, y_test) (train_images, train_labels), (test_images, test_labels) = load_data()

问题2:使用gzip.open()打开.gz文件失败

错误提示

“OSError: Not a gzipped file (b'\n\n')”

处理方法

对于损坏的、不完整的.gz文件,zip.open()将无法打开。检查.gz文件是否完整无损。

参考信息

https://github.com/tensorflow/tensorflow/issues/170

AI - TensorFlow - 示例01:基本分类的更多相关文章

  1. AI - TensorFlow - 示例02:影评文本分类

    影评文本分类 文本分类(Text classification):https://www.tensorflow.org/tutorials/keras/basic_text_classificatio ...

  2. AI - TensorFlow - 示例03:基本回归

    基本回归 回归(Regression):https://www.tensorflow.org/tutorials/keras/basic_regression 主要步骤:数据部分 获取数据(Get t ...

  3. AI - TensorFlow - 示例05:保存和恢复模型

    保存和恢复模型(Save and restore models) 官网示例:https://www.tensorflow.org/tutorials/keras/save_and_restore_mo ...

  4. AI - TensorFlow - 示例04:过拟合与欠拟合

    过拟合与欠拟合(Overfitting and underfitting) 官网示例:https://www.tensorflow.org/tutorials/keras/overfit_and_un ...

  5. 【5】TensorFlow光速入门-图片分类完整代码

    本文地址:https://www.cnblogs.com/tujia/p/13862364.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...

  6. 在 TensorFlow 中实现文本分类的卷积神经网络

    在TensorFlow中实现文本分类的卷积神经网络 Github提供了完整的代码: https://github.com/dennybritz/cnn-text-classification-tf 在 ...

  7. springmvc 项目完整示例01 需求与数据库表设计 简单的springmvc应用实例 web项目

    一个简单的用户登录系统 用户有账号密码,登录ip,登录时间 打开登录页面,输入用户名密码 登录日志,可以记录登陆的时间,登陆的ip 成功登陆了的话,就更新用户的最后登入时间和ip,同时记录一条登录记录 ...

  8. Tensorflow&CNN:裂纹分类

    版权声明:本文为博主原创文章,转载 请注明出处:https://blog.csdn.net/sc2079/article/details/90478551 - 写在前面 本科毕业设计终于告一段落了.特 ...

  9. 在TensorFlow中实现文本分类的卷积神经网络

    在TensorFlow中实现文本分类的卷积神经网络 Github提供了完整的代码: https://github.com/dennybritz/cnn-text-classification-tf 在 ...

随机推荐

  1. +function ($) { "use strict";}(window.jQuery);全面分析

    +function ($) { "use strict"; }(window.jQuery); 怎么理解? 匿名函数闭包 我们先来理一理函数表达式和函数声明的区别 函数表达式: 函 ...

  2. 浅谈前端中的mvvm与mvc

    用了vue这么久,却没有认真的关注mvvm与mvc,着实汗颜.趁着周末刚好看了一下网上的文章还有书籍,简单的谈一下我的理解. -以下图片均摘自网络. 一.MVC 特点:单项通讯 视图(View):用户 ...

  3. spring(一)--spring/springmvc/spring+hibernate(mybatis)配置文件

    这篇文章用来总结一下spring,springmvc,spring+mybatis,spring+hibernate的配置文件 1.web.xml 要使用spring,必须在web.xml中定义分发器 ...

  4. 9.app后端选择什么服务器

    对于很多刚入行的朋友来说,不清楚应该选择什么样的服务器提供商,是选择传统的IDC, 租用服务器租用机柜,还是选择现在很火的云服务器呢?在本文中,通过对比传统的IDC和云服务,简单阐述一下服务器的选择. ...

  5. Android软键盘弹出,覆盖h5页面输入框问题

    之前我们在使用vue进行 h5 表单录入的过程中,遇到了Android软键盘弹出,覆盖 h5页面 输入框 问题,在此进行回顾并分享给大家: 系统:Android 条件:当输入框在可视区底部或者偏下的位 ...

  6. driver匹配元素定位用法大全

    # -*- coding:utf-8 -*- from selenium import webdriver from selenium.webdriver.common.by import By fr ...

  7. Java Script 学习笔记 (一) 基础

    1. 设置变量 const: 赋常量,不可更改. let :设置可更改变量. ES6 中推荐使用let 而不是var. Let 和var的区别 : let 将变量的作用域限定在当前{}中, var 定 ...

  8. compact_op.go

    package clientv3 import (     pb "github.com/coreos/etcd/etcdserver/etcdserverpb" ) // Com ...

  9. Centos 7 Linux系统修改网卡名称为ethx

    一.Centos7 系统安装完成后更改网卡名称方法 1.查看Centos7系统默认的网卡配置(eno16777736) [root@server ~]# ifconfig eno16777736: f ...

  10. H5移动端项目案例、web手机微商城实战开发

    自微信生态圈一步步强大后,关于移动端购物的趋势,逐渐成为大众关心的内容,目前市场上关于移动商城的制定就有大量版本,比如.微商城.移动商城.移动webAPP.微信商城各等各种定义层出不穷,这就对于移动端 ...