这里基本完全参考网络资源完成,有疑问欢迎留言!

LinearRegression.h

#pragma once
#ifndef ML_LINEAEEEGRESSION_H
#define ML_LINEARREGRESSION_H
class LinearRegression {
public:
/*特征*/
double *x;
/*预测值*/
double *y;
/*样本数量*/
int m;
/*系数*/
double *theta;
/*创建实例*/
LinearRegression(double x[], double y[], int m);
/*训练 */
void train(double alpha, int iterations);
/*预测*/
double predict(double x);
private:
/*计算损失模型*/
static double compute_cost(double x[], double y[], double theta[], int m);
/*计算单个预测值*/
static double h(double x, double theta[]);
/*预测*/
static double *calculate_predictions(double x[], double theta[], int m);
/*梯度下降*/
static double *gradient_descent(double x[], double y[], double alpha, int iter, double *j, int m); };
#endif // !ML_LINEAEEEGRESSION_H

LinearRegression.cpp

#include "iostream"
#include "linearRegression.h"
#include "Utils.h"
using namespace std; /*初始化*/
LinearRegression::LinearRegression(double x[], double y[], int m)
{
this->x = x;
this->y = y;
this->m = m;
} /*
alpha:learn rate
iterations:iterators
*/
void LinearRegression::train(double alpha, int iterations)
{
double *J = new double[iterations];
this->theta = gradient_descent(x, y, alpha, iterations, J, m);
cout << "J=";
for (int i = ; i < iterations; ++i)
{
cout << J[i] << " " << endl;;
}
cout << "\n" << "Theta: " << theta[] << " " << theta[] << endl;
}
/*预测*/
double LinearRegression::predict(double x)
{
cout << "y':" << h(x, theta) << endl;
return h(x, theta);
} /*计算损失模型*/
double LinearRegression::compute_cost(double x[], double y[], double theta[], int m)
{
double *predictions = calculate_predictions(x, theta, m);
double *diff = Utils::array_diff(predictions, y, m);
double *sq_errors = Utils::array_pow(diff, m, );
return (1.0 / ( * m))*Utils::array_sum(sq_errors, m);
}
/*计算单个预测值*/
double LinearRegression::h(double x, double theta[])
{
return theta[] + theta[] * x;
}
/*预测*/
double *LinearRegression::calculate_predictions(double x[], double theta[], int m)
{
double *predictions = new double[m];
for (int i = ; i < m; i++)
{
predictions[i] = h(x[i], theta);
}
return predictions;
}
/*梯度下降*/
double *LinearRegression::gradient_descent(double x[], double y[], double alpha, int iter, double *J, int m)
{
double *theta = new double[];
theta[] = ;
theta[] = ;
for (int i = ; i < iter; i++)
{
double *predictions = calculate_predictions(x, theta, m);
double *diff = Utils::array_diff(predictions, y, m);
double *error_x1 = diff;
double *error_x2 = Utils::array_multiplication(diff, x, m);
/*这里可以设定J损失函数的阈值,也可以设定梯度变化量的阈值*/
theta[] = theta[] - alpha*(1.0 / m) * Utils::array_sum(error_x1, m);
theta[] = theta[] - alpha*(1.0 / m)*Utils::array_sum(error_x2, m);
J[i] = compute_cost(x, y, theta, m);
}
return theta;
}

Test.cpp

#include "iostream"
#include "linearRegression.h" using namespace std; int main()
{
double x[] = {,,,,};
double y[] = {,,,,}; LinearRegression test(x,y,);
test.train(0.1, );
test.predict();
system("pause");
return ;
}

C++ LinearRegression代码实现的更多相关文章

  1. 代码-Weka的LinearRegression类

    package kit.weka; import weka.classifiers.Evaluation; import weka.classifiers.functions.LinearRegres ...

  2. TensorFlow——LinearRegression简单模型代码

    代码函数详解 tf.random.truncated_normal()函数 tf.truncated_normal函数随机生成正态分布的数据,生成的数据是截断的正态分布,截断的标准是2倍的stddev ...

  3. 建模分析之机器学习算法(附python&R代码)

    0序 随着移动互联和大数据的拓展越发觉得算法以及模型在设计和开发中的重要性.不管是现在接触比较多的安全产品还是大互联网公司经常提到的人工智能产品(甚至人类2045的的智能拐点时代).都基于算法及建模来 ...

  4. TensorFlow实现线性回归模型代码

    模型构建 1.示例代码linear_regression_model.py #!/usr/bin/python # -*- coding: utf-8 -* import tensorflow as ...

  5. Spark MLlib线性回归代码实现及结果展示

    线性回归(Linear Regression)是利用称为线性回归方程的最小平方函数对一个或多个自变量和因变量之间关系进行建模的一种回归分析. 这种函数是一个或多个称为回归系数的模型参数的线性组合.只有 ...

  6. 10 种机器学习算法的要点(附 Python 和 R 代码)

    本文由 伯乐在线 - Agatha 翻译,唐尤华 校稿.未经许可,禁止转载!英文出处:SUNIL RAY.欢迎加入翻译组. 前言 谷歌董事长施密特曾说过:虽然谷歌的无人驾驶汽车和机器人受到了许多媒体关 ...

  7. Python机器学习/LinearRegression(线性回归模型)(附源码)

    LinearRegression(线性回归) 2019-02-20  20:25:47 1.线性回归简介 线性回归定义: 百科中解释 我个人的理解就是:线性回归算法就是一个使用线性函数作为模型框架($ ...

  8. spark LinearRegression 预测缺失字段的值

    最近在做金融科技建模的时候,字段里面很多缺少值得时候,模型对于新用户的预测会出现很大的不稳定,即PSI较大的情况. 虽然我们依据字段IV值得大小不断的在调整字段且开发新变量,但是很多IV值很大的字段直 ...

  9. <转>机器学习系列(9)_机器学习算法一览(附Python和R代码)

    转自http://blog.csdn.net/han_xiaoyang/article/details/51191386 – 谷歌的无人车和机器人得到了很多关注,但我们真正的未来却在于能够使电脑变得更 ...

随机推荐

  1. 2017工业软件top100

  2. centos R包 tidyverse安装

    tidyverse安装失败,install.packages('tidyverse') 错误原因大概是其中有个依赖包xml2安装不上,解决办法是yum install libxml2-devel,这样 ...

  3. 123、TensorFlow的Job

    # 如果你在分布式环境中部署TensorFlow # 你或许需要指定job name和task ID # 来将变量放置在参数服务器上 # 将操作放在worker job import tensorfl ...

  4. Git学习及使用

    一.认知git理论 1.git出现的背景 版本控制 版本控制是一种记录一个或若干文件内容变化,以便将来查阅特定版本修订情况的系统. 许多人习惯用复制整个项目目录的方式来保存不同的版本,或许还会改名加上 ...

  5. 笨方法学Python 错误记录

    ex8:忘记输入“空格”ex9:忘记输入“冒号”ex14:%前后要空格,否则errorex21:多个函数嵌套,漏写括号)ex24:%d,漏写d,导致程序错误:"""之间的 ...

  6. c#访问webapi以及获取

    提交post #region XML方式提交        public static void XML() {            HttpWebRequest wReq = (HttpWebRe ...

  7. spring-第六篇之创建bean的3种方式

    1.创建bean的方式有3种: 1>使用构造器创建bean,即设值注入.构造注入本质都是使用bean的构造器创建bean的. 2>使用静态工厂方法创建bean. 3>调用实例工厂方法 ...

  8. #python# 代理过程中遇到的error

    做一下总结 urllib.error.HTTPError: HTTP Error 503: Too many open connections TimeoutError: [WinError 1006 ...

  9. jar包 war包

    jar包和war包的区别: war是一个web模块,其中需要包括WEB-INF,是可以直接运行的WEB模块.而jar一般只是包括一些class文件,在声明了Main_class之后是可以用java命令 ...

  10. Codeforces - 1191B - Tokitsukaze and Mahjong - 模拟

    https://codeforces.com/contest/1191/problem/B 小心坎张听的情况. #include<bits/stdc++.h> using namespac ...