Resnet——深度残差网络(二)
基于上一篇resnet网络结构进行实战。
再来贴一下resnet的基本结构方便与代码进行对比

resnet的自定义类如下:
import tensorflow as tf
from tensorflow import keras class BasicBlock(keras.layers.Layer): # filter_num指定通道数,stride指定步长
def __init__(self,filter_num,stride=1):
super(BasicBlock, self).__init__() # 注意padding=same并不总使得输入维度等于输出维度,而是对不同的步长有不同的策略,使得滑动更加完整
self.conv1 = keras.layers.Conv2D(filter_num,(3,3),strides=stride,padding='same')
self.bn1 = keras.layers.BatchNormalization()
self.relu = keras.layers.Activation('relu') self.conv2 = keras.layers.Conv2D(filter_num,(3,3),strides=1,padding='same')
self.bn2 = keras.layers.BatchNormalization() if stride!=1:
self.dowmsample = keras.Sequential()
self.dowmsample.add(keras.layers.Conv2D(filter_num,(1,1),strides=stride))
else:
self.dowmsample = lambda x:x def call(self, inputs, training=None): out = self.conv1(inputs)
out = self.bn1(out)
out = self.relu(out) out = self.conv2(out)
out = self.bn2(out) identity = self.dowmsample(inputs) output = keras.layers.add([out,identity])
output = tf.nn.relu(output) return output class ResNet(keras.Model): # resnet基本结构为[2,2,2,2],即分为四个部分,每个部分又分两个小部分
def __init__(self,layer_dims,num_classes=100):
super(ResNet,self).__init__() # 预处理层
self.stem = keras.Sequential([
keras.layers.Conv2D(64,(3,3),strides=(1,1)),
keras.layers.BatchNormalization(),
keras.layers.Activation('relu'),
keras.layers.MaxPool2D(pool_size=(2,2),strides=(1,1),padding='same')
]) self.layer1 = self.build_resblock(64,layer_dims[0])
self.layer2 = self.build_resblock(128, layer_dims[1], stride=2)
self.layer3 = self.build_resblock(256, layer_dims[2], stride=2)
self.layer4 = self.build_resblock(512, layer_dims[3], stride=2) # 自适应输出,方便送入全连层进行分类
self.avgpool = keras.layers.GlobalAveragePooling2D()
self.fc = keras.layers.Dense(num_classes) def call(self, inputs, training=None):
x = self.stem(inputs) x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x) x = self.avgpool(x)
x = self.fc(x) return x def build_resblock(self,filter_num,blocks,stride=1):
res_blocks = keras.Sequential();
res_blocks.add(BasicBlock(filter_num,stride)) for _ in range(1,blocks):
res_blocks.add(BasicBlock(filter_num,1)) return res_blocks def resnet18():
return ResNet([2,2,2,2])
训练过程如下:
import tensorflow as tf
from tensorflow import keras
import os
from resnet import resnet18 os.environ['TF_CPP_MIN_LOG'] = '' def preprocess(x,y):
x = 2*tf.cast(x,dtype=tf.float32)/255.-1
y = tf.cast(y,dtype=tf.int32)
return x,y (x,y),(x_test,y_test) = keras.datasets.cifar100.load_data()
y = tf.squeeze(y,axis=1)
y_test = tf.squeeze(y_test,axis=1)
print(x.shape,y.shape,x_test.shape,y_test.shape) train_db = tf.data.Dataset.from_tensor_slices((x,y))
train_db = train_db.shuffle(1000).map(preprocess).batch(64) test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db = train_db.map(preprocess).batch(64) def main():
model = resnet18()
model.build(input_shape=(None,32,32,3))
optimizer = keras.optimizers.Adam(lr=1e-3)
model.summary() for epoch in range(50):
for step,(x,y) in enumerate(train_db):
with tf.GradientTape() as tape:
logits = model(x)
y_onehot = tf.one_hot(y,depth=10)
loss = tf.losses.categorical_crossentropy(y_onehot,logits,from_logits=True)
loss = tf.reduce_mean(loss) gradient = tape.gradient(loss,model.trainable_variables)
optimizer.apply_gradients(zip(gradient,model.trainable_variables)) if step % 100 == 0:
print(epoch,step,'loss:',float(loss)) total_num = 0
total_correct = 0
for x,y in test_db:
logits = model(x)
prob = tf.nn.softmax(logits,axis=1)
pred = tf.argmax(prob,axis=1)
pred = tf.cast(pred,dtype=tf.int32) correct = tf.cast(tf.equal(pred,y),dtype=tf.int32)
correct = tf.reduce_sum(correct) total_num += x.shape[0]
total_correct += correct
acc = total_correct/total_num print("acc:",acc) if __name__ == '__main__':
main()
打印网络结构和参数量如下:

Resnet——深度残差网络(二)的更多相关文章
- Resnet——深度残差网络(一)
我们都知道随着神经网络深度的加深,训练过程中会很容易产生误差的积累,从而出现梯度爆炸和梯度消散的问题,这是由于随着网络层数的增多,在网络中反向传播的梯度会随着连乘变得不稳定(特别大或特别小),出现最多 ...
- 使用dlib中的深度残差网络(ResNet)实现实时人脸识别
opencv中提供的基于haar特征级联进行人脸检测的方法效果非常不好,本文使用dlib中提供的人脸检测方法(使用HOG特征或卷积神经网方法),并使用提供的深度残差网络(ResNet)实现实时人脸识别 ...
- 深度残差网络(DRN)ResNet网络原理
一说起“深度学习”,自然就联想到它非常显著的特点“深.深.深”(重要的事说三遍),通过很深层次的网络实现准确率非常高的图像识别.语音识别等能力.因此,我们自然很容易就想到:深的网络一般会比浅的网络效果 ...
- Dual Path Networks(DPN)——一种结合了ResNet和DenseNet优势的新型卷积网络结构。深度残差网络通过残差旁支通路再利用特征,但残差通道不善于探索新特征。密集连接网络通过密集连接通路探索新特征,但有高冗余度。
如何评价Dual Path Networks(DPN)? 论文链接:https://arxiv.org/pdf/1707.01629v1.pdf在ImagNet-1k数据集上,浅DPN超过了最好的Re ...
- CNN卷积神经网络_深度残差网络 ResNet——解决神经网络过深反而引起误差增加的根本问题,Highway NetWork 则允许保留一定比例的原始输入 x。(这种思想在inception模型也有,例如卷积是concat并行,而不是串行)这样前面一层的信息,有一定比例可以不经过矩阵乘法和非线性变换,直接传输到下一层,仿佛一条信息高速公路,因此得名Highway Network
from:https://blog.csdn.net/diamonjoy_zone/article/details/70904212 环境:Win8.1 TensorFlow1.0.1 软件:Anac ...
- 关于深度残差网络(Deep residual network, ResNet)
题外话: From <白话深度学习与TensorFlow> 深度残差网络: 深度残差网络的设计就是为了克服这种由于网络深度加深而产生的学习效率变低,准确率无法有效提升的问题(也称为网络退化 ...
- 深度残差网络(ResNet)
引言 对于传统的深度学习网络应用来说,网络越深,所能学到的东西越多.当然收敛速度也就越慢,训练时间越长,然而深度到了一定程度之后就会发现越往深学习率越低的情况,甚至在一些场景下,网络层数越深反而降低了 ...
- 深度残差网络——ResNet学习笔记
深度残差网络—ResNet总结 写于:2019.03.15—大连理工大学 论文名称:Deep Residual Learning for Image Recognition 作者:微软亚洲研究院的何凯 ...
- ResNet(深度残差网络)
注:平原改为简单堆叠网络 一般x是恒等映射,当x与fx尺寸不同的时候,w作用就是将x变成和fx尺寸相同. 过程: 先用w将x进行恒等映射.扩维映射或者降维映射d得到wx.(没有参数,不需要优化器训练) ...
随机推荐
- 《高性能MySQL》之MySQL查询性能优化
为什么查询会慢? 响应时间过长.如果把查询看做是一个任务,那么它由一系列子任务组成,每个子任务都会消耗一定的时间.如果要优化查询,实际上优化其子任务,要么消除其中一些子任务,要么减少子任务的执行次数, ...
- redis--->事务和锁
redis 的事务.锁.流水线 Redis与 mysql事务的对比 开启 mysql:start transaction redis:multi 语句:mysql:普通sql redis:普通命令 成 ...
- 上线前一个小时,dubbo这个问题可把我折腾惨了
前因 那是一个月黑风高的夜晚,不管有没有圆圆的月亮,都无法解救要加班的我.这就是苦涩的人生啊! 那天正好是春节回家的日子,定了晚上的票,然后还是上线的日子. 测试在做回归测试的时候,发现一个老功能报错 ...
- Java常见问题汇总
1.String,StringBuffer,StringBulider的区别及应用场景 2.Servlet生命周期 3.向上转型与向下转型 4.Java的多态性 5.重写和重载的区别 6.深拷贝和浅拷 ...
- Iptables和Firewall-selinux
一.Iptables防火墙 ---------- **三表五链:**三表: filter过滤表 nat转换表 mangle表五链: PREROUTING--->在进行路由选择前处理数据包 INP ...
- SVN : 在SVN检测下来的Maven项目没有Maven标志
在Ecplise使用import->从SVN检出项目, 检出的项目没有了 Maven标志 解决方案 右键点击项目->configure->Convert to Maven Proje ...
- python中Threadlocal变量
在多线程环境下,每个线程都有自己的数据.一个线程使用自己的局部变量比使用全局变量好,因为局部变量只有线程自己能看见,不会影响其他线程,而全局变量的修改必须加锁. 不加锁就会出现变量会被修改的问题,进而 ...
- Django项目在Linux服务器上部署和躺过的坑
引言 在各方的推荐下,领导让我在测试环境部署之前开发的测试数据预报平台.那么问题来了,既然要在服务器上部署, 就需要准备: 1.linux服务器配置 2.linux安装python环境搭建与配置 3. ...
- java(list,set,map)链接
http://blog.csdn.net/smileiam/article/details/49836865 http://blog.csdn.net/u013344815/article/detai ...
- HDU_4456_二维树状数组
http://acm.hdu.edu.cn/showproblem.php?pid=4456 第一道二维树状数组就这么麻烦,题目要计算的是一个菱形范围内的和,于是可以把原来的坐标系旋转45度,就是求一 ...