Keras 训练一个单层全连接网络的线性回归模型
1、准备环境,探索数据
import numpy as np
from keras.models import Sequential
from keras.layers import Dense
import matplotlib.pyplot as plt # 创建数据集
rng = np.random.RandomState(27)
X = np.linspace(-3, 5, 300)
rng.shuffle(X) # 将数据集随机化
y = 0.5 * X + 1 + np.random.normal(0, 0.05, 300) # 假设真实模型为:y = 0.5X + 1 # 绘制数据集
plt.scatter(X, y, s=0.5)
plt.show()

2、准备数据训练模型
# 划分训练集和测试集
X_train, y_train = X[:400], y[:400]
X_test, y_test = X[-100:], y[-100:] # 定义模型
model = Sequential () # 用 Keras 序贯模型(Sequential)定义一个单输入单输出的模型 model
model.add(Dense(output_dim=1, input_dim=1)) # 通过 add()方法一层, Dense 是全连接层,第一层需要定义输入 # 设置模型参数
model.compile(loss='mse', optimizer='sgd') # 通过compile()方法选择损失函数(均方误差)和 优化器(随机梯度下降) # 开始训练
print('Training ==========')
for step in range(301):
cost = model.train_on_batch(X_train, y_train) # Keras 的 train_on_batch() 函数训练模型
if step % 100 == 0:
print('train cost: ', cost)

3、测试训练好的模型
print('\nTesting ==========')
cost = model.evaluate(X_test, y_test, batch_size=40)
print('test cost:', cost)
W, b = model.layers[0].get_weights() # 查看训练出的网络参数
print('Weights=', W, '\nbiases=', b) # 由于网络只有一层,且每次训练的输入和输出只有一个节点,因此第一层训练出 y=WX+b 的模型,其中 W,b 为训练出的参数

最终的测试 cost 为: 0.0026768923737108706
4、可视化测试结果
y_pred = model.predict(X_test) # 用测试集进行预测
plt.scatter(X_test, y_test, s=4) # 绘制测试点图
plt.plot(X_test, y_pred, lw=0.7) # 绘制回归直线
plt.show()

。。。
Keras 训练一个单层全连接网络的线性回归模型的更多相关文章
- PRML读书笔记——线性回归模型(上)
本章开始学习第一个有监督学习模型--线性回归模型."线性"在这里的含义仅限定了模型必须是参数的线性函数.而正如我们接下来要看到的,线性回归模型可以是输入变量\(x\)的非线性函数. ...
- TensorFlow从1到2(七)线性回归模型预测汽车油耗以及训练过程优化
线性回归模型 "回归"这个词,既是Regression算法的名称,也代表了不同的计算结果.当然结果也是由算法决定的. 不同于前面讲过的多个分类算法或者逻辑回归,线性回归模型的结果是 ...
- keras训练cnn模型时loss为nan
keras训练cnn模型时loss为nan 1.首先记下来如何解决这个问题的:由于我代码中 model.compile(loss='categorical_crossentropy', optimiz ...
- Keras(一)Sequential与Model模型、Keras基本结构功能
keras介绍与基本的模型保存 思维导图 1.keras网络结构 2.keras网络配置 3.keras预处理功能 模型的节点信息提取 config = model.get_config() 把mod ...
- 线性回归模型的 MXNet 与 TensorFlow 实现
本文主要探索如何使用深度学习框架 MXNet 或 TensorFlow 实现线性回归模型?并且以 Kaggle 上数据集 USA_Housing 做线性回归任务来预测房价. 回归任务,scikit-l ...
- 【scikit-learn】scikit-learn的线性回归模型
内容概要 怎样使用pandas读入数据 怎样使用seaborn进行数据的可视化 scikit-learn的线性回归模型和用法 线性回归模型的评估測度 特征选择的方法 作为有监督学习,分类问题是预 ...
- R语言解读多元线性回归模型
转载:http://blog.fens.me/r-multi-linear-regression/ 前言 本文接上一篇R语言解读一元线性回归模型.在许多生活和工作的实际问题中,影响因变量的因素可能不止 ...
- 机器学习(一) 从一个R语言案例学线性回归
写在前面的话 按照正常的顺序,本文应该先讲一些线性回归的基本概念,比如什么叫线性回归,线性回规的常用解法等.但既然本文名为<从一个R语言案例学会线性回归>,那就更重视如何使用R语言去解决线 ...
- 多元线性回归模型的特征压缩:岭回归和Lasso回归
多元线性回归模型中,如果所有特征一起上,容易造成过拟合使测试数据误差方差过大:因此减少不必要的特征,简化模型是减小方差的一个重要步骤.除了直接对特征筛选,来也可以进行特征压缩,减少某些不重要的特征系数 ...
随机推荐
- 不使用xftp上传/下载文件到linux
yum install lrzsz # 安装软件 window端上传到linux端: 1. window端先压缩需上传的文件 2. linux端运行命令rz 3. 在弹出的窗口选择压缩好的文件, ...
- 考试总结(橙题WA)
又逢校内测,成绩变化大 初见三道题,暗喜AK辣 谁知数据毒,特判不到家 三题两题WA,心态已爆炸 T1(我不想再见到这道题): 附上多年前AC但是随便出(毒瘤)一组数据就可以卡掉的代码: #inclu ...
- 洛谷4965 薇尔莉特的打字机(Trie,DP)
神仙题. 考虑在一棵 Trie 上进行染色,将可能出现的串的末尾染成黑色.答案就是黑点的个数.一开始只有 \(A\) 的末尾点是黑色. 当出现一个字符(不是退格)\(c\) 时,就要将每个黑点的 \( ...
- MySQL实战45讲学习笔记:第二十四讲
一.引子 在前面的文章中,我不止一次地和你提到了 binlog,大家知道 binlog 可以用来归档,也可以用来做主备同步,但它的内容是什么样的呢?为什么备库执行了 binlog 就可以跟主库保持一致 ...
- [LeetCode] 763. Partition Labels 分割标签
A string S of lowercase letters is given. We want to partition this string into as many parts as pos ...
- js中的require、define、export、import【转】
原文链接:https://www.cnblogs.com/libin-1/p/7127481.html 为什么有模块概念 理想情况下,开发者只需要实现核心的业务逻辑,其他都可以加载别人已经写好的模块. ...
- Java之数据库基础理论
一.事务的四大特性 ACID 只有满足一致性,事务的执行结果才是正确的. 在无并发的情况下,事务串行执行,隔离性一定能够满足.此时要只要能满足原子性,就一定能满足一致性. 在并发的情况下,多个事务并发 ...
- linux php composer安装和使用教程
linux php composer安装和使用教程建议在linux下 下载后 然后再下载到本地 win上最好别用composer下载速度超级慢 或者根本下不动 项目依赖包 ...
- 用cp命令拷贝文件,源目录后带不带斜杠的区别
当我还是Linux超级傻白的时候,需要拷贝一个很大的数据集,然后再拷贝源文件夹的后面跟了一个前倾斜杠,然后就发现居然拷贝的是整个文件夹里的东西,而不是文件夹本身.事儿倒是不大,我重新建一个文件夹,把这 ...
- JavaIO学习:字节流
JavaIO流之字节流 字节流 抽象基类:InputStream,OutputStream. 字节流可以操作任何数据. 注意: 字符流使用的数组是字符数组,char[] chs : 字节流使用的数组是 ...