# coding:UTF-8
'''
Date:20160901
@author: zhaozhiyong
'''
import numpy as np
from lr_train import sig def load_weight(w):
'''导入LR模型
input: w(string)权重所在的文件位置
output: np.mat(w)(mat)权重的矩阵
'''
f = open(w)
w = []
for line in f.readlines():
lines = line.strip().split("\t")
w_tmp = []
for x in lines:
w_tmp.append(float(x))
w.append(w_tmp)
f.close()
return np.mat(w) def load_data(file_name, n):
'''导入测试数据
input: file_name(string)测试集的位置
n(int)特征的个数
output: np.mat(feature_data)(mat)测试集的特征
'''
f = open(file_name)
feature_data = []
for line in f.readlines():
feature_tmp = []
lines = line.strip().split("\t")
# print lines[2]
if len(lines) <> n - 1:
continue
feature_tmp.append(1)
for x in lines:
# print x
feature_tmp.append(float(x))
feature_data.append(feature_tmp)
f.close()
return np.mat(feature_data) def predict(data, w):
'''对测试数据进行预测
input: data(mat)测试数据的特征
w(mat)模型的参数
output: h(mat)最终的预测结果
'''
h = sig(data * w.T)#sig
m = np.shape(h)[0]
for i in xrange(m):
if h[i, 0] < 0.5:
h[i, 0] = 0.0
else:
h[i, 0] = 1.0
return h def save_result(file_name, result):
'''保存最终的预测结果
input: file_name(string):预测结果保存的文件名
result(mat):预测的结果
'''
m = np.shape(result)[0]
#输出预测结果到文件
tmp = []
for i in xrange(m):
tmp.append(str(result[i, 0]))
f_result = open(file_name, "w")
f_result.write("\t".join(tmp))
f_result.close() if __name__ == "__main__":
# 1、导入LR模型
print "---------- 1.load model ------------"
w = load_weight("weights")
n = np.shape(w)[1]
# 2、导入测试数据
print "---------- 2.load data ------------"
testData = load_data("test_data", n)
# 3、对测试数据进行预测
print "---------- 3.get prediction ------------"
h = predict(testData, w)#进行预测
# 4、保存最终的预测结果
print "---------- 4.save prediction ------------"
save_result("result", h)

  

转自:

https://github.com/zhaozhiyong19890102/Python-Machine-Learning-Algorithm

02-赵志勇机器学习-Logistics_Regression-test(转载)的更多相关文章

  1. 00-赵志勇机器学习-Logistics_Regression-data.txt(转载)

    4.45925637575900 8.22541838354701 0 0.0432761720122110 6.30740040001402 0 6.99716180262699 9.3133933 ...

  2. 12-赵志勇机器学习-Label_Propagation

    (草稿) 过程: 1. 初始化所有节点的 labels 成唯一的值: 2. 对每个节点,将 label 更新为和其相连的所有节点中,标签最多的 节点的label: 2. 初始化情况下,假如所有相连的节 ...

  3. 11-赵志勇机器学习-DBSCAN聚类

    (草稿) 两点关系的三种定义: 1. 直接密度可达:A在B的邻域内: 2. 密度可达:AB之间存在,直接密度可达的点串: 3. 密度连接:AB之间存在点k,使得Ak和Bk都密度可达: 过程: 1. 对 ...

  4. 09-赵志勇机器学习-k-means

    (草稿) k-means: 1. 随机选取n个中心 2. 计算每个点到各个中心的距离 3. 距离小于阈值的归成一类. 4. 计算新类的质心,作为下一次循环的n个中心 5. 直到新类的质心和对应本次循环 ...

  5. 10-赵志勇机器学习-meanshift

    (草稿) meanshift 也是一种聚类方法. 优点在于:不需要提前指定类型数. 缺点就是计算量大 过程:(最一般的做法,没有使用核函数) 1. 逐点迭代,设置为位置中心 2. 计算所有点到位置中心 ...

  6. 01-赵志勇机器学习-Logistics_Regression-train

    Logistics Regression 二分类问题. 模型 线性模型 响应 sigmoid 损失函数(显示) 最小均方 优化方法 BGD 例子: #coding utf-8 import numpy ...

  7. 周志华-机器学习西瓜书-第三章习题3.5 LDA

    本文为周志华机器学习西瓜书第三章课后习题3.5答案,编程实现线性判别分析LDA,数据集为书本第89页的数据 首先介绍LDA算法流程: LDA的一个手工计算数学实例: 课后习题的代码: # coding ...

  8. 25个Java机器学习工具&库--转载

    本列表总结了25个Java机器学习工具&库: 1. Weka集成了数据挖掘工作的机器学习算法.这些算法可以直接应用于一个数据集上或者你可以自己编写代码来调用.Weka包括一系列的工具,如数据预 ...

  9. 机器学习周志华 pdf统计学习人工智能资料下载

    周志华-机器学习 pdf,下载地址: https://u12230716.pipipan.com/fs/12230716-239561959 统计学习方法-李航,  下载地址: https://u12 ...

随机推荐

  1. oracle 配置DBlink 链接mysql库

    一,环境配置与准备.简介 \ oracle mysql 主机名 oracle01 mysqlre1 IP 192.168.0.10 192.168.0.187 本文章是oracle通过dblink连接 ...

  2. 工作中常用的Linux命令介绍与实践

    前言 做后端开发的同学,一般都会接触到服务器,而我们现在的系统用的比较多的服务器系统就是linux了,平时多多少少也会接触到一些linux下的shell命令.我们来介绍下linux一些常用的命令和使用 ...

  3. Qt Quick 组件与动态对象

    博客24## 一.Components(组件) Component 是由 Qt 框架或开发者封装好的.只暴露了必要接口的 QML 类型,可以重复利用.一个 QML 组件就像一个黑盒子,它通过属性.信号 ...

  4. 33,Leetcode 搜索旋转排序数组-C++ 递归二分法

    题目描述 假设按照升序排序的数组在预先未知的某个点上进行了旋转. ( 例如,数组 [0,1,2,4,5,6,7] 可能变为 [4,5,6,7,0,1,2] ). 搜索一个给定的目标值,如果数组中存在这 ...

  5. BScroll使用

    当页面内容的高度超过视口高度的时候,会出现纵向滚动条:当页面内容的宽度超过视口宽度的时候,会出现横向滚动条.也就是当我们的视口展示不下内容的时候,会通过滚动条的方式让用户滚动屏幕看到剩余的内容. 话说 ...

  6. HUSKY CLOCK1.0上线啦!

    有人需要HUSKY CLOCK1.0下载资源的请联系1335415335@qq.com! 感谢支持,您的认可是我们前进的动力!

  7. VSCode批量替换使用注意问题

    VSCode批量替换功能很强大,需要注意两点 1.不要搜到文件个数超过到10000时替换,这时替换过程中可能会出错崩溃(也可能是服务器上内存较小导致) 2.不要在搜索中反复替换可能会导致数据错乱 比如 ...

  8. java跳出循环break;return;continue使用

    for(int i=0;i<5;i++){ if(i==2){ System.out.println("i==2时忽略了"); continue;//忽略i==2时的循环 } ...

  9. 【java】Java多线程总结之线程安全队列Queue【转载】

    原文地址:https://www.cnblogs.com/java-jun-world2099/articles/10165949.html ============================= ...

  10. ifame内嵌页面全屏完美展示

    <body style= marginwidth= marginheight= width='100%' height='100%' allowfullscreen='true' src='ht ...