1 前言

理论上,网络层数越深,拟合效果越好。但是,层数加深也会导致梯度消失或梯度爆炸现象产生。当网络层数已经过深时,深层网络表现为“恒等映射”。实践表明,神经网络对残差的学习比对恒等关系的学习表现更好,因此,残差网络在深层模型中广泛应用。

本文以MNIST手写数字分类为例,为方便读者认识残差网络,网络中只有全连接层,没有卷积层。关于MNIST数据集的说明,见使用TensorFlow实现MNIST数据集分类

笔者工作空间如下:

代码资源见-->残差网络(ResNet)案例分析

2 实验

renet.py

from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Model
from keras.layers import add,Input,Dense,Activation #载入数据
def read_data(path):
mnist=input_data.read_data_sets(path,one_hot=True)
train_x,train_y=mnist.train.images,mnist.train.labels,
valid_x,valid_y=mnist.validation.images,mnist.validation.labels,
test_x,test_y=mnist.test.images,mnist.test.labels
return train_x,train_y,valid_x,valid_y,test_x,test_y #残差块
def ResBlock(x,hidden_size1,hidden_size2):
r=Dense(hidden_size1,activation='relu')(x) #第一隐层
r=Dense(hidden_size2)(r) #第二隐层
if x.shape[1]==hidden_size2:
shortcut=x
else:
shortcut=Dense(hidden_size2)(x) #shortcut(捷径)
o=add([r,shortcut])
o=Activation('relu')(o) #激活函数
return o #残差网络
def ResNet(train_x,train_y,valid_x,valid_y,test_x,test_y):
inputs=Input(shape=(784,))
x=ResBlock(inputs,30,30)
x=ResBlock(x,30,30)
x=ResBlock(x,20,20)
x=Dense(10,activation='softmax')(x)
model=Model(input=inputs,output=x)
#查看网络结构
model.summary()
#编译模型
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
#训练模型
model.fit(train_x,train_y,batch_size=500,nb_epoch=50,verbose=2,validation_data=(valid_x,valid_y))
#评估模型
pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2)
print('test_loss:',pre[0],'- test_acc:',pre[1]) train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
ResNet(train_x,train_y,valid_x,valid_y,test_x,test_y)

网络各层输出尺寸:

__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 784) 0
__________________________________________________________________________________________________
dense_1 (Dense) (None, 30) 23550 input_1[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 30) 930 dense_1[0][0]
__________________________________________________________________________________________________
dense_3 (Dense) (None, 30) 23550 input_1[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 30) 0 dense_2[0][0]
dense_3[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 30) 0 add_1[0][0]
__________________________________________________________________________________________________
dense_4 (Dense) (None, 30) 930 activation_1[0][0]
__________________________________________________________________________________________________
dense_5 (Dense) (None, 30) 930 dense_4[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 30) 0 dense_5[0][0]
activation_1[0][0]
__________________________________________________________________________________________________
activation_2 (Activation) (None, 30) 0 add_2[0][0]
__________________________________________________________________________________________________
dense_6 (Dense) (None, 20) 620 activation_2[0][0]
__________________________________________________________________________________________________
dense_7 (Dense) (None, 20) 420 dense_6[0][0]
__________________________________________________________________________________________________
dense_8 (Dense) (None, 20) 620 activation_2[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 20) 0 dense_7[0][0]
dense_8[0][0]
__________________________________________________________________________________________________
activation_3 (Activation) (None, 20) 0 add_3[0][0]
__________________________________________________________________________________________________
dense_9 (Dense) (None, 10) 210 activation_3[0][0]
==================================================================================================
Total params: 51,760
Trainable params: 51,760
Non-trainable params: 0

网络训练结果:

Epoch 48/50
- 1s - loss: 0.0019 - acc: 0.9999 - val_loss: 0.1463 - val_acc: 0.9706
Epoch 49/50
- 1s - loss: 0.0016 - acc: 0.9999 - val_loss: 0.1502 - val_acc: 0.9722
Epoch 50/50
- 1s - loss: 0.0013 - acc: 0.9999 - val_loss: 0.1542 - val_acc: 0.9728
test_loss: 0.16228994959965348 - test_acc: 0.9721000045537949

​ 声明:本文转自基于keras的残差网络

基于keras的残差网络的更多相关文章

  1. 基于 Keras 用 LSTM 网络做时间序列预测

    目录 基于 Keras 用 LSTM 网络做时间序列预测 问题描述 长短记忆网络 LSTM 网络回归 LSTM 网络回归结合窗口法 基于时间步的 LSTM 网络回归 在批量训练之间保持 LSTM 的记 ...

  2. BERT实战——基于Keras

    1.keras_bert 和 kert4keras keras_bert 是 CyberZHG 大佬封装好了Keras版的Bert,可以直接调用官方发布的预训练权重. github:https://g ...

  3. 使用dlib中的深度残差网络(ResNet)实现实时人脸识别

    opencv中提供的基于haar特征级联进行人脸检测的方法效果非常不好,本文使用dlib中提供的人脸检测方法(使用HOG特征或卷积神经网方法),并使用提供的深度残差网络(ResNet)实现实时人脸识别 ...

  4. [深度应用]·首届中国心电智能大赛初赛开源Baseline(基于Keras val_acc: 0.88)

    [深度应用]·首届中国心电智能大赛初赛开源Baseline(基于Keras val_acc: 0.88) 个人主页--> https://xiaosongshine.github.io/ 项目g ...

  5. 残差网络ResNet笔记

    发现博客园也可以支持Markdown,就把我之前写的博客搬过来了- 欢迎转载,请注明出处:http://www.cnblogs.com/alanma/p/6877166.html 下面是正文: Dee ...

  6. 深度残差网络(DRN)ResNet网络原理

    一说起“深度学习”,自然就联想到它非常显著的特点“深.深.深”(重要的事说三遍),通过很深层次的网络实现准确率非常高的图像识别.语音识别等能力.因此,我们自然很容易就想到:深的网络一般会比浅的网络效果 ...

  7. Dual Path Networks(DPN)——一种结合了ResNet和DenseNet优势的新型卷积网络结构。深度残差网络通过残差旁支通路再利用特征,但残差通道不善于探索新特征。密集连接网络通过密集连接通路探索新特征,但有高冗余度。

    如何评价Dual Path Networks(DPN)? 论文链接:https://arxiv.org/pdf/1707.01629v1.pdf在ImagNet-1k数据集上,浅DPN超过了最好的Re ...

  8. 关于深度残差网络(Deep residual network, ResNet)

    题外话: From <白话深度学习与TensorFlow> 深度残差网络: 深度残差网络的设计就是为了克服这种由于网络深度加深而产生的学习效率变低,准确率无法有效提升的问题(也称为网络退化 ...

  9. 深度残差网络——ResNet学习笔记

    深度残差网络—ResNet总结 写于:2019.03.15—大连理工大学 论文名称:Deep Residual Learning for Image Recognition 作者:微软亚洲研究院的何凯 ...

  10. Resnet——深度残差网络(一)

    我们都知道随着神经网络深度的加深,训练过程中会很容易产生误差的积累,从而出现梯度爆炸和梯度消散的问题,这是由于随着网络层数的增多,在网络中反向传播的梯度会随着连乘变得不稳定(特别大或特别小),出现最多 ...

随机推荐

  1. 如何查看centos对于 TIME_WAIT 状态的 Socket 回收时间

    要查看系统对于 TIME_WAIT 状态的 Socket 回收时间,可以通过以下方式查询 TCP 数据结构中的相关字段值: cat /proc/sys/net/ipv4/tcp_fin_timeout ...

  2. MySQL高可用九种方案

    有的时候博客内容会有变动,首发博客是最新的,其他博客地址可能会未同步,认准https://blog.zysicyj.top 首发博客地址 参考视频 MMM 方案(单主) MySQL 高可用方案之 MM ...

  3. [转帖]美国出口管制法律制度及中国企业风险防范——EAR核心内容解读

    http://bzy.scjg.jl.gov.cn/wto/zszc/myxgzs/202202/t20220221_636006.html 发布时间:2022-01-18 一.<美国出口管理条 ...

  4. [转帖]Jmeter学习笔记(六)——使用badboy录制脚本

    https://www.cnblogs.com/pachongshangdexuebi/p/11506274.html 1.下载安装 可以去badboy官网下载地址:http://www.badboy ...

  5. [转帖]【P1】Jmeter 准备工作

    文章目录 一.Jmeter 介绍 1.1.Jmeter 有什么样功能 1.2.Jmeter 与 LoadRunner 比较 1.3.常用性能测试工具 1.4.性能测试工具如何选型 1.5.学习 Jme ...

  6. [转帖]《Linux性能优化实战》笔记(22)—— 网络丢包问题分析

    所谓丢包,是指在网络数据的收发过程中,由于种种原因,数据包还没传输到应用程序中,就被丢弃了.这些被丢弃包的数量,除以总的传输包数,也就是我们常说的丢包率.丢包率是网络性能中最核心的指标之一.丢包通常会 ...

  7. 华城金锐申威SW64服务器重装过程

    华城金锐申威SW64服务器重装过程 背景 这边为了进行兼容性验证新进了两套申威的服务器. 一台机器带着安装好的操作系统了. 但是另外一套没有对应的系统. 端午期间想着趁着上班的人少, 加吧给处理一下. ...

  8. 【转帖】16.JVM栈帧内部结构-局部变量表

    目录 1.局部变量表(Local variables) 1.局部变量表(Local variables) 1.局部变量表也称为局部变量数组或本地变量表. 2.局部变量表定义为一个数字数组,主要用于存储 ...

  9. [转帖]MIPS和ARM授权差异引起的龙芯路线变迁

    https://zhuanlan.zhihu.com/p/99807721 一.MIPS和ARM授权的异同 MIPS授权和ARM授权都分为处理器核授权(Core License)和架构授权(Archi ...

  10. [转帖]02-rsync备份方式

    https://developer.aliyun.com/article/885789?spm=a2c6h.24874632.expert-profile.283.7c46cfe9h5DxWK 简介: ...