吴裕雄 python 神经网络——TensorFlow 卷积神经网络手写数字图片识别
import os
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER1_NODE = 500 def get_weight_variable(shape, regularizer):
weights = tf.get_variable("weights", shape, initializer=tf.truncated_normal_initializer(stddev=0.1))
if(regularizer != None):
tf.add_to_collection('losses', regularizer(weights))
return weights def inference(input_tensor, regularizer):
with tf.variable_scope('layer1'):
weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)
biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.0))
layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases) with tf.variable_scope('layer2'):
weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)
biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.0))
layer2 = tf.matmul(layer1, weights) + biases
return layer2 BATCH_SIZE = 100
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
TRAINING_STEPS = 30000
MOVING_AVERAGE_DECAY = 0.99
MODEL_SAVE_PATH = "F:\\TensorFlowGoogle\\201806-github\\datasets\\MNIST_data\\"
MODEL_NAME = "mnist_model" def train(mnist):
# 定义输入输出placeholder。
x = tf.placeholder(tf.float32, [None, INPUT_NODE], name='x-input')
y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input')
regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
y = inference(x, regularizer)
global_step = tf.Variable(0, trainable=False) # 定义损失函数、学习率、滑动平均操作以及训练过程。
variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables())
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
cross_entropy_mean = tf.reduce_mean(cross_entropy)
loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY,staircase=True)
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
with tf.control_dependencies([train_step, variables_averages_op]):
train_op = tf.no_op(name='train')
# 初始化TensorFlow持久化类。
saver = tf.train.Saver()
with tf.Session() as sess:
tf.global_variables_initializer().run()
for i in range(TRAINING_STEPS):
xs, ys = mnist.train.next_batch(BATCH_SIZE)
_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
if i % 1000 == 0:
print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step) def main(argv=None):
mnist = input_data.read_data_sets("F:\\TensorFlowGoogle\\201806-github\\datasets\\MNIST_data\\", one_hot=True)
train(mnist) if __name__ == '__main__':
main()

吴裕雄 python 神经网络——TensorFlow 卷积神经网络手写数字图片识别的更多相关文章
- 用Keras搭建神经网络 简单模版(三)—— CNN 卷积神经网络(手写数字图片识别)
		# -*- coding: utf-8 -*- import numpy as np np.random.seed(1337) #for reproducibility再现性 from keras.d ... 
- 吴裕雄 python神经网络 手写数字图片识别(5)
		import kerasimport matplotlib.pyplot as pltfrom keras.models import Sequentialfrom keras.layers impo ... 
- 吴裕雄--天生自然 Tensorflow卷积神经网络:花朵图片识别
		import os import numpy as np import matplotlib.pyplot as plt from PIL import Image, ImageChops from ... 
- 吴裕雄--天生自然TensorFlow2教程:手写数字问题实战
		import tensorflow as tf from tensorflow import keras from keras import Sequential,datasets, layers, ... 
- 用Keras搭建神经网络 简单模版(四)—— RNN Classifier 循环神经网络(手写数字图片识别)
		# -*- coding: utf-8 -*- import numpy as np np.random.seed(1337) from keras.datasets import mnist fro ... 
- caffe+opencv3.3dnn模块 完成手写数字图片识别
		最近由于项目需要用到caffe,学习了下caffe的用法,在使用过程中也是遇到了些问题,通过上网搜索和问老师的方法解决了,在此记录下过程,方便以后查看,也希望能为和我一样的新手们提供帮助. 顺带附上老 ... 
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
		Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ... 
- 基于tensorflow的MNIST手写数字识别(二)--入门篇
		http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ... 
- 基于TensorFlow的MNIST手写数字识别-初级
		一:MNIST数据集 下载地址 MNIST是一个包含很多手写数字图片的数据集,一共4个二进制压缩文件 分别是test set images,test set labels,training se ... 
随机推荐
- VMware-Workstation-Full-12.5.9
			https://download3.vmware.com/software/wkst/file/VMware-Workstation-Full-12.5.9-7535481.x86_64.bundle ... 
- html+css 通信课上 2019。3.22
			数据通信 http协议:无状态.无连接.单向的应用层协议:采用请求/响应模型:通信请求只能由客户端发起,服务端对请求做出应答处理 服务器推送数据的解决方案:轮询( ajax) :让浏览器几秒就发送一次 ... 
- 吴裕雄--天生自然C++语言学习笔记:C++简介
			C++ 是一种中级语言,它是由 Bjarne Stroustrup 于 年在贝尔实验室开始设计开发的.C++ 进一步扩充和完善了 C 语言,是一种面向对象的程序设计语言.C++ 可运行于多种平台上,如 ... 
- JavaScript—封装animte动画函数
			封装Animte 动画函数 虽然可能以后的开发中可能根本不需要自己写,Jquery 给我们封装好了,或者用CSS3的一些属性达到这样的效果可能更简单. 我比较喜欢底层的算法实现,万变不离其中,这个逻辑 ... 
- 解决Android Studio的安装问题
			今天开始了android studio的下载与安装,我再官网上下载了Android studio,下载不难,运行出来可需要一定的时间,在中途中我遇到了一些问题 一:Build错误: 在我最开始下载完A ... 
- 覆盖(重写)&隐藏
			成员函数被重载的特征(1)相同的范围(在同一个类中): (2)函数名字相同: (3)参数不同: (4)virtual 关键字可有可无. 覆盖是指派生类函数覆盖基类函数,特征是(1)不同的范围(分别位于 ... 
- goahead调试经验
			一.参考网址 1.源码的github地址 2.Web开发之Goahead 二.技术细节 1.默认网页的存放目录和名称 1)目录:在main.c文件中有*rootWeb定义,如: 2)网页名:在mai ... 
- RedHat无法使用yum源问题
			RedHat下的yum是需要注册才能使用的 使用的话会提示: [root@test ~]# yum clean all Loaded plugins: product-id, refresh-pack ... 
- Heavy Light Decomposition
			Note 1.DFS1 mark all the depth mark fathers mark the heavy/light children mark the size of each subt ... 
- gitKraken取消/关闭全屏
			如果你找不到在哪里设置的 这是配置文件 注意 fullScreen 字段,改这个字段可以改变是不是全屏,改变之前先关闭软件, 文件目录 第二张图 
