TensorFlow高层次机器学习API (tf.contrib.learn)

1.tf.contrib.learn.datasets.base.load_csv_with_header 加载csv格式数据

2.tf.contrib.learn.DNNClassifier 建立DNN模型(classifier)

3.classifer.fit 训练模型

4.classifier.evaluate 评价模型

5.classifier.predict 预测新样本

完整代码:

 1 from __future__ import absolute_import
2 from __future__ import division
3 from __future__ import print_function
4
5 import tensorflow as tf
6 import numpy as np
7
8 # Data sets
9 IRIS_TRAINING = "iris_training.csv"
10 IRIS_TEST = "iris_test.csv"
11
12 # Load datasets.
13 training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
14 filename=IRIS_TRAINING,
15 target_dtype=np.int,
16 features_dtype=np.float32)
17 test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
18 filename=IRIS_TEST,
19 target_dtype=np.int,
20 features_dtype=np.float32)
21
22 # Specify that all features have real-value data
23 feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
24
25 # Build 3 layer DNN with 10, 20, 10 units respectively.
26 classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
27 hidden_units=[10, 20, 10],
28 n_classes=3,
29 model_dir="/tmp/iris_model")
30
31 # Fit model.
32 classifier.fit(x=training_set.data,
33 y=training_set.target,
34 steps=2000)
35
36 # Evaluate accuracy.
37 accuracy_score = classifier.evaluate(x=test_set.data,
38 y=test_set.target)["accuracy"]
39 print('Accuracy: {0:f}'.format(accuracy_score))
40
41 # Classify two new flower samples.
42 new_samples = np.array(
43 [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
44 y = list(classifier.predict(new_samples, as_iterable=True))
45 print('Predictions: {}'.format(str(y)))

结果:

Accuracy:0.966667

TensorFlow高层次机器学习API (tf.contrib.learn)的更多相关文章

  1. TensorFlow高级API(tf.contrib.learn)及可视化工具TensorBoard的使用

    一.TensorFlow高层次机器学习API (tf.contrib.learn) 1.tf.contrib.learn.datasets.base.load_csv_with_header 加载cs ...

  2. tf.contrib.learn.preprocessing.VocabularyProcessor()

    tf.contrib.learn.preprocessing.VocabularyProcessor (max_document_length, min_frequency=0, vocabulary ...

  3. tensorflow中slim模块api介绍

    tensorflow中slim模块api介绍 翻译 2017年08月29日 20:13:35   http://blog.csdn.net/guvcolie/article/details/77686 ...

  4. tf.contrib.layers.fully_connected参数笔记

    tf.contrib.layers.fully_connected 添加完全连接的图层. tf.contrib.layers.fully_connected(    inputs,    num_ou ...

  5. TensorFlow——tf.contrib.layers库中的相关API

    在TensorFlow中封装好了一个高级库,tf.contrib.layers库封装了很多的函数,使用这个高级库来开发将会提高效率,卷积函数使用tf.contrib.layers.conv2d,池化函 ...

  6. tensorflow笔记3:CRF函数:tf.contrib.crf.crf_log_likelihood()

    在分析训练代码的时候,遇到了,tf.contrib.crf.crf_log_likelihood,这个函数,于是想简单理解下: 函数的目的:使用crf 来计算损失,里面用到的优化方法是:最大似然估计 ...

  7. tensorflow教程:tf.contrib.rnn.DropoutWrapper

    tf.contrib.rnn.DropoutWrapper Defined in tensorflow/python/ops/rnn_cell_impl.py. def __init__(self, ...

  8. TensorFlow中的L2正则化函数:tf.nn.l2_loss()与tf.contrib.layers.l2_regularizerd()的用法与异同

    tf.nn.l2_loss()与tf.contrib.layers.l2_regularizerd()都是TensorFlow中的L2正则化函数,tf.contrib.layers.l2_regula ...

  9. 关于tensorflow里面的tf.contrib.rnn.BasicLSTMCell 中num_units参数问题

    这里的num_units参数并不是指这一层油多少个相互独立的时序lstm,而是lstm单元内部的几个门的参数,这几个门其实内部是一个神经网络,答案来自知乎: class TRNNConfig(obje ...

随机推荐

  1. Caused by: java.lang.NoClassDefFoundError: org/apache/neethi/AssertionBuilderFactory

    转自:https://blog.csdn.net/iteye_8264/article/details/82641058 1.错误描述 严重: StandardWrapper.Throwable or ...

  2. POJ 1155 树形DP

    题意:电视台发送信号给很多用户,每个用户有愿意出的钱,电视台经过的路线都有一定费用,求电视台不损失的情况下最多给多少用户发送信号. 转自:http://www.cnblogs.com/andre050 ...

  3. Memcache相关面试题

    1)memcached的cache机制是怎样的? 懒惰算法 +最近最少使用原则 2)memcached如何实现冗余机制? 冗余:就是有好多好多不经常使用的. 可以不用实现冗余机制,如果非要实现.那就搞 ...

  4. 织梦Fatal error: Call to a member function GetInnerText()

    问题:织梦修改或者添加了自定义表单后在后台修改文章的时候出现如下错误:Fatal error: Call to a member function GetInnerText() on a non-ob ...

  5. 脚本_统计固定时间段服务器的访问量.sh

    #!bin/bash#功能:统计 1:30 到 4:30 所有访问 apache 服务器的请求有多少个#作者:liusingbon#awk 使用-F 选项指定文件内容的分隔符是/或者:#条件判断$7: ...

  6. 路飞学城Python-Day43

    前端                                                                                                  ...

  7. poj 2954 Triangle 三角形内的整点数

    poj 2954 Triangle 题意 给出一个三角形的三个点,问三角形内部有多少个整点. 解法 pick's law 一个多边形如果每个顶点都由整点构成,该多边形的面积为\(S\),该多边形边上的 ...

  8. BZOJ 2119 股市的预测 (后缀数组+RMQ)

    题目大意:求一个字符串中形如$ABA$的串的数量,其中$B$的长度是给定的 有点像[NOI2016]优秀的拆分这道题 先对序列打差分,然后离散,再正反跑$SA$,跑出$st$表 进入正题 $ABA$串 ...

  9. python学习(一):python基础

    python两种执行方式: python解释器:py文件路径 python进入解释器:实时输入并获取执行结果 解释器路径: 在linux系统中,python文件在头部加上#!/usr/bin/env ...

  10. 洛谷P5269 欧稳欧再次学车

    正常模拟就好~ 首先初始化:转速=l, 档位=1 然后读入数据 由于先要处理换挡操作,所以我们先按照x处理,再按照y处理 当x=0时,档位+1,转速=l 当x=1时,档位-1,转速=r 当y=1时,转 ...