pytorch_01_基础_一维线性回归
pytorch基础
pytorch官方文档:https://pytorch.org/docs/master/nn.html#linear-layers
import torch
from torch import nn
1 tensor:张量,表示一个多维的矩阵。
b.numpy()能将b转换为numpy数据类型,同时使用torch.from_num py()将numpy转换为tensor。若将a的类型转换成float,a.float().
b.numpy():将b转换成numpy数据类型
torch.from_num py():把numpy转换成tensor
如果需要更改tensor的数据类型,只需要在转换后的tensor后面加上你需要的类型,eg:将a转换成float,
只需a.float()
如果电脑支持GPU,可以将Tensor 放到GPU上
torch.cuda.is_available()判断一下是否支持GPU,
如果把tensor放到GPU上,只需a.cuda()就把tensor a 放到GPU上
if torch.cuda.is_available():
a_cuda = a.cuda()
print(a_cuda)
2 torch.autograd.Variable(变量):
Variable提供了自动求导的功能,在做运算的时候需要构造一个计算图谱,然后再里面进行前向传播和反向传播
Variable和Tensor本质上没有区别,不过Variable会被放入一个计算图中,然后进行前向传播,反向传播,自动求导
Variable是在torch.autograd.variable中,将一个tensor编程Variable,Variable(a),有三个重要的组成属性:data,grad,grad_fn。
data:取出Variable里面的tensor数值,
grad_fn:得到这个Variable的操作,比如通过加减还是乘除得到,
grad:Variable的反向传播梯度
例如:
x = Variable(torch.Tensor([1],requires_grad=True)
w = Variable(torch.Tensor([2],requires_grad=True)
b = Variable(torch.Tensor([3],requires_grad=True)
y = w*x +b
y.backward()自动求导
print(x.grad)x的梯度
3.
import torch
x = torch.rand(3)
"""
读取数据
torch.utils.data.DataLoader定义一个新的迭代器
dataiter = DataLoader(myDataset,batch_size = 32,shuffle = True,collate_fn = default_colllate)
collate_fn:如何取样本
读取图片:
ImageFolder:处理图片
dset = ImageFolder(root = 'root_path',transform = None,loader = default_loader)
root :根目录,这个目录下有几个文件夹,每个文件夹表示一个类别:transform和target_transform是图片增强
loader是图片读取的办法,因为我们读取的是图片的名字,然后通过loader将图片转换成我们需要的图片类型
进入神经网络。
"""
"""
所有的层结构和损失函数都来自于torch.nn,所有的模型构建都是从这个基类nn.Module继承的
"""
# class net_name(nn.Module):
# def __init__(self,other_arguments):
# super(net_name,self).__init__()
# self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size)
# def forward(self,x):
# x = self.conv1(x)
# return x
"""
所有的网络层都是由nn这个包得到的,eg,nn.Linear
定义好模型之后,通过nn这个包定义损失函数
criterion =nn.CrossEntropyLoss()
loss = criterion(output,target)
"""
""" 4. torch.optim(优化):
实现各种优化的包,
optimizer = torch.optim.SGD(model.parameters(),lr = 0.01,momentum = 0.9)
学习率是0.01,动量是0.9的随机梯度下降,在优化之前将梯度归零,optimizer.zeros(),通过loss.backward()反向传播,
自动求导得到每个参数的梯度,最后需要optimizer.step()就可以通过梯度做一步参数更新。
梯度:导数的多变量表达式,函数的梯度形成了一个向量场,同时也是一个方向,这个方向上导数最大,且等于梯度。
"""
5. 线性模型
一维线性回归:f(xi) = wxi+b
Loss = Σ(f(xi)-yi)2
例子:
一维线性回归
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
# import torch.nn as nn
x_train = np.array([[3.3],[4.4],[5.5],[6.71],[6.93],[4.168],[9.779],[6.182],[7.59],[2.167],[7.042],[10.791],[5.313],[7.997],[3.1]],dtype = np.float32)
y_train = np.array([[1.7],[2.76],[2.09],[3.19],[1.694],[1.573],[3.366],[2.596],[2.53],[1.221],[2.827],[3.465],[1.65],[2.904],[1.3]],dtype = np.float32)
x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train) class LinearRegression(nn.Module):
def __init__(self):
super(LinearRegression,self).__init__()
self.linear = nn.Linear(1,1)#输出都是一维
def forward(self, x):
out = self.linear(x)
return out
if torch.cuda.is_available():
model = LinearRegression().cuda()#如果支持GPU加速,可以通过model.cuda()将模型放到GPU上
else:
model = LinearRegression() criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),lr = 1e-3) # 训练模型
num_epochs = 100
for epoch in range(num_epochs):
if torch.cuda.is_available():
inputs = torch.autograd.Variable(x_train).cuda()
target = torch.autograd.Variable(y_train).cuda()
else:
inputs = torch.autograd.Variable(x_train)#将数据变成Variable放入计算图
target = torch.autograd.Variable(y_train)
#前向传播
out = model(inputs) # 得到网络前向传播的结果
loss = criterion(out,target)# 得到损失函数
#backword
optimizer.zero_grad()#归零梯度,每次做反向传播之前都要归零梯度,不然梯度会累加在一起,造成结果不收敛
loss.backward()
optimizer.step()
if (epoch+1) % 20 == 0:#loss.data[0]:loss是一个Variable,通过loss.data可以取出一个Tensor,通过loss.data[0]得到一个int或者float类型的数据
# print(num_epochs)
print('epoch[{}/{}],loss:{:.6f}'.format(epoch+1,num_epochs,loss))
"""
#在训练完train样本后,生成模型model要用来测试样本,在model(test)之前,需要加上model.eval(),
否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有batch normalization层所带来的的性质。
"""
model.eval()
predict = model(torch.autograd.Variable(x_train))
predict = predict.data.numpy()
plt.plot(x_train.numpy(),y_train.numpy(),'ro',label = 'Original data')
plt.plot(x_train.numpy(),predict,label = 'Filtting Line')
plt.show()
print('finished')
pytorch_01_基础_一维线性回归的更多相关文章
- 算法基础_递归_求杨辉三角第m行第n个数字
问题描述: 算法基础_递归_求杨辉三角第m行第n个数字(m,n都从0开始) 解题源代码(这里打印出的是杨辉三角某一层的所有数字,没用大数,所以有上限,这里只写基本逻辑,要符合题意的话,把循环去掉就好) ...
- 稳定排序nlogn之归并排序_一维,二维
稳定排序nlogn之归并排序_一维,二维 稳定排序:排序时间稳定的排序 稳定排序包括:归并排序(nlogn),基数排序[设待排序列为n个记录,d个关键码,关键码的取值范围为radix,则进行链式基数排 ...
- Flutter实战视频-移动电商-05.Dio基础_引入和简单的Get请求
05.Dio基础_引入和简单的Get请求 博客地址: https://jspang.com/post/FlutterShop.html#toc-4c7 第三方的http请求库叫做Dio https:/ ...
- Flutter实战视频-移动电商-08.Dio基础_伪造请求头获取数据
08.Dio基础_伪造请求头获取数据 上节课代码清楚 重新编写HomePage这个动态组件 开始写请求的方法 请求数据 .但是由于我们没加请求的头 所以没有返回数据 451就是表示请求错错误 创建请求 ...
- [Zlib]_[0基础]_[使用zlib库压缩文件]
场景: 1. WIndows上没找到系统提供的win32 api来生成zip压缩文件, 有知道的大牛麻烦留个言. 2. zlib比較经常使用,编译也方便,使用它来做压缩吧. MacOSX平台默认支持z ...
- python基础_格式化输出(%用法和format用法)(转载)
python基础_格式化输出(%用法和format用法) 目录 %用法 format用法 %用法 1.整数的输出 %o -- oct 八进制%d -- dec 十进制%x -- hex 十六进制 &g ...
- 【转】opencv检测运动物体的基础_特征提取
特征提取是计算机视觉和图像处理中的一个概念.它指的是使用计算机提取图像信息,决定每个图像的点是否属于一个图像特征.特征提取的结果是把图像上的点分为不同的子集,这些子集往往属于孤立的点.连续的曲线或者连 ...
- 02_Java基础_第2天(变量、运算符)_讲义
今日内容介绍 1.变量 2.运算符 01变量概述 * A: 什么是变量? * a: 变量是一个内存中的小盒子(小容器),容器是什么?生活中也有很多容器, * 例如水杯是容器,用来装载水:你家里的大衣柜 ...
- 01_Java基础_第1天(Java概述、环境变量、注释、关键字、标识符、常量)_讲义
今日内容介绍 1.Java开发环境搭建 2.HelloWorld案例 3.注释.关键字.标识符 4.数据(数据类型.常量) 01java语言概述 * A: java语言概述 * a: Java是sun ...
随机推荐
- vue 脚手架搭建步骤!
========================================================== 说出来都是泪,最开始都不知道从哪里开始(回头一看还是很简单的,关键是要找到入口) ...
- Linux软件安装——软件包
Linux软件安装——软件包 摘要:本文主要学习了Linux下软件安装的相关知识. 软件包 简介 Linux下的软件包众多,且几乎都是经GPL授权.免费开源(无偿公开源代码)的.这意味着如果你具备修改 ...
- 最新整理的spring面试题从基础到高级,干货满满
最新整理的spring面试题从基础到高级,干货满满 前言: 收藏了一些关于Spring的面试题,一方面是为了准备找工作的时候看面试题,另一方面,通过面试题的方式加深一些自己的理论知识. spring ...
- 汇编指令之CMP, TEST指令
一.CMP指令 这一块呢,我不想上图了,汇编的博文我已经快要让我写吐了,其实也有好多我没有补充进来,比如进制,LEA指令,数据宽度,有符号,无符号的区分等等,但我真的要吐了,这些玩意我已经不是第一次写 ...
- oracle dg状态检查及相关命令
oracle dg 状态检查 先检查备库的归档日志同步情况 SELECT NAME,applied FROM v$archived_log; alter database recover manage ...
- linux 广播和组播
广播和组播 广播,必须使用UDP协议,是只能在局域网内使用,指定接收端的IP为*.*.*.255后,发送的信息,局域网内的所有接受端就能够接到信息了. 广播的发送端代码 #include <st ...
- Confluence 6.9.0 安装
平台环境:centos 7.6 数据库版本:mysql-5.7.26,提前安装好,安装步骤略. 软件版本:Confluence6.9.0 所需软件:提前下载到本地电脑 atlassian-conflu ...
- R 基于朴素贝叶斯模型实现手机垃圾短信过滤
# 读取数数据, 查看数据结构 df_raw <- read.csv("sms_spam.csv", stringsAsFactors=F) str(df_raw) leng ...
- 便宜的回文 (USACO 2007)(c++)
2019-08-21便宜的回文(USACO 2007) 内存限制:128 MiB 时间限制:1000 ms 标准输入输出 题目类型:传统 评测方式:文本比较 题目描述 追踪每头奶牛的去向是一件棘手的任 ...
- ppm
PPM图像格式是由Jef Poskanzer 在1991年所创造的. PPM(Portable Pixmap Format)还有两位兄长,大哥名叫「PBM」,二哥人称「PGM」,他们三兄弟各有所长,下 ...