# 手写数字识别  ----Softmax回归模型
# regression
import os
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data = input_data.read_data_sets("/tmp/data/", one_hot=True) # 获取数据 mnist是一个轻量级的类,其中以Numpy数组的形式中存储着训练集、验证集、测试集。 # 一个对手写数字进行识别的模型。
# 思路:
# 1、将训练集中获取的手写数字图像进行某一统一方式(全部按行或全部按列)的展开,
# 得到一个长向量(这是为了利用softmax做一维的回归,不过损失了二维信息),
# 用一个二维张量来索引某一个样本中的某一像素。
# 2、softmax模型:用来给不同的对象分配概率(即使在更精细的模型中,最后一步,往往也需要用softmax来分配概率)
# 两步:
# ① 加权求和,并引入偏置
# 对于给定输入图片x,其代表图像为数字i的证据为
# evidencei =∑i(wi, jxj) + bi
# evidence_i =∑_i(w_{i, j}x_j)+b_ievidencei=∑i​(wi, jxj) + bi
# ② 用softmax函数将evidence转换成概率,即
# y = softmax(evidence)
# y = softmax(evidence)
# y = softmax(evidence)
# 将输入值当成幂指数求值,再正则化这些结果
# 更紧凑的写法为
# y = softmax(Wx + b)
# y = softmax(Wx + b)
# y = softmax(Wx + b)
# 3、为了节省在python外使用别的语言进行复杂矩阵运算带来的开销,TensorFlow做出的优化为,先用图描述一系列可交互的操作,最后统一放在python外执行。
# 用占位符placeholder来描述这些可交互的单元:
# ---------------------
# 作者:Crystal 
# 来源:CSDN
# 原文:https: // blog.csdn.net / weixin_43226400 / article / details / 82749769
# 版权声明:本文为博主原创文章,转载请附上博文链接! #http://www.cnblogs.com/rgvb178/p/6052541.html 相关说明
# Softmax Regression Model Softmax回归模型
def regression(x):
W = tf.Variable(tf.zeros([784, 10]), name="W")
b = tf.Variable(tf.zeros([10]), name="b")
y = tf.nn.softmax(tf.matmul(x, W) + b)
# print(y)
return y, [W, b] # model 声明占位符
with tf.variable_scope("regression"):
x = tf.placeholder(tf.float32, [None, 784])
y, variables = regression(x) # 用交叉熵(cross - entropy)来评判模型的好坏,其表达式为
# Hy′(y) =−∑iy′ilog(yi)
# 其中y是预测的概率分布,y’是实际的概率分布(即训练集对应的真实标签,是一个one - hotvector)定义 # train 开始训练模型
y_ = tf.placeholder("float", [None, 10])
# 计算交叉熵
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
# tensorflow可以自动利用反向传播算法,根据选择的优化器来最小化你的目标函数
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
# tf.argmax给出对象在某一维度上最大值所对应的索引值,可以用来判断预测是否准确,即
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
# equal函数返回布尔值,用cast函数转化为浮点数后求均值来计算正确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) saver = tf.train.Saver(variables)
with tf.Session() as sess:
# 初始化操作
sess.run(tf.global_variables_initializer())
for _ in range(10000):
batch_xs, batch_ys = data.train.next_batch(100)
# 此为随机梯度下降训练,每次训练随机抓取训练集中的100个数据作为一个batch
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) # 计算学习到的模型在训练集上的准确率
print(sess.run(accuracy, feed_dict={x: data.test.images, y_: data.test.labels})) # 保存训练结果
# print(os.path.join(os.path.dirname(__file__), 'data', 'regression.ckpt'))
# //绝对路径包含中文字符可能导致路径不可用 相对路径:'mnist/data/regression.ckpt'
path = saver.save(
sess, 'mnist/data/regression.ckpt',
write_meta_graph=False, write_state=False)
print("Saved:", path) # path = saver.save(
# sess, os.path.join(os.path.dirname(__file__), 'mnist\data', 'regression.ckpt'),write_meta_graph=False, write_state=False)
# print("Saved:", path) write_meta_graph=False, write_state=False)
print("Saved:", path) # path = saver.save(
# sess, os.path.join(os.path.dirname(__file__), 'mnist\data', 'regression.ckpt'),write_meta_graph=False, write_state=False)
# print("Saved:", path)

手写数字识别 ----Softmax回归模型官方案例注释(基于Tensorflow,Python)的更多相关文章

  1. 手写数字识别 ----卷积神经网络模型官方案例注释(基于Tensorflow,Python)

    # 手写数字识别 ----卷积神经网络模型 import os import tensorflow as tf #部分注释来源于 # http://www.cnblogs.com/rgvb178/p/ ...

  2. mnist手写数字识别(Logistic回归)

    import numpy as np from sklearn.neural_network import MLPClassifier from sklearn.linear_model import ...

  3. CNN:人工智能之神经网络算法进阶优化,六种不同优化算法实现手写数字识别逐步提高,应用案例自动驾驶之捕捉并识别周围车牌号—Jason niu

    import mnist_loader from network3 import Network from network3 import ConvPoolLayer, FullyConnectedL ...

  4. 手写数字识别 ----在已经训练好的数据上根据28*28的图片获取识别概率(基于Tensorflow,Python)

    通过: 手写数字识别  ----卷积神经网络模型官方案例详解(基于Tensorflow,Python) 手写数字识别  ----Softmax回归模型官方案例详解(基于Tensorflow,Pytho ...

  5. 基于Numpy的神经网络+手写数字识别

    基于Numpy的神经网络+手写数字识别 本文代码来自Tariq Rashid所著<Python神经网络编程> 代码分为三个部分,框架如下所示: # neural network class ...

  6. Softmax用于手写数字识别(Tensorflow实现)-个人理解

    softmax函数的作用   对于分类方面,softmax函数的作用是从样本值计算得到该样本属于各个类别的概率大小.例如手写数字识别,softmax模型从给定的手写体图片像素值得出这张图片为数字0~9 ...

  7. 【百度飞桨】手写数字识别模型部署Paddle Inference

    从完成一个简单的『手写数字识别任务』开始,快速了解飞桨框架 API 的使用方法. 模型开发 『手写数字识别』是深度学习里的 Hello World 任务,用于对 0 ~ 9 的十类数字进行分类,即输入 ...

  8. 使用L2正则化和平均滑动模型的LeNet-5MNIST手写数字识别模型

    使用L2正则化和平均滑动模型的LeNet-5MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: T ...

  9. 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型

    持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...

随机推荐

  1. 细说Cookie--转

    Cookie虽然是个很简单的东西,但它又是WEB开发中一个很重要的客户端数据来源,而且它可以实现扩展性很好的会话状态, 所以我认为每个WEB开发人员都有必要对它有个清晰的认识.本文将对Cookie这个 ...

  2. 077、跨主机使用Rex-Ray volume (2019-04-24 周三)

    参考https://www.cnblogs.com/CloudMan6/p/7630205.html   上一节我们在docker1上创建mysql容器,并使用了 Rex-Ray volume mys ...

  3. [笔记]JS flat and flatMap

    原文 flat()接收一个数组(这个数组中的某些item本身也是一个数组),返回一个新的一维数组(如果没有特别指定depth参数的话返回一维数组). const nestedArraysOhMy = ...

  4. Spring ElasticsearchTemplate 经纬度按距离排序

    es实体,用 @GeoPointField 注解,值为:中间逗号隔开,如 29.477000,119.278536(经度, 维度) @Document(indexName = "v_inte ...

  5. Keras的一些功能函数

    摘自: https://www.cnblogs.com/Anita9002/p/8136357.html 1.模型的信息提取 # 节点信息提取 config = model.get_config() ...

  6. python中元组/列表/字典/集合

    转自:https://blog.csdn.net/lobo_seeworld/article/details/79404566

  7. 解决微信小程序wepy真机预览跟本地表现不一样,数据变化了视图没变化

    当时搜了很多相关问题都没找到相似的 只看到有这个相似的描述wepy在onLoad里修改data-object的值页面不渲染 ,通过setData解决的. 但是这个还不是根本的解决办法,有些地方用set ...

  8. Java继承详解

    目录 前言 继承的格式: 继承的特点: 继承的优缺点 继承的注意点(重要) 继承的使用 前言 类是对对象的抽象,具有共同属性和行为的许多对象抽象出一个类. 例如:有三个学生小明,小红,小李都有姓名,年 ...

  9. 医学图像数据(一)——TCIA基本介绍

    1.介绍 The Cancer Imaging Archive (TCIA)是癌症研究的医学图像的开放获取数据库.该网站由国家癌症研究所(NCI)癌症影像计划资助,合同由阿肯色大学医学科学院管理.存档 ...

  10. Visual Studio 2013 SDK 扩展之简介

    Release Notes:[发行说明]启动记事本的扩展,以管理员身份运行验证通过. Getting Started Guide:[入门]使用[Ctrl + 1]更快捷打开记事本 More Info ...