[深度学习] tf.keras入门3-回归
目录
回归主要基于波士顿房价数据库进行建模,官方文档地址为:https://tensorflow.google.cn/tutorials/keras/basic_regression
波士顿房价数据集
数据集
波士顿数据集是一个回归问题。每个类的观察值数量是均等的,共有 506 个观察,13 个输入变量和1个输出变量。每条数据包含房屋以及房屋周围的详细信息。其中包含城镇犯罪率,一氧化氮浓度,住宅平均房间数,到中心区域的加权距离以及自住房平均房价等等。
但是对于回归问题,需要读取数据后需要将数据集打散,代码如下:
boston_housing = keras.datasets.boston_housing
(train_data, train_labels), (test_data, test_labels) = boston_housing.load_data()
#打散数据集
order = np.argsort(np.random.random(train_labels.shape))
train_data = train_data[order]
train_labels = train_labels[order]
数据集标签展示:
import pandas as pd
column_names = ['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD',
'TAX', 'PTRATIO', 'B', 'LSTAT']
df = pd.DataFrame(train_data, columns=column_names)
df.head()
CRIM | ZN | INDUS | CHAS | NOX | RM | AGE | DIS | RAD | TAX | PTRATIO | B | LSTAT | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.07875 | 45.0 | 3.44 | 0.0 | 0.437 | 6.782 | 41.1 | 3.7886 | 5.0 | 398.0 | 15.2 | 393.87 | 6.68 |
1 | 4.55587 | 0.0 | 18.10 | 0.0 | 0.718 | 3.561 | 87.9 | 1.6132 | 24.0 | 666.0 | 20.2 | 354.70 | 7.12 |
2 | 0.09604 | 40.0 | 6.41 | 0.0 | 0.447 | 6.854 | 42.8 | 4.2673 | 4.0 | 254.0 | 17.6 | 396.90 | 2.98 |
3 | 0.01870 | 85.0 | 4.15 | 0.0 | 0.429 | 6.516 | 27.7 | 8.5353 | 4.0 | 351.0 | 17.9 | 392.43 | 6.36 |
4 | 0.52693 | 0.0 | 6.20 | 0.0 | 0.504 | 8.725 | 83.0 | 2.8944 | 8.0 | 307.0 | 17.4 | 382.00 | 4.63 |
数据归一化
数据的标准化(normalization)是将数据按比例缩放,使之落入一个小的特定区间。在某些比较和评价的指标处理中经常会用到,去除数据的单位限制,将其转化为无量纲的纯数值,便于不同单位或量级的指标能够进行比较和加权。
具体见https://blog.csdn.net/pipisorry/article/details/52247379
#z-score 标准化
mean = train_data.mean(axis=0)
std = train_data.std(axis=0)
train_data = (train_data - mean) / std
test_data = (test_data - mean) / std
模型训练和预测
模型建立和训练
模型建立的通用模式为网络结构确定(网络层数,节点数,输入,输出)、模型训练参数确定(损失函数,优化器、评价标准)、模型训练(训练次数,批次大小)
#z-score 标准化
mean = train_data.mean(axis=0)
std = train_data.std(axis=0)
train_data = (train_data - mean) / std
test_data = (test_data - mean) / std
#模型建立函数
def build_model():
model = keras.Sequential([
keras.layers.Dense(64, activation=tf.nn.relu,
input_shape=(train_data.shape[1],)),
keras.layers.Dense(64, activation=tf.nn.relu),
keras.layers.Dense(1)
])
optimizer = tf.train.RMSPropOptimizer(0.001)
model.compile(loss='mse',
optimizer=optimizer,
metrics=['mae']) #平均绝对误差
return model
#建立模型
model = build_model()
#模型结构显示
model.summary()
模型的训练代码如下:
# 回调函数
class PrintDot(keras.callbacks.Callback):
def on_epoch_end(self,epoch,logs):
if epoch % 100 == 0: print('')
print('.', end='')
EPOCHS = 500
#模型训练
history = model.fit(train_data, train_labels, epochs=EPOCHS,
validation_split=0.2, verbose=1, #verbose训练过程显示
callbacks=[PrintDot()]) #取测试集中的百分之20作为验证集
模型预测
调用history函数可以实现训练过程的可视化
#模型损失函数展示
def plot_history(history):
plt.figure()
plt.xlabel('Epoch')
plt.ylabel('Mean Abs Error [1000$]')
plt.plot(history.epoch, np.array(history.history['mean_absolute_error']),
label='Train Loss')
plt.plot(history.epoch, np.array(history.history['val_mean_absolute_error']),
label = 'Val loss')
plt.legend()
plt.ylim([0,5])
plot_history(history)
为了提前停止训练,可以通过设置回调函数EarlyStopping设置训练停止条件。
#停止条件设置,即验证集损失连续20次训练没有变化,即停止训练
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', patience=20)
history = model.fit(train_data, train_labels, epochs=EPOCHS,
validation_split=0.2, verbose=0,
callbacks=[early_stop, PrintDot()])
plot_history(history)
模型预测代码如下:
test_predictions = model.predict(test_data).flatten()
print(test_predictions)
总结
总体代码如下:
# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras
# 其他库
import numpy as np
import matplotlib.pyplot as plt
#查看版本
print(tf.__version__)
#1.9.0
boston_housing = keras.datasets.boston_housing
(train_data, train_labels), (test_data, test_labels) = boston_housing.load_data()
#打散数据集
order = np.argsort(np.random.random(train_labels.shape))
train_data = train_data[order]
train_labels = train_labels[order]
import pandas as pd
column_names = ['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD',
'TAX', 'PTRATIO', 'B', 'LSTAT']
df = pd.DataFrame(train_data, columns=column_names)
df.head()
#z-score 标准化
mean = train_data.mean(axis=0)
std = train_data.std(axis=0)
train_data = (train_data - mean) / std
test_data = (test_data - mean) / std
#模型建立函数
def build_model():
model = keras.Sequential([
keras.layers.Dense(64, activation=tf.nn.relu,
input_shape=(train_data.shape[1],)),
keras.layers.Dense(64, activation=tf.nn.relu),
keras.layers.Dense(1)
])
optimizer = tf.train.RMSPropOptimizer(0.001)
model.compile(loss='mse',
optimizer=optimizer,
metrics=['mae']) #平均绝对误差
return model
#建立模型
model = build_model()
#模型结构显示
model.summary()
# 回调函数
class PrintDot(keras.callbacks.Callback):
def on_epoch_end(self,epoch,logs):
if epoch % 100 == 0: print('')
print('.', end='')
EPOCHS = 500
#模型训练
history = model.fit(train_data, train_labels, epochs=EPOCHS,
validation_split=0.2, verbose=1, #verbose训练过程显示
callbacks=[PrintDot()]) #取测试集中的百分之20作为验证集
#模型损失函数展示
def plot_history(history):
plt.figure()
plt.xlabel('Epoch')
plt.ylabel('Mean Abs Error [1000$]')
plt.plot(history.epoch, np.array(history.history['mean_absolute_error']),
label='Train Loss')
plt.plot(history.epoch, np.array(history.history['val_mean_absolute_error']),
label = 'Val loss')
plt.legend()
plt.ylim([0,5])
plot_history(history)
model = build_model()
#停止条件设置,即验证集损失连续20次训练没有变化,即停止训练
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', patience=20)
history = model.fit(train_data, train_labels, epochs=EPOCHS,
validation_split=0.2, verbose=0,
callbacks=[early_stop, PrintDot()])
plot_history(history)
test_predictions = model.predict(test_data).flatten()
print(test_predictions)
对于回归问题的官方总结:
- 均方误差(MSE)是一种常见的用于回归问题损失函数。
- 平均绝对误差(MAE)也是一种常用评价指标而不是精度。
- 对于输入数据,归一化是十分必要的。
- 训练数据较少,则模型结构较小更合适,防止过拟合。
- 提前停止是防止过拟合的好办法。
[深度学习] tf.keras入门3-回归的更多相关文章
- [深度学习] tf.keras入门1-基本函数介绍
目录 构建一个简单的模型 序贯(Sequential)模型 网络层的构造 模型训练和参数评价 模型训练 模型的训练 tf.data的数据集 模型评估和预测 基本模型的建立 网络层模型 模型子类函数构建 ...
- [深度学习] tf.keras入门4-过拟合和欠拟合
过拟合和欠拟合 简单来说过拟合就是模型训练集精度高,测试集训练精度低:欠拟合则是模型训练集和测试集训练精度都低. 官方文档地址为 https://tensorflow.google.cn/tutori ...
- [深度学习] tf.keras入门5-模型保存和载入
目录 设置 基于checkpoints的模型保存 通过ModelCheckpoint模块来自动保存数据 手动保存权重 整个模型保存 总体代码 模型可以在训练中或者训练完成后保存.具体文档参考:http ...
- [深度学习] tf.keras入门2-分类
目录 Fashion MNIST数据库 分类模型的建立 模型预测 总体代码 主要介绍基于tf.keras的Fashion MNIST数据库分类, 官方文档地址为:https://tensorflow. ...
- 深度学习:Keras入门(一)之基础篇
1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深度学习框架. Keras是一个高层神经网络API,支持快速实验,能够把你的idea迅速转换为结 ...
- 深度学习:Keras入门(一)之基础篇【转】
本文转载自:http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorfl ...
- 深度学习:Keras入门(一)之基础篇(转)
转自http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深 ...
- 深度学习:Keras入门(二)之卷积神经网络(CNN)
说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么是卷积? 简单来说,卷积(或内积)就是一种先把对应位置相乘然后再把结果相加的运算.(具体含义或者数学公式 ...
- 深度学习:Keras入门(二)之卷积神经网络(CNN)【转】
本文转载自:https://www.cnblogs.com/lc1217/p/7324935.html 说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么 ...
随机推荐
- Hive之命令
Hive之命令 说明:此博客只记录了一些常见的hql,create/select/insert/update/delete这些基础操作是没有记录的. 一.时间级 select day -- 时间 ,d ...
- JavaScript基本语法(数组与JSON)
5.数组 #①使用new关键字创建数组 // 1.创建数组对象 var arr01 = new Array(); // 2.压入数据 arr01.push("apple"); ar ...
- linux 自动备份mysql数据库
今天一早打开服务器.13W个木马.被爆破成功2次,漏洞3个.数据库被删.这是个悲伤的经历 还好之前有备份,服务器也升级了安全机制,只是备份是上个月的备份.所以想写个脚本,试试自动备份数据库. 1. 先 ...
- C#中下载项目中的文件
1.将需要下载的文档添加到项目的文件夹中 2.接口部分 public IActionResult DownLoad() { var filePath = Directory.GetCurrentDir ...
- while循环条件不成立却无法跳出死循环的问题
在进入循环的时候,实际上是将A从内存加载到寄存器里面运行的,在整个循环中,A这个变量都只是在读取寄存器里面的值. 而当进入中断的时候,中断里面会从内存加载A到寄存器,修改完之后又存到内存里,然后退出中 ...
- CSS布局秘籍(1)-任督二脉BFC/IFC
01.CSS布局 1.1.正常布局流(Normal flow) 正常布局流 就是不做任何布局控制,按照HTML的顺序(从左到右,从上而下)进行布局排列.网页基于盒子模型进行正常的布局,主要特点: 盒子 ...
- 嵌入式-C语言基础:malloc动态开辟内存空间
#include<stdio.h> #include<stdlib.h> int main() { // char *p;//定义一个野指针:没有让它指向一个变量的地址 // ...
- day27 CSS浮动、溢出 & js基本语法 & DOM文档流操作
接day26CSS=>CSS定位 overflow属性 值 描述 示例 visible 默认值,内容不会被修剪,会呈现在元素框之外 hidden 内容会被修剪,并且其余内容是不可见的 overf ...
- 【文档资料】Linux、Vi/Vim常用命令、文件夹和文件介绍
一.Linux 1.系统信息[左1] 查看磁盘空间使用情况:df+参数 查看当前指定文件或目录的大小:du 查看不同硬件信息:cat/proc/xxx 查看系统和空闲内存:free +参数 SSH退出 ...
- 【消息队列面试】11-14:kafka高可靠、高吞吐量、消息丢失、消费模式
十一.kafka消息高可靠的解决方案 1.高可靠=避免消息丢失 解决消息丢失的问题 2.如何解决 (1)保证消息发送是可靠的(发成功了/落到partition) a.ack参数 发送端,采用ack机制 ...