Task4.用PyTorch实现多层网络
1.引入模块,读取数据
2.构建计算图(构建网络模型)
3.损失函数与优化器
4.开始训练模型
5.对训练的模型预测结果进行评估
import torch.nn.functional as F
import torch.nn.init as init
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np
import math
%matplotlib inline
#%matplotlib inline 可以在Ipython编译器里直接使用
#功能是可以内嵌绘图,并且可以省略掉plt.show()这一步。 xy=np.loadtxt('./data/diabetes.csv.gz',delimiter=',',dtype=np.float32)
x_data=torch.from_numpy(xy[:,0:-1])#取除了最后一列的数据
y_data=torch.from_numpy(xy[:,[-1]])#取最后一列的数据,[-1]加中括号是为了keepdim print(x_data.size(),y_data.size())
#print(x_data.shape,y_data.shape) #建立网络模型
class Model(torch.nn.Module): def __init__(self):
super(Model,self).__init__()
self.l1=torch.nn.Linear(8,6)
self.l2=torch.nn.Linear(6,4)
self.l3=torch.nn.Linear(4,1) def forward(self,x):
out1=F.relu(self.l1(x))
out2=F.dropout(out1,p=0.5)
out3=F.relu(self.l2(out2))
out4=F.dropout(out3,p=0.5)
y_pred=F.sigmoid(self.l3(out3))
return y_pred def weights_init(m):
classname=m.__class__.__name__
if classname.find('Linear')!=-1:
m.weight.data=torch.randn(m.weight.data.size()[0],m.weight.data.size()[1])
m.bias.data=torch.randn(m.bias.data.size()[0]) #our model
model=Model()
model.apply(weights_init)
criterion=torch.nn.BCELoss()
optimizer=torch.optim.SGD(model.parameters(),lr=0.1) #training loop
Loss=[]
for epoch in range(2000):
y_pred=model(x_data)
loss=criterion(y_pred,y_data)
if epoch%100 == 0:
print("epoch = ",epoch," loss = ",loss.data)
Loss.append(loss.data)
optimizer.zero_grad()
loss.backward()
optimizer.step() hour_var = Variable(torch.randn(1,8))
print("predict",model(hour_var).data[0]>0.5)
plt.plot(Loss)
这里说明一下,这个dataset不是自带的,需要大家自己去下载,我找的时候费了不少功夫,这里提供一个网址给大家下载https://github.com/LianHaiMiao/pytorch-lesson-zh/blob/master/dataSet/diabetes.csv.gz
参考:https://blog.csdn.net/qq_35547281/article/details/89285980
Task4.用PyTorch实现多层网络的更多相关文章
- 神经网络:多层网络与C++实现
相关源码可参考最新的实现:https://github.com/ronnyyoung/EasyML ,中的neural_network模块,后持续更新,包括加入CNN的结构. 一.引言 在前一篇关于神 ...
- MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)
版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...
- 简单的RNN和BP多层网络之间的区别
先来个简单的多层网络 RNN的原理和出现的原因,解决什么场景的什么问题 关于RNN出现的原因,RNN详细的原理,已经有很多博文讲解的非常棒了. 如下: http://ai.51cto.com/art/ ...
- PyTorch对ResNet网络的实现解析
PyTorch对ResNet网络的实现解析 1.首先导入需要使用的包 import torch.nn as nn import torch.utils.model_zoo as model_zoo # ...
- pytorch空间变换网络
pytorch空间变换网络 本文将学习如何使用称为空间变换器网络的视觉注意机制来扩充网络.可以在DeepMind paper 阅读更多有关空间变换器网络的内容. 空间变换器网络是对任何空间变换的差异化 ...
- 详解Pytorch中的网络构造,模型save和load,.pth权重文件解析
转载:https://zhuanlan.zhihu.com/p/53927068 https://blog.csdn.net/wangdongwei0/article/details/88956527 ...
- [pytorch笔记] 调整网络学习率
1. 为网络的不同部分指定不同的学习率 class LeNet(t.nn.Module): def __init__(self): super(LeNet, self).__init__() self ...
- pytorch搭建简单网络
pytorch搭建一个简单神经网络 import torch import torch.nn as nn # 定义数据 # x:输入数据 # y:标签 x = torch.Tensor([[0.2, ...
- pytorch实现AlexNet网络
直接上图吧 写网络就像搭积木
随机推荐
- 测开之路一百四十四:ORM之SQLAlchemy查询
在上一篇的基础上,插入数据 查询 Department.query.all() # 用表对象查db.session.query(Department).all() # 用db对象查 查询前两条,直接p ...
- [数据结构] 2.3 Trie树
抱歉更新晚了,看了几天三体,2333,我们继续数据结构之旅. 一.什么是Tire树? Tire树有很多名字:字典树.单词查找树. 故名思意,它就是一本”字典“,当我们查找"word" ...
- ARTS-1
ARTS的初衷 Algorithm:主要是为了编程训练和学习.每周至少做一个 leetcode 的算法题(先从Easy开始,然后再Medium,最后才Hard).进行编程训练,如果不训练你看再多的算法 ...
- cocos2dx基础篇(14) 滚动视图CCScrollView
[3.x] (1)去掉 "CC" (2)滚动方向 > CCScrollViewDirection 改为强枚举 ScrollView::Dire ...
- java文件编译后,出现xx$1.class的原因
java编译后的文件名字带有$接数字的就是匿名内部类的编译结果,接名字的就是内部类的编译结果 例如:TestFrame$1.class是匿名内部类的编译结果, TestFrame$MyJob.clas ...
- 如何在centos7中显示/etc/目录下以非字母开头,后面跟了一个字母及其它任意字符的文件或目录
ls /etc |grep "^[^[:alpha:]][[:alpha:]].*"
- 一分钟安装mysql
学数据库的人都知道,MySQL数据库是比较基本的掌握要求,不仅开源而且社区版本是免费使用的.由于工作上或者经常更换系统的原因,有时候会需要安装MySQL数据库.为了不至于每次安装都要查阅资料,现把安装 ...
- 虚拟机的网卡基本配置和基本linux命令
1.切换到/etc/sysconfig/network-script目录 cd /etc/sysconfig/network-scripts 2.将ifcfg-eth0备份成ifcfg-eth0. c ...
- mybatis报错(三)报错Result Maps collection does not contain value for java.lang.Integer解决方法
转自:https://blog.csdn.net/zengdeqing2012/article/details/50978682 1 [WARN ] 2016-03-25 13:03:23,955 - ...
- Linux系统性能测试工具(二)——内存压力测试工具memtester
本文介绍关于Linux系统(适用于centos/ubuntu等)的内存压力测试工具-memtester.内存性能测试工具包括: 内存带宽测试工具——mbw: 内存压力测试工具——memtester: ...