import numpy as np
import pandas as pd
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.wrappers.scikit_learn import KerasClassifier
from keras.utils import np_utils
from sklearn.model_selection import train_test_split, KFold, cross_val_score
from sklearn.preprocessing import LabelEncoder
from keras.optimizers import SGD
from keras.layers import LSTM # load dataset
dataframe = pd.read_csv("./data/iris1.csv", header=None)
dataset = dataframe.values
X = dataset[:, 0:19].astype(float)
dummy_y1 = dataset[:, 19]
m,n=1682,6
dum_imax=np.zeros((m,n))
# print(type(dum_imax))
for i in range(m):
# print(i)
# exit()
if dummy_y1[i]!=0:
dum_imax[i][dummy_y1[i]-1]=1
else:
dum_imax[i][5]=1
# print(dum_imax)
dummy_y =dum_imax
print(dummy_y)
print(type(dummy_y[0][0])) def baseline_model():
model = Sequential()
model.add(Dense(output_dim=50, input_dim=19, activation='relu'))
# model.add(LSTM(128))
model.add(Dropout(0.4))
model.add(Dense(output_dim=6, input_dim=50, activation='softmax'))
# Compile model
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
# model.compile(loss='categorical_crossentropy', optimizer=sgd)
#编译模型。由于我们做的是二元分类,所以我们指定损失函数为binary_crossentropy,以及模式为binary
#另外常见的损失函数还有mean_squared_error、categorical_crossentropy等,请阅读帮助文件。
#求解方法我们指定用adam,还有sgd、rmsprop等可选
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
return model
estimator = KerasClassifier(build_fn=baseline_model, nb_epoch=40, batch_size=256)
print(estimator) # splitting data into training set and test set. If random_state is set to an integer, the split datasets are fixed.
X_train, X_test, Y_train, Y_test = train_test_split(X, dummy_y, test_size=0.2, random_state=0)#train_test_split是交叉验证中常用的函数,功能是从样本中随机的按比例选取train data和testdata,
print(len(X_train[0]))
print(len(Y_train[0]))
estimator.fit(X_train, Y_train,nb_epoch = 100)#训练模型,学习一百次 # make predictions
print(X_test)
pred = estimator.predict(X_test)
print(pred)
# init_lables = encoder.inverse_transform(pred)
# print(init_lables) # inverse numeric variables to initial categorical labels
# init_lables = encoder.inverse_transform(pred)
# print(init_lables) # k-fold cross-validate
# seed = 42
# np.random.seed(seed)
'''
n_splits : 默认3,最小为2;K折验证的K值
shuffle : 默认False;shuffle会对数据产生随机搅动(洗牌)
random_state :默认None,随机种子
'''
kfold = KFold(n_splits=5, shuffle=True)#定义5折,在对数据进行划分之前,对数据进行随机混洗 results = cross_val_score(estimator, X, dummy_y, cv=kfold)#在数据集上,使用k fold交叉验证,对估计器estimator进行评估。
print("baseline:%.2f%%(%.2f%%)"%(results.mean()*100,results.std()*100))#返回的结果,是10次数据集划分后,每次的评估结果。评估结果包括平均准确率和标准差

Keras人工神经网络多分类(SGD)的更多相关文章

  1. keras人工神经网络构建入门

    //2019.07.29-301.Keras 是提供一些高度可用神经网络框架的 Python API ,能帮助你快速的构建和训练自己的深度学习模型,它的后端是 TensorFlow 或者 Theano ...

  2. [DL学习笔记]从人工神经网络到卷积神经网络_3_使用tensorflow搭建CNN来分类not_MNIST数据(有一些问题)

    3:用tensorflow搭个神经网络出来 为什么用tensorflow呢,应为谷歌是亲爹啊,虽然有些人说caffe更适合图像啊mxnet效率更高等等,但爸爸就是爸爸,Android都能那么火,一个道 ...

  3. neurosolutions 人工神经网络集成开发环境 keras

    人工神经网络集成开发环境 :  http://www.neurosolutions.com/ keras:   https://github.com/fchollet/keras 文档    http ...

  4. [DL学习笔记]从人工神经网络到卷积神经网络_2_卷积神经网络

    先一层一层的说卷积神经网络是啥: 1:卷积层,特征提取 我们输入这样一幅图片(28*28): 如果用传统神经网络,下一层的每个神经元将连接到输入图片的每一个像素上去,但是在卷积神经网络中,我们只把输入 ...

  5. [DL学习笔记]从人工神经网络到卷积神经网络_1_神经网络和BP算法

    前言:这只是我的一个学习笔记,里边肯定有不少错误,还希望有大神能帮帮找找,由于是从小白的视角来看问题的,所以对于初学者或多或少会有点帮助吧. 1:人工全连接神经网络和BP算法 <1>:人工 ...

  6. 人工神经网络 Artificial Neural Network

    2017-12-18 23:42:33 一.什么是深度学习 深度学习(deep neural network)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高 ...

  7. SIGAI深度学习第二集 人工神经网络1

    讲授神经网络的思想起源.神经元原理.神经网络的结构和本质.正向传播算法.链式求导及反向传播算法.神经网络怎么用于实际问题等 课程大纲: 神经网络的思想起源 神经元的原理 神经网络结构 正向传播算法 怎 ...

  8. 机器学习笔记之人工神经网络(ANN)

    人工神经网络(ANN)提供了一种普遍而且实际的方法从样例中学习值为实数.离散值或向量函数.人工神经网络由一系列简单的单元相互连接构成,其中每个单元有一定数量的实值输入,并产生单一的实值输出. 上面是一 ...

  9. 人工神经网络简介和单层网络实现AND运算--AForge.NET框架的使用(五)

    原文:人工神经网络简介和单层网络实现AND运算--AForge.NET框架的使用(五) 前面4篇文章说的是模糊系统,它不同于传统的值逻辑,理论基础是模糊数学,所以有些朋友看着有点迷糊,如果有兴趣建议参 ...

随机推荐

  1. 使用jquery刷新当前页面

    div的局部刷新 $(".dl").load(location.href+" .dl"); 全页面的刷新方法 window.location.reload()刷 ...

  2. linux命令学习之:mv

    mv命令是move的缩写,可以用来移动文件或者将文件改名(move (rename) files),是Linux系统下常用的命令,经常用来备份文件或者目录. 命令格式    mv [选项] 源文件或目 ...

  3. IE7下面踩得坑

    bug1.position:fixed:z-index:99; 出现了z-index:2的层级跑到他上面了, 为什么?会出现这问题??? 检查: 1你的固定定位的容器是不是被其他容器包裹,你包裹得容器 ...

  4. Vue 进度条 和 图片的动态更改

    <!DOCTYPE html> <html> <head> <meta charset="utf-8"> <title> ...

  5. mysql 添加外键

    create table class( cid tinyint unsigned primary key auto_increment, caption varchar(15) not null)en ...

  6. andorid 手机外部储存

    .xml <?xml version="1.0" encoding="utf-8"?> <LinearLayout xmlns:android ...

  7. XiaoKL学Python(D)argparse

    该文以Python 2为基础. 1. argparse简介 argparse使得编写用户友好的命令行接口更简单. argparse知道如何解析sys.argv. argparse 模块自动生成 “帮助 ...

  8. 图片延时加载原理 和 使用jquery实现的一个图片延迟加载插件(含图片延迟加载原理)

    图片加载技术分为:图片预加载和图片延时加载. javascript图片预加载和延时加载的区别主要体现在图片传输到客户端的时机上,都是为了提升用户体验的,延时加载又叫懒加载.两种技术的本质:两者的行为是 ...

  9. if __name__ == '__main__的理解

    模块之间引用不能循环成环,圆圈   模块的收搜   !!!把模块当作脚本执行 什么叫模块:py文件,如果一个py文件被导入了,他就是一个模块, 模块没有具体的调用过程 但是能对外提供功能   什么叫脚 ...

  10. 20172306《Java程序设计与数据结构》第十周学习总结

    20172306<Java程序设计>第十周学习总结 教材学习内容总结 本章主要的讲的是集合有关的知识: 1.集合与数据结构 - 集合是一种对象,集合表示一个专用于保存元素的对象,并该对象还 ...