pytorch固定部分参数

不用梯度

如果是Variable,则可以初始化时指定

j = Variable(torch.randn(5,5), requires_grad=True)

但是如果是m = nn.Linear(10,10)是没有requires_grad传入的

for i in m.parameters():
i.requires_grad=False

另外一个小技巧就是在nn.Module里,可以在中间插入这个

for p in self.parameters():
p.requires_grad=False # eg 前面的参数就是False,而后面的不变
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5) for p in self.parameters():
p.requires_grad=False self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def freeze(test_net):
ct = 0
for child in test_net.children():
ct += 1
if ct < 3:
for param in child.parameters():
param.requires_grad = False

过滤

optimizer.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

pytorch固定部分参数的更多相关文章

  1. PyTorch常用参数初始化方法详解

    1. 均匀分布 torch.nn.init.uniform_(tensor, a=0, b=1) 从均匀分布U(a, b)中采样,初始化张量. 参数: tensor - 需要填充的张量 a - 均匀分 ...

  2. PyTorch固定参数

    In situation of finetuning, parameters in backbone network need to be frozen. To achieve this target ...

  3. pytorch和tensorflow的爱恨情仇之定义可训练的参数

    pytorch和tensorflow的爱恨情仇之基本数据类型 pytorch和tensorflow的爱恨情仇之张量 pytorch版本:1.6.0 tensorflow版本:1.15.0 之前我们就已 ...

  4. pytorch和tensorflow的爱恨情仇之参数初始化

    pytorch和tensorflow的爱恨情仇之基本数据类型 pytorch和tensorflow的爱恨情仇之张量 pytorch和tensorflow的爱恨情仇之定义可训练的参数 pytorch版本 ...

  5. nn.ConvTranspose2d的参数output_padding的作用

    参考:https://blog.csdn.net/qq_41368247/article/details/86626446 使用前提:stride > 1 补充:same卷积操作 是通过padd ...

  6. Pytorch在colab和kaggle中使用TensorBoard/TensorboardX可视化

    在colab和kaggle内核的Jupyter notebook中如何可视化深度学习模型的参数对于我们分析模型具有很大的意义,相比tensorflow, pytorch缺乏一些的可视化生态包,但是幸好 ...

  7. [源码解析] PyTorch 分布式(16) --- 使用异步执行实现批处理 RPC

    [源码解析] PyTorch 分布式(16) --- 使用异步执行实现批处理 RPC 目录 [源码解析] PyTorch 分布式(16) --- 使用异步执行实现批处理 RPC 0x00 摘要 0x0 ...

  8. [源码解析] PyTorch 分布式(17) --- 结合DDP和分布式 RPC 框架

    [源码解析] PyTorch 分布式(17) --- 结合DDP和分布式 RPC 框架 目录 [源码解析] PyTorch 分布式(17) --- 结合DDP和分布式 RPC 框架 0x00 摘要 0 ...

  9. 【AI】Pytorch_预训练模型

    1. 模型下载 import re import os import glob import torch from torch.hub import download_url_to_file from ...

随机推荐

  1. 国内的go get问题的解决 --gopm

    一.golang之旅--gopm 1.什么是gopm 在nodejs中我们有npm,可以通过npm来下载安装一些依赖包.在go中也开发了类似的东西,那就是gopm.这玩意儿是七牛开发的.在这里说下,七 ...

  2. [转载]MGR变量group_replication_primary_member

    组复制的状态变量只有一个,目前在8.0.17这个版本上还是存在的,未来会被废弃掉,我们在单主的模式下,可以用来查看哪个节点是读写节点. mysql: [Warning] Using a passwor ...

  3. centos和Ubuntu系统最小化安装基础命令

    CentOS系统常用的基础软件如下 yum install vim iotop bc gcc gcc-c++ glibc glibc-devel pcre \ pcre-devel openssl o ...

  4. December 07th, Week 49th Saturday, 2019

    Snowflakes are pretty patterns etched in water's dreams. 雪花,是水在梦中镌刻的美丽图案. From Anthony T.Hincks. Tod ...

  5. 21.决策树(ID3/C4.5/CART)

    总览 算法   功能  树结构  特征选择  连续值处理 缺失值处理  剪枝  ID3  分类  多叉树  信息增益   不支持 不支持  不支持 C4.5  分类  多叉树  信息增益比   支持 ...

  6. 【Golang基础】defer执行顺序

    defer 执行顺序类似栈的先入后出原则(FILO)     一个defer引发的小坑:打开文件,读取内容,删除文件   // 原始问题代码 func testFun(){ // 打开文件 file, ...

  7. celery配置

    celery配置 celery的官方文档其实相对还是写的很不错的.但是在一些深层次的使用上面却显得杂乱甚至就没有某些方面的介绍, 通过我的一个测试环境的settings.py来说明一些使用celery ...

  8. Spring Boot MVC api返回的String无法关联到视图页面

    1:问题 使用 @Restcontroller 返回值定义为String 时 无法返回具体的页面 @RestController public class HelloController { @Get ...

  9. PHP 开发工程师基础篇 - PHP 数组

    数组 (Array) 数组是 PHP 中最重要的数据类型,可以说是掌握数组,基本上 PHP 一大半问题都可以解决. PHP 数组与其他编程语言数组概念不一样.其他编程语言数组是由相同类型的元素(ele ...

  10. NIO零拷贝的深入分析

    深入分析通过Socket进行数据文件传递中的传统IO的弊端以及NIO的零拷贝实现原理,及用户空间和内核空间的切换方式 传统的IO流程 在这个过程中: 数据从磁盘拷贝进内核空间缓冲区 从内核空间缓冲区拷 ...