话不多说,直接上代码

 def stacking_first(train, train_y, test):
savepath = './stack_op{}_dt{}_tfidf{}/'.format(args.option, args.data_type, args.tfidf)
os.makedirs(savepath, exist_ok=True) count_kflod = 0
num_folds = 6
kf = KFold(n_splits=num_folds, shuffle=True, random_state=10)
# 测试集上的预测结果
predict = np.zeros((test.shape[0], config.n_class))
# k折交叉验证集的预测结果
oof_predict = np.zeros((train.shape[0], config.n_class))
scores = []
f1s = [] for train_index, test_index in kf.split(train):
# 训练集划分为6折,每一折都要走一遍。那么第一个是5份的训练集索引,第二个是1份的测试集,此处为验证集是索引 kfold_X_train = {}
kfold_X_valid = {} # 取数据的标签
y_train, y_test = train_y[train_index], train_y[test_index]
# 取数据
kfold_X_train, kfold_X_valid = train[train_index], train[test_index] # 模型的前缀
model_prefix = savepath + 'DNN' + str(count_kflod)
if not os.path.exists(model_prefix):
os.mkdir(model_prefix) M = 4 # number of snapshots
alpha_zero = 1e-3 # initial learning rate
snap_epoch = 16
snapshot = SnapshotCallbackBuilder(snap_epoch, M, alpha_zero) # 使用训练集的size设定维度,fit一个模型出来
res_model = get_model(train)
res_model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
# res_model.fit(train_x, train_y, batch_size=BATCH_SIZE, epochs=EPOCH, verbose=1, class_weight=class_weight)
res_model.fit(kfold_X_train, y_train, batch_size=BATCH_SIZE, epochs=snap_epoch, verbose=1,
validation_data=(kfold_X_valid, y_test),
callbacks=snapshot.get_callbacks(model_save_place=model_prefix)) # 找到这个目录下所有已经训练好的深度学习模型,通过".h5"
evaluations = []
for i in os.listdir(model_prefix):
if '.h5' in i:
evaluations.append(i) # 给测试集和当前的验证集开辟空间,就是当前折的数据预测结果构建出这么多的数据集[数据个数,类别]
preds1 = np.zeros((test.shape[0], config.n_class))
preds2 = np.zeros((len(kfold_X_valid), config.n_class))
# 遍历每一个模型,用他们分别预测当前折数的验证集和测试集,N个模型的结果求平均
for run, i in enumerate(evaluations):
res_model.load_weights(os.path.join(model_prefix, i))
preds1 += res_model.predict(test, verbose=1) / len(evaluations)
preds2 += res_model.predict(kfold_X_valid, batch_size=128) / len(evaluations) # 测试集上预测结果的加权平均
predict += preds1 / num_folds
# 每一折的预测结果放到对应折上的测试集中,用来最后构建训练集
oof_predict[test_index] = preds2 # 计算精度和F1
accuracy = mb.cal_acc(oof_predict[test_index], np.argmax(y_test, axis=1))
f1 = mb.cal_f_alpha(oof_predict[test_index], np.argmax(y_test, axis=1), n_out=config.n_class)
print('the kflod cv is : ', str(accuracy))
print('the kflod f1 is : ', str(f1))
count_kflod += 1 # 模型融合的预测结果,存起来,用以以后求平均值
scores.append(accuracy)
f1s.append(f1)
# 指标均值,最为最后的预测结果
print('total scores is ', np.mean(scores))
print('total f1 is ', np.mean(f1s))
return predict

深度学习模型stacking模型融合python代码,看了你就会使的更多相关文章

  1. 时间序列深度学习:seq2seq 模型预测太阳黑子

    目录 时间序列深度学习:seq2seq 模型预测太阳黑子 学习路线 商业中的时间序列深度学习 商业中应用时间序列深度学习 深度学习时间序列预测:使用 keras 预测太阳黑子 递归神经网络 设置.预处 ...

  2. 【转】[caffe]深度学习之图像分类模型AlexNet解读

    [caffe]深度学习之图像分类模型AlexNet解读 原文地址:http://blog.csdn.net/sunbaigui/article/details/39938097   本文章已收录于: ...

  3. [caffe]深度学习之图像分类模型VGG解读

    一.简单介绍 vgg和googlenet是2014年imagenet竞赛的双雄,这两类模型结构有一个共同特点是go deeper.跟googlenet不同的是.vgg继承了lenet以及alexnet ...

  4. 深度学习 vs. 概率图模型 vs. 逻辑学

    深度学习 vs. 概率图模型 vs. 逻辑学 摘要:本文回顾过去50年人工智能(AI)领域形成的三大范式:逻辑学.概率方法和深度学习.文章按时间顺序展开,先回顾逻辑学和概率图方法,然后就人工智能和机器 ...

  5. 深度学习的seq2seq模型——本质是LSTM,训练过程是使得所有样本的p(y1,...,yT‘|x1,...,xT)概率之和最大

    from:https://baijiahao.baidu.com/s?id=1584177164196579663&wfr=spider&for=pc seq2seq模型是以编码(En ...

  6. 推荐系统遇上深度学习(十)--GBDT+LR融合方案实战

    推荐系统遇上深度学习(十)--GBDT+LR融合方案实战 0.8012018.05.19 16:17:18字数 2068阅读 22568 推荐系统遇上深度学习系列:推荐系统遇上深度学习(一)--FM模 ...

  7. 深入浅出深度学习:原理剖析与python实践_黄安埠(著) pdf

    深入浅出深度学习:原理剖析与python实践 目录: 第1 部分 概要 1 1 绪论 2 1.1 人工智能.机器学习与深度学习的关系 3 1.1.1 人工智能——机器推理 4 1.1.2 机器学习—— ...

  8. 一文看懂Stacking!(含Python代码)

    一文看懂Stacking!(含Python代码) https://mp.weixin.qq.com/s/faQNTGgBZdZyyZscdhjwUQ

  9. 风炫安全web安全学习第三十二节课 Python代码执行以及代码防御措施

    风炫安全web安全学习第三十二节课 Python代码执行以及代码防御措施 Python 语言可能发生的命令执行漏洞 内置危险函数 eval和exec函数 eval eval是一个python内置函数, ...

随机推荐

  1. 基于Centos搭建Django 环境搭建

    CentOS 7.2 64 位操作系统 安装 Django 先安装 PIP,再通过 PIP 安装 Django 安装 PIP cd /data; mkdir tmp; cd tmp; wget htt ...

  2. WCF兼容WebAPI输出Json格式数据,从此WCF一举两得

    问题起源: 很多时候为了业务层调用(后台代码),一些公共服务就独立成了WCF,使用起来非常方便,添加服务引用,然后简单配置就可以调用了. 如果这个时候Web站点页面需要调用怎么办呢? 复杂的XML , ...

  3. easy_install与pip 区别

    作为Python爱好者,如果不知道easy_install或者pip中的任何一个的话,那么......   easy_insall的作用和perl中的cpan,ruby中的gem类似,都提供了在线一键 ...

  4. linux每日命令(4):pwd命令

    Linux中用 pwd 命令来查看"当前工作目录"的完整路径. 简单得说,每当你在终端进行操作时,你都会有一个当前工作目录. 在不太确定当前位置时,就会使用pwd来判定当前目录在文 ...

  5. Java compiler level does not match the version of the installed Java project facet解决办法

    意思就是project facet 和 java compiler level 不一致 解决办法: 修改project facet 方法一: 选中工程,右键 Property -> Projec ...

  6. java中的数据加密4 数字签名

    数字签名 它是确定交换消息的通信方身份的第一个级别.A通过使用公钥加密数据后发给B,B利用私钥解密就得到了需要的数据,问题来了,由于都是使用公钥加密,那么如何检验是A发过来的消息呢?上面也提到了一点, ...

  7. Android自动化测试之Monkeyrunner从零开始

    最近由于公司在组织一个Free CoDE的项目,也就是由大家自己选择研究方向来做一些自己感兴趣的研究.由于之前我学过一点点关于android的东西,并且目前android开发方兴未艾如火如荼,但自动化 ...

  8. supervisor开机自动启动脚本+redis+MySQL+tomcat+nginx进程自动重启配置

    [root@mongodb-host supervisord]# cat mongo.conf [program:mongo]command=/usr/local/mongodb/bin/mongod ...

  9. 【转】Web前端性能优化——如何提高页面加载速度

    前言:  在同样的网络环境下,两个同样能满足你的需求的网站,一个“Duang”的一下就加载出来了,一个纠结了半天才出来,你会选择哪个?研究表明:用户最满意的打开网页时间是2-5秒,如果等待超过10秒, ...

  10. (实用)CentOS 6.3更新内置Python2.6

    在安装Kilo版的OpenStack时,我们发现社区已经将Python升到2.7,而CentOS 6.3上仍然在使用2.6版的Python.本文记录将CentOS 6.3内置的Python2.6更新为 ...