qwe框架- CNN 实现
CNN实现
概述
我在qwe中有两种,第一种是按照Ng课程中的写法,多层循环嵌套得到每次的“小方格”,然后WX+b,这样的做法是最简单,直观。但是效率极其慢。基本跑个10张以内图片都会卡的要死。
第二种方法是使用img2col,将其转换为对应的矩阵,然后直接做一次矩阵乘法运算。
先看第一种
def forward(self, X):
m, n_H_prev, n_W_prev, n_C_prev = X.shape
(f, f, n_C_prev, n_C) = self.W.shape
n_H = int((n_H_prev - f + 2 * self.pad) / self.stride) + 1
n_W = int((n_W_prev - f + 2 * self.pad) / self.stride) + 1
n_H, n_W, n_C = self.output_size
Z = np.zeros((m, n_H, n_W, n_C))
X_pad = zero_pad(X, self.pad)
for i in range(m):
for h in range(n_H):
for w in range(n_W):
for c in range(n_C):
vert_start = h * self.stride
vert_end = vert_start + f
horiz_start = w * self.stride
horiz_end = horiz_start + f
A_slice_prev =X_pad[i,vert_start:vert_end, horiz_start:horiz_end, :]
Z[i,h,w,c] = conv_single_step(A_slice_prev, self.W[...,c], self.b[...,c])
def conv_single_step(X, W, b):
# 对一个裁剪图像进行卷积
# X.shape = f, f, prev_channel_size
return np.sum(np.multiply(X, W) + b)
对于m,n_H,n_W,n_C循环就是取得裁剪小方块,可以看到这里的计算复杂度m * n_H * n_W * n_C * (f*f的矩阵计算)
第二种方法,先转换成大矩阵,再进行一次矩阵运算,相当于节省了多次小矩阵运算时间,这还是很可观的,能查个几十倍的速度。
img2col原理很简单,详情可参考caffe im2col
就是循环将每一部分都拉长成一维矩阵拼凑起来。
对于CNN来说,H就是要计算方块的个数即m(样本数) n_H(最终生成图像行数)n_W(最终生成图像列数),W就是f(核kernel长)f(核宽)*(输入样本通道输)
然后还要把参数矩阵W也拉成这个样子,H就是f(核长)f(核宽)(输入样本通道输),W列数就是核数kernel_size
如下图


def img2col(X, pad, stride, f):
pass
ff = f * f
m, n_H_prev, n_W_prev, n_C_prev= X.shape
n_H = int((n_H_prev - f + 2 * pad) / stride) + 1
n_W = int((n_W_prev - f + 2 * pad) / stride) + 1
Z = np.zeros((m * n_H * n_W, f * f * n_C_prev))
X_pad = np.pad(X, ((0, 0), (pad, pad), (pad, pad), (0, 0)), 'constant', constant_values=0)
row = -1
for i in range(m):
for h in range(n_H):
for w in range(n_W):
row += 1
vert_start = h * stride
horiz_start = w * stride
for col in range(f * f * n_C_prev):
t = col // n_C_prev
hh = t // f
ww = t % f
cc = col % n_C_prev
Z[row, col] = X_pad[i, vert_start + hh, horiz_start + ww, cc]
def speed_forward(model, X):
W = model.W
b = model.b
stride = model.stride
pad = model.pad
(n_C_prev, f, f, n_C) = W.shape
m, n_H_prev, n_W_prev, n_C_prev = X.shape
n_H = int((n_H_prev - f + 2 * pad) / stride) + 1
n_W = int((n_W_prev - f + 2 * pad) / stride) + 1
# WW = W.swapaxes(2,1)
# WW = WW.swapaxes(1,0)
XX = img2col(X, pad, stride, f)
# WW = WW.reshape(f*f*n_C_prev, n_C)
WW = W.reshape(f*f*n_C_prev, n_C)
model.XX = XX
model.WW = WW
Z = np.dot(XX, WW) + b
return Z.reshape(m, n_H, n_W, n_C)
这种耗时操作,最好使用Cython扩展来写,不然速度还是不够理想。Cython扩展代码code
反向传播同理,具体代码参考
github
qwe框架- CNN 实现的更多相关文章
- 深度学习原理与框架-CNN在文本分类的应用 1.tf.nn.embedding_lookup(根据索引数据从数据中取出数据) 2.saver.restore(加载sess参数)
1. tf.nn.embedding_lookup(W, X) W的维度为[len(vocabulary_list), 128], X的维度为[?, 8],组合后的维度为[?, 8, 128] 代码说 ...
- 深蓝色 --ppt
Deep Learning of Binary Hash Codes for Fast Image Retrieval [Paper] [Code-Caffe] 1. 摘要 针对图像检索问题,提出简单 ...
- [基础]Deep Learning的基础概念
目录 DNN CNN DNN VS CNN Example 卷积的好处why convolution? DCNN 卷积核移动的步长 stride 激活函数 active function 通道 cha ...
- qwe 简易深度框架
qwe github地址 简介 简单的深度框架,参考Ng的深度学习课程作业,使用了keras的API设计. 方便了解网络具体实现,避免深陷于成熟框架的细节和一些晦涩的优化代码. 网络层实现了Dense ...
- 【深度学习系列3】 Mariana CNN并行框架与图像识别
[深度学习系列3] Mariana CNN并行框架与图像识别 本文是腾讯深度学习系列文章的第三篇,聚焦于腾讯深度学习平台Mariana中深度卷积神经网络Deep CNNs的多GPU模型并行和数据并行框 ...
- 卷积神经网络CNN与深度学习常用框架的介绍与使用
一.神经网络为什么比传统的分类器好 1.传统的分类器有 LR(逻辑斯特回归) 或者 linear SVM ,多用来做线性分割,假如所有的样本可以看做一个个点,如下图,有蓝色的点和绿色的点,传统的分类器 ...
- 我所写的CNN框架 VS caffe
我所写的CNN框架 VS caffe 一个月前.自己模仿caffe实现了一个卷积神经网络的框架. 同样点 1无缝支持CPU和GPU模式,GPU模式使用cuda实现. 不同点 1我的CNN不依赖与不论什 ...
- ubuntu之路——day19.2 开源框架与迁移、CNN中的数据扩充
开源框架与迁移 上面介绍了一些已经取得很好成绩的CNN框架,我们可以直接从GitHub上下载这些神经网络的结构和已经在ImageNet等数据集上训练好的权重超参数. 在应用于我们自己的数据时. 1.如 ...
- CNN基础框架简介
卷积神经网络简介 卷积神经网络是多层感知机的变种,由生物学家休博尔和维瑟尔在早期关于猫视觉皮层的研究发展而来.视觉皮层的细胞存在一个复杂的构造,这些细胞对视觉输入空间的子区域非常敏感,我们称之为感受野 ...
随机推荐
- Markdown内嵌Html语言
概述 Markdown是内嵌Html语言的,这使得我们可以在Markdown文档里面实现很多有趣的东西.现在记录在此,供自己以后参考,相信对其他人也有用. 介绍 Markdown的语法只有一个目标:作 ...
- FBReader阅读引擎支持的功能
"三十年河东,三十年河西"是一句民间谚语,它的来源是:从前黄河河道不固定,经常会改道(历史上无数次发生).某个地方原来在河的东面,若干年后,因黄河水流改道,这个地方会变为在河的西面 ...
- 12、SEO工程师指南 - 软件项目角色指南系列文章
第11章 SEO工程师 SEO工程师是一个比较新兴的职位,在实际的项目管理过程中,SEO工程师的地位相对靠后,只有在项目试运行以及运营期间才能体现出SEO工程师的作用.在项目完成之后, ...
- 从UUID想到的
1.UUID的定义 通用唯一标识符(UUID)被设计成一个在时间和空间上都独一无二的数字,常被用作唯一性标识. UUID是一个由5位十六进制数的字符串表示的128比特数字,其格式为 aaaaaaaa- ...
- 编程岗位电话面试问答Top 50[转]
原文链接:http://blog.jobbole.com/84618/ 1. 从哈希表,二叉树和链表中取元素的时间复杂度?如果你有数百万记录呢? 哈希表的时间复杂度为O(1),二叉树为O(logN) ...
- Python基础篇(八)
key words:私有变量,类静态变量,生成器,导入Python模块,r查看模块可以使用的函数,查看帮助信息,启动外部程序,集合,堆,时间模块,random模块,shelve模块,文件读取等 > ...
- flask入门与发送邮件与QQ邮箱
前言: 快两个月没写博客了, 原因是懒了, 没有最初写博客那种看到阅读量上涨, 别人给自己文章点赞后的开心. 心态也发生了不少变化. 有机会再来写写. 前两个月我去厦门某公司实习, 本着去厦门玩一玩还 ...
- demo说明
访问http://192.168.90.63:30111/face_mark/, 会看到上图的界面. 下面简单说下如何使用这个demo. 一.选择选择同一个人不同角度的五张图. 选中了五张图片.此时任 ...
- Go基础之--反射
反射:可以在运行时动态获取变量的相关信息 反射需要导入reflect 反射中重要函数的演示 反射有几下几个重要的函数:reflect.TypeOf :获取变量的类型,返回reflect.Type类型r ...
- OPP面向对象的介绍及使用
概述 面向过程与面向对象面向过程:专注于如何去解决一个问题的过程,编程特点是用一个个函数去实现过程操作,没有类与对象的概念面向对象:专注于有哪一个对象实体去解决这个问题,编程特点是:出现了一个个的类, ...