CNN与RNN的结合

问题

前几天学习了RNN的推导以及代码,那么问题来了,能不能把CNN和RNN结合起来,我们通过CNN提取的特征,能不能也将其看成一个序列呢?答案是可以的。

但是我觉得一般直接提取的特征喂给哦RNN训练意义是不大的,因为RNN擅长处理的是不定长的序列,也就是说,seq size是不确定的,但是一般图像特征的神经元数量都是定的,这个时候再接个rnn说实话意义不大,除非设计一种结构可以让网络不定长输出。(我的一个简单想法就是再设计一条之路去学习一个神经元权重mask,按照规则过滤掉一些神经元,然后丢进rnn或者lstm训练)

如何实现呢

import torch
import torch.nn as nn
from torchsummary import summary
from torchvision import datasets,transforms
import torch.optim as optim
from tqdm import tqdm
class Model(nn.Module):
def __init__(self):
super(Model,self).__init__() self.feature_extractor = nn.Sequential(
nn.Conv2d(1,16,kernel_size = 3,stride=2),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Conv2d(16,64,kernel_size = 3,stride=2),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64,128,kernel_size = 3,stride=2),
nn.BatchNorm2d(128),
nn.ReLU(),
)
self.rnn = nn.RNN(128,256,2) # input_size,output_size,hidden_num
self.h0 = torch.zeros(2,32,256) # 层数 batchsize hidden_dim
self.predictor = nn.Linear(4*256,10)
def forward(self,x):
x = self.feature_extractor(x) # (-1,128,2,2),4个神经元,128维度
x,ht = self.rnn(x.permute(3,4,0,1).contiguous().view(4,-1,128),self.h0) # (h*w,batch_size,hidden_dim) x = self.predictor(x.view(-1,256*4))
return x if __name__ == "__main__":
model = Model()
#summary(model,(1,28,28),device = "cpu")
loss_fn = nn.CrossEntropyLoss()
train_dataset = datasets.MNIST(root="./data/",train = True,transform = transforms.ToTensor(),download = True)
test_dataset = datasets.MNIST(root="./data/",train = False,transform = transforms.ToTensor(),download = True) train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=32,
shuffle=True) test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=128,
shuffle=False)
optimizer = optim.Adam(model.parameters(),lr = 1e-3)
print(len(train_loader))
for epoch in range(100):
epoch_loss = 0.
for x,target in train_loader:
#print(x.size())
y = model(x)
loss = loss_fn(y,target)
epoch_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step() print("epoch : {} and loss is : {}".format(epoch +1,epoch_loss))
torch.save(model.state_dict(),"rnn_cnn.pth")

上面代码可以看出我已经规定了RNN输入神经元的个数,所以肯定是定长的输入,我训练之后是可以收敛的。

对于不定长,其实还是没办法改变每个batch的seq len,因为规定的一定是最长的seq len,所以没办法做到真正的不定长。所以我能做的就是通过支路学习一个权重作用到原来的feature上去,这个权重是0-1权重,其实这样就可以达到效果了。

import torch
import torch.nn as nn
from torchsummary import summary
from torchvision import datasets,transforms
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
class Model(nn.Module):
def __init__(self):
super(Model,self).__init__() self.feature_extractor = nn.Sequential(
nn.Conv2d(1,16,kernel_size = 3,stride=2),
nn.BatchNorm2d(16),
nn.ReLU6(),
nn.Conv2d(16,64,kernel_size = 3,stride=2),
nn.BatchNorm2d(64),
nn.ReLU6(),
nn.Conv2d(64,128,kernel_size = 3,stride=2),
nn.BatchNorm2d(128),
nn.ReLU6(),
)
self.attn = nn.Conv2d(128,1,kernel_size = 1)
self.rnn = nn.RNN(128,256,2) # input_size,output_size,hidden_num self.h0 = torch.zeros(2,32,256) # 层数 batchsize hidden_dim
self.predictor = nn.Linear(4*256,10)
def forward(self,x):
x = self.feature_extractor(x) # (-1,128,2,2),4个神经元,128维度
attn = F.relu(self.attn(x)) # (-1,1,2,2) -> (-1,4)
x = x * attn
#print(x.size())
x,ht = self.rnn(x.permute(3,4,0,1).contiguous().view(4,-1,128),self.h0) # (h*w,batch_size,hidden_dim)
#self.h0 = ht
x = self.predictor(x.view(-1,256*4))
return x if __name__ == "__main__":
model = Model()
#summary(model,(1,28,28),device = "cpu")
#exit()
loss_fn = nn.CrossEntropyLoss()
train_dataset = datasets.MNIST(root="./data/",train = True,transform = transforms.ToTensor(),download = True)
test_dataset = datasets.MNIST(root="./data/",train = False,transform = transforms.ToTensor(),download = True) train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=32,
shuffle=True) test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=128,
shuffle=False)
optimizer = optim.Adam(model.parameters(),lr = 1e-3)
print(len(train_loader))
for epoch in range(100):
epoch_loss = 0.
for x,target in train_loader:
#print(x.size()) y = model(x) loss = loss_fn(y,target)
epoch_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step() print("epoch : {} and loss is : {}".format(epoch +1,epoch_loss))
torch.save(model.state_dict(),"rnn_cnn.pth")

我自己训练了一下,后者要比前者收敛的快的多。

[学习笔记] CNN与RNN方法结合的更多相关文章

  1. 【python学习笔记】9.魔法方法、属性和迭代器

    [python学习笔记]9.魔法方法.属性和迭代器 魔法方法:xx, 收尾各有两个下划线的方法 __init__(self): 构造方法,创建对象时候自动执行,可以为其增加参数, 父类构造方法不会被自 ...

  2. Java8学习笔记(八)--方法引入的补充

    在Java8学习笔记(三)--方法引入中,简要总结了方法引入时的使用规则,但不够完善.这里补充下几种情况: 从形参到实例方法的实参 示例 public class Example { static L ...

  3. 深度学习中的序列模型演变及学习笔记(含RNN/LSTM/GRU/Seq2Seq/Attention机制)

    [说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![认真看图][认真看图] [补充说明]深度学习中的序列模型已经广泛应用于自然语言处理(例如机器翻 ...

  4. C#设计模式学习笔记:(2)工厂方法模式

    本笔记摘抄自:https://www.cnblogs.com/PatrickLiu/p/7567880.html,记录一下学习过程以备后续查用. 一.引言 接上一篇C#设计模式学习笔记:简单工厂模式( ...

  5. Python学习笔记--Python字符串连接方法总结

    声明: 这些总结的学习笔记,一部分是自己在工作学习中总结,一部分是收集网络中的知识点总结而成的,但不到原文链接.如果有侵权,请知会,多谢. python中有很多字符串连接方式,总结一下: 1)最原始的 ...

  6. 《Python基础教程(第二版)》学习笔记 -> 第九章 魔法方法、属性和迭代器

    准备工作 >>> class NewStyle(object): more_code_here >>> class OldStyle: more_code_here ...

  7. 0040 Java学习笔记-多线程-线程run()方法中的异常

    run()与异常 不管是Threade还是Runnable的run()方法都没有定义抛出异常,也就是说一条线程内部发生的checked异常,必须也只能在内部用try-catch处理掉,不能往外抛,因为 ...

  8. 【zepto学习笔记01】核心方法$()(补)

    前言 昨天学习了核心$(),有几个遗留问题,我们今天来看看吧 $.each 遍历数组/对象,将每条数据作为callback的上下文,并传入数据以及数据的索引进行处理,如果其中一条数据的处理结果明确返回 ...

  9. 【zepto学习笔记01】核心方法$()

    前言 我们移动端基本使用zepto了,而我也从一个小白变成稍微靠谱一点的前端了,最近居然经常要改到zepto源码但是,我对zepto不太熟悉,其实前端水准还是不够,所以便私下偷偷学习下吧,别被发现了 ...

随机推荐

  1. access注入

    前面有自己总结详细的mysql注入,自己access注入碰到的比较少,虽然比较简单,但是这里做一个总结 union联合查询法: 因为union前后字段数相同,所以可以先用order by 22 使查询 ...

  2. 为SourceInsight添加多行注释功能菜单

    由于项目看代码主要使用的是Source Insight,习惯了其他编辑器的多行注释功能,对此感到很不习惯,查询网上程序,可以自行添加. 1.打开project,选择base项目中的utils.em,添 ...

  3. Navicat安装、使用教程

    下载地址:Navicat的安装包及破解文件 一. Navicat安装 Navicat既可安装在服务器端,也可以安装在客户端.安装在服务器端,导入数据时可使用默认用户,也可以使用远程用户:安装在客户端, ...

  4. Cypress自动化测试系列之二

    本文技术难度★★★,如果前编内容顺利执行,请继续. 如果Selenium尚无法灵活运用的读者,本文可能难度较大. “理论联系实惠,密切联系领导,表扬和自我表扬”——我就是老司机,曾经写文章教各位怎么打 ...

  5. 02-jar包操作---引用本地包--maven项目

    在idea工具中,普通项目的话,直接在jar上右键add as library就行了. 如果是maven项目 可以将包,放入lib目录下,然后在pom文件配置引用.例子: <!--引入非本地仓库 ...

  6. 谈谈你对 mysql 引擎中的 MyISAM与InnoDB的区别理解?

    InnoDB和MyISAM是许多人在使用MySQL时最常用的两个表类型,这两个表类型各有优劣,视具体应用而定.基本的差别为:MyISAM类型不支持事务处理等高级处理,而InnoDB类型支持.MyISA ...

  7. windows时钟服务设置

    运行Regedit,打开注册表编辑器. 找到注册表项HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Services\W32Time\Config\,将Anno ...

  8. nginx 常用的命令和配置文件

    常用的命令 进入 nginx 目录中 cd  /usr/local/nginx/sbin 1.查看 nginx 版本号  ./nginx -v  2.启动 nginx  ./nginx  3.停止 n ...

  9. MHA监控进程异常退出(MHA版本:0.56)

    最近遇到一个非常诡异的问题,mha后台进程自己中断退出了.以下是报错:Mon Dec 21 20:16:07 2015 - [info] OK.Mon Dec 21 20:16:07 2015 - [ ...

  10. 【BZOJ4596】【Luogu P4336】 [SHOI2016]黑暗前的幻想乡 矩阵树定理,容斥

    同样是矩阵树定理的裸题.但是要解决它需要能够想到容斥才可以. \(20\)以内的数据范围一定要试试容斥的想法. #include <bits/stdc++.h> using namespa ...