import numpy as np
from sklearn.datasets import load_diabetes
from sklearn.utils import shuffle
import matplotlib.pyplot as plt def initialize_params(dims):
w = np.zeros((dims, 1))
b = 0
return w, b def linear_loss(X, y, w, b):
num_train = X.shape[0]
# 模型公式
y_hat = np.dot(X, w) + b
# 损失函数
loss = np.sum((y_hat - y) ** 2) / num_train
# 参数偏导
dw = np.dot(X.T, (y_hat - y)) / num_train
db = np.sum(y_hat - y) / num_train
return y_hat, loss, dw, db def linear_train(X, y, learning_rate, epochs):
# 参数初始化
w, b = initialize_params(X.shape[1]) loss_list = []
for i in range(1, epochs):
# 计算当前预测值、损失和梯度
y_hat, loss, dw, db = linear_loss(X, y, w, b)
loss_list.append(loss) # 基于梯度下降的参数更新
w += -learning_rate * dw
b += -learning_rate * db # 打印迭代次数和损失
if i % 10000 == 0:
print('epoch %d loss %f' % (i, loss)) # 保存参数
params = {
'w': w,
'b': b
} # 保存梯度
grads = {
'dw': dw,
'db': db
}
return loss_list, loss, params, grads def predict(X, params):
w = params['w']
b = params['b']
y_pred = np.dot(X, w) + b
return y_pred if __name__ == "__main__":
# 加载数据
diabets = load_diabetes()
data = diabets.data
target = diabets.target # 打乱数据
X, y = shuffle(data, target, random_state=13) # 划分训练集和测试集
offset = int(X.shape[0] * 0.9)
X_train, y_train = X[:offset], y[:offset]
X_test, y_test = X[offset:], y[offset:]
y_train = y_train.reshape((-1, 1))
y_test = y_test.reshape((-1, 1)) print(X_train.shape)
print(X_test.shape)
print(y_train.shape)
print(y_test.shape) # 训练
loss_list, loss, params, grads = linear_train(X_train, y_train, 0.01, 100000)
print(params) # 预测
y_pred = predict(X_test, params)
print(y_pred[:5]) # 画图
f = X_test.dot(params['w']) + params['b']
plt.scatter(range(X_test.shape[0]), y_test)
plt.plot(f, color='darkorange')
plt.xlabel('x')
plt.xlabel('y')
plt.show() plt.plot(loss_list, color='blue')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.show()

Python实现机器学习算法:线性回归的更多相关文章

  1. Python实现机器学习算法:AdaBoost算法

    Python程序 ''' 数据集:Mnist 训练集数量:60000(实际使用:10000) 测试集数量:10000(实际使用:1000) 层数:40 ------------------------ ...

  2. Python实现机器学习算法:决策树算法

    ''' 数据集:Mnist 训练集数量:60000 测试集数量:10000 ------------------------------ 运行结果:ID3(未剪枝) 正确率:85.9% 运行时长:35 ...

  3. Python实现机器学习算法:感知机

    ''' 数据集:Mnist 训练集数量:60000 测试集数量:10000 ------------------------------ 运行结果: 正确率:81.72%(二分类) ''' impor ...

  4. Python实现机器学习算法:EM算法

    ''' 数据集:伪造数据集(两个高斯分布混合) 数据集长度:1000 ------------------------------ 运行结果: ---------------------------- ...

  5. Python实现机器学习算法:朴素贝叶斯算法

    ''' 数据集:Mnist 训练集数量:60000 测试集数量:10000 ''' import numpy as np import time def loadData(fileName): ''' ...

  6. Python实现机器学习算法:K近邻算法

    ''' 数据集:Mnist 训练集数量:60000 测试集数量:10000(实际使用:200) ''' import numpy as np import time def loadData(file ...

  7. Python实现机器学习算法:逻辑回归

    import numpy as np import matplotlib.pyplot as plt from sklearn.datasets.samples_generator import ma ...

  8. 建模分析之机器学习算法(附python&R代码)

    0序 随着移动互联和大数据的拓展越发觉得算法以及模型在设计和开发中的重要性.不管是现在接触比较多的安全产品还是大互联网公司经常提到的人工智能产品(甚至人类2045的的智能拐点时代).都基于算法及建模来 ...

  9. 10 种机器学习算法的要点(附 Python 和 R 代码)

    本文由 伯乐在线 - Agatha 翻译,唐尤华 校稿.未经许可,禁止转载!英文出处:SUNIL RAY.欢迎加入翻译组. 前言 谷歌董事长施密特曾说过:虽然谷歌的无人驾驶汽车和机器人受到了许多媒体关 ...

随机推荐

  1. Block 实践

    OC版 函数中无参无返回值 /* 作为函数参数类型的格式 返回值类型 (^)(形参列表) */ CZPerson.h - (void) test:(void (^)(void))block; CZPe ...

  2. Redis入门——安装与基本命令

    1. Redis安装 下载地址:https://github.com/MSOpenTech/redis/releases 下载zip文件后直接解压 2. 启动Redis服务端 解压目录下执行redis ...

  3. Linux基础命令---文本显示tac

    tac 将指定文件中的行,按照反序方式显示.此命令的适用范围:RedHat.RHEL.Ubuntu.CentOS.SUSE.openSUSE.Fedora. 1.语法         tac [选项] ...

  4. 转:【专题六】UDP编程

    引用: 前一个专题简单介绍了TCP编程的一些知识,UDP与TCP地位相当的另一个传输层协议,它也是当下流行的很多主流网络应用(例如QQ.MSN和Skype等一些即时通信软件传输层都是应用UDP协议的) ...

  5. navicat远程连接阿里云ECS上的MYSQL报Lost connection to MySQL server at 'reading initial communication packet'

    问题现象 MySQL 远程连接报错:Lost connection to MySQL server at 'reading initial communication packet' 解决方案 1.检 ...

  6. web前端利用leaflet生成粒子风场,类似windy

    wind.js如下: $(function() { var dixing = L.tileLayer.chinaProvider('Google.Satellite.Map', { maxZoom: ...

  7. 计蒜客---N的-2进制表示

    对于十进制整数N,试求其-2进制表示. 例如,因为  1*1  +  1*-2  +  1*4  +  0*-8  +1*16  +  1*-32  =  -13  ,所以(-13)_10  =  ( ...

  8. The Little Prince-12/05

    The Little Prince-12/05 "When a mystery is too overpowering, one dare not disobey. Absurd as it ...

  9. SpringMVC实现 MultipartFile 文件上传

    1. Maven 工程引入所需要的依赖包 2. 页面需要开放多媒体标签 3. 配置文件上传试图解析器 4. 接收图片信息,通过 IO 流写入磁盘(调用解析其中的方法即可) 如下: 1.1 引入所依赖的 ...

  10. iOS项目之解析HTML数据

    最近因为需求,一直在做HTML数据的解析,从网页中去获取需要的数据,然后展示到自己的app中. 在网上找了很多资料,大多都是TFHpple这个第三方框架,能够根据标签节点获取对应的数据,但是现在我需要 ...