TensorFlow.NET机器学习入门【6】采用神经网络处理Fashion-MNIST
"如果一个算法在MNIST上不work,那么它就根本没法用;而如果它在MNIST上work,它在其他数据上也可能不work"。
—— 马克吐温
上一篇文章我们实现了一个MNIST手写数字识别的程序,通过一个简单的两层神经网络,就轻松获得了98%的识别成功率。这个成功率不代表你的网络是有效的,因为MNIST实在是太简单了,我们需要更复杂的数据集来检验网络的有效性!这就有了Fashion-MNIST数据集,它采用10种服装的图片来取代数字0~9,除此之外,其图片大小、数量均和MNIST一致。
上篇文章的代码几乎不用改动,只要改个获取原始图片文件的文件夹名称即可。
程序运行结果识别成功率大约为82%左右。
我们可以对网络进行调整,看能否提高识别率,具体可用的方法:
1、增加网络层
2、增加神经元个数
3、改用其它激活函数
试验结果表明,不管如何调整,识别率始终上不去多少。可见该网络方案已经碰到了瓶颈,如果要大幅度提高识别率必须要采取新的方案了。
下篇文章我们将介绍卷积神经网络(CNN)的应用,通过CNN来处理图像数据将是一个更好、更科学的解决方案。
由于本文代码和上一篇文章的代码高度一致,这里就不再详细说明了。全部代码如下:

/// <summary>
/// 采用神经网络处理Fashion-MNIST数据集
/// </summary>
public class NN_MultipleClassification_Fashion_MNIST
{
private readonly string TrainImagePath = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\train";
private readonly string TestImagePath = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\test";
private readonly string train_date_path = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\train_data.bin";
private readonly string train_label_path = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\train_label.bin"; private readonly int img_rows = 28;
private readonly int img_cols = 28;
private readonly int num_classes = 10; // total classes public void Run()
{
var model = BuildModel();
model.summary(); model.compile(optimizer: keras.optimizers.Adam(0.001f),
loss: keras.losses.SparseCategoricalCrossentropy(),
metrics: new[] { "accuracy" }); (NDArray train_x, NDArray train_y) = LoadTrainingData();
model.fit(train_x, train_y, batch_size: 1024, epochs: 20); test(model);
} /// <summary>
/// 构建网络模型
/// </summary>
private Model BuildModel()
{
// 网络参数
int n_hidden_1 = 128; // 1st layer number of neurons.
int n_hidden_2 = 128; // 2nd layer number of neurons.
float scale = 1.0f / 255; var model = keras.Sequential(new List<ILayer>
{
keras.layers.InputLayer((img_rows,img_cols)),
keras.layers.Flatten(),
keras.layers.Rescaling(scale),
keras.layers.Dense(n_hidden_1, activation:keras.activations.Relu),
keras.layers.Dense(n_hidden_2, activation:keras.activations.Relu),
keras.layers.Dense(num_classes, activation:keras.activations.Softmax)
}); return model;
} /// <summary>
/// 加载训练数据
/// </summary>
/// <param name="total_size"></param>
private (NDArray, NDArray) LoadTrainingData()
{
try
{
Console.WriteLine("Load data");
IFormatter serializer = new BinaryFormatter();
FileStream loadFile = new FileStream(train_date_path, FileMode.Open, FileAccess.Read);
float[,,] arrx = serializer.Deserialize(loadFile) as float[,,]; loadFile = new FileStream(train_label_path, FileMode.Open, FileAccess.Read);
int[] arry = serializer.Deserialize(loadFile) as int[];
Console.WriteLine("Load data success");
return (np.array(arrx), np.array(arry));
}
catch (Exception ex)
{
Console.WriteLine($"Load data Exception:{ex.Message}");
return LoadRawData();
}
} private (NDArray, NDArray) LoadRawData()
{
Console.WriteLine("LoadRawData"); int total_size = 60000;
float[,,] arrx = new float[total_size, img_rows, img_cols];
int[] arry = new int[total_size]; int count = 0; DirectoryInfo RootDir = new DirectoryInfo(TrainImagePath);
foreach (var Dir in RootDir.GetDirectories())
{
foreach (var file in Dir.GetFiles("*.png"))
{
Bitmap bmp = (Bitmap)Image.FromFile(file.FullName);
if (bmp.Width != img_cols || bmp.Height != img_rows)
{
continue;
} for (int row = 0; row < img_rows; row++)
for (int col = 0; col < img_cols; col++)
{
var pixel = bmp.GetPixel(col, row);
int val = (pixel.R + pixel.G + pixel.B) / 3; arrx[count, row, col] = val;
arry[count] = int.Parse(Dir.Name);
} count++;
} Console.WriteLine($"Load image data count={count}");
} Console.WriteLine("LoadRawData finished");
//Save Data
Console.WriteLine("Save data");
IFormatter serializer = new BinaryFormatter(); //开始序列化
FileStream saveFile = new FileStream(train_date_path, FileMode.Create, FileAccess.Write);
serializer.Serialize(saveFile, arrx);
saveFile.Close(); saveFile = new FileStream(train_label_path, FileMode.Create, FileAccess.Write);
serializer.Serialize(saveFile, arry);
saveFile.Close();
Console.WriteLine("Save data finished"); return (np.array(arrx), np.array(arry));
} /// <summary>
/// 消费模型
/// </summary>
private void test(Model model)
{
Random rand = new Random(1); DirectoryInfo TestDir = new DirectoryInfo(TestImagePath);
foreach (var ChildDir in TestDir.GetDirectories())
{
Console.WriteLine($"Folder:【{ChildDir.Name}】");
var Files = ChildDir.GetFiles("*.png");
for (int i = 0; i < 10; i++)
{
int index = rand.Next(1000);
var image = Files[index]; var x = LoadImage(image.FullName);
var pred_y = model.Apply(x);
var result = argmax(pred_y[0].numpy()); Console.WriteLine($"FileName:{image.Name}\tPred:{result}");
}
}
} private NDArray LoadImage(string filename)
{
float[,,] arrx = new float[1, img_rows, img_cols];
Bitmap bmp = (Bitmap)Image.FromFile(filename); for (int row = 0; row < img_rows; row++)
for (int col = 0; col < img_cols; col++)
{
var pixel = bmp.GetPixel(col, row);
int val = (pixel.R + pixel.G + pixel.B) / 3;
arrx[0, row, col] = val;
} return np.array(arrx);
} private int argmax(NDArray array)
{
var arr = array.reshape(-1); float max = 0;
for (int i = 0; i < 10; i++)
{
if (arr[i] > max)
{
max = arr[i];
}
} for (int i = 0; i < 10; i++)
{
if (arr[i] == max)
{
return i;
}
} return 0;
}
}
【相关资源】
源码:Git: https://gitee.com/seabluescn/tf_not.git
项目名称:NN_MultipleClassification_Fashion_MNIST
TensorFlow.NET机器学习入门【6】采用神经网络处理Fashion-MNIST的更多相关文章
- TensorFlow.NET机器学习入门【3】采用神经网络实现非线性回归
上一篇文章我们介绍的线性模型的求解,但有很多模型是非线性的,比如: 这里表示有两个输入,一个输出. 现在我们已经不能采用y=ax+b的形式去定义一个函数了,我们只能知道输入变量的数量,但不知道某个变量 ...
- TensorFlow.NET机器学习入门【4】采用神经网络处理分类问题
上一篇文章我们介绍了通过神经网络来处理一个非线性回归的问题,这次我们将采用神经网络来处理一个多元分类的问题. 这次我们解决这样一个问题:输入一个人的身高和体重的数据,程序判断出这个人的身材状况,一共三 ...
- TensorFlow.NET机器学习入门【5】采用神经网络实现手写数字识别(MNIST)
从这篇文章开始,终于要干点正儿八经的工作了,前面都是准备工作.这次我们要解决机器学习的经典问题,MNIST手写数字识别. 首先介绍一下数据集.请首先解压:TF_Net\Asset\mnist_png. ...
- TensorFlow.NET机器学习入门【7】采用卷积神经网络(CNN)处理Fashion-MNIST
本文将介绍如何采用卷积神经网络(CNN)来处理Fashion-MNIST数据集. 程序流程如下: 1.准备样本数据 2.构建卷积神经网络模型 3.网络学习(训练) 4.消费.测试 除了网络模型的构建, ...
- TensorFlow.NET机器学习入门【8】采用GPU进行学习
随着网络越来约复杂,训练难度越来越大,有条件的可以采用GPU进行学习.本文介绍如何在GPU环境下使用TensorFlow.NET. TensorFlow.NET使用GPU非常的简单,代码不用做任何修改 ...
- TensorFlow.NET机器学习入门【0】前言与目录
曾经学习过一段时间ML.NET的知识,ML.NET是微软提供的一套机器学习框架,相对于其他的一些机器学习框架,ML.NET侧重于消费现有的网络模型,不太好自定义自己的网络模型,底层实现也做了高度封装. ...
- TensorFlow.NET机器学习入门【1】开发环境与类型简介
项目开发环境为Visual Studio 2019 + .Net 5 创建新项目后首先通过Nuget引入相关包: SciSharp.TensorFlow.Redist是Google提供的TensorF ...
- TensorFlow.NET机器学习入门【2】线性回归
回归分析用于分析输入变量和输出变量之间的一种关系,其中线性回归是最简单的一种. 设: Y=wX+b,现已知一组X(输入)和Y(输出)的值,要求出w和b的值. 举个例子:快年底了,销售部门要发年终奖了, ...
- 45、Docker 加 tensorflow的机器学习入门初步
[1]最近领导天天在群里发一些机器学习的链接,搞得好像我们真的要搞机器学习似的,吃瓜群众感觉好神奇呀. 第一步 其实也是最后一步,就是网上百度一下,Docker Toolbox,下载下来,下载,安装之 ...
随机推荐
- C++构造函数和析构函数初步认识
构造函数 1.构造函数与类名相同,是特殊的公有成员函数.2.构造函数无函数返回类型说明,实际上构造函数是有返回值的,其返回值类型即为构造函数所构建到的对象.3.当新对象被建立时,构造函数便被自动调用, ...
- mysql之join浅析
1.可以使用join吗?使用join有什么问题呢?-- >超过3个表不使用join,笛卡尔积问题 -->这些问题是怎么造成的呢? 如果可以使用 Index Nested-Loop Join ...
- 【编程思想】【设计模式】【创建模式creational】Pool
Python版 https://github.com/faif/python-patterns/blob/master/creational/pool.py #!/usr/bin/env python ...
- Shell脚本实现自动修改IP地址
作为一名Linux SA,日常运维中很多地方都会用到脚本,而服务器的ip一般采用静态ip或者MAC绑定,当然后者比较操作起来相对繁琐,而前者我们可以设置主机名.ip信息.网关等配置.修改成特定的主机名 ...
- 【Linux】【Shell】【Basic】文件查找locate,find
1.locate: 1.1. 简介:依赖于事先构建好的索引库: 系统自动实现(周期性任务): 手动更新数据库(updatedb): 1.2. 工作特性:查找速度快:模糊 ...
- jQuery遍历的几种方式
一.jQuery对象遍历 1 <script type="text/javascript" src="js/jquery-3.4.1.js">< ...
- 【C/C++】习题3-7 DNA/算法竞赛入门经典/数组与字符串
[题目] 输入m组n长的DNA序列,要求找出和其他Hamming距离最小的那个序列,求其与其他的Hamming距离总和. 如果有多个序列,求字典序最小的. [注]这道题是我理解错误,不是找出输入的序列 ...
- 【JAVA今法修真】 第四章 redis特性 击穿雪崩!
感谢这段时间大家的支持,关注我的微信号:南橘ryc ,回复云小霄,就可以获取到最新的福利靓照一张,还等什么,赶快来加入我们吧~ "明日便是决赛了,咋只会用法器没练过法术呢.". 选 ...
- Linux下安装中文字体
目录 一.Centos系列 二.Ubuntu系列 一.Centos系列 1.安装字体库 yum -y install fontconfig 2.添加中文字体,建立存储中文字体的文件夹 mkdir /u ...
- GaussDB(DWS)中共享消息队列实现的三大功能
摘要:本文将详细介绍GaussDB(DWS)中共享消息队列的实现. 本文分享自华为云社区<GaussDB(DWS)CBB组件之共享消息队列介绍>,作者:疯狂朔朔. 1)共享消息队列是什么? ...