这篇文章中,我们将使用TensorFlow.js来根据数据拟合曲线。即使用多项式产生数据然后再改变其中某些数据(点),然后我们会训练模型来找到用于产生这些数据的多项式的系数。简单的说,就是给一些在二维坐标中的散点图,然后我们建立一个系数未知的多项式,通过TensorFlow.js来训练模型,最终找到这些未知的系数,让这个多项式和散点图拟合。

  

一、运行代码

  这篇文章关注的是创建模型以及学习模型的系数,完整的代码在这里可以找到。为了在本地运行,如下所示:

$ git clone https://github.com/tensorflow/tfjs-examples.git
$ cd tfjs-examples/polynomial-regression-core
$ yarn
$ yarn watch

  即首先将核心代码下载到本地,然后进入polynomial-regression-core(即多项式回归核心)部分,最后进行yarn安装并运行。

二、输入数据

  我们的数据在x坐标轴和y坐标轴内,看上去就是将之放在了笛卡尔坐标系,如下所示:

  4

  

  即这个图是 y = ax3 + bx2 + cx + d得到的,而在上图中,我们也看到了其真实系数为a=-0.800,b=-0.200,c=0.900,d=0.500,然后这些点是根据真实的点做了一定的偏移。  

  我们的任务就是通过机器学习得到这个函数的系数a、b、c以及d来最好的匹配这些数据。接下来,我们就看看如何通过TensorFlow.js来学习得到这些数据。

三、学习步骤

第一步 :设置变量

  首先,我们需要创建一些变量。即开始我们是不知道a、b、c、d的值的,所以先给他们一个随机数,入戏所示:

const a = tf.variable(tf.scalar(Math.random()));
const b = tf.variable(tf.scalar(Math.random()));
const c = tf.variable(tf.scalar(Math.random()));
const d = tf.variable(tf.scalar(Math.random()));

第二步:创建模型

  我们可以通过TensorFlow.js中的链式调用操作来实现这个多项式方程  y = ax3 + bx2 + cx + d,下面的代码就创建了一个 predict 函数,这个函数将x作为输入,y作为输出:

function predict(x) {
// y = a * x ^ 3 + b * x ^ 2 + c * x + d
return tf.tidy(() => {
return a.mul(x.pow(tf.scalar())) // a * x^3
.add(b.mul(x.square())) // + b * x ^ 2
.add(c.mul(x)) // + c * x
.add(d); // + d
});
}

  其中,在上一篇文章中,我们讲到tf.tify函数用来清除中间张量,其他的都很好理解。

  接着,让我们把这个多项式函数的系数使用之前得到的随机数,可以看到,得到的图应该是这样:

  因为开始时,我们使用的系数是随机数,所以这个函数和给定的数据匹配的非常差,而我们写的模型就是为了通过学习得到更精确的系数值。

第三步:训练模型

  最后一步就是要训练这个模型使得系数和这些散点更加匹配,而为了训练模型,我们需要定义下面的三样东西:

  • 损失函数(loss function):这个损失函数代表了给定多项式和数据的匹配程度。 损失函数值越小,那么这个多项式和数据就跟匹配。
  • 优化器(optimizer):这个优化器实现了一个算法,它会基于损失函数的输出来修正系数值。所以优化器的目的就是尽可能的减小损失函数的值。
  • 训练迭代器(traing loop):即它会不断地运行这个优化器来减少损失函数。

  所以,上面这三样东西的 关系就非常清楚了: 训练迭代器使得优化器不断运行,使得损失函数的值不断减小,以达到多项式和数据尽可能匹配的目的。这样,最终我们就可以得到a、b、c、d较为精确的值了。

  

四、定义损失函数

  这篇文章中,我们使用MSE(均方误差,mean squared error)作为我们的损失函数。MSE的计算非常简单,就是先根据给定的x得到实际的y值与预测得到的y值之差 的平方,然后在对这些差的平方求平均数即可

  

  于是,我们可以这样定义MSE损失函数:

function loss(predictions, labels) {
// 将labels(实际的值)进行抽象
// 然后获取平均数.
const meanSquareError = predictions.sub(labels).square().mean();
return meanSquareError;
}

   即这个损失函数返回的就是一个均方差,如果这个损失函数的值越小,显然数据和系数就拟合的越好。

  

五、定义优化器

  对于我们的优化器而言,我们选用 SGD (Stochastic Gradient Descent)优化器,即随机梯度下降SGD的工作原理就是利用数据中任意的点的梯度以及使用它们的值来决定增加或者减少我们模型中系数的值

  TensorFlow.js提供了一个很方便的函数用来实现SGD,所以你不需要担心自己不会这些特别复杂的数学运算。 即 tf.train.sdg 将一个学习率(learning rate)作为输入,然后返回一个SGDOptimizer对象,它与优化损失函数的值是有关的。

  在提高它的预测能力时,学习率(learning rate)会控制模型调整幅度将会有多大。低的学习率会使得学习过程运行的更慢一些(更多的训练迭代获得更符合数据的系数),而高的学习率将会加速学习过程但是将会导致最终的模型可能在正确值周围摇摆。简单的说,你既想要学的快,又想要学的好,这是不可能的。

  下面的代码就创建了一个学习率为0.5的SGD优化器。

const learningRate = 0.5;
const optimizer = tf.train.sgd(learningRate);

  

六、定义训练迭代器

  既然我们已经定义了损失函数和优化器,那么现在我们就可以创建一个训练迭代器了,它会不断地运行SGD优化器来使不断修正、完善模型的系数来减小损失(MSE)。下面就是我们创建的训练迭代器:

function train(xs, ys, numIterations = ) {

  const learningRate = 0.5;
const optimizer = tf.train.sgd(learningRate); for (let iter = ; iter < numIterations; iter++) {
optimizer.minimize(() => {
const predsYs = predict(xs);
return loss(predsYs, ys);
});
}
}

  现在,让我们一步一步地仔细看看上面的代码。首先,我们定义了训练函数,并且以数据中x和y的值以及制定的迭代次数作为输入:

function train(xs, ys, numIterations) {
...
}

  接下来,我们定义了之前讨论过的学习率(learning rate)以及SGD优化器:

const learningRate = 0.5;
const optimizer = tf.train.sgd(learningRate);

  

  最后,我们定义了一个for循环,这个循环会运行numIterations次训练。在每一次迭代中,我们都调用了optimizer优化器的minimize函数,这就是见证奇迹的地方:

for (let iter = ; iter < numIterations; iter++) {
optimizer.minimize(() => {
const predsYs = predict(xs);
return loss(predsYs, ys);
});
}

  minimize 接受了一个函数作为参数,这个函数做了下面的两件事情:

  1. 首先它对所有的x值通过我们在之前定义的pridict函数预测了y值。
  2. 然后它通过我们之前定义的损失函数返回了这些预测的均方误差。

    

  minimize函数之后会自动调整这些变量(即系数a、b、c、d)来使得损失函数更小。

  在运行训练迭代器之后,a、b、c以及d就会是通过模型75次SGD迭代之后学习到的结果了。

  

七、观察结果吧!

  一旦程序运行结束,我们就可以得到最终的a、b、c和d的结果了,然后使用它们来绘制曲线,如下所示:

  这个结果已经比开始随机分配系数的结果拟合的好得多了!

TensorFlow.js之根据数据拟合曲线的更多相关文章

  1. TensorFlow.js入门(一)一维向量的学习

    TensorFlow的介绍   TensorFlow是谷歌基于DistBelief进行研发的第二代人工智能学习系统,其命名来源于本身的运行原理.Tensor(张量)意味着N维数组,Flow(流)意味着 ...

  2. 转《在浏览器中使用tensorflow.js进行人脸识别的JavaScript API》

    作者 | Vincent Mühle 编译 | 姗姗 出品 | 人工智能头条(公众号ID:AI_Thinker) [导读]随着深度学习方法的应用,浏览器调用人脸识别技术已经得到了更广泛的应用与提升.在 ...

  3. TensorFlow.js之安装与核心概念

    TensorFlow.js是通过WebGL加速.基于浏览器的机器学习js框架.通过tensorflow.js,我们可以在浏览器中开发机器学习.运行现有的模型或者重新训练现有的模型. 一.安装     ...

  4. 大前端技术系列:TWA技术+TensorFlow.js => 集成原生和AI功能的app

    大前端技术系列:TWA技术+TensorFlow.js => 集成原生和AI功能的app ( 本文内容为melodyWxy原作,git地址:https://github.com/melodyWx ...

  5. TensorFlow.js入门:一维向量的学习

    转载自:https://blog.csdn.net/weixin_34061042/article/details/89700664 一维向量及其运算 tensor 是 TensorFlow.js 的 ...

  6. 【一统江湖的大前端(9)】TensorFlow.js 开箱即用的深度学习工具

    示例代码托管在:http://www.github.com/dashnowords/blogs 博客园地址:<大史住在大前端>原创博文目录 目录 一. 上手TensorFlow.js 二. ...

  7. js声明json数据,打印json数据,遍历json数据

    1.js声明json数据: 2.打印json数据: 3.遍历json数据 //声明JSON var json = {}; json.a = 1; //第一种赋值方式(仿对象型) json['b'] = ...

  8. 通过js获取前台数据向一般处理程序传递Json数据,并解析Json数据,将前台传来的Json数据写入数据库表中

    摘自:http://blog.csdn.net/mazhaojuan/article/details/8592015 通过js获取前台数据向一般处理程序传递Json数据,并解析Json数据,将前台传来 ...

  9. 抓取Js动态生成数据且以滚动页面方式分页的网页

    代码也可以从我的开源项目HtmlExtractor中获取. 当我们在进行数据抓取的时候,如果目标网站是以Js的方式动态生成数据且以滚动页面的方式进行分页,那么我们该如何抓取呢? 如类似今日头条这样的网 ...

随机推荐

  1. lnmp源码编译安装zabbix

    软件安装 Mysql 安装 tar xf mysql-5.7.13-1.el6.x86_64.rpm-bundle.tar -C mysql rpm -e --nodeps  mysql-libs-5 ...

  2. VB网络编程中Winsock的使用

    原文链接:http://tech.163.com/06/0407/14/2E46BB930009159S.html 如同上面的内容所描述的,不论您使用UDP协议或是TCP协议,Winsock控件都可以 ...

  3. Mysql之数据库操作

    数据库操作: 链接数据库: mysql -uroot -p masql -uroot -pmysql 退出数据库: exit/quit/ctrl + d   sql语句最后需要分号结尾: 查看时间: ...

  4. 用户权限,pymysql

    单表查询的完整语法 select [distinct] [*|字段|聚合函数|表达式] from tablewhere group byhaving distinctorder bylimit mys ...

  5. Web结构组件

    一.Web结构组件 1.代理 位于客户端和服务器之间的HTTP实体,接收客户端的所有HTTP请求,并将这些请求转发给HTTP服务器. 2.缓存 HTTP的仓库,使常用的页面的副本可以保存在离客户端更近 ...

  6. How to resolve "your security settings have blocked an untrusted application from running" in Mac

    If you encounter the error "your security settings have blocked an untrusted application from r ...

  7. poj 2352 stars 【树状数组】

    题目 题意:按y递增的顺序给出n颗星星的坐标(y相等则x递增),每个星星的等级等于在它左边且在它下边(包括水平和垂直方向)的星星的数量,求出等级为0到n-1的星星分别有多少个. 因为y递增的顺序给出, ...

  8. 2.panel面板

    注:什么时候使用组件,什么时候使用js编写:当要加载的配置项较少的时候可以使用组件,当它要加载的配置项较多的时候就是用js来实现.

  9. iOS 5 故事板进阶(1)

    译自<iOS 5 by tutorials> 在上一章,你已经学习了故事板的基本用法.包括如何向故事板中添加 View Controller,通过 segues 切换 View Contr ...

  10. ASP.NET Web API 框架研究 IoC容器 DependencyResolver

    一.概念 1.IoC(Inversion of Control),控制反转 即将依赖对象的创建和维护交给一个外部容器来负责,而不是应用本身.如,在类型A中需要使用类型B的实例,而B的实例的创建不是由A ...