torch.topk

功能:找出前k大的数据,及其索引号

  • input:张量
  • k:决定选取k个值,k=1是为top-1
  • dim:索引维度

    返回:
  • Tensor:前k大的值
  • LongTensor:前k大的值所在的位置
torch.topk(input, k, dim=None,largest=True,sorted=True,out=None )

torchvision.utils.make_grid

功能:制作网格图像

  • tensor:图像数据,BCHW形式
  • nrow:行数(列数自动计算)
  • padding:图像间距(像素单位)
  • normalize:是否将像素值标准化
  • range:标准化范围
  • scale_each:是否单张图维度标准化
  • pad_value:padding的像素值
make_grid(tensor, nrow=8, padding=2, normalize=False,range=None,scale_each=False, pad_value=0)

代码结构

  1. 加载图片
  2. 加载模型
  3. 模型推理
  4. 获取类别 topk -> index -> names
  5. 分类结果可视化

注意事项

  1. 模型接收4d张量
  2. 弃用LRN(LRN用处不大)
  3. 增加了一个trick增加AdaptiveAvgPool2d,使图片无论多大分辨率,到达该层输出都是6*6,后面接的FC层,即将不匹配的Tensor都池化为6*6的,接入fc层进行classifier。
  4. 卷积核数量有所改变(pytorch中没有完全按照paper中去设置)

采用norm_mean=[0.485,0.456,0.406]

norm_std=[0.229,0.224,0.225]

这两个是通过Image数据集统计的均值和标准差,但是后面用BatchNormalization就可以不用这个。

transforms.ToTensor()区间是[0,1]

img_tensor.unsqueeze_(0) #chw->bchw

with torch.no_grad():

默认情况下会记录grad,如果只是test情况下不需要做反向传播,就可以用torch.no_grad()只做前向传播省时间。

log_interval,多少个epoch打印信息

val_interval,多少个epoch执行验证集

lr_decay_step=1,学习率多少个epoch改变

transforms.Resize((256))只是短边截为256,长边成比例缩短。

transforms.Resize((256,256))是正方形

# ============================ step 4/5 优化器 ============================
# 冻结卷积层
flag = 0
# flag = 1
if flag:
fc_params_id = list(map(id, alexnet_model.classifier.parameters())) # 返回的是parameters的 内存地址
base_params = filter(lambda p: id(p) not in fc_params_id, alexnet_model.parameters())
optimizer = optim.SGD([
{'params': base_params, 'lr': LR * 0.1}, # 0 设置为0即冻结卷积层
{'params': alexnet_model.classifier.parameters(), 'lr': LR}], momentum=0.9)
bs, ncrops, c, h, w = inputs.size()     # [4, 10, 3, 224, 224
outputs = alexnet_model(inputs.view(-1, c, h, w))
outputs_avg = outputs.view(bs, ncrops, -1).mean(1)

ncrops是分割多出的10张图,然后对10张图结果取平均。

流程就是对4batch的每个10张图的图片进行调整,先将4*10,3,224,224输入model中,对结果进行分割出来,得到4batch的结果。

01_AlexNet的更多相关文章

随机推荐

  1. Codeforces Round #496 (Div. 3) E1. Median on Segments (Permutations Edition) (中位数,思维)

    题意:给你一个数组,求有多少子数组的中位数等于\(m\).(若元素个数为偶数,取中间靠左的为中位数). 题解:由中位数的定义我们知道:若数组中\(<m\)的数有\(x\)个,\(>m\)的 ...

  2. python代理池的构建3——爬取代理ip

    上篇博客地址:python代理池的构建2--代理ip是否可用的处理和检查 一.基础爬虫模块(Base_spider.py) #-*-coding:utf-8-*- ''' 目标: 实现可以指定不同UR ...

  3. DNS 是什么?如何运作的?

    前言 我们在上一篇说到,IP 地址的发明把我们纷乱复杂的网络设备整齐划一地统一在了同一个网络中. 但是类似于 192.168.1.0 这样的地址并不便于人类记忆,于是发明了 域名(Domain Nam ...

  4. XV6学习(14)Lab fs: File system

    代码在github上. 这次实验是要对文件系统修改,使其支持更大的文件以及符号链接,实验本身并不是很复杂.但文件系统可以说是XV6中最复杂的部分,整个文件系统包括了七层:文件描述符,路径名,目录,in ...

  5. 国内centos/windows10安装minikube

    centos/windows10安装minikube 目录 centos/windows10安装minikube A win10安装minikube 1 下载安装kubectl.exe 1.1 准备目 ...

  6. 2018牛客多校第一场 E-Removal【dp】

    题目链接:戳这里 转自:戳这里 题意:长度为n的序列,删掉m个数字后有多少种不同的序列.n<=10^5,m<=10. 题解:dp[i][j]表示加入第i个数字后,总共删掉j个数字时,有多少 ...

  7. codeforce 3C

    B. Lorry time limit per test 2 seconds memory limit per test 64 megabytes input standard input outpu ...

  8. 关于free和delete的使用

    上一篇篇幅太长,这里再区分free和delete的用法. 两个同时存在是有它的原因的,我们前面说过,free是函数,它只释放内存,但不会调用析构函数,如果用free去释放new申请的空间,会因为无法调 ...

  9. ASP.Net MVP Framework had been dead !

    ASP.Net MVP Framework Project Description A project to get you started with creating and designing w ...

  10. URLSearchParams & Location & URL params parse

    URLSearchParams & Location & URL params parse URL params parse node.js env bug node.js & ...