本文来自网易云社区

作者:汪洋

前言

新手学习可以点击参考Google的教程。开始前,我们先在本地安装好 TensorFlow机器学习框架。

  1. 首先我们在本地window下安装好python环境,约定安装3.6版本;

  2. 安装Anaconda工具集后,创建名为 tensorflow 的conda 环境:conda create -n tensorflow pip python=3.6;

  3. conda切换环境:activate tensorflow;

  4. 我们安装支持CPU的TensorFlow版本(快速):pip install --ignore-installed --upgrade tensorflow;

  5. 最后验证安装是否成功,进入 python dos命名,输入以下代码校验:

    import tensorflow as tf
    hello = tf.constant('Hello, TensorFlow')
    sess = tf.Session()
    print(sess.run(hello))

    输出Hello, TensorFlow,表示成功了。如果失败的话,就选择低版本重新安装如:pip install --ignore-installed --upgrade tensorflow==1.5.0。
    其它安装方式点击参考教程

监督学习实践

官方针对新手演示了一个入门示例,点击教程可查看,本文就围绕这个教程分享。

1.分类

官方示例里讲解了分类鸢尾花问题的解决,我们想到的就是用监督学习训练机器模型。采用这种学习方式后,我们需要确定用鸢尾花的哪些特征来分类,鸢尾花的特征还是蛮多的,官方示例里用的是花萼和花瓣的长度和宽度。
鸢尾花种类非常多,官方也仅是针对三种进行分类:

expected = ['Setosa', 'Versicolor', 'Virginica']

接下来就是获取大量数据,进行预处理,官方示例里直接引用了他人整理的数据源,省略了前期数据处理步骤,前5条数据结构如下:

SepalLength SepalWidth PetalLength PetalWidth Species
0 6.4 2.8 5.6 2.2 2
1 5.0 2.3 3.3 1.0 1
2 4.9 2.5 4.5 1.7 2
3 4.9 3.1 1.5 0.1 0
4 5.7 3.8 1.7 0.1 0

说明:

  1. 最后一列代表着鸢尾花的品种,也就是说它是监督学习中的标签;

  2. 中间四列从左到右表示花萼的长度和宽度、花瓣的长度和宽度;

  3. 表格数据代表了从120个样本的数据集中抽集的5个样本;
    机器学习一般依赖数值,因此当前数据集中标签值都为数字,对应关系:

0 1 2
Setosa Versicolor Virginica

接下来将编写代码,先复习下概念,模型指特征和标签之间的关系;训练指机器学习阶段,这个阶段模型不断优化。示例里选择的监督试学习方式,模型通过包含标签的样本进行训练。

2. 导入和解析数据集

首先我们要获取训练集和测试集,其中训练集是训练模型的样本,测试集是评估训练后模型效果的样本。
首先设置我们选择的数据集地址

 """训练集"""TRAN_URL = "http://download.tensorflow.org/data/iris_training.csv""""测试集"""TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"

使用tensorflow.keras.utils.get_file函数下载数据集,该方法第一个参数为文件名称,第二个参数为下载地址,点击查看详细)。

import tensorflow as tfdef download():
    train_path = tf.keras.utils.get_file('iris_training.csv', TRAN_URL)
    test_path = tf.keras.utils.get_file('iris_test.csv', TEST_URL)    return train_path, test_path

然后用pandas.read_csv函数解析下载的数据,解析后生成的格式是一个表格,然后再分成特征列表和标签列表,返回训练集和测试集

import pandas as pd
CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth',                    'PetalLength', 'PetalWidth', 'Species']def load_data(y_species='Species'):
    train_path, test_path = download()
    train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
    train_x, train_y = train, train.pop(y_species)     test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
    test_x, test_y = test, test.pop(y_species)    return (train_x, train_y), (test_x, test_y)

3. 特征列-数值列

我们已经获取到数据集,在tensorflow中需要将数据转换为模型(Estimator)可以使用的数据结构,这时候调用tf.feature_column模块中的函数来转换。鸢尾花例子中,需将特征数据转换为浮点数,调用tf.feature_column.numeric_column方法。

import iris_data

(train_x, train_y), (test_x, test_y) = iris_data.load_data()
my_feature_columns = []for key in train_x.keys():
    my_feature_columns.append(tf.feature_column.numeric_column(key=key))

其中key是 ['SepalLength' , 'SepalWidth' , 'PetalLength' , 'PetalWidth'] 其中之一。

4. 模型选择

官方例子中选择全连接神经网络解决鸢尾花问题,用神经网络来发现特征与标签之间的复杂关系。tensorflow中,通过实例化一个Estimator类指定模型类型,这里我们使用官方提供的预创建的Estimator类,tf.estimator.DNNClassifier,此Estimator会构建一个对样本进行分类的神经网络。

classifier = tf.estimator.DNNClassifier(
    feature_columns = my_feature_columns,
    hidden_units = [10,10],
    n_classes = 3)

feature_columns 参数指训练的特征列(这里是数值列);
hidden_units 参数定义神经网络内每个隐藏层中的神经元数量,这里设置了2个隐藏层,每个隐藏层中神经元数量都是10个;
n_classes 参数表示要预测的标签数量,这里我们需要预测3个品种;
其它参数点击查看

5. 训练模型

上一步我们已经创建了一个学习模型,接下来将数据导入到模型中进行训练。tensorflow中,调用Estimator对象的train方法训练。

classifier.train(
    input_fn = lambda:iris_data.train_input_fn(train_x, train_y, 100)
    steps = 1000)

input_fn 参数表示提供训练数据的函数; steps 参数表示训练迭代次数;
在train_input_fn函数里,我们将数据转换为 train方法所需的格式。

dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

为了保证训练效果,训练样本需随机排序。buffer_size 设置大于样本数(120),可确保数据得到充分的随机化处理。

dataset = dataset.shuffle(1000)

为了保证训练期间,有无限量的训练样本,需调用 tf.data.Dataset.repeat。

dataset = dataset.repeat()

train方法一次处理一批样本, tf.data.Dataset.batch 方法通过组合多个样本创建一个批次,这里组合多个包含100个样本的批次。

dataset = dataset.batch(100)

6. 模型评估

接下来我们将训练好的模型预测效果。tensorflow中,每个Estimator对象提供了evaluate方法。

eval_result = classifier.evaluate(
    input_fn = lambda:iris_data.eval_input_fn(test_x, test_y, 100)
)

在eval_input_fn函数里,我们将数据转换为 evaluate方法所需的格式。实现跟训练一样,只是无需随机化处理和无限量重复使用测试集。

dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
dataset.batch(100);return dataset

7. 预测

接下来将该模型对无标签样本进行预测。官方手动提供了三个无标签样本。

predict_x = {    'SepalLength': [5.1, 5.9, 6.9],    'SepalWidth': [3.3, 3.0, 3.1],    'PetalLength': [1.7, 4.2, 5.4],    'PetalWidth': [0.5, 1.5, 2.1],
}

tensorflow中,每个Estimator对象提供了predict方法。

predictions = classifier.predict(
    input_fn = lambda:iris_data.eval_input_fn(predict_x, labels=None, 100)
)

改造下eval_input_fn方法,使其能够接受 labels = none 情况

features=dict(features)if labels is None:
    inputs = featureselse:
    inputs = (features, labels)
dataset = tf.data.Dataset.from_tensor_slices(inputs)

接下来打印下预测结果, predictions 中 class_ids表示可能性最大的品种,probabilities 表示每个品种的概率

for pred_dict in predictions:
    class_id = pred_dict['class_ids'][0]
    probability = pred_dict['probabilities'][class_id]
    print(class_id, probability)

结果如下:

0 0.99706334
1 0.997407
2 0.97377485

结尾

通过官方例子,新手可初步了解其使用,当然更深入的使用还得学习理论和多使用API。本文是根据官方例子,作为新手重新梳理了一遍。

网易云免费体验馆,0成本体验20+款云产品

更多网易研发、产品、运营经验分享请访问网易云社区

相关文章:
【推荐】 云计算交互设计师的正确出装姿势

机器学习tensorflow框架初试的更多相关文章

  1. python机器学习TensorFlow框架

    TensorFlow框架 关注公众号"轻松学编程"了解更多. 一.简介 ​ TensorFlow是谷歌基于DistBelief进行研发的第二代人工智能学习系统,其命名来源于本身的运 ...

  2. TensorFlow框架(3)之MNIST机器学习入门

    1. MNIST数据集 1.1 概述 Tensorflow框架载tensorflow.contrib.learn.python.learn.datasets包中提供多个机器学习的数据集.本节介绍的是M ...

  3. TensorFlow框架(5)之机器学习实践

    1. Iris data set Iris数据集是常用的分类实验数据集,由Fisher, 1936收集整理.Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集.数据集包含150个数据集,分为3类, ...

  4. 【TensorFlow篇】--Tensorflow框架初始,实现机器学习中多元线性回归

    一.前述 TensorFlow是谷歌基于DistBelief进行研发的第二代人工智能学习系统,其命名来源于本身的运行原理.Tensor(张量)意味着N维数组,Flow(流)意味着基于数据流图的计算,T ...

  5. 人工智能 tensorflow框架-->简介及安装01

    简介:Tensorflow是google于2015年11月开源的第二代机器学习框架. Tensorflow名字理解:图形边中流动的数据叫张量(Tensor),因此叫Tensorflow 既 张量流动 ...

  6. (第二章第二部分)TensorFlow框架之读取图片数据

    系列博客链接: (第二章第一部分)TensorFlow框架之文件读取流程:https://www.cnblogs.com/kongweisi/p/11050302.html 本文概述: 目标 说明图片 ...

  7. 【TensorFlow篇】--Tensorflow框架实现SoftMax模型识别手写数字集

    一.前述 本文讲述用Tensorflow框架实现SoftMax模型识别手写数字集,来实现多分类. 同时对模型的保存和恢复做下示例. 二.具体原理 代码一:实现代码 #!/usr/bin/python ...

  8. .NET数据挖掘与机器学习开源框架

    1.    数据挖掘与机器学习开源框架 1.1 框架概述 1.1.1 AForge.NET AForge.NET是一个专门为开发者和研究者基于C#框架设计的,他包括计算机视觉与人工智能,图像处理,神经 ...

  9. 跟我学算法-吴恩达老师(超参数调试, batch归一化, softmax使用,tensorflow框架举例)

    1. 在我们学习中,调试超参数是非常重要的. 超参数的调试可以是a学习率,(β1和β2,ε)在Adam梯度下降中使用, layers层数, hidden units 隐藏层的数目, learning_ ...

随机推荐

  1. vos设置可呼出手机或固话

    问题: 默认公司只让呼出手机号码,但有的客户要求能打固话,怎么办? 落地网关——补充设置——落地前缀——落地被叫改写规则 在改写规则里添加固话号段即可 具体案例: 5201——1表示让520号段只能拨 ...

  2. 如何在 ubuntu linux 一行中执行多条指令

    cd /my_folder rm *.jar svn co path to repo mvn compile package install 使用&& 运算符连接指令 cd /my_f ...

  3. April 13 2017 Week 15 Thursday

    Happiness takes no account of time. 幸福不觉光阴过. Do you know the theory of relativity? If you know about ...

  4. 如何将Twitter的内容导入到SAP CRM和C4C

    Twitter的内容导入SAP CRM Interaction Center呼叫中心 具体步骤查看我的博客Twitter(also Facebook) is official integrated i ...

  5. veritas.com常用资源汇总

    NetBackup 8.1.2文档(合集) https://www.veritas.com/support/en_US/article.100044086   NetBackup产品组停止支持生命周期 ...

  6. 【转】Android开发学习笔记(一)——初识Android

    对于一名程序员来说,“自顶向下”虽然是一种最普通不过的分析问题和解决问题的方式,但其却是简单且较为有效的一种.所以,将其应用到Android的学习中来,不至于将自己的冲动演变为一种盲目和不知所措. 根 ...

  7. 将xml转换成Json,数组,对象格式转换方法

    xml字符串:$simplexml 转换成Json格式:json_encode($simplexml) 转换成数组格式:json_decode(json_encode($simplexml),TRUE ...

  8. 2017.11.12 web中JDBC 方式访问数据库的技术

    JavaWeb------ 第四章 JDBC数据库访问技术 在JavaWeb应用程序中数据库访问是通过Java数据库连接(JavaDateBase Connectivity简称JDBC)数据库的链接一 ...

  9. Ubuntu 10.04上安装MongoDB

    MongoDB是一个可扩展.高性能的下一代数据库.MongoDB中的数据以文档形式存储,这样就能在单个数据对象中表示复杂的关系.文档可能由 以下几 部分组成:独立的基本类型属性.“内嵌文档”或文档数组 ...

  10. 奇异值分解(SVD)和最小二乘解在解齐次线性超定方程中的应用

    奇异值分解,是在A不为方阵时的对特征值分解的一种拓展.奇异值和特征值的重要意义相似,都是为了提取出矩阵的主要特征. 对于齐次线性方程 A*X =0;当A的秩大于列数时,就需要求解最小二乘解,在||X| ...