Task3.PyTorch实现Logistic regression
1.PyTorch基础实现代码
import torch
from torch.autograd import Variable torch.manual_seed(2)
x_data = Variable(torch.Tensor([[1.0], [2.0], [3.0], [4.0]]))
y_data = Variable(torch.Tensor([[0.0], [0.0], [1.0], [1.0]])) #初始化
w = Variable(torch.Tensor([-1]), requires_grad=True)
b = Variable(torch.Tensor([0]), requires_grad=True)
epochs = 100
costs = []
lr = 0.1
print("before training, predict of x = 1.5 is:")
print("y_pred = ", float(w.data*1.5 + b.data > 0)) #模型训练
for epoch in range(epochs):
#计算梯度
A = 1/(1+torch.exp(-(w*x_data+b))) #逻辑回归函数
J = -torch.mean(y_data*torch.log(A) + (1-y_data)*torch.log(1-A)) #逻辑回归损失函数
#J = -torch.mean(y_data*torch.log(A) + (1-y_data)*torch.log(1-A)) +alpha*w**2
#基础类进行正则化,加上L2范数
costs.append(J.data)
J.backward() #自动反向传播 #参数更新
w.data = w.data - lr*w.grad.data
w.grad.data.zero_()
b.data = b.data - lr*b.grad.data
b.grad.data.zero_() print("after training, predict of x = 1.5 is:")
print("y_pred =", float(w.data*1.5+b.data > 0))
print(w.data, b.data)
2.用PyTorch类实现Logistic regression,torch.nn.module写网络结构
import torch
from torch.autograd import Variable x_data = Variable(torch.Tensor([[0.6], [1.0], [3.5], [4.0]]))
y_data = Variable(torch.Tensor([[0.], [0.], [1.], [1.]])) class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = torch.nn.Linear(1, 1)
self.sigmoid = torch.nn.Sigmoid() ###### **sigmoid** def forward(self, x):
y_pred = self.sigmoid(self.linear(x))
return y_pred model = Model() criterion = torch.nn.BCELoss(size_average=True) #损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # 随机梯度下降 for epoch in range(500):
# Forward pass
y_pred = model(x_data) loss = criterion(y_pred, y_data)
if epoch % 20 == 0:
print(epoch, loss.item()) #梯度归零
optimizer.zero_grad()
# 反向传播
loss.backward()
# update weights
optimizer.step() hour_var = Variable(torch.Tensor([[0.5]]))
print("predict (after training)", 0.5, model.forward(hour_var).data[0][0])
hour_var = Variable(torch.Tensor([[7.0]]))
print("predict (after training)", 7.0, model.forward(hour_var).data[0][0])
参考:https://blog.csdn.net/ZZQsAI/article/details/90216593
Task3.PyTorch实现Logistic regression的更多相关文章
- 逻辑回归 Logistic Regression
逻辑回归(Logistic Regression)是广义线性回归的一种.逻辑回归是用来做分类任务的常用算法.分类任务的目标是找一个函数,把观测值匹配到相关的类和标签上.比如一个人有没有病,又因为噪声的 ...
- logistic regression与SVM
Logistic模型和SVM都是用于二分类,现在大概说一下两者的区别 ① 寻找最优超平面的方法不同 形象点说,Logistic模型找的那个超平面,是尽量让所有点都远离它,而SVM寻找的那个超平面,是只 ...
- Logistic Regression - Formula Deduction
Sigmoid Function \[ \sigma(z)=\frac{1}{1+e^{(-z)}} \] feature: axial symmetry: \[ \sigma(z)+ \sigma( ...
- SparkMLlib之 logistic regression源码分析
最近在研究机器学习,使用的工具是spark,本文是针对spar最新的源码Spark1.6.0的MLlib中的logistic regression, linear regression进行源码分析,其 ...
- [OpenCV] Samples 06: [ML] logistic regression
logistic regression,这个算法只能解决简单的线性二分类,在众多的机器学习分类算法中并不出众,但它能被改进为多分类,并换了另外一个名字softmax, 这可是深度学习中响当当的分类算法 ...
- Stanford机器学习笔记-2.Logistic Regression
Content: 2 Logistic Regression. 2.1 Classification. 2.2 Hypothesis representation. 2.2.1 Interpretin ...
- Logistic Regression vs Decision Trees vs SVM: Part II
This is the 2nd part of the series. Read the first part here: Logistic Regression Vs Decision Trees ...
- Logistic Regression Vs Decision Trees Vs SVM: Part I
Classification is one of the major problems that we solve while working on standard business problem ...
- Logistic Regression逻辑回归
参考自: http://blog.sina.com.cn/s/blog_74cf26810100ypzf.html http://blog.sina.com.cn/s/blog_64ecfc2f010 ...
随机推荐
- 【转】python---方法解析顺序MRO(Method Resolution Order)<以及解决类中super方法>
[转]python---方法解析顺序MRO(Method Resolution Order)<以及解决类中super方法> MRO了解: 对于支持继承的编程语言来说,其方法(属性)可能定义 ...
- JS在页面输出九九乘法表
<!--小练习,练习使用循环实现一个九九乘法表 第一步,最低要求:在Console中按行输出 n * m = t 然后,尝试在网页中,使用table来实现一个九九乘法表 --> <! ...
- python -加密(MD5)
import hashlib def md5_passwd(str,salt ='aaaaa') str = str + salt m = hashlib.md5()#构造一个MD5对象 m.upda ...
- 阶段1 语言基础+高级_1-3-Java语言高级_07-网络编程_第3节 综合案例_文件上传_5_综合案例_文件上传案例优化
自定义文件命名 文件名称被写死了 服务器上传了一张图片,服务器就停止了 把服务器端的代码放在while循环里面 服务器也不用 关闭了. 上传完成后服务器端没有关闭 再来启动客户端,又上传一张投片. 多 ...
- 阶段1 语言基础+高级_1-3-Java语言高级_1-常用API_1_第1节 Scanner类_3-Scanner的使用步骤
Scanner如何进行键盘输入,引用类型就包含了Scanner,它就是引用类型,所以也有这三个步骤, 导包.创建.使用 先通过api文档找到它.左边输入要查找scanner.双夹scanner右边就会 ...
- 阶段1 语言基础+高级_1-3-Java语言高级_06-File类与IO流_02 递归_2_练习_使用递归计算1-n之间的和
输出6 1到100之间的和 求和的原理
- centos yum 安装php5.6
centos yum 安装php5.6 卸载 php之前的版本: yum remove -y php-common 配置源 CentOS 6.5的源 rpm -Uvh http://ftp.iij.a ...
- hacker101----XSS Review
所有你见过XSS行动在这一点上,但我们来回顾一下今天我们要讨论的XSS类型: 反射型XSS -- 来自用户的输入将直接返回到浏览器,从而允许注入任意内容 [浏览器输入,马上到服务器上,再反射回来直 ...
- 建立 Active Directory域 ----学习笔记
第五章 建立 Active Directory域 1.工作组和域的理解 a.工作组是一种平等身份环境,各个计算机之间各个为一个独立体,不方便管理和资源共享. b.域环境一般情况下满足两类需求, ...
- ntp局域网时间同步操作
需求:局域网里面有两台电脑需要同步时间 一台windows,一台Linux.把windows当作服务器 windows10自带ntp服务器,可以按如下步骤进行设置 1. 打开注册表编辑器,在运行里面输 ...