pytorch之 regression
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # torch.manual_seed(1) # reproducible x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1) # torch can only train on Variable, so convert them to Variable
# The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
# x, y = Variable(x), Variable(y) # plt.scatter(x.data.numpy(), y.data.numpy())
# plt.show() class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
self.predict = torch.nn.Linear(n_hidden, n_output) # output layer def forward(self, x):
x = F.relu(self.hidden(x)) # activation function for hidden layer
x = self.predict(x) # linear output
return x net = Net(n_feature=1, n_hidden=10, n_output=1) # define the network
print(net) # net architecture optimizer = torch.optim.SGD(net.parameters(), lr=0.2)
loss_func = torch.nn.MSELoss() # this is for regression mean squared loss plt.ion() # something about plotting for t in range(200):
prediction = net(x) # input x and predict based on x loss = loss_func(prediction, y) # must be (1. nn output, 2. target) optimizer.zero_grad() # clear gradients for next train
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients if t % 5 == 0:
# plot and show learning process
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
plt.pause(0.1) plt.ioff()
plt.show()
pytorch之 regression的更多相关文章
- pytorch 4 regression 回归
import torch import torch.nn.functional as F import matplotlib.pyplot as plt # torch.manual_seed(1) ...
- Linear Regression with PyTorch
Linear Regression with PyTorch Problem Description 初始化一组数据 \((x,y)\),使其满足这样的线性关系 \(y = w x + b\) .然后 ...
- Task3.PyTorch实现Logistic regression
1.PyTorch基础实现代码 import torch from torch.autograd import Variable torch.manual_seed(2) x_data = Varia ...
- pytorch之 RNN regression
关于RNN模型参数的解释,可以参看RNN参数解释 1 import torch from torch import nn import numpy as np import matplotlib.py ...
- (转)Extracting knowledge from knowledge graphs using Facebook Pytorch BigGraph.
Extracting knowledge from knowledge graphs using Facebook Pytorch BigGraph 2019-04-27 09:33:58 This ...
- PyTorch(一)Basics
PyTorch Basics import torch import torchvision import torch.nn as nn import numpy as np import torch ...
- (转) The Incredible PyTorch
转自:https://github.com/ritchieng/the-incredible-pytorch The Incredible PyTorch What is this? This is ...
- Pytorch自定义dataloader以及在迭代过程中返回image的name
pytorch官方给的加载数据的方式是已经定义好的dataset以及loader,如何加载自己本地的图片以及label? 形如数据格式为 image1 label1 image2 label2 ... ...
- 【深度学习】Pytorch 学习笔记
目录 Pytorch Leture 05: Linear Rregression in the Pytorch Way Logistic Regression 逻辑回归 - 二分类 Lecture07 ...
随机推荐
- 阿里开源服务发现组件 Nacos快速入门
最近几年随着云计算和微服务不断的发展,各大云厂商也都看好了微服务解决方案这个市场,纷纷推出了自己针对微服务上云架构的解决方案,并且诞生了云原生,Cloud Native的概念. 云原生是一种专门针对云 ...
- SPFA判负环模板
void DFS_SPFA(int u){ if(flag) return; vis[u]=true; for(int i=head[u];i;i=edges[i].nxt){ if(fl ...
- dp-多重背包
(推荐 : http://blog.csdn.net/insistgogo/article/details/11176693 ) 学会了前两个背包 , 学这个背包还是很轻松的 . 多重背包 , 顾名思 ...
- 线段树 or 并查集 (多一个时间截点)
There is a company that has N employees(numbered from 1 to N),every employee in the company has a im ...
- 贪心 + DFS
A New Year party is not a New Year party without lemonade! As usual, you are expecting a lot of gues ...
- 【 Tomcat 】tomcat8.0 基本参数调优配置-----(1)
Tomcat 的缺省配置是不能稳定长期运行的,也就是不适合生产环境,它会死机,让你不断重新启动,甚至在午夜时分唤醒你.对于操作系统优化来说,是尽可能的增大可使用的内存容量.提高CPU 的频率,保证文件 ...
- 关于爬虫的日常复习(9)—— 实战:分析Ajax抓取今日头条接拍美图
- 5、python基本数据类型之数值类型
前言:python的基本数据类型可以分为三类:数值类型.序列类型.散列类型,本文主要介绍数值类型. 一.数值类型 数值类型有四种: 1)整数(int):整数 2)浮点数(float):小数 3)布尔值 ...
- 简单实现Android手机“全局可调试”(ro.debuggable = 1)的方法【锤子坚果3】
在Android真机上调试程序有一个前提,就是这个apk包必须有 debuggable=true 的属性才行.而除了自己开发的apk能够控制打包属性之外,其他的程序发行之后显然不会设这个值为 true ...
- CTF--HTTP服务--SQL注入-X-Forwarded-For报文头
开门见山 1. 扫描靶场ip,发现PCS 192.168.31.196 2. 扫描靶场开放服务信息 3. 扫描靶场全部信息 4. 探测敏感信息 5. 查看靶场80端口的主界面 6. 使用AVWS工具进 ...