版权声明:本文为博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。

神经网络可以用来模拟回归问题 (regression),实质上是单输入单输出神经网络模型,例如给下面一组数据,用一条线来对数据进行拟合,并可以预测新输入 x 的输出值。

一、详细解读

我们通过这个简单的例子来熟悉Keras构建神经网络的步骤:

1.导入模块并生成数据

首先导入本例子需要的模块,numpy、Matplotlib、和keras.models、keras.layers模块。Sequential是多个网络层的线性堆叠,可以通过向Sequential模型传递一个layer的list来构造该模型,也可以通过.add()方法一个个的将layer加入模型中。layers.Dense 意思是这个神经层是全连接层。

2.建立模型

然后用 Sequential 建立 model,再用 model.add 添加神经层,添加的是 Dense 全连接神经层。参数有两个,(注意此处Keras 2.0.2版本中有变更)一个是输入数据的维度,另一个units代表神经元数,即输出单元数。如果需要添加下一个神经层的时候,不用再定义输入的纬度,因为它默认就把前一层的输出作为当前层的输入。在这个简单的例子里,只需要一层就够了。

3.激活模型

model.compile来激活模型,参数中,误差函数用的是 mse均方误差;优化器用的是 sgd 随机梯度下降法。

4.训练模型

训练的时候用 model.train_on_batch 一批一批的训练 X_train, Y_train。默认的返回值是 cost,每100步输出一下结果。

5.验证模型

用到的函数是 model.evaluate,输入测试集的x和y,输出 cost,weights 和 biases。其中 weights 和 biases 是取在模型的第一层 model.layers[0] 学习到的参数。从学习到的结果你可以看到, weights 比较接近0.5,bias 接近 2。

Weights= [[ 0.49136472]]

biases= [ 2.00405312]

6.可视化学习结果

最后可以画出预测结果,与测试集的值进行对比。

二、完整代码

  

import numpy as np
np.random.seed(1337)
from keras.models import Sequential
from keras.layers import Dense
import matplotlib.pyplot as plt # 生成数据
X = np.linspace(-1, 1, 200) #在返回(-1, 1)范围内的等差序列
np.random.shuffle(X) # 打乱顺序
Y = 0.5 * X + 2 + np.random.normal(0, 0.05, (200, )) #生成Y并添加噪声
# plot
plt.scatter(X, Y)
plt.show() X_train, Y_train = X[:160], Y[:160] # 前160组数据为训练数据集
X_test, Y_test = X[160:], Y[160:] #后40组数据为测试数据集 # 构建神经网络模型
model = Sequential()
model.add(Dense(input_dim=1, units=1)) # 选定loss函数和优化器
model.compile(loss='mse', optimizer='sgd') # 训练过程
print('Training -----------')
for step in range(501):
cost = model.train_on_batch(X_train, Y_train)
if step % 50 == 0:
print("After %d trainings, the cost: %f" % (step, cost)) # 测试过程
print('\nTesting ------------')
cost = model.evaluate(X_test, Y_test, batch_size=40)
print('test cost:', cost)
W, b = model.layers[0].get_weights()
print('Weights=', W, '\nbiases=', b) # 将训练结果绘出
Y_pred = model.predict(X_test)
plt.scatter(X_test, Y_test)
plt.plot(X_test, Y_pred)
plt.show()

  

三、其他补充

1. numpy.linspace

numpy.linspace(start, stop, num=50, endpoint=True,retstep=False,dtype=None)

返回等差序列,序列范围在(start,end),生成num个元素的np数组,如果endpoint为False,则生成num+1个但是返回num个,retstep=True则在其后返回步长.

>>> np.linspace(2.0, 3.0, num=5)
array([ 2. , 2.25, 2.5 , 2.75, 3. ])
>>> np.linspace(2.0, 3.0, num=5, endpoint=False)
array([ 2. , 2.2, 2.4, 2.6, 2.8])
>>> np.linspace(2.0, 3.0, num=5, retstep=True)
(array([ 2. , 2.25, 2.5 , 2.75, 3. ]), 0.25)

Keras上实现简单线性回归模型的更多相关文章

  1. 基于tensorflow的简单线性回归模型

    #!/usr/local/bin/python3 ##ljj [1] ##linear regression model import tensorflow as tf import matplotl ...

  2. 机器学习(2):简单线性回归 | 一元回归 | 损失计算 | MSE

    前文再续书接上一回,机器学习的主要目的,是根据特征进行预测.预测到的信息,叫标签. 从特征映射出标签的诸多算法中,有一个简单的算法,叫简单线性回归.本文介绍简单线性回归的概念. (1)什么是简单线性回 ...

  3. 机器学习——Day 2 简单线性回归

    写在开头 由于某些原因开始了机器学习,为了更好的理解和深入的思考(记录)所以开始写博客. 学习教程来源于github的Avik-Jain的100-Days-Of-MLCode 英文版:https:// ...

  4. Python回归分析五部曲(一)—简单线性回归

    回归最初是遗传学中的一个名词,是由英国生物学家兼统计学家高尔顿首先提出来的,他在研究人类身高的时候发现:高个子回归人类的平均身高,而矮个子则从另一方向回归人类的平均身高: 回归分析整体逻辑 回归分析( ...

  5. day-12 python实现简单线性回归和多元线性回归算法

    1.问题引入  在统计学中,线性回归是利用称为线性回归方程的最小二乘函数对一个或多个自变量和因变量之间关系进行建模的一种回归分析.这种函数是一个或多个称为回归系数的模型参数的线性组合.一个带有一个自变 ...

  6. PRML读书笔记——线性回归模型(上)

    本章开始学习第一个有监督学习模型--线性回归模型."线性"在这里的含义仅限定了模型必须是参数的线性函数.而正如我们接下来要看到的,线性回归模型可以是输入变量\(x\)的非线性函数. ...

  7. 用Tensorflow完成简单的线性回归模型

    思路:在数据上选择一条直线y=Wx+b,在这条直线上附件随机生成一些数据点如下图,让TensorFlow建立回归模型,去学习什么样的W和b能更好去拟合这些数据点. 1)随机生成1000个数据点,围绕在 ...

  8. TensorFlow从1到2(七)线性回归模型预测汽车油耗以及训练过程优化

    线性回归模型 "回归"这个词,既是Regression算法的名称,也代表了不同的计算结果.当然结果也是由算法决定的. 不同于前面讲过的多个分类算法或者逻辑回归,线性回归模型的结果是 ...

  9. R语言解读一元线性回归模型

    转载自:http://blog.fens.me/r-linear-regression/ 前言 在我们的日常生活中,存在大量的具有相关性的事件,比如大气压和海拔高度,海拔越高大气压强越小:人的身高和体 ...

随机推荐

  1. MongoDB的安装与简单使用

    一.安装MongoDB的步骤 注:本教程全部统一采用hadoop用户名登录Linux系统,用户名:hadoop 密码:hadoop ​ 首先,在Linux系统中打开一个终端,执行如下命令导入公共秘钥到 ...

  2. CSP 2019游记 & 退役记

    扶苏让我记录他AK CSP 的事实 ZAY NB!!! "你不配" 两年半的旅行结束了,我背着满满的行囊下了车,望着毫不犹豫远去的列车,我笑着哭了,笑着翻着我的行囊-- 游记 Da ...

  3. CSS居中方案

    1.行内元素或者内联元素 1.垂直居中 设置行高和高度一致,如果没必要设置高度的话,可以直接利用line-height垂直性,直接设置需要的高度为line-height的高度亦可居中 .center- ...

  4. springcloud的Hystrix turbine断路器聚合监控实现(基于springboot2.02版本)

    本文基于方志朋先生的博客实现:https://blog.csdn.net/forezp/article/details/70233227 一.准本工作 1.工具:Idea,JDK1.8,Maven3. ...

  5. Solr7.x学习(8)-使用spring-data-solr

    1.maven配置 <dependency> <groupId>org.springframework.data</groupId> <artifactId& ...

  6. SharePoint - Another Way to Delete Site Collection

    I had created a site collection. But there is a problem of web-frontend server (I did not know when ...

  7. iOS 测试 WebDriverAgent 简介

    WebDriverAgent 是什么   去年的 SeleniumConf 上,Facebook 推出了一款新的iOS移动测试框架 —— WebDriverAgent,当时的推文上,写的还只支持模拟器 ...

  8. [转帖]来聊聊,华为与H3C(华三)的前世今生!

    本篇,是以真实事件改编,将以故事篇的方式呈现出来. 本故事将分为两个篇幅讲述. 在中国的网络通信设备市场,有两个华字辈的选手,一名叫“华为技术有限公司”,另一名叫“杭州华三通信技术有限公司”. 这两个 ...

  9. C# 实现Escape UnEscape方法(支持中文-转载)

    //Escape方法 public static string Escape(string s) { StringBuilder sb = new StringBuilder(); byte[] by ...

  10. react项目添加本地音频

    <audio src="./res/audio/alarm.mp3" autoplay="autoplay" loop="loop"  ...