跟 Google 学 machineLearning [2] -- 关于 classifier.fit 的 warning
tensorfllow 的进化有点快。学习的很多例子已经很快的过时了,这里记录一些久的例子里被淘汰的方法,供后面参考。
我系统现在安装的是 tensorflow 1.4.1。
主要是使用了下面的代码后,出现 warning:
from tensorflow.contrib import learn myclassifier = learn.DNNClassifier(hidden_units=[10, 20, 10], n_classes=3) myclassifier.fit(x_train_array, y_train_array)
warning:
calling fit whith x is deprecated and will be removed after ...
解决方法,按照 warning 里的提示,搜了一下,发现,引入 SKCompat,并通过它来调用 classifier,即可使用原来的 fit 函数:
from tensorflow.contrib.learn.python import SKCompat feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)] classifier = SKCompat( learn.DNNClassifier(feature_columns=feature_columns,
hidden_units=[10, 20, 10],
n_classes=3) )
但是,使用 SKCompat 并没有真正的让 classifier 变成原来那个,只是改变了数据输入方式而已。从 pydoc 看到 SKCompat 共重写了三个函数:
1. fit,可以像原来一样,使用两个 array list 来进行数据填充。
2. predict,并不是原来的 predict,而是新 tensorflow.contrib.learn.Estimator 中的 predict,同样是使用 array 来喂数据。它的返回值也不是一个 array,反正我还没看懂到底它是个啥。
3. score,事实上就是新的 ensorflow.contrib.learn.Evaluable 中的 evaluate,同上,使用 array 来喂数据。
所以,即使使用过 SKCompat 之后,也还是没法用原来 predict 取得 y_test_prediction, 然后与 y_test 做比较。但是,你可以调用 score 得到一个 dic,其中 ["accuracy"]就是准确度评分。
accuracy_score = classifier.score(x_test, y_test)["accuracy"]
使用 predict ,要用下面的方法打印出可以看懂的结果(最新的手册上说 predict 的返回值是个 intertor,要用下面的方式取结果;我实验的结果是,我这里的返回值是个 dict, key 为 'classes'的就是我们要的内容了,具体的见最后的代码,这是我今天实验的最终代码;所以,tensor 又进化了):
y=classifier.predict(x_test)
predictions = list(p["predictions"] for p in itertools.islice(y, 6))
print("Predictions: {}".format(str(predictions)))
上面的 6 是 x_test 元素的个数。
===================================================
分割线
===================================================
新的 classifer 中,输入全部用的是 input_func 。这是上面报错的根本原因。
为什么要用 input_func 呢?官方给出的说法大概是,array 只适合小数据量时候使用。。。毕竟 array 的大小是有限的。这看起来完全没什么毛病。
官方给出的最新的方法(2017-12-25)是:
import numpy as np training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
filename=IRIS_TRAINING, target_dtype=np.int, features_dtype=np.float32) train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": np.array(training_set.data)},
y=np.array(training_set.target),
num_epochs=None,
shuffle=True) classifier.train(input_fn=train_input_fn, steps=)
载入一个 datasets 之后,直接调用 estimator.inputs 中的 numpy.input_fn 来生成需要的 input_fn,后面给 classifier 喂数据,就喂这个 train_input_fn 就可以了。需要注意的是,这里传入的是函数 input_fn=train_input_fn, 而不是函数的返回值 input_fn=train_input_fn()。闭包?
或者,你想使用一个可以传递参数的 input_func,官方给出了三种方法(茴香豆的茴字也有三种写法,mmp):
A)写个 wrapper
def my_input_fn(data_set):
... def my_input_fn_training_set():
return my_input_fn(training_set) classifier.train(input_fn=my_input_fn_training_set, steps=)
B)使用 functools.partial
classifier.train(
input_fn=functools.partial(my_input_fn, data_set=training_set),
steps=)
C) 使用 lamda
classifier.train(input_fn=lambda: my_input_fn(training_set), steps=2000)
反正,在我看来,是越来越麻烦了,但是,现在它毕竟是一个有用的工具,还是要用的。
============
from sklearn import metrics
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.contrib import learn
import numpy as np
from tensorflow.contrib.learn.python import SKCompat
import itertools iris = learn.datasets.load_dataset('iris') print iris.data
print iris.target x_train, x_test, y_train, y_test = train_test_split(
iris.data, iris.target, test_size=0.2, random_state=42) feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)] classifier = SKCompat( learn.DNNClassifier(feature_columns=feature_columns,
hidden_units=[10, 20, 10],
n_classes=3) ) classifier.fit(x_train, y_train, steps=200)
accuracy_score = classifier.score(x_test, y_test)["accuracy"]
print('Accuracy:{0:f}'.format(accuracy_score)) predictions=classifier.predict(x_test)['classes']
print("Predictions: {}".format(str(predictions)))
跟 Google 学 machineLearning [2] -- 关于 classifier.fit 的 warning的更多相关文章
- 跟 Google 学 machineLearning [1] -- hello sklearn
时至今日,我才发现 machineLearning 的应用门槛已经被降到了这么低,简直唾手可得.我实在找不到任何理由不对它进入深入了解.如标题,感谢 Google 为这项技术发展作出的贡献.当然,可能 ...
- Google机器学习课程基于TensorFlow : https://developers.google.cn/machine-learning/crash-course
Google机器学习课程基于TensorFlow : https://developers.google.cn/machine-learning/crash-course https ...
- 学习笔记之Machine Learning Crash Course | Google Developers
Machine Learning Crash Course | Google Developers https://developers.google.com/machine-learning/c ...
- Google机器学习笔记(七)TF.Learn 手写文字识别
转载请注明作者:梦里风林 Google Machine Learning Recipes 7 官方中文博客 - 视频地址 Github工程地址 https://github.com/ahangchen ...
- 机器学习入门 - Google的机器学习速成课程
1 - MLCC 通过机器学习,可以有效地解读数据的潜在含义,甚至可以改变思考问题的方式,使用统计信息而非逻辑推理来处理问题. Google的机器学习速成课程(MLCC,machine-learnin ...
- 【机器学习】Google机器学习工程的43条最佳实践
https://blog.csdn.net/ChenVast/article/details/81449509 本文档旨在帮助那些掌握机器学习基础知识的人从Google机器学习的最佳实践中获益.它提供 ...
- 使用Google Colab训练神经网络(二)
Colaboratory 是一个 Google 研究项目,旨在帮助传播机器学习培训和研究成果.它是一个 Jupyter 笔记本环境,不需要进行任何设置就可以使用,并且完全在云端运行.Colaborat ...
- 【阿里聚安全·安全周刊】Google“手枪”替换 | 伊朗中央银行禁止加密货币
本周七个关键词:Google"手枪"替换丨IOS 漏洞影响工业交换机丨伊朗中央银行禁止加密货币丨黑客针对医疗保健丨付费DDoS攻击丨数据获利的8种方式丨MySQL 8.0 正式版 ...
- google学习
https://developers.google.com/machine-learning/crash-course/ https://developers.google.com/machine-l ...
随机推荐
- 教你摆脱低级程序猿 项目中cocopads的安装使用
小农今天聊聊一款作为iOS开发者必备的第三方管理软件.希望程序猿朋友们看到小农的这篇文章后.可以真正的学会怎样灵活管理你项目中的第三方. (一)CocoaPods是什么? 首先我们来认识一下这款第三方 ...
- Mysql中的条件语句if、case
Mysql中的条件语句在我们对数据进行转换的时候比较有用,这样就不需要创建中转表. IF 函数 IF(expr1,expr2,expr3) 如果 expr1 是TRUE (expr1 <> ...
- VS Code搭建.NetCore开发环境(一)
一.使用命令创建并运行.Net Core程序 1.dotnet new xxx:创建指定类型的项目console,mvc,webapi 等 2.dotnet restore :加载依赖项 dotne ...
- Java NIO 的前生今世 之四 NIO Selector 详解
Selector Selector 允许一个单一的线程来操作多个 Channel. 如果我们的应用程序中使用了多个 Channel, 那么使用 Selector 很方便的实现这样的目的, 但是因为在一 ...
- cannot be resolved. It is indirectly referenced from required .class files
缺少引用. 把缺少的引用在导入一下...如果是mavan 在当前moudle里也要把 dependency加进来
- 一致性哈希算法(consistent hashing)(转)
原文链接:每天进步一点点——五分钟理解一致性哈希算法(consistent hashing) 一致性哈希算法在1997年由麻省理工学院提出的一种分布式哈希(DHT)实现算法,设计目标是为了解决因特网 ...
- iOS:三种常见计时器(NSTimer、CADisplayLink、dispatch_source_t)的使用
一.介绍 在iOS中,计时器是比较常用的,用于统计累加数据或者倒计时等,例如手机号获取验证码.计时器大概有那么三种,分别是:NSTimer.CADisplayLink.dispatch_source_ ...
- 【Scala】Scala-调用Java-集合
Scala-调用Java-集合 sacla 遍历 java list_百度搜索 13.11 Scala混用Java的集合类调用scala的foreach遍历问题 - 简书
- 条件随机场CRF HMM,MEMM的区别
http://blog.sina.com.cn/s/blog_605f5b4f010109z3.html 首先,CRF,HMM(隐马模型),MEMM(最大熵隐马模型)都常用来做序列标注的建模,像词性标 ...
- 谢宝友:会说话的Linux内核
我们本次开源专访的对象是一位认真钻研技术的工程师,谢宝友,他目前任职中兴通讯操作系统团队,他个人在业余时间前后共花费了6年时间完成了对Linux内核Linux 2.6.12内核源代码注释工作. 我们本 ...