版权声明:本文为博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。

神经网络可以用来模拟回归问题 (regression),实质上是单输入单输出神经网络模型,例如给下面一组数据,用一条线来对数据进行拟合,并可以预测新输入 x 的输出值。

一、详细解读

我们通过这个简单的例子来熟悉Keras构建神经网络的步骤:

1.导入模块并生成数据

首先导入本例子需要的模块,numpy、Matplotlib、和keras.models、keras.layers模块。Sequential是多个网络层的线性堆叠,可以通过向Sequential模型传递一个layer的list来构造该模型,也可以通过.add()方法一个个的将layer加入模型中。layers.Dense 意思是这个神经层是全连接层。

2.建立模型

然后用 Sequential 建立 model,再用 model.add 添加神经层,添加的是 Dense 全连接神经层。参数有两个,(注意此处Keras 2.0.2版本中有变更)一个是输入数据的维度,另一个units代表神经元数,即输出单元数。如果需要添加下一个神经层的时候,不用再定义输入的纬度,因为它默认就把前一层的输出作为当前层的输入。在这个简单的例子里,只需要一层就够了。

3.激活模型

model.compile来激活模型,参数中,误差函数用的是 mse均方误差;优化器用的是 sgd 随机梯度下降法。

4.训练模型

训练的时候用 model.train_on_batch 一批一批的训练 X_train, Y_train。默认的返回值是 cost,每100步输出一下结果。

5.验证模型

用到的函数是 model.evaluate,输入测试集的x和y,输出 cost,weights 和 biases。其中 weights 和 biases 是取在模型的第一层 model.layers[0] 学习到的参数。从学习到的结果你可以看到, weights 比较接近0.5,bias 接近 2。

Weights= [[ 0.49136472]]

biases= [ 2.00405312]

6.可视化学习结果

最后可以画出预测结果,与测试集的值进行对比。

二、完整代码

  

import numpy as np
np.random.seed(1337)
from keras.models import Sequential
from keras.layers import Dense
import matplotlib.pyplot as plt # 生成数据
X = np.linspace(-1, 1, 200) #在返回(-1, 1)范围内的等差序列
np.random.shuffle(X) # 打乱顺序
Y = 0.5 * X + 2 + np.random.normal(0, 0.05, (200, )) #生成Y并添加噪声
# plot
plt.scatter(X, Y)
plt.show() X_train, Y_train = X[:160], Y[:160] # 前160组数据为训练数据集
X_test, Y_test = X[160:], Y[160:] #后40组数据为测试数据集 # 构建神经网络模型
model = Sequential()
model.add(Dense(input_dim=1, units=1)) # 选定loss函数和优化器
model.compile(loss='mse', optimizer='sgd') # 训练过程
print('Training -----------')
for step in range(501):
cost = model.train_on_batch(X_train, Y_train)
if step % 50 == 0:
print("After %d trainings, the cost: %f" % (step, cost)) # 测试过程
print('\nTesting ------------')
cost = model.evaluate(X_test, Y_test, batch_size=40)
print('test cost:', cost)
W, b = model.layers[0].get_weights()
print('Weights=', W, '\nbiases=', b) # 将训练结果绘出
Y_pred = model.predict(X_test)
plt.scatter(X_test, Y_test)
plt.plot(X_test, Y_pred)
plt.show()

  

三、其他补充

1. numpy.linspace

numpy.linspace(start, stop, num=50, endpoint=True,retstep=False,dtype=None)

返回等差序列,序列范围在(start,end),生成num个元素的np数组,如果endpoint为False,则生成num+1个但是返回num个,retstep=True则在其后返回步长.

>>> np.linspace(2.0, 3.0, num=5)
array([ 2. , 2.25, 2.5 , 2.75, 3. ])
>>> np.linspace(2.0, 3.0, num=5, endpoint=False)
array([ 2. , 2.2, 2.4, 2.6, 2.8])
>>> np.linspace(2.0, 3.0, num=5, retstep=True)
(array([ 2. , 2.25, 2.5 , 2.75, 3. ]), 0.25)

Keras上实现简单线性回归模型的更多相关文章

  1. 基于tensorflow的简单线性回归模型

    #!/usr/local/bin/python3 ##ljj [1] ##linear regression model import tensorflow as tf import matplotl ...

  2. 机器学习(2):简单线性回归 | 一元回归 | 损失计算 | MSE

    前文再续书接上一回,机器学习的主要目的,是根据特征进行预测.预测到的信息,叫标签. 从特征映射出标签的诸多算法中,有一个简单的算法,叫简单线性回归.本文介绍简单线性回归的概念. (1)什么是简单线性回 ...

  3. 机器学习——Day 2 简单线性回归

    写在开头 由于某些原因开始了机器学习,为了更好的理解和深入的思考(记录)所以开始写博客. 学习教程来源于github的Avik-Jain的100-Days-Of-MLCode 英文版:https:// ...

  4. Python回归分析五部曲(一)—简单线性回归

    回归最初是遗传学中的一个名词,是由英国生物学家兼统计学家高尔顿首先提出来的,他在研究人类身高的时候发现:高个子回归人类的平均身高,而矮个子则从另一方向回归人类的平均身高: 回归分析整体逻辑 回归分析( ...

  5. day-12 python实现简单线性回归和多元线性回归算法

    1.问题引入  在统计学中,线性回归是利用称为线性回归方程的最小二乘函数对一个或多个自变量和因变量之间关系进行建模的一种回归分析.这种函数是一个或多个称为回归系数的模型参数的线性组合.一个带有一个自变 ...

  6. PRML读书笔记——线性回归模型(上)

    本章开始学习第一个有监督学习模型--线性回归模型."线性"在这里的含义仅限定了模型必须是参数的线性函数.而正如我们接下来要看到的,线性回归模型可以是输入变量\(x\)的非线性函数. ...

  7. 用Tensorflow完成简单的线性回归模型

    思路:在数据上选择一条直线y=Wx+b,在这条直线上附件随机生成一些数据点如下图,让TensorFlow建立回归模型,去学习什么样的W和b能更好去拟合这些数据点. 1)随机生成1000个数据点,围绕在 ...

  8. TensorFlow从1到2(七)线性回归模型预测汽车油耗以及训练过程优化

    线性回归模型 "回归"这个词,既是Regression算法的名称,也代表了不同的计算结果.当然结果也是由算法决定的. 不同于前面讲过的多个分类算法或者逻辑回归,线性回归模型的结果是 ...

  9. R语言解读一元线性回归模型

    转载自:http://blog.fens.me/r-linear-regression/ 前言 在我们的日常生活中,存在大量的具有相关性的事件,比如大气压和海拔高度,海拔越高大气压强越小:人的身高和体 ...

随机推荐

  1. (HK1-2)海康相机直接连接电脑不经过路由器设置

    解决电脑无法通过网线直连海康摄像机的问题 https://blog.csdn.net/u014552102/article/details/86708371 一.现象:    通过博主的另外一篇博客h ...

  2. Windbg Register(寄存器)窗口的使用

    寄存器是位于在 CPU 的小易失性内存单位. 许多寄存器专用于特定用途,并可用于用户模式应用程序使用的其他寄存器. 基于 x86 和基于 x64 的处理器在有可用的寄存器的不同集合. 如何打开寄存器窗 ...

  3. redis使用摘要

    一.redis使用: 在下载安装好redis后,pycharm内也需要安装redis工具包.cmd窗口运行pip install redis后才可在pycharm 内导入import redis来使用 ...

  4. 奇技淫巧and板子

    目录 求第\(k\)大的数 求长度不小于L的子段使之和最大 ST表 \(O(1)\)实现能查询栈中最小元素 二分 树和图的深度优先遍历和广度优先遍历 树的dfs序 求树的重心 图的联通块划分 拓扑排序 ...

  5. vs启动报4.X的错

    参考: https://www.cnblogs.com/zsx-blog/p/6136956.html https://blog.csdn.net/lishaoran369/article/detai ...

  6. Computer-Hunters——项目系统设计与数据库设计

    Computer-Hunters--项目系统设计与数据库设计 前言 本次作业属于2019秋福大软件工程实践Z班 本次作业要求 团队名称: Computer-Hunters 本次作业目标:撰写一份针对团 ...

  7. ROS+gre over ipv6,在ipv6的基础上建立GRE隧道

    感谢群众大神 @镜花水月全程技术支持.感谢! 环境: 1.阿里云华北5,申请ipv6测试,申请ECS带ipv6公网ip 2.为了便于测试便捷,在vultr创建2.5美元的vps,带ipv6 对了,测试 ...

  8. AQS源码的简单理解

    概念 AQS全称 AbstractQueuedSynchronizer. AQS是一个并发包的基础组件,用来实现各种锁,各种同步组件的.它包含了state变量.加锁线程.等待队列等并发中的核心组件. ...

  9. 记录一次在生成数据库服务器上出现The timeout period elapsed prior to completion of the operation or the server is not responding.和Exception has been thrown by the target of an invocation的解决办法

    记一次查询超时的解决方案The timeout period elapsed...... https://www.cnblogs.com/wyt007/p/9274613.html Exception ...

  10. 4 实战CPU上下文