NN入门,手把手教你用Numpy手撕NN(3)

这是一篇包含极少数学的CNN入门文章

上篇文章中简单介绍了NN的反向传播,并利用反向传播实现了一个简单的NN,在这篇文章中将介绍一下CNN。

CNN

CV(计算机视觉)作为AI的一大研究方向,越来越多的人选择了这个方向,其中使用的深度学习的方法基本以卷积神经网络(CNN)为基础。因此,这篇文章将介绍CNN的实现。

CNN与我们之前介绍的NN的相比,出现了卷积层(Convolution层)和池化层(Pooling层)。其网络架构大致如下图所示

卷积层

与全连接神经网络的对比

为什么在有存在全连接神经网络的情况下还会出现卷积神经网络呢?

这就得来看看全连接神经网络存在的问题了,全连接神经网络中存在的问题是数据的形状被“忽略了”。比如,输入的数据是图像时,图像通常是高、长、通道方向上的3为形状。

但是,全连接层输入时,需要将3维数据拉平为1维数据,因此,我们可能会丢失图像数据中存在有的空间信息(如空间上临近像素为相似的值、RGB的各个通道之间的关联性、相距较远像素之间的关联性等),这些信息都会被全连接层丢失。

而卷积层可以保持形状不变,将输入的数据以相同的维度输出,因此,可能可以正确理解图像等具有形状的数据。

卷积运算

卷积层进行的处理就是卷积运算,运算方式如下图所示

一般在计算中也会加上偏置

填充

在进行卷积层的处理之前,有时要向输入数据的周围填入固定的数据(如0),称为填充(padding)。如下图所示

向输入数据的周围填入0,将大小为(4, 4)的输入数据变成了(6, 6)的形状。

为什么要进行填充操作?

在对大小为(4, 4)的输入数据使用(3, 3)的滤波器时,输出的大小会变成(2, 2),如果反复进行多次卷积运算,在某个时刻输出大小就有可能变成1,导致无法再应用卷积运算。为了避免这种情况,就要使用填充,使得卷积运算可以在保持空间大小不变的情况下将数据传给下一层。

步幅

应用滤波器的位置间隔称为步幅(stride)。在上面的例子中,步幅都为1,如果将步幅设为2,则如下图所示

可以发现,增大步幅后,输出大小会变小,增大填充后,输出大小会变大。

假设输入大小为(H, W),滤波器大小为(FH, FW),输出大小为(OH, OW),填充为P,步幅为S,则输出大小可以表示为

\[OH=\frac{H+2P-FH}{S}+1 \\
OW=\frac{W+2P-FW}{S}+1
\]

三维数据卷积运算

上面卷积运算的例子都是二维的数据,但是,一般来说我们的图像数据都是三维的,除了高、宽之外,还需要处理通道方向的数据。如下图所示

其计算方式为每个通道处的数据与对应通道的卷积核相乘,最后将各个通道得到的结果相加,从而得到输出。

这里需要注意的是,一般情况下,卷积核的通道数需要与输入数据的通道数相同。(有时会使用1x1卷积核来对通道数进行降/升维操作)。可参考这篇文章

上面给出的例子输出的结果还是一个通道的,如果我们想要输出的结果在通道上也有多个输出,该怎么做呢?如下图所示

即使用多个卷积核

池化层

池化是缩小高、长方向上的空间的运算。比如下图所示的最大池化

除了上图所示的Max池化之外,还有Average池化。一般来说,池化的窗口大小会和步幅设定成相同的值。

im2col

从前面的例子来看,会发现,如果完全按照计算过程来写代码的话,要用上好几层for循环,这样的话不仅写起来麻烦,估计在运行的时候计算速度也很慢。这里将介绍im2col的方法。

im2col将输入数据展开以适合卷积核,如下图所示

对3维的输入数据应用im2col之后,数据转化维2维矩阵。

使用im2col展开输入数据后,之后就只需将卷积层的卷积核纵向展开为1列,并计算2个矩阵的乘积即可。这和全连接层的Affine层进行的处理基本相同。

代码实现

讲了这么多,这里将给出代码实现

import numpy as np

def im2col(input_data, filter_h, filter_w, stride=1, pad=0):
"""
input_data : 由(数据量,通道,高,长)的4维数组构成的输入数据
filter_h : 滤波器的高
filter_w : 滤波器的长
stride : 步幅
pad : 填充
"""
N, C, H, W = input_data.shape
out_h = (H + 2*pad - filter_h)//stride + 1
out_w = (W + 2*pad - filter_w)//stride + 1 img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
col = np.zeros((N, C, filter_h, filter_w, out_h, out_w)) for y in range(filter_h):
y_max = y + stride*out_h
for x in range(filter_w):
x_max = x + stride*out_w
col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride] col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
return col def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0):
N, C, H, W = input_shape
out_h = (H + 2*pad - filter_h)//stride + 1
out_w = (W + 2*pad - filter_w)//stride + 1
col = col.reshape(N, out_h, out_w, C, filter_h, filter_w).transpose(0, 3, 4, 5, 1, 2) img = np.zeros((N, C, H + 2*pad + stride - 1, W + 2*pad + stride - 1))
for y in range(filter_h):
y_max = y + stride*out_h
for x in range(filter_w):
x_max = x + stride*out_w
img[:, :, y:y_max:stride, x:x_max:stride] += col[:, :, y, x, :, :] return img[:, :, pad:H + pad, pad:W + pad] class Convolution:
def __init__(self, W, b, stride=1, pad=0):
self.W = W
self.b = b
self.stride = stride
self.pad = pad self.x = None
self.col = None
self.col_W = None self.dW = None
self.db = None def forward(self, x):
# [N, C, H, W]
FN, C, FH, FW = self.W.shape
N, C, H, W = x.shape
out_h = int(1 + (H + 2 * self.pad - FH) / self.stride)
out_w = int(1 + (W + 2 * self.pad - FW) / self.stride) col = im2col(x, FH, FW, self.stride, self.pad)
col_W = self.W.reshape(FN, -1).T # 滤波器展开
out = np.dot(col, col_W) + self.b out = out.reshape(N, out_h, out_w, -1).transpose(0, 3, 1, 2) self.x = x
self.col = col
self.col_W = col_W return out def backward(self, dout):
FN, C, FH, FW = self.W.shape
dout = dout.transpose(0, 2, 3, 1).reshape(-1, FN) self.db = np.sum(dout, axis=0)
self.dW = np.dot(self.col.T, dout)
self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW) dcol = np.dot(dout, self.col_W.T)
dx = col2im(dcol, self.x.shape, FH, FW, self.stride, self.pad) return dx class Pooling:
def __init__(self, pool_h, pool_w, stride=1, pad=0):
self.pool_h = pool_h
self.pool_w = pool_w
self.stride = stride
self.pad = pad self.x = None
self.arg_max = None def forward(self, x):
N, C, H, W = x.shape
out_h = int(1 + (H - self.pool_h) / self.stride)
out_w = int(1 + (W - self.pool_w) / self.stride) col = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad)
col = col.reshape(-1, self.pool_h*self.pool_w) arg_max = np.argmax(col, axis=1)
out = np.max(col, axis=1)
out = out.reshape(N, out_h, out_w, -1).transpose(0, 3, 1, 2) self.x = x
self.arg_max = arg_max return out def backward(self, dout):
dout = dout.transpose(0, 2, 3, 1) pool_size = self.pool_h * self.pool_w
dmax = np.zeros((dout.size, pool_size))
dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] = dout.flatten()
dmax = dmax.reshape(dout.shape + (pool_size,)) dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride, self.pad) return dx

小节

这篇文章断断续续地写了好久,中间还顺便在学tensorflow 2.0 ,还是框架用的舒服 orz。。。这几天还是决定把这篇文章写完,坑挖了还是得填,numpy手撕NN系列也算是暂时完成了,RNN后面再考虑。。。这之后准备再补补一些学过的算法的总结以及前段时间看的一些论文的总结。

本文首发于我的知乎

NN入门,手把手教你用Numpy手撕NN(三)的更多相关文章

  1. NN入门,手把手教你用Numpy手撕NN(一)

    前言 这是一篇包含极少数学推导的NN入门文章 大概从今年4月份起就想着学一学NN,但是无奈平时时间不多,而且空闲时间都拿去做比赛或是看动漫去了,所以一拖再拖,直到这8月份才正式开始NN的学习. 这篇文 ...

  2. NN入门,手把手教你用Numpy手撕NN(2)

    这是一篇包含较少数学推导的NN入门文章 上篇文章中简单介绍了如何手撕一个NN,但其中仍有可以改进的地方,将在这篇文章中进行完善. 误差反向传播 之前的NN计算梯度是利用数值微分法,虽容易实现,但是计算 ...

  3. 《手把手教你》系列技巧篇(三十八)-java+ selenium自动化测试-日历时间控件-下篇(详解教程)

    1.简介 理想很丰满现实很骨感,在应用selenium实现web自动化时,经常会遇到处理日期控件点击问题,手工很简单,可以一个个点击日期控件选择需要的日期,但自动化执行过程中,完全复制手工这样的操作就 ...

  4. 《手把手教你》系列技巧篇(三十)-java+ selenium自动化测试- Actions的相关操作下篇(详解教程)

    1.简介 本文主要介绍两个在测试过程中可能会用到的功能:Actions类中的拖拽操作和Actions类中的划取字段操作.例如:需要在一堆log字符中随机划取一段文字,然后右键选择摘取功能. 2.拖拽操 ...

  5. 《手把手教你》系列技巧篇(三十一)-java+ selenium自动化测试- Actions的相关操作-番外篇(详解教程)

    1.简介 上一篇中,宏哥说的宏哥在最后提到网站的反爬虫机制,那么宏哥在自己本地做一个网页,没有那个反爬虫的机制,谷歌浏览器是不是就可以验证成功了,宏哥就想验证一下自己想法,于是写了这一篇文章,另外也是 ...

  6. 《手把手教你》系列技巧篇(三十二)-java+ selenium自动化测试-select 下拉框(详解教程)

    1.简介 在实际自动化测试过程中,我们也避免不了会遇到下拉选择的测试,因此宏哥在这里直接分享和介绍一下,希望小伙伴或者童鞋们在以后工作中遇到可以有所帮助. 2.select 下拉框 2.1Select ...

  7. 《手把手教你》系列技巧篇(三十三)-java+ selenium自动化测试-单选和多选按钮操作-上篇(详解教程)

    1.简介 在实际自动化测试过程中,我们同样也避免不了会遇到单选和多选的测试,特别是调查问卷或者是答题系统中会经常碰到.因此宏哥在这里直接分享和介绍一下,希望小伙伴或者童鞋们在以后工作中遇到可以有所帮助 ...

  8. 《手把手教你》系列技巧篇(三十四)-java+ selenium自动化测试-单选和多选按钮操作-中篇(详解教程)

    1.简介 今天这一篇宏哥主要是讲解一下,如何使用list容器来遍历单选按钮.大致两部分内容:一部分是宏哥在本地弄的一个小demo,另一部分,宏哥是利用JQueryUI网站里的单选按钮进行实战. 2.d ...

  9. 《手把手教你》系列技巧篇(三十五)-java+ selenium自动化测试-单选和多选按钮操作-下篇(详解教程)

    1.简介 今天这一篇宏哥主要是讲解一下,如何使用list容器来遍历多选按钮.大致两部分内容:一部分是宏哥在本地弄的一个小demo,另一部分,宏哥是利用JQueryUI网站里的多选按钮进行实战. 2.d ...

随机推荐

  1. Java IO编程——File文件操作类

    在Java语言里面提供有对于文件操作系统操作的支持,而这个支持就在java.io.File类中进行了定义,也就是说在整个java.io包里面,File类是唯一 一个与文件本身操作(创建.删除.重命名等 ...

  2. Linux下终端字体颜色设置方法

    颜色=\033[代码;前景;背景m 如:\033[1;32;40m表示高亮显示字体为绿色,背景色为黑色 颜色=\[\033[代码;前景;背景m\] echo -e "this is a \0 ...

  3. const var let 三者的区别

    1.const定义的变量不可以修改,而且必须初始化. ;//正确 const b;//错误,必须初始化 console.log('函数外const定义b:' + b);//有输出值 b = ; con ...

  4. MAVEN(一) 安装和环境变量配置

    一.安装步骤 1.安装maven之前先安装jdk,并配置好环境变量.确保已安装JDK,并 “JAVA_HOME” 变量已加入到 Windows 环境变量. 2.下载maven 进入官方网站下载网址如下 ...

  5. Linux tar命令解压时提示时间戳异常的处理办法

    在Linux服务器上的文件会有3个时间戳信息 访问时间(Access).修改时间(Modify).改变时间(Change),都是存放在该文件的Inode里面 问题描述: 公司网站是前后端分离的,所有的 ...

  6. vue cli 4.0.5 的使用

    vue cli 4.0.5 的使用 现在的 vue 脚手架已经升级到4.0的版本了,前两日vue 刚发布了3.0版本,我看了一下cli 4 和cli 3 没什么区别,既然这样,就只总结一下vue cl ...

  7. 爬虫之scrapy简介

    原始的爬虫流程:效率低.同步.阻塞 scrapy执行流程:效率高.异步.非阻塞 scrapy的概念 scrapy是一个爬虫框架 开发速度快 稳定性高 性能优越 scrapy的流程 1. 爬虫模块(Sp ...

  8. TCP/IP协议第一卷第一章

    1.链路层 链路层有时也称作数据链路层或网络接口层,通常包括操作系统中的设备驱动程序和计算机中对应的网络接口卡.它们一起处理与电缆(或其他任何传输媒介)的物理接口细节.把链路层地址和网络层地址联系起来 ...

  9. CSPS模拟 95

    T_T flag彻底倒了 虽然打一开始就没觉得能实现过 可是我好桑心T_T skyh那个没素质的还一直bb T_T

  10. 程序员学点xx 之 Redis

    程序员学点xx 之 Redis 概述 其实程序员也要和操作系统打交道, 比如最常见的,部署自己电脑上的开发环境. 当然有时某些牛人, 觉得运维或基础部门的同事不够给力, 亲自上手部署服务器或线上环境, ...