Darknet网络代码
Darknet网络代码
import math
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
class Mish(nn.Module):
def __init__(self):
super(Mish, self).__init__()
def forward(self, x):
x = x * torch.tanh(F.softmax(x))
# F.softmax(x) = ln(1+e^x)
# tanh(x) = (e^x-e^(-x))/(e^x+e^(-x))
return x
# 卷积+归一化+激活
class CB(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1):
super(CB, self).__init__()
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.activation = Mish()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.activation(x)
return x
class Resblock(nn.Module):
def __init__(self, channels, hidden_channels=None):
super(Resblock, self).__init__()
if hidden_channels is None:
hidden_channels = channels
self.resblock = nn.Sequential(
nn.Conv2d(in_channels = channels, out_channels= hidden_channels, kernel_size=1),
nn.Conv2d(in_channels=hidden_channels, out_channels=channels, kernel_size=3, padding=1)
)
def forward(self, x):
x = x + self.resblock(x)
return x
class Resblock_body(nn.Module):
def __init__(self, in_channels, out_channels, num_blocks, first):
super(Resblock_body, self).__init__()
self.downsample_conv = CB(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2)
if first:
self.split_conv0 = CB(in_channels=out_channels, out_channels=out_channels, kernel_size=1)
self.split_conv1 = CB(in_channels=out_channels, out_channels=out_channels, kernel_size=1)
self.blocks_conv = nn.Sequential(
Resblock(channels=out_channels, hidden_channels=out_channels//2),
CB(in_channels=out_channels, out_channels=out_channels, kernel_size=1)
)
self.concat_conv = CB(in_channels=out_channels * 2, out_channels=out_channels, kernel_size=1)
else:
self.split_conv0 = CB(in_channels=out_channels, out_channels=out_channels // 2, kernel_size=1)
self.split_conv1 = CB(in_channels=out_channels, out_channels=out_channels // 2, kernel_size=1)
self.blocks_conv = nn.Sequential(
*[Resblock(out_channels // 2) for _ in range(num_blocks)],
CB(in_channels=out_channels // 2, out_channels=out_channels // 2, kernel_size=1)
)
self.concat_conv = CB(in_channels=out_channels, out_channels=out_channels, kernel_size=1)
def forward(self, x):
x = self.downsample_conv(x)
x0 = self.split_conv0(x)
x1 = self.split_conv1(x)
x1 = self.blocks_conv(x1)
x = torch.cat([x0, x1], dim=1)
x = self.concat_conv(x)
return x
class CspDarknet(nn.Module):
def __init__(self, layers):
super(CspDarknet, self).__init__()
self.inplanes = 32
self.conv1 = CB(in_channels = 3, out_channels=self.inplanes, kernel_size=3, stride=1)
self.feature_channels = [64, 128, 256, 512, 1024]
self.stages =nn.ModuleList
self.stages = nn.ModuleList([
# 416,416,32 -> 208,208,64
Resblock_body(in_channels=self.inplanes, out_channels=self.feature_channels[0], num_blocks=layers[0], first=True),
# 208,208,64 -> 104,104,128
Resblock_body(in_channels=self.feature_channels[0], out_channels=self.feature_channels[1], num_blocks=layers[1], first=False),
# 104,104,128 -> 52,52,256
Resblock_body(in_channels=self.feature_channels[1], out_channels=self.feature_channels[2], num_blocks=layers[2], first=False),
# 52,52,256 -> 26,26,512
Resblock_body(in_channels=self.feature_channels[2], out_channels=self.feature_channels[3], num_blocks=layers[3], first=False),
# 26,26,512 -> 13,13,1024
Resblock_body(in_channels=self.feature_channels[3], out_channels=self.feature_channels[4], num_blocks=layers[4], first=False)
])
self.num_features=1
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, x):
x = self.conv1(x)
x = self.stages[0](x)
x = self.stages[1](x)
out3 = self.stages[2](x)
out4 = self.stages[3](out3)
out5 = self.stages[4](out4)
return out3, out4, out5
def darknet53(pretrained):
model = CspDarknet([1, 2, 8, 8, 4])
if pretrained:
model.load_state_dict(torch.load("D:/finishi_project/model_data/CSPdarknet53_backbone_weights.pth"))
return model
# from torchsummary import summary
# dar = darknet53(False)
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# summary(dar, input_size=(3, 416, 416))
代码没有注释,欢迎留言共同讨论,顺便给个关注,感谢。
Darknet网络代码的更多相关文章
- 学习构建调试Linux内核网络代码的环境MenuOS系统
构建调试Linux内核网络代码的环境MenuOS系统 一.前言 这是网络程序设计的第三次实验,主要是学习自己编译linux内核,构建一个具有简易功能的操作系统,同时在系统上面进行调试linux内核网络 ...
- team foundation server——网络代码管理工具
像我们平时有时会莫名的弹出一个如下图所示的提示框,这个是什么呢?这个就是有人用team foundation server进行过代码管理的项目 那么team foundation server到底是什 ...
- 修复Microsoft Store 无法连接网络 代码: 0x80072EFD
事情的经过是这样的,我的Windows版本是1709,前两天刚从1703升上来,今天突然发现它自己给我装了个skype,我上Microsoft商店里查一下是什么情况,结果突然发现它又双双双不正常工作了 ...
- h5 网络断网时,返回上一个页面 demo (与检测网络代码相结合,更直观看到结果)
页面一: <!DOCTYPE html><html lang="en"><head> <meta charset="UTF-8& ...
- 构建调试Linux内核网络代码的环境MenuOS系统
构建MenuOS系统 1.将指定文件拷贝到本地: git clone https://github.com/mengning/linuxnet.git 此过程可能需要输入github账号和密码. 2. ...
- Git学习:利用Git和TortoiseGit把代码传输到网络服务器
版本控制这块,一直用SVN.感觉挺好用,比VSS要好用些.不过,近期在网上,又谈到时下很流行的Git.就想看看Git到底是何方神圣.趁着五一在家无事,就静下心来,简单研究一下. 当下,网络上提供的基于 ...
- 第二十四节,TensorFlow下slim库函数的使用以及使用VGG网络进行预训练、迁移学习(附代码)
在介绍这一节之前,需要你对slim模型库有一些基本了解,具体可以参考第二十二节,TensorFlow中的图片分类模型库slim的使用.数据集处理,这一节我们会详细介绍slim模型库下面的一些函数的使用 ...
- 目标检测网络之 YOLOv2
YOLOv1基本思想 YOLO将输入图像分成SxS个格子,若某个物体 Ground truth 的中心位置的坐标落入到某个格子,那么这个格子就负责检测出这个物体. 每个格子预测B个bounding b ...
- c 网络与套接字socket
我们已经知道如何使用I/O与文件通信,还知道了如何让同一计算机上的两个进程进行通信,这篇文章将创建具有服务器和客户端功能的程序 互联网中大部分的底层网络代码都是用C语言写的. 网络程序通常有两部分组成 ...
- Linux就这个范儿 第11章 独霸网络的蜘蛛神功
Linux就这个范儿 第11章 独霸网络的蜘蛛神功 第11章 应用层 (Application):网络服务与最终用户的一个接口.协议有:HTTP FTP TFTP SMTP SNMP DNS表示层 ...
随机推荐
- P2617 Dynamic Rankings 解题报告
link 整体二分是一种东西,比如上面这道题. 先考虑一个不带修版本的,也就是经典问题区间 kth,显然我们可以主席树但是我知道你很想用主席树但是你先别用不用主席树,用一种离线的算法,叫整体二分. 首 ...
- Spring注解补充(一)
注解补充 挑一些常用,但是深入不多的总结一下. Bean的声明周期 在@Bean注解中,添加init属性和destroy属性 @Bean(initMethod = "initMethod&q ...
- SpringMVC:RESTful案例
目录 相关准备 功能清单 具体功能:访问首页 ①配置view-controller ②创建页面 具体功能:查询所有员工数据 ①控制器方法 ②创建employee_list.html 具体功能:删除 ① ...
- 基于Docker部署Dubbo+Nacos服务
一.说明 本文介绍基于 Docker 部署一套 Dubbo + Nacos 的微服务环境,并解决容器里的 IP 及端口的访问问题. 基于上文<基于jib-maven-plugin快速构建微服务d ...
- PostgreSQL的10进制与16进制互转
1.10进制转16进制Postgres里面有一个内置的10进制转16进制的函数:to_hex(int)/to_hex(bigint) [postgres@localhost ~]$ psql Pass ...
- pat乙级1014 福尔摩斯的约会
#include<stdio.h> #include<stdlib.h> #include<string.h> #include<math.h> int ...
- taro框架开发微信小程序遇到的问题
ios端,如果input放在了dispplay flex里面,会导致一系列问题 滑动屏幕,键盘不收起,input值随屏幕滚动 input之前切换,键盘不弹起来或有时弹有时不弹 键盘莫名收起 input ...
- K8s之Etcd的备份与恢复
ETCD简介 ETCD用于共享和配置服务发现的分布式,一致性的KV存储系统. ETCD是CoreOS公司发起的一个开源项目,授权协议为Apache. ETCD 存储 k8s 所有数据信息 ETCD 是 ...
- 如何将多个TXT合并成一个TXT,文件名称提取
方法1:1.将所有需要合并的TXT整理到一个文件夹中,切记,TXT合并最好每个TXT内容头或尾留一行间距,因为合并是直接合并,不会保留间距. 2.使用Windows命令cmd,切换到文件所在文件夹 3 ...
- Java面向对像之方法重写
方法重写Override 重写:需要有继承关系,子类重写父类的方法! 特点: 1.方法名必须相同 2.参数列表必须相同 3.修饰符:范围可以扩大:public > Protected > ...
