《用Python玩转数据》项目—线性回归分析入门之波士顿房价预测(二)
接上一部分,此篇将用tensorflow建立神经网络,对波士顿房价数据进行简单建模预测。
二、使用tensorflow拟合boston房价datasets
1、数据处理依然利用sklearn来分训练集和测试集。
2、使用一层隐藏层的简单网络,试下来用当前这组超参数收敛较快,准确率也可以。
3、激活函数使用relu来引入非线性因子。
4、原本想使用如下方式来动态更新lr,但是尝试下来效果不明显,就索性不要了。
def learning_rate(epoch):
if epoch < 200:
return 0.01
if epoch < 400:
return 0.001
if epoch < 800:
return 1e-4
好了,废话不多说了,看代码如下:
from sklearn import datasets
from sklearn.model_selection import train_test_split
import os
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf dataset = datasets.load_boston()
x = dataset.data
target = dataset.target
y = np.reshape(target,(len(target), 1)) x_train, x_verify, y_train, y_verify = train_test_split(x, y, random_state=1)
y_train = y_train.reshape(-1)
train_data = np.insert(x_train, 0, values=y_train, axis=1) def r_square(y_verify, y_pred):
var = np.var(y_verify)
mse = np.sum(np.power((y_verify-y_pred.reshape(-1,1)), 2))/len(y_verify)
res = 1 - mse/var
print('var:', var)
print('MSE-ljj:', mse)
print('R2-ljj:', res) EPOCH = 3000
lr = tf.placeholder(tf.float32, [], 'lr')
x = tf.placeholder(tf.float32, shape=[None, 13], name='input_feature_x')
y = tf.placeholder(tf.float32, shape=[None, 1], name='input_feature_y') W = tf.Variable(tf.truncated_normal(shape=[13, 10], stddev=0.1))
b = tf.Variable(tf.constant(0., shape=[10])) W2 = tf.Variable(tf.truncated_normal(shape=[10, 1], stddev=0.1))
b2 = tf.Variable(tf.constant(0., shape=[1])) with tf.Session() as sess:
hidden1 = tf.nn.relu(tf.add(tf.matmul(x, W), b)) y_predict = tf.add(tf.matmul(hidden1, W2), b2)
loss = tf.reduce_mean(tf.reduce_sum(tf.pow(y-y_predict,2), reduction_indices=[1]))
print(loss.shape)
train = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss) sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
W_res = 0
b_res = 0
try:
last_chk_path = tf.train.latest_checkpoint(checkpoint_dir='/home/ljj/PycharmProjects/mooc/train_record')
saver.restore(sess, save_path=last_chk_path)
except:
print('no save file to recover-----------start new train instead--------') loss_list = []
over_flag = 0
for i in range(EPOCH):
if over_flag ==1:
break
y_t = train_data[:, 0].reshape(-1, 1)
_, W_res, b_res, loss_train = sess.run([train, W, b, loss],
feed_dict={x: train_data[:, 1:],
y: y_t,
lr: 0.01}) checkpoint_file = os.path.join('/home/ljj/PycharmProjects/mooc/train_record', 'checkpoint')
saver.save(sess, checkpoint_file, global_step=i)
loss_list.append(loss_train)
if loss_train < 0.2:
over_flag = 1
break
if i %500 == 0:
print('EPOCH = {:}, train_loss ={:}'.format(i, loss_train))
if i % 500 == 0:
r = loss.eval(session=sess, feed_dict={x: x_verify,
y: y_verify,
lr: 0.01})
print('verify_loss = ',r)
np.random.shuffle(train_data) plt.plot(range(len(loss_list)-1), loss_list[1:], 'r')
plt.show() print('final loss = ',loss.eval(session=sess, feed_dict={x: x_verify,
y: y_verify,
lr: 0.01})) y_pred = sess.run(y_predict, feed_dict={x: x_verify,
y: y_verify,
lr: 0.01}) plt.subplot(2,1,1)
plt.xlim([0,50])
plt.plot(range(len(y_verify)), y_pred,'b--')
plt.plot(range(len(y_verify)), y_verify,'r')
plt.title('validation') y_ss = sess.run(y_predict, feed_dict={x: x_train,
y: y_train.reshape(-1, 1),
lr: 0.01})
plt.subplot(2,1,2)
plt.xlim([0,50])
plt.plot(range(len(y_train)), y_ss,'r--')
plt.plot(range(len(y_train)), y_train,'b')
plt.title('train') plt.savefig('tf.png')
plt.show() r_square(y_verify, y_pred)
训练了大概3000个epoch后,保存模型,之后可以多次训练,但是loss基本收敛了,没有太大变化。
输出结果如下:
final loss = 15.117827
var: 99.0584735569471
MSE-ljj: 15.11782691349897
R2-ljj: 0.8473848185757882

从图像上看,拟合效果也是一般,再拿一个放大版本的validation图,同样取前50个样本,这样方便和之前的线性回归模型对比。


最后我们还是用数据来说明:
tf模型结果中,
R2:0.847 > 0. 779
MSE:15.1 < 21.8
都比sklearn的线性回归结果要好。所以,此tf模型对波士顿房价数据的可解释性更强。
《用Python玩转数据》项目—线性回归分析入门之波士顿房价预测(二)的更多相关文章
- 用Python玩转数据第六周——高级数据处理与可视化
1.matplotlib中有两个模块,pyplot和pylab import matplotlib.pyplot as plt ///plt.plot(x,y) import pylab as pl ...
- 用Python玩转数据——第五周数据统计和可视化
一.数据获取 1.本地数据 with 语句,pd.read_csv('data.csv') 2.网站上数据 2.1 直接获取网页源码,在用正则表达式进行删选 2.2 API接口获取---以豆瓣为例 i ...
- (转载)微软数据挖掘算法:Microsoft 线性回归分析算法(11)
前言 此篇为微软系列挖掘算法的最后一篇了,完整该篇之后,微软在商业智能这块提供的一系列挖掘算法我们就算总结完成了,在此系列中涵盖了微软在商业智能(BI)模块系统所能提供的所有挖掘算法,当然此框架完全可 ...
- Python之机器学习-波斯顿房价预测
目录 波士顿房价预测 导入模块 获取数据 打印数据 特征选择 散点图矩阵 关联矩阵 训练模型 可视化 波士顿房价预测 导入模块 import pandas as pd import numpy as ...
- 使用sklearn进行数据挖掘-房价预测(4)—数据预处理
在使用机器算法之前,我们先把数据做下预处理,先把特征和标签拆分出来 housing = strat_train_set.drop("median_house_value",axis ...
- 使用sklearn进行数据挖掘-房价预测(3)—绘制数据的分布
使用sklearn进行数据挖掘系列文章: 1.使用sklearn进行数据挖掘-房价预测(1) 2.使用sklearn进行数据挖掘-房价预测(2)-划分测试集 3.使用sklearn进行数据挖掘-房价预 ...
- $用python玩点有趣的数据分析——一元线性回归分析实例
Refer:http://python.jobbole.com/81215/ 本文参考了博乐在线的这篇文章,在其基础上加了一些自己的理解.其原文是一篇英文的博客,讲的通俗易懂. 本文通过一个简单的例子 ...
- 2.3 Hive的数据类型讲解及实际项目中如何使用python脚本对数据进行ETL
一.hive Data Types https://cwiki. apache. org/confluence/display/HiveLanguageManual+Types Numeric Typ ...
- Python即时网络爬虫项目启动说明
作为酷爱编程的老程序员,实在按耐不下这个冲动,Python真的是太火了,不断撩拨我的心. 我是对Python存有戒备之心的,想当年我基于Drupal做的系统,使用php语言,当语言升级了,推翻了老版本 ...
随机推荐
- 【转载】IP地址和子网划分学习笔记之《子网掩码详解》
原文地址: https://blog.51cto.com/6930123/2112748 一.子网掩码 IP地址是以网络号和主机号来标示网络上的主机的,我们把网络号相同的主机称之为本地网络,网络号不相 ...
- java多线程面试中常见知识点
1.进程和线程 (1)进程是资源分配的最小单位,线程是程序执行的最小单位. (2)进程有自己的独立地址空间,每启动一个进程,系统就会为它分配地址空间,建立数据表来维护代码段.堆栈段和数据段,这种操作非 ...
- git add.后回退 代码丢失
记录一次操作git丢失代码的过程: 写完代码后:git staus git add. git status 发现有一堆.class 文件不想提交,想着代码回退到add 之前,使用了 git log 开 ...
- throw与throws
throws可以单独使用(一直上抛) throw要么和try-catch-finally语句配套使用,要么与throws配套使用 /** * 总结: * 1.throws是方法抛出异常.如: p ...
- nginx+uWSGI+django+virtualenv+supervisor发布web服务器
nginx+uWSGI+django+virtualenv+supervisor发布web服务器 导论 WSGI是Web服务器网关接口.它是一个规范,描述了Web服务器如何与Web应用程序通信,以 ...
- centos7搭建vsftpd并启用虚拟用户
虚拟用户的特点是只能访问服务器为其提供的FTP服务,不能访问系统的其它资源,所以,如果想让用户对FTP服务器站内具有写权限,但又不允许访问系统其他资源,可以使用虚拟用户来提高系统的安全性. 在vsft ...
- HFun.快速开发平台(五)=》自定义系统数据选择
本篇介绍HFun.快速开发平台的另一项系统常用功能:系统数据或参数选择,主要应用在表单录入中信息的选择,如类别,编号等.先贴出本系统实现的页面效果: 如上图所示,系统中将参数的选择统一展现为该方式,开 ...
- Django基础-02
django的介绍: Django 中提供了开发网站经常用到的模块,常见的代码都为你写好了,通过减少重复的代码,Django 使你能够专注于 web 应用上有 趣的关键性的东西.为了达到这个目标,Dj ...
- 一个简单的Quartz定时任务
package com.shuadan.quartz; import org.springframework.scheduling.annotation.Scheduled; import org.s ...
- 声明一个set集合,使用HashSet类,来保存十个字符串信息,然后通过这个集合,然后使用iterator()方法,得到一个迭代器,遍历所有的集合中所有的字符串;然后拿出所有的字符串拼接到一个StringBuffer对象中,然后输出它的长度和具体内容; 验证集合的remove()、size()、contains()、isEmpty()等
package com.lanxi.demo1_3; import java.util.HashSet; import java.util.Iterator; import java.util.Set ...