TensorFlow.NET机器学习入门【5】采用神经网络实现手写数字识别(MNIST)
从这篇文章开始,终于要干点正儿八经的工作了,前面都是准备工作。这次我们要解决机器学习的经典问题,MNIST手写数字识别。
首先介绍一下数据集。请首先解压:TF_Net\Asset\mnist_png.tar.gz文件
文件夹内包括两个文件夹:training和validation,其中training文件夹下包括60000个训练图片validation下包括10000个评估图片,图片为28*28像素,分别放在0~9十个文件夹中。
程序总体流程和上一篇文章介绍的BMI分析程序基本一致,毕竟都是多元分类,有几点不一样。
1、BMI程序的特征数据(输入)为一维数组,包含两个数字,MNIST的特征数据为28*28的二位数组;
2、BMI程序的输出为3个,MNIST的输出为10个;
网络模型构建如下:
private readonly int img_rows = 28;
private readonly int img_cols = 28;
private readonly int num_classes = 10; // total classes
/// <summary>
/// 构建网络模型
/// </summary>
private Model BuildModel()
{
// 网络参数
int n_hidden_1 = 128; // 1st layer number of neurons.
int n_hidden_2 = 128; // 2nd layer number of neurons.
float scale = 1.0f / 255; var model = keras.Sequential(new List<ILayer>
{
keras.layers.InputLayer((img_rows,img_cols)),
keras.layers.Flatten(),
keras.layers.Rescaling(scale),
keras.layers.Dense(n_hidden_1, activation:keras.activations.Relu),
keras.layers.Dense(n_hidden_2, activation:keras.activations.Relu),
keras.layers.Dense(num_classes, activation:keras.activations.Softmax)
}); return model;
}
这个网络里用到了两个新方法,需要解释一下:
1、Flatten方法:这里表示拉平,把28*28的二维数组拉平为含784个数据的一维数组,因为二维数组无法进行运算;
2、Rescaling 方法:就是对每个数据乘以一个系数,因为我们从图片获取的数据为每一个位点的灰度值,其取值范围为0~255,所以乘以一个系数将数据缩小到1以内,以免后面运算时溢出。
其它基本和上一篇文章介绍的差不多,全部代码如下:

/// <summary>
/// 通过神经网络来实现多元分类
/// </summary>
public class NN_MultipleClassification_BMI
{
private readonly Random random = new Random(1); // 网络参数
int num_features = 2; // data features
int num_classes = 3; // total output . public void Run()
{
var model = BuildModel();
model.summary(); Console.WriteLine("Press any key to continue...");
Console.ReadKey(); (NDArray train_x, NDArray train_y) = PrepareData(1000);
model.compile(optimizer: keras.optimizers.Adam(0.001f),
loss: keras.losses.SparseCategoricalCrossentropy(),
metrics: new[] { "accuracy" });
model.fit(train_x, train_y, batch_size: 128, epochs: 300); test(model);
} /// <summary>
/// 构建网络模型
/// </summary>
private Model BuildModel()
{
// 网络参数
int n_hidden_1 = 64; // 1st layer number of neurons.
int n_hidden_2 = 64; // 2nd layer number of neurons. var model = keras.Sequential(new List<ILayer>
{
keras.layers.InputLayer(num_features),
keras.layers.Dense(n_hidden_1, activation:keras.activations.Relu),
keras.layers.Dense(n_hidden_2, activation:keras.activations.Relu),
keras.layers.Dense(num_classes, activation:keras.activations.Softmax)
}); return model;
} /// <summary>
/// 加载训练数据
/// </summary>
/// <param name="total_size"></param>
private (NDArray, NDArray) PrepareData(int total_size)
{
float[,] arrx = new float[total_size, num_features];
int[] arry = new int[total_size]; for (int i = 0; i < total_size; i++)
{
float weight = (float)random.Next(30, 100) / 100;
float height = (float)random.Next(140, 190) / 100;
float bmi = (weight * 100) / (height * height); arrx[i, 0] = weight;
arrx[i, 1] = height; switch (bmi)
{
case var x when x < 18.0f:
arry[i] = 0;
break; case var x when x >= 18.0f && x <= 28.0f:
arry[i] = 1;
break; case var x when x > 28.0f:
arry[i] = 2;
break;
}
} return (np.array(arrx), np.array(arry));
} /// <summary>
/// 消费模型
/// </summary>
private void test(Model model)
{
int test_size = 20;
for (int i = 0; i < test_size; i++)
{
float weight = (float)random.Next(40, 90) / 100;
float height = (float)random.Next(145, 185) / 100;
float bmi = (weight * 100) / (height * height); var test_x = np.array(new float[1, 2] { { weight, height } });
var pred_y = model.Apply(test_x); Console.WriteLine($"{i}:weight={(float)weight} \theight={height} \tBMI={bmi:0.0} \tPred:{pred_y[0].numpy()}");
}
}
}
另有两点说明:
1、由于对图片的读取比较耗时,所以我采用了一个方法,就是把读取到的数据序列化到一个二进制文件中,下次直接从二进制文件反序列化即可,大大加快处理速度。
2、我没有采用validation图片进行评估,只是简单选了20个样本测试了一下。
【相关资源】
源码:Git: https://gitee.com/seabluescn/tf_not.git
项目名称:NN_MultipleClassification_MNIST
TensorFlow.NET机器学习入门【5】采用神经网络实现手写数字识别(MNIST)的更多相关文章
- TensorFlow卷积神经网络实现手写数字识别以及可视化
边学习边笔记 https://www.cnblogs.com/felixwang2/p/9190602.html # https://www.cnblogs.com/felixwang2/p/9190 ...
- TensorFlow 之 手写数字识别MNIST
官方文档: MNIST For ML Beginners - https://www.tensorflow.org/get_started/mnist/beginners Deep MNIST for ...
- BP神经网络的手写数字识别
BP神经网络的手写数字识别 ANN 人工神经网络算法在实践中往往给人难以琢磨的印象,有句老话叫“出来混总是要还的”,大概是由于具有很强的非线性模拟和处理能力,因此作为代价上帝让它“黑盒”化了.作为一种 ...
- 利用c++编写bp神经网络实现手写数字识别详解
利用c++编写bp神经网络实现手写数字识别 写在前面 从大一入学开始,本菜菜就一直想学习一下神经网络算法,但由于时间和资源所限,一直未展开比较透彻的学习.大二下人工智能课的修习,给了我一个学习的契机. ...
- 第二节,TensorFlow 使用前馈神经网络实现手写数字识别
一 感知器 感知器学习笔记:https://blog.csdn.net/liyuanbhu/article/details/51622695 感知器(Perceptron)是二分类的线性分类模型,其输 ...
- 5 TensorFlow入门笔记之RNN实现手写数字识别
------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...
- 卷积神经网络CNN 手写数字识别
1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...
- 【机器学习】BP神经网络实现手写数字识别
最近用python写了一个实现手写数字识别的BP神经网络,BP的推导到处都是,但是一动手才知道,会理论推导跟实现它是两回事.关于BP神经网络的实现网上有一些代码,可惜或多或少都有各种问题,在下手写了一 ...
- BP神经网络(手写数字识别)
1实验环境 实验环境:CPU i7-3770@3.40GHz,内存8G,windows10 64位操作系统 实现语言:python 实验数据:Mnist数据集 程序使用的数据库是mnist手写数字数据 ...
随机推荐
- hadoop-uber作业模式
如果作业很小,就选择和自己在同一个JVM上运行任务,与在一个节点上顺序运行这些任务相比,当application master 判断在新的容器中的分配和运行任务的开销大于并行运行它们的开销时,就会发生 ...
- pyqt5 改写函数
重新改写了keyPressEvent() class TextEdit(QTextEdit): def __init__(self): QtWidgets.QTextEdit.__init__(sel ...
- MediaPlayer详解
[1]MediaPlayer 详细使用细则 [2]MediaPlayer使用详解_为新手准备 [3]MediaPlayer 概览
- Android Https相关完全解析
转载: 转载请标明出处: http://blog.csdn.net/lmj623565791/article/details/48129405: 本文出自:[张鸿洋的博客] 一.概述 其实这篇文章理论 ...
- Bootstrap-table动态表格
在开发中遇到一个需要动态生成table的需求,包括表头和数据.在调试的过程中遇到很多问题,包括数据分页,解决之后记录一下. 如下代码的数据加载流程: ①表头是动态的,在初始化table之前需要调一次后 ...
- EM配置问题
配置EM,首先要保证dbconsole在运行. C:\Users\dingqi>emctl start dbconsoleEnvironment variable ORACLE_UNQNAME ...
- Linux:expr、let、for、while、until、shift、if、case、break、continue、函数、select
1.expr计算整数变量值 格式 :expr arg 例子:计算(2+3)×4的值 1.分步计算,即先计算2+3,再对其和乘4 s=`expr 2 + 3` expr $s \* 4 2.一步完成计算 ...
- java的父类声明,子类实例化(强制类型转换导致异常ClassCastException)
一.使用原因 父类声明,子类实例化,既可以使用子类强大的功能,又可以抽取父类的共性. 二.使用要点 1.父类类型的引用可以调用父类中定义的所有属性和方法: 2.父类中方法只有在是父类中定义而在子类中没 ...
- java多线程并发编程中的锁
synchronized: https://www.cnblogs.com/dolphin0520/p/3923737.html Lock:https://www.cnblogs.com/dolphi ...
- 团队协作项目——SVN的使用
参考文献:https://www.cnblogs.com/rwh871212/p/6955489.html 老师接了一个新项目,需要团队共同完成开发任务,因此需要SVN.SVN是C/S架构: 1.服务 ...