我想大部分程序员的第一个程序应该都是“hello world”,在深度学习领域,这个“hello world”程序就是手写字体识别程序。

这次我们详细的分析下手写字体识别程序,从而可以对深度学习建立一个基本的概念。

1.初始化权重和偏置矩阵,构建神经网络的架构

import numpy as np

class network():

  def __init__(self, sizes):

    self.num_layers = len(sizes)

    self.sizes = sizes

    self.biases = [ np.random.randn(y,1) for y in sizes[1:] ]

    self.weights = [ np.random.randn(y,x) for x,y in zip(sizes(:-1), sizes(1:)) ]

在实例化一个神经网络时,去初始化权重和偏置的矩阵,例如

  network0 = network([784, 30, 10])

可以初始化一个3层的神经网络, 各层神经元的个数分别为 784, 30 , 10

2. 如何去反向传播计算代价函数的梯度?

这个过程可以大概概括如下:

(1)正向传播,获得每个神经元的带权输出和激活因子(a)

(2)计算输出层的误差

(3)反向传播计算每一层的误差和梯度

用python实现的代码如下:

def backprop(self, x, y):

  delta_w = [ np.zeros(w.shape) for w in self.weights]

delta_b = [ np.zeros(b.shape)  for b in self.biases ]

#计算每个神经元的带权输入z及激活值

  zs = []

activation = x

activations = [x]

  for b,w in zip(self.biases, self.weights):

    z = np.dot(w, activation) + b

    zs.append(z)

    activation = sigmod(z)

    activations.append(activation)

#计算输出层误差(这里采用的是二次代价函数)

  delta = (activations[-1] - y) * sigmod_prime(zs[-1])

  delta_w[-1] = np.dot(delta, activations[-2].transpose())

  delta_b[-1] = delta

  #反向传播

  for l in xrange(2, self.num_layers):

    delta = np.dot(delta_w[-l+1].transpose(),delta)*sigmod_prime(zs[-l])

    delta_w[-l] = np.dot(delta, activations[-l-1].transpose())

    delta_b[-l] = delta

  return delta_w, delta_b

3.如何梯度下降,更新权重和偏置?

通过反向传播获得了更新权重和偏置的增量,进一步进行更新,梯度下降。

def update_mini_batch(self, mini_batch, eta):

  delta_w = [ np.zeros(w.shape) for w in self.weights ]

  delta_b = [ np.zeros(b.shape) for b in self.biases ]

  for x,y in mini_batch:

    (这里针对一个小批量内所有样本,应用反向传播,积累权重和偏置的变化)

    delta_w_p, delta_b_p = self.backprop(x,y)

    delta_w = [ dt_w + dt_w_p for dt_w,dt_w_p in zip(delta_w, delta_w_p)]

    delta_b = [ dt_b + dt_b_p for dt_b,dt_b_p in zip(delta_b, delta_b_p)]

  self.weights = [ w-(eta/len(mini_batch)*nw) for w,nw in zip(self.weights, delta_w)]

  self.biases = [ b-(eta/len(mini_batch)*nb) for b,nb in zip(self.biases, delta_b)]

def SGD(self, epochs, training_data,  mini_batch_size,eta, test_data=None):

  if test_data:

    n_tests = len(tast_data)

  n_training_data = len(training_data)

  for i in xrange(0, epochs):

    random.shuffle(training_data)

    mini_batches = [  training_data[k:k+mini_batch_size]

            for k in xrange(0, n_training_data, mini_batch_size)

            ]

    for mini_batch in mini_batches:

      self.update_mini_batch(mini_batch, eta)

  

深度学习---手写字体识别程序分析(python)的更多相关文章

  1. 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识

    深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...

  2. 深度学习-tensorflow学习笔记(2)-MNIST手写字体识别

    深度学习-tensorflow学习笔记(2)-MNIST手写字体识别超级详细版 这是tf入门的第一个例子.minst应该是内置的数据集. 前置知识在学习笔记(1)里面讲过了 这里直接上代码 # -*- ...

  3. pytorch深度学习神经网络实现手写字体识别

    利用平pytorch搭建简单的神经网络实现minist手写字体的识别,采用三层线性函数迭代运算,使得其具备一定的非线性转化与运算能力,其数学原理如下: 其具体实现代码如下所示:import torch ...

  4. 【OpenCV】opencv3.0中的SVM训练 mnist 手写字体识别

    前言: SVM(支持向量机)一种训练分类器的学习方法 mnist 是一个手写字体图像数据库,训练样本有60000个,测试样本有10000个 LibSVM 一个常用的SVM框架 OpenCV3.0 中的 ...

  5. 机器学习之路: python 支持向量机 LinearSVC 手写字体识别

    使用python3 学习sklearn中支持向量机api的使用 可以来到我的git下载源代码:https://github.com/linyi0604/MachineLearning # 导入手写字体 ...

  6. 基于kNN的手写字体识别——《机器学习实战》笔记

    看完一节<机器学习实战>,算是踏入ML的大门了吧!这里就详细讲一下一个demo:使用kNN算法实现手写字体的简单识别 kNN 先简单介绍一下kNN,就是所谓的K-近邻算法: [作用原理]: ...

  7. 第二节,mnist手写字体识别

    1.获取mnist数据集,得到正确的数据格式 mnist = input_data.read_data_sets('MNIST_data',one_hot=True) 2.定义网络大小:图片的大小是2 ...

  8. 【深度学习系列】PaddlePaddle之手写数字识别

    上周在搜索关于深度学习分布式运行方式的资料时,无意间搜到了paddlepaddle,发现这个框架的分布式训练方案做的还挺不错的,想跟大家分享一下.不过呢,这块内容太复杂了,所以就简单的介绍一下padd ...

  9. 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

随机推荐

  1. SQL - 查询某一字段值相同而另一字段值最大的记录

    有时需要以某一字段作为分组,筛选每一组的另一字段值最大(或最小)的记录.例如,有如下表 app,存储了 app 的 ID.名称.版本号等信息.现在要筛选出每个 app 版本最大的记录. 方法一 SEL ...

  2. SpringCloud (十) Hystrix Dashboard单体监控、集群监控、与消息代理结合

    一.前言 Dashboard又称为仪表盘,是用来监控项目的执行情况的,本文旨在Dashboard的使用 分别为单体监控.集群监控.与消息代理结合. 代码请戳我的github 二.快速入门 新建一个Sp ...

  3. CentOS 7快速入门系列教程(一)

    基本命令 ls 列举当前目录下的所有文件夹 ls -l 查看文件还是文件夹   d表示文件夹   -表示文件 ls --help man ls 询问命令 man 3 malloc 查看函数 cd 跳转 ...

  4. 非常强力的reduce

    Array 的方法 reduce 是一个有非常多用处的函数. 它一个非常具有代表性的作用是将一个数组转换成一个值.但是你可以用它来做更多的事. 1.使用"reduce"代替&quo ...

  5. 微服务深入浅出(11)-- SpringBoot整合Docker

    添加Dockerfile 在目录src/main/resources目录下店家Dockerfile文件: From java MAINTAINER "Eric"<eric.l ...

  6. python初步学习-python 模块之 sys(持续补充)

    sys sys 模块包括了一组非常实用的服务,内含很多函数方法和变量 sys 模块重要函数变量 sys.stdin 标准输出流 sys.stdout 标准输出流 sys.stderr 标准错误流 sy ...

  7. Python文件操作-文件的增删改查

    需求:对文件进行增删改查 由于时间原因,本次代码没有增加任何注释,如有疑问,请联系编辑者:闫龙 其实我也是醉了,看着这些个代码,我脑袋也特么大了,没办法,大神说了,不让用新知识,只可以使用学过的,所以 ...

  8. 你需要知道的12个Git高级命令【转】

    转自:http://www.linuxidc.com/Linux/2016-01/128024.htm 众所周知,Git目前已经是分布式版本控制领域的翘楚,围绕着Git形成了完整的生态圈.学习Git, ...

  9. linux shell语言编程规范安全篇之通用原则【转】

    shell语言编程规范安全篇是针对bash语言编程中的数据校验.加密与解密.脚本执行.目录&文件操作等方面,描述可能导致安全漏洞或风险的常见编码错误.该规范基于业界最佳实践,并总结了公司内部的 ...

  10. java系统的优化

    1.tomcat.jboss.jetty的jvm内存,增大 2.数据库的优化,如MySQL的innodb_buffer_pool_size等参数,增大