基于上一篇resnet网络结构进行实战。

再来贴一下resnet的基本结构方便与代码进行对比

resnet的自定义类如下:

import tensorflow as tf
from tensorflow import keras class BasicBlock(keras.layers.Layer): # filter_num指定通道数,stride指定步长
def __init__(self,filter_num,stride=1):
super(BasicBlock, self).__init__() # 注意padding=same并不总使得输入维度等于输出维度,而是对不同的步长有不同的策略,使得滑动更加完整
self.conv1 = keras.layers.Conv2D(filter_num,(3,3),strides=stride,padding='same')
self.bn1 = keras.layers.BatchNormalization()
self.relu = keras.layers.Activation('relu') self.conv2 = keras.layers.Conv2D(filter_num,(3,3),strides=1,padding='same')
self.bn2 = keras.layers.BatchNormalization() if stride!=1:
self.dowmsample = keras.Sequential()
self.dowmsample.add(keras.layers.Conv2D(filter_num,(1,1),strides=stride))
else:
self.dowmsample = lambda x:x def call(self, inputs, training=None): out = self.conv1(inputs)
out = self.bn1(out)
out = self.relu(out) out = self.conv2(out)
out = self.bn2(out) identity = self.dowmsample(inputs) output = keras.layers.add([out,identity])
output = tf.nn.relu(output) return output class ResNet(keras.Model): # resnet基本结构为[2,2,2,2],即分为四个部分,每个部分又分两个小部分
def __init__(self,layer_dims,num_classes=100):
super(ResNet,self).__init__() # 预处理层
self.stem = keras.Sequential([
keras.layers.Conv2D(64,(3,3),strides=(1,1)),
keras.layers.BatchNormalization(),
keras.layers.Activation('relu'),
keras.layers.MaxPool2D(pool_size=(2,2),strides=(1,1),padding='same')
]) self.layer1 = self.build_resblock(64,layer_dims[0])
self.layer2 = self.build_resblock(128, layer_dims[1], stride=2)
self.layer3 = self.build_resblock(256, layer_dims[2], stride=2)
self.layer4 = self.build_resblock(512, layer_dims[3], stride=2) # 自适应输出,方便送入全连层进行分类
self.avgpool = keras.layers.GlobalAveragePooling2D()
self.fc = keras.layers.Dense(num_classes) def call(self, inputs, training=None):
x = self.stem(inputs) x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x) x = self.avgpool(x)
x = self.fc(x) return x def build_resblock(self,filter_num,blocks,stride=1):
res_blocks = keras.Sequential();
res_blocks.add(BasicBlock(filter_num,stride)) for _ in range(1,blocks):
res_blocks.add(BasicBlock(filter_num,1)) return res_blocks def resnet18():
return ResNet([2,2,2,2])

训练过程如下:

import tensorflow as tf
from tensorflow import keras
import os
from resnet import resnet18 os.environ['TF_CPP_MIN_LOG'] = '' def preprocess(x,y):
x = 2*tf.cast(x,dtype=tf.float32)/255.-1
y = tf.cast(y,dtype=tf.int32)
return x,y (x,y),(x_test,y_test) = keras.datasets.cifar100.load_data()
y = tf.squeeze(y,axis=1)
y_test = tf.squeeze(y_test,axis=1)
print(x.shape,y.shape,x_test.shape,y_test.shape) train_db = tf.data.Dataset.from_tensor_slices((x,y))
train_db = train_db.shuffle(1000).map(preprocess).batch(64) test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db = train_db.map(preprocess).batch(64) def main():
model = resnet18()
model.build(input_shape=(None,32,32,3))
optimizer = keras.optimizers.Adam(lr=1e-3)
model.summary() for epoch in range(50):
for step,(x,y) in enumerate(train_db):
with tf.GradientTape() as tape:
logits = model(x)
y_onehot = tf.one_hot(y,depth=10)
loss = tf.losses.categorical_crossentropy(y_onehot,logits,from_logits=True)
loss = tf.reduce_mean(loss) gradient = tape.gradient(loss,model.trainable_variables)
optimizer.apply_gradients(zip(gradient,model.trainable_variables)) if step % 100 == 0:
print(epoch,step,'loss:',float(loss)) total_num = 0
total_correct = 0
for x,y in test_db:
logits = model(x)
prob = tf.nn.softmax(logits,axis=1)
pred = tf.argmax(prob,axis=1)
pred = tf.cast(pred,dtype=tf.int32) correct = tf.cast(tf.equal(pred,y),dtype=tf.int32)
correct = tf.reduce_sum(correct) total_num += x.shape[0]
total_correct += correct
acc = total_correct/total_num print("acc:",acc) if __name__ == '__main__':
main()

打印网络结构和参数量如下:

Resnet——深度残差网络(二)的更多相关文章

  1. Resnet——深度残差网络(一)

    我们都知道随着神经网络深度的加深,训练过程中会很容易产生误差的积累,从而出现梯度爆炸和梯度消散的问题,这是由于随着网络层数的增多,在网络中反向传播的梯度会随着连乘变得不稳定(特别大或特别小),出现最多 ...

  2. 使用dlib中的深度残差网络(ResNet)实现实时人脸识别

    opencv中提供的基于haar特征级联进行人脸检测的方法效果非常不好,本文使用dlib中提供的人脸检测方法(使用HOG特征或卷积神经网方法),并使用提供的深度残差网络(ResNet)实现实时人脸识别 ...

  3. 深度残差网络(DRN)ResNet网络原理

    一说起“深度学习”,自然就联想到它非常显著的特点“深.深.深”(重要的事说三遍),通过很深层次的网络实现准确率非常高的图像识别.语音识别等能力.因此,我们自然很容易就想到:深的网络一般会比浅的网络效果 ...

  4. Dual Path Networks(DPN)——一种结合了ResNet和DenseNet优势的新型卷积网络结构。深度残差网络通过残差旁支通路再利用特征,但残差通道不善于探索新特征。密集连接网络通过密集连接通路探索新特征,但有高冗余度。

    如何评价Dual Path Networks(DPN)? 论文链接:https://arxiv.org/pdf/1707.01629v1.pdf在ImagNet-1k数据集上,浅DPN超过了最好的Re ...

  5. CNN卷积神经网络_深度残差网络 ResNet——解决神经网络过深反而引起误差增加的根本问题,Highway NetWork 则允许保留一定比例的原始输入 x。(这种思想在inception模型也有,例如卷积是concat并行,而不是串行)这样前面一层的信息,有一定比例可以不经过矩阵乘法和非线性变换,直接传输到下一层,仿佛一条信息高速公路,因此得名Highway Network

    from:https://blog.csdn.net/diamonjoy_zone/article/details/70904212 环境:Win8.1 TensorFlow1.0.1 软件:Anac ...

  6. 关于深度残差网络(Deep residual network, ResNet)

    题外话: From <白话深度学习与TensorFlow> 深度残差网络: 深度残差网络的设计就是为了克服这种由于网络深度加深而产生的学习效率变低,准确率无法有效提升的问题(也称为网络退化 ...

  7. 深度残差网络(ResNet)

    引言 对于传统的深度学习网络应用来说,网络越深,所能学到的东西越多.当然收敛速度也就越慢,训练时间越长,然而深度到了一定程度之后就会发现越往深学习率越低的情况,甚至在一些场景下,网络层数越深反而降低了 ...

  8. 深度残差网络——ResNet学习笔记

    深度残差网络—ResNet总结 写于:2019.03.15—大连理工大学 论文名称:Deep Residual Learning for Image Recognition 作者:微软亚洲研究院的何凯 ...

  9. ResNet(深度残差网络)

    注:平原改为简单堆叠网络 一般x是恒等映射,当x与fx尺寸不同的时候,w作用就是将x变成和fx尺寸相同. 过程: 先用w将x进行恒等映射.扩维映射或者降维映射d得到wx.(没有参数,不需要优化器训练) ...

随机推荐

  1. (初学JS)JS基础——ATM机终端程序编写<1.0>

    初步学习了JS基础,为了更好地将所学知识熟练运用,我进行了银行ATM存取款机的模拟程序编写,主要通过VScode终端实现系列操作. 我的ATM程序包括6个主要功能:1.查询余额 2.存钱 3. 取钱 ...

  2. 使用Gradle构建springboot多模块项目,并混合groovy开发

    idea设置本地gradle 打包: build.gradle //声明gradle脚本自身需要使用的资源,优先执行 buildscript { ext { springBootVersion = ' ...

  3. Stopping service [Tomcat] Disconnected from the target VM, address:XXXXXX解决方案

    原文出处:https://blog.csdn.net/u013294097/article/details/90677049 Stopping service [Tomcat] Disconnecte ...

  4. try catch finally的理解

    定义以及用法: try/catch/finally 语句用于处理代码中可能出现的错误信息. 错误可能是语法错误,通常是程序员造成的编码错误或错别字.也可能是拼写错误或语言中缺少的功能(可能由于浏览器差 ...

  5. Arduino系列之光照传感器(三)

    今天,我将简单做一个当光照值低于某个值的时候,灯光自动打开,当高于某个值的时候,自动关闭. 设计代码原理: 首先,定义一个全局变量,并赋予初始值 然后,初始化程序 将设定某个IO口为输出模式 读取光度 ...

  6. python中的变量和字符串

    一.变量 1.python变量 *变量用于存储某个或某些特定的值,它与一个特定标识符相关联,该标识符称为变量名称.变量名指向存储在内存中的值.在创建变量时会在内存中开辟一个空间.基于变量的数据类型,解 ...

  7. AWS的边缘计算平台GreenGrass和IoT

    AWS的边缘计算平台GreenGrass和IoT 为什么需要有边缘计算? 如今公有云和私有云平台提供的服务已经连接上了绝大多数的桌面设备和移动设备.但是更多的设备比如,车辆,工程机械,医疗设备,无人机 ...

  8. ubuntu 如何搭建svn 服务器

    1.在终端中直接输入  sudo apt-get install subversion,选择安装即可 来这个subversion同时包含了服务端和客户端. 2.(可选)看版本命令 svnserve - ...

  9. 《Python学习手册 第五版》 -第12章 if测试和语法规则

    本章节的内容,主要讲解if语句,if语句是三大复合语句之一(其他两个是while和for),能处理编程中大多数逻辑运算 本章的重点内容如下: 1.if语句的基本形式(多路分支) 2.布尔表达式 3.i ...

  10. JAVA编程思想——分析阅读

    需要源码.JDK1.6 .编码风格参考阿里java规约 7/12开始 有点意识到自己喜欢理论大而泛的模糊知识的学习,而不喜欢实践和细节的打磨,是因为粗心浮躁导致的么? cron表达式使用 设计能力.领 ...