代码比较简单,没啥好说的,就做个记录而已。大致就是现建立graph,再通过session运行即可。需要注意的就是Variable要先初始化再使用。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt # 把下载的MNIST数据集放到mnist_link目录下,用TF提供的接口解析数据集
MNIST = input_data.read_data_sets('../mnist_link',one_hot = True) learning_rate = 0.01
epoch_num = 25
batch_size = 128 X = tf.placeholder(tf.float32, [batch_size, 784], name = 'input')
Y = tf.placeholder(tf.float32, [batch_size, 10], name = 'label')
w = tf.Variable(tf.random_normal(shape = [784, 10], stddev = 0.01), name = 'weights')
b = tf.Variable(tf.zeros([1, 10]), name = 'bias') logits = tf.matmul(X, w) + b
entropy = tf.nn.softmax_cross_entropy_with_logits(labels = Y, logits = logits)
loss = tf.reduce_mean(entropy) optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate).minimize(loss) init = tf.global_variables_initializer()
loss_array = []
with tf.Session() as sess:
sess.run(init)
# train
batch_num = int(MNIST.train.num_examples/batch_size)
for _ in range(epoch_num):
for _ in range(batch_num):
X_batch, Y_batch = MNIST.train.next_batch(batch_size)
_, v = sess.run([optimizer, loss], {X: X_batch, Y: Y_batch})
loss_array.append(v) # test
total_correct_preds = 0
batch_num = int(MNIST.test.num_examples/batch_size)
for i in range(batch_num):
X_batch, Y_batch = MNIST.test.next_batch(batch_size)
_, loss_batch, logits_batch = sess.run([optimizer, loss, logits], {X: X_batch, Y: Y_batch})
preds = tf.nn.softmax(logits_batch)
correct_preds = tf.equal(tf.argmax(preds, 1), tf.argmax(Y_batch, 1))
accuracy = tf.reduce_sum(tf.cast(correct_preds, tf.float32))
total_correct_preds += sess.run(accuracy)
print("accuracy rate is {}".format(total_correct_preds/MNIST.test.num_examples)) x_axis = range(len(loss_array))
plt.plot(x_axis, loss_array)
plt.title('loss for each batch')
plt.show()

最终准确率在90%左右。学习曲线如下:

TensorFlow学习笔记2:逻辑回归实现手写字符识别的更多相关文章

  1. tensorflow学习笔记五----------逻辑回归

    在逻辑回归中使用mnist数据集.导入相应的包以及数据集. import numpy as np import tensorflow as tf import matplotlib.pyplot as ...

  2. 10分钟搞懂Tensorflow 逻辑回归实现手写识别

    1. Tensorflow 逻辑回归实现手写识别 1.1. 逻辑回归原理 1.1.1. 逻辑回归 1.1.2. 损失函数 1.2. 实例:手写识别系统 1.1. 逻辑回归原理 1.1.1. 逻辑回归 ...

  3. 学习笔记TF020:序列标注、手写小写字母OCR数据集、双向RNN

    序列标注(sequence labelling),输入序列每一帧预测一个类别.OCR(Optical Character Recognition 光学字符识别). MIT口语系统研究组Rob Kass ...

  4. Python学习笔记之逻辑回归

    # -*- coding: utf-8 -*- """ Created on Wed Apr 22 17:39:19 2015 @author: 90Zeng " ...

  5. Tensorflow学习练习-卷积神经网络应用于手写数字数据集训练

    # coding: utf-8 import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data mn ...

  6. 仅用200个样本就能得到当前最佳结果:手写字符识别新模型TextCaps

    由于深度学习近期取得的进展,手写字符识别任务对一些主流语言来说已然不是什么难题了.但是对于一些训练样本较少的非主流语言来说,这仍是一个挑战性问题.为此,本文提出新模型TextCaps,它每类仅用200 ...

  7. 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字

    TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...

  8. 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识

    深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...

  9. 深度学习-tensorflow学习笔记(2)-MNIST手写字体识别

    深度学习-tensorflow学习笔记(2)-MNIST手写字体识别超级详细版 这是tf入门的第一个例子.minst应该是内置的数据集. 前置知识在学习笔记(1)里面讲过了 这里直接上代码 # -*- ...

随机推荐

  1. 牛飞盘队Cow Frisbee Team

    老唐最近迷上了飞盘,约翰想和他一起玩,于是打算从他家的N头奶牛中选出一支队伍. 每只奶牛的能力为整数,第i头奶牛的能力为R i .飞盘队的队员数量不能少于 .大于N. 一支队伍的总能力就是所有队员能力 ...

  2. 特征提取算法(2)——HOG特征提取算法

    histogram of oriented gradient(方向梯度直方图)特征是一种在计算机视觉和图像处理中用来进行物体检测的特征描述子.它通过计算和统计图像局部区域的梯度方向直方图来构成特征.H ...

  3. E. Natasha, Sasha and the Prefix Sums

    http://codeforces.com/contest/1204/problem/E 给定n个 1 m个 -1的全排 求所有排列的$f(a) = max(0,max_{1≤i≤l} \sum_{j ...

  4. RabbitMQ消费端自定义监听(九)

    场景: 我们一般在代码中编写while循环,进行consumer.nextDelivery方法进行获取下一条消息,然后进行消费处理. 实际环境: 我们使用自定义的Consumer更加的方便,解耦性更强 ...

  5. Redis之Java客户端Jedis

    导读 Redis不仅使用命令客户端来操作,而且可以使用程序客户端操作. 现在基本上主流的语言都有客户端支持,比如Java.C.C#.C++.php.Node.js.Go等. 在官方网站里列一些Java ...

  6. js方法返回多值如何取值demo

    js方法返回,如何取值?下面demo两种方法 new array 和 json 返回值 取值示例. 方法一:  new array <html> <head> <meta ...

  7. vue 使用props 实现父组件向子组件传数据

    刚自学vue不久遇到很多问题,刚好用到的分组件,所以就用到传递数据 弄了好久终于搞定了,不多说直接上代码 父组件: <template> <headers :inputName=&q ...

  8. Sublime如何设置背景透明

    Sublime如何设置背景透明 下载sublime 透明背景插件 我用的是git下载插件: git clone https://github.com/vhanla/SublimeTextTrans.g ...

  9. 用Vue来实现音乐播放器(十四):歌手数据接口抓取

    第一步:在api文件夹下创建一个singer.js文件 返回一个getSingerList()方法  使他能够在singer.vue中调用 import jsonp from '../common/j ...

  10. 重拾SQL——表中索值

    2016.10.23 + 2016.11.02 1.选择所有数据(查看整表) MariaDB [tianyuan]> select * from pet; +----------+------- ...