python手写神经网络实现识别手写数字
写在开头:这个实验和matlab手写神经网络实现识别手写数字一样。
实验说明
一直想自己写一个神经网络来实现手写数字的识别,而不是套用别人的框架。恰巧前几天,有幸从同学那拿到5000张已经贴好标签的手写数字图片,于是我就尝试用matlab写一个网络。
实验数据:5000张手写数字图片(.jpg),图片命名为1.jpg,2.jpg…5000.jpg。还有一个放着标签的excel文件。
数据处理:前4000张作为训练样本,后1000张作为测试样本。
图片处理:用matlab的imread()函数读取图片的灰度值矩阵(28,28),然后把每张图片的灰度值矩阵reshape为(28*28,1),然后把前4000张图片的灰度值矩阵合并为x_train,把后1000张图片的灰度值矩阵合并为x_test。
 
神经网络设计
网络层设计:一层隐藏层,一层输出层
输入层:一张图片的灰度值矩阵reshape后的784个数,也就是x_train中的某一列
输出层:(10,1)的列向量,其中列向量中最大的数所在的索引+1就是预测的数字
激励函数:sigmoid函数(公式)
更新法则:后向传播算法(参考)
测试:统计预测正确的个数
网络实现
- 函数说明:读图片的函数(read_photo() )、读excel的函数(read_excel(path) )、修正函数(layerout(w,b,x) )、训练函数(mytrain(x_train,y_train) )、测试函数(mytest(x_test,y_test,w,b,w_h,b_h) )、主函数(main() )
 
具体代码如下:
# -*- coding: utf-8 -*-
from PIL import Image
from pylab import *
import numpy as np
import xlrd
#读取图片的灰度值矩阵
def read_photo():
    for i in range(5000):
        j = i+1
        j = str(j)
        st = '.jpg'
        j = j+st
        im1 = array(Image.open(j))
        #(28,28)-->(28*28,1)
        im1 = im1.reshape((784,1))
        #把所有的图片灰度值放到一个矩阵中
        #一列代表一张图片的信息
        if i == 0:
            im = im1
        else:
            im = np.hstack((im,im1))
    return im
#读取excel文件内容(path为文件路径)
def read_excel(path):
    # 获取所有sheet
    workbook = xlrd.open_workbook(path)
    sheet_names = workbook.sheet_names()
    # 根据sheet索引或者名称获取sheet内容
    for sheet_name in sheet_names:
        isheet = workbook.sheet_by_name(sheet_name)
        #获取该sheet的列数
        ncols = isheet.ncols
        #获取每一列的内容
        for i in range(ncols):
            if i == 0:
                xl1 = isheet.col_values(i)
                xl1 = np.array(xl1)
                xl1 = xl1.reshape((10,1))
                xl = xl1
            else:
                xl1 = isheet.col_values(i)
                xl1 = np.array(xl1)
                xl1 = xl1.reshape((10,1))
                xl = np.hstack((xl,xl1))
    return xl
#layerout函数
def layerout(w,b,x):
    y = np.dot(w,x) + b
    t = -1.0*y
    # n = len(y)
    # for i in range(n):
        # y[i]=1.0/(1+exp(-y[i]))
    y = 1.0/(1+exp(t))
    return y
#训练函数
def mytrain(x_train,y_train):
    '''
    设置一个隐藏层,784-->隐藏层神经元个数-->10
    '''
    step=int(input('mytrain迭代步数:'))
    a=double(input('学习因子:'))
    inn = 784  #输入神经元个数
    hid = int(input('隐藏层神经元个数:'))#隐藏层神经元个数
    out = 10  #输出层神经元个数
    w = np.random.randn(out,hid)
    w = np.mat(w)
    b = np.mat(np.random.randn(out,1))
    w_h = np.random.randn(hid,inn)
    w_h = np.mat(w_h)
    b_h = np.mat(np.random.randn(hid,1)) 
    for i in range(step):
        #打乱训练样本
        r=np.random.permutation(4000)
        x_train = x_train[:,r]
        y_train = y_train[:,r]
        #mini_batch
        for j in range(400):
            #取batch为10  更新取10次的平均值
            x = np.mat(x_train[:,j])
            x = x.reshape((784,1))
            y = np.mat(y_train[:,j])
            y = y.reshape((10,1))
            hid_put = layerout(w_h,b_h,x)
            out_put = layerout(w,b,hid_put) 
            #更新公式的实现
            o_update = np.multiply(np.multiply((y-out_put),out_put),(1-out_put))
            h_update = np.multiply(np.multiply(np.dot((w.T),np.mat(o_update)),hid_put),(1-hid_put)) 
            outw_update = a*np.dot(o_update,(hid_put.T))
            outb_update = a*o_update
            hidw_update = a*np.dot(h_update,(x.T))
            hidb_update = a*h_update 
            w = w + outw_update
            b = b+ outb_update
            w_h = w_h +hidw_update
            b_h =b_h +hidb_update 
    return w,b,w_h,b_h
#test函数
def mytest(x_test,y_test,w,b,w_h,b_h):
    '''
    统计1000个测试样本中有多少个预测正确了
    预测结果表示:10*1的列向量中最大的那个数的索引+1就是预测结果了
    '''
    sum = 0
    for k in range(1000):
        x = np.mat(x_test[:,k])
        x = x.reshape((784,1))
        y = np.mat(y_test[:,k])
        y = y.reshape((10,1))
        yn = np.where(y ==(np.max(y)))
        # print(yn)
        # print(y)
        hid = layerout(w_h,b_h,x);
        pre = layerout(w,b,hid);
        #print(pre)
        pre = np.mat(pre)
        pre = pre.reshape((10,1))
        pren = np.where(pre ==(np.max(pre)))
        # print(pren)
        # print(pre)
        if yn == pren:
            sum += 1
    print('1000个样本,正确的有:',sum)
def main():
    #获取图片信息
    im = read_photo()
    immin = im.min()
    immax = im.max()
    im = (im-immin)/(immax-immin)
    #前4000张图片作为训练样本
    x_train = im[:,0:4000]
    #后1000张图片作为测试样本
    x_test = im[:,4000:5000]
    #获取label信息
    xl = read_excel('./label.xlsx')
    y_train = xl[:,0:4000]
    y_test = xl[:,4000:5000]
    print("---------------------------------------------------------------")
    w,b,w_h,b_h = mytrain(x_train,y_train)
    mytest(x_test,y_test,w,b,w_h,b_h)
    print("---------------------------------------------------------------")
if __name__ == '__main__':
    main()
实验结果
---------------------------------------------------------------
mytrain迭代步数:300
学习因子:0.3
隐藏层神经元个数:28
1000个样本,正确的有: 933
---------------------------------------------------------------
迭代300步,正确率就有93.3%啦,还不错的正确率~
python手写神经网络实现识别手写数字的更多相关文章
- matlab手写神经网络实现识别手写数字
		
实验说明 一直想自己写一个神经网络来实现手写数字的识别,而不是套用别人的框架.恰巧前几天,有幸从同学那拿到5000张已经贴好标签的手写数字图片,于是我就尝试用matlab写一个网络. 实验数据:500 ...
 - 使用神经网络来识别手写数字【译】(三)- 用Python代码实现
		
实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样识别手写数字. 我们用Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNI ...
 - 如何用卷积神经网络CNN识别手写数字集?
		
前几天用CNN识别手写数字集,后来看到kaggle上有一个比赛是识别手写数字集的,已经进行了一年多了,目前有1179个有效提交,最高的是100%,我做了一下,用keras做的,一开始用最简单的MLP, ...
 - python机器学习使用PCA降维识别手写数字
		
PCA降维识别手写数字 关注公众号"轻松学编程"了解更多. PCA 用于数据降维,减少运算时间,避免过拟合. PCA(n_components=150,whiten=True) n ...
 - 解决python中import时无法识别自己写的包和模块的方法
		
我们用pycharm打开自己写的代码,当多个文件之间有相互依赖的关系的时候,import无法识别自己写的文件,但是我们写的文件又确实在同一个文件夹中, 这种问题可以用下面的方法解决: 1)打开File ...
 - Python实现神经网络算法识别手写数字集
		
最近忙里偷闲学习了一点机器学习的知识,看到神经网络算法时我和阿Kun便想到要将它用Python代码实现.我们用了两种不同的方法来编写它.这里只放出我的代码. MNIST数据集基于美国国家标准与技术研究 ...
 - NN:神经网络实现识别手写的1~9的10个数字—Jason niu
		
import numpy as np from sklearn.datasets import load_digits from sklearn.metrics import confusion_ma ...
 - Tensorflow搭建卷积神经网络识别手写英语字母
		
更新记录: 2018年2月5日 初始文章版本 近几天需要进行英语手写体识别,查阅了很多资料,但是大多数资料都是针对MNIST数据集的,并且主要识别手写数字.为了满足实际的英文手写识别需求,需要从训练集 ...
 - 用BP人工神经网络识别手写数字
		
http://wenku.baidu.com/link?url=HQ-5tZCXBQ3uwPZQECHkMCtursKIpglboBHq416N-q2WZupkNNH3Gv4vtEHyPULezDb5 ...
 
随机推荐
- TensorFlow基础笔记(5) VGGnet_test
			
参考 http://blog.csdn.net/jsond/article/details/72667829 资源: 1.相关的vgg模型下载网址 http://www.vlfeat.org/matc ...
 - php -- 判断文件是否存在
			
file_exists is_file is_dir 基本上,PHP的 file_exists = is_dir + is_file 写程序验证一下: 分别执行1000次,记录所需时间. ------ ...
 - 【BZOJ】1025: [SCOI2009]游戏(置换群+dp+特殊的技巧+lcm)
			
http://www.lydsy.com/JudgeOnline/problem.php?id=1025 首先根据置换群可得 $$排数=lcm\{A_i, A_i表示循环节长度\}, \sum_{i= ...
 - 【BZOJ】1626: [Usaco2007 Dec]Building Roads 修建道路(kruskal)
			
http://www.lydsy.com/JudgeOnline/problem.php?id=1626 依旧是水题..太水了.. #include <cstdio> #include & ...
 - 关于 Apache 的 25 个初中级面试题
			
关于 Apache 的 25 个初中级面试题 出自:http://blog.jobbole.com/60471/
 - PHP正则表达式 /i, /s, /x,/u, /U, /A, /D, /S等模式修饰符
			
i (PCRE_CASELESS) 如果设置了这个修饰符, 模式中的字母会进行大小写不敏感匹配. m (PCRE_MULTILINE) 默认情况下, PCRE认为目标字符串是由单行字符组成的(然而实际 ...
 - 编程之美 最长递增子序列 LIS
			
1. O(N*logN) 解法 先对序列排序, 然后寻找两个序列的最长公共子序列 2. O(N*N) 的动态规划解法 令 LIST[i] 表示以 i 为结尾的最长子序列的长度, 那么 LIST[J] ...
 - 《SQL Server 2000设计与T-SQL编程》
			
<SQL Server 2000设计与T-SQL编程> <SQL Server 2000设计与T-SQL编程>笔记1 http://dukedingding.blog.sohu ...
 - M451例程讲解之按键
			
/**************************************************************************//** * @file main.c * @ve ...
 - SharedPreferences小技巧
			
相信Android的这个最简单的存储方式大家都很熟悉了,但是有一个小小技巧,也许你没有用过,今天就跟大家分享一下,我们可以把SharedPreferences封装在一个工具类中,当我们需要写数据和读数 ...