在机器学习中,当确定好一个模型后,我们需要将它保存下来,这样当新数据出现时,我们能够调出这个模型来对新数据进行预测。同时这些新数据将被作为历史数据保存起来,经过一段周期后,使用更新的历史数据再次训练,得到更新的模型。

如果模型的流转都在python内部,那么可以使用内置的pickle库来完成模型的存储和调取。

什么是pickle?pickle是负责将python对象序列化(serialization)和反序列化(de-serialization)的模块。pickle模块可以读入任何python对象,然后将它们转换成字符串,我们再使用dump函数将其储存到文件中,这个过程叫做pickling;反之从文件中提取原始python对象的过程叫做unpickling。

picke.dump() --- 将训练好的模型保存在磁盘上

with open(file_name, 'wb') as file:
pickle.dump(model, file)

pickle.load() --- 读取保存在磁盘上的模型

with open(file_name, 'rb') as file:
model=pickle.load(file)

以线性回归模型为例:

import numpy as np

class Linear_Regression:
def __init__(self):
self._w = None def fit(self, X, y, lr=0.01, epsilon=0.01, epoch=1000):
#训练数据
#将输入的X,y转换为numpy数组
X, y = np.asarray(X, np.float32), np.asarray(y, np.float32)
#给X增加一列常数项
X=np.hstack((X,np.ones((X.shape[0],1))))
#初始化w
self._w = np.zeros((X.shape[1],1)) for _ in range(epoch):
#随机选择一组样本计算梯度
random_num=np.random.choice(len(X))
x_random=X[random_num].reshape(1,2)
y_random=y[random_num] gradient=(x_random.T)*(np.dot(x_random,self._w)-y_random) #如果收敛,那么停止迭代
if (np.abs(self._w-lr*gradient)<epsilon).all():
break
#否则,更新w
else:
self._w =self._w-lr*gradient return self._w def print_results(self):
print("参数w:{}".format(self._w))
print("回归拟合线:y={}x+{}".format(self._w[0],self._w[1])) def predict(self,x):
x=np.asarray(x, np.float32)
x=x.reshape(x.shape[0],1)
x=np.hstack((x,np.ones((x.shape[0],1))))
return np.dot(x,self._w)

训练并保存模型:

import pickle

#创建数据
x=np.linspace(0,100,10).reshape(10,1)
rng=np.random.RandomState(4)
noise=rng.randint(-10,10,size=(10,1))*4
y=4*x+4+noise model=Linear_Regression()
model.fit(x,y,lr=0.0001,epsilon=0.001,epoch=20) with open('model.pickle', 'wb') as file:
pickle.dump(model, file)

然后调取模型并进行预测和打印结果:

with open('model.pickle', 'rb') as file:
model=pickle.load(file)
print(model.predict([50]))
model.print_results()

输出:

[[208.73892002]]
参数w:[[4.17372929]
[0.05245564]]
回归拟合线:y=[4.17372929]x+[0.05245564]

model是保存在磁盘上的一个python对象:

<__main__.Linear_Regression object at 0x0000009FA44B2F98>

用pickle保存机器学习模型的更多相关文章

  1. 使用Flask构建机器学习模型API

    1. Python环境设置和Flask基础 使用"Anaconda"创建一个虚拟环境.如果你需要在Python中创建你的工作流程,并将依赖项分离出来,或者共享环境设置," ...

  2. (sklearn)机器学习模型的保存与加载

    需求: 一直写的代码都是从加载数据,模型训练,模型预测,模型评估走出来的,但是实际业务线上咱们肯定不能每次都来训练模型,而是应该将训练好的模型保存下来 ,如果有新数据直接套用模型就行了吧?现在问题就是 ...

  3. scikit-learn系列之如何存储和导入机器学习模型

    scikit-learn系列之如何存储和导入机器学习模型   如何存储和导入机器学习模型 找到一个准确的机器学习模型,你的项目并没有完成.本文中你将学习如何使用scikit-learn来存储和导入机器 ...

  4. 使用pmml实现跨平台部署机器学习模型

    一.概述   对于由Python训练的机器学习模型,通常有pickle和pmml两种部署方式,pickle方式用于在python环境中的部署,pmml方式用于跨平台(如Java环境)的部署,本文叙述的 ...

  5. Python 3 利用机器学习模型 进行手写体数字识别

    0.引言 介绍了如何生成数据,提取特征,利用sklearn的几种机器学习模型建模,进行手写体数字1-9识别. 用到的四种模型: 1. LR回归模型,Logistic Regression 2. SGD ...

  6. 使用ML.NET + ASP.NET Core + Docker + Azure Container Instances部署.NET机器学习模型

    本文将使用ML.NET创建机器学习分类模型,通过ASP.NET Core Web API公开它,将其打包到Docker容器中,并通过Azure Container Instances将其部署到云中. ...

  7. tensorflow机器学习模型的跨平台上线

    在用PMML实现机器学习模型的跨平台上线中,我们讨论了使用PMML文件来实现跨平台模型上线的方法,这个方法当然也适用于tensorflow生成的模型,但是由于tensorflow模型往往较大,使用无法 ...

  8. 用PMML实现机器学习模型的跨平台上线

    在机器学习用于产品的时候,我们经常会遇到跨平台的问题.比如我们用Python基于一系列的机器学习库训练了一个模型,但是有时候其他的产品和项目想把这个模型集成进去,但是这些产品很多只支持某些特定的生产环 ...

  9. 为你的机器学习模型创建API服务

    1. 什么是API 当调包侠们训练好一个模型后,下一步要做的就是与业务开发组同学们进行代码对接,以便这些‘AI大脑’们可以顺利的被使用.然而往往要面临不同编程语言的挑战,例如很常见的是调包侠们用Pyt ...

随机推荐

  1. 相似文档查找算法之 simHash及其 java 实现

    传统的 hash 算法只负责将原始内容尽量均匀随机地映射为一个签名值,原理上相当于伪随机数产生算法.产生的两个签名,如果相等,说明原始内容在一定概 率 下是相等的:如果不相等,除了说明原始内容不相等外 ...

  2. docker 安装redis mysql rabbitmq

    docker redis mysql rabbitmq 基本命令 安装redis 安装mysql 安装rabbitmq 基本命令 命令格式: docker 命令 [镜像/容器]名字 常用命令: sea ...

  3. 2019 UCloudjava面试笔试题 (含面试题解析)

      本人5年开发经验.18年年底开始跑路找工作,在互联网寒冬下成功拿到阿里巴巴.今日头条.UCloud等公司offer,岗位是Java后端开发,因为发展原因最终选择去了UCloud,入职一年时间了,也 ...

  4. 供应链管理如何提高效率?APS系统成优化引擎

    APS系统,虽然它的起兴只有短短的十几年,但是在这段时间里面,它为很多企业解决了很多人工手动.脑力不可解决的问题. 所以APS被誉为供应链优化引擎,APS常常被称为高级计划与排程,但也有称为高级计划系 ...

  5. 下载Spring

    下载Spring Spring官网并不直接提供Spring的下载,Spring现在托管在GitHub上. 1.进入Spring官网 -> PROJECTS -> SPRING FRAMEW ...

  6. iOS学习——NSLog输出各种类型

    在开发过程中,在调试过程中经常打印不出自己想要的数据格式,还时常报警告,所以整理了一下iOS中用NSLog打印各种数据类型的样式.整型占位符说明 : %d : 十进制整数, 正数无符号, 负数有 “- ...

  7. 使用Prometheus监控Linux系统各项指标

    首先在Linux系统上安装一个探测器node explorer, 下载地址https://prometheus.io/docs/guides/node-exporter/ 这个探测器会定期将linux ...

  8. mysql replace into 实现存在则更新,不存在则插入

    测试用的mysql数据库: 新建测试表: CREATE TABLE `test` ( `id` bigint(20) NOT NULL AUTO_INCREMENT, `text` varchar(2 ...

  9. oracle删除表空间和用户

    步骤一: 删除tablespace(登录对应用户删除表空间) DROP TABLESPACE tablespace_name INCLUDING CONTENTS AND DATAFILES; 步骤二 ...

  10. Linux访问控制列表(Access Control List,简称ACL)

    Linux访问控制列表(Access Control List,简称ACL) 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 一.ACL概述 ACL:Access Control L ...