C++ LinearRegression代码实现
这里基本完全参考网络资源完成,有疑问欢迎留言!
LinearRegression.h
#pragma once
#ifndef ML_LINEAEEEGRESSION_H
#define ML_LINEARREGRESSION_H
class LinearRegression {
public:
/*特征*/
double *x;
/*预测值*/
double *y;
/*样本数量*/
int m;
/*系数*/
double *theta;
/*创建实例*/
LinearRegression(double x[], double y[], int m);
/*训练 */
void train(double alpha, int iterations);
/*预测*/
double predict(double x);
private:
/*计算损失模型*/
static double compute_cost(double x[], double y[], double theta[], int m);
/*计算单个预测值*/
static double h(double x, double theta[]);
/*预测*/
static double *calculate_predictions(double x[], double theta[], int m);
/*梯度下降*/
static double *gradient_descent(double x[], double y[], double alpha, int iter, double *j, int m); };
#endif // !ML_LINEAEEEGRESSION_H
LinearRegression.cpp
#include "iostream"
#include "linearRegression.h"
#include "Utils.h"
using namespace std; /*初始化*/
LinearRegression::LinearRegression(double x[], double y[], int m)
{
this->x = x;
this->y = y;
this->m = m;
} /*
alpha:learn rate
iterations:iterators
*/
void LinearRegression::train(double alpha, int iterations)
{
double *J = new double[iterations];
this->theta = gradient_descent(x, y, alpha, iterations, J, m);
cout << "J=";
for (int i = ; i < iterations; ++i)
{
cout << J[i] << " " << endl;;
}
cout << "\n" << "Theta: " << theta[] << " " << theta[] << endl;
}
/*预测*/
double LinearRegression::predict(double x)
{
cout << "y':" << h(x, theta) << endl;
return h(x, theta);
} /*计算损失模型*/
double LinearRegression::compute_cost(double x[], double y[], double theta[], int m)
{
double *predictions = calculate_predictions(x, theta, m);
double *diff = Utils::array_diff(predictions, y, m);
double *sq_errors = Utils::array_pow(diff, m, );
return (1.0 / ( * m))*Utils::array_sum(sq_errors, m);
}
/*计算单个预测值*/
double LinearRegression::h(double x, double theta[])
{
return theta[] + theta[] * x;
}
/*预测*/
double *LinearRegression::calculate_predictions(double x[], double theta[], int m)
{
double *predictions = new double[m];
for (int i = ; i < m; i++)
{
predictions[i] = h(x[i], theta);
}
return predictions;
}
/*梯度下降*/
double *LinearRegression::gradient_descent(double x[], double y[], double alpha, int iter, double *J, int m)
{
double *theta = new double[];
theta[] = ;
theta[] = ;
for (int i = ; i < iter; i++)
{
double *predictions = calculate_predictions(x, theta, m);
double *diff = Utils::array_diff(predictions, y, m);
double *error_x1 = diff;
double *error_x2 = Utils::array_multiplication(diff, x, m);
/*这里可以设定J损失函数的阈值,也可以设定梯度变化量的阈值*/
theta[] = theta[] - alpha*(1.0 / m) * Utils::array_sum(error_x1, m);
theta[] = theta[] - alpha*(1.0 / m)*Utils::array_sum(error_x2, m);
J[i] = compute_cost(x, y, theta, m);
}
return theta;
}
Test.cpp
#include "iostream"
#include "linearRegression.h" using namespace std; int main()
{
double x[] = {,,,,};
double y[] = {,,,,}; LinearRegression test(x,y,);
test.train(0.1, );
test.predict();
system("pause");
return ;
}
C++ LinearRegression代码实现的更多相关文章
- 代码-Weka的LinearRegression类
package kit.weka; import weka.classifiers.Evaluation; import weka.classifiers.functions.LinearRegres ...
- TensorFlow——LinearRegression简单模型代码
代码函数详解 tf.random.truncated_normal()函数 tf.truncated_normal函数随机生成正态分布的数据,生成的数据是截断的正态分布,截断的标准是2倍的stddev ...
- 建模分析之机器学习算法(附python&R代码)
0序 随着移动互联和大数据的拓展越发觉得算法以及模型在设计和开发中的重要性.不管是现在接触比较多的安全产品还是大互联网公司经常提到的人工智能产品(甚至人类2045的的智能拐点时代).都基于算法及建模来 ...
- TensorFlow实现线性回归模型代码
模型构建 1.示例代码linear_regression_model.py #!/usr/bin/python # -*- coding: utf-8 -* import tensorflow as ...
- Spark MLlib线性回归代码实现及结果展示
线性回归(Linear Regression)是利用称为线性回归方程的最小平方函数对一个或多个自变量和因变量之间关系进行建模的一种回归分析. 这种函数是一个或多个称为回归系数的模型参数的线性组合.只有 ...
- 10 种机器学习算法的要点(附 Python 和 R 代码)
本文由 伯乐在线 - Agatha 翻译,唐尤华 校稿.未经许可,禁止转载!英文出处:SUNIL RAY.欢迎加入翻译组. 前言 谷歌董事长施密特曾说过:虽然谷歌的无人驾驶汽车和机器人受到了许多媒体关 ...
- Python机器学习/LinearRegression(线性回归模型)(附源码)
LinearRegression(线性回归) 2019-02-20 20:25:47 1.线性回归简介 线性回归定义: 百科中解释 我个人的理解就是:线性回归算法就是一个使用线性函数作为模型框架($ ...
- spark LinearRegression 预测缺失字段的值
最近在做金融科技建模的时候,字段里面很多缺少值得时候,模型对于新用户的预测会出现很大的不稳定,即PSI较大的情况. 虽然我们依据字段IV值得大小不断的在调整字段且开发新变量,但是很多IV值很大的字段直 ...
- <转>机器学习系列(9)_机器学习算法一览(附Python和R代码)
转自http://blog.csdn.net/han_xiaoyang/article/details/51191386 – 谷歌的无人车和机器人得到了很多关注,但我们真正的未来却在于能够使电脑变得更 ...
随机推荐
- 前端面试之路之HTML面试真题
1.doctype的意义是什么 让浏览器以标准模式渲染 让浏览器知道元素的合法性 2.HTML XHTML HTML5的关系 HTML属于SGML XHTML属于XML,是HTML进行XML严格化的结 ...
- PHP Trait特性
php类的单继承性,无法同时从两个基类中继承属性和方法,为了解决这个问题,使用Trait特性解决. Trait是一种代码复用技术,为PHP的单继承限制提供了一套灵活的代码复用机制. 用法:通过在类中使 ...
- P哥的桶(线段树+线性基)
https://www.luogu.org/problem/P4839 题目: 有两个操作 1 a b 在a的位置添加b数值 (注意一个位置可以有多个值) 2 a b : 在 a到b的范围任取任意 ...
- 如何修改运行中的docker容器的端口映射
在docker run创建并运行容器的时候,可以通过-p指定端口映射规则.但是,我们经常会遇到刚开始忘记设置端口映射或者设置错了需要修改.当docker start运行容器后并没有提供一个-p选项或设 ...
- HDU 1003 解题报告
问题描述:求最大连续字串 分析:一道简单的DP,状态转移方程是d[i] = ( d[i-1]+a[i] > a[i] ) ? d[i-1]+a[i] : a[i] d[i]表示以第i个数字结尾的 ...
- HDU 5634 Rikka with Phi
题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=5634 ------------------------------------------------ ...
- p5471 [NOI2019]弹跳
分析 代码 #include<bits/stdc++.h> using namespace std; #define fi first #define se second #define ...
- fiddler抓取火狐浏览器上https协议请求
前言:现在很多网站采用https协议,当打开fiddler时.浏览https协议的网站会提示不安全,若使用fiddler抓取https协议的请求,则需要向浏览器导入证书,才能抓取https协议的请求, ...
- 微软手写识别模块sdk及delphi接口例子
http://download.csdn.net/download/coolstar1204/2008061 微软手写识别模块sdk及delphi接口例子
- C++ 编写的DLL导出的函数名乱码含义解析
C++编译时函数名修饰约定规则: __stdcall调用约定: 1.以"?"标识函数名的开始,后跟函数名: 2.函数名后面以"@@YG"标识参数表的 ...