#classification 分类问题
#例子 分类手写数字0-9
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#数据包,如果没有自动下载 number 0 to 9 data
mnist = input_data.read_data_sets('MNIST_data',one_hot=True) # 定义一个神经层
def add_layer(inputs, in_size, out_size, activation_function=None):
#add one more layer and return the output of the layer
Weights = tf.Variable(tf.random_normal([in_size, out_size]))
biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
Wx_plus_b = tf.matmul(inputs, Weights) + biases
if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_function(Wx_plus_b)
return outputs #用测试集来评估神经网络的准确度
def computer_accuracy(v_xs,v_ys):
global prediction
y_pre = sess.run(prediction,feed_dict={xs:v_xs})
#tf.argmax()返回最大数值的下标
#tf.equal(A, B)是对比这两个矩阵或者向量的相等的元素,如果是相等的那就返回True,反正返回False,返回的值的矩阵维度和A是一样的
'''
tf.argmax(input, axis=None, name=None, dimension=None)
此函数是对矩阵按行或列计算最大值
参数
input:输入Tensor
axis:0表示按列,1表示按行
name:名称
dimension:和axis功能一样,默认axis取值优先。新加的字段
返回:Tensor 一般是行或列的最大值下标向量
'''
'''
A = [[1,3,4,5,6]]
B = [[1,3,4,3,2]]
with tf.Session() as sess:
print(sess.run(tf.equal(A, B)))
输出:[[ True True True False False]]
'''
correct_prediction = tf.equal(tf.argmax(y_pre,1),tf.argmax(v_ys,1))#tf.argmax(y_pre,1)表示预测出的值,tf.argmax(v_ys,1)表示实际值
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))#将correct_prediction的数据格式转换为tf.float32
result = sess.run(accuracy,feed_dict={xs: v_xs, ys: v_ys})
return result #define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 784]) # none表示无论给多少个例子都行,784=28*28
ys = tf.placeholder(tf.float32, [None, 10]) #表示10个需要识别的数字 # add output layer
prediction = add_layer(xs, 784, 10 , activation_function=tf.nn.softmax) #the error between prediction and real data
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),reduction_indices=[1])) #loss function
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
#cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=xs,labels=ys)) sess = tf.Session()
#important step
sess.run(tf.initialize_all_variables()) for i in range(1000):
batch_xs,batch_ys = mnist.train.next_batch(100) #由于计算能力有限,每次只提取数据集的一部分
sess.run(train_step,feed_dict={xs: batch_xs, ys: batch_ys})
if i % 50 == 0:
#打印计算准确度
print(computer_accuracy(mnist.test.images,mnist.test.labels))

tensorflow学习之(九)classification 分类问题之分类手写数字0-9的更多相关文章

  1. 深度学习之PyTorch实战(3)——实战手写数字识别

    上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...

  2. 吴裕雄 PYTHON 神经网络——TENSORFLOW 双隐藏层自编码器设计处理MNIST手写数字数据集并使用TENSORBORD描绘神经网络数据2

    import os import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data os.envi ...

  3. 吴裕雄 PYTHON 神经网络——TENSORFLOW 单隐藏层自编码器设计处理MNIST手写数字数据集并使用TensorBord描绘神经网络数据

    import os import numpy as np import tensorflow as tf import matplotlib.pyplot as plt from tensorflow ...

  4. 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)

    莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...

  5. 一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)

    笔记整理者:王小草 笔记整理时间2017年2月24日 原文地址 http://blog.csdn.net/sinat_33761963/article/details/56837466?fps=1&a ...

  6. 3 TensorFlow入门之识别手写数字

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

  7. MindSpore手写数字识别初体验,深度学习也没那么神秘嘛

    摘要:想了解深度学习却又无从下手,不如从手写数字识别模型训练开始吧! 深度学习作为机器学习分支之一,应用日益广泛.语音识别.自动机器翻译.即时视觉翻译.刷脸支付.人脸考勤--不知不觉,深度学习已经渗入 ...

  8. 机器学习框架ML.NET学习笔记【5】多元分类之手写数字识别(续)

    一.概述 上一篇文章我们利用ML.NET的多元分类算法实现了一个手写数字识别的例子,这个例子存在一个问题,就是输入的数据是预处理过的,很不直观,这次我们要直接通过图片来进行学习和判断.思路很简单,就是 ...

  9. 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字

    TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...

随机推荐

  1. 排序大集合java

    今日面试被问到排序问题,发现自己的不足,特来查漏补缺: 首先是各大排序算法的总结表 排序算法大合集  排序算法 平均时间复杂度 最好情况 最坏情况 空间复杂度 稳定性 冒泡排序 Ο(n2) Ο(n) ...

  2. firewalld 防火墙配置

    1. firewalld-cmd 命令中使用的参数以及作用  --get-default-zone 查询默认的区域名称 --set-default-zone=<区域名称> 设置默认的区域, ...

  3. RabbitMQ消费者抛异常日志持续打印的问题

    场景 消费者接受消息,进行一系列处理,但是由于某些原因处理过程中该消费者的抛出了异常,并且不捕获(直接 throws IOException 抛出去): 由于抛出了IOException,那么这条消息 ...

  4. C/C++控制台接收不到鼠标消息-【解决办法】

    控制台编程中,使用了鼠标操作,遇到了控制台无法接收到鼠标消息的问题,可尝试一下办法解决 [win10系统] 在控制台标题栏右键-默认值-选项,将一下对勾取消 然后调用如下函数: HANDLE hIn ...

  5. unittest模块小结

    这次写的是unittest模块的测试用例,属于自动化的门槛,进去了基本算自动化入了门,测试内容很简单,模拟给url推送用户名.密码测试登录功能 先上代码: #login_test.py import ...

  6. 3D Math Keynote 2

    [3D Math Keynote 2] 1.方向(diretion),指的是前方朝向.方位(orientation),指的是head.pitch.roll. 2.欧拉角的缺点: 1)给定方位的表达式不 ...

  7. findbugs插件使用

    本文以idea的插件举例子 介绍 Findbugs是一个静态分析工具,它检查类或者JAR 文件,将字节码与一组缺陷模式进行对比以发现可能的问题. idea安装 自此,插件安装完毕,需要重启idea才生 ...

  8. I/O多路复用、协程、线程、进程

    select注册fd,阻塞,当有fd状态改变时返回,确认对应的fd,做下一步处理.简单来说就是先注册,注册完后休眠并设置一个定时器醒来查看,有事件就通知来取,进行后续动作,没事件就继续睡,再设闹钟.用 ...

  9. Lombok快速入门

    Lombok是简化开发的jar包 借用老师的图来说明

  10. PCIE4.0 简单介绍

    关于PCI-E的标准,可以从2003年说起,2003年推出了PCI-E 1.0标准,在三年之后就推出了PCI-E 2.0,而在4年后的2010年就推出了PCI-E 3.0,但是在2010年之后的6年里 ...