# 1. 自定义模型并训练。
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data tf.logging.set_verbosity(tf.logging.INFO) def lenet(x, is_training):
x = tf.reshape(x, shape=[-1, 28, 28, 1]) conv1 = tf.layers.conv2d(x, 32, 5, activation=tf.nn.relu)
conv1 = tf.layers.max_pooling2d(conv1, 2, 2) conv2 = tf.layers.conv2d(conv1, 64, 3, activation=tf.nn.relu)
conv2 = tf.layers.max_pooling2d(conv2, 2, 2) fc1 = tf.contrib.layers.flatten(conv2)
fc1 = tf.layers.dense(fc1, 1024)
fc1 = tf.layers.dropout(fc1, rate=0.4, training=is_training)
return tf.layers.dense(fc1, 10) def model_fn(features, labels, mode, params):
predict = lenet(features["image"], mode == tf.estimator.ModeKeys.TRAIN) if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode,predictions={"result": tf.argmax(predict, 1)}) loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=predict, labels=labels)) optimizer = tf.train.GradientDescentOptimizer(learning_rate=params["learning_rate"]) train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step()) eval_metric_ops = {"accuracy": tf.metrics.accuracy(tf.argmax(predict, 1), labels)} return tf.estimator.EstimatorSpec(mode=mode,loss=loss,train_op=train_op,eval_metric_ops=eval_metric_ops) mnist = input_data.read_data_sets("F:\\TensorFlowGoogle\\201806-github\\datasets\\MNIST_data", one_hot=False) model_params = {"learning_rate": 0.01}
estimator = tf.estimator.Estimator(model_fn=model_fn, params=model_params) train_input_fn = tf.estimator.inputs.numpy_input_fn(x={"image": mnist.train.images},y=mnist.train.labels.astype(np.int32),num_epochs=None,batch_size=128,shuffle=True) estimator.train(input_fn=train_input_fn, steps=30000)

# 2. 在测试数据上测试模型。
test_input_fn = tf.estimator.inputs.numpy_input_fn(x={"image": mnist.test.images},y=mnist.test.labels.astype(np.int32),num_epochs=1,batch_size=128,shuffle=False) test_results = estimator.evaluate(input_fn=test_input_fn)
accuracy_score = test_results["accuracy"]
print("\nTest accuracy: %g %%" % (accuracy_score*100))
# 3. 预测过程。
predict_input_fn = tf.estimator.inputs.numpy_input_fn(x={"image": mnist.test.images[:10]},num_epochs=1,shuffle=False) predictions = estimator.predict(input_fn=predict_input_fn)
for i, p in enumerate(predictions):
print("Prediction %s: %s" % (i + 1, p["result"]))

吴裕雄--天生自然TensorFlow高层封装:Estimator-自定义模型的更多相关文章

  1. 吴裕雄--天生自然TensorFlow高层封装:Estimator-DNNClassifier

    # 1. 模型定义. import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist impor ...

  2. 吴裕雄--天生自然TensorFlow高层封装:Keras-TensorFlow API

    # 1. 模型定义. import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist_ ...

  3. 吴裕雄--天生自然TensorFlow高层封装:Keras-多输入输出

    # 1. 数据预处理. import keras from keras.models import Model from keras.datasets import mnist from keras. ...

  4. 吴裕雄--天生自然TensorFlow高层封装:Keras-返回值

    # 1. 数据预处理. import keras from keras.models import Model from keras.datasets import mnist from keras. ...

  5. 吴裕雄--天生自然TensorFlow高层封装:Keras-CNN

    # 1. 数据预处理 import keras from keras import backend as K from keras.datasets import mnist from keras.m ...

  6. 吴裕雄--天生自然TensorFlow高层封装:使用TFLearn处理MNIST数据集实现LeNet-5模型

    # 1. 通过TFLearn的API定义卷机神经网络. import tflearn import tflearn.datasets.mnist as mnist from tflearn.layer ...

  7. 吴裕雄--天生自然TensorFlow高层封装:使用TensorFlow-Slim处理MNIST数据集实现LeNet-5模型

    # 1. 通过TensorFlow-Slim定义卷机神经网络 import numpy as np import tensorflow as tf import tensorflow.contrib. ...

  8. 吴裕雄--天生自然TensorFlow高层封装:Keras-RNN

    # 1. 数据预处理. from keras.layers import LSTM from keras.datasets import imdb from keras.models import S ...

  9. 吴裕雄--天生自然TensorFlow高层封装:解决ImportError: cannot import name 'tf_utils'

    将原来版本的keras卸载了,再安装2.1.5版本的keras就可以了.

随机推荐

  1. 腾讯云服务器上搭建Jenkins配置邮箱通知

    1,Jenkins 点击 系统管理 2,点击系统管理 3,配置系统管理员邮件地址 5,配置 Extended E-main Notification,(用户名不需要邮箱后缀“@163.com”, SS ...

  2. 吴裕雄--天生自然JAVA SPRING框架开发学习笔记:Spring JDBCTemplate简介

    Spring 框架针对数据库开发中的应用提供了 JDBCTemplate 类,该类是 Spring 对 JDBC 支持的核心,它提供了所有对数据库操作功能的支持. Spring 框架提供的JDBC支持 ...

  3. Spring-ResolvableType可解决的数据类型

    ResolvableType,可解决的数据类型.它为java语言中的所有类型提供了相同的数据结构,其内部封装了一个java.lang.reflect.Type类型的对象. 在讲解这个数据结构之前,首先 ...

  4. M: Mysterious Conch 字符串哈希

    Problem Description小明有一个神奇的海螺,你对海螺说一段字符串,海螺就会返回一个单词,有字符串里面的所有字符组成如告诉海螺“lloeh”海螺则会告诉你“hello”如果有多个单词对应 ...

  5. Android 为控件添加点击涟漪效果

    Android在5.0版为Button默认添加了点击时的涟漪效果,而且在其他的控件上也可以轻松的实现这种炫酷的效果.涟漪效果可以分为两种,一种时有边界的涟漪,另一种时无边界的涟漪.所谓的有边界,即涟漪 ...

  6. 三星首款折叠屏手机Galaxy Fold上架中国官网

    2 月 28 日,在三星 Galaxy S10 系列新品发布会上,备受期待的三星首款可折叠屏手机 Galaxy Fold 也在中国正式亮相.目前,Galaxy Fold 已正式上架三星中国官网,可以预 ...

  7. sql拆分列 时间拆分 日、月、年

    我想 查看今日访问人数 需要分组查询 就得 时间拆分 这两天百度 方法有很多 substring ... 但是 我这一列 是时间类型: 可以直接用 DATEPART() 函数用于返回日期/时间的单独部 ...

  8. 封装localStorage设置,获取,移除方法

    export const local = { set(key, value) { localStorage.setItem(key, JSON.stringify(value)); }, get(ke ...

  9. 10 —— node —— 获取文件在前台遍历

    思想 : 前台主动发起获取 => ajax 1,前台文件 index.html <!DOCTYPE html> <html lang="en"> &l ...

  10. node —— 静态资源文件管理

    var http = require("http"); var url = require("url"); var fs = require("fs& ...