import torch
import torch.nn as nn
import numpy as np
torch.__version__
'1.0.0'

3.1 logistic回归实战

在这一章里面,我们将处理一下结构化数据,并使用logistic回归对结构化数据进行简单的分类。

3.1.1 logistic回归介绍

logistic回归是一种广义线性回归(generalized linear model),与多重线性回归分析有很多相同之处。它们的模型形式基本上相同,都具有 wx + b,其中w和b是待求参数,其区别在于他们的因变量不同,多重线性回归直接将wx+b作为因变量,即y =wx+b,而logistic回归则通过函数L将wx+b对应一个隐状态p,p =L(wx+b),然后根据p 与1-p的大小决定因变量的值。如果L是logistic函数,就是logistic回归,如果L是多项式函数就是多项式回归。

说的更通俗一点,就是logistic回归会在线性回归后再加一层logistic函数的调用。

logistic回归主要是进行二分类预测,我们在激活函数时候讲到过 Sigmod函数,Sigmod函数是最常见的logistic函数,因为Sigmod函数的输出的是是对于0~1之间的概率值,当概率大于0.5预测为1,小于0.5预测为0。

下面我们就来使用公开的数据来进行介绍

3.1.2 UCI German Credit 数据集

UCI German Credit是UCI的德国信用数据集,里面有原数据和数值化后的数据。

German Credit数据是根据个人的银行贷款信息和申请客户贷款逾期发生情况来预测贷款违约倾向的数据集,数据集包含24个维度的,1000条数据,

在这里我们直接使用处理好的数值化的数据,作为展示。

地址

3.2 代码实战

我们这里使用的 german.data-numeric是numpy处理好数值化数据,我们直接使用numpy的load方法读取即可

data=np.loadtxt("german.data-numeric")

数据读取完成后我们要对数据做一下归一化的处理

n,l=data.shape
for j in range(l-1):
meanVal=np.mean(data[:,j])
stdVal=np.std(data[:,j])
data[:,j]=(data[:,j]-meanVal)/stdVal

打乱数据

np.random.shuffle(data)

区分训练集和测试集,由于这里没有验证集,所以我们直接使用测试集的准确度作为评判好坏的标准

区分规则:900条用于训练,100条作为测试

german.data-numeric的格式为,前24列为24个维度,最后一个为要打的标签(0,1),所以我们将数据和标签一起区分出来

train_data=data[:900,:l-1]
train_lab=data[:900,l-1]-1
test_data=data[900:,:l-1]
test_lab=data[900:,l-1]-1

下面我们定义模型,模型很简单

class LR(nn.Module):
def __init__(self):
super(LR,self).__init__()
self.fc=nn.Linear(24,2) # 由于24个维度已经固定了,所以这里写24
def forward(self,x):
out=self.fc(x)
out=torch.sigmoid(out)
return out

测试集上的准确率

def test(pred,lab):
t=pred.max(-1)[1]==lab
return torch.mean(t.float())

下面就是对一些设置

net=LR()
criterion=nn.CrossEntropyLoss() # 使用CrossEntropyLoss损失
optm=torch.optim.Adam(net.parameters()) # Adam优化
epochs=1000 # 训练1000次

下面开始训练了

for i in range(epochs):
# 指定模型为训练模式,计算梯度
net.train()
# 输入值都需要转化成torch的Tensor
x=torch.from_numpy(train_data).float()
y=torch.from_numpy(train_lab).long()
y_hat=net(x)
loss=criterion(y_hat,y) # 计算损失
optm.zero_grad() # 前一步的损失清零
loss.backward() # 反向传播
optm.step() # 优化
if (i+1)%100 ==0 : # 这里我们每100次输出相关的信息
# 指定模型为计算模式
net.eval()
test_in=torch.from_numpy(test_data).float()
test_l=torch.from_numpy(test_lab).long()
test_out=net(test_in)
# 使用我们的测试函数计算准确率
accu=test(test_out,test_l)
print("Epoch:{},Loss:{:.4f},Accuracy:{:.2f}".format(i+1,loss.item(),accu))
Epoch:100,Loss:0.6313,Accuracy:0.76
Epoch:200,Loss:0.6065,Accuracy:0.79
Epoch:300,Loss:0.5909,Accuracy:0.80
Epoch:400,Loss:0.5801,Accuracy:0.81
Epoch:500,Loss:0.5720,Accuracy:0.82
Epoch:600,Loss:0.5657,Accuracy:0.81
Epoch:700,Loss:0.5606,Accuracy:0.81
Epoch:800,Loss:0.5563,Accuracy:0.81
Epoch:900,Loss:0.5527,Accuracy:0.81
Epoch:1000,Loss:0.5496,Accuracy:0.80

训练完成了,我们的准确度达到了80%

[Pytorch框架] 3.1 logistic回归实战的更多相关文章

  1. Logistic回归实战篇之预测病马死亡率

    利用sklearn.linear_model.LogisticRegression训练和测试算法. 示例代码: import numpy as np import matplotlib.pyplot ...

  2. 机器学习实战笔记5(logistic回归)

    1:简单概念描写叙述 如果如今有一些数据点,我们用一条直线对这些点进行拟合(改线称为最佳拟合直线),这个拟合过程就称为回归.训练分类器就是为了寻找最佳拟合參数,使用的是最优化算法. 基于sigmoid ...

  3. 机器学习实战(Machine Learning in Action)学习笔记————05.Logistic回归

    机器学习实战(Machine Learning in Action)学习笔记————05.Logistic回归 关键字:Logistic回归.python.源码解析.测试作者:米仓山下时间:2018- ...

  4. [机器学习实战-Logistic回归]使用Logistic回归预测各种实例

    目录 本实验代码已经传到gitee上,请点击查收! 一.实验目的 二.实验内容与设计思想 实验内容 设计思想 三.实验使用环境 四.实验步骤和调试过程 4.1 基于Logistic回归和Sigmoid ...

  5. 机器学习实战 - 读书笔记(05) - Logistic回归

    解释 Logistic回归用于寻找最优化算法. 最优化算法可以解决最XX问题,比如如何在最短时间内从A点到达B点?如何投入最少工作量却获得最大的效益?如何设计发动机使得油耗最少而功率最大? 我们可以看 ...

  6. 【机器学习实战】第5章 Logistic回归

    第5章 Logistic回归 Logistic 回归 概述 Logistic 回归虽然名字叫回归,但是它是用来做分类的.其主要思想是: 根据现有数据对分类边界线建立回归公式,以此进行分类. 须知概念 ...

  7. 05机器学习实战之Logistic 回归

    Logistic 回归 概述 Logistic 回归 或者叫逻辑回归 虽然名字有回归,但是它是用来做分类的.其主要思想是: 根据现有数据对分类边界线(Decision Boundary)建立回归公式, ...

  8. 《机器学习实战》Logistic回归

    注释:Ng的视频有完整的推到步骤,不过理论和实践还是有很大差别的,代码实现还得完成 1.Logistic回归理论 http://www.cnblogs.com/wjy-lulu/p/7759515.h ...

  9. <机器学习实战>读书笔记--logistic回归

    1. 利用logistic回归进行分类的主要思想是:根据现有数据对分类边界线建立回归公式,以此进行分类. 2.sigmoid函数的分类 Sigmoid函数公式定义 3.梯度上升法    基本思想:要找 ...

  10. 机器学习实战之Logistic回归

    Logistic回归一.概述 1. Logistic Regression 1.1 线性回归 1.2 Sigmoid函数 1.3 逻辑回归 1.4 LR 与线性回归的区别 2. LR的损失函数 3. ...

随机推荐

  1. [Vue warn]: Duplicate keys detected: ''. This may cause an update error. found in

    原因: 使用element-ui 导致 使用路由模式之后  index 没写 导致 解决办法: 删掉  或者天添加路由

  2. Caused by: java.lang.NoClassDefFoundError: net/minidev/asm/FieldFilter 报错的解决

    Caused by: org.springframework.beans.factory.BeanCreationException: Error creating bean with name 'r ...

  3. PerfDog的使用教程

    一.介绍: 移动全平台iOS/Android性能测试.分析工具平台,快速定位分析性能问题.PerfDog支持移动平台所有应用程序(游戏.APP应用.浏览器.小程序.小游戏.H5.后台系统进程等).An ...

  4. DNS Capture: UDP, TCP, IP-Fragmentation, EDNS, ECS, Cookie

    EDNS 扩展实现"EDNS Client Subnet" (ECS) 和 DNS cookies.' 这里不讨论相关概念,实现如有疑问请查看: https://weberblog ...

  5. Python中的join函数用法

    函数:string.join()Python中有join()和os.path.join()两个函数,具体作用如下:    join():    连接字符串数组.将字符串.元组.列表中的元素以指定的字符 ...

  6. python调用方法或者变量时出现未定义异常的原因,可能会是没有正确实例化

    当引用某个某块时 例如 Testpython import test class test(object): def __init__(): -- self.mimi = test def test1 ...

  7. Net DB Web多级缓存的实现

    1.客户端缓存(浏览器缓存) HTTP有一套控制缓存的协议-RFC7234,其中最重要的就是cache-control这个相应报文头,服务器返回时,如果Response带上 cache-control ...

  8. Rainbond PipeLine插件部署与springboot应用部署实践

    前言:上一篇介绍额rainbond单机部署+单个节点的k8s环境搭建,本篇介绍rainbond5.12新增的pipeline插件的使用 1.Pipeline插件的安装 安装gitlab与gitlab- ...

  9. GUI编程--3 Swing

    GUI编程-3 Swing 3.1 JFrame 窗口 窗口: package com.ssl.lesson04; import javax.swing.*; import java.awt.*; p ...

  10. 前端开发工具 VS Code 安裝及使用

    一.下载地址 https://code.visualstudio.com/ 下载完后,傻瓜式安装即可 关注公众号"Java程序员进阶"回复"vs"也可获取 二. ...