import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # torch.manual_seed(1) # reproducible x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1) # torch can only train on Variable, so convert them to Variable
# The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
# x, y = Variable(x), Variable(y) # plt.scatter(x.data.numpy(), y.data.numpy())
# plt.show() class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
self.predict = torch.nn.Linear(n_hidden, n_output) # output layer def forward(self, x):
x = F.relu(self.hidden(x)) # activation function for hidden layer
x = self.predict(x) # linear output
return x net = Net(n_feature=1, n_hidden=10, n_output=1) # define the network
print(net) # net architecture optimizer = torch.optim.SGD(net.parameters(), lr=0.2)
loss_func = torch.nn.MSELoss() # this is for regression mean squared loss plt.ion() # something about plotting for t in range(200):
prediction = net(x) # input x and predict based on x loss = loss_func(prediction, y) # must be (1. nn output, 2. target) optimizer.zero_grad() # clear gradients for next train
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients if t % 5 == 0:
# plot and show learning process
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
plt.pause(0.1) plt.ioff()
plt.show()

pytorch之 regression的更多相关文章

  1. pytorch 4 regression 回归

    import torch import torch.nn.functional as F import matplotlib.pyplot as plt # torch.manual_seed(1) ...

  2. Linear Regression with PyTorch

    Linear Regression with PyTorch Problem Description 初始化一组数据 \((x,y)\),使其满足这样的线性关系 \(y = w x + b\) .然后 ...

  3. Task3.PyTorch实现Logistic regression

    1.PyTorch基础实现代码 import torch from torch.autograd import Variable torch.manual_seed(2) x_data = Varia ...

  4. pytorch之 RNN regression

    关于RNN模型参数的解释,可以参看RNN参数解释 1 import torch from torch import nn import numpy as np import matplotlib.py ...

  5. (转)Extracting knowledge from knowledge graphs using Facebook Pytorch BigGraph.

    Extracting knowledge from knowledge graphs using Facebook Pytorch BigGraph 2019-04-27 09:33:58 This ...

  6. PyTorch(一)Basics

    PyTorch Basics import torch import torchvision import torch.nn as nn import numpy as np import torch ...

  7. (转) The Incredible PyTorch

    转自:https://github.com/ritchieng/the-incredible-pytorch The Incredible PyTorch What is this? This is ...

  8. Pytorch自定义dataloader以及在迭代过程中返回image的name

    pytorch官方给的加载数据的方式是已经定义好的dataset以及loader,如何加载自己本地的图片以及label? 形如数据格式为 image1 label1 image2 label2 ... ...

  9. 【深度学习】Pytorch 学习笔记

    目录 Pytorch Leture 05: Linear Rregression in the Pytorch Way Logistic Regression 逻辑回归 - 二分类 Lecture07 ...

随机推荐

  1. 阿里开源服务发现组件 Nacos快速入门

    最近几年随着云计算和微服务不断的发展,各大云厂商也都看好了微服务解决方案这个市场,纷纷推出了自己针对微服务上云架构的解决方案,并且诞生了云原生,Cloud Native的概念. 云原生是一种专门针对云 ...

  2. SPFA判负环模板

    void DFS_SPFA(int u){   if(flag) return; vis[u]=true;   for(int i=head[u];i;i=edges[i].nxt){   if(fl ...

  3. dp-多重背包

    (推荐 : http://blog.csdn.net/insistgogo/article/details/11176693 ) 学会了前两个背包 , 学这个背包还是很轻松的 . 多重背包 , 顾名思 ...

  4. 线段树 or 并查集 (多一个时间截点)

    There is a company that has N employees(numbered from 1 to N),every employee in the company has a im ...

  5. 贪心 + DFS

    A New Year party is not a New Year party without lemonade! As usual, you are expecting a lot of gues ...

  6. 【 Tomcat 】tomcat8.0 基本参数调优配置-----(1)

    Tomcat 的缺省配置是不能稳定长期运行的,也就是不适合生产环境,它会死机,让你不断重新启动,甚至在午夜时分唤醒你.对于操作系统优化来说,是尽可能的增大可使用的内存容量.提高CPU 的频率,保证文件 ...

  7. 关于爬虫的日常复习(9)—— 实战:分析Ajax抓取今日头条接拍美图

  8. 5、python基本数据类型之数值类型

    前言:python的基本数据类型可以分为三类:数值类型.序列类型.散列类型,本文主要介绍数值类型. 一.数值类型 数值类型有四种: 1)整数(int):整数 2)浮点数(float):小数 3)布尔值 ...

  9. 简单实现Android手机“全局可调试”(ro.debuggable = 1)的方法【锤子坚果3】

    在Android真机上调试程序有一个前提,就是这个apk包必须有 debuggable=true 的属性才行.而除了自己开发的apk能够控制打包属性之外,其他的程序发行之后显然不会设这个值为 true ...

  10. CTF--HTTP服务--SQL注入-X-Forwarded-For报文头

    开门见山 1. 扫描靶场ip,发现PCS 192.168.31.196 2. 扫描靶场开放服务信息 3. 扫描靶场全部信息 4. 探测敏感信息 5. 查看靶场80端口的主界面 6. 使用AVWS工具进 ...