TensorFlow基础笔记(13) Mobilenet训练测试mnist数据
主要是四个文件
mnist_train.py
#coding: utf-8
import os import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data import mnist_inference BATCH_SIZE = 100
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
REGULARAZTION_RATE = 0.0001
TRAINING_STEPS =10000
MOVING_AVERAGE_DECAY = 0.99
MODEL_SAVE_PATH = "./mobilenet_v1_model/"
MODEL_NAME = "model.ckpt"
channels = 1 def train_MLP(mnist):
x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE) y = mnist_inference.inference_MLP(x, regularizer) global_step = tf.Variable(0, trainable=False) variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
variable_averages_op = variable_averages.apply(tf.trainable_variables())
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
cross_entropy_mean = tf.reduce_mean(cross_entropy)
loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step, mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY)
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step) with tf.control_dependencies([train_step, variable_averages_op]):
train_op = tf.no_op(name='train') saver = tf.train.Saver()
with tf.Session() as sess:
tf.initialize_all_variables().run() for i in range(TRAINING_STEPS):
xs, ys = mnist.train.next_batch(BATCH_SIZE)
_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys}) if i % 1000 == 0:
print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
# print os.path.join(MODEL_SAVE_PATH, MODEL_NAME)
saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step) def train_mobilenet(mnist):
x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE) #mobilenet 把输入数据变成与w矩阵同纬度的
x_image = tf.reshape(x, [-1,28,28,1])
x_image = tf.image.resize_image_with_crop_or_pad(x_image, 28*4,28*4)
y = mnist_inference.inference_mobilenet(x_image, regularizer) global_step = tf.Variable(0, trainable=False) variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
variable_averages_op = variable_averages.apply(tf.trainable_variables())
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
cross_entropy_mean = tf.reduce_mean(cross_entropy)
loss = cross_entropy_mean #+ tf.add_n(tf.get_collection('losses'))
learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step, mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY)
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step) with tf.control_dependencies([train_step, variable_averages_op]):
train_op = tf.no_op(name='train') saver = tf.train.Saver()
with tf.Session() as sess:
tf.initialize_all_variables().run() for i in range(TRAINING_STEPS):
xs, ys = mnist.train.next_batch(BATCH_SIZE)
_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys}) if i % 1000 == 0:
print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
# print os.path.join(MODEL_SAVE_PATH, MODEL_NAME)
saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
else:
print("After %d training step(s), loss on training batch is %g." % (step, loss_value)) def main(argv=None):
mnist = input_data.read_data_sets("../MNIST_data", one_hot=True)
train_mobilenet(mnist) if __name__ == '__main__':
tf.app.run()
mnist_eval.py
#coding: utf-8
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data import mnist_inference
import mnist_train #every 10 sec load the newest model
EVAL_INTERVAL_SECS = 10 def evaluate_MLP(mnist):
with tf.Graph().as_default() as g:
x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels} y = mnist_inference.inference(x, None) correcgt_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correcgt_prediction, tf.float32)) variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY)
variable_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variable_to_restore) #while True:
if 1:
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
#load the model
saver.restore(sess, ckpt.model_checkpoint_path)
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
accuracy_score = sess.run(accuracy, feed_dict=validate_feed)
print("After %s training steps, validation accuracy = %g" % (global_step, accuracy_score)) else:
print('No checkpoint file found')
return
#time.sleep(EVAL_INTERVAL_SECS) def evaluate_mobilenet(mnist):
with tf.Graph().as_default() as g:
x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input') #mobilenet 把输入数据变成与w矩阵同纬度的
x_image = tf.reshape(x, [-1,28,28,1])
x_image = tf.image.resize_image_with_crop_or_pad(x_image, 28*4,28*4)
y = mnist_inference.inference_mobilenet(x_image, None) correcgt_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correcgt_prediction, tf.float32)) variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY)
variable_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variable_to_restore) input = mnist.validation.images
label = mnist.validation.labels
batch_size = 100
TEST_STEPS = input.shape[0] / batch_size
sum_accury = 0.0
#while True:
if 1:
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
#load the model
saver.restore(sess, ckpt.model_checkpoint_path)
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
for i in range(int(TEST_STEPS)):
input_batch = input[i*batch_size : (i + 1)*batch_size, :]
label_batch = label[i*batch_size : (i + 1)*batch_size, :]
validate_feed = {x: input_batch, y_: label_batch}
# 取出部分数据测试
accuracy_score = sess.run(accuracy, feed_dict=validate_feed)
sum_accury += accuracy_score
print("test %s batch steps, validation accuracy = %g" % (i, accuracy_score)) else:
print('No checkpoint file found')
return
#time.sleep(EVAL_INTERVAL_SECS)
print("After %s training steps, all validation accuracy = %g" % (global_step, sum_accury / TEST_STEPS)) def main(argv=None):
mnist = input_data.read_data_sets("../MNIST_data", one_hot=True)
evaluate_mobilenet(mnist) if __name__ == '__main__':
tf.app.run()
mnist_inference.py
#coding: utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function import numpy as np
import tensorflow as tf import mobilenet_v1 slim = tf.contrib.slim #define the variables of nerual network
INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER1_NODE = 500 def get_weight_variable(shape, regularizer):
weights = tf.get_variable("weights", shape, initializer=tf.truncated_normal_initializer(stddev=0.1)) if regularizer != None:
tf.add_to_collection('losses', regularizer(weights)) return weights #define the forward network with MLPnet
def inference_MLP(input_tensor, regularizer):
with tf.variable_scope('layer1'):
weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)
biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.0))
layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases) with tf.variable_scope('layer2'):
weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)
biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.0))
layer2 = tf.matmul(layer1, weights) + biases return layer2 #define the forward network with mobilenet_v1
def inference_mobilenet(input_tensor, regularizer):
#inputs = tf.random_uniform((batch_size, height, width, 3))
with slim.arg_scope([slim.conv2d, slim.separable_conv2d],
normalizer_fn=slim.batch_norm):
logits, end_points = mobilenet_v1.mobilenet_v1(
input_tensor,
num_classes=OUTPUT_NODE,
dropout_keep_prob=0.8,
is_training=True,
min_depth=8,
depth_multiplier=1.0,
conv_defs=None,
prediction_fn=tf.contrib.layers.softmax,
spatial_squeeze=True,
reuse=None,
scope='MobilenetV1',
global_pool=False
) return logits
mobilenet_v1.py
从此处下载
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py
TensorFlow基础笔记(13) Mobilenet训练测试mnist数据的更多相关文章
- TensorFlow基础笔记(13) tf.name_scope tf.variable_scope学习
转载http://blog.csdn.net/jerr__y/article/details/60877873 1. 首先看看比较简单的 tf.name_scope(‘scope_name’). tf ...
- TensorFlow基础笔记(0) 参考资源学习文档
1 官方文档 https://www.tensorflow.org/api_docs/ 2 极客学院中文文档 http://www.tensorfly.cn/tfdoc/api_docs/python ...
- TensorFlow基础笔记(3) cifar10 分类学习
TensorFlow基础笔记(3) cifar10 分类学习 CIFAR-10 is a common benchmark in machine learning for image recognit ...
- 机器学习实战 - 读书笔记(13) - 利用PCA来简化数据
前言 最近在看Peter Harrington写的"机器学习实战",这是我的学习心得,这次是第13章 - 利用PCA来简化数据. 这里介绍,机器学习中的降维技术,可简化样品数据. ...
- TensorFlow学习笔记——LeNet-5(训练自己的数据集)
在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...
- Tensorflow学习笔记(一):MNIST机器学习入门
学习深度学习,首先从深度学习的入门MNIST入手.通过这个例子,了解Tensorflow的工作流程和机器学习的基本概念. 一 MNIST数据集 MNIST是入门级的计算机视觉数据集,包含了各种手写数 ...
- TensorFlow基础笔记(2) minist分类学习
(1) 最简单的神经网络分类器 # encoding: UTF-8 import tensorflow as tf from tensorflow.examples.tutorials.mnist i ...
- tensorflow学习笔记3:写一个mnist rpc服务
本篇做一个没有实用价值的mnist rpc服务,重点记录我在调试整合tensorflow和opencv时遇到的问题: 准备模型 mnist的基础模型结构就使用tensorflow tutorial给的 ...
- TensorFlow基础笔记(14) 网络模型的保存与恢复_mnist数据实例
http://blog.csdn.net/huachao1001/article/details/78502910 http://blog.csdn.net/u014432647/article/de ...
随机推荐
- 平衡二叉树AVL - 插入节点后旋转方法分析
平衡二叉树 AVL( 发明者为Adel'son-Vel'skii 和 Landis)是一种二叉排序树,其中每一个节点的左子树和右子树的高度差至多等于1. 首先我们知道,当插入一个节点,从此插入点到树根 ...
- nodejs request gb2312乱码的问题
http://www.cnblogs.com/linka/p/6658055.html https://cnodejs.org/topic/53142ef833dbcb076d007230 // np ...
- 在python中配置MySQL数据库
MySQL数据库(1) 尽管用文件形式将数据保存到磁盘,已经是一种不错的方式.但是,人们还是发明了更具有格式化特点,并且写入和读取更快速便捷的东西——数据库(如果阅读港台的资料,它们称之为“资料库”) ...
- java连接mysql数据库实例
做游戏客户端多一年多了,在大学学的java的SSH,基本上都忘完了,今天看了一下发现基本的连接数据库的都忘了...太可怕了这遗忘的速度. 所以写了个连接的例子吧..安装好mysql数据库之后新建了两张 ...
- 使用 bat cmd命令杀掉 删掉运行的程序
删掉所有xx.exe开启的进程 taskkill /f /im xx.exe 开启xx.exe start xx.exe 根据标题栏信息删除 taskkill /f /FI "windows ...
- Spring cloud子项目
目前来说spring主要集中于spring boot(用于开发微服务)和spring cloud相关框架的开发,我们从几张图着手理解,然后再具体介绍: spring cloud子项目包括: Sprin ...
- 关于VS中的调试信息输出
有时候一些项目的调试信息不方便输出到界面中,比如ASP.NET或者WPF之类的 可以使用Debug.WriteLine()等方法输出到"输出"窗口,不过"输出" ...
- javascript页面刷新的一些方法
在使用js刷新页面的时候,有时会遇到表单的重复提交问题 这时就需要一些强制刷新的办法,从网上大概搜了一下,js的刷新方法大致有以下几种, 刷新页面,不提示重新发送: window.location.r ...
- PHP 如何获取二维数组中某个key的集合(高性能查找)
分享下PHP 获取二维数组中某个key的集合的方法. 具体是这样的,如下一个二维数组,是从库中读取出来的. 代码: $user = array( 0 => array( 'id' => 1 ...
- Go源代码分析——http.ListenAndServe()是怎样工作的
Go对webserver的编写提供了很好的支持,标准库中提供了net/http包来方便编写server.很多教程和书籍在讲到用Go编写webserver时都会直接教新手用http包写一个最简单的hel ...