1.引入模块,读取数据 

2.构建计算图(构建网络模型)

3.损失函数与优化器

4.开始训练模型

5.对训练的模型预测结果进行评估

 import torch.nn.functional as F
import torch.nn.init as init
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np
import math
%matplotlib inline
#%matplotlib inline 可以在Ipython编译器里直接使用
#功能是可以内嵌绘图,并且可以省略掉plt.show()这一步。 xy=np.loadtxt('./data/diabetes.csv.gz',delimiter=',',dtype=np.float32)
x_data=torch.from_numpy(xy[:,0:-1])#取除了最后一列的数据
y_data=torch.from_numpy(xy[:,[-1]])#取最后一列的数据,[-1]加中括号是为了keepdim print(x_data.size(),y_data.size())
#print(x_data.shape,y_data.shape) #建立网络模型
class Model(torch.nn.Module): def __init__(self):
super(Model,self).__init__()
self.l1=torch.nn.Linear(8,6)
self.l2=torch.nn.Linear(6,4)
self.l3=torch.nn.Linear(4,1) def forward(self,x):
out1=F.relu(self.l1(x))
out2=F.dropout(out1,p=0.5)
out3=F.relu(self.l2(out2))
out4=F.dropout(out3,p=0.5)
y_pred=F.sigmoid(self.l3(out3))
return y_pred def weights_init(m):
classname=m.__class__.__name__
if classname.find('Linear')!=-1:
m.weight.data=torch.randn(m.weight.data.size()[0],m.weight.data.size()[1])
m.bias.data=torch.randn(m.bias.data.size()[0]) #our model
model=Model()
model.apply(weights_init)
criterion=torch.nn.BCELoss()
optimizer=torch.optim.SGD(model.parameters(),lr=0.1) #training loop
Loss=[]
for epoch in range(2000):
y_pred=model(x_data)
loss=criterion(y_pred,y_data)
if epoch%100 == 0:
print("epoch = ",epoch," loss = ",loss.data)
Loss.append(loss.data)
optimizer.zero_grad()
loss.backward()
optimizer.step() hour_var = Variable(torch.randn(1,8))
print("predict",model(hour_var).data[0]>0.5)
plt.plot(Loss)

这里说明一下,这个dataset不是自带的,需要大家自己去下载,我找的时候费了不少功夫,这里提供一个网址给大家下载https://github.com/LianHaiMiao/pytorch-lesson-zh/blob/master/dataSet/diabetes.csv.gz
参考:https://blog.csdn.net/qq_35547281/article/details/89285980

Task4.用PyTorch实现多层网络的更多相关文章

  1. 神经网络:多层网络与C++实现

    相关源码可参考最新的实现:https://github.com/ronnyyoung/EasyML ,中的neural_network模块,后持续更新,包括加入CNN的结构. 一.引言 在前一篇关于神 ...

  2. MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...

  3. 简单的RNN和BP多层网络之间的区别

    先来个简单的多层网络 RNN的原理和出现的原因,解决什么场景的什么问题 关于RNN出现的原因,RNN详细的原理,已经有很多博文讲解的非常棒了. 如下: http://ai.51cto.com/art/ ...

  4. PyTorch对ResNet网络的实现解析

    PyTorch对ResNet网络的实现解析 1.首先导入需要使用的包 import torch.nn as nn import torch.utils.model_zoo as model_zoo # ...

  5. pytorch空间变换网络

    pytorch空间变换网络 本文将学习如何使用称为空间变换器网络的视觉注意机制来扩充网络.可以在DeepMind paper 阅读更多有关空间变换器网络的内容. 空间变换器网络是对任何空间变换的差异化 ...

  6. 详解Pytorch中的网络构造,模型save和load,.pth权重文件解析

    转载:https://zhuanlan.zhihu.com/p/53927068 https://blog.csdn.net/wangdongwei0/article/details/88956527 ...

  7. [pytorch笔记] 调整网络学习率

    1. 为网络的不同部分指定不同的学习率 class LeNet(t.nn.Module): def __init__(self): super(LeNet, self).__init__() self ...

  8. pytorch搭建简单网络

    pytorch搭建一个简单神经网络 import torch import torch.nn as nn # 定义数据 # x:输入数据 # y:标签 x = torch.Tensor([[0.2, ...

  9. pytorch实现AlexNet网络

    直接上图吧 写网络就像搭积木

随机推荐

  1. 阶段3 1.Mybatis_09.Mybatis的多表操作_7 mybatis多对多准备角色表的实体类和映射配置

    创建Role表和user_role表 DROP TABLE IF EXISTS `role`; CREATE TABLE `role` ( `ID` int(11) NOT NULL COMMENT ...

  2. fiddler之数据分析和查看(inspectors)-抓包

    在instpectors中主要是对请求和响应进行查看和分享,监听请求的响应内容.他有多个分页标签.界面分上下两部分,上面部分显示请求的相关信息:下面部分显示响应相关信息.简单说明常用的几个分页标签 一 ...

  3. 中国MOOC_零基础学Java语言_第1周 计算_第1周编程题_1温度转换

    第1周编程题 依照学术诚信条款,我保证此作业是本人独立完成的. 温馨提示: 1.本次作业属于Online Judge题目,提交后由系统即时判分. 2.学生可以在作业截止时间之前不限次数提交答案,系统将 ...

  4. zstack分配的虚拟机的dns设置

    环境: $ uname -a Linux 10-57-19-61 2.6.32-504.el6.x86_64 #1 SMP Wed Oct 15 04:27:16 UTC 2014 x86_64 x8 ...

  5. jmeter链接数据库操作

    jmeter链接数据库操作步骤 首先要先下载mysql-connector-java-5.1.39-bin.jar驱动包 链接:https://pan.baidu.com/s/14F4rp4uH1hX ...

  6. Java回调机制的理解

    用一句话讲明回调机制就是,在A类里面拥有一个类B的对象,调用B类的某个方法并把自身引用传入,在B类的这个方法里面又通过传进来的A的引用来调用A类的某个方法(这个最后调用的A类的方法就叫做回调方法). ...

  7. 20191112 Spring Boot官方文档学习(4.5-4.6)

    4.5.国际化 Spring Boot支持本地化消息,因此您的应用程序可以迎合不同语言首选项的用户.默认情况下,Spring Boot messages在类路径的根目录下查找message resou ...

  8. ascx

    aspx是页面文件ascx是用户控件,用户控件必须嵌入到aspx中才能使用. ascx是用户控件,相当于模板 其实ascx你可以理解为Html里的一部分代码,只是嵌到aspx里而已,因为aspx内容多 ...

  9. springboot swagger教程😀

    传送门开启:https://www.ibm.com/developerworks/cn/java/j-using-swagger-in-a-spring-boot-project/index.html

  10. 编写Servlet步骤以及Servlet生命周期是怎样的

    一.编写Servlet步骤 1.继承HttpServlet,HttpServlet在javax-servlet-api依赖下 2.重写doGet()或者doPost()方法 3.在web.xml中注册 ...