话不多说,直接上代码

 def stacking_first(train, train_y, test):
savepath = './stack_op{}_dt{}_tfidf{}/'.format(args.option, args.data_type, args.tfidf)
os.makedirs(savepath, exist_ok=True) count_kflod = 0
num_folds = 6
kf = KFold(n_splits=num_folds, shuffle=True, random_state=10)
# 测试集上的预测结果
predict = np.zeros((test.shape[0], config.n_class))
# k折交叉验证集的预测结果
oof_predict = np.zeros((train.shape[0], config.n_class))
scores = []
f1s = [] for train_index, test_index in kf.split(train):
# 训练集划分为6折,每一折都要走一遍。那么第一个是5份的训练集索引,第二个是1份的测试集,此处为验证集是索引 kfold_X_train = {}
kfold_X_valid = {} # 取数据的标签
y_train, y_test = train_y[train_index], train_y[test_index]
# 取数据
kfold_X_train, kfold_X_valid = train[train_index], train[test_index] # 模型的前缀
model_prefix = savepath + 'DNN' + str(count_kflod)
if not os.path.exists(model_prefix):
os.mkdir(model_prefix) M = 4 # number of snapshots
alpha_zero = 1e-3 # initial learning rate
snap_epoch = 16
snapshot = SnapshotCallbackBuilder(snap_epoch, M, alpha_zero) # 使用训练集的size设定维度,fit一个模型出来
res_model = get_model(train)
res_model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
# res_model.fit(train_x, train_y, batch_size=BATCH_SIZE, epochs=EPOCH, verbose=1, class_weight=class_weight)
res_model.fit(kfold_X_train, y_train, batch_size=BATCH_SIZE, epochs=snap_epoch, verbose=1,
validation_data=(kfold_X_valid, y_test),
callbacks=snapshot.get_callbacks(model_save_place=model_prefix)) # 找到这个目录下所有已经训练好的深度学习模型,通过".h5"
evaluations = []
for i in os.listdir(model_prefix):
if '.h5' in i:
evaluations.append(i) # 给测试集和当前的验证集开辟空间,就是当前折的数据预测结果构建出这么多的数据集[数据个数,类别]
preds1 = np.zeros((test.shape[0], config.n_class))
preds2 = np.zeros((len(kfold_X_valid), config.n_class))
# 遍历每一个模型,用他们分别预测当前折数的验证集和测试集,N个模型的结果求平均
for run, i in enumerate(evaluations):
res_model.load_weights(os.path.join(model_prefix, i))
preds1 += res_model.predict(test, verbose=1) / len(evaluations)
preds2 += res_model.predict(kfold_X_valid, batch_size=128) / len(evaluations) # 测试集上预测结果的加权平均
predict += preds1 / num_folds
# 每一折的预测结果放到对应折上的测试集中,用来最后构建训练集
oof_predict[test_index] = preds2 # 计算精度和F1
accuracy = mb.cal_acc(oof_predict[test_index], np.argmax(y_test, axis=1))
f1 = mb.cal_f_alpha(oof_predict[test_index], np.argmax(y_test, axis=1), n_out=config.n_class)
print('the kflod cv is : ', str(accuracy))
print('the kflod f1 is : ', str(f1))
count_kflod += 1 # 模型融合的预测结果,存起来,用以以后求平均值
scores.append(accuracy)
f1s.append(f1)
# 指标均值,最为最后的预测结果
print('total scores is ', np.mean(scores))
print('total f1 is ', np.mean(f1s))
return predict

深度学习模型stacking模型融合python代码,看了你就会使的更多相关文章

  1. 时间序列深度学习:seq2seq 模型预测太阳黑子

    目录 时间序列深度学习:seq2seq 模型预测太阳黑子 学习路线 商业中的时间序列深度学习 商业中应用时间序列深度学习 深度学习时间序列预测:使用 keras 预测太阳黑子 递归神经网络 设置.预处 ...

  2. 【转】[caffe]深度学习之图像分类模型AlexNet解读

    [caffe]深度学习之图像分类模型AlexNet解读 原文地址:http://blog.csdn.net/sunbaigui/article/details/39938097   本文章已收录于: ...

  3. [caffe]深度学习之图像分类模型VGG解读

    一.简单介绍 vgg和googlenet是2014年imagenet竞赛的双雄,这两类模型结构有一个共同特点是go deeper.跟googlenet不同的是.vgg继承了lenet以及alexnet ...

  4. 深度学习 vs. 概率图模型 vs. 逻辑学

    深度学习 vs. 概率图模型 vs. 逻辑学 摘要:本文回顾过去50年人工智能(AI)领域形成的三大范式:逻辑学.概率方法和深度学习.文章按时间顺序展开,先回顾逻辑学和概率图方法,然后就人工智能和机器 ...

  5. 深度学习的seq2seq模型——本质是LSTM,训练过程是使得所有样本的p(y1,...,yT‘|x1,...,xT)概率之和最大

    from:https://baijiahao.baidu.com/s?id=1584177164196579663&wfr=spider&for=pc seq2seq模型是以编码(En ...

  6. 推荐系统遇上深度学习(十)--GBDT+LR融合方案实战

    推荐系统遇上深度学习(十)--GBDT+LR融合方案实战 0.8012018.05.19 16:17:18字数 2068阅读 22568 推荐系统遇上深度学习系列:推荐系统遇上深度学习(一)--FM模 ...

  7. 深入浅出深度学习:原理剖析与python实践_黄安埠(著) pdf

    深入浅出深度学习:原理剖析与python实践 目录: 第1 部分 概要 1 1 绪论 2 1.1 人工智能.机器学习与深度学习的关系 3 1.1.1 人工智能——机器推理 4 1.1.2 机器学习—— ...

  8. 一文看懂Stacking!(含Python代码)

    一文看懂Stacking!(含Python代码) https://mp.weixin.qq.com/s/faQNTGgBZdZyyZscdhjwUQ

  9. 风炫安全web安全学习第三十二节课 Python代码执行以及代码防御措施

    风炫安全web安全学习第三十二节课 Python代码执行以及代码防御措施 Python 语言可能发生的命令执行漏洞 内置危险函数 eval和exec函数 eval eval是一个python内置函数, ...

随机推荐

  1. Mysql INSERT、REPLACE、UPDATE的区别

    用于操作数据库的SQL一般分为两种,一种是查询语句,也就是我们所说的SELECT语句,另外一种就是更新语句,也叫做数据操作语句.言外之意,就是对数据进行修改.在标准的SQL中有3个语句,它们是INSE ...

  2. JavaBean之lombok

    参见:https://www.ibm.com/developerworks/cn/opensource/os-lombok/ http://blog.didispace.com/java-lombok ...

  3. Redis资料整理

    1.Redis命令參考中文简体版. 2.java操作redis.jedis使用api 3.Redis学习笔记. 4.浅谈Redis数据库的键值设计 5.Redis资料汇总专题 6.MongoDB资料汇 ...

  4. 【转载】JAVA基础:注解

    原文:https://www.cnblogs.com/xdp-gacl/p/3622275.html#undefined 一.认识注解 注解(Annotation)很重要,未来的开发模式都是基于注解的 ...

  5. Xcode 插件优缺点对照(推荐 20 款插件)

    Xcode 插件优缺点对照(推荐 20 款插件) 2016-01-22 06:16 编辑: lansekuangtu 分类:iOS开发 来源:董铂然 的博客 28 13527 /XCode/" ...

  6. django admin upload 上传图片到oss Django Aliyun OSS2 Storage

    https://github.com/xiewenya/django-aliyun-oss2-storage Install pip install django-aliyun-oss2-storag ...

  7. Spring自动扫描无法扫描jar包中bean的解决方法(转)

    转载自:http://www.jb51.net/article/116357.htm 在日常开发中往往会对公共的模块打包发布,然后调用公共包的内容.然而,最近对公司的公共模块进行整理发布后.sprin ...

  8. 【九天教您南方cass 9.1】 09 提取坐标的几种方法

    同学们大家好,欢迎收看由老王测量上班记出品的cass9.1视频课程 我是本节课主讲老师九天. 我们讲课的教程附件也是共享的,请注意索取测量空间中. [点击索取cass教程]5元立得 (给客服说暗号:“ ...

  9. js 上一步 下一步 操作

    <a id="syb" href="#" style="display: block;" class="btn button ...

  10. Go指南练习_错误

    源地址 https://tour.go-zh.org/methods/20 一.题目描述 从之前的练习中复制 Sqrt 函数,修改它使其返回 error 值. Sqrt 接受到一个负数时,应当返回一个 ...