pytorch RNN层api的几个参数说明
classtorch.nn.RNN(*args, **kwargs)
input_size – The number of expected features in the input x
hidden_size – The number of features in the hidden state h
num_layers – Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two RNNs together to form a stacked RNN, with the second RNN taking in outputs of the first RNN and computing the final results. Default: 1
nonlinearity – The non-linearity to use. Can be either ‘tanh’ or ‘relu’. Default: ‘tanh’
bias – If False, then the layer does not use bias weights b_ih and b_hh. Default: True
batch_first – If True, then the input and output tensors are provided as (batch, seq, feature)
dropout – If non-zero, introduces a Dropout layer on the outputs of each RNN layer except the last layer, with dropout probability equal to dropout. Default: 0
bidirectional – If True, becomes a bidirectional RNN. Default: False
有个参数一直理解错误,导致了认知困难
首先,RNN这里的序列长度,是动态的,不写在参数里的,具体会由输入的input参数而定
而num_layers并不是RNN的序列长度,而是堆叠层数,由上一层每个时间节点的输出作为下一层每个时间节点的输入
RNN的对象接受的参数,input维度是(seq_len, batch_size, input_dim),h0维度是(num_layers * directions, batch_size, hidden_dim)
其中,input的seq_len决定了序列的长度,h0是提供给每层RNN的初始输入,所有num_layers要和RNN的num_layers对得上
返回两个值,一个output,一个hn
hn的维度是(num_layers * directions, batch_size, hidden_dim),是RNN的右侧输出,如果是双向的话,就还有一个左侧输出
output的维度是(seq_len, batch_size, hidden_dim * directions),是RNN的上侧输出
pytorch RNN层api的几个参数说明的更多相关文章
- 自己动手实现深度学习框架-7 RNN层--GRU, LSTM
目标 这个阶段会给cute-dl添加循环层,使之能够支持RNN--循环神经网络. 具体目标包括: 添加激活函数sigmoid, tanh. 添加GRU(Gate Recurrent U ...
- Zigbee协议栈OSAL层API函数【转载】
OSAL层提供了很多的API来对整个的协议栈进行管理.主要有下面的几类:信息管理.任务同步.时间管理.中断管理.任务管理.内存管理.电源管理以及非易失存储管理.看到这些管理是不是感 ...
- 【转载】 Caffe BN+Scale层和Pytorch BN层的对比
原文地址: https://blog.csdn.net/elysion122/article/details/79628587 ------------------------------------ ...
- [PyTorch] rnn,lstm,gru中输入输出维度
本文中的RNN泛指LSTM,GRU等等 CNN中和RNN中batchSize的默认位置是不同的. CNN中:batchsize的位置是position 0. RNN中:batchsize的位置是pos ...
- Android 访问Android Wear数据层Api——同步Data Items
Data Items它被用来同步手机和wear数据接口,一个Date Items通常包含以下几个部分: Payload 字节数组.无论你需要设置数据类型,我们同意对象序列化和反序列化,大小不能超过10 ...
- pytorch rnn 2
import torch import torch.nn as nn import numpy as np import torch.optim as optim class RNN(nn.Modul ...
- pytorch rnn
温习一下,写着玩. import torch import torch.nn as nn import numpy as np import torch.optim as optim class RN ...
- pytorch --Rnn语言模型(LSTM,BiLSTM) -- 《Recurrent neural network based language model》
论文通过实现RNN来完成了文本分类. 论文地址:88888888 模型结构图: 原理自行参考论文,code and comment: # -*- coding: utf-8 -*- # @time : ...
- Pytorch基础——使用 RNN 生成简单序列
一.介绍 内容 使用 RNN 进行序列预测 今天我们就从一个基本的使用 RNN 生成简单序列的例子中,来窥探神经网络生成符号序列的秘密. 我们首先让神经网络模型学习形如 0^n 1^n 形式的上下文无 ...
随机推荐
- cesium入门示例-测量工具
作为cesium入门示例级别的最后一篇,参考cesium-长度测量和面积测量实现测量工具封装,修改了其中的距离测量函数,计算贴地距离,并对事件内部处理做了调整.包括贴地距离测量.面积测量.结果清除. ...
- Mac下如何使用homebrew
Homebrew简称brew,是Mac OSX上的软件包管理工具,能在Mac中方便的安装软件或者卸载软件. 常用的命令: 搜索软件:brew search 软件名,如brew search wget ...
- 模型压缩之Channel Pruning
论文地址 channel pruning是指给定一个CNN模型,去掉卷积层的某几个输入channel以及相应的卷积核, 并最小化裁剪channel后与原始输出的误差. 可以分两步来解决: channe ...
- windowserver 2012安装openssh
下载https://github.com/PowerShell/Win32-OpenSSH/releases解压放到C:\Program Files\OpenSSH-Win64 进入到C:\Progr ...
- 2019DDCTF部分Writeup
-- re Windows Reverse1 通过DIE查壳发现存在upx,在linux上upx -d脱壳即可,拖入IDA,通过关键字符串找到关键函数: main函数中也没有什么,将输入的字符串带到s ...
- postgresql学习记录1
数据库9.3.5,系统fedora20,不同系统操作略有不同. 使用yum 命令安装即可:sudo yum install postgresql,postgresql-server 安装完毕后系统中会 ...
- 201771010105—达拉草 实验一 软件工程准备—<软件工程构建之法—初步了解和认识>
项目 内容 这个作业属于哪个课程 https://www.cnblogs.com/nwnu-daizh/ 这个作业的要求在哪里 https://www.cnblogs.com/nwnu-daizh/p ...
- ITT Corporation之“中国战略”
前言:众所周知,中国已经成为全世界第二大经济体,并且坐拥14亿人口的庞大市场,蕴藏着巨大的市场机遇,海外高科技企业想法获得长足的发展重视和开拓中国市场成为重中之重,诸如特斯拉,google,苹果等,近 ...
- LeetCode--二叉树2--运用递归解决树的问题
LeetCode--二叉树2--运用递归解决树的问题 在前面的章节中,我们已经介绍了如何利用递归求解树的遍历. 递归是解决树的相关问题最有效和最常用的方法之一. 我们知道,树可以以递归的方式定义为一个 ...
- Java设计模式二
今天谈的是工厂模式,该模式用于封装和对对象的创建,万物皆对象,那么万物又是产品类,如一个水果厂生产三种水果罐头,我们就可以将这三种水果作为产品类,再定义一个接口用来设定对水果罐头的生成方法,在工厂类中 ...