基于keras的残差网络
1 前言
理论上,网络层数越深,拟合效果越好。但是,层数加深也会导致梯度消失或梯度爆炸现象产生。当网络层数已经过深时,深层网络表现为“恒等映射”。实践表明,神经网络对残差的学习比对恒等关系的学习表现更好,因此,残差网络在深层模型中广泛应用。
本文以MNIST手写数字分类为例,为方便读者认识残差网络,网络中只有全连接层,没有卷积层。关于MNIST数据集的说明,见使用TensorFlow实现MNIST数据集分类
笔者工作空间如下:

代码资源见-->残差网络(ResNet)案例分析
2 实验
renet.py
from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Model
from keras.layers import add,Input,Dense,Activation
#载入数据
def read_data(path):
mnist=input_data.read_data_sets(path,one_hot=True)
train_x,train_y=mnist.train.images,mnist.train.labels,
valid_x,valid_y=mnist.validation.images,mnist.validation.labels,
test_x,test_y=mnist.test.images,mnist.test.labels
return train_x,train_y,valid_x,valid_y,test_x,test_y
#残差块
def ResBlock(x,hidden_size1,hidden_size2):
r=Dense(hidden_size1,activation='relu')(x) #第一隐层
r=Dense(hidden_size2)(r) #第二隐层
if x.shape[1]==hidden_size2:
shortcut=x
else:
shortcut=Dense(hidden_size2)(x) #shortcut(捷径)
o=add([r,shortcut])
o=Activation('relu')(o) #激活函数
return o
#残差网络
def ResNet(train_x,train_y,valid_x,valid_y,test_x,test_y):
inputs=Input(shape=(784,))
x=ResBlock(inputs,30,30)
x=ResBlock(x,30,30)
x=ResBlock(x,20,20)
x=Dense(10,activation='softmax')(x)
model=Model(input=inputs,output=x)
#查看网络结构
model.summary()
#编译模型
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
#训练模型
model.fit(train_x,train_y,batch_size=500,nb_epoch=50,verbose=2,validation_data=(valid_x,valid_y))
#评估模型
pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2)
print('test_loss:',pre[0],'- test_acc:',pre[1])
train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
ResNet(train_x,train_y,valid_x,valid_y,test_x,test_y)
网络各层输出尺寸:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 784) 0
__________________________________________________________________________________________________
dense_1 (Dense) (None, 30) 23550 input_1[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 30) 930 dense_1[0][0]
__________________________________________________________________________________________________
dense_3 (Dense) (None, 30) 23550 input_1[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 30) 0 dense_2[0][0]
dense_3[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 30) 0 add_1[0][0]
__________________________________________________________________________________________________
dense_4 (Dense) (None, 30) 930 activation_1[0][0]
__________________________________________________________________________________________________
dense_5 (Dense) (None, 30) 930 dense_4[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 30) 0 dense_5[0][0]
activation_1[0][0]
__________________________________________________________________________________________________
activation_2 (Activation) (None, 30) 0 add_2[0][0]
__________________________________________________________________________________________________
dense_6 (Dense) (None, 20) 620 activation_2[0][0]
__________________________________________________________________________________________________
dense_7 (Dense) (None, 20) 420 dense_6[0][0]
__________________________________________________________________________________________________
dense_8 (Dense) (None, 20) 620 activation_2[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 20) 0 dense_7[0][0]
dense_8[0][0]
__________________________________________________________________________________________________
activation_3 (Activation) (None, 20) 0 add_3[0][0]
__________________________________________________________________________________________________
dense_9 (Dense) (None, 10) 210 activation_3[0][0]
==================================================================================================
Total params: 51,760
Trainable params: 51,760
Non-trainable params: 0
网络训练结果:
Epoch 48/50
- 1s - loss: 0.0019 - acc: 0.9999 - val_loss: 0.1463 - val_acc: 0.9706
Epoch 49/50
- 1s - loss: 0.0016 - acc: 0.9999 - val_loss: 0.1502 - val_acc: 0.9722
Epoch 50/50
- 1s - loss: 0.0013 - acc: 0.9999 - val_loss: 0.1542 - val_acc: 0.9728
test_loss: 0.16228994959965348 - test_acc: 0.9721000045537949
声明:本文转自基于keras的残差网络
基于keras的残差网络的更多相关文章
- 基于 Keras 用 LSTM 网络做时间序列预测
目录 基于 Keras 用 LSTM 网络做时间序列预测 问题描述 长短记忆网络 LSTM 网络回归 LSTM 网络回归结合窗口法 基于时间步的 LSTM 网络回归 在批量训练之间保持 LSTM 的记 ...
- BERT实战——基于Keras
1.keras_bert 和 kert4keras keras_bert 是 CyberZHG 大佬封装好了Keras版的Bert,可以直接调用官方发布的预训练权重. github:https://g ...
- 使用dlib中的深度残差网络(ResNet)实现实时人脸识别
opencv中提供的基于haar特征级联进行人脸检测的方法效果非常不好,本文使用dlib中提供的人脸检测方法(使用HOG特征或卷积神经网方法),并使用提供的深度残差网络(ResNet)实现实时人脸识别 ...
- [深度应用]·首届中国心电智能大赛初赛开源Baseline(基于Keras val_acc: 0.88)
[深度应用]·首届中国心电智能大赛初赛开源Baseline(基于Keras val_acc: 0.88) 个人主页--> https://xiaosongshine.github.io/ 项目g ...
- 残差网络ResNet笔记
发现博客园也可以支持Markdown,就把我之前写的博客搬过来了- 欢迎转载,请注明出处:http://www.cnblogs.com/alanma/p/6877166.html 下面是正文: Dee ...
- 深度残差网络(DRN)ResNet网络原理
一说起“深度学习”,自然就联想到它非常显著的特点“深.深.深”(重要的事说三遍),通过很深层次的网络实现准确率非常高的图像识别.语音识别等能力.因此,我们自然很容易就想到:深的网络一般会比浅的网络效果 ...
- Dual Path Networks(DPN)——一种结合了ResNet和DenseNet优势的新型卷积网络结构。深度残差网络通过残差旁支通路再利用特征,但残差通道不善于探索新特征。密集连接网络通过密集连接通路探索新特征,但有高冗余度。
如何评价Dual Path Networks(DPN)? 论文链接:https://arxiv.org/pdf/1707.01629v1.pdf在ImagNet-1k数据集上,浅DPN超过了最好的Re ...
- 关于深度残差网络(Deep residual network, ResNet)
题外话: From <白话深度学习与TensorFlow> 深度残差网络: 深度残差网络的设计就是为了克服这种由于网络深度加深而产生的学习效率变低,准确率无法有效提升的问题(也称为网络退化 ...
- 深度残差网络——ResNet学习笔记
深度残差网络—ResNet总结 写于:2019.03.15—大连理工大学 论文名称:Deep Residual Learning for Image Recognition 作者:微软亚洲研究院的何凯 ...
- Resnet——深度残差网络(一)
我们都知道随着神经网络深度的加深,训练过程中会很容易产生误差的积累,从而出现梯度爆炸和梯度消散的问题,这是由于随着网络层数的增多,在网络中反向传播的梯度会随着连乘变得不稳定(特别大或特别小),出现最多 ...
随机推荐
- Clock Domain Crossing
Clock Domain Crossing CDC问题主要有亚稳态问题,多比特信号同步,握手信号同步,异步Fifo等 Topics Describe the SoC Design Issues Und ...
- 08-任务Task和函数Function
任务Task和函数Function 类似于c语言中的函数 Task task 含有input\output\inout语句 task消耗仿真时间 task中可以写延迟:#20 延迟20个仿真时间单位 ...
- 2023强网拟态crypto-一眼看出
1.题目信息 一眼看穿 查看代码 from Crypto.Util.number import * from secret import flag import gmpy2 flag=b'' r = ...
- Laravel - 模板中的url
<!-- 1, url --> <a href="{{url('/')}}">跳转到主页</a> <!-- 2,action 方法 ...
- Redis之入门概括与指令
Redis特点(AP模型,优先保证可用,不会管数据丢失): 快的原因: 基于内存操作,操作不需要跟磁盘交互 k-v结构,类似与hashMap,所以查询速度非常快,接近O(1). 底层数据结构是有如:跳 ...
- [转帖]gooyfs 的编译 github
https://github.com/kahing/goofys/issues/527 @maobaolong @PengleiShi I had the same issue as you guys ...
- [转帖]Redis之安全措施
指令安全 Redis的一些指令会对Redis服务的稳定性及安全性各方面造成影响,例如keys指令在数据量大的情况下会导致Redis卡顿,flushdb和flushall会导致Redis的数据被清空. ...
- OpenSSH 9.2P1升级以及版本显示的处理过程
说明 本次维护的时间是 2023-2-9 最新已发布的补丁是 OpenSSH9.2P1版本 其他本本应该是类似处理. 下载介质 在 OpenSSH官网打开相关界面. http://www.openss ...
- 批量删除一个月为tag的镜像的办法
第一步获取镜像列表 这是一个最简单的列转行. docker images |grep 20220401 |awk 'BEGIN{ORS=","}{print $1}' 第二步执行双 ...
- Debian 安装vim 提示版本问题的处理
https://blog.csdn.net/Oil__/article/details/113384278 purge 还有 --allow-remove-essential 安装失败提示解决方法安装 ...