import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # torch.manual_seed(1) # reproducible x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1) # torch can only train on Variable, so convert them to Variable
# The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
# x, y = Variable(x), Variable(y) # plt.scatter(x.data.numpy(), y.data.numpy())
# plt.show() class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
self.predict = torch.nn.Linear(n_hidden, n_output) # output layer def forward(self, x):
x = F.relu(self.hidden(x)) # activation function for hidden layer
x = self.predict(x) # linear output
return x net = Net(n_feature=1, n_hidden=10, n_output=1) # define the network
print(net) # net architecture optimizer = torch.optim.SGD(net.parameters(), lr=0.2)
loss_func = torch.nn.MSELoss() # this is for regression mean squared loss plt.ion() # something about plotting for t in range(200):
prediction = net(x) # input x and predict based on x loss = loss_func(prediction, y) # must be (1. nn output, 2. target) optimizer.zero_grad() # clear gradients for next train
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients if t % 5 == 0:
# plot and show learning process
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
plt.pause(0.1) plt.ioff()
plt.show()

pytorch之 regression的更多相关文章

  1. pytorch 4 regression 回归

    import torch import torch.nn.functional as F import matplotlib.pyplot as plt # torch.manual_seed(1) ...

  2. Linear Regression with PyTorch

    Linear Regression with PyTorch Problem Description 初始化一组数据 \((x,y)\),使其满足这样的线性关系 \(y = w x + b\) .然后 ...

  3. Task3.PyTorch实现Logistic regression

    1.PyTorch基础实现代码 import torch from torch.autograd import Variable torch.manual_seed(2) x_data = Varia ...

  4. pytorch之 RNN regression

    关于RNN模型参数的解释,可以参看RNN参数解释 1 import torch from torch import nn import numpy as np import matplotlib.py ...

  5. (转)Extracting knowledge from knowledge graphs using Facebook Pytorch BigGraph.

    Extracting knowledge from knowledge graphs using Facebook Pytorch BigGraph 2019-04-27 09:33:58 This ...

  6. PyTorch(一)Basics

    PyTorch Basics import torch import torchvision import torch.nn as nn import numpy as np import torch ...

  7. (转) The Incredible PyTorch

    转自:https://github.com/ritchieng/the-incredible-pytorch The Incredible PyTorch What is this? This is ...

  8. Pytorch自定义dataloader以及在迭代过程中返回image的name

    pytorch官方给的加载数据的方式是已经定义好的dataset以及loader,如何加载自己本地的图片以及label? 形如数据格式为 image1 label1 image2 label2 ... ...

  9. 【深度学习】Pytorch 学习笔记

    目录 Pytorch Leture 05: Linear Rregression in the Pytorch Way Logistic Regression 逻辑回归 - 二分类 Lecture07 ...

随机推荐

  1. python的list()函数

    list()函数将其它序列转换为 列表 (就是js的数组). 该函数不会改变   其它序列 效果图一: 代码一: # 定义一个元组序列 tuple_one = (123,','abc') print( ...

  2. 一文读懂MapReduce 附流量解析实例

    1.MapReduce是什么 Hadoop MapReduce是一个软件框架,基于该框架能够容易地编写应用程序,这些应用程序能够运行在由上千个商用机器组成的大集群上,并以一种可靠的,具有容错能力的方式 ...

  3. Spring Boot2 系列教程 (十一) | 整合数据缓存 Cache

    如题,今天介绍 SpringBoot 的数据缓存.做过开发的都知道程序的瓶颈在于数据库,我们也知道内存的速度是大大快于硬盘的,当需要重复获取相同数据时,一次又一次的请求数据库或者远程服务,导致大量时间 ...

  4. 使用ajax向后台发送请求跳转页面无效的原因

    Ajax只是利用脚本访问对应url获取数据而已,不能做除了获取返回数据以外的其它动作了.所以浏览器端是不会发起重定向的. 1)正常的http url请求,只有浏览器和服务器两个参与者.浏览器端发起一个 ...

  5. Django 导入配置文件

    from django.conf import settings

  6. W3C 带来了一个新的语言

    2019年12月5日,W3C 宣布: WebAssembly 核心规范 正式成为 Web 官方标准. 继 HTML, CSS, JavaScript 之后,WebAssembly 成为了第4个 Web ...

  7. 51Nod 1238 - 最小公倍数之和 V3(毒瘤数学+杜教筛)

    题目 戳这里 推导 ∑i=1n∑j=1nlcm(i,j)~~~\sum_{i=1}^{n}\sum_{j=1}^{n}lcm(i,j)   ∑i=1n​∑j=1n​lcm(i,j) =∑i=1n∑j= ...

  8. Java Swing图形界面开发

    本文转自xietansheng的CSDN博客内容,这是自己见过的最通俗易懂.最适合快速上手做Java GUI开发的教程了,这里整合一下作为自己以后复习的笔记: 原文地址:https://blog.cs ...

  9. DFS或BFS(深度优先搜索或广度优先搜索遍历无向图)-04-无向图-岛屿数量

    给定一个由 '1'(陆地)和 '0'(水)组成的的二维网格,计算岛屿的数量.一个岛被水包围,并且它是通过水平方向或垂直方向上相邻的陆地连接而成的.你可以假设网格的四个边均被水包围. 示例 1: 输入: ...

  10. CSS-07-CSS文本设置

    <!DOCTYPE html> <html> <head> <meta charset="UTF-8"> <title> ...