基于keras的残差网络
1 前言
理论上,网络层数越深,拟合效果越好。但是,层数加深也会导致梯度消失或梯度爆炸现象产生。当网络层数已经过深时,深层网络表现为“恒等映射”。实践表明,神经网络对残差的学习比对恒等关系的学习表现更好,因此,残差网络在深层模型中广泛应用。
本文以MNIST手写数字分类为例,为方便读者认识残差网络,网络中只有全连接层,没有卷积层。关于MNIST数据集的说明,见使用TensorFlow实现MNIST数据集分类
笔者工作空间如下:

代码资源见-->残差网络(ResNet)案例分析
2 实验
renet.py
from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Model
from keras.layers import add,Input,Dense,Activation
#载入数据
def read_data(path):
mnist=input_data.read_data_sets(path,one_hot=True)
train_x,train_y=mnist.train.images,mnist.train.labels,
valid_x,valid_y=mnist.validation.images,mnist.validation.labels,
test_x,test_y=mnist.test.images,mnist.test.labels
return train_x,train_y,valid_x,valid_y,test_x,test_y
#残差块
def ResBlock(x,hidden_size1,hidden_size2):
r=Dense(hidden_size1,activation='relu')(x) #第一隐层
r=Dense(hidden_size2)(r) #第二隐层
if x.shape[1]==hidden_size2:
shortcut=x
else:
shortcut=Dense(hidden_size2)(x) #shortcut(捷径)
o=add([r,shortcut])
o=Activation('relu')(o) #激活函数
return o
#残差网络
def ResNet(train_x,train_y,valid_x,valid_y,test_x,test_y):
inputs=Input(shape=(784,))
x=ResBlock(inputs,30,30)
x=ResBlock(x,30,30)
x=ResBlock(x,20,20)
x=Dense(10,activation='softmax')(x)
model=Model(input=inputs,output=x)
#查看网络结构
model.summary()
#编译模型
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
#训练模型
model.fit(train_x,train_y,batch_size=500,nb_epoch=50,verbose=2,validation_data=(valid_x,valid_y))
#评估模型
pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2)
print('test_loss:',pre[0],'- test_acc:',pre[1])
train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
ResNet(train_x,train_y,valid_x,valid_y,test_x,test_y)
网络各层输出尺寸:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 784) 0
__________________________________________________________________________________________________
dense_1 (Dense) (None, 30) 23550 input_1[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 30) 930 dense_1[0][0]
__________________________________________________________________________________________________
dense_3 (Dense) (None, 30) 23550 input_1[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 30) 0 dense_2[0][0]
dense_3[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 30) 0 add_1[0][0]
__________________________________________________________________________________________________
dense_4 (Dense) (None, 30) 930 activation_1[0][0]
__________________________________________________________________________________________________
dense_5 (Dense) (None, 30) 930 dense_4[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 30) 0 dense_5[0][0]
activation_1[0][0]
__________________________________________________________________________________________________
activation_2 (Activation) (None, 30) 0 add_2[0][0]
__________________________________________________________________________________________________
dense_6 (Dense) (None, 20) 620 activation_2[0][0]
__________________________________________________________________________________________________
dense_7 (Dense) (None, 20) 420 dense_6[0][0]
__________________________________________________________________________________________________
dense_8 (Dense) (None, 20) 620 activation_2[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 20) 0 dense_7[0][0]
dense_8[0][0]
__________________________________________________________________________________________________
activation_3 (Activation) (None, 20) 0 add_3[0][0]
__________________________________________________________________________________________________
dense_9 (Dense) (None, 10) 210 activation_3[0][0]
==================================================================================================
Total params: 51,760
Trainable params: 51,760
Non-trainable params: 0
网络训练结果:
Epoch 48/50
- 1s - loss: 0.0019 - acc: 0.9999 - val_loss: 0.1463 - val_acc: 0.9706
Epoch 49/50
- 1s - loss: 0.0016 - acc: 0.9999 - val_loss: 0.1502 - val_acc: 0.9722
Epoch 50/50
- 1s - loss: 0.0013 - acc: 0.9999 - val_loss: 0.1542 - val_acc: 0.9728
test_loss: 0.16228994959965348 - test_acc: 0.9721000045537949
声明:本文转自基于keras的残差网络
基于keras的残差网络的更多相关文章
- 基于 Keras 用 LSTM 网络做时间序列预测
目录 基于 Keras 用 LSTM 网络做时间序列预测 问题描述 长短记忆网络 LSTM 网络回归 LSTM 网络回归结合窗口法 基于时间步的 LSTM 网络回归 在批量训练之间保持 LSTM 的记 ...
- BERT实战——基于Keras
1.keras_bert 和 kert4keras keras_bert 是 CyberZHG 大佬封装好了Keras版的Bert,可以直接调用官方发布的预训练权重. github:https://g ...
- 使用dlib中的深度残差网络(ResNet)实现实时人脸识别
opencv中提供的基于haar特征级联进行人脸检测的方法效果非常不好,本文使用dlib中提供的人脸检测方法(使用HOG特征或卷积神经网方法),并使用提供的深度残差网络(ResNet)实现实时人脸识别 ...
- [深度应用]·首届中国心电智能大赛初赛开源Baseline(基于Keras val_acc: 0.88)
[深度应用]·首届中国心电智能大赛初赛开源Baseline(基于Keras val_acc: 0.88) 个人主页--> https://xiaosongshine.github.io/ 项目g ...
- 残差网络ResNet笔记
发现博客园也可以支持Markdown,就把我之前写的博客搬过来了- 欢迎转载,请注明出处:http://www.cnblogs.com/alanma/p/6877166.html 下面是正文: Dee ...
- 深度残差网络(DRN)ResNet网络原理
一说起“深度学习”,自然就联想到它非常显著的特点“深.深.深”(重要的事说三遍),通过很深层次的网络实现准确率非常高的图像识别.语音识别等能力.因此,我们自然很容易就想到:深的网络一般会比浅的网络效果 ...
- Dual Path Networks(DPN)——一种结合了ResNet和DenseNet优势的新型卷积网络结构。深度残差网络通过残差旁支通路再利用特征,但残差通道不善于探索新特征。密集连接网络通过密集连接通路探索新特征,但有高冗余度。
如何评价Dual Path Networks(DPN)? 论文链接:https://arxiv.org/pdf/1707.01629v1.pdf在ImagNet-1k数据集上,浅DPN超过了最好的Re ...
- 关于深度残差网络(Deep residual network, ResNet)
题外话: From <白话深度学习与TensorFlow> 深度残差网络: 深度残差网络的设计就是为了克服这种由于网络深度加深而产生的学习效率变低,准确率无法有效提升的问题(也称为网络退化 ...
- 深度残差网络——ResNet学习笔记
深度残差网络—ResNet总结 写于:2019.03.15—大连理工大学 论文名称:Deep Residual Learning for Image Recognition 作者:微软亚洲研究院的何凯 ...
- Resnet——深度残差网络(一)
我们都知道随着神经网络深度的加深,训练过程中会很容易产生误差的积累,从而出现梯度爆炸和梯度消散的问题,这是由于随着网络层数的增多,在网络中反向传播的梯度会随着连乘变得不稳定(特别大或特别小),出现最多 ...
随机推荐
- std::istringstream的用法
1.概要 std::istringstream 是 C++ 标准库中的一个类,它用于从字符串中提取数据,并将数据转换为不同的数据类型.它通常用于从字符串中解析数据,例如整数.浮点数等.以下是关于 st ...
- 【ThreadX-NetX】Azure RTOS NetX概述
Azure RTOS NetX是工业级TCP / IP IPv4嵌入式网络堆栈,专门针对深度嵌入式,实时和IoT应用程序而设计.Azure RTOS NetX是Microsoft最初的IPv4网络堆栈 ...
- 百度网盘(百度云)SVIP超级会员共享账号每日更新(2023.12.15)
一.百度网盘SVIP超级会员共享账号 可能很多人不懂这个共享账号是什么意思,小编在这里给大家做一下解答. 我们多知道百度网盘很大的用处就是类似U盘,不同的人把文件上传到百度网盘,别人可以直接下载,避免 ...
- [转帖]能使 Oracle 索引失效的六大限制条件
Oracle 索引的目标是避免全表扫描,提高查询效率,但有些时候却适得其反. 例如一张表中有上百万条数据,对某个字段加了索引,但是查询时性能并没有什么提高,这可能是 oracle 索引失效造成的.or ...
- [转帖]sqluldr2 oracle直接导出数据为文本的小工具使用
https://www.cnblogs.com/ocp-100/p/11098373.html 近期客户有需求,导出某些审计数据,供审计人进行核查,只能导出成文本或excel格式的进行查看,这里我们使 ...
- [转帖]JVM(3)之垃圾回收(GC垃圾收集器+垃圾回收算法+安全点+记忆集与卡表+并发可达性分析......)
<深入理解java虚拟机>+宋红康老师+阳哥大厂面试题2总结整理 一.堆的结构组成 堆位于运行时数据区中是线程共享的.一个进程对应一个jvm实例.一个jvm实例对应一个运行时数据区.一个运 ...
- [转帖]python读取配置文件获取所有键值对_python总结——处理配置文件(ConfigParser)
python处理ConfigParser 使用ConfigParser模块读写ini文件 (转载) ConfigParserPython 的ConfigParser Module中定义了3个类对INI ...
- CentOS8 安装Oracle19c RPM的办法
1. 下载相应的rpm包 我这边使用的主要有: -rw-r--r-- 1 root root 19112 Apr 5 15:13 compat-libcap1-1.10-7.el7.x86_64.rp ...
- Linux 排除某些目录下 重复jar包的方法
Linux 排除某些目录下 取重复jar包的方法 find . -path ./runtime/java -prune -o -name '*.jar' -exec basename {} \;| s ...
- WebAssembly入门笔记[2]:利用Memory传递数据
利用灵活的"导入"和"导出"机制,WebAssembly与承载的JavaScript应用之间可以很便利地"互通有无".<与JavaSc ...