1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using System.Threading.Tasks;
  6.  
  7. namespace ConsoleApp4
  8. {
  9. class Program
  10. {
  11. static void Main(string[] args)
  12. {
  13. List<float[]> inputs_x = new List<float[]>();
  14. inputs_x.Add( new float[] { 0.9f, 0.6f});
  15. inputs_x.Add(new float[] { 2f, 2.5f } );
  16. inputs_x.Add(new float[] { 2.6f, 2.3f });
  17. inputs_x.Add(new float[] { 2.7f, 1.9f });
  18.  
  19. List<float> inputs_y = new List<float>();
  20. inputs_y.Add( 2.5f);
  21. inputs_y.Add( 2.5f);
  22. inputs_y.Add( 3.5f);
  23. inputs_y.Add( 4.2f);
  24.  
  25. float[] weights = new float[3];
  26. for (var i= 0;i < weights.Length;i++)
  27. weights[i] = (float)new Random().NextDouble();
  28.  
  29. int epoch = 30000;
  30. float epsilon =0.00001f;
  31. float lr = 0.01f;
  32.  
  33. float lastCost=0;
  34.  
  35. for (var epoch_i = 0; epoch_i <= epoch; epoch_i++)
  36. {
  37. //随机获取input
  38. var batch = GetRandomBatch(inputs_x, inputs_y, 2);
  39.  
  40. float[] weights_in_poch = new float[weights.Length];
  41.  
  42. foreach (var x_y in batch)
  43. {
  44. var x1 = x_y.Item1.First();
  45. var x2 = x_y.Item1.Skip(1).Take(1).First();
  46. var target_y = x_y.Item2;
  47.  
  48. float diffWithTargetY = target_y - fun(x1, x2, weights[1], weights[2], weights[0]);
  49.  
  50. weights_in_poch[0] += diffWithTargetY * dy_b(x1, x2);
  51. weights_in_poch[1] += diffWithTargetY * dy_theta1(x1, x2);
  52. weights_in_poch[2] += diffWithTargetY * dy_theta2(x1, x2);
  53. }
  54.  
  55. for(var i=0;i<weights.Length;i++)
  56. weights[i] += lr * weights_in_poch[i];
  57.  
  58. float totalErrorCost = 0f;
  59. foreach (var x_y in batch)
  60. {
  61. var x1 = x_y.Item1.First();
  62. var x2 = x_y.Item1.Skip(1).Take(1).First();
  63. var target_y = x_y.Item2;
  64.  
  65. float diffWithTargetY = target_y - fun(x1, x2, weights[1], weights[2], weights[0]);
  66. totalErrorCost += (float)System.Math.Pow(diffWithTargetY, 2)/2;
  67. }
  68.  
  69. float cost = totalErrorCost / batch.Count;
  70.  
  71. if (System.Math.Abs(cost - lastCost) <= epsilon)
  72. {
  73. Console.WriteLine(string.Format("EPOCH {0}", epoch_i));
  74. Console.WriteLine(string.Format("LAST MSE {0}", lastCost));
  75. Console.WriteLine(string.Format("MSE {0}", cost));
  76. break;
  77. }
  78.  
  79. lastCost = cost;
  80.  
  81. if (epoch_i % 100 == 0|| epoch_i==epoch)
  82. {
  83. Console.WriteLine(string.Format("MSE {0}", cost));
  84. }
  85. }
  86.  
  87. print(weights[1], weights[2], weights[0]);
  88.  
  89. Console.ReadLine();
  90. }
  91.  
  92. private static List<Tuple<float[], float>> GetRandomBatch(List<float[]> inputs_x, List<float> inputs_y, int maxCount)
  93. {
  94. List<Tuple<float[], float>> lst = new List<Tuple<float[], float>>();
  95.  
  96. System.Random rnd = new Random((int)DateTime.Now.Ticks);
  97.  
  98. int count = 0;
  99. while (count<maxCount)
  100. {
  101. int rndIndex = rnd.Next(inputs_x.Count);
  102. var item=Tuple.Create<float[], float>(inputs_x[rndIndex], inputs_y[rndIndex]);
  103. lst.Add(item);
  104. count++;
  105. }
  106.  
  107. return lst;
  108. }
  109.  
  110. private static void print(float theta1, float theta2, float b)
  111. {
  112. Console.WriteLine(string.Format("y={0}*x1+{1}*x2+{2}", theta1, theta2, b));
  113. }
  114. private static float fun(float x1, float x2, float theta1, float theta2, float b)
  115. {
  116. return theta1 * x1 + theta2 * x2 + b;
  117. }
  118. private static float dy_theta1(float x1, float x2)
  119. {
  120. return x1;
  121. }
  122.  
  123. private static float dy_theta2(float x1, float x2)
  124. {
  125. return x2;
  126. }
  127.  
  128. private static float dy_b(float x1, float x2)
  129. {
  130. return 1;
  131. }
  132. }
  133. }

  

SGD的更多相关文章

  1. [Machine Learning] 梯度下降法的三种形式BGD、SGD以及MBGD

    在应用机器学习算法时,我们通常采用梯度下降法来对采用的算法进行训练.其实,常用的梯度下降法还具体包含有三种不同的形式,它们也各自有着不同的优缺点. 下面我们以线性回归算法来对三种梯度下降法进行比较. ...

  2. 为什么是梯度下降?SGD

    在机器学习算法中,为了优化损失函数loss function ,我们往往采用梯度下降算法来进行优化.举个例子: 线性SVM的得分函数和损失函数分别为:                         ...

  3. 【原创】batch-GD, SGD, Mini-batch-GD, Stochastic GD, Online-GD -- 大数据背景下的梯度训练算法

    机器学习中梯度下降(Gradient Descent, GD)算法只需要计算损失函数的一阶导数,计算代价小,非常适合训练数据非常大的应用. 梯度下降法的物理意义很好理解,就是沿着当前点的梯度方向进行线 ...

  4. 逻辑回归:使用SGD(Stochastic Gradient Descent)进行大规模机器学习

    Mahout学习算法训练模型 mahout提供了许多分类算法,但许多被设计来处理非常大的数据集,因此可能会有点麻烦.另一方面,有些很容易上手,因为,虽然依然可扩展性,它们具有低开销小的数据集.这样一个 ...

  5. [Machine Learning] 梯度下降(BGD)、随机梯度下降(SGD)、Mini-batch Gradient Descent、带Mini-batch的SGD

    一.回归函数及目标函数 以均方误差作为目标函数(损失函数),目的是使其值最小化,用于优化上式. 二.优化方式(Gradient Descent) 1.最速梯度下降法 也叫批量梯度下降法Batch Gr ...

  6. 监督学习:随机梯度下降算法(sgd)和批梯度下降算法(bgd)

    线性回归 首先要明白什么是回归.回归的目的是通过几个已知数据来预测另一个数值型数据的目标值. 假设特征和结果满足线性关系,即满足一个计算公式h(x),这个公式的自变量就是已知的数据x,函数值h(x)就 ...

  7. tensorflow实现最基本的神经网络 + 对比GD、SGD、batch-GD的训练方法

    参考博客:https://zhuanlan.zhihu.com/p/27853521 该代码默认是梯度下降法,可自行从注释中选择其他训练方法 在异或问题上,由于训练的样本数较少,神经网络简单,训练结果 ...

  8. 深度学习——优化器算法Optimizer详解(BGD、SGD、MBGD、Momentum、NAG、Adagrad、Adadelta、RMSprop、Adam)

    在机器学习.深度学习中使用的优化算法除了常见的梯度下降,还有 Adadelta,Adagrad,RMSProp 等几种优化器,都是什么呢,又该怎么选择呢? 在 Sebastian Ruder 的这篇论 ...

  9. 【深度学习】深入理解优化器Optimizer算法(BGD、SGD、MBGD、Momentum、NAG、Adagrad、Adadelta、RMSprop、Adam)

    在机器学习.深度学习中使用的优化算法除了常见的梯度下降,还有 Adadelta,Adagrad,RMSProp 等几种优化器,都是什么呢,又该怎么选择呢? 在 Sebastian Ruder 的这篇论 ...

  10. 【DeepLearning】优化算法:SGD、GD、mini-batch GD、Moment、RMSprob、Adam

    优化算法 1 GD/SGD/mini-batch GD GD:Gradient Descent,就是传统意义上的梯度下降,也叫batch GD. SGD:随机梯度下降.一次只随机选择一个样本进行训练和 ...

随机推荐

  1. 记录一下通过分析Tomcat内部jar包找出request.getReader()所用的字符编码在哪里设置和起效的完整分析流程

    前言: 之前写Java服务端处理POST请求时遇到了请求体转换成字符流所用编码来源的疑惑,在doPost方法里通过request.getReader()获取的BufferedReader对象内部的 R ...

  2. Java 浏览器兼容模式

    现在设计的东西,很多浏览器不兼容.下面贴出代码.测试在360和IE浏览器下,可以兼容的 <!doctype html><html><head>    <met ...

  3. ionic开发遇到的坑及总结

    前言 ionic是一个用来开发混合手机应用的,开源的,免费的代码库.可以优化html.css和js的性能,构建高效的应用程序,而且还可以用于构建Sass和AngularJS的优化.ionic会是一个可 ...

  4. FreeMarker 快速入门

    FreeMarker 快速入门 FreeMarker是一个很值得去学习的模版引擎.它是基于模板文件生成其他文本的通用工具.本章内容通过如何使用FreeMarker生成Html web 页面 和 代码自 ...

  5. Net Core下多种ORM框架特性及性能对比

    在.NET Framework下有许多ORM框架,最著名的无外乎是Entity Framework,它拥有悠久的历史以及便捷的语法,在占有率上一路领先.但随着Dapper的出现,它的地位受到了威胁,本 ...

  6. Java面向对象编程基础

    一.Java面向对象编程基础 1.什么是对象?Object 什么都是对象! 只要是客观存在的具体事物,都是对象(汽车.小强.事件.任务.按钮.字体) 2.为什么需要面向对象? 面向对象能够像分析现实生 ...

  7. 基于文本图形(ncurses)的文本搜索工具 ncgrep

    背景 作为一个VIM党,日常工作开发中,会经常利用grep进行关键词搜索,以快速定位到文件.如图: 利用grep进行文本搜索 但是,这一过程会有两个效率问题: 展示的结果无法进行直接交互,需要手动粘贴 ...

  8. hashlib,configparser,logging,模块

    一,hashlib模块 算法介绍 Python的hashlib提供了常见的摘要算法,如MD5,SHA1等等. 什么是摘要算法呢?摘要算法又称哈希算法.散列算法.它通过一个函数,把任意长度的数据转换为一 ...

  9. Python 串口通信操作

    下载  pyserial包 https://pypi.python.org/packages/source/p/pyserial/pyserial-2.7.tar.gz#md5=794506184df ...

  10. 51Nod--1006 lcs

    1006 最长公共子序列Lcs 基准时间限制:1 秒 空间限制:131072 KB 分值: 0 难度:基础题  收藏  关注 给出两个字符串A B,求A与B的最长公共子序列(子序列不要求是连续的). ...