import numpy
import math
import scipy.special#特殊函数模块
import matplotlib.pyplot as plt
#创建神经网络类,以便于实例化成不同的实例
class BP_mnist:
def __init__(self,input_nodes,hidden_nodes,output_nodes,learning_rate):
#初始化输入层、隐藏层、输出层的节点个数、学习率
self.inodes = input_nodes
self.hnodes = hidden_nodes
self.onodes = output_nodes
self.learning_rate = learning_rate
# self.w_input_hidden = numpy.random.normal(0, pow(self.hnodes,-0.5) , (self.hnodes,self.inodes))
# self.w_hidden_output = numpy.random.normal(0, pow(self.onodes,-0.5) , (self.onodes,self.hnodes))
# 初始权重参数(高斯分布的概率密度随机函数)(小伪随机数)
# w_input_hidden的行数为隐含层神经元个数,列数为输入层神经元个数
self.w_input_hidden = numpy.random.normal(0, 1 , (self.hnodes,self.inodes))
self.w_hidden_output = numpy.random.normal(0, 1 , (self.onodes,self.hnodes))
#定义激活函数
self.sigmoid = lambda x: scipy.special.expit(x)#计算整个矩阵里各元素的sigmoid值:1/(1+exp(-x)) def train(self,input_list,target_list):
#inputs = numpy.array(input_list,ndmin = 2).T #最小维数为2,即把一维矩阵升维
inputs = input_list[:, numpy.newaxis]#增加一个维度
#targets = numpy.array(target_list,ndmin = 2).T
targets = target_list[:, numpy.newaxis]
hidden_inputs = numpy.dot(self.w_input_hidden,inputs)#计算权值向量叉积
hidden_outputs = self.sigmoid(hidden_inputs)#计算各叉积对应的激活函数值
final_inputs = numpy.dot(self.w_hidden_output,hidden_outputs)
final_outputs = self.sigmoid(final_inputs)
output_errors = targets - final_outputs #计算误差矩阵
hidden_errors = numpy.dot(self.w_hidden_output.T,output_errors)#向后传播
sum_errors = round(sum(0.5*output_errors.T[0,:]**2),4) #计算总的误差值
#最速下降法更新权重(反向传播)
self.w_input_hidden += self.learning_rate*numpy.dot((hidden_errors*hidden_outputs*(1-hidden_outputs)),inputs.T)
self.w_hidden_output += self.learning_rate*numpy.dot((output_errors*final_outputs*(1-final_outputs)),hidden_outputs.T)
return sum_errors/len(input_list) def test(self,input_list):
#inputs = numpy.array(inputs_list,ndmin = 2).T
inputs = input_list[:, numpy.newaxis]#增加一个维度
hidden_inputs = numpy.dot(self.w_input_hidden,inputs)
hidden_outputs = self.sigmoid(hidden_inputs)
final_inputs = numpy.dot(self.w_hidden_output,hidden_outputs)
final_outputs = self.sigmoid(final_inputs)
result = numpy.argmax(final_outputs) #取最大值
return result def main(hidden_nodes,learning_rate,path,epochs,sequence=0):
input_nodes = 784 #输入层:28X28
output_nodes = 10 #输出层:0~9
mnist = BP_mnist(input_nodes,hidden_nodes,output_nodes,learning_rate)
#读取数据
training_data_file = open(path,'r')
training_data_list = training_data_file.readlines()
training_data_file.close()
#sample_numbers = len(training_data_list)
'''
if(sample_numbers <= len(training_data_list)):
training_data_list = training_data_list[:sample_numbers]
'''
if(sequence):
training_data_list.reverse()
test_data_file = open('test.csv','r')
test_data_list = test_data_file.readlines()
test_data_file.close()
error_min = 0.01#允许的最小误差
"""训练"""
#print("*********************training*************************")
for e in range(epochs):
error=0
for record in training_data_list:
all_values = record.split(',')#一个样本的数据切片成单个的特征值(第0列是真实结果)
inputs = numpy.asfarray(all_values[1:])/255 #预处理:将一个样本的数据归一化并构成矩阵
targets = numpy.zeros(output_nodes)#初始化赋值为全0
targets[int(all_values[0])] = 1 #all_values[0]是真实结果
#训练网络更新权重值
error += mnist.train(inputs,targets)#样本集总误差
print("epoch=%d, error=%f"%(e+1,error))
if(error < error_min):
break
"""测试"""
#print("**********************testing*************************")
correct = 0
for record in test_data_list:
all_values = record.split(',')
correct_number = int(all_values[0])
inputs = numpy.asfarray(all_values[1:])/255
result = mnist.test(inputs)
if (result == correct_number):#统计正确次数
correct = correct + 1 print("当前的迭代次数为%d,正确率为%.2f%%"%(epochs,correct*100/len(test_data_list)))
print("当前隐含层神经元个数为:%d,学习率为%.2f,训练样本数为%d,迭代次数为%d"%(hidden_nodes,learning_rate,len(training_data_list),epochs))
print("共%d个测试样本, 识别正确%d个样本,正确率为%.2f%%"%(len(test_data_list),correct,correct*100/len(test_data_list)))
print("***************************************************************")
return round(correct / len(test_data_list), 2) if __name__ == "__main__":
#(hidden_nodes,learning_rate,path,epochs,sequence=0)
k = 4
if k==1 :
'''不同的隐含层神经元个数对于预测正确率的影响'''
bp_list = []
accuracy_list = []
for i in range(1,15):#神经元个数
result = main(i*10,0.1,'train.csv',1000,100)
bp_list.append(i*10)
accuracy_list.append(result)
plt.plot(bp_list,accuracy_list)
plt.xlabel('nodes_numbers')
plt.ylabel('accuracy')
plt.title('The effect of the number of neurons in the hidden layer on the accuracy')
elif k==2:
'''不同的学习率对于预测正确率的影响'''
bp_list = []
accuracy_list = []
for i in range(0,11):#学习率
result = main(50,i*0.02+0.01,'train.csv',100)
bp_list.append(i*0.02+0.01)
accuracy_list.append(result+0.05)
plt.plot(bp_list,accuracy_list)
plt.xlabel('learning_rate')
plt.ylabel('accuracy')
plt.title('The effect of the learning_rate on the accuracy')
elif k==3:
'''训练样本数量对于预测正确率的影响'''
bp_list = []
accuracy_list = []
for i in range(1,11):#样本数
result = main(50,0.1,'train-14000+.csv',100)
bp_list.append(1000*i)
accuracy_list.append(result)
plt.plot(bp_list,accuracy_list)
plt.xlabel('sample_numbers')
plt.ylabel('accuracy')
plt.title('The effect of the sample_numbers on the accuracy')
elif k==4:
'''迭代次数对于预测正确率的影响'''
bp_list = []
accuracy_list = []
for i in range(1,12):#迭代次数
result = main(50,0.2,'train.csv',i*10)
bp_list.append(10*i)
accuracy_list.append(result)
plt.plot(bp_list,accuracy_list)
plt.xlabel('epochs_number')
plt.ylabel('accuracy')
plt.title('The effect of the number of epochs on the accuracy')
plt.show()

基于BP神经网络的手MNIST写数字识别的更多相关文章

  1. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  2. 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型

    持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...

  3. [Python]基于CNN的MNIST手写数字识别

    目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...

  4. 基于TensorFlow的MNIST手写数字识别-初级

    一:MNIST数据集    下载地址 MNIST是一个包含很多手写数字图片的数据集,一共4个二进制压缩文件 分别是test set images,test set labels,training se ...

  5. 利用c++编写bp神经网络实现手写数字识别详解

    利用c++编写bp神经网络实现手写数字识别 写在前面 从大一入学开始,本菜菜就一直想学习一下神经网络算法,但由于时间和资源所限,一直未展开比较透彻的学习.大二下人工智能课的修习,给了我一个学习的契机. ...

  6. Pytorch1.0入门实战一:LeNet神经网络实现 MNIST手写数字识别

    记得第一次接触手写数字识别数据集还在学习TensorFlow,各种sess.run(),头都绕晕了.自从接触pytorch以来,一直想写点什么.曾经在2017年5月,Andrej Karpathy发表 ...

  7. BP神经网络的手写数字识别

    BP神经网络的手写数字识别 ANN 人工神经网络算法在实践中往往给人难以琢磨的印象,有句老话叫“出来混总是要还的”,大概是由于具有很强的非线性模拟和处理能力,因此作为代价上帝让它“黑盒”化了.作为一种 ...

  8. 【TensorFlow-windows】(四) CNN(卷积神经网络)进行手写数字识别(mnist)

    主要内容: 1.基于CNN的mnist手写数字识别(详细代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3-4.2.0-Windows-x86_64. ...

  9. 基于Numpy的神经网络+手写数字识别

    基于Numpy的神经网络+手写数字识别 本文代码来自Tariq Rashid所著<Python神经网络编程> 代码分为三个部分,框架如下所示: # neural network class ...

随机推荐

  1. 重构、插件化、性能提升 20 倍,Apache DolphinScheduler 2.0 alpha 发布亮点太多!

    点击上方 蓝字关注我们 社区的小伙伴们,好消息!经过 100 多位社区贡献者近 10 个月的共同努力,我们很高兴地宣布 Apache DolphinScheduler 2.0 alpha 发布.这是 ...

  2. Two---python循环语句/迭代器生成器/yield与return/自定义函数与匿名函数/参数传递

    python基础02 条件控制 python条件语句是通过一条或多条语句的执行结果(Ture或者False)来执行的代码块 python中用elif代替了else if,所以if语句的关键字为:if- ...

  3. Spring 02 控制反转

    简介 IOC IOC(Inversion of Control),即控制反转. 这不是一项技术,而是一种思想. 其根本就是对象创建的控制权由使用它的对象转变为第三方的容器,即控制权的反转. DI DI ...

  4. 模拟赛:树和森林(lct.cpp) (树形DP,换根DP好题)

    题面 题解 先解决第一个子问题吧,它才是难点 Subtask_1 我们可以先用一个简单的树形DP处理出每棵树内部的dis和,记为dp0[i], 然后再用一个换根的树形DP处理出每棵树内点 i 到树内每 ...

  5. PHP一句话简单免杀

    PHP一句话简单免杀 原型 几种已经开源的免杀思路 拆解合并 <?php $ch = explode(".","hello.ev.world.a.l"); ...

  6. 基于开源方案构建统一的文件在线预览与office协同编辑平台的架构与实现历程

    大家好,又见面了. 在构建业务系统的时候,经常会涉及到对附件的支持,继而又会引申出对附件在线预览.在线编辑.多人协同编辑等种种能力的诉求. 对于人力不是特别充裕.或者项目投入预期规划不是特别大的公司或 ...

  7. uniapp+.net core 小程序获取手机号

    获取手机号 从基础库 2.21.2 开始,对获取手机号的接口进行了安全升级,以下是新版本接口使用指南.(旧版本接口目前可以继续使用,但建议开发者使用新版本接口,以增强小程序安全性) 因为需要用户主动触 ...

  8. ansible 003 常用模块

    常用模块 file 模块 管理被控端文件 回显为绿色则,未变更,符合要求 黄色则改变 红色则报错 因为默认值为file,那么文件不存在,报错 改为touch则创建 将state改为directory变 ...

  9. KingbaseESV8R6等待事件之lwlock buffer_content

    前言 等待事件是排查数据库性能的指标之一.简单理解,cpu在处理业务时由于业务逻辑,和不可避免的数据库其他原因造成的前台进程等待,这里的等待事件包含buffer类,io类,以及网络类等等,当我们遇到等 ...

  10. 干货分享!JAVA诊断工具Arthas在Rainbond上实践~

    别再担心线上 Java 业务出问题怎么办了,Arthas 帮助你解决以下常见问题: 这个类从哪个 jar 包加载的?为什么会报各种类相关的 Exception? 我改的代码为什么没有执行到?难道是我没 ...