吴裕雄--天生自然TensorFlow高层封装:Estimator-自定义模型
# 1. 自定义模型并训练。
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data tf.logging.set_verbosity(tf.logging.INFO) def lenet(x, is_training):
x = tf.reshape(x, shape=[-1, 28, 28, 1]) conv1 = tf.layers.conv2d(x, 32, 5, activation=tf.nn.relu)
conv1 = tf.layers.max_pooling2d(conv1, 2, 2) conv2 = tf.layers.conv2d(conv1, 64, 3, activation=tf.nn.relu)
conv2 = tf.layers.max_pooling2d(conv2, 2, 2) fc1 = tf.contrib.layers.flatten(conv2)
fc1 = tf.layers.dense(fc1, 1024)
fc1 = tf.layers.dropout(fc1, rate=0.4, training=is_training)
return tf.layers.dense(fc1, 10) def model_fn(features, labels, mode, params):
predict = lenet(features["image"], mode == tf.estimator.ModeKeys.TRAIN) if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode,predictions={"result": tf.argmax(predict, 1)}) loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=predict, labels=labels)) optimizer = tf.train.GradientDescentOptimizer(learning_rate=params["learning_rate"]) train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step()) eval_metric_ops = {"accuracy": tf.metrics.accuracy(tf.argmax(predict, 1), labels)} return tf.estimator.EstimatorSpec(mode=mode,loss=loss,train_op=train_op,eval_metric_ops=eval_metric_ops) mnist = input_data.read_data_sets("F:\\TensorFlowGoogle\\201806-github\\datasets\\MNIST_data", one_hot=False) model_params = {"learning_rate": 0.01}
estimator = tf.estimator.Estimator(model_fn=model_fn, params=model_params) train_input_fn = tf.estimator.inputs.numpy_input_fn(x={"image": mnist.train.images},y=mnist.train.labels.astype(np.int32),num_epochs=None,batch_size=128,shuffle=True) estimator.train(input_fn=train_input_fn, steps=30000)

# 2. 在测试数据上测试模型。
test_input_fn = tf.estimator.inputs.numpy_input_fn(x={"image": mnist.test.images},y=mnist.test.labels.astype(np.int32),num_epochs=1,batch_size=128,shuffle=False) test_results = estimator.evaluate(input_fn=test_input_fn)
accuracy_score = test_results["accuracy"]
print("\nTest accuracy: %g %%" % (accuracy_score*100))
# 3. 预测过程。
predict_input_fn = tf.estimator.inputs.numpy_input_fn(x={"image": mnist.test.images[:10]},num_epochs=1,shuffle=False) predictions = estimator.predict(input_fn=predict_input_fn)
for i, p in enumerate(predictions):
print("Prediction %s: %s" % (i + 1, p["result"]))
吴裕雄--天生自然TensorFlow高层封装:Estimator-自定义模型的更多相关文章
- 吴裕雄--天生自然TensorFlow高层封装:Estimator-DNNClassifier
# 1. 模型定义. import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist impor ...
- 吴裕雄--天生自然TensorFlow高层封装:Keras-TensorFlow API
# 1. 模型定义. import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist_ ...
- 吴裕雄--天生自然TensorFlow高层封装:Keras-多输入输出
# 1. 数据预处理. import keras from keras.models import Model from keras.datasets import mnist from keras. ...
- 吴裕雄--天生自然TensorFlow高层封装:Keras-返回值
# 1. 数据预处理. import keras from keras.models import Model from keras.datasets import mnist from keras. ...
- 吴裕雄--天生自然TensorFlow高层封装:Keras-CNN
# 1. 数据预处理 import keras from keras import backend as K from keras.datasets import mnist from keras.m ...
- 吴裕雄--天生自然TensorFlow高层封装:使用TFLearn处理MNIST数据集实现LeNet-5模型
# 1. 通过TFLearn的API定义卷机神经网络. import tflearn import tflearn.datasets.mnist as mnist from tflearn.layer ...
- 吴裕雄--天生自然TensorFlow高层封装:使用TensorFlow-Slim处理MNIST数据集实现LeNet-5模型
# 1. 通过TensorFlow-Slim定义卷机神经网络 import numpy as np import tensorflow as tf import tensorflow.contrib. ...
- 吴裕雄--天生自然TensorFlow高层封装:Keras-RNN
# 1. 数据预处理. from keras.layers import LSTM from keras.datasets import imdb from keras.models import S ...
- 吴裕雄--天生自然TensorFlow高层封装:解决ImportError: cannot import name 'tf_utils'
将原来版本的keras卸载了,再安装2.1.5版本的keras就可以了.
随机推荐
- 167-PHP 文本分割函数str_split(二)
<?php $str='PHP is a very good programming language'; //定义一个字符串 $arr=explode(' ',$str,-3); //使用空格 ...
- 063-PHP函数按地址传参,交换数值函数
<?php function swap(&$x,&$y){ //定义交换数值函数 $temp=$x; $x=$y; $y=$temp; } $m=5; $n=15; echo & ...
- 【CS224n-2019学习笔记】Lecture 1: Introduction and Word Vectors
附上斯坦福cs224n-2019链接:https://web.stanford.edu/class/archive/cs/cs224n/cs224n.1194/ 文章目录 1.课程简单介绍 1.1 本 ...
- nidlist 问题
错误问题如下: 解决方案: Dao文件 boolean DeleteList(String nidList); 改为: boolean DeleteList(@Param("nidList& ...
- List中bean某属性值转换为list
List<类> lst = new ArrayList<>() ; lst.stream().map(类::get需要取得仠的属性名).collect(Collectors.t ...
- POJ1833 & POJ3187 & POJ3785
要是没有next_permutation这个函数,这些题觉得还不算特别水,不过也不一定,那样可能就会有相应的模板了.反正正是因为next_permutation这个函数,这些题包括之前的POJ1226 ...
- 04-String——课后动手动脑
1.请运行以下示例代码StringPool.java,查看输出结果.如何解释这样的输出结果?从中你能总结出什么? public class StringPool { public static voi ...
- docker创建redis容器
1.拉取最新的redis镜像 docker pull redis; 2.创建存放redis数据的目录 mkdir /redis/data 3.查询redis镜像id docker images; RE ...
- DP背包问题学习笔记及系列练习题
01 背包: 01背包:在M件物品中取出若干件物品放到背包中,每件物品对应的体积v1,v2,v3,....对应的价值为w1,w2,w3,,,,,每件物品最多拿一件. 和很多DP题一样,对于每一个物品, ...
- 【每日Scrum】第四天冲刺
一.计划会议内容 连接数据库报错,解决问题中. 二.任务看板 三.scrum讨论照片 四.产品的状态 无 五.任务燃尽图