import tensorflow as tf
import numpy as np
# const = tf.constant(2.0, name='const')
# b = tf.placeholder(tf.float32, [None, 1], name='b')
# # b = tf.Variable(2.0, dtype=tf.float32, name='b')
# c = tf.Variable(1.0, dtype=tf.float32, name='c')
#
# d = tf.add(b, c, name='d')
# e = tf.add(c, const, name='e')
# a = tf.multiply(d, e, name='a')
# init = tf.global_variables_initializer()
#
# print(a)
# with tf.Session() as sess:
# sess.run(init)
# ans = sess.run(a, feed_dict={b: np.arange(0, 10)[:, np.newaxis]})
# print(a)
# print(ans) from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) # 载入数据集 learning_rate = 0.5 # 学习率
epochs = 10 # 训练10次所有的样本
batch_size = 100 # 每批训练的样本数 x = tf.placeholder(tf.float32, [None, 784]) # 为训练集的特征提供占位符
y = tf.placeholder(tf.float32, [None, 10]) # 为训练集的标签提供占位符 W1 = tf.Variable(tf.random_normal([784, 300], stddev=0.03), name='W1') # 初始化隐藏层的W1参数
b1 = tf.Variable(tf.random_normal([300]), name='b1') # 初始化隐藏层的b1参数
W2 = tf.Variable(tf.random_normal([300, 10], stddev=0.03), name='W2') # 初始化全连接层的W1参数
b2 = tf.Variable(tf.random_normal([10]), name='b2') # 初始化全连接层的b1参数 hidden_out = tf.add(tf.matmul(x, W1), b1) # 定义隐藏层的第一步运算
hidden_out = tf.nn.relu(hidden_out) # 定义隐藏层经过激活函数后的运算 y_ = tf.nn.softmax(tf.add(tf.matmul(hidden_out, W2), b2)) # 定义全连接层的输出运算 y_clipped = tf.clip_by_value(y_, 1e-10, 0.9999999)
cross_entropy = -tf.reduce_mean(tf.reduce_sum(y * tf.log(y_clipped) + (1 - y) * tf.log(1 - y_clipped), axis=1))
# 交叉熵 optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cross_entropy)
# 梯度下降优化器,传入的参数是交叉熵 init = tf.global_variables_initializer() # 所有参数初始化 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) # 返回true|false
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # 将true转化为1,false转化为0 # 开始训练
with tf.Session() as sess:
sess.run(init)
total_batch = int(len(mnist.train.labels) / batch_size) # 计算每个epoch要迭代几次
for epoch in range(epochs):
avg_cost = 0
for i in range(total_batch):
batch_x, batch_y = mnist.train.next_batch(batch_size=batch_size)
_, c = sess.run([optimizer, cross_entropy], feed_dict={x: batch_x, y: batch_y})
# 其实上面这一步只需要跑optimizer这个优化器就好了,因为交叉熵也会同时跑。
# 但是我们想要得到交叉熵的值来作为损失函数,所以还需要跑一个交叉熵。
avg_cost += c / total_batch
print("Epoch:", (epoch + 1), "cost = ", "{:.3f}".format(avg_cost)) # 这是每训练完所有样本得到的损失值
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}))
# 因为之前的计算已经把中间参数计算出来了,所以这里只用最后的计算测试集就行了

tensorflow手写数字识别(有注释)的更多相关文章

  1. Tensorflow手写数字识别(交叉熵)练习

    # coding: utf-8import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data #pr ...

  2. Tensorflow手写数字识别训练(梯度下降法)

    # coding: utf-8 import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data #p ...

  3. tensorflow 手写数字识别

    https://www.kaggle.com/kakauandme/tensorflow-deep-nn 本人只是负责将这个kernels的代码整理了一遍,具体还是请看原链接 import numpy ...

  4. Tensorflow手写数字识别---MNIST

    MNIST数据集:包含数字0-9的灰度图, 图片size为28x28.训练样本:55000,测试样本:10000,验证集:5000

  5. 卷积神经网络应用于tensorflow手写数字识别(第三版)

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_dat ...

  6. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  7. 手写数字识别 ----卷积神经网络模型官方案例注释(基于Tensorflow,Python)

    # 手写数字识别 ----卷积神经网络模型 import os import tensorflow as tf #部分注释来源于 # http://www.cnblogs.com/rgvb178/p/ ...

  8. 手写数字识别 ----Softmax回归模型官方案例注释(基于Tensorflow,Python)

    # 手写数字识别 ----Softmax回归模型 # regression import os import tensorflow as tf from tensorflow.examples.tut ...

  9. TensorFlow使用RNN实现手写数字识别

    学习,笔记,有时间会加注释以及函数之间的逻辑关系. # https://www.cnblogs.com/felixwang2/p/9190664.html # https://www.cnblogs. ...

随机推荐

  1. Java之数据类型讲解

    Java数据类型关系图 基本数据类型 从小到大的关系图: 图中从左向右的转换都是隐式转换,无需再代码中进行强制转换 : byte i = 12; System.out.println("by ...

  2. 在ASP.NET MVC中加载部分视图的方法及差别

    在视图里有多种方法可以加载部分视图,包括Partial() .Action().RenderPartial().RenderAction().RenderPage()方法.下面说明一下这些方法的差别. ...

  3. Tensorflow在python3.7版本的运行

    安装tensorflow pip install tensorflow==1.13.1 -i https://pypi.tuna.tsinghua.edu.cn/simple 可以在命令行 或者在py ...

  4. Python的字符串函数

    今天用了将近一天的时间去学习Python字符串函数 上午学了17个,下午学了23个(共计40) 详细内容请见菜鸟教程--Python3字符串--Python的字符串内建函数

  5. 2019 东软java面试笔试题 (含面试题解析)

    本人3年开发经验.18年年底开始跑路找工作,在互联网寒冬下成功拿到阿里巴巴.今日头条.东软等公司offer,岗位是Java后端开发,最终选择去了东软. 面试了很多家公司,感觉大部分公司考察的点都差不多 ...

  6. 在Centos6.5上部署kvm虚拟化技术

    KVM是什么? KVM 全称是 基于内核的虚拟机(Kernel-based Virtual Machine),它是一个 Linux 的一个内核模块,该内核模块使得 Linux 变成了一个 Hyperv ...

  7. OO第三单元作业总结

    OO第三单元作业总结--JML 第三单元的主题是JML规格的学习,其中的三次作业也是围绕JML规格的实现所展开的(虽然感觉作业中最难的还是如何正确适用数据结构以及如何正确地对于时间复杂度进行优化). ...

  8. 浏览器提示:源映射错误:request failed with status 404 源 URL:http://xxx.js 源映射 URL:jquery.min.map

    浏览器 jquery1.9.1min.js 报脚本错误 无jquery.min.map 文件 最近在浏览个人网站的时候就遇到了这个问题 我先说一下什么是source map文件. source map ...

  9. php xml解析

    XML处理是开发过程中经常遇到的,PHP对其也有很丰富的支持,本文只是对其中某几种解析技术做简要说明,包括:Xml parser, SimpleXML, XMLReader, DOMDocument. ...

  10. java开发手册-总结与补充

    1.分层领域模型规约 1.DO( Data Object):与数据库表结构一一对应,通过 DAO 层向上传输数据源对象. 2.DTO( Data Transfer Object):数据传输对象, Se ...