深度学习模型stacking模型融合python代码,看了你就会使
话不多说,直接上代码
def stacking_first(train, train_y, test):
savepath = './stack_op{}_dt{}_tfidf{}/'.format(args.option, args.data_type, args.tfidf)
os.makedirs(savepath, exist_ok=True) count_kflod = 0
num_folds = 6
kf = KFold(n_splits=num_folds, shuffle=True, random_state=10)
# 测试集上的预测结果
predict = np.zeros((test.shape[0], config.n_class))
# k折交叉验证集的预测结果
oof_predict = np.zeros((train.shape[0], config.n_class))
scores = []
f1s = [] for train_index, test_index in kf.split(train):
# 训练集划分为6折,每一折都要走一遍。那么第一个是5份的训练集索引,第二个是1份的测试集,此处为验证集是索引 kfold_X_train = {}
kfold_X_valid = {} # 取数据的标签
y_train, y_test = train_y[train_index], train_y[test_index]
# 取数据
kfold_X_train, kfold_X_valid = train[train_index], train[test_index] # 模型的前缀
model_prefix = savepath + 'DNN' + str(count_kflod)
if not os.path.exists(model_prefix):
os.mkdir(model_prefix) M = 4 # number of snapshots
alpha_zero = 1e-3 # initial learning rate
snap_epoch = 16
snapshot = SnapshotCallbackBuilder(snap_epoch, M, alpha_zero) # 使用训练集的size设定维度,fit一个模型出来
res_model = get_model(train)
res_model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
# res_model.fit(train_x, train_y, batch_size=BATCH_SIZE, epochs=EPOCH, verbose=1, class_weight=class_weight)
res_model.fit(kfold_X_train, y_train, batch_size=BATCH_SIZE, epochs=snap_epoch, verbose=1,
validation_data=(kfold_X_valid, y_test),
callbacks=snapshot.get_callbacks(model_save_place=model_prefix)) # 找到这个目录下所有已经训练好的深度学习模型,通过".h5"
evaluations = []
for i in os.listdir(model_prefix):
if '.h5' in i:
evaluations.append(i) # 给测试集和当前的验证集开辟空间,就是当前折的数据预测结果构建出这么多的数据集[数据个数,类别]
preds1 = np.zeros((test.shape[0], config.n_class))
preds2 = np.zeros((len(kfold_X_valid), config.n_class))
# 遍历每一个模型,用他们分别预测当前折数的验证集和测试集,N个模型的结果求平均
for run, i in enumerate(evaluations):
res_model.load_weights(os.path.join(model_prefix, i))
preds1 += res_model.predict(test, verbose=1) / len(evaluations)
preds2 += res_model.predict(kfold_X_valid, batch_size=128) / len(evaluations) # 测试集上预测结果的加权平均
predict += preds1 / num_folds
# 每一折的预测结果放到对应折上的测试集中,用来最后构建训练集
oof_predict[test_index] = preds2 # 计算精度和F1
accuracy = mb.cal_acc(oof_predict[test_index], np.argmax(y_test, axis=1))
f1 = mb.cal_f_alpha(oof_predict[test_index], np.argmax(y_test, axis=1), n_out=config.n_class)
print('the kflod cv is : ', str(accuracy))
print('the kflod f1 is : ', str(f1))
count_kflod += 1 # 模型融合的预测结果,存起来,用以以后求平均值
scores.append(accuracy)
f1s.append(f1)
# 指标均值,最为最后的预测结果
print('total scores is ', np.mean(scores))
print('total f1 is ', np.mean(f1s))
return predict
深度学习模型stacking模型融合python代码,看了你就会使的更多相关文章
- 时间序列深度学习:seq2seq 模型预测太阳黑子
目录 时间序列深度学习:seq2seq 模型预测太阳黑子 学习路线 商业中的时间序列深度学习 商业中应用时间序列深度学习 深度学习时间序列预测:使用 keras 预测太阳黑子 递归神经网络 设置.预处 ...
- 【转】[caffe]深度学习之图像分类模型AlexNet解读
[caffe]深度学习之图像分类模型AlexNet解读 原文地址:http://blog.csdn.net/sunbaigui/article/details/39938097 本文章已收录于: ...
- [caffe]深度学习之图像分类模型VGG解读
一.简单介绍 vgg和googlenet是2014年imagenet竞赛的双雄,这两类模型结构有一个共同特点是go deeper.跟googlenet不同的是.vgg继承了lenet以及alexnet ...
- 深度学习 vs. 概率图模型 vs. 逻辑学
深度学习 vs. 概率图模型 vs. 逻辑学 摘要:本文回顾过去50年人工智能(AI)领域形成的三大范式:逻辑学.概率方法和深度学习.文章按时间顺序展开,先回顾逻辑学和概率图方法,然后就人工智能和机器 ...
- 深度学习的seq2seq模型——本质是LSTM,训练过程是使得所有样本的p(y1,...,yT‘|x1,...,xT)概率之和最大
from:https://baijiahao.baidu.com/s?id=1584177164196579663&wfr=spider&for=pc seq2seq模型是以编码(En ...
- 推荐系统遇上深度学习(十)--GBDT+LR融合方案实战
推荐系统遇上深度学习(十)--GBDT+LR融合方案实战 0.8012018.05.19 16:17:18字数 2068阅读 22568 推荐系统遇上深度学习系列:推荐系统遇上深度学习(一)--FM模 ...
- 深入浅出深度学习:原理剖析与python实践_黄安埠(著) pdf
深入浅出深度学习:原理剖析与python实践 目录: 第1 部分 概要 1 1 绪论 2 1.1 人工智能.机器学习与深度学习的关系 3 1.1.1 人工智能——机器推理 4 1.1.2 机器学习—— ...
- 一文看懂Stacking!(含Python代码)
一文看懂Stacking!(含Python代码) https://mp.weixin.qq.com/s/faQNTGgBZdZyyZscdhjwUQ
- 风炫安全web安全学习第三十二节课 Python代码执行以及代码防御措施
风炫安全web安全学习第三十二节课 Python代码执行以及代码防御措施 Python 语言可能发生的命令执行漏洞 内置危险函数 eval和exec函数 eval eval是一个python内置函数, ...
随机推荐
- 基于Centos搭建Django 环境搭建
CentOS 7.2 64 位操作系统 安装 Django 先安装 PIP,再通过 PIP 安装 Django 安装 PIP cd /data; mkdir tmp; cd tmp; wget htt ...
- WCF兼容WebAPI输出Json格式数据,从此WCF一举两得
问题起源: 很多时候为了业务层调用(后台代码),一些公共服务就独立成了WCF,使用起来非常方便,添加服务引用,然后简单配置就可以调用了. 如果这个时候Web站点页面需要调用怎么办呢? 复杂的XML , ...
- easy_install与pip 区别
作为Python爱好者,如果不知道easy_install或者pip中的任何一个的话,那么...... easy_insall的作用和perl中的cpan,ruby中的gem类似,都提供了在线一键 ...
- linux每日命令(4):pwd命令
Linux中用 pwd 命令来查看"当前工作目录"的完整路径. 简单得说,每当你在终端进行操作时,你都会有一个当前工作目录. 在不太确定当前位置时,就会使用pwd来判定当前目录在文 ...
- Java compiler level does not match the version of the installed Java project facet解决办法
意思就是project facet 和 java compiler level 不一致 解决办法: 修改project facet 方法一: 选中工程,右键 Property -> Projec ...
- java中的数据加密4 数字签名
数字签名 它是确定交换消息的通信方身份的第一个级别.A通过使用公钥加密数据后发给B,B利用私钥解密就得到了需要的数据,问题来了,由于都是使用公钥加密,那么如何检验是A发过来的消息呢?上面也提到了一点, ...
- Android自动化测试之Monkeyrunner从零开始
最近由于公司在组织一个Free CoDE的项目,也就是由大家自己选择研究方向来做一些自己感兴趣的研究.由于之前我学过一点点关于android的东西,并且目前android开发方兴未艾如火如荼,但自动化 ...
- supervisor开机自动启动脚本+redis+MySQL+tomcat+nginx进程自动重启配置
[root@mongodb-host supervisord]# cat mongo.conf [program:mongo]command=/usr/local/mongodb/bin/mongod ...
- 【转】Web前端性能优化——如何提高页面加载速度
前言: 在同样的网络环境下,两个同样能满足你的需求的网站,一个“Duang”的一下就加载出来了,一个纠结了半天才出来,你会选择哪个?研究表明:用户最满意的打开网页时间是2-5秒,如果等待超过10秒, ...
- (实用)CentOS 6.3更新内置Python2.6
在安装Kilo版的OpenStack时,我们发现社区已经将Python升到2.7,而CentOS 6.3上仍然在使用2.6版的Python.本文记录将CentOS 6.3内置的Python2.6更新为 ...