版权声明:本文为博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。

神经网络可以用来模拟回归问题 (regression),实质上是单输入单输出神经网络模型,例如给下面一组数据,用一条线来对数据进行拟合,并可以预测新输入 x 的输出值。

一、详细解读

我们通过这个简单的例子来熟悉Keras构建神经网络的步骤:

1.导入模块并生成数据

首先导入本例子需要的模块,numpy、Matplotlib、和keras.models、keras.layers模块。Sequential是多个网络层的线性堆叠,可以通过向Sequential模型传递一个layer的list来构造该模型,也可以通过.add()方法一个个的将layer加入模型中。layers.Dense 意思是这个神经层是全连接层。

2.建立模型

然后用 Sequential 建立 model,再用 model.add 添加神经层,添加的是 Dense 全连接神经层。参数有两个,(注意此处Keras 2.0.2版本中有变更)一个是输入数据的维度,另一个units代表神经元数,即输出单元数。如果需要添加下一个神经层的时候,不用再定义输入的纬度,因为它默认就把前一层的输出作为当前层的输入。在这个简单的例子里,只需要一层就够了。

3.激活模型

model.compile来激活模型,参数中,误差函数用的是 mse均方误差;优化器用的是 sgd 随机梯度下降法。

4.训练模型

训练的时候用 model.train_on_batch 一批一批的训练 X_train, Y_train。默认的返回值是 cost,每100步输出一下结果。

5.验证模型

用到的函数是 model.evaluate,输入测试集的x和y,输出 cost,weights 和 biases。其中 weights 和 biases 是取在模型的第一层 model.layers[0] 学习到的参数。从学习到的结果你可以看到, weights 比较接近0.5,bias 接近 2。

Weights= [[ 0.49136472]]

biases= [ 2.00405312]

6.可视化学习结果

最后可以画出预测结果,与测试集的值进行对比。

二、完整代码

  

import numpy as np
np.random.seed(1337)
from keras.models import Sequential
from keras.layers import Dense
import matplotlib.pyplot as plt # 生成数据
X = np.linspace(-1, 1, 200) #在返回(-1, 1)范围内的等差序列
np.random.shuffle(X) # 打乱顺序
Y = 0.5 * X + 2 + np.random.normal(0, 0.05, (200, )) #生成Y并添加噪声
# plot
plt.scatter(X, Y)
plt.show() X_train, Y_train = X[:160], Y[:160] # 前160组数据为训练数据集
X_test, Y_test = X[160:], Y[160:] #后40组数据为测试数据集 # 构建神经网络模型
model = Sequential()
model.add(Dense(input_dim=1, units=1)) # 选定loss函数和优化器
model.compile(loss='mse', optimizer='sgd') # 训练过程
print('Training -----------')
for step in range(501):
cost = model.train_on_batch(X_train, Y_train)
if step % 50 == 0:
print("After %d trainings, the cost: %f" % (step, cost)) # 测试过程
print('\nTesting ------------')
cost = model.evaluate(X_test, Y_test, batch_size=40)
print('test cost:', cost)
W, b = model.layers[0].get_weights()
print('Weights=', W, '\nbiases=', b) # 将训练结果绘出
Y_pred = model.predict(X_test)
plt.scatter(X_test, Y_test)
plt.plot(X_test, Y_pred)
plt.show()

  

三、其他补充

1. numpy.linspace

numpy.linspace(start, stop, num=50, endpoint=True,retstep=False,dtype=None)

返回等差序列,序列范围在(start,end),生成num个元素的np数组,如果endpoint为False,则生成num+1个但是返回num个,retstep=True则在其后返回步长.

>>> np.linspace(2.0, 3.0, num=5)
array([ 2. , 2.25, 2.5 , 2.75, 3. ])
>>> np.linspace(2.0, 3.0, num=5, endpoint=False)
array([ 2. , 2.2, 2.4, 2.6, 2.8])
>>> np.linspace(2.0, 3.0, num=5, retstep=True)
(array([ 2. , 2.25, 2.5 , 2.75, 3. ]), 0.25)

Keras上实现简单线性回归模型的更多相关文章

  1. 基于tensorflow的简单线性回归模型

    #!/usr/local/bin/python3 ##ljj [1] ##linear regression model import tensorflow as tf import matplotl ...

  2. 机器学习(2):简单线性回归 | 一元回归 | 损失计算 | MSE

    前文再续书接上一回,机器学习的主要目的,是根据特征进行预测.预测到的信息,叫标签. 从特征映射出标签的诸多算法中,有一个简单的算法,叫简单线性回归.本文介绍简单线性回归的概念. (1)什么是简单线性回 ...

  3. 机器学习——Day 2 简单线性回归

    写在开头 由于某些原因开始了机器学习,为了更好的理解和深入的思考(记录)所以开始写博客. 学习教程来源于github的Avik-Jain的100-Days-Of-MLCode 英文版:https:// ...

  4. Python回归分析五部曲(一)—简单线性回归

    回归最初是遗传学中的一个名词,是由英国生物学家兼统计学家高尔顿首先提出来的,他在研究人类身高的时候发现:高个子回归人类的平均身高,而矮个子则从另一方向回归人类的平均身高: 回归分析整体逻辑 回归分析( ...

  5. day-12 python实现简单线性回归和多元线性回归算法

    1.问题引入  在统计学中,线性回归是利用称为线性回归方程的最小二乘函数对一个或多个自变量和因变量之间关系进行建模的一种回归分析.这种函数是一个或多个称为回归系数的模型参数的线性组合.一个带有一个自变 ...

  6. PRML读书笔记——线性回归模型(上)

    本章开始学习第一个有监督学习模型--线性回归模型."线性"在这里的含义仅限定了模型必须是参数的线性函数.而正如我们接下来要看到的,线性回归模型可以是输入变量\(x\)的非线性函数. ...

  7. 用Tensorflow完成简单的线性回归模型

    思路:在数据上选择一条直线y=Wx+b,在这条直线上附件随机生成一些数据点如下图,让TensorFlow建立回归模型,去学习什么样的W和b能更好去拟合这些数据点. 1)随机生成1000个数据点,围绕在 ...

  8. TensorFlow从1到2(七)线性回归模型预测汽车油耗以及训练过程优化

    线性回归模型 "回归"这个词,既是Regression算法的名称,也代表了不同的计算结果.当然结果也是由算法决定的. 不同于前面讲过的多个分类算法或者逻辑回归,线性回归模型的结果是 ...

  9. R语言解读一元线性回归模型

    转载自:http://blog.fens.me/r-linear-regression/ 前言 在我们的日常生活中,存在大量的具有相关性的事件,比如大气压和海拔高度,海拔越高大气压强越小:人的身高和体 ...

随机推荐

  1. 洛谷p1458顺序的分数题解

    抱歉,您们的蒟蒻yxj不知道怎么插入链接qwq就只好粘个文本的了qwq:https://www.luogu.org/problemnew/show/P1458 没错,是个黄题,因为你们的小蒟蒻只会这样 ...

  2. hive基础知识四

    1. hive表的数据压缩 1.1 数据的压缩说明 压缩模式评价 可使用以下三种标准对压缩方式进行评价 1.压缩比:压缩比越高,压缩后文件越小,所以压缩比越高越好 2.压缩时间:越快越好 3.已经压缩 ...

  3. ZROI 暑期高端峰会 A班 Day3 字符串

    FBI Warning:本文含有大量人类的本质之一 后缀树 反正后缀树就是反串的后缀自动机的 Parent 树,就不管了. 然而 SAM 也忘了 好的假装自己会吧--dls 后缀自动机 大概记得,不管 ...

  4. ECMAScript6之Class

    1.Class的基本语法 1.1简介 基本上,ES6 的class可以看作只是一个语法糖,它的绝大部分功能,ES5 都可以做到,新的class写法只是让对象原型的写法更加清晰.更像面向对象编程的语法而 ...

  5. JAVA程序执行顺序(静态代码块》非静态代码块》静态方法》构造函数)

    总结:静态代码块总是最先执行. 非静态代码块跟非静态方法一样,跟对象有关.只不过非静态代码块在构造函数之前执行. 父类非静态代码块.构造函数执行完毕后(相当于父类对象初始化完成), 才开始执行子类的非 ...

  6. 计算GPS点之间的距离

    latitude纬度 longtitude经度 // 求弧度 double getRadian(double d) { return d * PI / 180.0; //角度1? = π / 180 ...

  7. python 项目实战之备份文件夹并且压缩文件夹及下面的文件

    #!/usr/bin/env python # -*- coding: utf-8 -*- # @Time : 2019/11/12 14:21 # @Author : zoulixiang # @S ...

  8. JSON Template

    public java.lang.String toString() {#if ( $members.size() > 0 ) #set ( $i = 0 )return "{\&qu ...

  9. Solr7.x学习(7)-JAVA操作

    maven依赖 <dependency> <groupId>org.apache.solr</groupId> <artifactId>solr-sol ...

  10. Asp.Net Core 2.x 和 3.x WebAPI 使用 Swagger 时 API Controller 控制器 Action 方法 隐藏 hidden 与 and 分组 group

    1.前言 为什么我们要隐藏部分接口? 因为我们在用swagger代替接口的时候,难免有些接口会直观的暴露出来,比如我们结合Consul一起使用的时候,会将健康检查接口以及报警通知接口暴露出来,这些接口 ...