一、项目简介

手动实现mini深度学习框架,主要精力不放在运算优化上,仅体会原理。

地址见:miniDeepFrame

相关博客

『TensorFlow』卷积层、池化层详解
『科学计算』全连接层、均方误差、激活函数实现

文件介绍

Layer.py 层 class,已实现:全连接层,卷积层,平均池化层
Loss.py 损失函数 class,已实现:均方误差损失函数
Activate.py 激活函数 class,已实现:sigmoid、tanh、relu
test.py 训练测试代码

主流框架对于卷积相关层的实现都是基于矩阵乘法运算,而非这里的多层for循环。由于计算机计算矩阵乘法速度非常快,所以这是一个虽然提高内存消耗但是计算速度显著上升的方法,把feature map中的感受野(包含重叠的部分,所以会加大内存消耗)和卷积核全部拉伸成为向量,组成两个矩阵相乘,再想办法恢复为输出的feature map(详见『TensorFlow』卷积层、池化层详解)。

二、测试输出

我们此时不对层函数进行封装,仅仅实现了最简单的前向传播、反向传播、参数获取几个功能,利用这些功能,我们已经可以实现一个最简单的神经网络,

声明并初始化各层class的实例,这会使得各个实例初始化可学习参数

(【注】一般的框架会在运行时,即第一次前向传播时才初始化参数,本demo由于是动态的,所以没必要这样写)

进入循环体:

  获取数据,向前传播,计算损失函数&损失函数的梯度

  向后传播,获取各个参数的梯度

  对参数循环,利用参数梯度更新参数

在test.py中,我们使用tensorflow的接口,下载并读取mnist数据集,然后训练一个10分类的分类器,观察收敛过程。

损失函数收敛展示

实际运行test.py,会输出loss函数结果,并绘制成图,左图展示了整个loss函数收敛过程,

实际训练并查看中间输出可以看见,最开始几次训练的损失函数下降的极快,相应的梯度值如果添加了中间的输出也会极大(10^3量级,对应的参数初始化为-1~1之间),于是下图截掉了前四次迭代输出的Loss,能够更好的展示后面的收敛过程:

『计算机视觉』mini深度学习框架实现的更多相关文章

  1. 『计算机视觉』Mask-RCNN_从服装关键点检测看KeyPoints分支

    下图Github地址:Mask_RCNN       Mask_RCNN_KeyPoints『计算机视觉』Mask-RCNN_论文学习『计算机视觉』Mask-RCNN_项目文档翻译『计算机视觉』Mas ...

  2. 『计算机视觉』Mask-RCNN_训练网络其三:训练Model

    Github地址:Mask_RCNN 『计算机视觉』Mask-RCNN_论文学习 『计算机视觉』Mask-RCNN_项目文档翻译 『计算机视觉』Mask-RCNN_推断网络其一:总览 『计算机视觉』M ...

  3. 『计算机视觉』Mask-RCNN_训练网络其二:train网络结构&损失函数

    Github地址:Mask_RCNN 『计算机视觉』Mask-RCNN_论文学习 『计算机视觉』Mask-RCNN_项目文档翻译 『计算机视觉』Mask-RCNN_推断网络其一:总览 『计算机视觉』M ...

  4. 『计算机视觉』Mask-RCNN_训练网络其一:数据集与Dataset类

    Github地址:Mask_RCNN 『计算机视觉』Mask-RCNN_论文学习 『计算机视觉』Mask-RCNN_项目文档翻译 『计算机视觉』Mask-RCNN_推断网络其一:总览 『计算机视觉』M ...

  5. 『计算机视觉』Mask-RCNN_锚框生成

    Github地址:Mask_RCNN 『计算机视觉』Mask-RCNN_论文学习 『计算机视觉』Mask-RCNN_项目文档翻译 『计算机视觉』Mask-RCNN_推断网络其一:总览 『计算机视觉』M ...

  6. 『计算机视觉』FPN:feature pyramid networks for object detection

    对用卷积神经网络进行目标检测方法的一种改进,通过提取多尺度的特征信息进行融合,进而提高目标检测的精度,特别是在小物体检测上的精度.FPN是ResNet或DenseNet等通用特征提取网络的附加组件,可 ...

  7. 『计算机视觉』经典RCNN_其二:Faster-RCNN

    项目源码 一.Faster-RCNN简介 『cs231n』Faster_RCNN 『计算机视觉』Faster-RCNN学习_其一:目标检测及RCNN谱系 一篇讲的非常明白的文章:一文读懂Faster ...

  8. 28款GitHub最流行的开源机器学习项目,推荐GitHub上10 个开源深度学习框架

    20 个顶尖的 Python 机器学习开源项目 机器学习 2015-06-08 22:44:30 发布 您的评价: 0.0 收藏 1收藏 我们在Github上的贡献者和提交者之中检查了用Python语 ...

  9. Cs231n课堂内容记录-Lecture 8 深度学习框架

    Lecture 8  Deep Learning Software 课堂笔记参见:https://blog.csdn.net/u012554092/article/details/78159316 今 ...

随机推荐

  1. 对text字段聚合,没有设置fielddate所以出错

    http://192.168.60.26:9200/linewell_assets_mgt_es_yh_test/lw_devices/ _mapping { "properties&quo ...

  2. 探究Java中的锁

    一.锁的作用和比较 1.Lock接口及其类图 Lock接口:是Java提供的用来控制多个线程访问共享资源的方式. ReentrantLock:Lock的实现类,提供了可重入的加锁语义 ReadWrit ...

  3. Please run SwitchHosts! as an Administrator 原因

    github 访问慢的初期,不得已修改host,但直接修改host的文件太不够灵活了,使用switchhost工具. win10 遇到上述问题如这个地址 要撞墙了. 解决方法: 进入 C:\Windo ...

  4. git常用操作命令使用说明

    设置用户名和邮箱 git config --global user.email 'xxx' git config --global user.name 'xxx' 创建分支 git branch xx ...

  5. spring 事务注解

    在spring中使用事务需要遵守一些规范和了解一些坑点,别想当然.列举一下一些注意点. 在需要事务管理的地方加@Transactional 注解.@Transactional 注解可以被应用于接口定义 ...

  6. [py]一致性hash原理

    1,可变,不可变 python中值得是引用地址是否变化. 2.可hash 生命周期里不可变得值都可hash 3.python中内置数据结构特点 有序不可变 有序可变 无序可变 无序不可变 5.一致性h ...

  7. Java过关测验

    库存物资管理系统一.背景资料:1.有一个存放商品的仓库,每天都有商品出库和入库.2.每种商品都有名称.生产厂家.型号.规格等.3.出入库时必须填写出入库单据,单据包括商品名称.生产厂家.型号.规格.数 ...

  8. Oracle sqlplus的输出表的排版,数据表发生折行问题

    当查寻数据表的时候,会发生折行的问题 这时,我们可以用下面的语句 设置每行显示的记录长度:set    linesize    300;    --->  每行显示300个字符. 设置每页显示的 ...

  9. Java代码质量改进之:同步对象的选择

    在Java中,让线程同步的一种方式是使用synchronized关键字,它可以被用来修饰一段代码块,如下: synchronized(被锁的同步对象) { // 代码块:业务代码 } 当synchro ...

  10. Spring Boot 的 application.properties

    更改默认端口:8080 server.port = 8081 更改context-path :/server.context-path = /springboot #server.address= # ...