pytorch RNN层api的几个参数说明
classtorch.nn.RNN(*args, **kwargs)
input_size – The number of expected features in the input x
hidden_size – The number of features in the hidden state h
num_layers – Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two RNNs together to form a stacked RNN, with the second RNN taking in outputs of the first RNN and computing the final results. Default: 1
nonlinearity – The non-linearity to use. Can be either ‘tanh’ or ‘relu’. Default: ‘tanh’
bias – If False, then the layer does not use bias weights b_ih and b_hh. Default: True
batch_first – If True, then the input and output tensors are provided as (batch, seq, feature)
dropout – If non-zero, introduces a Dropout layer on the outputs of each RNN layer except the last layer, with dropout probability equal to dropout. Default: 0
bidirectional – If True, becomes a bidirectional RNN. Default: False
有个参数一直理解错误,导致了认知困难
首先,RNN这里的序列长度,是动态的,不写在参数里的,具体会由输入的input参数而定
而num_layers并不是RNN的序列长度,而是堆叠层数,由上一层每个时间节点的输出作为下一层每个时间节点的输入
RNN的对象接受的参数,input维度是(seq_len, batch_size, input_dim),h0维度是(num_layers * directions, batch_size, hidden_dim)
其中,input的seq_len决定了序列的长度,h0是提供给每层RNN的初始输入,所有num_layers要和RNN的num_layers对得上
返回两个值,一个output,一个hn
hn的维度是(num_layers * directions, batch_size, hidden_dim),是RNN的右侧输出,如果是双向的话,就还有一个左侧输出
output的维度是(seq_len, batch_size, hidden_dim * directions),是RNN的上侧输出
pytorch RNN层api的几个参数说明的更多相关文章
- 自己动手实现深度学习框架-7 RNN层--GRU, LSTM
目标 这个阶段会给cute-dl添加循环层,使之能够支持RNN--循环神经网络. 具体目标包括: 添加激活函数sigmoid, tanh. 添加GRU(Gate Recurrent U ...
- Zigbee协议栈OSAL层API函数【转载】
OSAL层提供了很多的API来对整个的协议栈进行管理.主要有下面的几类:信息管理.任务同步.时间管理.中断管理.任务管理.内存管理.电源管理以及非易失存储管理.看到这些管理是不是感 ...
- 【转载】 Caffe BN+Scale层和Pytorch BN层的对比
原文地址: https://blog.csdn.net/elysion122/article/details/79628587 ------------------------------------ ...
- [PyTorch] rnn,lstm,gru中输入输出维度
本文中的RNN泛指LSTM,GRU等等 CNN中和RNN中batchSize的默认位置是不同的. CNN中:batchsize的位置是position 0. RNN中:batchsize的位置是pos ...
- Android 访问Android Wear数据层Api——同步Data Items
Data Items它被用来同步手机和wear数据接口,一个Date Items通常包含以下几个部分: Payload 字节数组.无论你需要设置数据类型,我们同意对象序列化和反序列化,大小不能超过10 ...
- pytorch rnn 2
import torch import torch.nn as nn import numpy as np import torch.optim as optim class RNN(nn.Modul ...
- pytorch rnn
温习一下,写着玩. import torch import torch.nn as nn import numpy as np import torch.optim as optim class RN ...
- pytorch --Rnn语言模型(LSTM,BiLSTM) -- 《Recurrent neural network based language model》
论文通过实现RNN来完成了文本分类. 论文地址:88888888 模型结构图: 原理自行参考论文,code and comment: # -*- coding: utf-8 -*- # @time : ...
- Pytorch基础——使用 RNN 生成简单序列
一.介绍 内容 使用 RNN 进行序列预测 今天我们就从一个基本的使用 RNN 生成简单序列的例子中,来窥探神经网络生成符号序列的秘密. 我们首先让神经网络模型学习形如 0^n 1^n 形式的上下文无 ...
随机推荐
- WEB前端资源集(二)
在上一篇为大家整理出了一些资源网站,接下来给大家整理了一些开发中常用的工具. 开发工具篇 开发工具集 Sublime Text 3:SublimeText 3是一个代码编辑器,也是HTML和散文先进的 ...
- cs231n spring 2017 lecture1 Introduction to Convolutional Neural Networks for Visual Recognition
1. 生物学家做实验发现脑皮层对简单的结构比如角.边有反应,而通过复杂的神经元传递,这些简单的结构最终帮助生物体有了更复杂的视觉系统.1970年David Marr提出的视觉处理流程遵循这样的原则,拿 ...
- 关于log4j中log4j.properties和log4j.xml的加载顺序
如果采用log4j输出日志,要对log4j加载配置文件的过程有所了解. log4j启动时,默认会寻找source folder下的log4j.xml配置文件,若没有,会寻找log4j.properti ...
- MySQL数据类型(DATA Type)与数据恢复与备份方法
一.数据类型(DATA Type)概述 MySQL支持多种类型的SQL数据类型:数字类型,日期和时间类型,字符串(字符和字节)类型以及空间类型 数据类型描述使用以下约定: M表示整数类型的最大显示宽度 ...
- 算法笔记-Day_01(1001 害死人不偿命的(3n+1)猜想
卡拉兹(Callatz)猜想: 对任何一个正整数 n,如果它是偶数,那么把它砍掉一半:如果它是奇数,那么把 (3n+1) 砍掉一半.这样一直反复砍下去,最后一定在某一步得到 n=1.卡拉兹在 1950 ...
- 如何理解TCP的三次握手协议?
• TCP是一个面向链接的协议,任何一个面向连接的协议,我们都可以将其类比为我们最熟悉的打电话模型. 如何类比呢?我们可以从建立和销毁两个阶段分别来看这件事情. 建立连接阶段 首先,我们来看看TCP中 ...
- Win10+WSL2+Ubuntu 18.04(WSL下)+VS Code(Win10下)+TexLive 2019(Ubuntu下)安装和配置
本人手头电脑是Win10 Home版全新安装的系统,由于不想在新系统盘里面安装TexLive导致固态硬盘不断扩大,所以,考虑安装Ubuntu做为WSL,然后把TexLive安装在Ubuntu,并通过V ...
- Vizceral小白入门
Vizceral小白入门 接到一个任务,要求将N个program可视化,能一目了然查看当前爬虫状态.记得之前做测试时,一个queue service前端可视化效果不错,经询问是用vizceral开源框 ...
- NumPy的随机函数子库——numpy.random
NumPy的随机函数子库numpy.random 导入模块:import numpy as np 1.numpy.random.rand(d0,d1,...,dn) 生成一个shape为(d0,d1, ...
- html5调用摄像头功能
前言 前些天,线上笔试的时候,发现需要浏览器同意开启摄像头,感觉像是 js 调用的,由于当时笔试,也就没想到这么多