TensorFlow高层次机器学习API (tf.contrib.learn)

1.tf.contrib.learn.datasets.base.load_csv_with_header 加载csv格式数据

2.tf.contrib.learn.DNNClassifier 建立DNN模型(classifier)

3.classifer.fit 训练模型

4.classifier.evaluate 评价模型

5.classifier.predict 预测新样本

完整代码:

 1 from __future__ import absolute_import
2 from __future__ import division
3 from __future__ import print_function
4
5 import tensorflow as tf
6 import numpy as np
7
8 # Data sets
9 IRIS_TRAINING = "iris_training.csv"
10 IRIS_TEST = "iris_test.csv"
11
12 # Load datasets.
13 training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
14 filename=IRIS_TRAINING,
15 target_dtype=np.int,
16 features_dtype=np.float32)
17 test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
18 filename=IRIS_TEST,
19 target_dtype=np.int,
20 features_dtype=np.float32)
21
22 # Specify that all features have real-value data
23 feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
24
25 # Build 3 layer DNN with 10, 20, 10 units respectively.
26 classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
27 hidden_units=[10, 20, 10],
28 n_classes=3,
29 model_dir="/tmp/iris_model")
30
31 # Fit model.
32 classifier.fit(x=training_set.data,
33 y=training_set.target,
34 steps=2000)
35
36 # Evaluate accuracy.
37 accuracy_score = classifier.evaluate(x=test_set.data,
38 y=test_set.target)["accuracy"]
39 print('Accuracy: {0:f}'.format(accuracy_score))
40
41 # Classify two new flower samples.
42 new_samples = np.array(
43 [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
44 y = list(classifier.predict(new_samples, as_iterable=True))
45 print('Predictions: {}'.format(str(y)))

结果:

Accuracy:0.966667

TensorFlow高层次机器学习API (tf.contrib.learn)的更多相关文章

  1. TensorFlow高级API(tf.contrib.learn)及可视化工具TensorBoard的使用

    一.TensorFlow高层次机器学习API (tf.contrib.learn) 1.tf.contrib.learn.datasets.base.load_csv_with_header 加载cs ...

  2. tf.contrib.learn.preprocessing.VocabularyProcessor()

    tf.contrib.learn.preprocessing.VocabularyProcessor (max_document_length, min_frequency=0, vocabulary ...

  3. tensorflow中slim模块api介绍

    tensorflow中slim模块api介绍 翻译 2017年08月29日 20:13:35   http://blog.csdn.net/guvcolie/article/details/77686 ...

  4. tf.contrib.layers.fully_connected参数笔记

    tf.contrib.layers.fully_connected 添加完全连接的图层. tf.contrib.layers.fully_connected(    inputs,    num_ou ...

  5. TensorFlow——tf.contrib.layers库中的相关API

    在TensorFlow中封装好了一个高级库,tf.contrib.layers库封装了很多的函数,使用这个高级库来开发将会提高效率,卷积函数使用tf.contrib.layers.conv2d,池化函 ...

  6. tensorflow笔记3:CRF函数:tf.contrib.crf.crf_log_likelihood()

    在分析训练代码的时候,遇到了,tf.contrib.crf.crf_log_likelihood,这个函数,于是想简单理解下: 函数的目的:使用crf 来计算损失,里面用到的优化方法是:最大似然估计 ...

  7. tensorflow教程:tf.contrib.rnn.DropoutWrapper

    tf.contrib.rnn.DropoutWrapper Defined in tensorflow/python/ops/rnn_cell_impl.py. def __init__(self, ...

  8. TensorFlow中的L2正则化函数:tf.nn.l2_loss()与tf.contrib.layers.l2_regularizerd()的用法与异同

    tf.nn.l2_loss()与tf.contrib.layers.l2_regularizerd()都是TensorFlow中的L2正则化函数,tf.contrib.layers.l2_regula ...

  9. 关于tensorflow里面的tf.contrib.rnn.BasicLSTMCell 中num_units参数问题

    这里的num_units参数并不是指这一层油多少个相互独立的时序lstm,而是lstm单元内部的几个门的参数,这几个门其实内部是一个神经网络,答案来自知乎: class TRNNConfig(obje ...

随机推荐

  1. 2014.04.17,转帖,关于FFT的结果为什么要除以N

    http://www.chinavib.com/forum/viewthread.php?tid=23665&highlight= 关于这个问题,我看到的书好像都没有进行解释,这里我试着解释下 ...

  2. 杂项-建模:BIM

    ylbtech-杂项-建模:BIM 建筑信息模型是建筑学.工程学及土木工程的新工具.建筑信息模型或建筑资讯模型一词由Autodesk所创的.它是来形容那些以三维图形为主.物件导向.建筑学有关的电脑辅助 ...

  3. oracle故障处理之删除大表空间hang住

    背景 数据库分区表数据越来越大,需要对过期话的数据进行迁移,以及大的分区表需要进行数据的清理和删除,达到释放磁盘空间的目的. 问题说明 环境:linux 6.X 数据库:oracle 11.2.0.4 ...

  4. Linux 系列- 基本命令

    Linux 基本命令 转自:http://www.taobaotest.com/blogs/qa?bid=353 Linux是一个基于命令的系统,它有很多很强的命令. 但它也有桌面系统,比如KDE, ...

  5. 制作ubuntu的U盘启动盘

    在制作U盘启动盘之前,请各位先格式化你的U盘. 制作U盘启动盘的工具有很多种,我们这里为大家介绍的是用软碟通制作.所有我们需要有这个软件,如果大家没有可以百度“软碟通”,下载安装一个.然后点击打开.在 ...

  6. Oracle Access和filter的区别

    在查看Oracle执行计划的时候经常会遇到Access和filter,脑容量太小,总是分不清两者的区别...稍作整理. Access:表示对应的谓词条件会影响数据的访问路径(是按照索引还是表) Fil ...

  7. 微信小程序中获取高度及设备的方法

    由于js中可以采用操纵dom的方法来获取页面元素的高度,可是在微信小程序中不能操纵dom,经过查找之后发现仅仅只有以下几个方法可以获取到高度 wx.getSystemInfoSync().window ...

  8. vue-cli搭建项目结构及引用bootstrap

    vue-cli脚手架工具快速构建项目架构: 1.首先默认了有已经安装了node,然后依次执行以下命令: npm install -g vue-cli                   全局安装vue ...

  9. Windows 安装 MySQL8

    MySQL8下载地址:https://dev.mysql.com/downloads/mysql/ 解压到安装目录 新建配置文件my.ini [mysqld]# 设置mysql的安装目录basedir ...

  10. SQL的几个路径

    这个是主数据库文件存放的地方 C:\Program Files\Microsoft SQL Server\MSSQL12.MSSQLSERVER2014\MSSQL\DATA