Spark学习笔记——手写数字识别
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.regression.RandomForestRegressor
import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, NaiveBayes, SVMWithSGD}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.optimization.L1Updater
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree, RandomForest}
import org.apache.spark.mllib.tree.configuration.Algo
import org.apache.spark.mllib.tree.impurity.Entropy /**
* Created by common on 17-5-17.
*/ case class LabeledPic(
label: Int,
pic: List[Double] = List()
) object DigitRecognizer { def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("DigitRecgonizer").setMaster("local")
val sc = new SparkContext(conf)
// 去掉第一行,sed 1d train.csv > train_noheader.csv
val trainFile = "file:///media/common/工作/kaggle/DigitRecognizer/train_noheader.csv"
val trainRawData = sc.textFile(trainFile)
// 通过逗号对数据进行分割,生成数组的rdd
val trainRecords = trainRawData.map(line => line.split(",")) val trainData = trainRecords.map { r =>
val label = r(0).toInt
val features = r.slice(1, r.size).map(d => d.toDouble)
LabeledPoint(label, Vectors.dense(features))
} // // 使用贝叶斯模型
// val nbModel = NaiveBayes.train(trainData)
//
// val nbTotalCorrect = trainData.map { point =>
// if (nbModel.predict(point.features) == point.label) 1 else 0
// }.sum
// val nbAccuracy = nbTotalCorrect / trainData.count
//
// println("贝叶斯模型正确率:" + nbAccuracy)
//
// // 对测试数据进行预测
// val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
// // 通过逗号对数据进行分割,生成数组的rdd
// val testRecords = testRawData.map(line => line.split(","))
//
// val testData = testRecords.map { r =>
// val features = r.map(d => d.toDouble)
// Vectors.dense(features)
// }
// val predictions = nbModel.predict(testData).map(p => p.toInt)
// // 保存预测结果
// predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict") // // 使用线性回归模型
// val lrModel = new LogisticRegressionWithLBFGS()
// .setNumClasses(10)
// .run(trainData)
//
// val lrTotalCorrect = trainData.map { point =>
// if (lrModel.predict(point.features) == point.label) 1 else 0
// }.sum
// val lrAccuracy = lrTotalCorrect / trainData.count
//
// println("线性回归模型正确率:" + lrAccuracy)
//
// // 对测试数据进行预测
// val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
// // 通过逗号对数据进行分割,生成数组的rdd
// val testRecords = testRawData.map(line => line.split(","))
//
// val testData = testRecords.map { r =>
// val features = r.map(d => d.toDouble)
// Vectors.dense(features)
// }
// val predictions = lrModel.predict(testData).map(p => p.toInt)
// // 保存预测结果
// predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict1") // // 使用决策树模型
// val maxTreeDepth = 10
// val numClass = 10
// val dtModel = DecisionTree.train(trainData, Algo.Classification, Entropy, maxTreeDepth, numClass)
//
// val dtTotalCorrect = trainData.map { point =>
// if (dtModel.predict(point.features) == point.label) 1 else 0
// }.sum
// val dtAccuracy = dtTotalCorrect / trainData.count
//
// println("决策树模型正确率:" + dtAccuracy)
//
// // 对测试数据进行预测
// val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
// // 通过逗号对数据进行分割,生成数组的rdd
// val testRecords = testRawData.map(line => line.split(","))
//
// val testData = testRecords.map { r =>
// val features = r.map(d => d.toDouble)
// Vectors.dense(features)
// }
// val predictions = dtModel.predict(testData).map(p => p.toInt)
// // 保存预测结果
// predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict2") // // 使用随机森林模型
// val numClasses = 30
// val categoricalFeaturesInfo = Map[Int, Int]()
// val numTrees = 50
// val featureSubsetStrategy = "auto"
// val impurity = "gini"
// val maxDepth = 10
// val maxBins = 32
// val rtModel = RandomForest.trainClassifier(trainData, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
//
// val rtTotalCorrect = trainData.map { point =>
// if (rtModel.predict(point.features) == point.label) 1 else 0
// }.sum
// val rtAccuracy = rtTotalCorrect / trainData.count
//
// println("随机森林模型正确率:" + rtAccuracy)
//
// // 对测试数据进行预测
// val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
// // 通过逗号对数据进行分割,生成数组的rdd
// val testRecords = testRawData.map(line => line.split(","))
//
// val testData = testRecords.map { r =>
// val features = r.map(d => d.toDouble)
// Vectors.dense(features)
// }
// val predictions = rtModel.predict(testData).map(p => p.toInt)
// // 保存预测结果
// predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict") } }
Spark学习笔记——手写数字识别的更多相关文章
- TessorFlow学习 之 手写数字识别的搭建
手写数字识别的搭建
- 机器学习框架ML.NET学习笔记【4】多元分类之手写数字识别
一.问题与解决方案 通过多元分类算法进行手写数字识别,手写数字的图片分辨率为8*8的灰度图片.已经预先进行过处理,读取了各像素点的灰度值,并进行了标记. 其中第0列是序号(不参与运算).1-64列是像 ...
- 机器学习框架ML.NET学习笔记【5】多元分类之手写数字识别(续)
一.概述 上一篇文章我们利用ML.NET的多元分类算法实现了一个手写数字识别的例子,这个例子存在一个问题,就是输入的数据是预处理过的,很不直观,这次我们要直接通过图片来进行学习和判断.思路很简单,就是 ...
- 学习笔记CB009:人工神经网络模型、手写数字识别、多层卷积网络、词向量、word2vec
人工神经网络,借鉴生物神经网络工作原理数学模型. 由n个输入特征得出与输入特征几乎相同的n个结果,训练隐藏层得到意想不到信息.信息检索领域,模型训练合理排序模型,输入特征,文档质量.文档点击历史.文档 ...
- SVM学习笔记(二)----手写数字识别
引言 上一篇博客整理了一下SVM分类算法的基本理论问题,它分类的基本思想是利用最大间隔进行分类,处理非线性问题是通过核函数将特征向量映射到高维空间,从而变成线性可分的,但是运算却是在低维空间运行的.考 ...
- 【深度学习系列】PaddlePaddle之手写数字识别
上周在搜索关于深度学习分布式运行方式的资料时,无意间搜到了paddlepaddle,发现这个框架的分布式训练方案做的还挺不错的,想跟大家分享一下.不过呢,这块内容太复杂了,所以就简单的介绍一下padd ...
- 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)
上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...
- 深度学习之 mnist 手写数字识别
深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...
- 深度学习之PyTorch实战(3)——实战手写数字识别
上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...
随机推荐
- 本地搭建Wordpress博客网站(Windows)
最近在写一些web功能测试的一个主题分享,里边有一个分类是数据库测试,那么数据库测试有几个点的方法,其中有一个是学会看数据库的日志.由于公司内部的数据库日志我们测试人员暂时不开放查看,所以打算自己在本 ...
- 20172302 《Java软件结构与数据结构》第九周学习总结
2018年学习总结博客总目录:第一周 第二周 第三周 第四周 第五周 第六周 第七周 第八周 第九周 教材学习内容总结 第十五章 图 1.图:图(graph)是由一些点(vertex)和这些点之间的连 ...
- 如何调整eclipse中代码字体大小
找到windows--->preferences---->General------>Appearance---->color and fonts ---->ba ...
- oracle中的decode的使用(转)
地址:http://www.cnblogs.com/juddhu/archive/2012/03/07/2383101.html 含义解释:decode(条件,值1,返回值1,值2,返回值2,...值 ...
- Delphi发布ActiveX控件 制作CAB包 数字签名相关
文件: SignTool.rar 大小: 84KB 下载: 下载 最近我正在研究ActiveX技术.我使用Delphi 7创建了一个具有ActiveForm的ActiveX控件应用程序.这个控件产生一 ...
- asp.net mvc流程图4.6以前
- 小程序快速部署富文本插件wxParser
为了解决html2wxml在ios下字体过大问题,又发现一个比较好用的富文本插件:wxParser. 目前 wxParser 支持对一般的富文本内容包括标题.字体大小.对齐和列表等进行解析.同时也支持 ...
- Andorid之官方导航栏Toobar
在前面学习使用ActionBar的时候,我们就发现ActionBar中有些方法被标记为过时了,原来在android5.0之后,google推出了一个新的导航工具栏,官方将其定义为:A standard ...
- postgresql ltree类型
最近一个月使用Postgresql的时候,经常遇到ltree的数据,感觉有些别扭,可是有绕不过去.今天决心整理一下,以后使用方便一些. 一.简介 ltree是Postgresql的一个扩展类型,由两位 ...
- 最新整合maven+SSM+Tomcat 实现注册登录
mybatis学习 http://www.mybatis.org/mybatis-3/zh/index.html Spring学习:http://blog.csdn.net/king1425/arti ...