Keras上实现简单线性回归模型
神经网络可以用来模拟回归问题 (regression),实质上是单输入单输出神经网络模型,例如给下面一组数据,用一条线来对数据进行拟合,并可以预测新输入 x 的输出值。
一、详细解读
我们通过这个简单的例子来熟悉Keras构建神经网络的步骤:
1.导入模块并生成数据
首先导入本例子需要的模块,numpy、Matplotlib、和keras.models、keras.layers模块。Sequential是多个网络层的线性堆叠,可以通过向Sequential模型传递一个layer的list来构造该模型,也可以通过.add()方法一个个的将layer加入模型中。layers.Dense 意思是这个神经层是全连接层。
2.建立模型
然后用 Sequential 建立 model,再用 model.add 添加神经层,添加的是 Dense 全连接神经层。参数有两个,(注意此处Keras 2.0.2版本中有变更)一个是输入数据的维度,另一个units代表神经元数,即输出单元数。如果需要添加下一个神经层的时候,不用再定义输入的纬度,因为它默认就把前一层的输出作为当前层的输入。在这个简单的例子里,只需要一层就够了。
3.激活模型
model.compile来激活模型,参数中,误差函数用的是 mse均方误差;优化器用的是 sgd 随机梯度下降法。
4.训练模型
训练的时候用 model.train_on_batch 一批一批的训练 X_train, Y_train。默认的返回值是 cost,每100步输出一下结果。
5.验证模型
用到的函数是 model.evaluate,输入测试集的x和y,输出 cost,weights 和 biases。其中 weights 和 biases 是取在模型的第一层 model.layers[0] 学习到的参数。从学习到的结果你可以看到, weights 比较接近0.5,bias 接近 2。
Weights= [[ 0.49136472]]
biases= [ 2.00405312]
6.可视化学习结果
最后可以画出预测结果,与测试集的值进行对比。
二、完整代码
import numpy as np
np.random.seed(1337)
from keras.models import Sequential
from keras.layers import Dense
import matplotlib.pyplot as plt # 生成数据
X = np.linspace(-1, 1, 200) #在返回(-1, 1)范围内的等差序列
np.random.shuffle(X) # 打乱顺序
Y = 0.5 * X + 2 + np.random.normal(0, 0.05, (200, )) #生成Y并添加噪声
# plot
plt.scatter(X, Y)
plt.show() X_train, Y_train = X[:160], Y[:160] # 前160组数据为训练数据集
X_test, Y_test = X[160:], Y[160:] #后40组数据为测试数据集 # 构建神经网络模型
model = Sequential()
model.add(Dense(input_dim=1, units=1)) # 选定loss函数和优化器
model.compile(loss='mse', optimizer='sgd') # 训练过程
print('Training -----------')
for step in range(501):
cost = model.train_on_batch(X_train, Y_train)
if step % 50 == 0:
print("After %d trainings, the cost: %f" % (step, cost)) # 测试过程
print('\nTesting ------------')
cost = model.evaluate(X_test, Y_test, batch_size=40)
print('test cost:', cost)
W, b = model.layers[0].get_weights()
print('Weights=', W, '\nbiases=', b) # 将训练结果绘出
Y_pred = model.predict(X_test)
plt.scatter(X_test, Y_test)
plt.plot(X_test, Y_pred)
plt.show()
三、其他补充
1. numpy.linspace
numpy.linspace(start, stop, num=50, endpoint=True,retstep=False,dtype=None)
返回等差序列,序列范围在(start,end),生成num个元素的np数组,如果endpoint为False,则生成num+1个但是返回num个,retstep=True则在其后返回步长.
>>> np.linspace(2.0, 3.0, num=5)
array([ 2. , 2.25, 2.5 , 2.75, 3. ])
>>> np.linspace(2.0, 3.0, num=5, endpoint=False)
array([ 2. , 2.2, 2.4, 2.6, 2.8])
>>> np.linspace(2.0, 3.0, num=5, retstep=True)
(array([ 2. , 2.25, 2.5 , 2.75, 3. ]), 0.25)
Keras上实现简单线性回归模型的更多相关文章
- 基于tensorflow的简单线性回归模型
#!/usr/local/bin/python3 ##ljj [1] ##linear regression model import tensorflow as tf import matplotl ...
- 机器学习(2):简单线性回归 | 一元回归 | 损失计算 | MSE
前文再续书接上一回,机器学习的主要目的,是根据特征进行预测.预测到的信息,叫标签. 从特征映射出标签的诸多算法中,有一个简单的算法,叫简单线性回归.本文介绍简单线性回归的概念. (1)什么是简单线性回 ...
- 机器学习——Day 2 简单线性回归
写在开头 由于某些原因开始了机器学习,为了更好的理解和深入的思考(记录)所以开始写博客. 学习教程来源于github的Avik-Jain的100-Days-Of-MLCode 英文版:https:// ...
- Python回归分析五部曲(一)—简单线性回归
回归最初是遗传学中的一个名词,是由英国生物学家兼统计学家高尔顿首先提出来的,他在研究人类身高的时候发现:高个子回归人类的平均身高,而矮个子则从另一方向回归人类的平均身高: 回归分析整体逻辑 回归分析( ...
- day-12 python实现简单线性回归和多元线性回归算法
1.问题引入 在统计学中,线性回归是利用称为线性回归方程的最小二乘函数对一个或多个自变量和因变量之间关系进行建模的一种回归分析.这种函数是一个或多个称为回归系数的模型参数的线性组合.一个带有一个自变 ...
- PRML读书笔记——线性回归模型(上)
本章开始学习第一个有监督学习模型--线性回归模型."线性"在这里的含义仅限定了模型必须是参数的线性函数.而正如我们接下来要看到的,线性回归模型可以是输入变量\(x\)的非线性函数. ...
- 用Tensorflow完成简单的线性回归模型
思路:在数据上选择一条直线y=Wx+b,在这条直线上附件随机生成一些数据点如下图,让TensorFlow建立回归模型,去学习什么样的W和b能更好去拟合这些数据点. 1)随机生成1000个数据点,围绕在 ...
- TensorFlow从1到2(七)线性回归模型预测汽车油耗以及训练过程优化
线性回归模型 "回归"这个词,既是Regression算法的名称,也代表了不同的计算结果.当然结果也是由算法决定的. 不同于前面讲过的多个分类算法或者逻辑回归,线性回归模型的结果是 ...
- R语言解读一元线性回归模型
转载自:http://blog.fens.me/r-linear-regression/ 前言 在我们的日常生活中,存在大量的具有相关性的事件,比如大气压和海拔高度,海拔越高大气压强越小:人的身高和体 ...
随机推荐
- SPA项目开发之登录注册
CMD安装所需要的pom依赖 npm install element-ui -S npm install axios -S npm install qs -S npm install vue-axio ...
- Django 1.11 网站分页设计
参考网址:https://www.cnblogs.com/kongzhagen/p/6640975.html
- 联想笔记本上Ubuntu无线网卡问题
可能有两个问题: 1.无无线网卡驱动 2.无线网卡驱动不能自动加载 问题1:无线网卡驱动 百度出网卡驱动iwlwifi-9000,如iwlwifi-9000-pu-b0-jf-b0-34.618819 ...
- 洛谷P3063 [USACO12DEC]牛奶的路由Milk Routing
链接 其实在博客园里写题解都挺应付的都是在洛谷写了之后 挑一部分粘过来 在洛谷写的也都是废话,是为了凑篇幅 主要就是代码 大体思路就一提 这题贪心不行废话 跑m遍SPFA更新最小值 注意数组记得清空 ...
- prometheus、node_exporter、cAdvisor常用参数
本节将介绍一下我在使用过程中用到的promethues.node_exporter.cAdvisor的常用参数,做一个总结 一.prometheus prometheus分为容器安装和二进制文件安装, ...
- express框架,使用 static 访问 public 内静态文件
使用 express 生成 node 服务器后,我们需要访问放在public文件夹内的静态文件,如上传的图片 我们需要在app.js中添加配置项: app.use('/public',express. ...
- Qt应用开发所需
Qt判断当前操作系统? 可使用宏判断,例如: #ifdef Q_OS_MAC //mac ... #endif #ifdef Q_OS_LINUX //linux ... #endif #ifdef ...
- 逻辑运算符/三元运算符/Scanner
逻辑运算符 &(并且) , |(或者) , !(非) , ^(异或) , &&(短路与) , ||(短路或) 注意事项: a:逻辑运算符一般用于连接boolean类型的表达式或 ...
- xcode添加一个真机设备
1.首先先安装Xcode并且运行Xcode,点击左角菜单Xcode -> Preferences:点击Accounts+号弹菜单点击Add Apple ID:弹框输入账号密码普通账号行需要发者账 ...
- AKKA事件机制
AKKA Event Bus 事件机制就用于当前运行环境,与集群环境不同,详细见AKKA 集群中的发布与订阅Distributed Publish Subscribe in Cluster 简单实现示 ...