CNN实现

概述

我在qwe中有两种,第一种是按照Ng课程中的写法,多层循环嵌套得到每次的“小方格”,然后WX+b,这样的做法是最简单,直观。但是效率极其慢。基本跑个10张以内图片都会卡的要死。

第二种方法是使用img2col,将其转换为对应的矩阵,然后直接做一次矩阵乘法运算。

先看第一种

def forward(self, X):
m, n_H_prev, n_W_prev, n_C_prev = X.shape
(f, f, n_C_prev, n_C) = self.W.shape
n_H = int((n_H_prev - f + 2 * self.pad) / self.stride) + 1
n_W = int((n_W_prev - f + 2 * self.pad) / self.stride) + 1
n_H, n_W, n_C = self.output_size Z = np.zeros((m, n_H, n_W, n_C))
X_pad = zero_pad(X, self.pad)
for i in range(m):
for h in range(n_H):
for w in range(n_W):
for c in range(n_C):
vert_start = h * self.stride
vert_end = vert_start + f
horiz_start = w * self.stride
horiz_end = horiz_start + f
A_slice_prev =X_pad[i,vert_start:vert_end, horiz_start:horiz_end, :]
Z[i,h,w,c] = conv_single_step(A_slice_prev, self.W[...,c], self.b[...,c]) def conv_single_step(X, W, b):
# 对一个裁剪图像进行卷积
# X.shape = f, f, prev_channel_size
return np.sum(np.multiply(X, W) + b)

对于m,n_H,n_W,n_C循环就是取得裁剪小方块,可以看到这里的计算复杂度m * n_H * n_W * n_C * (f*f的矩阵计算)

第二种方法,先转换成大矩阵,再进行一次矩阵运算,相当于节省了多次小矩阵运算时间,这还是很可观的,能查个几十倍的速度。

img2col原理很简单,详情可参考caffe im2col

就是循环将每一部分都拉长成一维矩阵拼凑起来。

对于CNN来说,H就是要计算方块的个数即m(样本数) n_H(最终生成图像行数)n_W(最终生成图像列数),W就是f(核kernel长)f(核宽)*(输入样本通道输)

然后还要把参数矩阵W也拉成这个样子,H就是f(核长)f(核宽)(输入样本通道输),W列数就是核数kernel_size

如下图



def img2col(X, pad, stride, f):
pass
ff = f * f
m, n_H_prev, n_W_prev, n_C_prev= X.shape
n_H = int((n_H_prev - f + 2 * pad) / stride) + 1
n_W = int((n_W_prev - f + 2 * pad) / stride) + 1
Z = np.zeros((m * n_H * n_W, f * f * n_C_prev))
X_pad = np.pad(X, ((0, 0), (pad, pad), (pad, pad), (0, 0)), 'constant', constant_values=0)
row = -1 for i in range(m):
for h in range(n_H):
for w in range(n_W):
row += 1
vert_start = h * stride
horiz_start = w * stride
for col in range(f * f * n_C_prev):
t = col // n_C_prev
hh = t // f
ww = t % f
cc = col % n_C_prev
Z[row, col] = X_pad[i, vert_start + hh, horiz_start + ww, cc] def speed_forward(model, X):
W = model.W
b = model.b
stride = model.stride
pad = model.pad
(n_C_prev, f, f, n_C) = W.shape
m, n_H_prev, n_W_prev, n_C_prev = X.shape n_H = int((n_H_prev - f + 2 * pad) / stride) + 1
n_W = int((n_W_prev - f + 2 * pad) / stride) + 1 # WW = W.swapaxes(2,1)
# WW = WW.swapaxes(1,0) XX = img2col(X, pad, stride, f)
# WW = WW.reshape(f*f*n_C_prev, n_C)
WW = W.reshape(f*f*n_C_prev, n_C)
model.XX = XX
model.WW = WW Z = np.dot(XX, WW) + b
return Z.reshape(m, n_H, n_W, n_C)

这种耗时操作,最好使用Cython扩展来写,不然速度还是不够理想。Cython扩展代码code

反向传播同理,具体代码参考

github

qwe框架- CNN 实现的更多相关文章

  1. 深度学习原理与框架-CNN在文本分类的应用 1.tf.nn.embedding_lookup(根据索引数据从数据中取出数据) 2.saver.restore(加载sess参数)

    1. tf.nn.embedding_lookup(W, X) W的维度为[len(vocabulary_list), 128], X的维度为[?, 8],组合后的维度为[?, 8, 128] 代码说 ...

  2. 深蓝色 --ppt

    Deep Learning of Binary Hash Codes for Fast Image Retrieval [Paper] [Code-Caffe] 1. 摘要 针对图像检索问题,提出简单 ...

  3. [基础]Deep Learning的基础概念

    目录 DNN CNN DNN VS CNN Example 卷积的好处why convolution? DCNN 卷积核移动的步长 stride 激活函数 active function 通道 cha ...

  4. qwe 简易深度框架

    qwe github地址 简介 简单的深度框架,参考Ng的深度学习课程作业,使用了keras的API设计. 方便了解网络具体实现,避免深陷于成熟框架的细节和一些晦涩的优化代码. 网络层实现了Dense ...

  5. 【深度学习系列3】 Mariana CNN并行框架与图像识别

    [深度学习系列3] Mariana CNN并行框架与图像识别 本文是腾讯深度学习系列文章的第三篇,聚焦于腾讯深度学习平台Mariana中深度卷积神经网络Deep CNNs的多GPU模型并行和数据并行框 ...

  6. 卷积神经网络CNN与深度学习常用框架的介绍与使用

    一.神经网络为什么比传统的分类器好 1.传统的分类器有 LR(逻辑斯特回归) 或者 linear SVM ,多用来做线性分割,假如所有的样本可以看做一个个点,如下图,有蓝色的点和绿色的点,传统的分类器 ...

  7. 我所写的CNN框架 VS caffe

    我所写的CNN框架 VS caffe 一个月前.自己模仿caffe实现了一个卷积神经网络的框架. 同样点 1无缝支持CPU和GPU模式,GPU模式使用cuda实现. 不同点 1我的CNN不依赖与不论什 ...

  8. ubuntu之路——day19.2 开源框架与迁移、CNN中的数据扩充

    开源框架与迁移 上面介绍了一些已经取得很好成绩的CNN框架,我们可以直接从GitHub上下载这些神经网络的结构和已经在ImageNet等数据集上训练好的权重超参数. 在应用于我们自己的数据时. 1.如 ...

  9. CNN基础框架简介

    卷积神经网络简介 卷积神经网络是多层感知机的变种,由生物学家休博尔和维瑟尔在早期关于猫视觉皮层的研究发展而来.视觉皮层的细胞存在一个复杂的构造,这些细胞对视觉输入空间的子区域非常敏感,我们称之为感受野 ...

随机推荐

  1. dom4j读取xml

    -----记录和回顾是一个比学习更重要的环节----- 一.首先,我们需要知道xml是做什么的 1.作为软件的配置文件 2.作为数据的载体(小型的数据库) 二.xml的语法 xml文件以xml后缀名结 ...

  2. 备忘:Junit单元测试

    junit 目前测试都是在main方法中调用目前的结果都需要人工对比是否是想要的 1.使用Junit测试方法,绿色条条代表方法测试成功,没有bug,如果是红色条条代表有异常,测试不通过2.点击方法名. ...

  3. 浏览器中页面的visibility状态及变化监听

    需求 在浏览器中播放视频,当用户进行页面切换操作时.需要根据视频播放页是否处于可见状态,来控制视频的暂停及重新播放. 相关文档 参考MDN中,关于页面的可见性相关的API说明.https://deve ...

  4. apache 限制IP访问

    <Directory "/var/www"> Options All AllowOverride None Order Deny,Allow Deny From all ...

  5. ABP官方文档翻译 5.4 SwaggerUI集成

    SwaggerUI集成 介绍 ASP.NET Core 安装Nuget包 配置 测试 ASP.NET 5.x 安装Nuget包 配置 测试 介绍 在它的网站上:“...使用Swagger可用的API, ...

  6. SSH 面试题集锦

    1.  BeanFactory的作用是什么?   [中] BeanFactory是配置.创建.管理bean的容器,有时候也称为bean上下文.Bean与bean的依赖关系,也是由BeanFactory ...

  7. BZOJ 3239: Discrete Logging [BGSG]

    裸题 求\(ind_{n,a}b\),也就是\(a^x \equiv b \pmod n\) 注意这里开根不能直接下取整 这个题少了一些特判也可以过... #include <iostream& ...

  8. win7(windows 7)系统下安装SQL2005(SQL Server 2005)图文教程

    操作系统:Microsoft Windows 7 旗舰版(32位) 数据库版本:SQL Server 2005 简体中文开发板 数据库下载链接: https://pan.baidu.com/s/1cq ...

  9. 试用MarkDown

    自定义界面风格 可以在设置中选择日间,或者夜间模式进行定义.具体的定义项的说明,可以查看菜单栏 (Windows版本位于托盘按钮上) 自定义的帮助. MarkEditor几乎所有跟色彩有关的界面,都已 ...

  10. 关于c++栈溢出的问题

    我自己定义了一个数据类型node,嵌套在另一个数据类型当中时候,用到了delete函数, 在我node的声明当中声明了几个指针 在我的析构函数中却调用了delet函数 结果程序结果是能跑出来 提示我栈 ...