一、TensorFlow简介

TensorFlow是由谷歌开发的一套机器学习的工具,使用方法很简单,只需要输入训练数据位置,设定参数和优化方法等,TensorFlow就可以将优化结果显示出来,节省了很大量的编程时间,TensorFlow的功能很多很强大,这边挑选了一个比较简单实现的方法,就是利用TensorFlow的逻辑回归算法对数据库中的手写数字做识别,让机器找出规律,然后再导入新的数字让机器识别。

二、流程介绍

上图是TensorFlow的流程,可以看到一开始要先将参数初始化,然后导入训练数据,计算偏差,然后修正参数,再导入新的训练数据,不断重复,当数据量越大,理论上参数就会越准确,不过也要注意不可训练过度。

三、导入数据

数据可进入MNIST数据库 (Mixed National Institute of Standards and Technology database),这是一个开放的数据库,里面有许多免费的训练数据可以提供下载,这次我们要下载的是手写的阿拉伯数字,为什么要阿拉伯数字呢?1、因为结果少,只有十个,比较好训练 2、图片的容量小,不占空间,下面是部分的训练数据案例

TensorFlow可以直接下载MNIST上的训练数据,并将它导入使用,下面为导入数据的代码

from tensorflow.examples.tutorials.mnist import input_data
MNIST = input_data.read_data_sets("/data/mnist", one_hot=True)

四、设定参数

接下来就是在TensorFlow里设定逻辑回归的参数,我们知道回归的公式为Y=w*X+b,X为输入,Y为计算结果,w为权重参数,b为修正参数,其中w和b就是我们要训练修正的参数,但训练里要怎么判断计算结果好坏呢?就是要判断计算出来的Y和实际的Y损失值(loss)是多少,并尽量减少loss,这边我们使用softmax函数来计算,softmax函数在计算多类别分类上的表现比较好,有兴趣可以百度一下,这边就不展开说明了,下面为参数设定

X = tf.placeholder(tf.float32, [batch_size, 784], name="image")
Y = tf.placeholder(tf.float32, [batch_size, 10], name="label")

X为输入的图片,图片大小为784K,Y为实际结果,总共有十个结果(数字0-9)

w = tf.Variable(tf.random_normal(shape=[784, 10], stddev=0.01), name="weights")
b = tf.Variable(tf.zeros([1, 10]), name="bias")

w初始值为一个随机的变数,标准差为0.01,b初始值为0。

logits = tf.matmul(X, w) + b
entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y)
loss = tf.reduce_sum(entropy)

TensorFlow里面已经有softmax的函数,只要把他叫出来就可以使用。

optimizer =
tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(loss)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
n_batches = int(MNIST.train.num_examples/batch_size)
for i in range(n_epochs): # train the model n_epochs times
for _ in range(n_batches):
X_batch, Y_batch = MNIST.train.next_batch(batch_size)
sess.run([optimizer, loss], feed_dict={X: X_batch, Y:Y_batch})

接着就是设定优化方式,这边是使用梯度降下发,然后将参数初始化,接着就运行了,这边要提一下,我们的训练方式是每次从训练数据里面抓取一个batch的数据,然后进行计算,这样可以预防过度训练,也比较可以进行事后的验证,运行完后再用下面的代码进行验证

    n_batches = int(MNIST.test.num_examples/batch_size)
total_correct_preds = 0
for i in range(n_batches):
X_batch, Y_batch = MNIST.test.next_batch(batch_size)
_, loss_batch, logits_batch = sess.run([optimizer, loss, logits],
feed_dict={X: X_batch, Y:Y_batch})
preds = tf.nn.softmax(logits_batch)
correct_preds = tf.equal(tf.argmax(preds, 1), tf.argmax(Y_batch, 1))
accuracy = tf.reduce_sum(tf.cast(correct_preds, tf.float32))
total_correct_preds += sess.run(accuracy)
print ("Accuracy {0}".format(total_correct_preds/MNIST.test.num_examples))

最后shell跑出来的结果是0.916,虽然看上去还算是不错的结果,但其实准确率是很低的,因为他验证的方式是判断一个图片是否为某个数字(单输出),所以假如机器随便猜也会有0.82左右的命中几率(0.9*0.9+0.1*0.1),想要更准确的话目前想到有两个方向,一个是提高训练量和增加神经网络的层数。

用TensorFlow做图像识别(python)的更多相关文章

  1. 03 使用Tensorflow做计算题

    我们使用Tensorflow,计算((a+b)*c)^2/a,然后求平方根.看代码: import tensorflow as tf # 输入储存容器 a = tf.placeholder(tf.fl ...

  2. 【转载】 深度学习总结:用pytorch做dropout和Batch Normalization时需要注意的地方,用tensorflow做dropout和BN时需要注意的地方,

    原文地址: https://blog.csdn.net/weixin_40759186/article/details/87547795 ------------------------------- ...

  3. 【6】TensorFlow光速入门-python模型转换为tfjs模型并使用

    本文地址:https://www.cnblogs.com/tujia/p/13862365.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...

  4. Tensorflow学习(练习)—使用inception做图像识别

    import osimport tensorflow as tfimport numpy as npimport re from PIL import Imageimport matplotlib.p ...

  5. 【转】机器学习教程 十四-利用tensorflow做手写数字识别

    模式识别领域应用机器学习的场景非常多,手写识别就是其中一种,最简单的数字识别是一个多类分类问题,我们借这个多类分类问题来介绍一下google最新开源的tensorflow框架,后面深度学习的内容都会基 ...

  6. [树莓派(raspberry pi)] 02、PI3安装openCV开发环境做图像识别(详细版)

    前言 上一篇我们讲了在linux环境下给树莓派安装系统及入门各种资料 ,今天我们更进一步,尝试在PI3上安装openCV开发环境. 博主在做的过程中主要参考一个国外小哥的文章(见最后链接1),不过其教 ...

  7. Tensorflow做阅读理解与完形填空

    catalogue . 前言 . 使用的数据集 . 数据预处理 . 训练 . 测试模型运行结果: 进行实际完形填空 0. 前言 开始写这篇文章的时候是晚上12点,突然想到几点新的理解,赶紧记下来.我们 ...

  8. tensorflow 学习1——tensorflow 做线性回归

    . 首先 Numpy: Numpy是Python的科学计算库,提供矩阵运算. 想想list已经提供了矩阵的形式,为啥要用Numpy,因为numpy提供了更多的函数. 使用numpy,首先要导入nump ...

  9. TensorFlow 安装以及python虚拟环境

    python虚拟环境 由于TensorFlow只支持某些版本的python解释器,如Python3.6.如果其他版本用户要使用TensorFlow就必须安装受支持的python版本.为了方便在不同项目 ...

随机推荐

  1. java文件操作 之 创建文件夹路径和新文件

    一:问题 (1)java 的如果文件夹路径不存在,先创建: (2)如果文件名 的文件不存在,先创建再读写;存在的话直接追加写,关键字true表示追加 (3)File myPath = new File ...

  2. H3C UDP封装

  3. Jenkins安装总结

    Jenkins官方文档说的安装步骤,http://jenkins.io/zh/doc/pipeline/tour/getting-started/ 相关安装资源可在官方文档下载 安装Jenkins之前 ...

  4. H3C MP-Group方式配置PPP MP

  5. PHP IF判断 简写

    第一种:IF 条件语句 第二种:三元运算 第三种:&& .|| 组成的条件语句 第一种: IF 基础,相信绝大多数人都会: 第二种:  c=a>b ? true:false  / ...

  6. 【b804】双栈排序

    Time Limit: 1 second Memory Limit: 50 MB [问题描述] Tom最近在研究一个有趣的排序问题.如图所示,通过2个栈S1和S2,Tom希望借助以下4种操作实现将输入 ...

  7. ThinkPHP 模版中的内置标签

    内置标签就是模版引擎提供的一组可以完成控制.循环和判断功能的类似HTML语法的标签.   一.判断比较:   1.if标签进行条件判断 //if语句的完整格式 <if condition=&qu ...

  8. The bind() Method

    The bind() method was added in ESMAScript 5, but it is easy to simulate in ESMAScrpt 3. As its name ...

  9. H3C 802.1X典型配置举例

  10. H3C通过桥ID决定端口角色