训练好了一个Model 以后总需要保存和再次预测, 所以保存和读取我们的sklearn model也是同样重要的一步。

比如,我们根据房源样本数据训练了一下房价模型,当用户输入自己的房子后,我们就需要根据训练好的房价模型来预测用户房子的价格。

这样就需要在训练模型后把模型保存起来,在使用模型时把模型读取出来对输入的数据进行预测。

这里保存和读取模型有两种方法,都非常简单,差别在于保存和读取速度的快慢上,因为有一个是利用了多进程机制,下面我们分别来看一下。

创建模型

首先我们创建模型并训练数据:

from sklearn.datasets import load_digits
from sklearn.svm import SVC # 加载数据
digits = load_digits()
X = digits.data
y = digits.target model = SVC()
model.fit(X, y)

用pickle读写模型

pickle是python中用于数据序列化的模块,因此,对于模型的序列化也可以用此模块来进行:

import pickle
# 以写二进制的方式打开文件
file = open("D:/data/python/model.pickle", "wb")
# 把模型写入到文件中
pickle.dump(model, file)
# 关闭文件
file.close()

这样会创建D:/data/python/model.pickle的文件,大家可以自己去尝试下看看,我这边生成的文件大概1M左右。

有了模型文件之后,在进行预测时我们就不需要进行训练了,而只要把这个训练好的模型文件读取出来,然后直接进行预测就可以:

import pickle
# 以读二进制的方式打开文件
file = open("D:/data/python/model.pickle", "rb")
# 把模型从文件中读取出来
model = pickle.load(file)
# 关闭文件
file.close() # 用模型进行预测
from sklearn.datasets import load_digits
digits = load_digits()
X = digits.data
y = digits.target print("预测值:", model.predict(X[15:20]))
print("实际值:", y[15:20])

输出为:

预测值: [5 6 7 8 9]
实际值: [5 6 7 8 9]

用joblib进行模型的读写

直接上代码:

from sklearn.datasets import load_digits
from sklearn.svm import SVC # 用模型进行训练
digits = load_digits()
X = digits.data
y = digits.target
model = SVC()
model.fit(X, y) # 用joblib保存模型
from sklearn.externals import joblib
joblib.dump(model, "D:/data/python/model.joblib")

这样就会生成D:/data/python/model.joblib文件,看起来比pickle生成的文件大一点点。

读取模型:

# 用joblib读取模型
from sklearn.externals import joblib
model = joblib.load("D:/data/python/model.joblib") # 对数据进行预测
from sklearn.datasets import load_digits
digits = load_digits()
X = digits.data
y = digits.target print("预测值:", model.predict(X[15:20]))
print("实际值:", y[15:20])

输出为:

预测值: [5 6 7 8 9]
实际值: [5 6 7 8 9]

看起来也很简单,同pickle的区别是joblib会以多进程方式来进行,据说性能会好些。

sklearn保存模型-【老鱼学sklearn】的更多相关文章

  1. sklearn标准化-【老鱼学sklearn】

    在前面的一篇博文中关于计算房价中我们也大致提到了标准化的概念,也就是比如对于影响房价的参数中有面积和户型,面积的取值范围可以很广,它可以从0-500平米,而户型一般也就1-5. 标准化就是要把这两种参 ...

  2. sklearn数据库-【老鱼学sklearn】

    在做机器学习时需要有数据进行训练,幸好sklearn提供了很多已经标注好的数据集供我们进行训练. 本节就来看看sklearn提供了哪些可供训练的数据集. 这些数据位于datasets中,网址为:htt ...

  3. sklearn交叉验证-【老鱼学sklearn】

    交叉验证(Cross validation),有时亦称循环估计, 是一种统计学上将数据样本切割成较小子集的实用方法.于是可以先在一个子集上做分析, 而其它子集则用来做后续对此分析的确认及验证. 一开始 ...

  4. sklearn模型的属性与功能-【老鱼学sklearn】

    本节主要讲述模型中的各种属性及其含义. 例如上个博文中,我们有用线性回归模型来拟合房价. # 创建线性回归模型 model = LinearRegression() # 训练模型 model.fit( ...

  5. sklearn交叉验证2-【老鱼学sklearn】

    过拟合 过拟合相当于一个人只会读书,却不知如何利用知识进行变通. 相当于他把考试题目背得滚瓜烂熟,但一旦环境稍微有些变化,就死得很惨. 从图形上看,类似下图的最右图: 从数学公式上来看,这个曲线应该是 ...

  6. sklearn交叉验证3-【老鱼学sklearn】

    在上一个博文中,我们用learning_curve函数来确定应该拥有多少的训练集能够达到效果,就像一个人进行学习时需要做多少题目就能拥有较好的考试成绩了. 本次我们来看下如何调整学习中的参数,类似一个 ...

  7. tensorflow卷积神经网络-【老鱼学tensorflow】

    前面我们曾有篇文章中提到过关于用tensorflow训练手写2828像素点的数字的识别,在那篇文章中我们把手写数字图像直接碾压成了一个784列的数据进行识别,但实际上,这个图像是2828长宽结构的,我 ...

  8. 二分类问题续 - 【老鱼学tensorflow2】

    前面我们针对电影评论编写了二分类问题的解决方案. 这里对前面的这个方案进行一些改进. 分批训练 model.fit(x_train, y_train, epochs=20, batch_size=51 ...

  9. 转sklearn保存模型

    训练好了一个Model 以后总需要保存和再次预测, 所以保存和读取我们的sklearn model也是同样重要的一步. 比如,我们根据房源样本数据训练了一下房价模型,当用户输入自己的房子后,我们就需要 ...

随机推荐

  1. 基于vue现有项目的服务器端渲染SSR改造

    前面的话 不论是官网教程,还是官方DEMO,都是从0开始的服务端渲染配置.对于现有项目的服务器端渲染SSR改造,特别是基于vue cli生成的项目,没有特别提及.本文就小火柴的前端小站这个前台项目进行 ...

  2. Magento 2 安装数据表

    Magento 2 安装数据表 第1步:安装脚本 首先,我们将为CRUD模型创建数据库表.为此,我们需要插入安装文件 app/code/Mageplaza/HelloWorld/Setup/Insta ...

  3. IP地址及网络常识

    一.IP 互联网网络协议(internret protocol address ,IP),IP地址是IP协议提供的一种统一的标准化的地址格式,它会为互联网中的每个网络和每台主机备提供一个逻辑地址,来区 ...

  4. [hashcat]基于字典和暴力破解尝试找到rar3-hp的压缩包密码

    1.使用rar2john找到md5 2.基于字典 hashcat -a 0 -m 12500 /root/Desktop/md5.txt /usr/share/wordlists/weakpass.t ...

  5. java java.net.URLConnection 实现http get,post

    抽象类 URLConnection 是所有类的超类,它代表应用程序和 URL 之间的通信链接.此类的实例可用于读取和写入此 URL 引用的资源.通常,创建一个到 URL 的连接需要几个步骤: open ...

  6. 微服务与容器化Docker

    1.Docker的应用案例 2. 3. 4.docker的核心:镜像.仓库.容器 Build构建镜像:类似于集装箱. Ship运输镜像,仓库:类似于码头.将镜像运输到仓库. Run运行镜像:容器:类似 ...

  7. python httpserver

    python3: python -m http.server 80 python2: python -m SimpleHTTPServer 9004

  8. dom4j,json,pattern性能对比【原】

    报文大概2000字节,对比时为只取其中某个节点的值即可. 以下对比可知取少量节点时pattern性能是远大于dom4j,和json的, 但取大量的时候就不能这么以偏概全了. dom4j和pattern ...

  9. 类型和原生函数及类型转换(二:终结js类型判断)

    typeof instanceof isArray() Object.prototype.toString.call() DOM对象与DOM集合对象的类型判断 一.typeof typeof是一个一元 ...

  10. Spring Security 之方法级的安全管控

    默认情况下, Spring Security 并不启用方法级的安全管控. 启用方法级的管控后, 可以针对不同的方法通过注解设置不同的访问条件. Spring Security 支持三种方法级注解, 分 ...