温习一下,写着玩。

import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim class RNN(nn.Module): def __init__(self,input_dim , hidden_dim):
super(RNN,self).__init__()
self._rnn = nn.RNN(input_size = input_dim , hidden_size= hidden_dim )
self.linear = nn.Linear(hidden_dim , 1)
self.relu = nn.ReLU() def forward(self , _in):
layer1 , h = self._rnn(_in)
layer2 = self.relu(self.linear(self.relu(layer1)))
return layer2 def init_weight(self):
nn.init.normal_(self.linear.weight.data , 0 , np.sqrt(2 / 16))
nn.init.uniform_(self.linear.bias, 0, 0) def getBinDict(bit_size = 16):
max = pow(2,bit_size)
bin_dict = {}
for i in range(max):
s = '{:016b}'.format(i)
arr = np.array(list(s))
arr = arr.astype(int)
bin_dict[i] = arr
return bin_dict binary_dim = 16
int2binary = getBinDict(binary_dim) def getBatch( batch_size):
x = np.random.randint(0,256,[batch_size , 2])
x_arr = np.zeros([binary_dim , batch_size , 2 ] , dtype=int)
y_arr = np.zeros([binary_dim,batch_size,1] , dtype=int)
for i in range(0 , binary_dim):
batch_x_arr = np.zeros([batch_size,2] , dtype=int)
batch_y_arr = np.zeros([batch_size,1] , dtype=int)
for j in range(len(x)):
batch_x_arr[j] =[int2binary[int(x[j][0])][i] , int2binary[int(x[j][1])][i]]
batch_y_arr[j] =[int2binary[ int(x[j][0]) + int(x[j][1])][i]] #此处要翻转,rnn处理时是从下标为0处开始,所以要把二进制的高低位翻转
y_arr[binary_dim - i - 1] = batch_y_arr
x_arr[binary_dim - i - 1] = batch_x_arr
return x_arr , y_arr , x def getInt(y , bit_size):
arr = np.zeros([len(y[0])])
for i in range(len(y[0])):
for j in range(bit_size):
arr[i] += (int(y[j][i][0]) * pow(2 , j))
return arr if __name__ == '__main__':
input_size = 2
hidden_size = 8
batch_size = 100
net = RNN(input_size, hidden_size)
net.init_weight()
print(net)
optimizer = optim.Adam(net.parameters(), lr=0.01, weight_decay=1e-4)
loss_function = nn.MSELoss()#.CrossEntropyLoss()
for i in range(100000):
net.zero_grad()
x ,y , t = getBatch(batch_size)
in_x = torch.Tensor(x)
y = torch.Tensor(y)
output = net(in_x)
loss = loss_function(output , y)
loss.backward()
optimizer.step() if i % 100== 0:
output2 = torch.round(output)
result = getInt(output2,binary_dim)
print(t , result)
print('iterater:%d loss:%f'%(i , loss))

pytorch rnn的更多相关文章

  1. pytorch rnn 2

    import torch import torch.nn as nn import numpy as np import torch.optim as optim class RNN(nn.Modul ...

  2. [PyTorch] rnn,lstm,gru中输入输出维度

    本文中的RNN泛指LSTM,GRU等等 CNN中和RNN中batchSize的默认位置是不同的. CNN中:batchsize的位置是position 0. RNN中:batchsize的位置是pos ...

  3. pytorch --Rnn语言模型(LSTM,BiLSTM) -- 《Recurrent neural network based language model》

    论文通过实现RNN来完成了文本分类. 论文地址:88888888 模型结构图: 原理自行参考论文,code and comment: # -*- coding: utf-8 -*- # @time : ...

  4. pytorch RNN层api的几个参数说明

    classtorch.nn.RNN(*args, **kwargs) input_size – The number of expected features in the input x hidde ...

  5. 机器翻译注意力机制及其PyTorch实现

    前面阐述注意力理论知识,后面简单描述PyTorch利用注意力实现机器翻译 Effective Approaches to Attention-based Neural Machine Translat ...

  6. PyTorch专栏(六): 混合前端的seq2seq模型部署

    欢迎关注磐创博客资源汇总站: http://docs.panchuang.net/ 欢迎关注PyTorch官方中文教程站: http://pytorch.panchuang.net/ 专栏目录: 第一 ...

  7. 混合前端seq2seq模型部署

    混合前端seq2seq模型部署 本文介绍,如何将seq2seq模型转换为PyTorch可用的前端混合Torch脚本.要转换的模型来自于聊天机器人教程Chatbot tutorial. 1.混合前端 在 ...

  8. “你什么意思”之基于RNN的语义槽填充(Pytorch实现)

    1. 概况 1.1 任务 口语理解(Spoken Language Understanding, SLU)作为语音识别与自然语言处理之间的一个新兴领域,其目的是为了让计算机从用户的讲话中理解他们的意图 ...

  9. Pytorch系列教程-使用字符级RNN生成姓名

    前言 本系列教程为pytorch官网文档翻译.本文对应官网地址:https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutor ...

随机推荐

  1. c经典算法

    1. 河内之塔 说明 河内之塔(Towers of Hanoi)是法国人M.Claus(Lucas)于1883年从泰国带至法国的,河内为越战时 北越的首都,即现在的胡志明市:1883年法国数学家 Ed ...

  2. 匿名内部类 Inner class

    先说结论 匿名内部类分两种,一种是接口的匿名实现,一种是类的匿名子类!后者往往用于修改特定方法. 再说起因 本来以为匿名内部类很简单,就是接口的匿名实现,直到我发现了下面这段代码: public cl ...

  3. 【linux】硬盘分区

    fdisk -l fdisk /dev/sda d--删除分区 n-新建分区 p--主分区 e--扩展分区 t--改变分区格式 82为swap分区 w--保存退出 http://www.blogjav ...

  4. java 反序列化漏洞检测及修复

    Jboss.Websphere和weblogic的反序列化漏洞已经出来一段时间了,还是有很多服务器没有解决这个漏洞: 反序列化漏洞原理参考:JAVA反序列化漏洞完整过程分析与调试 这里参考了网上的 J ...

  5. a &a &a[0]之间的区别和联系

    数组中,a为数组的首地址,&a[0]为数组第一个元素的地址. 所以 a == &a[0] 但是,&a又是什么东西呢? 我们来做下面的代码测试: #include <std ...

  6. 第二十一篇:Linux 操作系统中的进程结构

    前言 在 Linux 中,一个正在执行的程序往往由各种各样的进程组成,这些进程除了父子关系,还有其他的关系.依赖于这些关系,所有进程构成一个整体,给用户提供完整的服务( 考虑到了终端,即与用户的交互 ...

  7. NestedScrollView,RecyclerView

    为什么把它们放一起呢, 是因为它有着相同的特点 在新版的support-v4兼容包里面有一个NestedScrollView控件,这个控件其实和普通的ScrollView并没有多大的区别,这个控件其实 ...

  8. VS2010类模板更改,增加版权等等信息

    本文转载自XDOTNET 在开发过程中往往需要在每一个页面(类)增加注释等等内容,VS2010中可以修改模板,在原有模板中增加一个类,会引用System等等命名空间,以及一些程序集.下面我们来看看如何 ...

  9. 170329、用 Maven 部署 war 包到远程 Tomcat 服务器

    过去我们发布一个Java Web程序通常的做法就是把它打成一个war包,然后用SSH这样的工具把它上传到服务器,并放到相应的目录里,让Tomcat自动去解包,完成部署. 很显然,这样做不够方便,且我们 ...

  10. CEF3 HTML5 audio标签为什么不能播放mp3格式的音频文件

    CEF3 HTML5 audio标签 为什么不能播放mp3格式的音频文件   原因略.   解决方法: 找一个最新版的chrome ,我用的是24版本.路径 C:\Documents and Sett ...