只是简单demo,
可以看出tensorflow非常简洁,适合快速实验

 
 

import tensorflow as tf

import numpy as np

import melt_dataset

import sys

from sklearn.metrics import roc_auc_score

 
 

def init_weights(shape):

return tf.Variable(tf.random_normal(shape, stddev=0.01))

 
 

def model(X, w_h, w_o):

h = tf.nn.sigmoid(tf.matmul(X, w_h)) # this is a basic mlp, think 2 stacked logistic regressions

return tf.matmul(h, w_o) # note that we dont take the softmax at the end because our cost fn does that for us

 
 

batch_size = 50

learning_rate = 0.1

num_iters = 500

hidden_size = 20

 
 

argv = sys.argv

trainset = argv[1]

testset = argv[2]

 
 

trX, trY = melt_dataset.load_dense_data(trainset)

print "finish loading train set ",trainset

teX, teY = melt_dataset.load_dense_data(testset)

print "finish loading test set ", testset

 
 

num_features = trX[0].shape[0]

print 'num_features: ',num_features

print 'trainSet size: ', len(trX)

print 'testSet size: ', len(teX)

print 'batch_size:', batch_size, ' learning_rate:', learning_rate, ' num_iters:', num_iters

 
 

X = tf.placeholder("float", [None, num_features]) # create symbolic variables

Y = tf.placeholder("float", [None, 1])

 
 

w_h = init_weights([num_features, hidden_size]) # create symbolic variables

w_o = init_weights([hidden_size, 1])

 
 

py_x = model(X, w_h, w_o)

 
 

cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(py_x, Y)) # compute costs

train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) # construct an optimizer

predict_op = tf.nn.sigmoid(py_x)

 
 

sess = tf.Session()

init = tf.initialize_all_variables()

sess.run(init)

 
 

for i in range(num_iters):

predicts, cost_ = sess.run([predict_op, cost], feed_dict={X: teX, Y: teY})

print i, 'auc:', roc_auc_score(teY, predicts), 'cost:', cost_

for start, end in zip(range(0, len(trX), batch_size), range(batch_size, len(trX), batch_size)):

sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end]})

 
 

predicts, cost_ = sess.run([predict_op, cost], feed_dict={X: teX, Y: teY})

print 'final ', 'auc:', roc_auc_score(teY, predicts),'cost:', cost_

 
 

 
 

 
 

python ./mlp.py corpus/feature.normed.rand.12000.0_2.txt corpus/feature.normed.rand.12000.1_2.txt

 
 

233 auc: 0.932099377357 cost: 0.210673

234 auc: 0.93210173764 cost: 0.210674

235 auc: 0.93210173764 cost: 0.210675

236 auc: 0.932089936225 cost: 0.210676

Tensorflow mlp二分类的更多相关文章

  1. tensorflow实现二分类

    读万卷书,不如行万里路.之前看了不少机器学习方面的书籍,但是实战很少.这次因为项目接触到tensorflow,用一个最简单的深层神经网络实现分类和回归任务. 首先说分类任务,分类任务的两个思路: 如果 ...

  2. Tensorflow CIFAR10 (二分类)

    数据的下载: (共有三个版本:python,matlab,binary version 适用于C语言) http://www.cs.toronto.edu/~kriz/cifar-10-python. ...

  3. tensorflow实现svm iris二分类——本质上在使用梯度下降法求解线性回归(loss是定制的而已)

    iris二分类 # Linear Support Vector Machine: Soft Margin # ---------------------------------- # # This f ...

  4. 【原】Spark之机器学习(Python版)(二)——分类

    写这个系列是因为最近公司在搞技术分享,学习Spark,我的任务是讲PySpark的应用,因为我主要用Python,结合Spark,就讲PySpark了.然而我在学习的过程中发现,PySpark很鸡肋( ...

  5. SVM原理以及Tensorflow 实现SVM分类(附代码)

    1.1. SVM介绍 1.2. 工作原理 1.2.1. 几何间隔和函数间隔 1.2.2. 最大化间隔 - 1.2.2.0.0.1. \(L( {x}^*)\)对$ {x}^*$求导为0 - 1.2.2 ...

  6. Kaggle实战之二分类问题

    0. 前言 1. MNIST 数据集 2. 二分类器 3. 效果评测 4. 多分类器与误差分析 5. Kaggle 实战 0. 前言 "尽管新技术新算法层出不穷,但是掌握好基础算法就能解决手 ...

  7. 深度学习之 TensorFlow(二):TensorFlow 基础知识

    1.TensorFlow 系统架构: 分为设备层和网络层.数据操作层.图计算层.API 层.应用层.其中设备层和网络层.数据操作层.图计算层是 TensorFlow 的核心层. 2.TensorFlo ...

  8. keras实现简单性别识别(二分类问题)

    keras实现简单性别识别(二分类问题) 第一步:准备好需要的库 tensorflow  1.4.0 h5py 2.7.0 hdf5 1.8.15.1 Keras     2.0.8 opencv-p ...

  9. tensorflow 教程 文本分类 IMDB电影评论

    昨天配置了tensorflow的gpu版本,今天开始简单的使用一下 主要是看了一下tensorflow的tutorial 里面的 IMDB 电影评论二分类这个教程 教程里面主要包括了一下几个内容:下载 ...

随机推荐

  1. ionic ios发布图标启动也不能正确加载问题

    前两天发布ios的时候发现ios安装的图标和启动页的时候不能正确显示,重新发布也不能正确显示,修改方法 在ionic build ios --release之前执行ionic resources即可

  2. [创业中, 寻求合作] 业务方向:车联网智能终端;APP蓝牙控制汽车;APP网络远程控制汽车 (联系电话:18503086002)

    擅长领域 手机APP蓝牙控制汽车方案 手机APP网络远程控制汽车方案 手机APP与汽车车机的文件极速传输技术 车载OBD终端 (后装) 智能TBOX终端,Base on Linux,使用车规级硬件加密 ...

  3. MAC地址是什么

    简介: MAC(Media Access Control或者Medium Access Control)地址,意译为媒体访问控制,或称为物理地址.硬件地址,用来定义网络设备的位置.在OSI模型中,第三 ...

  4. 面试题目——《CC150》数组与字符串

    面试题1.1:实现一个算法,确定一个字符串的所有字符是否全都不同.假使不允许使用额外的数据结构,又该如何处理? 注意:ASCII字符共有255个,其中0-127的字符有字符表 第一种解法:是<C ...

  5. Debian 8安装中文字体

    1.使用的镜像是debian-8.3.0-amd64-kde-CD-1.iso,下载链接可在Debian网站找到,系统安装完成后中文显示为方框 2.安装字体 apt-get install xfont ...

  6. Git 学习笔记参考

    1.参考学习资料 网上资料: http://www.cnblogs.com/aoguren/p/4189086.html http://www.liaoxuefeng.com/wiki/0013739 ...

  7. 直接拿来用,最火的.NET开源项目

    综合类 微软企业库 微软官方出品,是为了协助开发商解决企业级应用开发过程中所面临的一系列共性的问题, 如安全(Security).日志(Logging).数据访问(Data Access).配置管理( ...

  8. PHP中判断变量为空的几种方法

    判断变量为空,在许多场合都会用到,同时自己和许多新手一样也经常会犯一些错误, 所以自己整理了一下PHP中一些常用的.判断变量为空的方法. 1. isset功能:判断变量是否被初始化本函数用来测试变量是 ...

  9. poj1062 昂贵的聘礼

    Description 年轻的探险家来到了一个印第安部落里.在那里他和酋长的女儿相爱了,于是便向酋长去求亲.酋长要他用10000个金币作为聘礼才答应把女儿嫁给他.探险家拿不出这么多金币,便请求酋长降低 ...

  10. javascript 函数与对象

    javascript中的函数是非常重要的概念,也是比较难于理解的一个知识点! 下面就来聊聊函数: JS基于对象:什么是基于对象呢?简单的说所有代码都是"对象"; 比如函数: fun ...