输入的数据集是10000行,31645列,其中前31644是特征,最后一列是标签值。训练集和测试集格式是一样的。

特征值都是0,1形式,表示有还是没有这个特征,标签值是0,1形式,2分类。

import keras
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense, Activation
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation
from keras.layers import Embedding
from keras.layers import LSTM

#*******************#########################**********************************************************
#++++++++++++$$$$$***********************************************
#*****************************************************************************************

#带标签数据训练
# Keras以Numpy数组作为输入数据和标签的数据类型。训练模型一般使用fit函数

# for a single-input model with 2 classes (binary):
#dese后面的数字表示输出的数据的维数,只有一个add,一个激活函数,认为网络只有一层,那么dense后面的数字必须是要输出的类别数,
#如果中间有几层网络,那么每一个dense后面试输出为下一层的网络神经元个数,但是最后一个add的dense后面的数字必须是输出的类别数。

model = Sequential()
model.add(Dense(20, input_dim=31644, activation='sigmoid')) #把dense=1改为20
model.add(Dropout(0.5))
model.add(Dense(10,activation='relu')) #没有input 表示隐层神经元
model.add(Dropout(0.5))
model.add(Dense(1,activation='sigmoid')) #输出1维,表示是输出层神经元

model.compile(optimizer='rmsprop',loss='binary_crossentropy',metrics=['accuracy'])

# generate dummy data
import numpy as np

def meke_sample(p):
   data=[]
   label=[]
   f_tain=open("/Users/zhangb/Desktop/学习研究/测试样本/"+p,'r')
   for i in f_tain.readlines():
        lt=i.strip().split(',')
        data.append([int(x) for x in lt[0:-1]])
        label.append([int(lt[-1])])
    f_tain.close()
    data2=np.array(data)
    label2=np.array(label)
    return [data2,label2]

train_sample=meke_sample("train")
test_sample=meke_sample("test")

train_data=train_sample[0]
train_label=train_sample[1]

test_data=test_sample[0]
test_label=test_sample[1]

#定义混淆矩阵,左边是预测值,上面是实际值 从左到右,从上到下 依次为a,b,c,d表示
def confusion_mat(test_label,predicts):
      test_calss=[int(x) for x in (list(test_label))] #传入的是数组,转成数字列表
      pred_class=[int(x) for x in (list(predicts))]
      a,b,c,d=0,0,0,0
      for i in range(len(test_calss)):
          if pred_class[i]==1 and test_calss[i]==1 :
                 a +=1
          elif pred_class[i]==1 and test_calss[i]==0 :
                 b +=1
          elif pred_class[i]==0 and test_calss[i]==1 :
                  c +=1
          elif pred_class[i]==0 and test_calss[i]==0 :
                  d +=1
      precision_1=a/(a+b+0.0)
      precision_0=d/(c+d+0.0)
      recall_1=a/(a+c+0.0)
      recall_0=d/(d+b+0.0)
      precision=(a+d)/(a+b+c+d+0.0)
      f1=2*precision_1*recall_1/(precision_1+recall_1)
      f0=2*precision_0*recall_0/(precision_0+recall_0)
      return [[a,b],[c,d],[precision_1,precision_0,recall_1,recall_0,precision,f1,f0]]

#
#print label
model.fit(train_data, train_label, nb_epoch=80, batch_size=32)
score = model.evaluate(test_data, test_label, batch_size=32) #得到损失值和准确率
print score
pred=model.predict_classes(test_data) #得到预测值 是数组
#print pred

confusion=confusion_mat(test_label,pred)
precision_1=confusion[2][0]
precision_0=confusion[2][1]
recall_1=confusion[2][2]
recall_0=confusion[2][3]
precision=confusion[2][4]
f1=confusion[2][5]
f0=confusion[2][6]
print confusion[0]
print confusion[1]
print 'precision_1 :'+ str(precision_1) + ' precision_0:' +str(precision_0)
print 'recall_1: '+str(recall_1)+' recall_0: '+ str(recall_0)
print 'precision:' +str(precision)
print 'f1: '+str(f1) + 'f0: '+str(f0)
#print ret
#print score

keras做DNN的更多相关文章

  1. Keras 构建DNN 对用户名检测判断是否为非法用户名(从数据预处理到模型在线预测)

    一.  数据集的准备与预处理 1 . 收集dataset (大量用户名--包含正常用户名与非法用户名) 包含两个txt文件  legal_name.txt  ilegal_name.txt. 如下图所 ...

  2. 用keras做SQL注入攻击的判断

    本文是通过深度学习框架keras来做SQL注入特征识别, 不过虽然用了keras,但是大部分还是普通的神经网络,只是外加了一些规则化.dropout层(随着深度学习出现的层). 基本思路就是喂入一堆数 ...

  3. 使用Keras做OCR时报错:ValueError: Tensor Tensor is not an element of this graph

    现象 项目使用 Flask + Keras + Tensorflow 同样的代码在机器A和B上都能正常运行,但在机器C上就会报如下异常.机器A和B的环境是先安装的,运行.调试成功后才尝试在C上跑. F ...

  4. 数据挖掘入门系列教程(十一)之keras入门使用以及构建DNN网络识别MNIST

    简介 在上一篇博客:数据挖掘入门系列教程(十点五)之DNN介绍及公式推导中,详细的介绍了DNN,并对其进行了公式推导.本来这篇博客是准备直接介绍CNN的,但是想了一下,觉得还是使用keras构建一个D ...

  5. keras如何求分类问题中的准确率和召回率

    https://www.zhihu.com/question/53294625 由于要用keras做一个多分类的问题,评价标准采用precision,recall,和f1_score:但是keras中 ...

  6. 【Python】keras神经网络识别mnist

    上次用Matlab写过一个识别Mnist的神经网络,地址在:https://www.cnblogs.com/tiandsp/p/9042908.html 这次又用Keras做了一个差不多的,毕竟,现在 ...

  7. keras系列︱seq2seq系列相关实现与案例(feedback、peek、attention类型)

    之前在看<Semi-supervised Sequence Learning>这篇文章的时候对seq2seq半监督的方式做文本分类的方式产生了一定兴趣,于是开始简单研究了seq2seq.先 ...

  8. 条件随机场CRF原理介绍 以及Keras实现

    本文是对CRF基本原理的一个简明的介绍.当然,“简明”是相对而言中,要想真的弄清楚CRF,免不了要提及一些公式,如果只关心调用的读者,可以直接移到文末. 图示# 按照之前的思路,我们依旧来对比一下普通 ...

  9. keras冒bug

    使用keras做vgg16的迁移学习实验,在实现的过程中,冒各种奇怪的bug,甚至剪贴复制还是出问题. 解决方案: 当使用组合keras和tensorflow.keras时.由于版本不一致问题导致很多 ...

随机推荐

  1. CC2530中串口波特率改为9600时单个数据包来不及接收的解决方案

    在调试CC2530过程中发现波特率改为9600时,单个包仅有3个Byte时,接收DMA就会启动 因而数据包被强迫拆分成多个,显然只要将接收DMA启动延时做到足够大即可. 具体修改内容如下图所示: 经过 ...

  2. Mac 下安装python3.7 + pip 利用 chrome + chromedriver + selenium 自动打开网页并自动点击访问指定页面

    1.安装python3.7https://www.python.org/downloads/release/python-370/选择了这个版本,直接默认下一步 2.安装pipcurl https:/ ...

  3. C# DateTime 月第一天和最后一天 取法

    取得某月和上个月第一天和最后一天的方法 /// <summary> /// 取得某月的第一天 /// </summary> /// <param name="d ...

  4. WIN7 X64 下 VS2008升级补丁 (显示隐藏按钮)

    原文地址:http://blog.sina.com.cn/s/blog_57b5da120100gk7l.html 更新列表: 2010年3月26日:增加对日文版的支持. 2010年3月3日:更新代码 ...

  5. BASIC-25_蓝桥杯_回形取数

    示例代码: #include <stdio.h>#define N 200 int main(void){ int num[N][N]; int i= 0, j = 0 , k = 0 , ...

  6. (unittest之装饰器(@classmethod)) 让多个测试用例在一个浏览器里面跑 的方法

    一.装饰器 1.用setUp与setUpClass区别 setup():每个测试case运行前运行teardown():每个测试case运行完后执行setUpClass():必须使用@classmet ...

  7. appium 3-31626 toast识别

    1.toast弹窗,普通方式不能获取 例如使用getPageSource是无法找到toast的信息,uiautomatorViewer加载页面时间较长,也很难采集到toast信息 2.通过curl命令 ...

  8. php 流程控制switch实例

    switch允许对一个标量(表达式)的多个可能结果做选择. 语法: switch (expr) { case result1: statement1 break; case result2: stat ...

  9. R语言中的遗传算法详细解析

    前言 人类总是在生活中摸索规律,把规律总结为经验,再把经验传给后人,让后人发现更多的规规律,每一次知识的传递都是一次进化的过程,最终会形成了人类的智慧.自然界规律,让人类适者生存地活了下来,聪明的科学 ...

  10. Mybatis通过colliection属性递归获取菜单树

    1.现有商品分类数据表category结构如下,三个字段都为varchar类型 2.创建商品分类对应的数据Bean /** * */ package com.xdw.dao; import java. ...