import tensorflow as tf
import numpy as np
# const = tf.constant(2.0, name='const')
# b = tf.placeholder(tf.float32, [None, 1], name='b')
# # b = tf.Variable(2.0, dtype=tf.float32, name='b')
# c = tf.Variable(1.0, dtype=tf.float32, name='c')
#
# d = tf.add(b, c, name='d')
# e = tf.add(c, const, name='e')
# a = tf.multiply(d, e, name='a')
# init = tf.global_variables_initializer()
#
# print(a)
# with tf.Session() as sess:
# sess.run(init)
# ans = sess.run(a, feed_dict={b: np.arange(0, 10)[:, np.newaxis]})
# print(a)
# print(ans) from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) # 载入数据集 learning_rate = 0.5 # 学习率
epochs = 10 # 训练10次所有的样本
batch_size = 100 # 每批训练的样本数 x = tf.placeholder(tf.float32, [None, 784]) # 为训练集的特征提供占位符
y = tf.placeholder(tf.float32, [None, 10]) # 为训练集的标签提供占位符 W1 = tf.Variable(tf.random_normal([784, 300], stddev=0.03), name='W1') # 初始化隐藏层的W1参数
b1 = tf.Variable(tf.random_normal([300]), name='b1') # 初始化隐藏层的b1参数
W2 = tf.Variable(tf.random_normal([300, 10], stddev=0.03), name='W2') # 初始化全连接层的W1参数
b2 = tf.Variable(tf.random_normal([10]), name='b2') # 初始化全连接层的b1参数 hidden_out = tf.add(tf.matmul(x, W1), b1) # 定义隐藏层的第一步运算
hidden_out = tf.nn.relu(hidden_out) # 定义隐藏层经过激活函数后的运算 y_ = tf.nn.softmax(tf.add(tf.matmul(hidden_out, W2), b2)) # 定义全连接层的输出运算 y_clipped = tf.clip_by_value(y_, 1e-10, 0.9999999)
cross_entropy = -tf.reduce_mean(tf.reduce_sum(y * tf.log(y_clipped) + (1 - y) * tf.log(1 - y_clipped), axis=1))
# 交叉熵 optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cross_entropy)
# 梯度下降优化器,传入的参数是交叉熵 init = tf.global_variables_initializer() # 所有参数初始化 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) # 返回true|false
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # 将true转化为1,false转化为0 # 开始训练
with tf.Session() as sess:
sess.run(init)
total_batch = int(len(mnist.train.labels) / batch_size) # 计算每个epoch要迭代几次
for epoch in range(epochs):
avg_cost = 0
for i in range(total_batch):
batch_x, batch_y = mnist.train.next_batch(batch_size=batch_size)
_, c = sess.run([optimizer, cross_entropy], feed_dict={x: batch_x, y: batch_y})
# 其实上面这一步只需要跑optimizer这个优化器就好了,因为交叉熵也会同时跑。
# 但是我们想要得到交叉熵的值来作为损失函数,所以还需要跑一个交叉熵。
avg_cost += c / total_batch
print("Epoch:", (epoch + 1), "cost = ", "{:.3f}".format(avg_cost)) # 这是每训练完所有样本得到的损失值
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}))
# 因为之前的计算已经把中间参数计算出来了,所以这里只用最后的计算测试集就行了

tensorflow手写数字识别(有注释)的更多相关文章

  1. Tensorflow手写数字识别(交叉熵)练习

    # coding: utf-8import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data #pr ...

  2. Tensorflow手写数字识别训练(梯度下降法)

    # coding: utf-8 import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data #p ...

  3. tensorflow 手写数字识别

    https://www.kaggle.com/kakauandme/tensorflow-deep-nn 本人只是负责将这个kernels的代码整理了一遍,具体还是请看原链接 import numpy ...

  4. Tensorflow手写数字识别---MNIST

    MNIST数据集:包含数字0-9的灰度图, 图片size为28x28.训练样本:55000,测试样本:10000,验证集:5000

  5. 卷积神经网络应用于tensorflow手写数字识别(第三版)

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_dat ...

  6. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  7. 手写数字识别 ----卷积神经网络模型官方案例注释(基于Tensorflow,Python)

    # 手写数字识别 ----卷积神经网络模型 import os import tensorflow as tf #部分注释来源于 # http://www.cnblogs.com/rgvb178/p/ ...

  8. 手写数字识别 ----Softmax回归模型官方案例注释(基于Tensorflow,Python)

    # 手写数字识别 ----Softmax回归模型 # regression import os import tensorflow as tf from tensorflow.examples.tut ...

  9. TensorFlow使用RNN实现手写数字识别

    学习,笔记,有时间会加注释以及函数之间的逻辑关系. # https://www.cnblogs.com/felixwang2/p/9190664.html # https://www.cnblogs. ...

随机推荐

  1. Kafka理解

    1. 引言 最近使用Kafka做消息队列时,完成了基本的消息发送与接收,已上线运行.一方面防止出现Bug时自己不能及时定位问题,一方面网上的配置可能还可以更加优化,决定去了解下Kafka. 2. 配置 ...

  2. java转换编码报错java.lang.IllegalArgumentException: URLDecoder: Illegal hex characters in escape (%) pattern

    Exception in thread "main" java.lang.IllegalArgumentException: URLDecoder: Illegal hex cha ...

  3. 1 matplotlib绘制折线图

    from matplotlib import pyplot as plt #设置图形大小 plt.figure(figsize=(20,8),dpi=80) plt.plot(x,y,color=&q ...

  4. SpringDataRedis

    一.简介 1.SpringData和Redis Redis将数据存储到内存的,速度快.可以解决请求mysql数据库过多而导致mysql崩溃的问题. SpringData是专门用来控制Redis的工具, ...

  5. Vue父组件向子组件传值以及data和props的区别

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/xukongjing1/article/ ...

  6. js判断数组中是否有重复元素

    方法一:正则 var ary = new Array("111","ff","222","aa","222&q ...

  7. 【等待事件】等待事件系列(5.1)--Enqueue(队列等待)

    [等待事件]等待事件系列(5.1)--Enqueue(队列等待)   1  BLOG文档结构图   2  前言部分   2.1  导读和注意事项 各位技术爱好者,看完本文后,你可以掌握如下的技能,也可 ...

  8. Redis中如何发现并优化big key?

    Redis中的大key一直是重点需要优化的对象,big key既占用比较多的内存,也可能占用比较多的网卡资源,造成redis阻塞,因此我们需要找到这些big key进行优化 一.寻找big key 通 ...

  9. js 判断浏览器是pc端还是移动端

    if(/Android|webOS|iPhone|iPod|BlackBerry/i.test(navigator.userAgent)) { //说明是移动端 } else { //说明是pc端 }

  10. 云计算与大数据实验:Hbase shell操作成绩表

    [实验目的] 1)了解hbase服务 2)学会hbase shell命令操作成绩表 [实验原理] HBase是一个分布式的.面向列的开源数据库,它利用Hadoop HDFS作为其文件存储系统,利用Ha ...