import numpy as np
from sklearn.datasets import load_diabetes
from sklearn.utils import shuffle
import matplotlib.pyplot as plt def initialize_params(dims):
w = np.zeros((dims, 1))
b = 0
return w, b def linear_loss(X, y, w, b):
num_train = X.shape[0]
# 模型公式
y_hat = np.dot(X, w) + b
# 损失函数
loss = np.sum((y_hat - y) ** 2) / num_train
# 参数偏导
dw = np.dot(X.T, (y_hat - y)) / num_train
db = np.sum(y_hat - y) / num_train
return y_hat, loss, dw, db def linear_train(X, y, learning_rate, epochs):
# 参数初始化
w, b = initialize_params(X.shape[1]) loss_list = []
for i in range(1, epochs):
# 计算当前预测值、损失和梯度
y_hat, loss, dw, db = linear_loss(X, y, w, b)
loss_list.append(loss) # 基于梯度下降的参数更新
w += -learning_rate * dw
b += -learning_rate * db # 打印迭代次数和损失
if i % 10000 == 0:
print('epoch %d loss %f' % (i, loss)) # 保存参数
params = {
'w': w,
'b': b
} # 保存梯度
grads = {
'dw': dw,
'db': db
}
return loss_list, loss, params, grads def predict(X, params):
w = params['w']
b = params['b']
y_pred = np.dot(X, w) + b
return y_pred if __name__ == "__main__":
# 加载数据
diabets = load_diabetes()
data = diabets.data
target = diabets.target # 打乱数据
X, y = shuffle(data, target, random_state=13) # 划分训练集和测试集
offset = int(X.shape[0] * 0.9)
X_train, y_train = X[:offset], y[:offset]
X_test, y_test = X[offset:], y[offset:]
y_train = y_train.reshape((-1, 1))
y_test = y_test.reshape((-1, 1)) print(X_train.shape)
print(X_test.shape)
print(y_train.shape)
print(y_test.shape) # 训练
loss_list, loss, params, grads = linear_train(X_train, y_train, 0.01, 100000)
print(params) # 预测
y_pred = predict(X_test, params)
print(y_pred[:5]) # 画图
f = X_test.dot(params['w']) + params['b']
plt.scatter(range(X_test.shape[0]), y_test)
plt.plot(f, color='darkorange')
plt.xlabel('x')
plt.xlabel('y')
plt.show() plt.plot(loss_list, color='blue')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.show()

Python实现机器学习算法:线性回归的更多相关文章

  1. Python实现机器学习算法:AdaBoost算法

    Python程序 ''' 数据集:Mnist 训练集数量:60000(实际使用:10000) 测试集数量:10000(实际使用:1000) 层数:40 ------------------------ ...

  2. Python实现机器学习算法:决策树算法

    ''' 数据集:Mnist 训练集数量:60000 测试集数量:10000 ------------------------------ 运行结果:ID3(未剪枝) 正确率:85.9% 运行时长:35 ...

  3. Python实现机器学习算法:感知机

    ''' 数据集:Mnist 训练集数量:60000 测试集数量:10000 ------------------------------ 运行结果: 正确率:81.72%(二分类) ''' impor ...

  4. Python实现机器学习算法:EM算法

    ''' 数据集:伪造数据集(两个高斯分布混合) 数据集长度:1000 ------------------------------ 运行结果: ---------------------------- ...

  5. Python实现机器学习算法:朴素贝叶斯算法

    ''' 数据集:Mnist 训练集数量:60000 测试集数量:10000 ''' import numpy as np import time def loadData(fileName): ''' ...

  6. Python实现机器学习算法:K近邻算法

    ''' 数据集:Mnist 训练集数量:60000 测试集数量:10000(实际使用:200) ''' import numpy as np import time def loadData(file ...

  7. Python实现机器学习算法:逻辑回归

    import numpy as np import matplotlib.pyplot as plt from sklearn.datasets.samples_generator import ma ...

  8. 建模分析之机器学习算法(附python&R代码)

    0序 随着移动互联和大数据的拓展越发觉得算法以及模型在设计和开发中的重要性.不管是现在接触比较多的安全产品还是大互联网公司经常提到的人工智能产品(甚至人类2045的的智能拐点时代).都基于算法及建模来 ...

  9. 10 种机器学习算法的要点(附 Python 和 R 代码)

    本文由 伯乐在线 - Agatha 翻译,唐尤华 校稿.未经许可,禁止转载!英文出处:SUNIL RAY.欢迎加入翻译组. 前言 谷歌董事长施密特曾说过:虽然谷歌的无人驾驶汽车和机器人受到了许多媒体关 ...

随机推荐

  1. FTL 数字有逗号

    Long i=100000000l; Map model=new Map(); model.put("t",i); 在freemarker中显示为100,000,000 想按原样输 ...

  2. 给定一个正整数,实现一个方法求出离该整数最近的大于自身的 换位数 <把一个整数各个数位进行全排列>

    """给定一个正整数,实现一个方法求出离该整数最近的大于自身的 换位数 -> 把一个整数各个数位进行全排列""" # 使用 permu ...

  3. CRM 权限设置

    表结构的设计 权限表 url -url地址的正则表达式 ^$ title - 标题 角色表 name - 角色名称 permissions 多对多关联权限表 (权限和角色的关系表) 用户表 name ...

  4. VI编辑器常用命令

    Linux下的文本编辑器有很多种,vi 是最常用的,也是各版本Linux的标配.注意,vi 仅仅是一个文本编辑器,可以给字符着色,可以自动补全,但是不像 Windows 下的 word 有排版功能. ...

  5. JustOj 2042: Dada的游戏

    题目描述 Dada无聊时,喜欢做一个游戏,将很多钱分成若干堆排成一列,每堆钱数不固定,谁能找出每堆钱数严格递增的最长区间,谁就是人生赢家了.Dada可能脑子里的水还没干,她找不出来,你来帮她找找吧. ...

  6. 远程获得乐趣的 Linux 命令

    今天是我们为期 24 天的 Linux 命令行玩具日历的最后一天.希望你一直有在看,但如果没有,请从头开始,继续努力.你会发现 Linux 终端有很多游戏.消遣和奇怪之处. 虽然你之前可能已经看过我们 ...

  7. CentOS7安装MySQL冲突和问题解决小结

    摘自:https://blog.csdn.net/typa01_kk/article/details/49059729 问题1: [root@localhost install-files]# rpm ...

  8. MyBatis中#{ }和${ }的区别,数据库优化遵循层次和查询方法

    MyBatis中#{ }和${ }的区别详解 1.#将传入的数据当成一个字符串,会对自动传入的数据加一个 双引号. 例如order by #id#,如果传入的值是111,那么解析成sql时变为orde ...

  9. Python调用大漠插件

    Python版本要用32位的?我去官网下载,太慢了,就在腾讯软件里面下载了一个,结果实验成功 import win32com.client dm = win32com.client.Dispatch( ...

  10. Linux指令之netstat

    查看某个端口的连接数 netstat -nat | grep -iw "8463" | wc -l [Mac&Redhat通用] 查看连接状况 netstat -nat | ...