模型持久化(模型保存与加载)是机器学习完成的最后一步。
因为,在实际情况中,训练一个模型可能会非常耗时,如果每次需要使用模型时都要重新训练,这无疑会浪费大量的计算资源和时间。

通过将训练好的模型持久化到磁盘,我们可以在需要使用模型时直接从磁盘加载到内存,而无需重新训练。这样不仅可以节省时间,还可以提高模型的使用效率。

本篇介绍scikit-learn中几种常用的模型持久化方法。

1. 训练模型

首先,训练一个模型,这里用scikit-learn自带的手写数字数据集作为样本。

import matplotlib.pyplot as plt
from sklearn import datasets # 加载手写数据集
data = datasets.load_digits() # 调整数据格式
n_samples = len(data.images)
X = data.images.reshape((n_samples, -1))
y = data.target # 用支持向量机训练模型
from sklearn.svm import SVC # 定义
reg = SVC() # 训练模型
reg.fit(X, y)

最后的得到的 reg 就是我们训练之后的模型,使用这个模型,就可以预测一些手写数字图片。

但是这个 reg 是代码中的一个变量,如果不能保存下来,那么,每次需要使用的时候,
还要重新执行一次上面的模型训练代码,样本数据量大的话,每次重复训练会浪费大量时间和计算资源。

所以,要将上面的 reg 模型保存下来,下次使用的时候,直接加载,不用重新训练。

2. 模型持久化

2.1. pickle 序列化

pickle格式是python中常用的序列化方式,它通过将python对象及其所拥有的层次结构转化为一个字节流来实现序列化。

将上面的模型保存到磁盘文件model.pkl中。

import pickle

with open("./model.pkl", "wb") as f:
pickle.dump(reg, f)

需要使用模型时,从磁盘加载的方式:

with open("./model.pkl", "rb") as f:
reg_pkl = pickle.load(f)

验证加载之后的模型reg_pkl是否可以正常使用。

y_pred = reg_pkl.predict(X)

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

cm = confusion_matrix(y, y_pred)
g = ConfusionMatrixDisplay(confusion_matrix=cm)
g.plot() plt.show()


从混淆矩阵来看,模型可以正常加载和使用。
关于混淆矩阵具体内容,可以参考:【scikit-learn基础】--『分类模型评估』之评估报告

2.2. joblib 序列化

相比于pickle,保存机器学习模型时,更推荐使用joblib
因为joblib针对大数据进行了优化,使其在处理大型数据集时性能更佳。

序列化的方式也很简单:

import joblib

joblib.dump(reg, "model.jlib")

从磁盘加载模型并验证:

reg_jlib = joblib.load("model.jlib")

y_pred = reg_jlib.predict(X)

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

cm = confusion_matrix(y, y_pred)
g = ConfusionMatrixDisplay(confusion_matrix=cm)
g.plot() plt.show()

2.3. skops 格式

skops是比较新的一种格式,它是专门为了共享基于 scikit-learn 的模型而开发的。
目前还在积极的开发中,github上的地址是:github-skops

相比于picklejoblib,它提供了更加安全的序列化格式,
但使用上和它们差别不大。

import skops.io as sio

# 保存到文件 model.sio
sio.dump(reg, "model.sio")

从文件中读取模型并验证:

reg_sio = sio.load("model.sio")

y_pred = reg_jlib.predict(X)

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

cm = confusion_matrix(y, y_pred)
g = ConfusionMatrixDisplay(confusion_matrix=cm)
g.plot() plt.show()

3. 总结

scikit-learn中,模型持久化是一个重要且实用的技术,它允许我们将训练好的模型保存到磁盘上,以便在不同的时间点或不同的环境中重新加载和使用。
通过模型持久化,我们能够避免每次需要使用时重新训练模型,从而节省大量的时间和计算资源。

本篇介绍的三种方法可以方便的序列化和反序列化模型对象,使其可以轻松地保存到磁盘上,并能够在需要时恢复出原始模型对象。

总而言之,模型持久化不仅使得我们能够在不同的运行会话之间重用模型,还方便了模型的共享和部署。

【scikit-learn基础】--模型持久化的更多相关文章

  1. (原创)(三)机器学习笔记之Scikit Learn的线性回归模型初探

    一.Scikit Learn中使用estimator三部曲 1. 构造estimator 2. 训练模型:fit 3. 利用模型进行预测:predict 二.模型评价 模型训练好后,度量模型拟合效果的 ...

  2. Scikit Learn: 在python中机器学习

    转自:http://my.oschina.net/u/175377/blog/84420#OSC_h2_23 Scikit Learn: 在python中机器学习 Warning 警告:有些没能理解的 ...

  3. scikit learn 模块 调参 pipeline+girdsearch 数据举例:文档分类 (python代码)

    scikit learn 模块 调参 pipeline+girdsearch 数据举例:文档分类数据集 fetch_20newsgroups #-*- coding: UTF-8 -*- import ...

  4. (原创)(四)机器学习笔记之Scikit Learn的Logistic回归初探

    目录 5.3 使用LogisticRegressionCV进行正则化的 Logistic Regression 参数调优 一.Scikit Learn中有关logistics回归函数的介绍 1. 交叉 ...

  5. tensorflow学习笔记——模型持久化的原理,将CKPT转为pb文件,使用pb模型预测

    由题目就可以看出,本节内容分为三部分,第一部分就是如何将训练好的模型持久化,并学习模型持久化的原理,第二部分就是如何将CKPT转化为pb文件,第三部分就是如何使用pb模型进行预测. 一,模型持久化 为 ...

  6. [Tensorflow]模型持久化的原理,将CKPT转为pb文件,使用pb模型预测

    文章目录 [Tensorflow]模型持久化的原理,将CKPT转为pb文件,使用pb模型预测 一.模型持久化 1.持久化代码实现 convert_variables_to_constants固化模型结 ...

  7. linux下bus、devices和platform的基础模型

    转自:http://blog.chinaunix.net/uid-20672257-id-3147337.html 一.kobject的定义:kobject是Linux2.6引入的设备管理机制,在内核 ...

  8. Query意图分析:记一次完整的机器学习过程(scikit learn library学习笔记)

    所谓学习问题,是指观察由n个样本组成的集合,并根据这些数据来预测未知数据的性质. 学习任务(一个二分类问题): 区分一个普通的互联网检索Query是否具有某个垂直领域的意图.假设现在有一个O2O领域的 ...

  9. ThinkPHP 学习笔记 ( 三 ) 数据库操作之数据表模型和基础模型 ( Model )

    //TP 恶补ing... 一.定义数据表模型 1.模型映射 要测试数据库是否正常连接,最直接的办法就是在当前控制器中实例化数据表,然后使用 dump 函数输出,查看数据库的链接状态.代码: publ ...

  10. Tensorflow 模型持久化saver及加载图结构

    主要内容: 1. 直接保存,加载模型; (可以指定加载,保存的var_list) 2. 加载,保存指定变量的模型 3. slim加载模型使用 4. 加载模型图结构和参数等 tensorflow 恢复部 ...

随机推荐

  1. springboot2.0+dubbo-spring-boot-starter聚合项目打可执行的jar包

    springboot2.0+dubbo聚合项目打可执行的jar包 springboot2.0+dubbo-spring-boot-starter项目服务方打包和以前老版本的dubbo打包方式不一样,不 ...

  2. Logback 实现日志链路追踪

    本文为博主原创,未经允许不得转载: 在开发过程中,经常会使用log记录一下当前请求的参数,过程和结果,以便帮助定位问题.在并发量下的情况下,日志打印不会剧增,可以很快就能通过打印的日志查看执行的情况. ...

  3. 基于python+django的旅游信息网站-旅游景点门票管理系统设计与实现

    该系统是基于python+django开发的旅游景点门票管理系统.是给师弟做的课程作业.大家学习过程中,遇到问题可以在github咨询作者 演示地址 前台地址: http://travel.gitap ...

  4. npm, yarn和pnpm清理缓存

    .markdown-body { line-height: 1.75; font-weight: 400; font-size: 16px; overflow-x: hidden; color: rg ...

  5. 【TouchGFX】使用CubeMX创建touchgfx项目 -- 初始篇

    1.系统构成,黑色块表示组件非必须 2.环境准备 CubeMX:6.0.1 touchgfx:4.15.0 rt-thread:2020-8-14 commit Keil:5.30 board:stm ...

  6. 2023第十四届极客大挑战 — CRYPTO(WP全)

    浅谈: 本次大挑战我们队伍也是取得了第一名的成绩,首先要感谢同伴的陪伴和帮助.在共同的努力下终不负期望! 但遗憾的是我们没有在某个方向全通关的,呜呜呜~ 继续努力吧!要学的还很多.明年有机会再战!!加 ...

  7. 【Spring 5核心原理】1设计模式

    1.1开闭原则 开闭原则(open-closed principle,OCP)是指一个软件实体(如类,模块和函数)应该对扩展开放,对修改关闭.所谓的开闭,也正是对扩展和修改两个行为的一个原则. 强调用 ...

  8. IDE-常用插件

    2021-8-25_IDE-常用插件 1. 背景 提升编写代码的舒适度,提升开发效率 2. 常用插件列表 IDE EVal Reset 白嫖付费的golang编辑器,reset插件可以重置golang ...

  9. [转帖]实用小技能:一键获取Harbor中镜像信息,快捷查询镜像

    [摘要]一键获取Harbor中的镜像列表,无需登录harbor UI,也可批量下载镜像到本地并保存为tar包.本文已参与「开源摘星计划」,欢迎正在阅读的你加入.活动链接:https://github. ...

  10. [转帖]docker-compose完全清除

    https://www.cnblogs.com/gelandesprung/p/12112420.html#:~:text=docker-compose%E5%AE%8C%E5%85%A8%E6%B8 ...