#coding=utf-8
import tensorflow as tf
import numpy as np
import matplotlib .pyplot as plt
from tensorflow .examples .tutorials .mnist import input_data #define dataset mnist=input_data .read_data_sets ("/home/nvidia/Downloads/",one_hot= True ) #defien agruments batch_zize=20
iter=np.int(mnist .train.images.shape[0]/batch_zize )
print(iter ) #define learning_rate LEARNING_RATE_STEP=100
LEARNING_RATE_BASE=0.001
LEARNING_RATE_DECAY=0.99
global_step=tf.Variable (0,trainable= False )
learning_rate=tf.train.exponential_decay (learning_rate= LEARNING_RATE_BASE ,global_step= global_step ,decay_steps= LEARNING_RATE_STEP
,decay_rate= LEARNING_RATE_DECAY ,staircase= True ) #define tool def Weight_V(shape):
weight=tf.truncated_normal (shape=shape,stddev= 0.1)
return tf.Variable (weight ) def bias_V(shape):
bia_=tf.constant (shape=shape,value= 0.1)
return tf.Variable (bia_ ) def conv2d_(x,w):
return tf.nn.conv2d (x,filter= w,padding= "SAME",strides= [1,1,1,1]) def max_pool(x):
return tf.nn.max_pool (x,ksize= [1,2,2,1],strides=[1,2,2,1],padding="SAME") #define net x_input=tf.placeholder (shape=[None,784],dtype= tf.float32)
y_input=tf.placeholder (shape= [None,10],dtype= tf.float32) x =tf.reshape(x_input ,shape= [-1,28,28,1]) #
w_conv1=Weight_V(shape= [5,5,1,32])
b_conv1=bias_V(shape= [32])
c_conv1=tf.nn.relu (conv2d_(x ,w_conv1 )+b_conv1 )
m_conv1=max_pool(c_conv1 )
#14*14*32 w_conv2=Weight_V(shape= [5,5,32,64])
b_conv2=bias_V(shape= [64])
c_conv2=tf.nn.relu (conv2d_(m_conv1 ,w_conv2 )+b_conv2 )
m_conv2=max_pool(c_conv2 )
#7*7*64 w_fc1=Weight_V([7*7*64,1024])
b_fc1=bias_V(shape= [1024])
c_fc1=tf.reshape(m_conv2 ,[-1,7*7*64])
fc1=tf.nn.relu(tf.matmul(c_fc1 ,w_fc1 )+b_fc1 ) w_fc2=Weight_V(shape= [1024,10])
b_fc2=bias_V(shape= [10])
prediction=tf.nn.softmax (tf.matmul(fc1,w_fc2 )+b_fc2 ) #define # correct_accurcy=tf.equal(tf.argmax(prediction,axis=1),tf.argmax(y_input,axis=1))
# accurcy=tf.reduce_mean(tf.cast(correct_accurcy,dtype=tf.float32)) correct_accurcy=tf.equal (tf.argmax (prediction ,axis= 1),tf.argmax (y_input ,axis= 1)) accurcy=tf.reduce_mean (tf.cast(correct_accurcy ,dtype= tf.float32)) #traing backward
#
crosss_entropy =-tf.reduce_mean (y_input *tf.log(prediction ))
train_step=tf.train.GradientDescentOptimizer (learning_rate).minimize(crosss_entropy,global_step= global_step ) #initial global argumnets init=tf.global_variables_initializer () #SESS with tf.Session() as sess:
sess.run(init)
for i in range(21):
X,Y=mnist .test.next_batch(100)
for j in range(iter ):
xt,yt=mnist .train.next_batch (batch_zize )
sess.run(train_step ,feed_dict= {x_input :xt,y_input :yt}) acc=sess.run(accurcy ,feed_dict= {x_input :X,y_input :Y})
print(acc)

tensorflow-cnnn-mnist的更多相关文章

  1. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  2. Ubuntu16.04安装TensorFlow及Mnist训练

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com TensorFlow是Google开发的开源的深度学习框架,也是当前使用最广泛的深度学习框架. 一.安 ...

  3. 一个简单的TensorFlow可视化MNIST数据集识别程序

    下面是TensorFlow可视化MNIST数据集识别程序,可视化内容是,TensorFlow计算图,表(loss, 直方图, 标准差(stddev)) # -*- coding: utf-8 -*- ...

  4. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  5. 使用Tensorflow操作MNIST数据

    MNIST是一个非常有名的手写体数字识别数据集,在很多资料中,这个数据集都会被用作深度学习的入门样例.而TensorFlow的封装让使用MNIST数据集变得更加方便.MNIST数据集是NIST数据集的 ...

  6. TensorFlow RNN MNIST字符识别演示快速了解TF RNN核心框架

    TensorFlow RNN MNIST字符识别演示快速了解TF RNN核心框架 http://blog.sina.com.cn/s/blog_4b0020f30102wv4l.html

  7. 2、TensorFlow训练MNIST

    装载自:http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html TensorFlow训练MNIST 这个教程的目标读者是对机器学习和T ...

  8. 深入浅出TensorFlow(二):TensorFlow解决MNIST问题入门

    2017年2月16日,Google正式对外发布Google TensorFlow 1.0版本,并保证本次的发布版本API接口完全满足生产环境稳定性要求.这是TensorFlow的一个重要里程碑,标志着 ...

  9. Tensorflow之MNIST的最佳实践思路总结

    Tensorflow之MNIST的最佳实践思路总结   在上两篇文章中已经总结出了深层神经网络常用方法和Tensorflow的最佳实践所需要的知识点,如果对这些基础不熟悉,可以返回去看一下.在< ...

  10. TensorFlow训练MNIST报错ResourceExhaustedError

    title: TensorFlow训练MNIST报错ResourceExhaustedError date: 2018-04-01 12:35:44 categories: deep learning ...

随机推荐

  1. 时间转换(scanf的指定格式读入)

    给定一个12小时制的时间,请将其转换成24小时制的时间.说明:12小时制的午夜12:00:00AM,对应的24小时制时间为00:00:00.12小时制的中午12:00:00PM,对应的24小时制时间为 ...

  2. c++中比较好用的“黑科技”

    切入正题,上黑科技 一.黑科技函数(常用的我就不写了,例如sort函数) 1.next_permutation(a+1,a+1+n) a[1-n]全排列 2.reverse(a+1,a+1+n) 将a ...

  3. Unity ShaderLab 学习笔记(一)

    因为项目的问题,有个效果在iOS上面无法实现出来- 因为shader用的HardSurface的,在android上面跑起来没有问题- 以为在iOS上也不会有问题,但是悲剧啊,技能效果一片漆黑- 而且 ...

  4. SciPy 输入输出

    章节 SciPy 介绍 SciPy 安装 SciPy 基础功能 SciPy 特殊函数 SciPy k均值聚类 SciPy 常量 SciPy fftpack(傅里叶变换) SciPy 积分 SciPy ...

  5. HihoCoder第二周与POJ3630:Trie树的建立

    这又是两道一样的题,都是建立trie树的过程. HihoCoder第二周: 这里其实逻辑都很简单,主要在于数据结构struct的使用. #include <iostream> #inclu ...

  6. linux(centos6.9)下rpm方式安装mysql后mysql服务无法启动

    以下两种方式启动都报错:启动失败: [root@node03 ~]# service mysqld startMySQL Daemon failed to start.Starting mysqld: ...

  7. 编程题目:输入一个链表,输出该链表中倒数第k个节点

    两种方法 1.在链表的初始化数据中加入 num 数据, 每添加一个节点,num加1,每删除一个节点,num减1 查找倒数第k个元素,即 指向第一个节点的指针向后移动 num - k 步. 2.使用两个 ...

  8. Day5 - G - The Unique MST POJ - 1679

    Given a connected undirected graph, tell if its minimum spanning tree is unique. Definition 1 (Spann ...

  9. Day 4 -E - Catenyms POJ - 2337

    A catenym is a pair of words separated by a period such that the last letter of the first word is th ...

  10. String+、intern()、字符串常量池

    字符串连接符 "+"及字符串常量池实验.字符串final属性 结果预览 public class StrTest{ public static void main(String[] ...