一、前述

本文讲述用Tensorflow框架实现SoftMax模型识别手写数字集,来实现多分类。

同时对模型的保存和恢复做下示例。

二、具体原理

代码一:实现代码

#!/usr/bin/python
# -*- coding: UTF-8 -*-
# 文件名: 12_Softmax_regression.py from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf # mn.SOURCE_URL = "http://yann.lecun.com/exdb/mnist/"
my_mnist = input_data.read_data_sets("MNIST_data_bak/", one_hot=True)#从本地路径加载进来 # The MNIST data is split into three parts:
# 55,000 data points of training data (mnist.train)#训练集图片
# 10,000 points of test data (mnist.test), and#测试集图片
# 5,000 points of validation data (mnist.validation).#验证集图片 # Each image is 28 pixels by 28 pixels # 输入的是一堆图片,None表示不限输入条数,784表示每张图片都是一个784个像素值的一维向量
# 所以输入的矩阵是None乘以784二维矩阵
x = tf.placeholder(dtype=tf.float32, shape=(None, 784)) #x矩阵是m行*784列
# 初始化都是0,二维矩阵784乘以10个W值 #初始值最好不为0
W = tf.Variable(tf.zeros([784, 10]))#W矩阵是784行*10列
b = tf.Variable(tf.zeros([10]))#bias也必须有10个 y = tf.nn.softmax(tf.matmul(x, W) + b)# x*w 即为m行10列的矩阵就是y #预测值 # 训练
# labels是每张图片都对应一个one-hot的10个值的向量
y_ = tf.placeholder(dtype=tf.float32, shape=(None, 10))#真实值 m行10列
# 定义损失函数,交叉熵损失函数
# 对于多分类问题,通常使用交叉熵损失函数
# reduction_indices等价于axis,指明按照每行加,还是按照每列加
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),
reduction_indices=[1]))#指明按照列加和 一列是一个类别
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)#将损失函数梯度下降 #0.5是学习率 # 初始化变量
sess = tf.InteractiveSession()#初始化Session
tf.global_variables_initializer().run()#初始化所有变量
for _ in range(1000):
batch_xs, batch_ys = my_mnist.train.next_batch(100)#每次迭代取100行数据
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
#每次迭代内部就是求梯度,然后更新参数
# 评估 # tf.argmax()是一个从tensor中寻找最大值的序号 就是分类号,tf.argmax就是求各个预测的数字中概率最大的那一个
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) # 用tf.cast将之前correct_prediction输出的bool值转换为float32,再求平均
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # 测试
print(accuracy.eval({x: my_mnist.test.images, y_: my_mnist.test.labels})) # 总结
# 1,定义算法公式,也就是神经网络forward时的计算
# 2,定义loss,选定优化器,并指定优化器优化loss
# 3,迭代地对数据进行训练
# 4,在测试集或验证集上对准确率进行评测

代码二:保存模型

# 有时候需要把模型保持起来,有时候需要做一些checkpoint在训练中
# 以致于如果计算机宕机,我们还可以从之前checkpoint的位置去继续
# TensorFlow使得我们去保存和加载模型非常方便,仅需要去创建Saver节点在构建阶段最后
# 然后在计算阶段去调用save()方法 from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf # mn.SOURCE_URL = "http://yann.lecun.com/exdb/mnist/"
my_mnist = input_data.read_data_sets("MNIST_data_bak/", one_hot=True) # The MNIST data is split into three parts:
# 55,000 data points of training data (mnist.train)
# 10,000 points of test data (mnist.test), and
# 5,000 points of validation data (mnist.validation). # Each image is 28 pixels by 28 pixels # 输入的是一堆图片,None表示不限输入条数,784表示每张图片都是一个784个像素值的一维向量
# 所以输入的矩阵是None乘以784二维矩阵
x = tf.placeholder(dtype=tf.float32, shape=(None, 784))
# 初始化都是0,二维矩阵784乘以10个W值
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10])) y = tf.nn.softmax(tf.matmul(x, W) + b) # 训练
# labels是每张图片都对应一个one-hot的10个值的向量
y_ = tf.placeholder(dtype=tf.float32, shape=(None, 10))
# 定义损失函数,交叉熵损失函数
# 对于多分类问题,通常使用交叉熵损失函数
# reduction_indices等价于axis,指明按照每行加,还是按照每列加
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),
reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 初始化变量
init = tf.global_variables_initializer()
# 创建Saver()节点
saver = tf.train.Saver()#在运算之前,初始化之后 n_epoch = 1000 with tf.Session() as sess:
sess.run(init)
for epoch in range(n_epoch):
if epoch % 100 == 0:
save_path = saver.save(sess, "./my_model.ckpt")#每跑100次save一次模型,可以保证容错性
#直接保存session即可。 batch_xs, batch_ys = my_mnist.train.next_batch(100)#每一批次跑的数据 用m行数据/迭代次数来计算出来。
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) best_theta = W.eval()
save_path = saver.save(sess, "./my_model_final.ckpt")#保存最后的模型,session实际上保存的上面所有的数据

代码三:恢复模型

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf # mn.SOURCE_URL = "http://yann.lecun.com/exdb/mnist/"
my_mnist = input_data.read_data_sets("MNIST_data_bak/", one_hot=True) # The MNIST data is split into three parts:
# 55,000 data points of training data (mnist.train)
# 10,000 points of test data (mnist.test), and
# 5,000 points of validation data (mnist.validation). # Each image is 28 pixels by 28 pixels # 输入的是一堆图片,None表示不限输入条数,784表示每张图片都是一个784个像素值的一维向量
# 所以输入的矩阵是None乘以784二维矩阵
x = tf.placeholder(dtype=tf.float32, shape=(None, 784))
# 初始化都是0,二维矩阵784乘以10个W值
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10])) y = tf.nn.softmax(tf.matmul(x, W) + b)
# labels是每张图片都对应一个one-hot的10个值的向量
y_ = tf.placeholder(dtype=tf.float32, shape=(None, 10)) saver = tf.train.Saver() with tf.Session() as sess:
saver.restore(sess, "./my_model_final.ckpt")#把路径下面所有的session的数据加载进来 y y_head还有模型都保存下来了。 # 评估
# tf.argmax()是一个从tensor中寻找最大值的序号,tf.argmax就是求各个预测的数字中概率最大的那一个
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) # 用tf.cast将之前correct_prediction输出的bool值转换为float32,再求平均
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # 测试
print(accuracy.eval({x: my_mnist.test.images, y_: my_mnist.test.labels}))

【TensorFlow篇】--Tensorflow框架实现SoftMax模型识别手写数字集的更多相关文章

  1. 李宏毅 Keras手写数字集识别(优化篇)

    在之前的一章中我们讲到的keras手写数字集的识别中,所使用的loss function为‘mse’,即均方差.那我们如何才能知道所得出的结果是不是overfitting?我们通过运行结果中的trai ...

  2. 如何用卷积神经网络CNN识别手写数字集?

    前几天用CNN识别手写数字集,后来看到kaggle上有一个比赛是识别手写数字集的,已经进行了一年多了,目前有1179个有效提交,最高的是100%,我做了一下,用keras做的,一开始用最简单的MLP, ...

  3. Python实现神经网络算法识别手写数字集

    最近忙里偷闲学习了一点机器学习的知识,看到神经网络算法时我和阿Kun便想到要将它用Python代码实现.我们用了两种不同的方法来编写它.这里只放出我的代码. MNIST数据集基于美国国家标准与技术研究 ...

  4. Pytorch卷积神经网络识别手写数字集

    卷积神经网络目前被广泛地用在图片识别上, 已经有层出不穷的应用, 如果你对卷积神经网络充满好奇心,这里为你带来pytorch实现cnn一些入门的教程代码 #首先导入包 import torchfrom ...

  5. keras和tensorflow搭建DNN、CNN、RNN手写数字识别

    MNIST手写数字集 MNIST是一个由美国由美国邮政系统开发的手写数字识别数据集.手写内容是0~9,一共有60000个图片样本,我们可以到MNIST官网免费下载,总共4个.gz后缀的压缩文件,该文件 ...

  6. tensorflow笔记(四)之MNIST手写识别系列一

    tensorflow笔记(四)之MNIST手写识别系列一 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7436310.html ...

  7. tensorflow笔记(五)之MNIST手写识别系列二

    tensorflow笔记(五)之MNIST手写识别系列二 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7455233.html ...

  8. [机器学习] keras:MNIST手写数字体识别(DeepLearning 的 HelloWord程序)

    深度学习界的Hello Word程序:MNIST手写数字体识别 learn from(仍然是李宏毅老师<机器学习>课程):http://speech.ee.ntu.edu.tw/~tlka ...

  9. TensorFlow下利用MNIST训练模型识别手写数字

    本文将参考TensorFlow中文社区官方文档使用mnist数据集训练一个多层卷积神经网络(LeNet5网络),并利用所训练的模型识别自己手写数字. 训练MNIST数据集,并保存训练模型 # Pyth ...

随机推荐

  1. Canvas的基本用法

    canvas没有设置宽度和高度的时候,会初始化宽度:300像素和高度:150像素.可以使用CSS来定义大小,但在绘制时图像会伸缩以适应它的框架尺寸:如果CSS的尺寸与初始画布的比例不一致,它会出现扭曲 ...

  2. python之12306自动查票

      一.导读 本篇文章所采用的技术仅用于学习.研究,任何其他用途请自行承担后果. 12306自动查票使用到的python库主要是splinter,同时也涉及到查票的城市编码,具体的城市编码请在网络上搜 ...

  3. redis基础操作~~数据备份与恢复、数据安全、性能测试、客户端连接、分区

    数据备份与恢复 数据备份redis save 命令用于创建当前数据库的备份. redis 127.0.0.1:6379> SAVE OK 该命令将在 redis 安装目录中创建dump.rdb文 ...

  4. BZOJ_2738_矩阵乘法_整体二分

    BZOJ_2738_矩阵乘法_整体二分 Description 给你一个N*N的矩阵,不用算矩阵乘法,但是每次询问一个子矩形的第K小数. Input 第一行两个数N,Q,表示矩阵大小和询问组数: 接下 ...

  5. ./configure、make、make install

    这些都是典型的使用GNU的AUTOCONF和AUTOMAKE产生的程序的安装步骤 一.基本信息 1../configure 是用来检测你的安装平台的目标特征的.比如它会检测你是不是有CC或GCC,并不 ...

  6. MYSQL——数据库存储引擎!

    本人安装mysql版本为:mysql  Ver 14.14 Distrib 5.7.18, for Win64 (x86_64),查看mysql的版本号方式:cmd-->mysql --vers ...

  7. 【公告】MIP组件审核平台故障-影响说明

    故障通报 2017年8月8日 下午14:11,由于机器故障原因,MIP组件审核平台无法提供服务. 2017年8月8日 下午16:46,服务恢复. 故障影响 2017年8月8日下午13:00-14:00 ...

  8. eShopOnContainers 知多少[9]:Ocelot gateways

    引言 客户端与微服务的通信问题永远是一个绕不开的问题,对于小型微服务应用,客户端与微服务可以使用直连的方式进行通信,但对于对于大型的微服务应用我们将不得不面对以下问题: 如何降低客户端到后台的请求数量 ...

  9. 聊聊真实的 Android TV 开发技术栈

    智能电视越来越普及了,华为说四月发布智能电视跳票了,一加也说今后要布局智能电视,在智能电视方向,小米已经算是先驱了.但是还有不少开发把智能电视简单的理解成手机屏幕的放大,其实这两者并不一样. 一.序 ...

  10. Android版数据结构与算法(七):赫夫曼树

    版权声明:本文出自汪磊的博客,未经作者允许禁止转载. 近期忙着新版本的开发,此外正在回顾C语言,大部分时间没放在数据结构与算法的整理上,所以更新有点慢了,不过既然写了就肯定尽力将这部分完全整理好分享出 ...