写在开头:这个实验和matlab手写神经网络实现识别手写数字一样。


实验说明

一直想自己写一个神经网络来实现手写数字的识别,而不是套用别人的框架。恰巧前几天,有幸从同学那拿到5000张已经贴好标签的手写数字图片,于是我就尝试用matlab写一个网络。

  • 实验数据:5000张手写数字图片(.jpg),图片命名为1.jpg,2.jpg…5000.jpg。还有一个放着标签的excel文件。

  • 数据处理:前4000张作为训练样本,后1000张作为测试样本。

  • 图片处理:用matlab的imread()函数读取图片的灰度值矩阵(28,28),然后把每张图片的灰度值矩阵reshape为(28*28,1),然后把前4000张图片的灰度值矩阵合并为x_train,把后1000张图片的灰度值矩阵合并为x_test。



神经网络设计

  • 网络层设计:一层隐藏层,一层输出层

  • 输入层:一张图片的灰度值矩阵reshape后的784个数,也就是x_train中的某一列

  • 输出层:(10,1)的列向量,其中列向量中最大的数所在的索引+1就是预测的数字

  • 激励函数:sigmoid函数(公式)

  • 更新法则:后向传播算法(参考

  • 测试:统计预测正确的个数

网络实现

  • 函数说明:读图片的函数(read_photo() )、读excel的函数(read_excel(path) )、修正函数(layerout(w,b,x) )、训练函数(mytrain(x_train,y_train) )、测试函数(mytest(x_test,y_test,w,b,w_h,b_h) )、主函数(main() )

具体代码如下:

# -*- coding: utf-8 -*-

from PIL import Image

from pylab import *

import numpy as np

import xlrd

#读取图片的灰度值矩阵
def read_photo():
for i in range(5000):
j = i+1
j = str(j)
st = '.jpg'
j = j+st
im1 = array(Image.open(j))
#(28,28)-->(28*28,1)
im1 = im1.reshape((784,1))
#把所有的图片灰度值放到一个矩阵中
#一列代表一张图片的信息
if i == 0:
im = im1
else:
im = np.hstack((im,im1))
return im #读取excel文件内容(path为文件路径)
def read_excel(path):
# 获取所有sheet
workbook = xlrd.open_workbook(path)
sheet_names = workbook.sheet_names() # 根据sheet索引或者名称获取sheet内容
for sheet_name in sheet_names:
isheet = workbook.sheet_by_name(sheet_name) #获取该sheet的列数
ncols = isheet.ncols
#获取每一列的内容
for i in range(ncols):
if i == 0:
xl1 = isheet.col_values(i)
xl1 = np.array(xl1)
xl1 = xl1.reshape((10,1))
xl = xl1
else:
xl1 = isheet.col_values(i)
xl1 = np.array(xl1)
xl1 = xl1.reshape((10,1))
xl = np.hstack((xl,xl1))
return xl #layerout函数
def layerout(w,b,x):
y = np.dot(w,x) + b
t = -1.0*y
# n = len(y)
# for i in range(n):
# y[i]=1.0/(1+exp(-y[i]))
y = 1.0/(1+exp(t))
return y #训练函数
def mytrain(x_train,y_train):
'''
设置一个隐藏层,784-->隐藏层神经元个数-->10
''' step=int(input('mytrain迭代步数:'))
a=double(input('学习因子:'))
inn = 784 #输入神经元个数
hid = int(input('隐藏层神经元个数:'))#隐藏层神经元个数
out = 10 #输出层神经元个数 w = np.random.randn(out,hid)
w = np.mat(w)
b = np.mat(np.random.randn(out,1))
w_h = np.random.randn(hid,inn)
w_h = np.mat(w_h)
b_h = np.mat(np.random.randn(hid,1)) for i in range(step):
#打乱训练样本
r=np.random.permutation(4000)
x_train = x_train[:,r]
y_train = y_train[:,r]
#mini_batch
for j in range(400):
#取batch为10 更新取10次的平均值
x = np.mat(x_train[:,j])
x = x.reshape((784,1))
y = np.mat(y_train[:,j])
y = y.reshape((10,1))
hid_put = layerout(w_h,b_h,x)
out_put = layerout(w,b,hid_put) #更新公式的实现
o_update = np.multiply(np.multiply((y-out_put),out_put),(1-out_put))
h_update = np.multiply(np.multiply(np.dot((w.T),np.mat(o_update)),hid_put),(1-hid_put)) outw_update = a*np.dot(o_update,(hid_put.T))
outb_update = a*o_update
hidw_update = a*np.dot(h_update,(x.T))
hidb_update = a*h_update w = w + outw_update
b = b+ outb_update
w_h = w_h +hidw_update
b_h =b_h +hidb_update return w,b,w_h,b_h #test函数
def mytest(x_test,y_test,w,b,w_h,b_h):
'''
统计1000个测试样本中有多少个预测正确了
预测结果表示:10*1的列向量中最大的那个数的索引+1就是预测结果了
'''
sum = 0
for k in range(1000):
x = np.mat(x_test[:,k])
x = x.reshape((784,1)) y = np.mat(y_test[:,k])
y = y.reshape((10,1)) yn = np.where(y ==(np.max(y)))
# print(yn)
# print(y)
hid = layerout(w_h,b_h,x);
pre = layerout(w,b,hid);
#print(pre)
pre = np.mat(pre)
pre = pre.reshape((10,1))
pren = np.where(pre ==(np.max(pre)))
# print(pren)
# print(pre)
if yn == pren:
sum += 1 print('1000个样本,正确的有:',sum) def main():
#获取图片信息
im = read_photo()
immin = im.min()
immax = im.max() im = (im-immin)/(immax-immin) #前4000张图片作为训练样本
x_train = im[:,0:4000]
#后1000张图片作为测试样本
x_test = im[:,4000:5000] #获取label信息
xl = read_excel('./label.xlsx') y_train = xl[:,0:4000]
y_test = xl[:,4000:5000] print("---------------------------------------------------------------")
w,b,w_h,b_h = mytrain(x_train,y_train)
mytest(x_test,y_test,w,b,w_h,b_h)
print("---------------------------------------------------------------") if __name__ == '__main__':
main()

实验结果

---------------------------------------------------------------
mytrain迭代步数:300
学习因子:0.3
隐藏层神经元个数:28
1000个样本,正确的有: 933
---------------------------------------------------------------

迭代300步,正确率就有93.3%啦,还不错的正确率~

python手写神经网络实现识别手写数字的更多相关文章

  1. matlab手写神经网络实现识别手写数字

    实验说明 一直想自己写一个神经网络来实现手写数字的识别,而不是套用别人的框架.恰巧前几天,有幸从同学那拿到5000张已经贴好标签的手写数字图片,于是我就尝试用matlab写一个网络. 实验数据:500 ...

  2. 使用神经网络来识别手写数字【译】(三)- 用Python代码实现

    实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样识别手写数字. 我们用Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNI ...

  3. 如何用卷积神经网络CNN识别手写数字集?

    前几天用CNN识别手写数字集,后来看到kaggle上有一个比赛是识别手写数字集的,已经进行了一年多了,目前有1179个有效提交,最高的是100%,我做了一下,用keras做的,一开始用最简单的MLP, ...

  4. python机器学习使用PCA降维识别手写数字

    PCA降维识别手写数字 关注公众号"轻松学编程"了解更多. PCA 用于数据降维,减少运算时间,避免过拟合. PCA(n_components=150,whiten=True) n ...

  5. 解决python中import时无法识别自己写的包和模块的方法

    我们用pycharm打开自己写的代码,当多个文件之间有相互依赖的关系的时候,import无法识别自己写的文件,但是我们写的文件又确实在同一个文件夹中, 这种问题可以用下面的方法解决: 1)打开File ...

  6. Python实现神经网络算法识别手写数字集

    最近忙里偷闲学习了一点机器学习的知识,看到神经网络算法时我和阿Kun便想到要将它用Python代码实现.我们用了两种不同的方法来编写它.这里只放出我的代码. MNIST数据集基于美国国家标准与技术研究 ...

  7. NN:神经网络实现识别手写的1~9的10个数字—Jason niu

    import numpy as np from sklearn.datasets import load_digits from sklearn.metrics import confusion_ma ...

  8. Tensorflow搭建卷积神经网络识别手写英语字母

    更新记录: 2018年2月5日 初始文章版本 近几天需要进行英语手写体识别,查阅了很多资料,但是大多数资料都是针对MNIST数据集的,并且主要识别手写数字.为了满足实际的英文手写识别需求,需要从训练集 ...

  9. 用BP人工神经网络识别手写数字

    http://wenku.baidu.com/link?url=HQ-5tZCXBQ3uwPZQECHkMCtursKIpglboBHq416N-q2WZupkNNH3Gv4vtEHyPULezDb5 ...

随机推荐

  1. EasyUI Pagination 分页分页布局定义 显示按钮布局

    //分页布局定义.该属性自版本 1.3.5 起可用.//布局项目包括一个或多个下列值://1.list:页面尺寸列表.//2.sep:页面按钮分割.//3.first:第一个按钮.//4.prev:前 ...

  2. C语言错误: CRT detected that the application wrote to memory after end of heap buffer

    CRT detected that the application wrote to memory after end of heap buffer 多是中间对其进行了一些操作,在程序结束处,释放内存 ...

  3. Lifecycle for overriding binding, validation, etc,易于同其它View框架(Tiles等)无缝集成,采用IOC便于测试。

    Lifecycle for overriding binding, validation, etc,易于同其它View框架(Tiles等)无缝集成,采用IOC便于测试. 它是一个典型的教科书式的mvc ...

  4. 【BZOJ】1682: [Usaco2005 Mar]Out of Hay 干草危机(kruskal)

    http://www.lydsy.com/JudgeOnline/problem.php?id=1682 最小生成树裸题.. #include <cstdio> #include < ...

  5. 使用spring + ActiveMQ 总结

    使用spring + ActiveMQ 总结   摘要 Spring 整合JMS 基于ActiveMQ 实现消息的发送接收 目录[-] Spring 整合JMS 基于ActiveMQ 实现消息的发送接 ...

  6. Boost库初见

    Boost库是一个功能强大.构造精巧.跨平台.开源并且完全免费的C++库,有C++"准"标准库的美称! Boost有着与其它程序库(如MFC等)无法比拟的优点. Boost库采用了 ...

  7. CSS 属性的默认值

    最近在看到一篇关于如何实现水平垂直居中,发现有许多属性值,自己并不了解,特此Google一番,查到,摘抄过来,方便以后查阅,下面是如何实现水平垂直居中的博文. 解读CSS布局之-水平垂直居中 html ...

  8. XDocument简单入门

    [+]   1.什么是XML? 2.XDocument和XmlDocument的区别? 3.XDocument 4.XmlDocument 5.LINQ to XML 6.XML序列化与反序列化 因为 ...

  9. ListView中的方法

    getCount(); getItem(); getItemId(); getView(); getViewCountType();

  10. poj 1386

    Play on Words Time Limit: 1000MS   Memory Limit: 10000K Total Submissions: 11312   Accepted: 3862 De ...