pytorch rnn
温习一下,写着玩。
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
class RNN(nn.Module):
def __init__(self,input_dim , hidden_dim):
super(RNN,self).__init__()
self._rnn = nn.RNN(input_size = input_dim , hidden_size= hidden_dim )
self.linear = nn.Linear(hidden_dim , 1)
self.relu = nn.ReLU()
def forward(self , _in):
layer1 , h = self._rnn(_in)
layer2 = self.relu(self.linear(self.relu(layer1)))
return layer2
def init_weight(self):
nn.init.normal_(self.linear.weight.data , 0 , np.sqrt(2 / 16))
nn.init.uniform_(self.linear.bias, 0, 0)
def getBinDict(bit_size = 16):
max = pow(2,bit_size)
bin_dict = {}
for i in range(max):
s = '{:016b}'.format(i)
arr = np.array(list(s))
arr = arr.astype(int)
bin_dict[i] = arr
return bin_dict
binary_dim = 16
int2binary = getBinDict(binary_dim)
def getBatch( batch_size):
x = np.random.randint(0,256,[batch_size , 2])
x_arr = np.zeros([binary_dim , batch_size , 2 ] , dtype=int)
y_arr = np.zeros([binary_dim,batch_size,1] , dtype=int)
for i in range(0 , binary_dim):
batch_x_arr = np.zeros([batch_size,2] , dtype=int)
batch_y_arr = np.zeros([batch_size,1] , dtype=int)
for j in range(len(x)):
batch_x_arr[j] =[int2binary[int(x[j][0])][i] , int2binary[int(x[j][1])][i]]
batch_y_arr[j] =[int2binary[ int(x[j][0]) + int(x[j][1])][i]]
#此处要翻转,rnn处理时是从下标为0处开始,所以要把二进制的高低位翻转
y_arr[binary_dim - i - 1] = batch_y_arr
x_arr[binary_dim - i - 1] = batch_x_arr
return x_arr , y_arr , x
def getInt(y , bit_size):
arr = np.zeros([len(y[0])])
for i in range(len(y[0])):
for j in range(bit_size):
arr[i] += (int(y[j][i][0]) * pow(2 , j))
return arr
if __name__ == '__main__':
input_size = 2
hidden_size = 8
batch_size = 100
net = RNN(input_size, hidden_size)
net.init_weight()
print(net)
optimizer = optim.Adam(net.parameters(), lr=0.01, weight_decay=1e-4)
loss_function = nn.MSELoss()#.CrossEntropyLoss()
for i in range(100000):
net.zero_grad()
x ,y , t = getBatch(batch_size)
in_x = torch.Tensor(x)
y = torch.Tensor(y)
output = net(in_x)
loss = loss_function(output , y)
loss.backward()
optimizer.step()
if i % 100== 0:
output2 = torch.round(output)
result = getInt(output2,binary_dim)
print(t , result)
print('iterater:%d loss:%f'%(i , loss))
pytorch rnn的更多相关文章
- pytorch rnn 2
import torch import torch.nn as nn import numpy as np import torch.optim as optim class RNN(nn.Modul ...
- [PyTorch] rnn,lstm,gru中输入输出维度
本文中的RNN泛指LSTM,GRU等等 CNN中和RNN中batchSize的默认位置是不同的. CNN中:batchsize的位置是position 0. RNN中:batchsize的位置是pos ...
- pytorch --Rnn语言模型(LSTM,BiLSTM) -- 《Recurrent neural network based language model》
论文通过实现RNN来完成了文本分类. 论文地址:88888888 模型结构图: 原理自行参考论文,code and comment: # -*- coding: utf-8 -*- # @time : ...
- pytorch RNN层api的几个参数说明
classtorch.nn.RNN(*args, **kwargs) input_size – The number of expected features in the input x hidde ...
- 机器翻译注意力机制及其PyTorch实现
前面阐述注意力理论知识,后面简单描述PyTorch利用注意力实现机器翻译 Effective Approaches to Attention-based Neural Machine Translat ...
- PyTorch专栏(六): 混合前端的seq2seq模型部署
欢迎关注磐创博客资源汇总站: http://docs.panchuang.net/ 欢迎关注PyTorch官方中文教程站: http://pytorch.panchuang.net/ 专栏目录: 第一 ...
- 混合前端seq2seq模型部署
混合前端seq2seq模型部署 本文介绍,如何将seq2seq模型转换为PyTorch可用的前端混合Torch脚本.要转换的模型来自于聊天机器人教程Chatbot tutorial. 1.混合前端 在 ...
- “你什么意思”之基于RNN的语义槽填充(Pytorch实现)
1. 概况 1.1 任务 口语理解(Spoken Language Understanding, SLU)作为语音识别与自然语言处理之间的一个新兴领域,其目的是为了让计算机从用户的讲话中理解他们的意图 ...
- Pytorch系列教程-使用字符级RNN生成姓名
前言 本系列教程为pytorch官网文档翻译.本文对应官网地址:https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutor ...
随机推荐
- C++ 类的继承三(继承中的构造与析构)
//继承中的构造与析构 #include<iostream> using namespace std; /* 继承中的构造析构调用原则 1.子类对象在创建时会首先调用父类的构造函数 2.父 ...
- MyBatis是支持普通 SQL查询
MyBatis是支持普通 SQL查询,存储过程和高级映射的优秀持久层框架.MyBatis 消除了几乎所有的JDBC代码和参数的手工设置以及结果集的检索.MyBatis 使用简单的 XML或注解用于配置 ...
- java----内部类与匿名内部类的各种注意事项与知识点
Java 内部类分四种:成员内部类.局部内部类.静态内部类和匿名内部类.1.成员内部类: 即作为外部类的一个成员存在,与外部类的属性.方法并列.注意:成员内部类中不能定义静态变量,但可以访问外部类的所 ...
- SQL语句:语法错误(操作符丢失)在查询表达式中
所谓操作符丢失,应该是你在拼接SQL语句是少了关键词或者分隔符,导致系统无法识别SQL语句.建议:1.监控SQL语句,看看哪里出现问题:断点看下最后的sql到底是什么样子就知道了,另外你可以把这段sq ...
- django admin后台css样式丢失
尼玛 坑爹啊 怎么光秃秃的,跟人家的不一样啊 打开firebug 发现报错,找不到css 通过google找到原因,是因为admin所需的js ,css等静态文件虽然都在django的安装目录内,但是 ...
- WPF 隧道路由事件
阅读本文前,请先了解 冒泡路由事件:http://www.cnblogs.com/andrew-blog/p/WPF_BubbledEvent.html 隧道路由事件的工作方式和冒泡路由事件相同,但方 ...
- Linq循环DataTable,使用匿名对象取出需要的列
var g_id = context.Request["g_id"]; DataTable dt = new DataTable(); var sql = @"selec ...
- iOS面试题--网络--如何处理多个网络请求的并发的情况
如何处理多个网络请求的并发的情况 一.概念 1.并发 当有多个线程在操作时,如果系统只有一个CPU,则它根本不可能真正同时进行一个以上的线程,它只能把CPU运行时间划分成若干个时间段,再将时间 段分配 ...
- ORA-00845 MEMORY_TARGET not supported on this system解决办法
ORA-00845: MEMORY_TARGET not supported on this system报错解决 Oracle 11g数据库修改pfile参数后启动数据库报错ora-00845 SQ ...
- Windows7 x64系统下安装Nodejs并在WebStorm下搭建编译less环境
1. 打开Nodejs官网http://www.nodejs.org/,点“DOWNLOADS”,点64-bit下载“node-v0.10.33-x64.msi”. 2. 下载好后,双击“node-v ...