用python实现数字图片识别神经网络--启动网络的自我训练流程,展示网络数字图片识别效果
上一节,我们完成了网络训练代码的实现,还有一些问题需要做进一步的确认。网络的最终目标是,输入一张手写数字图片后,网络输出该图片对应的数字。由于网络需要从0到9一共十个数字中挑选出一个,于是我们的网络最终输出层应该有十个节点,每个节点对应一个数字。假设图片对应的是数字0,那么输出层网络中,第一个节点应该输出一个高百分比,其他节点输出低百分比,如果图片对应的数字是9,那么输出层最后一个节点应该输出高百分比,其他节点输出低百分比,例如下图:
屏幕快照 2018-05-07 下午5.10.59.png
高比率,如果网络认为图片对应的数字是0,那么编号为0的节点输出0.95的高比率。最后一个例子是很有意思,编号为4和9的神经元都输出一个不低的比率,这表明图片对应的数字很像4和9,但神经网络认为是9的概率是4的概率的两倍以上。
还记得上一节我们准备好要输入网络的数据吗:
这里写图片描述
数据的第一个值代表图片对应的数字,我们需要把这种对应信息通过代码表现出来:
#最外层有10个输出节点onodes = 10
targets = numpy.zeros(onodes) + 0.01
targets[int(all_values[0])] = 0.99print(targets)
上面代码的输出结果为:
image.png
targets第8个元素的值是0.99,这表示图片对应的数字是7,记住数组是从编号0开始的。根据这种做法,我们就能把输入图片给对应的正确数字建立联系,这种联系就可以用于输入到网络中,进行训练。由于一张图片总共有28*28 = 764个数值,因此我们需要让网络的输入层具备764个输入节点,于是网络的初始化以及将数据输入网络进行训练的实现代码为:
#初始化网络input_nodes = 784hidden_nodes = 100output_nodes = 10learning_rate = 0.3n = NeuralNetWork(input_nodes, hidden_nodes, output_nodes, learning_rate)#读入训练数据#open函数里的路径根据数据存储的路径来设定training_data_file = open("/Users/chenyi/Documents/人工智能/mnist_train_100.csv")
trainning_data_list = training_data_file.readlines()
training_data_file.close()#把数据依靠','区分,并分别读入for record in trainning_data_list:
all_values = record.split(',')
inputs = (numpy.asfarray(all_values[1:]))/255.0 * 0.99 + 0.01
#设置图片与数值的对应关系
targets = numpy.zeros(output_nodes) + 0.01
targets[int(all_values[0])] = 0.99
n.train(inputs, targets)
这里需要注意的是,中间层的节点我们选择了100个神经元,这个选择其实是经验值,也就是中间层的节点数其实没有专门的办法去规定,其数量会根据不同的问题而变化,确定中间层神经元节点数最好的办法是实验,你不停的选取各种数量,看看那种数量能使得网络的表现最好就行。
上面代码把一百条数据输入网络进行训练,现在我们看看训练后的网络它的表现怎样。我们先从加载另一组数据,取出其中一张手写数字图片,将其输入到网络中,看看网络的判断结果如何:
test_data_file = open("/Users/chenyi/Documents/人工智能/mnist_test_10.csv")
test_data_list = test_data_file.readlines()
test_data_file.close()import numpyimport matplotlib.pyplot
%matplotlib inline#把数据依靠','区分,并分别读入all_values = data_list[0].split(',')#第一个值对应的是图片的表示的数字,所以我们读取图片数据时要去掉第一个数值image_array = numpy.asfarray(all_values[1:]).reshape((28, 28))
matplotlib.pyplot.imshow(image_array, cmap='Greys', interpolation='None')
这段代码我们在上一节讲解过,我们把测试数据里面的第一张手写数字图片先绘制出来,代码运行结果如下:
这里写图片描述
通过人眼观察,我们基本确定这种图片对应的是数字7,那么网络识别它的结果如何呢,我们将这张图片的数字输入到网络看看其识别结果:
n.query(numpy.asfarray(all_values[1:]) / 255.0 * 0.99 + 0.01)
上面这行代码运行后结果如下:
这里写图片描述
前面我们讨论过最外层节点输出的意义,最外层节点有十个,分别对应0到9十个数字,哪个节点输出的数值高,那意味着网络认为图片对应哪个数字,我们看到网络输出中,对应编号为7的节点输出值最大,为0.68,也就是说网络把图片识别为数字7,这与我们的观察是一致的,这么说我们辛辛苦苦打造的网络是有效的,前面那么多的铺垫到现在终于有了收获。
我们原来给网络输入的训练数据来自trainning_set,而现在给网络判断的图片来自testing_set,因此网络从未见过这张图片,它能识别这张图片是数字7,这种能力是通过分析训练图片,不断改进链路权重值的结果。实现网络的Python代码不过百来行,他居然就能实现了我们所认为的人工智能,如此看来人工智能似乎并非那么神秘。
接着我们把所有测试图片都输入网络,看看它检测的效果如何,代码如下:
scores = []for record in test_data_list:
all_values = record.split(',')
correct_number = int(all_values[0]) print("该图片对应的数字为:",correct_number) #预处理数字图片
inputs = (numpy.asfarray(all_values[1:])) / 255.0 * 0.99 + 0.01 #让网络判断图片对应的数字
outputs = n.query(inputs) #找到数值最大的神经元对应的编号
label = numpy.argmax(outputs) print("out put reslut is : ", label) #print("网络认为图片的数字是:", label)
if label == correct_number:
scores.append(1) else:
scores.append(0)print(scores)
上面代码把测试数据集里的10张图片全部加载,然后输入到网络中,看看网络对每张数字图片的识别效果如何,上面代码运行后结果如下:
这里写图片描述
从输出结果看,有些图片网络还是识别错了,最后代码打印出一个数组,里面的1表示识别正确,0表示识别错误,从数组内容看,有4张图片网络给出了错误答案。这次的结果多少令人有些沮丧,我们计算一下图片判断的成功率:
scores_array = numpy.asarray(scores)print("perfermance = ", scores_array.sum() / scores_array.size)
代码运行后结果如下:
这里写图片描述
由此看来,网络识别的成功率只有六成。为了提升成功率,我们必须加大网络的训练力度,原来我们训练网络时只使用了100条数据,现在我们使用60000条数据,然后用10000条数据作为测试集,我们从以下两个链接获取相应的数据集:
http://www.pjreddie.com/media/files/mnist_train.csvhttp://www.pjreddie.com/media/files/mnist_test.csv
然后我们把原来代码做一点小修改,加载上面的数据来对网络进行训练和测试:
#初始化网络input_nodes = 784
hidden_nodes = 100
output_nodes = 10
learning_rate = 0.3
n = NeuralNetWork(input_nodes, hidden_nodes, output_nodes, learning_rate)#读入训练数据#open函数里的路径根据数据存储的路径来设定training_data_file = open("/Users/chenyi/Documents/人工智能/mnist_train.csv")
trainning_data_list = training_data_file.readlines()print(len(trainning_data_list))
training_data_file.close()#把数据依靠','区分,并分别读入for record in trainning_data_list:
all_values = record.split(',')
inputs = (numpy.asfarray(all_values[1:]))/255.0 * 0.99 + 0.01 #设置图片与数值的对应关系
targets = numpy.zeros(output_nodes) + 0.01
targets[int(all_values[0])] = 0.99
n.train(inputs, targets)
test_data_file = open("/Users/chenyi/Documents/人工智能/mnist_test.csv")
test_data_list = test_data_file.readlines()
test_data_file.close()
scores = []for record in test_data_list:
all_values = record.split(',')
correct_number = int(all_values[0]) #预处理数字图片
inputs = (numpy.asfarray(all_values[1:])) / 255.0 * 0.99 + 0.01 #让网络判断图片对应的数字
outputs = n.query(inputs) #找到数值最大的神经元对应的编号
label = numpy.argmax(outputs) if label == correct_number:
scores.append(1) else:
scores.append(0)
scores_array = numpy.asarray(scores)print("perfermance = ", scores_array.sum() / scores_array.size)
上面代码跟以前是一样的,只不过加载的数据文件不同而已,这次我们用60000条数据来训练网络,然后用10000条数据来检测网络的准确性,上面代码执行后结果如下:
这里写图片描述
从结果上看,当训练网络的数据流增大后,网络识别的正确性由原来的0.6提升到0.9,我们再次用新训练后的网络识别原来那十张数字图片,得到结果如下:
这里写图片描述
经过大数据训练后的网络,对图片的识别率达到了百分之百,这意味着当用于训练网络的数据越多,网络识别的效果就越好,这就是为何在某种程度上说,人工智能也是大公司的大杀器,因为只有大公司才能拥有足量的数据。
在整个过程,我们一直保持着学习率不变,实际上学习率的大小对网络的训练效果有很大影响,大家可以把该参数改成0.6,0.1等不同的值去看看结果,另外也可以修改中间层的节点数看看有什么效果。二手叉车哪家好
这里我们引入在第一节时提到的一个概念叫epocs,它表示网络进行几次训练循环,对其使用的代码如下:
#加入epocs,设定网络的训练循环次数epochs = 10for e in range(epochs): #把数据依靠','区分,并分别读入
for record in trainning_data_list:
all_values = record.split(',')
inputs = (numpy.asfarray(all_values[1:]))/255.0 * 0.99 + 0.01
#设置图片与数值的对应关系
targets = numpy.zeros(output_nodes) + 0.01
targets[int(all_values[0])] = 0.99
n.train(inputs, targets)
也就是在原来网络训练的基础上再加上一层外循环,上面代码运行后执行的对于普通电脑而言执行的时间会很长。一般来说,epochs 的数值越大,网络被训练的就越精准,但如果超过一个阈值,网络就会引发一个过渡拟合的问题,也就是网络会对老数据识别的很精准,但对新数据识别的效率反而变得越来越低,大家可以自行尝试一下不同的学习率和epochs组合,看看网络的识别精度是否有所提高,另外大家也可以修改中间层的节点数看看其对网络的识别精度是否有显著影响,在我电脑上把epochs设置成7时,成功率能提升到95%。
用python实现数字图片识别神经网络--启动网络的自我训练流程,展示网络数字图片识别效果的更多相关文章
- Python机器学习笔记:卷积神经网络最终笔记
这已经是我的第四篇博客学习卷积神经网络了.之前的文章分别是: 1,Keras深度学习之卷积神经网络(CNN),这是开始学习Keras,了解到CNN,其实不懂的还是有点多,当然第一次笔记主要是给自己心中 ...
- Python之TensorFlow的卷积神经网络-5
一.卷积神经网络(Convolutional Neural Networks, CNN)是一类包含卷积计算且具有深度结构的前馈神经网络(Feedforward Neural Networks),是深度 ...
- 将Python项目打包成EXE可执行文件(单文件,多文件,包含图片)
解决 将Python项目打包成EXE可执行文件(单文件,多文件,包含图片) 1.当我们写了一个Python的项目时,特别是一个GUI项目,我们特备希望它能成为一个在Windows系统可执行的EXE文件 ...
- Python实现在给定整数序列中找到和为100的所有数字组合
摘要: 使用Python在给定整数序列中找到和为100的所有数字组合.可以学习贪婪算法及递归技巧. 难度: 初级 问题 给定一个整数序列,要求将这些整数的和尽可能拼成 100. 比如 [17, 1 ...
- 移动端禁止图片长按和vivo手机点击img标签放大图片,禁止长按识别二维码或保存图片【转载】
移动端禁止图片长按和vivo手机点击img标签放大图片,禁止长按识别二维码或保存图片 img{ pointer-events: none; } 源文地址:https://www.cnblogs.com ...
- Python爬虫实例(一)爬取百度贴吧帖子中的图片
程序功能说明:爬取百度贴吧帖子中的图片,用户输入贴吧名称和要爬取的起始和终止页数即可进行爬取. 思路分析: 一.指定贴吧url的获取 例如我们进入秦时明月吧,提取并分析其有效url如下 http:// ...
- iOS 实现启动屏动画(Swift实现,包含图片适配)
代码地址如下:http://www.demodashi.com/demo/12090.html 准备工作 首先我们需要确定作为宣传的图片的宽高比,这个一般是与 UI 确定的.一般启动屏展示会有上下两部 ...
- Python 数据处理之对 list 数据进行数据重排(为连续的数字序号)
Python 数据处理之对 list 数据进行数据重排(为连续的数字序号) # user ID 序号重新排,即,原来是 1,3,4,6 ,排为 1,2,3,4 # item ID 序号重新排,too ...
- Python使用numpy实现BP神经网络
Python使用numpy实现BP神经网络 本文完全利用numpy实现一个简单的BP神经网络,由于是做regression而不是classification,因此在这里输出层选取的激励函数就是f(x) ...
随机推荐
- 适合自己的adblock过滤列表
轻微完美主义,极简主义 已屏蔽广告: 1.CSDN的广告 2.百度侧栏热点搜索 3. 知乎广告 4.stackoverflow的推送广告 5.LeetCode的推送的是否见过这个题 bbs.csdn. ...
- c++getline()、get()等
1.cin 接受一个字符串,遇“空格”.“TAB”.“回车”都结束 2.cin.get() cin.get(字符变量名)可以用来接收字符 只能接收一个字符 cin.get(字符数组名,接收字符数目)用 ...
- HDU1003 最大子段和 线性dp
题目链接: http://acm.hdu.edu.cn/showproblem.php?pid=1003 Max Sum Time Limit: 2000/1000 MS (Java/Others) ...
- Python学习 :常用模块(二)
常用模块(二) 四.os模块 os模块是与操作系统交互的一个接口,用于对操作系统进行调用 os.getcwd() # 提供当前工作目录 os.chdir() # 改变当前工作目录 os.curdir( ...
- Rails 自定义验证的错误信息
Active Record 验证辅助方法的默认错误消息都是英文,为了提高用户体验,有时候我们经常会被要求按特定的文本展示错误信息.此时有两种实现方式. 1. 直接在:message添加文案 class ...
- 6.Exceptions-异常(Dart中文文档)
异常是用于标识程序发生未知异常.如果异常没有被捕获,If the exception isn't caught, the isolate that raised the exception is su ...
- Kubernetes学习之路(二)之ETCD集群二进制部署
ETCD集群部署 所有持久化的状态信息以KV的形式存储在ETCD中.类似zookeeper,提供分布式协调服务.之所以说kubenetes各个组件是无状态的,就是因为其中把数据都存放在ETCD中.由于 ...
- [BZOJ2742][HEOI2012]Akai的数学作业[推导]
题意 给定各项系数,求一元 \(n\) 次方程的有理数解. \(n\leq 100\). 分析 设答案为 \(\frac{p}{q}\) ,那么多项式可以写成 \(a_0\frac{p}{q}+a_1 ...
- SSIS 数据流的执行树和数据管道
数据流组件的设计愿景是快速处理海量的数据,为了实现该目标,SSIS数据源引擎需要创建执行树和数据管道这两个数据结构,而用户为了快速处理数据流,必须知道各个转换组件的阻塞性,充分利用流式处理流程,利用更 ...
- 5 行 Python 代码调用电脑摄像头
前提: 确保 python 中安装了 opencv-python 模块.如果没有安装,可以参考:https://pypi.org/project/opencv-python/ 进行安装.话不多少,直接 ...