Python实现机器学习算法:线性回归
import numpy as np
from sklearn.datasets import load_diabetes
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
def initialize_params(dims):
w = np.zeros((dims, 1))
b = 0
return w, b
def linear_loss(X, y, w, b):
num_train = X.shape[0]
# 模型公式
y_hat = np.dot(X, w) + b
# 损失函数
loss = np.sum((y_hat - y) ** 2) / num_train
# 参数偏导
dw = np.dot(X.T, (y_hat - y)) / num_train
db = np.sum(y_hat - y) / num_train
return y_hat, loss, dw, db
def linear_train(X, y, learning_rate, epochs):
# 参数初始化
w, b = initialize_params(X.shape[1])
loss_list = []
for i in range(1, epochs):
# 计算当前预测值、损失和梯度
y_hat, loss, dw, db = linear_loss(X, y, w, b)
loss_list.append(loss)
# 基于梯度下降的参数更新
w += -learning_rate * dw
b += -learning_rate * db
# 打印迭代次数和损失
if i % 10000 == 0:
print('epoch %d loss %f' % (i, loss))
# 保存参数
params = {
'w': w,
'b': b
}
# 保存梯度
grads = {
'dw': dw,
'db': db
}
return loss_list, loss, params, grads
def predict(X, params):
w = params['w']
b = params['b']
y_pred = np.dot(X, w) + b
return y_pred
if __name__ == "__main__":
# 加载数据
diabets = load_diabetes()
data = diabets.data
target = diabets.target
# 打乱数据
X, y = shuffle(data, target, random_state=13)
# 划分训练集和测试集
offset = int(X.shape[0] * 0.9)
X_train, y_train = X[:offset], y[:offset]
X_test, y_test = X[offset:], y[offset:]
y_train = y_train.reshape((-1, 1))
y_test = y_test.reshape((-1, 1))
print(X_train.shape)
print(X_test.shape)
print(y_train.shape)
print(y_test.shape)
# 训练
loss_list, loss, params, grads = linear_train(X_train, y_train, 0.01, 100000)
print(params)
# 预测
y_pred = predict(X_test, params)
print(y_pred[:5])
# 画图
f = X_test.dot(params['w']) + params['b']
plt.scatter(range(X_test.shape[0]), y_test)
plt.plot(f, color='darkorange')
plt.xlabel('x')
plt.xlabel('y')
plt.show()
plt.plot(loss_list, color='blue')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.show()
Python实现机器学习算法:线性回归的更多相关文章
- Python实现机器学习算法:AdaBoost算法
Python程序 ''' 数据集:Mnist 训练集数量:60000(实际使用:10000) 测试集数量:10000(实际使用:1000) 层数:40 ------------------------ ...
- Python实现机器学习算法:决策树算法
''' 数据集:Mnist 训练集数量:60000 测试集数量:10000 ------------------------------ 运行结果:ID3(未剪枝) 正确率:85.9% 运行时长:35 ...
- Python实现机器学习算法:感知机
''' 数据集:Mnist 训练集数量:60000 测试集数量:10000 ------------------------------ 运行结果: 正确率:81.72%(二分类) ''' impor ...
- Python实现机器学习算法:EM算法
''' 数据集:伪造数据集(两个高斯分布混合) 数据集长度:1000 ------------------------------ 运行结果: ---------------------------- ...
- Python实现机器学习算法:朴素贝叶斯算法
''' 数据集:Mnist 训练集数量:60000 测试集数量:10000 ''' import numpy as np import time def loadData(fileName): ''' ...
- Python实现机器学习算法:K近邻算法
''' 数据集:Mnist 训练集数量:60000 测试集数量:10000(实际使用:200) ''' import numpy as np import time def loadData(file ...
- Python实现机器学习算法:逻辑回归
import numpy as np import matplotlib.pyplot as plt from sklearn.datasets.samples_generator import ma ...
- 建模分析之机器学习算法(附python&R代码)
0序 随着移动互联和大数据的拓展越发觉得算法以及模型在设计和开发中的重要性.不管是现在接触比较多的安全产品还是大互联网公司经常提到的人工智能产品(甚至人类2045的的智能拐点时代).都基于算法及建模来 ...
- 10 种机器学习算法的要点(附 Python 和 R 代码)
本文由 伯乐在线 - Agatha 翻译,唐尤华 校稿.未经许可,禁止转载!英文出处:SUNIL RAY.欢迎加入翻译组. 前言 谷歌董事长施密特曾说过:虽然谷歌的无人驾驶汽车和机器人受到了许多媒体关 ...
随机推荐
- JavaScript--详解typeof的用法
typeof定义 typeof是一元运算符,用来返回操作参数的类型(不是值) 检查一个变量是否存在,是否有值 typeof在两种情况下会返回"undefined&q ...
- tensorflow学习4-过拟合-over-fitting
过拟合: 真实的应用中,并不是让模型尽量模拟训练数据的行为,而是希望训练数据对未知做出判断. 模型过于复杂后,模型会积极每一个噪声的部分,而不是学习数据中的通用 趋势.当一个模型的参数比训练数据还要多 ...
- 转:wcf大文件传输解决之道(2)
此篇文章主要是基于http协议应用于大文件传输中的应用,现在我们先解析下wcf中编码器的定义,编码器实现了类的编码,并负责将Message内存中消息转变为网络发送的字节流或者字节缓冲区(对于发送方而言 ...
- android搜索框列表布局,流程及主要步骤思维导图
android搜索框列表布局,流程及主要步骤思维导图 android搜索框列表布局,流程及主要步骤思维导图 activity_coin_search.xml----------<com.scwa ...
- 数据库中的undo日志、redo日志
MySQL中有六种日志文件,分别是:重做日志(redo log).回滚日志(undo log).二进制日志(binlog).错误日志(errorlog).慢查询日志(slow query log).一 ...
- 关于treeMap
https://www.cnblogs.com/skywang12345/p/3310928.html
- 怎样从外网访问内网Redis数据库?
本地安装了一个Redis数据库,只能在局域网内访问到,怎样从外网也能访问到本地的Redis数据库呢?本文将介绍具体的实现步骤. 1. 准备工作 1.1 安装并启动Redis数据库 默认安装的Redis ...
- 【shell脚本】通过遍历文件的一种批量执行shell命令的方法。
在分析数据时,经常会有许多机械重复的命令带入,作为一个半路出家的程序猿,我曾经对这种工作束手无策.不像一个熟手那样举重若轻的分析,感觉自己的生信分析完全是个体力活.为了打开这样的局面,我开始学习如何批 ...
- mysql 通过查看mysql 配置参数、状态来优化你的mysql
我把MYISAM改成了INNODB,数据库对CPU方面的占用变小很多' mysql的监控方法大致分为两类: 1.连接到mysql数据库内部,使用show status,show variables,f ...
- oracle 11g亿级复杂SQL优化一例(数量级性能提升)
自从16年之后,因为工作原因,项目中就没有再使用oracle了,最近最近支持一个项目,又要开始负责这块事情了.最近在跑性能测试,配置全部调好之后,不少sql还存在性能低下的问题,主要涉及执行计划的不合 ...