import numpy as np
import pandas as pd
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.wrappers.scikit_learn import KerasClassifier
from keras.utils import np_utils
from sklearn.model_selection import train_test_split, KFold, cross_val_score
from sklearn.preprocessing import LabelEncoder
from keras.optimizers import SGD
from keras.layers import LSTM # load dataset
dataframe = pd.read_csv("./data/iris1.csv", header=None)
dataset = dataframe.values
X = dataset[:, 0:19].astype(float)
dummy_y1 = dataset[:, 19]
m,n=1682,6
dum_imax=np.zeros((m,n))
# print(type(dum_imax))
for i in range(m):
# print(i)
# exit()
if dummy_y1[i]!=0:
dum_imax[i][dummy_y1[i]-1]=1
else:
dum_imax[i][5]=1
# print(dum_imax)
dummy_y =dum_imax
print(dummy_y)
print(type(dummy_y[0][0])) def baseline_model():
model = Sequential()
model.add(Dense(output_dim=50, input_dim=19, activation='relu'))
# model.add(LSTM(128))
model.add(Dropout(0.4))
model.add(Dense(output_dim=6, input_dim=50, activation='softmax'))
# Compile model
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
# model.compile(loss='categorical_crossentropy', optimizer=sgd)
#编译模型。由于我们做的是二元分类,所以我们指定损失函数为binary_crossentropy,以及模式为binary
#另外常见的损失函数还有mean_squared_error、categorical_crossentropy等,请阅读帮助文件。
#求解方法我们指定用adam,还有sgd、rmsprop等可选
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
return model
estimator = KerasClassifier(build_fn=baseline_model, nb_epoch=40, batch_size=256)
print(estimator) # splitting data into training set and test set. If random_state is set to an integer, the split datasets are fixed.
X_train, X_test, Y_train, Y_test = train_test_split(X, dummy_y, test_size=0.2, random_state=0)#train_test_split是交叉验证中常用的函数,功能是从样本中随机的按比例选取train data和testdata,
print(len(X_train[0]))
print(len(Y_train[0]))
estimator.fit(X_train, Y_train,nb_epoch = 100)#训练模型,学习一百次 # make predictions
print(X_test)
pred = estimator.predict(X_test)
print(pred)
# init_lables = encoder.inverse_transform(pred)
# print(init_lables) # inverse numeric variables to initial categorical labels
# init_lables = encoder.inverse_transform(pred)
# print(init_lables) # k-fold cross-validate
# seed = 42
# np.random.seed(seed)
'''
n_splits : 默认3,最小为2;K折验证的K值
shuffle : 默认False;shuffle会对数据产生随机搅动(洗牌)
random_state :默认None,随机种子
'''
kfold = KFold(n_splits=5, shuffle=True)#定义5折,在对数据进行划分之前,对数据进行随机混洗 results = cross_val_score(estimator, X, dummy_y, cv=kfold)#在数据集上,使用k fold交叉验证,对估计器estimator进行评估。
print("baseline:%.2f%%(%.2f%%)"%(results.mean()*100,results.std()*100))#返回的结果,是10次数据集划分后,每次的评估结果。评估结果包括平均准确率和标准差

Keras人工神经网络多分类(SGD)的更多相关文章

  1. keras人工神经网络构建入门

    //2019.07.29-301.Keras 是提供一些高度可用神经网络框架的 Python API ,能帮助你快速的构建和训练自己的深度学习模型,它的后端是 TensorFlow 或者 Theano ...

  2. [DL学习笔记]从人工神经网络到卷积神经网络_3_使用tensorflow搭建CNN来分类not_MNIST数据(有一些问题)

    3:用tensorflow搭个神经网络出来 为什么用tensorflow呢,应为谷歌是亲爹啊,虽然有些人说caffe更适合图像啊mxnet效率更高等等,但爸爸就是爸爸,Android都能那么火,一个道 ...

  3. neurosolutions 人工神经网络集成开发环境 keras

    人工神经网络集成开发环境 :  http://www.neurosolutions.com/ keras:   https://github.com/fchollet/keras 文档    http ...

  4. [DL学习笔记]从人工神经网络到卷积神经网络_2_卷积神经网络

    先一层一层的说卷积神经网络是啥: 1:卷积层,特征提取 我们输入这样一幅图片(28*28): 如果用传统神经网络,下一层的每个神经元将连接到输入图片的每一个像素上去,但是在卷积神经网络中,我们只把输入 ...

  5. [DL学习笔记]从人工神经网络到卷积神经网络_1_神经网络和BP算法

    前言:这只是我的一个学习笔记,里边肯定有不少错误,还希望有大神能帮帮找找,由于是从小白的视角来看问题的,所以对于初学者或多或少会有点帮助吧. 1:人工全连接神经网络和BP算法 <1>:人工 ...

  6. 人工神经网络 Artificial Neural Network

    2017-12-18 23:42:33 一.什么是深度学习 深度学习(deep neural network)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高 ...

  7. SIGAI深度学习第二集 人工神经网络1

    讲授神经网络的思想起源.神经元原理.神经网络的结构和本质.正向传播算法.链式求导及反向传播算法.神经网络怎么用于实际问题等 课程大纲: 神经网络的思想起源 神经元的原理 神经网络结构 正向传播算法 怎 ...

  8. 机器学习笔记之人工神经网络(ANN)

    人工神经网络(ANN)提供了一种普遍而且实际的方法从样例中学习值为实数.离散值或向量函数.人工神经网络由一系列简单的单元相互连接构成,其中每个单元有一定数量的实值输入,并产生单一的实值输出. 上面是一 ...

  9. 人工神经网络简介和单层网络实现AND运算--AForge.NET框架的使用(五)

    原文:人工神经网络简介和单层网络实现AND运算--AForge.NET框架的使用(五) 前面4篇文章说的是模糊系统,它不同于传统的值逻辑,理论基础是模糊数学,所以有些朋友看着有点迷糊,如果有兴趣建议参 ...

随机推荐

  1. swift - view的指定位置切圆角

    1. extension UIView{ func addCorner(conrners: UIRectCorner , radius: CGFloat) { let maskPath = UIBez ...

  2. PHP系统编程--PHP进程信号处理(转)

    原地址:https://www.cnblogs.com/linzhenjie/p/5485436.html PHP的pcntl扩展提供了信号处理的功能,利用它可以让PHP来接管信号的处理,在开发服务器 ...

  3. Ubuntu 16.04安装JDK并配置环境变量-【小白版】

    系统版本:Ubuntu 16.04 JDK版本:jdk1.8.0_121 1.官网下载JDK文件jdk-8u121-linux-x64.tar.gz 我这里下的是最新版,其他版本也可以 2.创建一个目 ...

  4. JS中判断某个字符串是否包含另一个字符串的五种方法

    String对象的方法 方法一: indexOf()   (推荐) ? 1 2 var str = "123" console.log(str.indexOf("2&qu ...

  5. mysql 存储过程 与 循环

    mysql 操作同样有循环语句操作,三种标准循环模式:while, loop,repeat, 外加一种非标准循环:goto [在c或c#中貌似出现过类型循环但是一般不建议用!] 一般格式为:delim ...

  6. laravel中类似于thinkPHP中trace功能

    答案来自https://segmentfault.com/q/1010000007716945 一楼: 到 https://packagist.org 上搜索你想要的关键词,比如查debugbar 列 ...

  7. 7.Mysql存储引擎

    7.表类型(存储引擎)的选择7.1 Mysql存储引擎概述 mysql支持插件式存储引擎,即存储引擎以插件形式存在于mysql库中. mysql支持的存储引擎包括:MyISAM.InnoDB.BDB. ...

  8. 20172325『Java程序设计』课程 结对编程练习_四则运算第三周阶段总结

    20172325『Java程序设计』课程 结对编程练习_四则运算第三周阶段总结 结对伙伴 学号:20172306 姓名:刘辰 在这次项目的完成过程中刘辰同学付出了很多,在代码的实践上完成的很出色,在技 ...

  9. vim删除单词

    参考资料: https://blog.csdn.net/grey_csdn/article/details/72355735 混迹于Windows.Linux以及Mac,选择加强自己的VIM水平应该不 ...

  10. 设计师们做UI设计和交互设计、界面设计等一般会去什么网站呢?

    明明可靠颜值吃饭,却偏偏要靠才华立身,UI设计师就是这样一群神奇的物种.面对“大的同时小一点”.“五彩斑斓黑”.“下班之前给我”……这些甲方大大刁钻的需求,设计师每天都在咬牙微笑讨生活.你可以批评我的 ...