TCN代码详解-Torch (误导纠正)
TCN代码详解-Torch (误导纠正)
1. 绪论
TCN网络由Shaojie Bai, J. Zico Kolter, Vladlen Koltun 三人于2018提出。对于序列预测而言,通常考虑循环神经网络结构,例如RNN、LSTM、GRU等。他们三个人的研究建议我们,对于某些序列预测(音频合成、字级语言建模和机器翻译),可以考虑使用卷积网络结构。
关于TCN基本构成和他们的原理有相当多的博客已经解释的很详细的了。总结一句话:TCN = 1D FCN + 因果卷积。下面的博客对因果卷积和孔洞卷积有详细的解释。
但是,包括TCN原文作者,上面这些博客对TCN网络结构的阐释无一例外都是使用下面这张图片。而问题在于,如果不熟悉Torch操作和基本的卷积网络操作,这张图片具有很大的误导性。

图1 膨胀因果卷积(膨胀因子d = 1,2,4,滤波器大小k = 3)
结合上图和上面列举的博客,我们可以大致理解到,TCN就是在序列上使用一维卷积核,沿着时间方向,按照空洞卷积的方式,依次计算。
例如,上图中,
- 第一个hidden层是由 \(d=1\) 的空洞卷积,卷积而来,退化为基本的一维卷积操作;
- 第二个hidden层是由 \(d=2\) 的空洞卷积,卷积而来,卷积每个值时隔开了一个值;
- 第二个hidden层是由 \(d=4\) 的空洞卷积,卷积而来,卷积每个值时隔开了三个值;
由此,上图中网络深度为3,每一层有1个卷积操作。
如果你也是这么理解,恭喜你,成功的被我带跑偏了。
2. TCN结构再次图解
上图中网络深度确实为3,但是每一层并不是只有1个卷积操作。这时候就要拿出原论文中第2个图了。

图2 TCN核心结构
这张图左边展示了TCN结构的核心,卷积+残差,作者把它命名为Residual block。我这里简称为block。
可以发现一个block有两个卷积操作和一个残差操作。因此,图1中每到下一层,都会有两个卷积操作和一个残差操作,并不是一个卷积操作。再次提醒,当 \(d=1\) 时,空洞卷积退化为普通的卷积,正如图2右图展示的。
因此,对于图1中由原始序列到第一层hidden的真实结构为:

3. 结合原文的torch代码解释
很多博客再源代码解释时,基本都是一个模子,没有真正解释关键参数的含义,以及他们如何通过torch的tensor作用的。
预了解TCN结构,须明白原论文中作者描述的这样一句话:
Since a TCN’s receptive field depends on the network depth n as well as filter size k and dilation factor d, stabilization of deeper and larger TCNs becomes important.
翻译是:
由于TCN的感受野依赖于网络深度n、滤波器大小k和扩张因子d,因此更大更深的TCN的稳定变得很重要。
下面结合作者源代码,对这三个参数解释。
3.1 TemporalConvNet
网络深度n就是有多少个block,反应到源代码的变量为num_channels的长度,即 \(len(num_channels)\)。
class TemporalConvNet(nn.Module):
def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
super(TemporalConvNet, self).__init__()
"""
:param num_inputs: int, 输入通道数或者特征数
:param num_channels: list, 每层的hidden_channel数. 例如[5,12,3], 代表有3个block,
block1的输出channel数量为5;
block2的输出channel数量为12;
block3的输出channel数量为3.
:param kernel_size: int, 卷积核尺寸
:param dropout: float, drop_out比率
"""
layers = []
num_levels = len(num_channels)
# 可见,如果num_channels=[5,12,3],那么
# block1的dilation_size=1
# block2的dilation_size=2
# block3的dilation_size=4
for i in range(num_levels):
dilation_size = 2 ** i
in_channels = num_inputs if i == 0 else num_channels[i-1]
out_channels = num_channels[i]
layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
padding=(kernel_size-1) * dilation_size, dropout=dropout)]
self.network = nn.Sequential(*layers)
def forward(self, x):
return self.network(x)
3.2 TemporalBlock
参数dilation的解释,结合上面和下面的代码。
class TemporalBlock(nn.Module):
def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
super(TemporalBlock, self).__init__()
"""
构成TCN的核心Block, 原作者在图中成为Residual block, 是因为它存在残差连接.
但注意, 这个模块包含了2个Conv1d.
:param n_inputs: int, 输入通道数或者特征数
:param n_outputs: int, 输出通道数或者特征数
:param kernel_size: int, 卷积核尺寸
:param stride: int, 步长, 在TCN固定为1
:param dilation: int, 膨胀系数. 与这个Residual block(或者说, 隐藏层)所在的层数有关系.
例如, 如果这个Residual block在第1层, dilation = 2**0 = 1;
如果这个Residual block在第2层, dilation = 2**1 = 2;
如果这个Residual block在第3层, dilation = 2**2 = 4;
如果这个Residual block在第4层, dilation = 2**3 = 8 ......
:param padding: int, 填充系数. 与kernel_size和dilation有关.
:param dropout: float, dropout比率
"""
self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
stride=stride, padding=padding, dilation=dilation))
# 因为 padding 的时候, 在序列的左边和右边都有填充, 所以要裁剪
self.chomp1 = Chomp1d(padding)
self.relu1 = nn.ReLU()
self.dropout1 = nn.Dropout(dropout)
self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
stride=stride, padding=padding, dilation=dilation))
self.chomp2 = Chomp1d(padding)
self.relu2 = nn.ReLU()
self.dropout2 = nn.Dropout(dropout)
self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
self.conv2, self.chomp2, self.relu2, self.dropout2)
# 1×1的卷积. 只有在进入Residual block的通道数与出Residual block的通道数不一样时使用.
# 一般都会不一样, 除非num_channels这个里面的数, 与num_inputs相等. 例如[5,5,5], 并且num_inputs也是5
self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
# 在整个Residual block中有非线性的激活. 这个容易忽略!
self.relu = nn.ReLU()
self.init_weights()
def init_weights(self):
self.conv1.weight.data.normal_(0, 0.01)
self.conv2.weight.data.normal_(0, 0.01)
if self.downsample is not None:
self.downsample.weight.data.normal_(0, 0.01)
def forward(self, x):
out = self.net(x)
res = x if self.downsample is None else self.downsample(x)
return self.relu(out + res)
3.3 Chomp1d
裁剪模块。这里注意,padding的时候对数据列首尾都添加了,torch官方解释如下:
padding controls the amount of padding applied to the input. It can be either a string {‘valid’, ‘same’} or a tuple of ints giving the amount of implicit padding applied on both sides.
注意这里是both sides。例如,还是上述代码中的例子,kernel_size = 3,在第一层(对于第一个block),padding = 2。对于长度为20的序列,先padding,长度为\(20+2\times2=24\),再卷积,长度为\((24-3)+1=22\)。所以要裁掉,保证输出序列与输入序列相等。
class Chomp1d(nn.Module):
def __init__(self, chomp_size):
super(Chomp1d, self).__init__()
self.chomp_size = chomp_size
def forward(self, x):
return x[:, :, :-self.chomp_size].contiguous()
4. 验证TCN的输入输出
根据上述代码的解释和理解,我们可以方便的呃验证其输入和输出。
# 输入27个通道,或者特征
# 构建1层的TCN,最后输出一个通道,或者特征
model2 = TemporalConvNet(num_inputs=27, num_channels=[32,16,4,1], kernel_size=3, dropout=0.3)
import torch
# 检测输出
with torch.no_grad():
# 模型输入一定是 (batch_size, channels, length)
model2.eval()
print(model2(torch.randn(16,27,20)).shape)
打印结果为(16, 1, 20) 。通道数降为1。输入序列长度20, 输出序列长度也是20。
TCN代码详解-Torch (误导纠正)的更多相关文章
- Github-jcjohnson/torch-rnn代码详解
Github-jcjohnson/torch-rnn代码详解 zoerywzhou@gmail.com http://www.cnblogs.com/swje/ 作者:Zhouwan 2016-3- ...
- BM算法 Boyer-Moore高质量实现代码详解与算法详解
Boyer-Moore高质量实现代码详解与算法详解 鉴于我见到对算法本身分析非常透彻的文章以及实现的非常精巧的文章,所以就转载了,本文的贡献在于将两者结合起来,方便大家了解代码实现! 算法详解转自:h ...
- ASP.NET MVC 5 学习教程:生成的代码详解
原文 ASP.NET MVC 5 学习教程:生成的代码详解 起飞网 ASP.NET MVC 5 学习教程目录: 添加控制器 添加视图 修改视图和布局页 控制器传递数据给视图 添加模型 创建连接字符串 ...
- Github-karpathy/char-rnn代码详解
Github-karpathy/char-rnn代码详解 zoerywzhou@gmail.com http://www.cnblogs.com/swje/ 作者:Zhouwan 2016-1-10 ...
- 代码详解:TensorFlow Core带你探索深度神经网络“黑匣子”
来源商业新知网,原标题:代码详解:TensorFlow Core带你探索深度神经网络“黑匣子” 想学TensorFlow?先从低阶API开始吧~某种程度而言,它能够帮助我们更好地理解Tensorflo ...
- JAVA类与类之间的全部关系简述+代码详解
本文转自: https://blog.csdn.net/wq6ylg08/article/details/81092056类和类之间关系包括了 is a,has a, use a三种关系(1)is a ...
- Java中String的intern方法,javap&cfr.jar反编译,javap反编译后二进制指令代码详解,Java8常量池的位置
一个例子 public class TestString{ public static void main(String[] args){ String a = "a"; Stri ...
- Kaggle网站流量预测任务第一名解决方案:从模型到代码详解时序预测
Kaggle网站流量预测任务第一名解决方案:从模型到代码详解时序预测 2017年12月13日 17:39:11 机器之心V 阅读数:5931 近日,Artur Suilin 等人发布了 Kaggl ...
- 基础 | batchnorm原理及代码详解
https://blog.csdn.net/qq_25737169/article/details/79048516 https://www.cnblogs.com/bonelee/p/8528722 ...
- 非极大值抑制(NMS,Non-Maximum Suppression)的原理与代码详解
1.NMS的原理 NMS(Non-Maximum Suppression)算法本质是搜索局部极大值,抑制非极大值元素.NMS就是需要根据score矩阵和region的坐标信息,从中找到置信度比较高的b ...
随机推荐
- webpack打包优化点
目录 1. noParse 2. 包含和排除目录 3. IgnorePlugin 4. happypack 5. DllPlugin动态链接库 6. 热更新 7. 开发环境 tree-shaking ...
- 微信公众号商城、小程序商城、H5商城 实例 前后端源码
CRMEB客户管理+电商营销系统 https://gitee.com/ZhongBangKeJi/CRMEB 演示站后台: http://demo.crmeb.net/admin 账号:demo 密 ...
- 我的Vue之旅、04 CSS媒体查询完全指南(Media Quires)
什么是SCSS Sass: Sass Basics (sass-lang.com) SCSS 是 CSS 的预处理器,它比常规 CSS 更强大. 可以嵌套选择器,更好维护.管理代码. 可以将各种值存储 ...
- 第四章:Django表单 - 1:使用表单
假设你想从表单接收用户名数据,一般情况下,你需要在HTML中手动编写一个如下的表单元素: <form action="/your-name/" method="po ...
- linux系统排查数据包常用命令
1.查看当前系统中生效的所有参数 sysctl -a 2.统计处于TIME_WAIT状态的TCP连接数 netstat -ant|grep TIME_WAIT|wc -l 3.统计TCP连接数 net ...
- Request Body Search
官方文档地址:https://www.elastic.co/guide/en/elasticsearch/reference/master/modules-scripting-using.html
- Spring Boot 项目转容器化 K8S 部署实用经验分享
转载自:https://cloud.tencent.com/developer/article/1477003 我们知道 Kubernetes 是 Google 开源的容器集群管理系统,它构建在目前流 ...
- django-compressor安装失败
报错日志: Installing collected packages: rcssmin, django-compressor Running setup.py install for rcssmin ...
- 痞子衡嵌入式:浅谈i.MXRT10xx系列MCU外接24MHz晶振的作用
大家好,我是痞子衡,是正经搞技术的痞子.今天痞子衡给大家介绍的是i.MXRT10xx系列MCU外接24MHz晶振的作用. 痞子衡之前写过一篇关于时钟引脚的文章 <i.MXRT1xxx系列MCU时 ...
- 关于Vue多线程的思考
在前端调用的时候,我们难免需要同一时刻向后端请求多组数据或是总是期待着是否存在一个独立的线程去处理一系列的数据.线程相应,资源的抢占这是前端较为麻烦的点.这里就来聊聊我在前端踩的坑. 首先是线程问题说 ...