import tensorflow as tf
import numpy as np
# const = tf.constant(2.0, name='const')
# b = tf.placeholder(tf.float32, [None, 1], name='b')
# # b = tf.Variable(2.0, dtype=tf.float32, name='b')
# c = tf.Variable(1.0, dtype=tf.float32, name='c')
#
# d = tf.add(b, c, name='d')
# e = tf.add(c, const, name='e')
# a = tf.multiply(d, e, name='a')
# init = tf.global_variables_initializer()
#
# print(a)
# with tf.Session() as sess:
# sess.run(init)
# ans = sess.run(a, feed_dict={b: np.arange(0, 10)[:, np.newaxis]})
# print(a)
# print(ans) from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) # 载入数据集 learning_rate = 0.5 # 学习率
epochs = 10 # 训练10次所有的样本
batch_size = 100 # 每批训练的样本数 x = tf.placeholder(tf.float32, [None, 784]) # 为训练集的特征提供占位符
y = tf.placeholder(tf.float32, [None, 10]) # 为训练集的标签提供占位符 W1 = tf.Variable(tf.random_normal([784, 300], stddev=0.03), name='W1') # 初始化隐藏层的W1参数
b1 = tf.Variable(tf.random_normal([300]), name='b1') # 初始化隐藏层的b1参数
W2 = tf.Variable(tf.random_normal([300, 10], stddev=0.03), name='W2') # 初始化全连接层的W1参数
b2 = tf.Variable(tf.random_normal([10]), name='b2') # 初始化全连接层的b1参数 hidden_out = tf.add(tf.matmul(x, W1), b1) # 定义隐藏层的第一步运算
hidden_out = tf.nn.relu(hidden_out) # 定义隐藏层经过激活函数后的运算 y_ = tf.nn.softmax(tf.add(tf.matmul(hidden_out, W2), b2)) # 定义全连接层的输出运算 y_clipped = tf.clip_by_value(y_, 1e-10, 0.9999999)
cross_entropy = -tf.reduce_mean(tf.reduce_sum(y * tf.log(y_clipped) + (1 - y) * tf.log(1 - y_clipped), axis=1))
# 交叉熵 optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cross_entropy)
# 梯度下降优化器,传入的参数是交叉熵 init = tf.global_variables_initializer() # 所有参数初始化 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) # 返回true|false
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # 将true转化为1,false转化为0 # 开始训练
with tf.Session() as sess:
sess.run(init)
total_batch = int(len(mnist.train.labels) / batch_size) # 计算每个epoch要迭代几次
for epoch in range(epochs):
avg_cost = 0
for i in range(total_batch):
batch_x, batch_y = mnist.train.next_batch(batch_size=batch_size)
_, c = sess.run([optimizer, cross_entropy], feed_dict={x: batch_x, y: batch_y})
# 其实上面这一步只需要跑optimizer这个优化器就好了,因为交叉熵也会同时跑。
# 但是我们想要得到交叉熵的值来作为损失函数,所以还需要跑一个交叉熵。
avg_cost += c / total_batch
print("Epoch:", (epoch + 1), "cost = ", "{:.3f}".format(avg_cost)) # 这是每训练完所有样本得到的损失值
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}))
# 因为之前的计算已经把中间参数计算出来了,所以这里只用最后的计算测试集就行了

tensorflow手写数字识别(有注释)的更多相关文章

  1. Tensorflow手写数字识别(交叉熵)练习

    # coding: utf-8import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data #pr ...

  2. Tensorflow手写数字识别训练(梯度下降法)

    # coding: utf-8 import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data #p ...

  3. tensorflow 手写数字识别

    https://www.kaggle.com/kakauandme/tensorflow-deep-nn 本人只是负责将这个kernels的代码整理了一遍,具体还是请看原链接 import numpy ...

  4. Tensorflow手写数字识别---MNIST

    MNIST数据集:包含数字0-9的灰度图, 图片size为28x28.训练样本:55000,测试样本:10000,验证集:5000

  5. 卷积神经网络应用于tensorflow手写数字识别(第三版)

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_dat ...

  6. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  7. 手写数字识别 ----卷积神经网络模型官方案例注释(基于Tensorflow,Python)

    # 手写数字识别 ----卷积神经网络模型 import os import tensorflow as tf #部分注释来源于 # http://www.cnblogs.com/rgvb178/p/ ...

  8. 手写数字识别 ----Softmax回归模型官方案例注释(基于Tensorflow,Python)

    # 手写数字识别 ----Softmax回归模型 # regression import os import tensorflow as tf from tensorflow.examples.tut ...

  9. TensorFlow使用RNN实现手写数字识别

    学习,笔记,有时间会加注释以及函数之间的逻辑关系. # https://www.cnblogs.com/felixwang2/p/9190664.html # https://www.cnblogs. ...

随机推荐

  1. 小米9安装charles证书

    一.打开你 mac 中对应的 charles 二.点击右上角的help按钮,打开帮助弹窗 三.点击帮助弹窗中的SSL Proxying,选择save charles root certificatio ...

  2. CentOS7 搭建 NFS 服务器

    环境: 系统版本:CentOS 7.5 一.服务端配置 1.配置环境 关闭防火墙服务 # 停止并禁用防火墙 $ systemctl stop firewalld $ systemctl disable ...

  3. (转)消息队列 Kafka 的基本知识及 .NET Core 客户端

    原文地址:https://www.cnblogs.com/savorboard/p/dotnetcore-kafka.html 前言 最新项目中要用到消息队列来做消息的传输,之所以选着 Kafka 是 ...

  4. MVC视图中 TextBoxFor 数据格式化

    @Html.TextBoxFor(m => m.Birthday,"{0:yyyy-MM-dd}", new { @class = "m-wrap small&qu ...

  5. 【开发工具】- Java开发必知工具

    压力测试工具_JMeter 作用 1.能够对HTTP和FTP服务器进行压力和性能测试, 也可以对任何数据库进行同样的测试(通过JDBC). 2.完全的可移植性和100% 纯java. 3.完全 Swi ...

  6. 内核加载错误module license

    出现如下错误: module_name: Unknown symbol "symbol_name" tail /var/log/messages查看具体错误 出现如下错误: Unk ...

  7. xadmin集成DjangoUeditor,以及编辑器的视频路径配置

    稍微讲一下DjangoUeditor的配置,因为之前去找配置的时候东拼西凑的,所以自己写一下自己一步步配置的过程.首先我是再github上去下载下来,因为是当作第三方插件集成到xadmin中,所以不用 ...

  8. SpringCloud-Zuul源码分析和路由改造

    在使用SpringCloud的时候准备使用Zuul作为微服务的网关,Zuul的默认路由方式主要是两种,一种是在配置 文件里直接指定静态路由,另一种是根据注册在Eureka的服务名自动匹配.比如如果有一 ...

  9. Xshell连接虚拟机中的Ubuntu

    虚拟机中安装好Ubuntu系统后使用cmd测试ping 设置xshell的连接ip 连接 连接失败 安装openssh-server sudo apt install openssh-server 再 ...

  10. Python面向对象三要素-继承(Inheritance)

    Python面向对象三要素-继承(Inheritance) 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 一.继承概述 1>.基本概念 前面我们学习了Python的面向对象三 ...