CNN实现

概述

我在qwe中有两种,第一种是按照Ng课程中的写法,多层循环嵌套得到每次的“小方格”,然后WX+b,这样的做法是最简单,直观。但是效率极其慢。基本跑个10张以内图片都会卡的要死。

第二种方法是使用img2col,将其转换为对应的矩阵,然后直接做一次矩阵乘法运算。

先看第一种

def forward(self, X):
m, n_H_prev, n_W_prev, n_C_prev = X.shape
(f, f, n_C_prev, n_C) = self.W.shape
n_H = int((n_H_prev - f + 2 * self.pad) / self.stride) + 1
n_W = int((n_W_prev - f + 2 * self.pad) / self.stride) + 1
n_H, n_W, n_C = self.output_size Z = np.zeros((m, n_H, n_W, n_C))
X_pad = zero_pad(X, self.pad)
for i in range(m):
for h in range(n_H):
for w in range(n_W):
for c in range(n_C):
vert_start = h * self.stride
vert_end = vert_start + f
horiz_start = w * self.stride
horiz_end = horiz_start + f
A_slice_prev =X_pad[i,vert_start:vert_end, horiz_start:horiz_end, :]
Z[i,h,w,c] = conv_single_step(A_slice_prev, self.W[...,c], self.b[...,c]) def conv_single_step(X, W, b):
# 对一个裁剪图像进行卷积
# X.shape = f, f, prev_channel_size
return np.sum(np.multiply(X, W) + b)

对于m,n_H,n_W,n_C循环就是取得裁剪小方块,可以看到这里的计算复杂度m * n_H * n_W * n_C * (f*f的矩阵计算)

第二种方法,先转换成大矩阵,再进行一次矩阵运算,相当于节省了多次小矩阵运算时间,这还是很可观的,能查个几十倍的速度。

img2col原理很简单,详情可参考caffe im2col

就是循环将每一部分都拉长成一维矩阵拼凑起来。

对于CNN来说,H就是要计算方块的个数即m(样本数) n_H(最终生成图像行数)n_W(最终生成图像列数),W就是f(核kernel长)f(核宽)*(输入样本通道输)

然后还要把参数矩阵W也拉成这个样子,H就是f(核长)f(核宽)(输入样本通道输),W列数就是核数kernel_size

如下图



def img2col(X, pad, stride, f):
pass
ff = f * f
m, n_H_prev, n_W_prev, n_C_prev= X.shape
n_H = int((n_H_prev - f + 2 * pad) / stride) + 1
n_W = int((n_W_prev - f + 2 * pad) / stride) + 1
Z = np.zeros((m * n_H * n_W, f * f * n_C_prev))
X_pad = np.pad(X, ((0, 0), (pad, pad), (pad, pad), (0, 0)), 'constant', constant_values=0)
row = -1 for i in range(m):
for h in range(n_H):
for w in range(n_W):
row += 1
vert_start = h * stride
horiz_start = w * stride
for col in range(f * f * n_C_prev):
t = col // n_C_prev
hh = t // f
ww = t % f
cc = col % n_C_prev
Z[row, col] = X_pad[i, vert_start + hh, horiz_start + ww, cc] def speed_forward(model, X):
W = model.W
b = model.b
stride = model.stride
pad = model.pad
(n_C_prev, f, f, n_C) = W.shape
m, n_H_prev, n_W_prev, n_C_prev = X.shape n_H = int((n_H_prev - f + 2 * pad) / stride) + 1
n_W = int((n_W_prev - f + 2 * pad) / stride) + 1 # WW = W.swapaxes(2,1)
# WW = WW.swapaxes(1,0) XX = img2col(X, pad, stride, f)
# WW = WW.reshape(f*f*n_C_prev, n_C)
WW = W.reshape(f*f*n_C_prev, n_C)
model.XX = XX
model.WW = WW Z = np.dot(XX, WW) + b
return Z.reshape(m, n_H, n_W, n_C)

这种耗时操作,最好使用Cython扩展来写,不然速度还是不够理想。Cython扩展代码code

反向传播同理,具体代码参考

github

qwe框架- CNN 实现的更多相关文章

  1. 深度学习原理与框架-CNN在文本分类的应用 1.tf.nn.embedding_lookup(根据索引数据从数据中取出数据) 2.saver.restore(加载sess参数)

    1. tf.nn.embedding_lookup(W, X) W的维度为[len(vocabulary_list), 128], X的维度为[?, 8],组合后的维度为[?, 8, 128] 代码说 ...

  2. 深蓝色 --ppt

    Deep Learning of Binary Hash Codes for Fast Image Retrieval [Paper] [Code-Caffe] 1. 摘要 针对图像检索问题,提出简单 ...

  3. [基础]Deep Learning的基础概念

    目录 DNN CNN DNN VS CNN Example 卷积的好处why convolution? DCNN 卷积核移动的步长 stride 激活函数 active function 通道 cha ...

  4. qwe 简易深度框架

    qwe github地址 简介 简单的深度框架,参考Ng的深度学习课程作业,使用了keras的API设计. 方便了解网络具体实现,避免深陷于成熟框架的细节和一些晦涩的优化代码. 网络层实现了Dense ...

  5. 【深度学习系列3】 Mariana CNN并行框架与图像识别

    [深度学习系列3] Mariana CNN并行框架与图像识别 本文是腾讯深度学习系列文章的第三篇,聚焦于腾讯深度学习平台Mariana中深度卷积神经网络Deep CNNs的多GPU模型并行和数据并行框 ...

  6. 卷积神经网络CNN与深度学习常用框架的介绍与使用

    一.神经网络为什么比传统的分类器好 1.传统的分类器有 LR(逻辑斯特回归) 或者 linear SVM ,多用来做线性分割,假如所有的样本可以看做一个个点,如下图,有蓝色的点和绿色的点,传统的分类器 ...

  7. 我所写的CNN框架 VS caffe

    我所写的CNN框架 VS caffe 一个月前.自己模仿caffe实现了一个卷积神经网络的框架. 同样点 1无缝支持CPU和GPU模式,GPU模式使用cuda实现. 不同点 1我的CNN不依赖与不论什 ...

  8. ubuntu之路——day19.2 开源框架与迁移、CNN中的数据扩充

    开源框架与迁移 上面介绍了一些已经取得很好成绩的CNN框架,我们可以直接从GitHub上下载这些神经网络的结构和已经在ImageNet等数据集上训练好的权重超参数. 在应用于我们自己的数据时. 1.如 ...

  9. CNN基础框架简介

    卷积神经网络简介 卷积神经网络是多层感知机的变种,由生物学家休博尔和维瑟尔在早期关于猫视觉皮层的研究发展而来.视觉皮层的细胞存在一个复杂的构造,这些细胞对视觉输入空间的子区域非常敏感,我们称之为感受野 ...

随机推荐

  1. TCP连接之未连接队列的理解[转]

    tcp服务器在TCP/IP协议中,TCP协议提供可靠的连接服务,采用三次握手建立一个连接. 第一次握手:建立连接时,客户端发送syn包(syn=j)到服务器,并进入SYN_SEND状态,等待服务器确认 ...

  2. dfs 与 dijkstra 总结

    Dijkstra: //寻求加权图起始点到各个节点的最短路径 for i <- 1:n do distance[i] <- INF; distance[0] <- 0;//起始节点距 ...

  3. JavaScript转unix时间戳

    由于 unix 的时间戳是10位不带毫秒的,所以前端获取到时间戳之后需要做一下处理,才能获取正确的时间. // 假设这里是从服务端获取到的时间戳 var unixTime = data.time; / ...

  4. linux基本语法和常用运维命令

    linux上的操作一般是命令行操作,看起来很高大上,让人畏而远之. Help!Help! 忽然间闯入的linux黑黑的世界,怎么办,不要慌.赶紧敲出一个help命令,然后回车,黑色的窗口就会展示一些常 ...

  5. ABP官方文档翻译 7.2 Hangfire集成

    Hangfire集成 介绍 ASP.NET Core集成 ASP.NET MVC 5.x集成 面板授权 介绍 Hangfire是一个综合的后台job管理器.你可以 把它集成到ABP,用来取代默认的后台 ...

  6. 软件开发:网站&视频&书籍推荐(不断更新)

    利用书籍进行系统学习,凭借博客/新闻等资料开阔眼界,辅之以代码及项目实战,并勤加以总结,方可进步. 常用网站: Leetcode刷题:https://leetcode.com/ ,练习数据结构和算法必 ...

  7. srand()和rand()函数的使用

    rand()函数不接受参数,默认以1为种子(即起始值). 随机数生成器总是以相同的种子开始,所以形成的伪随机数列也相同,失去了随机意义.(但这样便于程序调试) srand()函数就是指明种子的大小:只 ...

  8. putty,xshell以及密钥认证:linux学习第二篇

    1.    Putty下载 官网:https://www.chiark.greenend.org 下载putty的zip包 2.    Putty使用 2000为可查看的文件行数,建议设置为2000 ...

  9. Linux知识体系之路径属性与目录

    最近在看鸟哥的Linux私房菜,我觉得这本书还是很不错的.这里进行相关的总结. 1.Linux目录权限概念   Linux一般讲目录可存取的方式分为三个类别,分别是owner/group/other, ...

  10. 2n皇后问题

    此题为蓝桥杯基础练习题. 问题描述 给定一个n*n的棋盘,棋盘中有一些位置不能放皇后.现在要向棋盘中放入n个黑皇后和n个白皇后,使任意的两个黑皇后都不在同一行.同一列或同一条对角线上,任意的两个白皇后 ...