博文主要内容有:

1.softmax regression的TensorFlow实现代码(教科书级的代码注释)

2.该实现中的函数总结

平台:

1.windows 10 64位

2.Anaconda3-4.2.0-Windows-x86_64.exe (当时TF还不支持python3.6,又懒得在高版本的anaconda下配置多个python环境,于是装了一个3-4.2.0(默认装python3.5),建议装anaconda3的最新版本,TF1.2.0版本已经支持python3.6!)

3.TensorFlow1.1.0

先贴代码,函数

# -*- coding: utf-8 -*-
"""
Created on Mon Jun 12 16:36:43 2017 @author: ASUS
"""
#import itchat
#from PIL import Image
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('MNIST_data/', one_hot = True) # 是一个tensorflow内部的变量
print(mnist.train.images.shape, mnist.train.labels.shape) # 训练集形状, 标签形状 sess = tf.InteractiveSession() # sess被注册为默认的session
x = tf.placeholder(tf.float32, [None, 784]) # Placeholder是输入数据的地方 #-------------给weights和bias创建Variable对象-------------
# Variable是用来存储模型参数,与存储数据的tensor不同,tensor一旦使用掉就消失
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
# 计算 softmax 输出 y 其中x形状是[None, 784],None是为batch数而准备的
y = tf.nn.softmax(tf.matmul(x, W) + b) #-------------交叉熵损失函数-------------
# y_存放真实标签
y_ = tf.placeholder(tf.float32, [None, 10])
# recude_mean和reduce_sum意思缩减维度的均值,以及缩减维度的求和
# reduce_mean在这里是对一个batch进行求均值
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices = [1])) #-------------优化算法设置-------------
# 采用梯度下降的优化方法,
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) #---------------全局参数初始化-------------------
tf.global_variables_initializer().run() #---------------迭代地执行训练操作-------------------
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
train_step.run({x: batch_xs, y_: batch_ys}) #---------------准确率验证-------------------
# tf.argmax是寻找tensor中值最大的元素的序号 ,在此用来判断类别
# tf.equal是判断两个变量是否相等, 返回的是bool值
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_, 1))
# tf.cast用于数据类型转换
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(accuracy)
print(accuracy.eval({x:mnist.test.images, y_: mnist.test.labels}))

其中用到的函数总结:

1. sess = tf.InteractiveSession() 将sess注册为默认的session

2. tf.placeholder() , Placeholder是输入数据的地方,也称为占位符,通俗的理解就是给输入数据(此例中的图片x)和真实标签(y_)提供一个入口,或者是存放地。(个人理解,可能不太正确,后期对TF有深入认识的话再回来改~~)

3. tf.Variable() Variable是用来存储模型参数,与存储数据的tensor不同,tensor一旦使用掉就消失

4. tf.matmul() 矩阵相乘函数

5. tf.reduce_mean 和tf.reduce_sum 是缩减维度的计算均值,以及缩减维度的求和

6. tf.argmax() 是寻找tensor中值最大的元素的序号 ,此例中用来判断类别

7. tf.cast() 用于数据类型转换

(ps: 可将每篇博文最后的函数总结复制到一个word文档,便于日后查找)

【TensorFlow-windows】(一)实现Softmax Regression进行手写数字识别(mnist)的更多相关文章

  1. TensorFlow 之 手写数字识别MNIST

    官方文档: MNIST For ML Beginners - https://www.tensorflow.org/get_started/mnist/beginners Deep MNIST for ...

  2. keras框架的MLP手写数字识别MNIST,梳理?

    keras框架的MLP手写数字识别MNIST 代码: # coding: utf-8 # In[1]: import numpy as np import pandas as pd from kera ...

  3. keras和tensorflow搭建DNN、CNN、RNN手写数字识别

    MNIST手写数字集 MNIST是一个由美国由美国邮政系统开发的手写数字识别数据集.手写内容是0~9,一共有60000个图片样本,我们可以到MNIST官网免费下载,总共4个.gz后缀的压缩文件,该文件 ...

  4. TensorFlow(十二):使用RNN实现手写数字识别

    上代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #载入数据集 mnist ...

  5. 吴裕雄 python 神经网络TensorFlow实现LeNet模型处理手写数字识别MNIST数据集

    import tensorflow as tf tf.reset_default_graph() # 配置神经网络的参数 INPUT_NODE = 784 OUTPUT_NODE = 10 IMAGE ...

  6. Tensorflow - Tutorial (7) : 利用 RNN/LSTM 进行手写数字识别

    1. 经常使用类 class tf.contrib.rnn.BasicLSTMCell BasicLSTMCell 是最简单的一个LSTM类.没有实现clipping,projection layer ...

  7. 吴裕雄 python 神经网络——TensorFlow实现AlexNet模型处理手写数字识别MNIST数据集

    import tensorflow as tf # 输入数据 from tensorflow.examples.tutorials.mnist import input_data mnist = in ...

  8. Tensorflow手写数字识别---MNIST

    MNIST数据集:包含数字0-9的灰度图, 图片size为28x28.训练样本:55000,测试样本:10000,验证集:5000

  9. Tensorflow笔记——神经网络图像识别(五)手写数字识别

随机推荐

  1. 【CF505D】Mr. Kitayuta's Technology

    题目大意: 在一个有向图中,有n个顶点,给出m对数字(u,v)表示顶点u和顶点v必须直接或者间接相连,让你构造一个这样的图,输出最少需要多少条边. 挖坑待填 官方题解链接:http://codefor ...

  2. “百度杯”CTF比赛 十月场_GetFlag(验证码爆破+注入+绝对路径文件下载)

    题目在i春秋ctf大本营 页面给出了验证码经过md5加密后前6位的值,依照之前做题的套路,首先肯定是要爆破出验证码,这里直接给我写的爆破代码 #coding:utf-8 import hashlib ...

  3. 【Visual Studio】VS2013的Release模式下进行调试(转)

    原文转自 http://blog.csdn.net/haizimin/article/details/50262901 在有的情况下,我们可能不能直接利用Debug模式进行程序调试,那么如何在Rele ...

  4. 47深入理解C指针之---指针与硬件

    一.size_t:用于安全表示长度,所有平台和系统都会解析成自己对应的长度 1.定义:size_t类型表示C中任何对象所能表示的最大长度,是个无符号整数:常常定义在stdio.h或stdlib.h中 ...

  5. spring boot--日志、开发和生产环境切换、自定义配置(环境变量)

    Spring Boot日志常用配置: # 日志输出的地址:Spring Boot默认并没有进行文件输出,只在控制台中进行了打印 logging.file=/home/zhou # 日志级别 debug ...

  6. [Web Tools] 实用的Web开发工具

    模拟http请求:Postman https://www.getpostman.com 生成Json数据 http://www.json-generator.com/

  7. Codeforces 946 B.Weird Subtraction Process

    B. Weird Subtraction Process   time limit per test 1 second memory limit per test 256 megabytes inpu ...

  8. [WPF自定义控件库]以Button为例谈谈如何模仿Aero2主题

    1. 为什么选择Aero2 除了以外观为卖点的控件库,WPF的控件库都默认使用"素颜"的外观,然后再提供一些主题包.这样做的最大好处是可以和原生控件或其它控件库兼容,而且对于大部分 ...

  9. luogu P2056 采花

    题目描述 萧芸斓是 Z国的公主,平时的一大爱好是采花. 今天天气晴朗,阳光明媚,公主清晨便去了皇宫中新建的花园采花.花园足够大,容纳了 n 朵花,花有 c 种颜色(用整数 1-c 表示) ,且花是排成 ...

  10. 第4章 使用 Spring Boot

    使用 Spring Boot 本部分将详细介绍如何使用Spring Boot. 这部分涵盖诸如构建系统,自动配置以及如何运行应用程序等主题. 我们还介绍了一些Spring Boot的最佳实践(best ...