CNN实现

概述

我在qwe中有两种,第一种是按照Ng课程中的写法,多层循环嵌套得到每次的“小方格”,然后WX+b,这样的做法是最简单,直观。但是效率极其慢。基本跑个10张以内图片都会卡的要死。

第二种方法是使用img2col,将其转换为对应的矩阵,然后直接做一次矩阵乘法运算。

先看第一种

def forward(self, X):
m, n_H_prev, n_W_prev, n_C_prev = X.shape
(f, f, n_C_prev, n_C) = self.W.shape
n_H = int((n_H_prev - f + 2 * self.pad) / self.stride) + 1
n_W = int((n_W_prev - f + 2 * self.pad) / self.stride) + 1
n_H, n_W, n_C = self.output_size Z = np.zeros((m, n_H, n_W, n_C))
X_pad = zero_pad(X, self.pad)
for i in range(m):
for h in range(n_H):
for w in range(n_W):
for c in range(n_C):
vert_start = h * self.stride
vert_end = vert_start + f
horiz_start = w * self.stride
horiz_end = horiz_start + f
A_slice_prev =X_pad[i,vert_start:vert_end, horiz_start:horiz_end, :]
Z[i,h,w,c] = conv_single_step(A_slice_prev, self.W[...,c], self.b[...,c]) def conv_single_step(X, W, b):
# 对一个裁剪图像进行卷积
# X.shape = f, f, prev_channel_size
return np.sum(np.multiply(X, W) + b)

对于m,n_H,n_W,n_C循环就是取得裁剪小方块,可以看到这里的计算复杂度m * n_H * n_W * n_C * (f*f的矩阵计算)

第二种方法,先转换成大矩阵,再进行一次矩阵运算,相当于节省了多次小矩阵运算时间,这还是很可观的,能查个几十倍的速度。

img2col原理很简单,详情可参考caffe im2col

就是循环将每一部分都拉长成一维矩阵拼凑起来。

对于CNN来说,H就是要计算方块的个数即m(样本数) n_H(最终生成图像行数)n_W(最终生成图像列数),W就是f(核kernel长)f(核宽)*(输入样本通道输)

然后还要把参数矩阵W也拉成这个样子,H就是f(核长)f(核宽)(输入样本通道输),W列数就是核数kernel_size

如下图



def img2col(X, pad, stride, f):
pass
ff = f * f
m, n_H_prev, n_W_prev, n_C_prev= X.shape
n_H = int((n_H_prev - f + 2 * pad) / stride) + 1
n_W = int((n_W_prev - f + 2 * pad) / stride) + 1
Z = np.zeros((m * n_H * n_W, f * f * n_C_prev))
X_pad = np.pad(X, ((0, 0), (pad, pad), (pad, pad), (0, 0)), 'constant', constant_values=0)
row = -1 for i in range(m):
for h in range(n_H):
for w in range(n_W):
row += 1
vert_start = h * stride
horiz_start = w * stride
for col in range(f * f * n_C_prev):
t = col // n_C_prev
hh = t // f
ww = t % f
cc = col % n_C_prev
Z[row, col] = X_pad[i, vert_start + hh, horiz_start + ww, cc] def speed_forward(model, X):
W = model.W
b = model.b
stride = model.stride
pad = model.pad
(n_C_prev, f, f, n_C) = W.shape
m, n_H_prev, n_W_prev, n_C_prev = X.shape n_H = int((n_H_prev - f + 2 * pad) / stride) + 1
n_W = int((n_W_prev - f + 2 * pad) / stride) + 1 # WW = W.swapaxes(2,1)
# WW = WW.swapaxes(1,0) XX = img2col(X, pad, stride, f)
# WW = WW.reshape(f*f*n_C_prev, n_C)
WW = W.reshape(f*f*n_C_prev, n_C)
model.XX = XX
model.WW = WW Z = np.dot(XX, WW) + b
return Z.reshape(m, n_H, n_W, n_C)

这种耗时操作,最好使用Cython扩展来写,不然速度还是不够理想。Cython扩展代码code

反向传播同理,具体代码参考

github

qwe框架- CNN 实现的更多相关文章

  1. 深度学习原理与框架-CNN在文本分类的应用 1.tf.nn.embedding_lookup(根据索引数据从数据中取出数据) 2.saver.restore(加载sess参数)

    1. tf.nn.embedding_lookup(W, X) W的维度为[len(vocabulary_list), 128], X的维度为[?, 8],组合后的维度为[?, 8, 128] 代码说 ...

  2. 深蓝色 --ppt

    Deep Learning of Binary Hash Codes for Fast Image Retrieval [Paper] [Code-Caffe] 1. 摘要 针对图像检索问题,提出简单 ...

  3. [基础]Deep Learning的基础概念

    目录 DNN CNN DNN VS CNN Example 卷积的好处why convolution? DCNN 卷积核移动的步长 stride 激活函数 active function 通道 cha ...

  4. qwe 简易深度框架

    qwe github地址 简介 简单的深度框架,参考Ng的深度学习课程作业,使用了keras的API设计. 方便了解网络具体实现,避免深陷于成熟框架的细节和一些晦涩的优化代码. 网络层实现了Dense ...

  5. 【深度学习系列3】 Mariana CNN并行框架与图像识别

    [深度学习系列3] Mariana CNN并行框架与图像识别 本文是腾讯深度学习系列文章的第三篇,聚焦于腾讯深度学习平台Mariana中深度卷积神经网络Deep CNNs的多GPU模型并行和数据并行框 ...

  6. 卷积神经网络CNN与深度学习常用框架的介绍与使用

    一.神经网络为什么比传统的分类器好 1.传统的分类器有 LR(逻辑斯特回归) 或者 linear SVM ,多用来做线性分割,假如所有的样本可以看做一个个点,如下图,有蓝色的点和绿色的点,传统的分类器 ...

  7. 我所写的CNN框架 VS caffe

    我所写的CNN框架 VS caffe 一个月前.自己模仿caffe实现了一个卷积神经网络的框架. 同样点 1无缝支持CPU和GPU模式,GPU模式使用cuda实现. 不同点 1我的CNN不依赖与不论什 ...

  8. ubuntu之路——day19.2 开源框架与迁移、CNN中的数据扩充

    开源框架与迁移 上面介绍了一些已经取得很好成绩的CNN框架,我们可以直接从GitHub上下载这些神经网络的结构和已经在ImageNet等数据集上训练好的权重超参数. 在应用于我们自己的数据时. 1.如 ...

  9. CNN基础框架简介

    卷积神经网络简介 卷积神经网络是多层感知机的变种,由生物学家休博尔和维瑟尔在早期关于猫视觉皮层的研究发展而来.视觉皮层的细胞存在一个复杂的构造,这些细胞对视觉输入空间的子区域非常敏感,我们称之为感受野 ...

随机推荐

  1. 20165318 预备作业二 学习基础和C语言基础调查

    20165318 学习基础和C语言基础调查 技能学习经验 我们这一代人,或多或少的都上过各种兴趣班,舞蹈钢琴画画书法,我也是如此.可这些技能中,唯一能拿的出手的就是舞蹈了.按照<优秀的教学方法- ...

  2. js函数知识

    1.函数基本知识 通过函数可以封装任意条语句,在任何地方调用,js中用function关键字来声明, //基本格式,函数名,传递参数,代码块 function functionName(arg0,ar ...

  3. Jmeter性能测试,新手上路篇

    1. JMeter简介 Apache JMeter是Apache组织开发的基于Java的压力测试工具.用于对软件做压力测试,它最初被设计用于Web应用测试,但后来扩展到其他测试领域. 它可以用于测试静 ...

  4. SQL性能优化的几点建议

    1. 索引:索引可以提高查询的速度,但不是使用带有索引的字段查询时,索引都会起作用,如下几种特殊情况下,有可能使用带有索引的字段查询时,索引没有起作用:1)使用LIKE关键字的查询语句 如果匹配字符串 ...

  5. Spring基础篇——bean的自动化装配

    上篇博文讲Spring的IOC容器时说道,虽然容器功能强大,但容器本身只是个空壳,需要我们主动放入装配对象,并告诉它对象之间的协作关系,然后容器才能按照我们的指示发挥它的魔力,完成装配bean的使命. ...

  6. 蓝桥杯 基础练习 之 FJ的字符串

    问题描述 FJ在沙盘上写了这样一些字符串: A1 = "A" A2 = "ABA" A3 = "ABACABA" A4 = "AB ...

  7. Go学习笔记03-附录

    第三部分 附录 A. 工具 1. 工具集 1.1 go build gcflags ldflags 更多参数: go tool 6g -h 或 [https://golang.org/cmd/gc/] ...

  8. Go学习笔记01-语言

    1.1 变量 Go 是静态类型语言,不能在运行期改变变量类型.使用关键字 var 定义变量,自动初始化为零值.如果提供初始化值,可省略变量类型,由编译器自动推断. var x int var f fl ...

  9. OpenCV角点检测goodFeaturesToTrack()源代码分析

    上面一篇博客分析了HARRIS和ShiTomasi角点检测的源代码.而为了提取更准确的角点,OpenCV中提供了goodFeaturesToTrack()这个API函数,来获取更加准确的角点位置.这篇 ...

  10. Linux一些常用操作

    1.linux swap分区 可采用文件的方式 dd if=/dev/zero of=/var/swap bs=1024 count=2048000 mkswap /var/swap swapon / ...