吴裕雄--天生自然TensorFlow高层封装:Estimator-自定义模型
# 1. 自定义模型并训练。
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data tf.logging.set_verbosity(tf.logging.INFO) def lenet(x, is_training):
x = tf.reshape(x, shape=[-1, 28, 28, 1]) conv1 = tf.layers.conv2d(x, 32, 5, activation=tf.nn.relu)
conv1 = tf.layers.max_pooling2d(conv1, 2, 2) conv2 = tf.layers.conv2d(conv1, 64, 3, activation=tf.nn.relu)
conv2 = tf.layers.max_pooling2d(conv2, 2, 2) fc1 = tf.contrib.layers.flatten(conv2)
fc1 = tf.layers.dense(fc1, 1024)
fc1 = tf.layers.dropout(fc1, rate=0.4, training=is_training)
return tf.layers.dense(fc1, 10) def model_fn(features, labels, mode, params):
predict = lenet(features["image"], mode == tf.estimator.ModeKeys.TRAIN) if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode,predictions={"result": tf.argmax(predict, 1)}) loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=predict, labels=labels)) optimizer = tf.train.GradientDescentOptimizer(learning_rate=params["learning_rate"]) train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step()) eval_metric_ops = {"accuracy": tf.metrics.accuracy(tf.argmax(predict, 1), labels)} return tf.estimator.EstimatorSpec(mode=mode,loss=loss,train_op=train_op,eval_metric_ops=eval_metric_ops) mnist = input_data.read_data_sets("F:\\TensorFlowGoogle\\201806-github\\datasets\\MNIST_data", one_hot=False) model_params = {"learning_rate": 0.01}
estimator = tf.estimator.Estimator(model_fn=model_fn, params=model_params) train_input_fn = tf.estimator.inputs.numpy_input_fn(x={"image": mnist.train.images},y=mnist.train.labels.astype(np.int32),num_epochs=None,batch_size=128,shuffle=True) estimator.train(input_fn=train_input_fn, steps=30000)

# 2. 在测试数据上测试模型。
test_input_fn = tf.estimator.inputs.numpy_input_fn(x={"image": mnist.test.images},y=mnist.test.labels.astype(np.int32),num_epochs=1,batch_size=128,shuffle=False) test_results = estimator.evaluate(input_fn=test_input_fn)
accuracy_score = test_results["accuracy"]
print("\nTest accuracy: %g %%" % (accuracy_score*100))
# 3. 预测过程。
predict_input_fn = tf.estimator.inputs.numpy_input_fn(x={"image": mnist.test.images[:10]},num_epochs=1,shuffle=False) predictions = estimator.predict(input_fn=predict_input_fn)
for i, p in enumerate(predictions):
print("Prediction %s: %s" % (i + 1, p["result"]))
吴裕雄--天生自然TensorFlow高层封装:Estimator-自定义模型的更多相关文章
- 吴裕雄--天生自然TensorFlow高层封装:Estimator-DNNClassifier
# 1. 模型定义. import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist impor ...
- 吴裕雄--天生自然TensorFlow高层封装:Keras-TensorFlow API
# 1. 模型定义. import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist_ ...
- 吴裕雄--天生自然TensorFlow高层封装:Keras-多输入输出
# 1. 数据预处理. import keras from keras.models import Model from keras.datasets import mnist from keras. ...
- 吴裕雄--天生自然TensorFlow高层封装:Keras-返回值
# 1. 数据预处理. import keras from keras.models import Model from keras.datasets import mnist from keras. ...
- 吴裕雄--天生自然TensorFlow高层封装:Keras-CNN
# 1. 数据预处理 import keras from keras import backend as K from keras.datasets import mnist from keras.m ...
- 吴裕雄--天生自然TensorFlow高层封装:使用TFLearn处理MNIST数据集实现LeNet-5模型
# 1. 通过TFLearn的API定义卷机神经网络. import tflearn import tflearn.datasets.mnist as mnist from tflearn.layer ...
- 吴裕雄--天生自然TensorFlow高层封装:使用TensorFlow-Slim处理MNIST数据集实现LeNet-5模型
# 1. 通过TensorFlow-Slim定义卷机神经网络 import numpy as np import tensorflow as tf import tensorflow.contrib. ...
- 吴裕雄--天生自然TensorFlow高层封装:Keras-RNN
# 1. 数据预处理. from keras.layers import LSTM from keras.datasets import imdb from keras.models import S ...
- 吴裕雄--天生自然TensorFlow高层封装:解决ImportError: cannot import name 'tf_utils'
将原来版本的keras卸载了,再安装2.1.5版本的keras就可以了.
随机推荐
- Python安装bs4
- 需要将pip源设置为国内源,阿里源.豆瓣源.网易源等 - windows (1)打开文件资源管理器(文件夹地址栏中) (2)地址栏上面输入 %appdata% (3)在这里面新建一个文件夹 pip ...
- ELK 介绍
章节 ELK 介绍 ELK 安装Elasticsearch ELK 安装Kibana ELK 安装Beat ELK 安装Logstash ELK是什么? ELK是3个开源产品的组合: Elastics ...
- javascript中的私有作用域
我们知道js中所有的块级作用域都是无效的,块级作用域内的变量,在外部仍然可以被读取,其实是申明在外部的.如何实现变量的私有化,只在块级作用域起效,避免污染全局的变量呢.而且,挂载在全局的变量很难被回收 ...
- Day 22:网络编程(3)
TCP通讯协议特点: 1. tcp是基于IO流进行数据 的传输的,面向连接. 2. tcp进行数据传输的时候是没有大小限制的. 3. tcp是面向连接,通过三次握手的机制保证数据的完整性.可靠协 ...
- 【Python】【Django】登录用户-链接Mysql
- UVA - 1606 Amphiphilic Carbon Molecules(两亲性分子)(扫描法)
题意:平面上有n(n <= 1000)个点,每个点为白点或者黑点.现在需放置一条隔板,使得隔板一侧的白点数加上另一侧的黑点数总数最大.隔板上的点可以看做是在任意一侧. 分析:枚举每个基准点i,将 ...
- 19 01 16 jquery 的 属性操作 循环 jquery 事件 和事件的绑定 解绑
jquery属性操作 1.html() 取出或设置html内容 // 取出html内容 var $htm = $('#div1').html(); // 设置html内容 $('#div1').htm ...
- Atomic系列类整体介绍
本博客系列是学习并发编程过程中的记录总结.由于文章比较多,写的时间也比较散,所以我整理了个目录贴(传送门),方便查阅. 并发编程系列博客传送门 本文是转载文章,原文请见此博客,文章主要对java.ut ...
- JS的BOM对象
BOM对象 (一)简介:BOM对象,即浏览器对象模型: 通过javascript的对象,操作和浏览器相关的操作 B: Browser,浏览器 O: Object,对象 M: Model,模型 (1) ...
- 好文推荐:终于有人把Elasticsearch原理讲透了
专注于Java领域优质技术,欢迎关注 作者:channingbreeze 转自公号:互联网侦察 小史是一个非科班的程序员,虽然学的是电子专业,但是通过自己的努力成功通过了面试,现在要开始迎接新生活了. ...