视频指导:https://www.bilibili.com/video/BV1Y7411d7Ys?p=6

一些数据集

在pytorch框架下,里面面有配套的数据集,pytorch里面有一个torchversion的包,他可以提供相应的,比如MNIST这种很基础的数据集,但是安装的时候这些数据集不会包含在包里面,所以需要下载,具体代码以及解释如下:

import torchvision
train_set = torchvision.datasets.MNIST(root= '../dataset/mnist', train=True, download=True) //root后面表示把数据集安装在什么位置,train表示是要训练集还是测试集,download表示是否需要下载(如果没有下载就需要下载)
text_set = torchvision.datasets.MNIST(root= '../dataset/mnist', train=False, download=True)

还有一个叫做CIFAR-10 的数据集,训练集里面包含了5w个样本,测试集包含了1w个样本,内容就是一些图片,被分成了十类,分别表示不同的事务

import torchvision
train_set = torchvision.datasets.CIFAR10(...)
text_set = torchvision.datasets.CIFAR10(...)

逻辑斯蒂回归

首先我们自己设计几个数据

x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])

然后开始设计模型

class LogisiticRegressionModel(torch.nn.Module):
def __init__(self):
super(LogisiticRegressionModel, self).__init__()
self.linear = torch.nn.Linear(1, 1) def forward(self,x):
y_pred = F.sigmoid(self.linear(x))
return y_pred

这里和昨天不一样的地方在于forward里面使用了逻辑斯蒂回归

然后是优化器和损失函数

model = LogisiticRegressionModel()
critertion = torch.nn.BCELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)

损失函数采用的是BCE,就是交叉熵损失

开始训练

for epoch in range(1000):
y_pred = model(x_data)
loss = critertion(y_pred,y_data)
print(epoch, loss.item()) optimizer.zero_grad()
loss.backward()
optimizer.step()

结果可视化

用matplotlib进行绘图可以使结果可视化,具体代码和解释如下

import numpy as np
import matplotlib.pyplot as plt x = np.linspace(0, 10, 200) //意思就是在0到10之间弄出来两百个点
x_t = torch.Tensor(x).view((200, 1)) //把他变成200行一列的矩阵
y_t = model(x_t) //送入模型
y = y_t.data.numpy() //用numpy接收y_data的数据
plt.plot(x, y) //绘图
plt.plot([0, 10], [0.5, 0.5], c='r') //表示在0到10,0.5到0.5上面画一条线(显然是一条直线)
plt.ylabel('probability of pass')
plt.grid()
plt.show()

总结

第一遍跑有一个警告

UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead. warnings.warn

查了一下,把

F.sigmoid(self.hidden(x))
#修改成以下语句
torch.sigmoid(self.hidden(x))

即可

估计是版本问题吧

然后我有魔改了一下代码,具体如下

import torch

x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]]) class LogisiticRegressionModel(torch.nn.Module):
def __init__(self):
super(LogisiticRegressionModel, self).__init__()
self.linear = torch.nn.Linear(1, 1) def forward(self,x):
y_pred = torch.sigmoid(self.linear(x))
return y_pred model = LogisiticRegressionModel()
critertion = torch.nn.BCELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) for epoch in range(50000):
y_pred = model(x_data)
loss = critertion(y_pred,y_data)
print(epoch, loss.item()) optimizer.zero_grad()
loss.backward()
optimizer.step() import numpy as np
import matplotlib.pyplot as plt x = np.linspace(0, 100, 2000)
x_t = torch.Tensor(x).view((2000, 1))
y_t = model(x_t)
y = y_t.data.numpy()
plt.plot(x, y)
plt.plot([0, 100], [0.5, 0.5], c='r')
plt.xlabel('hours')
plt.ylabel('pass')
plt.grid()
plt.show()

最后的结果,我愿称之为,过拟合之王



哈哈哈哈哈,很有意思

  • 总之整个代码编写就分这几步:
  • 准备数据;
  • 设计模型;
  • 构造损失函数和优化器;
  • 训练过程;
  • 打印结果

    熟悉了就可以搞出自己的模型了

【项目实战】pytorch实现逻辑斯蒂回归的更多相关文章

  1. python机器学习实现逻辑斯蒂回归

    逻辑斯蒂回归 关注公众号"轻松学编程"了解更多. [关键词]Logistics函数,最大似然估计,梯度下降法 1.Logistics回归的原理 利用Logistics回归进行分类的 ...

  2. 【转】机器学习笔记之(3)——Logistic回归(逻辑斯蒂回归)

    原文链接:https://blog.csdn.net/gwplovekimi/article/details/80288964 本博文为逻辑斯特回归的学习笔记.由于仅仅是学习笔记,水平有限,还望广大读 ...

  3. 机器学习之LinearRegression与Logistic Regression逻辑斯蒂回归(三)

    一 评价尺度 sklearn包含四种评价尺度 1 均方差(mean-squared-error) 2 平均绝对值误差(mean_absolute_error) 3 可释方差得分(explained_v ...

  4. [置顶] 局部加权回归、最小二乘的概率解释、逻辑斯蒂回归、感知器算法——斯坦福ML公开课笔记3

    转载请注明:http://blog.csdn.net/xinzhangyanxiang/article/details/9113681 最近在看Ng的机器学习公开课,Ng的讲法循循善诱,感觉提高了不少 ...

  5. 【分类器】感知机+线性回归+逻辑斯蒂回归+softmax回归

    一.感知机     详细参考:https://blog.csdn.net/wodeai1235/article/details/54755735 1.模型和图像: 2.数学定义推导和优化: 3.流程 ...

  6. spark机器学习从0到1逻辑斯蒂回归之(四)

      逻辑斯蒂回归 一.概念 逻辑斯蒂回归(logistic regression)是统计学习中的经典分类方法,属于对数线性模型.logistic回归的因变量可以是二分类的,也可以是多分类的.logis ...

  7. 逻辑斯蒂回归(Logistic Regression)

    逻辑回归名字比较古怪,看上去是回归,却是一个简单的二分类模型. 逻辑回归的模型是如下形式: 其中x是features,θ是feature的权重,σ是sigmoid函数.将θ0视为θ0*x0(x0取值为 ...

  8. 逻辑斯蒂回归VS决策树VS随机森林

    LR 与SVM 不同 1.logistic regression适合需要得到一个分类概率的场景,SVM则没有分类概率 2.LR其实同样可以使用kernel,但是LR没有support vector在计 ...

  9. [转]逻辑斯蒂回归 via python

    # -*- coding:UTF-8 -*-import numpydef loadDataSet(): return dataMat,labelMat def sigmoid(inX): retur ...

随机推荐

  1. Docker — 从入门到实践PDF下载(可复制版)

    0.9-rc2(2017-12-09)修订说明:本书内容将基于DockerCEv17.MM进行重新修订,计划2017年底发布0.9.0版本.旧版本(Docker1.13-)内容,请阅读docker-l ...

  2. ReentrantLock 公平锁源码 第0篇

    ReentrantLock 0 关于ReentrantLock的文章其实写过的,但当时写的感觉不是太好,就给删了,那为啥又要再写一遍呢 最近闲着没事想自己写个锁,然后整了几天出来后不是跑丢线程就是和没 ...

  3. JAVA语言的跨平台性和JDK,JRE与JVM

    Java虚拟机--JVM ~JVM:java虚拟机简称JVM是运行所有java程序的假想计算机,是java程序的运行环境,是java最具有吸引力的特性之一,我们编写的java代码,都运行在JVM之上 ...

  4. HashMap源码深度剖析,手把手带你分析每一行代码,包会!!!

    HashMap源码深度剖析,手把手带你分析每一行代码! 在前面的两篇文章哈希表的原理和200行代码带你写自己的HashMap(如果你阅读这篇文章感觉有点困难,可以先阅读这两篇文章)当中我们仔细谈到了哈 ...

  5. 低代码如何构建支持OAuth2.0的后端Web API

    OAuth2.0 OAuth 是一个安全协议,用于保护全球范围内大量且不断增长的Web API.它用于连接不同的网站,还支持原生应用和移动应用于云服务之间的连接,同时它也是各个领域标准协议中的安全层. ...

  6. CF1703E Mirror Grid 题解

    给定一个矩阵,判断最少将多少个格反转后使得旋转零度,九十度,一百八十度,二百七十度相等. 枚举矩阵每个位置是 \(0\) 还是 \(1\),若已经判断过则跳过,全统 \(1\) 和全统 \(0\) 取 ...

  7. Arraylist集合的概述和基本使用与常用方法

    什么是ArrayList类 java.util.ArrayList 是大小可变的数组实现的,存储在内的数据称为元素,此类提供一些方法来操作内部存储的元素.ArrayList中可不断添加元素,其大小也自 ...

  8. Thymeleaf是什么?该如何使用。

    先了解Thymeleaf是什么 1. Thymeleaf 简介 Thymeleaf 是新⼀代 Java 模板引擎,与 Velocity.FreeMarker 等传统 Java 模板引擎不同,Thyme ...

  9. [NCTF2019]SQLi-1||SQL注入

    1.打开之后首先尝试万能密码登录和部分关键词(or.select.=.or.table.#.-等等)登录,显示被检测到了攻击行为并进行了拦截,结果如下: 2.使用dirmap进行目录扫描,发现robo ...

  10. 参考MySQL Internals手册,使用Golang写一个简单解析binlog的程序

    GreatSQL社区原创内容未经授权不得随意使用,转载请联系小编并注明来源. MySQL作为最流行的开源关系型数据库,有大量的拥趸.其生态已经相当完善,各项特性在圈内都有大量研究.每次新特性发布,都会 ...