接上一部分,此篇将用tensorflow建立神经网络,对波士顿房价数据进行简单建模预测。

二、使用tensorflow拟合boston房价datasets

1、数据处理依然利用sklearn来分训练集和测试集。

2、使用一层隐藏层的简单网络,试下来用当前这组超参数收敛较快,准确率也可以。

3、激活函数使用relu来引入非线性因子。

4、原本想使用如下方式来动态更新lr,但是尝试下来效果不明显,就索性不要了。

def learning_rate(epoch):
if epoch < 200:
return 0.01
if epoch < 400:
return 0.001
if epoch < 800:
return 1e-4

好了,废话不多说了,看代码如下:

from sklearn import datasets
from sklearn.model_selection import train_test_split
import os
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf dataset = datasets.load_boston()
x = dataset.data
target = dataset.target
y = np.reshape(target,(len(target), 1)) x_train, x_verify, y_train, y_verify = train_test_split(x, y, random_state=1)
y_train = y_train.reshape(-1)
train_data = np.insert(x_train, 0, values=y_train, axis=1) def r_square(y_verify, y_pred):
var = np.var(y_verify)
mse = np.sum(np.power((y_verify-y_pred.reshape(-1,1)), 2))/len(y_verify)
res = 1 - mse/var
print('var:', var)
print('MSE-ljj:', mse)
print('R2-ljj:', res) EPOCH = 3000
lr = tf.placeholder(tf.float32, [], 'lr')
x = tf.placeholder(tf.float32, shape=[None, 13], name='input_feature_x')
y = tf.placeholder(tf.float32, shape=[None, 1], name='input_feature_y') W = tf.Variable(tf.truncated_normal(shape=[13, 10], stddev=0.1))
b = tf.Variable(tf.constant(0., shape=[10])) W2 = tf.Variable(tf.truncated_normal(shape=[10, 1], stddev=0.1))
b2 = tf.Variable(tf.constant(0., shape=[1])) with tf.Session() as sess:
hidden1 = tf.nn.relu(tf.add(tf.matmul(x, W), b)) y_predict = tf.add(tf.matmul(hidden1, W2), b2)
loss = tf.reduce_mean(tf.reduce_sum(tf.pow(y-y_predict,2), reduction_indices=[1]))
print(loss.shape)
train = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss) sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
W_res = 0
b_res = 0
try:
last_chk_path = tf.train.latest_checkpoint(checkpoint_dir='/home/ljj/PycharmProjects/mooc/train_record')
saver.restore(sess, save_path=last_chk_path)
except:
print('no save file to recover-----------start new train instead--------') loss_list = []
over_flag = 0
for i in range(EPOCH):
if over_flag ==1:
break
y_t = train_data[:, 0].reshape(-1, 1)
_, W_res, b_res, loss_train = sess.run([train, W, b, loss],
feed_dict={x: train_data[:, 1:],
y: y_t,
lr: 0.01}) checkpoint_file = os.path.join('/home/ljj/PycharmProjects/mooc/train_record', 'checkpoint')
saver.save(sess, checkpoint_file, global_step=i)
loss_list.append(loss_train)
if loss_train < 0.2:
over_flag = 1
break
if i %500 == 0:
print('EPOCH = {:}, train_loss ={:}'.format(i, loss_train))
if i % 500 == 0:
r = loss.eval(session=sess, feed_dict={x: x_verify,
y: y_verify,
lr: 0.01})
print('verify_loss = ',r)
np.random.shuffle(train_data) plt.plot(range(len(loss_list)-1), loss_list[1:], 'r')
plt.show() print('final loss = ',loss.eval(session=sess, feed_dict={x: x_verify,
y: y_verify,
lr: 0.01})) y_pred = sess.run(y_predict, feed_dict={x: x_verify,
y: y_verify,
lr: 0.01}) plt.subplot(2,1,1)
plt.xlim([0,50])
plt.plot(range(len(y_verify)), y_pred,'b--')
plt.plot(range(len(y_verify)), y_verify,'r')
plt.title('validation') y_ss = sess.run(y_predict, feed_dict={x: x_train,
y: y_train.reshape(-1, 1),
lr: 0.01})
plt.subplot(2,1,2)
plt.xlim([0,50])
plt.plot(range(len(y_train)), y_ss,'r--')
plt.plot(range(len(y_train)), y_train,'b')
plt.title('train') plt.savefig('tf.png')
plt.show() r_square(y_verify, y_pred)

训练了大概3000个epoch后,保存模型,之后可以多次训练,但是loss基本收敛了,没有太大变化。

输出结果如下:

final loss =  15.117827
var: 99.0584735569471
MSE-ljj: 15.11782691349897
R2-ljj: 0.8473848185757882

从图像上看,拟合效果也是一般,再拿一个放大版本的validation图,同样取前50个样本,这样方便和之前的线性回归模型对比。

最后我们还是用数据来说明:

tf模型结果中,

R2:0.847   > 0. 779

MSE:15.1  < 21.8

都比sklearn的线性回归结果要好。所以,此tf模型对波士顿房价数据的可解释性更强。

《用Python玩转数据》项目—线性回归分析入门之波士顿房价预测(二)的更多相关文章

  1. 用Python玩转数据第六周——高级数据处理与可视化

    1.matplotlib中有两个模块,pyplot和pylab import matplotlib.pyplot as plt  ///plt.plot(x,y) import pylab as pl ...

  2. 用Python玩转数据——第五周数据统计和可视化

    一.数据获取 1.本地数据 with 语句,pd.read_csv('data.csv') 2.网站上数据 2.1 直接获取网页源码,在用正则表达式进行删选 2.2 API接口获取---以豆瓣为例 i ...

  3. (转载)微软数据挖掘算法:Microsoft 线性回归分析算法(11)

    前言 此篇为微软系列挖掘算法的最后一篇了,完整该篇之后,微软在商业智能这块提供的一系列挖掘算法我们就算总结完成了,在此系列中涵盖了微软在商业智能(BI)模块系统所能提供的所有挖掘算法,当然此框架完全可 ...

  4. Python之机器学习-波斯顿房价预测

    目录 波士顿房价预测 导入模块 获取数据 打印数据 特征选择 散点图矩阵 关联矩阵 训练模型 可视化 波士顿房价预测 导入模块 import pandas as pd import numpy as ...

  5. 使用sklearn进行数据挖掘-房价预测(4)—数据预处理

    在使用机器算法之前,我们先把数据做下预处理,先把特征和标签拆分出来 housing = strat_train_set.drop("median_house_value",axis ...

  6. 使用sklearn进行数据挖掘-房价预测(3)—绘制数据的分布

    使用sklearn进行数据挖掘系列文章: 1.使用sklearn进行数据挖掘-房价预测(1) 2.使用sklearn进行数据挖掘-房价预测(2)-划分测试集 3.使用sklearn进行数据挖掘-房价预 ...

  7. $用python玩点有趣的数据分析——一元线性回归分析实例

    Refer:http://python.jobbole.com/81215/ 本文参考了博乐在线的这篇文章,在其基础上加了一些自己的理解.其原文是一篇英文的博客,讲的通俗易懂. 本文通过一个简单的例子 ...

  8. 2.3 Hive的数据类型讲解及实际项目中如何使用python脚本对数据进行ETL

    一.hive Data Types https://cwiki. apache. org/confluence/display/HiveLanguageManual+Types Numeric Typ ...

  9. Python即时网络爬虫项目启动说明

    作为酷爱编程的老程序员,实在按耐不下这个冲动,Python真的是太火了,不断撩拨我的心. 我是对Python存有戒备之心的,想当年我基于Drupal做的系统,使用php语言,当语言升级了,推翻了老版本 ...

随机推荐

  1. vscode 创建.net core项目初体验

    微软的virtual studio编辑器那是宇宙第一大编辑器,可惜就是太笨重,遇到性能差一些的电脑设备,简直无法快速的编辑项目. 而vs code编辑器轻便易用,想要编辑哪种项目,只需扩展插件就OK, ...

  2. Shadow Properties之美(二)【Microsoft Entity Framework Core随笔】

    接着上一篇Shadow Properties之美(一),我们来继续举一个有点啰嗦的栗子. 先看简单需求:某HR系统,需要记录员工资料.需要记录的资料有: 员工号(规则:分公司所在城市拼音首字母,加上三 ...

  3. 在win中,给powershell客户端,搭建sshd服务器。

    下载:https://github.com/PowerShell/Win32-OpenSSH/releases     问:为什么要用这个sshd?答:这是微软用,openssh官方的源码,源码网址: ...

  4. Problem 3: Largest prime factor

    The prime factors of 13195 are 5, 7, 13 and 29. What is the largest prime factor of the number 60085 ...

  5. Java继承2

    1.为什么使用继承 从已有的类派生出新的类,称为继承. 在不同的类中也可能会有共同的特征和动作,可以把这些共同的特征和动作放在一个类中,让其它类共享. 因此可以定义一个通用类,然后将其扩展为其它多个特 ...

  6. python单列模式

    单例模式:就是永远用一个对象的实例 初级版 #初级版 class Foo(object): instance=None def __init__(self): pass @classmethod # ...

  7. dubbo入门学习笔记之环境准备

    粗略的学完springcloud后由于公司的项目有用到一点dubbo,刚好手头上又有dubbo的学习资料,于是趁机相对系统的学了下duboo框架,今天开始记录下我的所学所悟;说来惭愧,今年之前,作为一 ...

  8. FCC-js算法题解题笔记

    题目链接:https://learn.freecodecamp.org/javascript-algorithms-and-data-structures/intermediate-algorithm ...

  9. RobotFramework Selenium2 关键字

    *** Settings ***Library Selenium2Library *** Keywords ***Checkbox应该不被选择 [Arguments] ${locator} Check ...

  10. better-scroll

    better-scroll会将默认事件阻止掉,如果自己写的部分需要有点击事件,需要在参数里加上click:true. 同时,在PC上或某些手机端,由于未成功将touchend事件move掉,点击事件会 ...