Spark学习笔记——手写数字识别
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.regression.RandomForestRegressor
import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, NaiveBayes, SVMWithSGD}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.optimization.L1Updater
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree, RandomForest}
import org.apache.spark.mllib.tree.configuration.Algo
import org.apache.spark.mllib.tree.impurity.Entropy /**
* Created by common on 17-5-17.
*/ case class LabeledPic(
label: Int,
pic: List[Double] = List()
) object DigitRecognizer { def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("DigitRecgonizer").setMaster("local")
val sc = new SparkContext(conf)
// 去掉第一行,sed 1d train.csv > train_noheader.csv
val trainFile = "file:///media/common/工作/kaggle/DigitRecognizer/train_noheader.csv"
val trainRawData = sc.textFile(trainFile)
// 通过逗号对数据进行分割,生成数组的rdd
val trainRecords = trainRawData.map(line => line.split(",")) val trainData = trainRecords.map { r =>
val label = r(0).toInt
val features = r.slice(1, r.size).map(d => d.toDouble)
LabeledPoint(label, Vectors.dense(features))
} // // 使用贝叶斯模型
// val nbModel = NaiveBayes.train(trainData)
//
// val nbTotalCorrect = trainData.map { point =>
// if (nbModel.predict(point.features) == point.label) 1 else 0
// }.sum
// val nbAccuracy = nbTotalCorrect / trainData.count
//
// println("贝叶斯模型正确率:" + nbAccuracy)
//
// // 对测试数据进行预测
// val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
// // 通过逗号对数据进行分割,生成数组的rdd
// val testRecords = testRawData.map(line => line.split(","))
//
// val testData = testRecords.map { r =>
// val features = r.map(d => d.toDouble)
// Vectors.dense(features)
// }
// val predictions = nbModel.predict(testData).map(p => p.toInt)
// // 保存预测结果
// predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict") // // 使用线性回归模型
// val lrModel = new LogisticRegressionWithLBFGS()
// .setNumClasses(10)
// .run(trainData)
//
// val lrTotalCorrect = trainData.map { point =>
// if (lrModel.predict(point.features) == point.label) 1 else 0
// }.sum
// val lrAccuracy = lrTotalCorrect / trainData.count
//
// println("线性回归模型正确率:" + lrAccuracy)
//
// // 对测试数据进行预测
// val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
// // 通过逗号对数据进行分割,生成数组的rdd
// val testRecords = testRawData.map(line => line.split(","))
//
// val testData = testRecords.map { r =>
// val features = r.map(d => d.toDouble)
// Vectors.dense(features)
// }
// val predictions = lrModel.predict(testData).map(p => p.toInt)
// // 保存预测结果
// predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict1") // // 使用决策树模型
// val maxTreeDepth = 10
// val numClass = 10
// val dtModel = DecisionTree.train(trainData, Algo.Classification, Entropy, maxTreeDepth, numClass)
//
// val dtTotalCorrect = trainData.map { point =>
// if (dtModel.predict(point.features) == point.label) 1 else 0
// }.sum
// val dtAccuracy = dtTotalCorrect / trainData.count
//
// println("决策树模型正确率:" + dtAccuracy)
//
// // 对测试数据进行预测
// val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
// // 通过逗号对数据进行分割,生成数组的rdd
// val testRecords = testRawData.map(line => line.split(","))
//
// val testData = testRecords.map { r =>
// val features = r.map(d => d.toDouble)
// Vectors.dense(features)
// }
// val predictions = dtModel.predict(testData).map(p => p.toInt)
// // 保存预测结果
// predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict2") // // 使用随机森林模型
// val numClasses = 30
// val categoricalFeaturesInfo = Map[Int, Int]()
// val numTrees = 50
// val featureSubsetStrategy = "auto"
// val impurity = "gini"
// val maxDepth = 10
// val maxBins = 32
// val rtModel = RandomForest.trainClassifier(trainData, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
//
// val rtTotalCorrect = trainData.map { point =>
// if (rtModel.predict(point.features) == point.label) 1 else 0
// }.sum
// val rtAccuracy = rtTotalCorrect / trainData.count
//
// println("随机森林模型正确率:" + rtAccuracy)
//
// // 对测试数据进行预测
// val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
// // 通过逗号对数据进行分割,生成数组的rdd
// val testRecords = testRawData.map(line => line.split(","))
//
// val testData = testRecords.map { r =>
// val features = r.map(d => d.toDouble)
// Vectors.dense(features)
// }
// val predictions = rtModel.predict(testData).map(p => p.toInt)
// // 保存预测结果
// predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict") } }
Spark学习笔记——手写数字识别的更多相关文章
- TessorFlow学习 之 手写数字识别的搭建
手写数字识别的搭建
- 机器学习框架ML.NET学习笔记【4】多元分类之手写数字识别
一.问题与解决方案 通过多元分类算法进行手写数字识别,手写数字的图片分辨率为8*8的灰度图片.已经预先进行过处理,读取了各像素点的灰度值,并进行了标记. 其中第0列是序号(不参与运算).1-64列是像 ...
- 机器学习框架ML.NET学习笔记【5】多元分类之手写数字识别(续)
一.概述 上一篇文章我们利用ML.NET的多元分类算法实现了一个手写数字识别的例子,这个例子存在一个问题,就是输入的数据是预处理过的,很不直观,这次我们要直接通过图片来进行学习和判断.思路很简单,就是 ...
- 学习笔记CB009:人工神经网络模型、手写数字识别、多层卷积网络、词向量、word2vec
人工神经网络,借鉴生物神经网络工作原理数学模型. 由n个输入特征得出与输入特征几乎相同的n个结果,训练隐藏层得到意想不到信息.信息检索领域,模型训练合理排序模型,输入特征,文档质量.文档点击历史.文档 ...
- SVM学习笔记(二)----手写数字识别
引言 上一篇博客整理了一下SVM分类算法的基本理论问题,它分类的基本思想是利用最大间隔进行分类,处理非线性问题是通过核函数将特征向量映射到高维空间,从而变成线性可分的,但是运算却是在低维空间运行的.考 ...
- 【深度学习系列】PaddlePaddle之手写数字识别
上周在搜索关于深度学习分布式运行方式的资料时,无意间搜到了paddlepaddle,发现这个框架的分布式训练方案做的还挺不错的,想跟大家分享一下.不过呢,这块内容太复杂了,所以就简单的介绍一下padd ...
- 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)
上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...
- 深度学习之 mnist 手写数字识别
深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...
- 深度学习之PyTorch实战(3)——实战手写数字识别
上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...
随机推荐
- swagger知识点补充
1. swagger知识点补充 1.1. 概述 在swagger的使用过程中,除了网上常见的例子,还会有很多细节上的东西需要注意和改写,这里我列几点我使用过程中遇到的问题和改进方式 1.2. 知识点 ...
- 多线程里面this.getName()和currentThread.getName()有什么区别
public class hello extends Thread { public hello(){ System.out.println("Thread.currentThread(). ...
- iOS Application Extension
链接: iOS App Extension入门 iOS10通知及通知拓展Extension使用详解(附Demo) iOS10 通知extension之 Content Extension
- IntelliJ IDEA2018.1、2017.3破解教程
(1)下载破解补丁 把下载的破解补丁放在你的idea的安装目录下的bin的目录下面(如下图所示),本文示例为G:\idea\IntelliJ IDEA 2017.3.4 破解补丁下载:http://i ...
- nginx多站路由配置tomcat
server { listen 80; server_name 1.goal.cn; index index index.html index.htm index.jsp; root /www/ser ...
- VS2013中Python学习笔记[环境搭建]
前言 Python是一个高层次的结合了解释性.编译性.互动性和面向对象的脚本语言. Python的设计具有很强的可读性,相比其他语言经常使用英文关键字,其他语言的一些标点符号,它具有比其他语言更有特色 ...
- WebMagic编译时提示Failure to transfer org.apache.maven.plugins:maven-surefire-plugin:pom:2.18的解决方法
问题描述: 从http://git.oschina.net/flashsword20/webmagic 下载最新代码,按照http://webmagic.io/docs/zh/posts/ch3 ...
- oracle获取今年在内的前几年、后几年
前几年 select to_char(sysdate, 'yyyy') - level + 1 years from dual connect by level <= num num:即想获取几 ...
- Morris图表使用小记
挺好用的,碰到几个问题,有的是瞎试解决了的: 1.我想折线图能够响应单击事件,即点击某个节点后,就能加载进一步的信息,帮助没找到,参照另外一个地方的写法,居然支持事件 Morris.Line({ el ...
- GraphQL循环引用的问题
下面的代码中, 由于friends字段引用了PersonType字段,而friends本身又是PersonType的一部分,在运行的时候会报错: Expected undefined to be a ...