代码比较简单,没啥好说的,就做个记录而已。大致就是现建立graph,再通过session运行即可。需要注意的就是Variable要先初始化再使用。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt # 把下载的MNIST数据集放到mnist_link目录下,用TF提供的接口解析数据集
MNIST = input_data.read_data_sets('../mnist_link',one_hot = True) learning_rate = 0.01
epoch_num = 25
batch_size = 128 X = tf.placeholder(tf.float32, [batch_size, 784], name = 'input')
Y = tf.placeholder(tf.float32, [batch_size, 10], name = 'label')
w = tf.Variable(tf.random_normal(shape = [784, 10], stddev = 0.01), name = 'weights')
b = tf.Variable(tf.zeros([1, 10]), name = 'bias') logits = tf.matmul(X, w) + b
entropy = tf.nn.softmax_cross_entropy_with_logits(labels = Y, logits = logits)
loss = tf.reduce_mean(entropy) optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate).minimize(loss) init = tf.global_variables_initializer()
loss_array = []
with tf.Session() as sess:
sess.run(init)
# train
batch_num = int(MNIST.train.num_examples/batch_size)
for _ in range(epoch_num):
for _ in range(batch_num):
X_batch, Y_batch = MNIST.train.next_batch(batch_size)
_, v = sess.run([optimizer, loss], {X: X_batch, Y: Y_batch})
loss_array.append(v) # test
total_correct_preds = 0
batch_num = int(MNIST.test.num_examples/batch_size)
for i in range(batch_num):
X_batch, Y_batch = MNIST.test.next_batch(batch_size)
_, loss_batch, logits_batch = sess.run([optimizer, loss, logits], {X: X_batch, Y: Y_batch})
preds = tf.nn.softmax(logits_batch)
correct_preds = tf.equal(tf.argmax(preds, 1), tf.argmax(Y_batch, 1))
accuracy = tf.reduce_sum(tf.cast(correct_preds, tf.float32))
total_correct_preds += sess.run(accuracy)
print("accuracy rate is {}".format(total_correct_preds/MNIST.test.num_examples)) x_axis = range(len(loss_array))
plt.plot(x_axis, loss_array)
plt.title('loss for each batch')
plt.show()

最终准确率在90%左右。学习曲线如下:

TensorFlow学习笔记2:逻辑回归实现手写字符识别的更多相关文章

  1. tensorflow学习笔记五----------逻辑回归

    在逻辑回归中使用mnist数据集.导入相应的包以及数据集. import numpy as np import tensorflow as tf import matplotlib.pyplot as ...

  2. 10分钟搞懂Tensorflow 逻辑回归实现手写识别

    1. Tensorflow 逻辑回归实现手写识别 1.1. 逻辑回归原理 1.1.1. 逻辑回归 1.1.2. 损失函数 1.2. 实例:手写识别系统 1.1. 逻辑回归原理 1.1.1. 逻辑回归 ...

  3. 学习笔记TF020:序列标注、手写小写字母OCR数据集、双向RNN

    序列标注(sequence labelling),输入序列每一帧预测一个类别.OCR(Optical Character Recognition 光学字符识别). MIT口语系统研究组Rob Kass ...

  4. Python学习笔记之逻辑回归

    # -*- coding: utf-8 -*- """ Created on Wed Apr 22 17:39:19 2015 @author: 90Zeng " ...

  5. Tensorflow学习练习-卷积神经网络应用于手写数字数据集训练

    # coding: utf-8 import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data mn ...

  6. 仅用200个样本就能得到当前最佳结果:手写字符识别新模型TextCaps

    由于深度学习近期取得的进展,手写字符识别任务对一些主流语言来说已然不是什么难题了.但是对于一些训练样本较少的非主流语言来说,这仍是一个挑战性问题.为此,本文提出新模型TextCaps,它每类仅用200 ...

  7. 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字

    TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...

  8. 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识

    深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...

  9. 深度学习-tensorflow学习笔记(2)-MNIST手写字体识别

    深度学习-tensorflow学习笔记(2)-MNIST手写字体识别超级详细版 这是tf入门的第一个例子.minst应该是内置的数据集. 前置知识在学习笔记(1)里面讲过了 这里直接上代码 # -*- ...

随机推荐

  1. 浅谈IPv4至IPv6演进的实施路径

    作者:个推运维平台网络工程师 宗堂   1 业务背景 在互联网呈现爆炸式发展的今天, IPv4网络地址数量匮乏等问题将会影响到我国的互联网发展与应用,制约物联网.5G等新业务开展.今年4月国家工信部发 ...

  2. spring-boot-devtools热加载不起作用

    在开发过程中,希望修改时能够及时更新修改,即热加载,但是spring-boot-devtools不起作用.这主要是两个原因导致. 一.spring-boot-maven-plugin插件没有配置,如下 ...

  3. angular 发送ajax

    在使用angular发送ajax的时候get和post一样的,就是method改一下. ajax的js: <script> var app = angular.module('emialV ...

  4. 10.1 ‘The server's host key is not cached in the registry’

    10.1 ‘The server's host key is not cached in the registry’ This error message occurs when PuTTY conn ...

  5. 架构-层-Model:Model

    ylbtech-架构-层-Model:Model 1.返回顶部 1. Model,意思是模特儿,模特儿是英文“model”的音译.模特一般来说要五官端正,身材良好,有气质,展示能力强,另外身高要具备一 ...

  6. javascript处理json字符串

    字符串转JSON格式 var obj = JSON.parse(json字符串); 判断字段值是否存在,返回true或false obj.hasOwnProperty("error" ...

  7. VMware 虚拟化编程(14) — VDDK 的高级传输模式详解

    目录 目录 前文列表 虚拟磁盘数据的传输方式 Transport Methods Local File Access NBD and NBDSSL Transport SAN Transport Ho ...

  8. 阶段1 语言基础+高级_1-3-Java语言高级_06-File类与IO流_06 Properties集合_2_Properties集合中的方法store

    第一行是注释,第二行是时间,时间是自动加的 使用FileOutputStream. 写入中文会乱码

  9. 阶段1 语言基础+高级_1-3-Java语言高级_04-集合_04 数据结构_2_数据结构_队列

    先进先出 队列 队列:queue,简称队,它同堆栈一样,也是一种运算受限的线性表,其限制是仅允许在表的一端进行插入, 而在表的另一端进行删除. 简单的说,采用该结构的集合,对元素的存取有如下的特点: ...

  10. docker进阶——数据管理与网络

    一.数据卷管理 用户在使用 Docker 的过程中,势必需要查看容器内应用产生的数据,或者 需要将容器内数据进行备份,甚至多个容器之间进行数据共享,这必然会涉及 到容器的数据管理 (1)Data Vo ...