对于一个模型,都可以从以下几个方面进行调参:

1. 对weight和bias进行初始化(效果很好,一般都可以提升1-2%)

Point 1 (CNN):

 for conv in self.convs1:
init.xavier_normal(conv.weight, gain=np.sqrt(2.0))  # 对weight进行正态分布初始化
# init.normal(conv.weight, mean=0, std=0.1)
# init.constant(conv.bias, 0.1)              # 对bias初始化为0.1

Point 2 (LSTM):

(1)Bias vectors are initialized to zero, except the bias b f for the forget gate in LSTM , which is initialized to 1.0 .(参见论文End-to-end Sequence Labeling via Bi-directional LSTM-CNNs-CRF)。weight 使用高斯分布或是均匀分布都可以。详细讲解参考博文Deep Learning 之 参数初始化

(2)简单的设置就是,weight设为0.1,bias设为0。

         init.xavier_normal(self.lstm.all_weights[0][0], gain=np.sqrt(2.0))
self.lstm.all_weights[0][3].data[20:40].fill_(1) # forget gate
self.lstm.all_weights[0][3].data[0:20].fill_(0)
self.lstm.all_weights[0][3].data[40:80].fill_(0)

注:对于封装好的lstm,其提供了all_weights接口统一对其参数进行初始化,不能单个定义,forget gate对应的下标是20-39。若是使用lstmcell则可以对单个想要修改的参数进行修改。

2. clip  gradients让权重的梯度更新限制在一定范围内,防止单个节点出现梯度爆炸、梯度消失。

 optimizer.zero_grad()
logit = model(feature)
loss = F.cross_entropy(logit, target)
loss.backward()
# clip gradients
utils.clip_grad_norm(model.parameters(), 5)
optimizer.step()

3. L2 regularization

L2值也叫惩罚值,是为了防止过拟合问题。提供了接口可直接设值,一般设为1e-8。

 optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.01)

4. batch normalization批标准化若设置正确,据说会大大加大迭代速度,效果明显。

若是BatchNorm2d(x),input是(batchsize,channel,height,width),x值对应channel,即维度1。所以channel=0时,求一次mean,var,做一次normalize;channel=1时,求一次.......channel=x时,求一次。BatchNorm1d时情况也是一样的,x对应的是维度1的值,若是不对应,则需要进行转置,如下示例。

 m = nn.BatchNorm1d(2)
input = torch.randn(2, 10)
input = Variable(input)
input = Variable(torch.transpose(input.data, 0, 1))
print(input)
output = m(input)
print(output)

Point 1 (CNN):

     def __init__(self, args):
super(CNN, self).__init__()
self.bn = nn.BatchNorm2d(1) def forward(self, x):
for conv in self.convs1:
xx = conv(x) # variable [torch.FloatTensor of size 16x200x35x1]
xx = Variable(torch.transpose(xx.data, 2, 3))
xx = Variable(torch.transpose(xx.data, 1, 2))
xx = self.bn(xx)
xx = F.relu(xx)
xx = xx.squeeze(1)
a.append(xx)

Point 2 (LSTM):

 class BiLSTM(nn.Module):
def __init__(self, args):
super(BiLSTM, self).__init__()
self.bn1 = nn.BatchNorm1d(2*self.hidden_size)   def forward(self, sentence):
out = self.bn1(out)
out = F.tanh(out)
y = self.hidden2label(out)

结果:以上两种设置并没有提高准确率。

Point 3 (BN-LSTM):

参看论文RECURRENT BATCH NORMALIZATION,不使用pytorch框架,自己实现。

调参tips的更多相关文章

  1. 01.CNN调参

    转载:调参是个头疼的事情,Yann LeCun.Yoshua Bengio和Geoffrey Hinton这些大牛为什么能够跳出各种牛逼的网络? 下面一些推荐的书和文章:调参资料总结Neural Ne ...

  2. 【Python机器学习实战】决策树与集成学习(七)——集成学习(5)XGBoost实例及调参

    上一节对XGBoost算法的原理和过程进行了描述,XGBoost在算法优化方面主要在原损失函数中加入了正则项,同时将损失函数的二阶泰勒展开近似展开代替残差(事实上在GBDT中叶子结点的最优值求解也是使 ...

  3. scikit-learn随机森林调参小结

    在Bagging与随机森林算法原理小结中,我们对随机森林(Random Forest, 以下简称RF)的原理做了总结.本文就从实践的角度对RF做一个总结.重点讲述scikit-learn中RF的调参注 ...

  4. scikit-learn 梯度提升树(GBDT)调参小结

    在梯度提升树(GBDT)原理小结中,我们对GBDT的原理做了总结,本文我们就从scikit-learn里GBDT的类库使用方法作一个总结,主要会关注调参中的一些要点. 1. scikit-learn ...

  5. word2vec参数调整 及lda调参

     一.word2vec调参   ./word2vec -train resultbig.txt -output vectors.bin -cbow 0 -size 200 -window 5 -neg ...

  6. scikit learn 模块 调参 pipeline+girdsearch 数据举例:文档分类 (python代码)

    scikit learn 模块 调参 pipeline+girdsearch 数据举例:文档分类数据集 fetch_20newsgroups #-*- coding: UTF-8 -*- import ...

  7. 基于pytorch的CNN、LSTM神经网络模型调参小结

    (Demo) 这是最近两个月来的一个小总结,实现的demo已经上传github,里面包含了CNN.LSTM.BiLSTM.GRU以及CNN与LSTM.BiLSTM的结合还有多层多通道CNN.LSTM. ...

  8. 漫谈PID——实现与调参

    闲话: 作为一个控制专业的学生,说起PID,真是让我又爱又恨.甚至有时候会觉得我可能这辈子都学不会pid了,但是经过一段时间的反复琢磨,pid也不是很复杂.所以在看懂pid的基础上,写下这篇文章,方便 ...

  9. hyperopt自动调参

    hyperopt自动调参 在传统机器学习和深度学习领域经常需要调参,调参有些是通过通过对数据和算法的理解进行的,这当然是上上策,但还有相当一部分属于"黑盒" hyperopt可以帮 ...

随机推荐

  1. windows下,cmd 运行 python 脚本,选中文字就停止运行了【已解决】

    参考资料: https://jingyan.baidu.com/article/ce09321bb95dda2bff858f26.html 问题原因: cmd 里面,快速编辑模式会暂停程序 解决步骤: ...

  2. java自动化测试开发环境搭建(更新至2018年10月8日 11:42:15)

    1.安装JDK的1.8版本 官网下载地址:https://www.oracle.com/technetwork/java/javase/downloads/jdk8-downloads-2133151 ...

  3. c++ primer plus 第6版 部分一 1-4章

    c++ primer plus 第6版 源代码 ---编译器---目标代码---连接程序(启动代码--库代码)---可执行代码 源代码扩展名:c   cc   cxx     C    cpp     ...

  4. Java开发微信公众号(二)---开启开发者模式,接入微信公众平台开发

    接入微信公众平台开发,开发者需要按照如下步骤完成: 1.填写服务器配置 2.验证服务器地址的有效性 3.依据接口文档实现业务逻辑 资料准备: 1.一个可以访问的外网,即80的访问端口,因为微信公众号接 ...

  5. SQLAlchemy 学习笔记(三):ORM 中的关系构建

    个人笔记,不保证正确. 关系构建:ForeignKey 与 relationship 关系构建的重点,在于搞清楚这两个函数的用法.ForeignKey 的用法已经在 SQL表达式语言 - 表定义中的约 ...

  6. ionic2实战-使用Chart.js

    前言 Chart.js官网 Chart.js中文文档 安装Chart.js 执行cnpm install typings -g,全局安装Typings 执行typings search chart.j ...

  7. 笔记:CS231n+assignment2(作业二)(二)

    一.参数更新策略     1.SGD 也就是随机梯度下降,最简单的更新形式是沿着负梯度方向改变参数(因为梯度指向的是上升方向,但是我们通常希望最小化损失函数).假设有一个参数向量x及其梯度dx,那么最 ...

  8. vs2017 出现“文件中的类都不能进行设计,因此未能为该文件显示设计器”问题处理

    今天拷贝了以前的一个项目.打算出一个新版本. 但是拷贝了sln文件后,去除掉以前的项目,新增了一个  winfrom项目中 出现了:文件中的类都不能进行设计,因此未能为该文件显示设计器.错误 百度了一 ...

  9. Windows + Eclipse 构建mahout运行环境

    mahout的完整运行还是需要hadoop的支持的,不过很多算法只需要能把hadoop的jar包加入到classpath之中就能正常运行. 比如我们在使用LogisticModelParameters ...

  10. hdu 1503 最长公共子序列

    /* 给两个串a,b.输出一个最短的串(含等于a的子序列且含等于b的子序列) */ #include <iostream> #include <cstdio> #include ...