tensorflow对鸢尾花进行分类——人工智能入门篇
tensorflow之对鸢尾花进行分类
任务目标
- 对鸢尾花数据集分析
- 建立鸢尾花的模型
- 利用模型预测鸢尾花的类别
环境搭建
pycharm编辑器搭建python3.*
第三方库
- tensorflow1.*
- numpy
- pandas
- sklearn
- keras
处理鸢尾花数据集
了解数据集
鸢尾花数据集是一个经典的机器学习数据集,非常适合用来入门。
鸢尾花数据集链接:下载鸢尾花数据集
鸢尾花数据集包含四个特征和一个标签。这四个特征确定了单株鸢尾花的下列植物学特征:
- 花萼长度
- 花萼宽度
- 花瓣长度
- 花瓣宽度
该表确定了鸢尾花品种,品种必须是下列任意一种:
- 山鸢尾 Iris-Setosa(0)
- 杂色鸢尾 Iris-versicolor(1)
- 维吉尼亚鸢尾 Iris-virginica(2)
数据集中三类鸢尾花各含有50个样本,共150各样本
下面显示了数据集中的样本:

机器学习中,为了保证测试结果的准确性,一般会从数据集中抽取一部分数据专门留作测试,其余数据用于训练。所以我将数据集按7:3(训练集:测试集)的比例进行划分。
数据集处理具体代码
def dealIrisData(IrisDatapath):
    """
    :param IrisDatapath:传入数据集路径
    :return: 返回 训练特征集,测试特征集,训练标签集,测试标签集
    """
    # 读取数据集
    iris = pd.read_csv(IrisDatapath, header=None)
    # 数据集转化成数组
    iris = np.array(iris)
    # 提取特征集
    X = iris[:, 0:4]
    # 提取标签集
    Y = iris[:, 4]
    # One-Hot编码
    encoder = LabelEncoder()
    Y = encoder.fit_transform(Y)
    Y = np_utils.to_categorical(Y)
    x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.3)
    return x_train,x_test,y_train,y_test
什么是one-hot编码?
  One-Hot编码,又称为一位有效编码,主要是采用N位状态寄存器来对N个状态进行编码,每个状态都由他独立的寄存器位,并且在任意时候只有一位有效。
  One-Hot编码是分类变量作为二进制向量的表示。这首先要求将分类值映射到整数值。然后,每个整数值被表示为二进制向量,除了整数的索引之外,它都是零值,它被标记为1。
  One-Hot编码是将类别变量转换为机器学习算法易于利用的一种形式的过程。
  比如:["山鸢尾","杂色鸢尾","维吉尼亚鸢尾"]---->[[1,0,0][0,1,0][0,0,1]]
模型建立
  由于结构简单并没有建立隐藏层。
建立模型代码
def getIrisModel(saveModelPath,step):
    """
    :param saveModelPath: 模型保存路径
    :param step: 训练步数
    :return: None
    """
    x_train, x_test, y_train, y_test = dealIrisData("iris.data")
    # 输入层
    with tf.variable_scope("data"):
        x = tf.placeholder(tf.float32,[None,4])
        y_true = tf.placeholder(tf.int32,[None,3])
        # placeholder()函数是在神经网络构建graph的时候在模型中的占位,此时并没有把要输入的数据传入模型,
        # 它只会分配必要的内存。等建立session,在会话中,运行模型的时候通过feed_dict()函数向占位符喂入数据。
    # 无隐藏层
    # 输出层
    with tf.variable_scope("fc_model"):
        weight = tf.Variable(tf.random_normal([4,3],mean=0.0,stddev=1.0)) # 创建一个形状为[4,3],均值为0,方差为1的正态分布随机值变量
        bias = tf.Variable(tf.constant(0.0,shape=[3])) # 创建 张量为0,形状为3变量
        y_predict = tf.matmul(x,weight)+bias # 矩阵相乘
        # Variable()创建一个变量
    # 误差
    with tf.variable_scope("loss"):
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true,logits=y_predict))
    # 优化器
    with tf.variable_scope("optimizer"):
        train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
    # 准确率
    with tf.variable_scope("acc"):
        equal_list = tf.equal(tf.arg_max(y_true,1),tf.arg_max(y_predict,1))
        accuracy = tf.reduce_mean(tf.cast(equal_list,tf.float32))
    # 开始训练
    with tf.Session() as sess:
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
        for i in range(step):
            _train = sess.run(train_op, feed_dict={x: x_train, y_true: y_train})
            _acc = sess.run(accuracy, feed_dict={x: x_train, y_true: y_train})
            print("训练%d步,准确率为%.2f" % (i + 1, _acc))
        print("测试集的准确率为%.2f" %sess.run(accuracy, feed_dict={x: x_test, y_true: y_test}))
        saver.save(sess, saveModelPath)
载入模型—预测鸢尾花
- saver.restore()时填的文件名,因为在saver.save的时候,每个checkpoint会保存三个文件,如 modelIris.meta,modelIris.index, modelIris.data-00000-of-00001
 在import_meta_graph时填的就是meta文件名,我们知道权值都保存在modelIris.data-00000-of-00001这个文件中,但是如果在restore方法中填这个文件名,就会报错,应该填的是前缀,这个前缀可以使用tf.train.latest_checkpoint(checkpoint_dir)这个方法获取。
- 模型的y中有用到placeholder,在sess.run()的时候肯定要feed对应的数据,因此还要根据具体placeholder的名字,从graph中使用get_operation_by_name方法获取。
 代码实现
def predictIris(modelPath,data):
    """
    :param modelPath: 载入模型路径
    :param data: 预测数据
    :return: None
    """
    with tf.Session() as sess:
        #
        new_saver = tf.train.import_meta_graph("model/iris_model.meta")
        new_saver.restore(sess,"model/iris_model")
        graph = tf.get_default_graph()
        x = graph.get_operation_by_name('data/x_pred').outputs[0]
        y = tf.get_collection("pred_network")[0]
        predict = np.argmax(sess.run(y,feed_dict={x:data}))
        if predict == 0:
            print("山鸢尾 Iris-Setosa")
        elif predict == 1:
            print("杂色鸢尾 Iris-versicolor")
        else:
            print("维吉尼亚鸢尾 Iris-virginica")
整体代码
import tensorflow as tf
import numpy as np
import pandas as pd
from keras.utils import np_utils
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # 不启动GPU
def dealIrisData(IrisDatapath):
    """
    :param IrisDatapath:传入数据集路径
    :return: 返回 训练特征集,测试特征集,训练标签集,测试标签集
    """
    # 读取数据集
    iris = pd.read_csv(IrisDatapath, header=None)
    # 数据集转化成数组
    iris = np.array(iris)
    # 提取特征集
    X = iris[:, 0:4]
    # 提取标签集
    Y = iris[:, 4]
    # One-Hot编码
    encoder = LabelEncoder()
    Y = encoder.fit_transform(Y)
    Y = np_utils.to_categorical(Y)
    x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.3)
    return x_train,x_test,y_train,y_test
def getIrisModel(saveModelPath,step):
    """
    :param saveModelPath: 模型保存路径
    :param step: 训练步数
    :return: None
    """
    x_train, x_test, y_train, y_test = dealIrisData("iris.data")
    # 输入层
    with tf.variable_scope("data"):
        x = tf.placeholder(tf.float32,[None,4],name='x_pred')
        y_true = tf.placeholder(tf.int32,[None,3])
        # placeholder()函数是在神经网络构建graph的时候在模型中的占位,此时并没有把要输入的数据传入模型,
        # 它只会分配必要的内存。等建立session,在会话中,运行模型的时候通过feed_dict()函数向占位符喂入数据。
    # 无隐藏层
    # 输出层
    with tf.variable_scope("fc_model"):
        weight = tf.Variable(tf.random_normal([4,3],mean=0.0,stddev=1.0)) # 创建一个形状为[4,3],均值为0,方差为1的正态分布随机值变量
        bias = tf.Variable(tf.constant(0.0,shape=[3])) # 创建 张量为0,形状为3变量
        y_predict = tf.matmul(x,weight)+bias # 矩阵相乘
        tf.add_to_collection('pred_network', y_predict)  # 用于加载模型获取要预测的网络结构
        # Variable()创建一个变量
    # 误差
    with tf.variable_scope("loss"):
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true,logits=y_predict))
    # 优化器
    with tf.variable_scope("optimizer"):
        train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
    # 准确率
    with tf.variable_scope("acc"):
        equal_list = tf.equal(tf.arg_max(y_true,1),tf.arg_max(y_predict,1))
        accuracy = tf.reduce_mean(tf.cast(equal_list,tf.float32))
    # 开始训练
    with tf.Session() as sess:
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
        for i in range(step):
            _train = sess.run(train_op, feed_dict={x: x_train, y_true: y_train})
            _acc = sess.run(accuracy, feed_dict={x: x_train, y_true: y_train})
            print("训练%d步,准确率为%.2f" % (i + 1, _acc))
        print("测试集的准确率为%.2f" %sess.run(accuracy, feed_dict={x: x_test, y_true: y_test}))
        saver.save(sess, saveModelPath)
def predictIris(modelPath,data):
    """
    :param modelPath: 载入模型路径
    :param data: 预测数据
    :return: None
    """
    with tf.Session() as sess:
        #
        new_saver = tf.train.import_meta_graph("model/iris_model.meta")
        new_saver.restore(sess,"model/iris_model")
        graph = tf.get_default_graph()
        x = graph.get_operation_by_name('data/x_pred').outputs[0]
        y = tf.get_collection("pred_network")[0]
        predict = np.argmax(sess.run(y,feed_dict={x:data}))
        if predict == 0:
            print("山鸢尾 Iris-Setosa")
        elif predict == 1:
            print("杂色鸢尾 Iris-versicolor")
        else:
            print("维吉尼亚鸢尾 Iris-virginica")
if __name__ == '__main__':
    model_path = "model/iris_model"
    # 模型训练
    # model = getIrisModel(model_path,1000)
    # 模型预测
    # predictData = [[5.0,3.4,1.5,0.2]] # 填入数据集
    # predictIris(model_path,predictData)
tensorflow对鸢尾花进行分类——人工智能入门篇的更多相关文章
- 基于tensorflow的MNIST手写数字识别(二)--入门篇
		http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ... 
- 腾讯QQ会员技术团队:人人都可以做深度学习应用:入门篇(下)
		四.经典入门demo:识别手写数字(MNIST) 常规的编程入门有"Hello world"程序,而深度学习的入门程序则是MNIST,一个识别28*28像素的图片中的手写数字的程序 ... 
- Unity3D大风暴之入门篇(海量教学视频版)
		智画互动开发团队 编 ISBN 978-7-121-22242-9 2014年2月出版 定价:79.00元 328页 16开 编辑推荐 长达800分钟的高清教学视频,手把手教会初学者 数个开发案例 ... 
- PC游戏编程(入门篇)(前言写的很不错)
		PC游戏编程(入门篇) 第一章 基石 1. 1 BOSS登场--GAF简介 第二章 2D图形程式初体验 2.l 饮水思源--第一个"游戏"程式 2.2 知其所以然一一2D图形学基础 ... 
- 私有仓库GitLab快速入门篇
		私有仓库GitLab快速入门篇 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 安装文档请参考官网:https://about.gitlab.com/installation/#ce ... 
- 《Unity3D大风暴之入门篇(海量教学视频版)》
		<Unity3D大风暴之入门篇(海量教学视频版)> 基本信息 作者: 智画互动开发团队 出版社:电子工业出版社 ISBN:9787121222429 上架时间:2014-1-13 出版日期 ... 
- XTU | 人工智能入门复习总结
		写在前面 本文严禁转载,只限于学习交流. 课件分享在这里了. 还有人工智能标准化白皮书(2018版)也一并分享了. 绪论 人工智能的定义与发展 定义 一般解释:人工智能就是用 人工的方法在 **机器( ... 
- Python入门篇-基础语法
		Python入门篇-基础语法 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 一.编程基础 1>.程序 一组能让计算机识别和执行的指令. 程序 >.算法+ 数据结构= 程 ... 
- Membership三步曲之入门篇 - Membership基础示例
		Membership 三步曲之入门篇 - Membership基础示例 Membership三步曲之入门篇 - Membership基础示例 Membership三步曲之进阶篇 - 深入剖析Pro ... 
随机推荐
- 一、Adobe Premiere Pro CC概述
			一.Adobe Premiere Pro CC概述 使用建议 一.开始 二.在Adobe Premiere Pro CC执行非线性编辑 1.标准的视频剪辑工作流 2.使用Premiere增强工作流 p ... 
- PHP字符串函数总结
			字符串函数 addcslashes — 为字符串里面的部分字符添加反斜线转义字符 addslashes — 用指定的方式对字符串里面的字符进行转义 bin2hex — 将二进制数据转换成十六进制表示 ... 
- 使用docker创建mysql容器
			1.拉取mysql容器 docker pull mysql:5.7 
- 深入探究ASP.NET Core异常处理中间件
			前言 全局异常处理是我们编程过程中不可或缺的重要环节.有了全局异常处理机制给我们带来了很多便捷,首先我们不用满屏幕处理程序可能出现的异常,其次我们可以对异常进行统一的处理,比如收集异常信息或者 ... 
- Java 线程基础,从这篇开始
			线程作为操作系统中最少调度单位,在当前系统的运行环境中,一般都拥有多核处理器,为了更好的充分利用 CPU,掌握其正确使用方式,能更高效的使程序运行.同时,在 Java 面试中,也是极其重要的一个模块. ... 
- linux根据进程查端口,根据端口查进程
			[root@test_environment src]# netstat -tnllup 能显示对应端口和进程 Active Internet connections (only servers) ... 
- 化繁就简,如何利用Spring AOP快速实现系统日志
			1.引言 有关Spring AOP的概念就不细讲了,网上这样的文章一大堆,要讲我也不会比别人讲得更好,所以就不啰嗦了. 为什么要用Spring AOP呢?少写代码.专注自身业务逻辑实现(关注本身的业务 ... 
- Write a program to copy its input to its output, replacing each tab by \t, each backspace by \b, and each backslash by \\. This makes tabs and backspa
			#include <stdio.h> #define DBS '\\' void main() { int c; while((c=getchar())!=EOF) { if(c=='\t ... 
- C program Language  'EOF'  and   'getchar()'
			#include <stdio.h> void main() { int c; c=getchar(); while(c!=EOF) { putchar(c); c=getchar(); ... 
- 学习 Java 网站推荐给你
			推荐几个非常不错的 Java 学习网站 LearnJava 在线 这是一个非常不错的学习 Java 的在线网站,纯免费.这是一个个人项目,旨在通过简单有效的在浏览器中进行练习让你快速掌握 Java 编 ... 
