TensorFlow.NET机器学习入门【6】采用神经网络处理Fashion-MNIST
"如果一个算法在MNIST上不work,那么它就根本没法用;而如果它在MNIST上work,它在其他数据上也可能不work"。
—— 马克吐温
上一篇文章我们实现了一个MNIST手写数字识别的程序,通过一个简单的两层神经网络,就轻松获得了98%的识别成功率。这个成功率不代表你的网络是有效的,因为MNIST实在是太简单了,我们需要更复杂的数据集来检验网络的有效性!这就有了Fashion-MNIST数据集,它采用10种服装的图片来取代数字0~9,除此之外,其图片大小、数量均和MNIST一致。
上篇文章的代码几乎不用改动,只要改个获取原始图片文件的文件夹名称即可。
程序运行结果识别成功率大约为82%左右。
我们可以对网络进行调整,看能否提高识别率,具体可用的方法:
1、增加网络层
2、增加神经元个数
3、改用其它激活函数
试验结果表明,不管如何调整,识别率始终上不去多少。可见该网络方案已经碰到了瓶颈,如果要大幅度提高识别率必须要采取新的方案了。
下篇文章我们将介绍卷积神经网络(CNN)的应用,通过CNN来处理图像数据将是一个更好、更科学的解决方案。
由于本文代码和上一篇文章的代码高度一致,这里就不再详细说明了。全部代码如下:


/// <summary>
/// 采用神经网络处理Fashion-MNIST数据集
/// </summary>
public class NN_MultipleClassification_Fashion_MNIST
{
private readonly string TrainImagePath = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\train";
private readonly string TestImagePath = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\test";
private readonly string train_date_path = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\train_data.bin";
private readonly string train_label_path = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\train_label.bin"; private readonly int img_rows = 28;
private readonly int img_cols = 28;
private readonly int num_classes = 10; // total classes public void Run()
{
var model = BuildModel();
model.summary(); model.compile(optimizer: keras.optimizers.Adam(0.001f),
loss: keras.losses.SparseCategoricalCrossentropy(),
metrics: new[] { "accuracy" }); (NDArray train_x, NDArray train_y) = LoadTrainingData();
model.fit(train_x, train_y, batch_size: 1024, epochs: 20); test(model);
} /// <summary>
/// 构建网络模型
/// </summary>
private Model BuildModel()
{
// 网络参数
int n_hidden_1 = 128; // 1st layer number of neurons.
int n_hidden_2 = 128; // 2nd layer number of neurons.
float scale = 1.0f / 255; var model = keras.Sequential(new List<ILayer>
{
keras.layers.InputLayer((img_rows,img_cols)),
keras.layers.Flatten(),
keras.layers.Rescaling(scale),
keras.layers.Dense(n_hidden_1, activation:keras.activations.Relu),
keras.layers.Dense(n_hidden_2, activation:keras.activations.Relu),
keras.layers.Dense(num_classes, activation:keras.activations.Softmax)
}); return model;
} /// <summary>
/// 加载训练数据
/// </summary>
/// <param name="total_size"></param>
private (NDArray, NDArray) LoadTrainingData()
{
try
{
Console.WriteLine("Load data");
IFormatter serializer = new BinaryFormatter();
FileStream loadFile = new FileStream(train_date_path, FileMode.Open, FileAccess.Read);
float[,,] arrx = serializer.Deserialize(loadFile) as float[,,]; loadFile = new FileStream(train_label_path, FileMode.Open, FileAccess.Read);
int[] arry = serializer.Deserialize(loadFile) as int[];
Console.WriteLine("Load data success");
return (np.array(arrx), np.array(arry));
}
catch (Exception ex)
{
Console.WriteLine($"Load data Exception:{ex.Message}");
return LoadRawData();
}
} private (NDArray, NDArray) LoadRawData()
{
Console.WriteLine("LoadRawData"); int total_size = 60000;
float[,,] arrx = new float[total_size, img_rows, img_cols];
int[] arry = new int[total_size]; int count = 0; DirectoryInfo RootDir = new DirectoryInfo(TrainImagePath);
foreach (var Dir in RootDir.GetDirectories())
{
foreach (var file in Dir.GetFiles("*.png"))
{
Bitmap bmp = (Bitmap)Image.FromFile(file.FullName);
if (bmp.Width != img_cols || bmp.Height != img_rows)
{
continue;
} for (int row = 0; row < img_rows; row++)
for (int col = 0; col < img_cols; col++)
{
var pixel = bmp.GetPixel(col, row);
int val = (pixel.R + pixel.G + pixel.B) / 3; arrx[count, row, col] = val;
arry[count] = int.Parse(Dir.Name);
} count++;
} Console.WriteLine($"Load image data count={count}");
} Console.WriteLine("LoadRawData finished");
//Save Data
Console.WriteLine("Save data");
IFormatter serializer = new BinaryFormatter(); //开始序列化
FileStream saveFile = new FileStream(train_date_path, FileMode.Create, FileAccess.Write);
serializer.Serialize(saveFile, arrx);
saveFile.Close(); saveFile = new FileStream(train_label_path, FileMode.Create, FileAccess.Write);
serializer.Serialize(saveFile, arry);
saveFile.Close();
Console.WriteLine("Save data finished"); return (np.array(arrx), np.array(arry));
} /// <summary>
/// 消费模型
/// </summary>
private void test(Model model)
{
Random rand = new Random(1); DirectoryInfo TestDir = new DirectoryInfo(TestImagePath);
foreach (var ChildDir in TestDir.GetDirectories())
{
Console.WriteLine($"Folder:【{ChildDir.Name}】");
var Files = ChildDir.GetFiles("*.png");
for (int i = 0; i < 10; i++)
{
int index = rand.Next(1000);
var image = Files[index]; var x = LoadImage(image.FullName);
var pred_y = model.Apply(x);
var result = argmax(pred_y[0].numpy()); Console.WriteLine($"FileName:{image.Name}\tPred:{result}");
}
}
} private NDArray LoadImage(string filename)
{
float[,,] arrx = new float[1, img_rows, img_cols];
Bitmap bmp = (Bitmap)Image.FromFile(filename); for (int row = 0; row < img_rows; row++)
for (int col = 0; col < img_cols; col++)
{
var pixel = bmp.GetPixel(col, row);
int val = (pixel.R + pixel.G + pixel.B) / 3;
arrx[0, row, col] = val;
} return np.array(arrx);
} private int argmax(NDArray array)
{
var arr = array.reshape(-1); float max = 0;
for (int i = 0; i < 10; i++)
{
if (arr[i] > max)
{
max = arr[i];
}
} for (int i = 0; i < 10; i++)
{
if (arr[i] == max)
{
return i;
}
} return 0;
}
}
【相关资源】
源码:Git: https://gitee.com/seabluescn/tf_not.git
项目名称:NN_MultipleClassification_Fashion_MNIST
TensorFlow.NET机器学习入门【6】采用神经网络处理Fashion-MNIST的更多相关文章
- TensorFlow.NET机器学习入门【3】采用神经网络实现非线性回归
上一篇文章我们介绍的线性模型的求解,但有很多模型是非线性的,比如: 这里表示有两个输入,一个输出. 现在我们已经不能采用y=ax+b的形式去定义一个函数了,我们只能知道输入变量的数量,但不知道某个变量 ...
- TensorFlow.NET机器学习入门【4】采用神经网络处理分类问题
上一篇文章我们介绍了通过神经网络来处理一个非线性回归的问题,这次我们将采用神经网络来处理一个多元分类的问题. 这次我们解决这样一个问题:输入一个人的身高和体重的数据,程序判断出这个人的身材状况,一共三 ...
- TensorFlow.NET机器学习入门【5】采用神经网络实现手写数字识别(MNIST)
从这篇文章开始,终于要干点正儿八经的工作了,前面都是准备工作.这次我们要解决机器学习的经典问题,MNIST手写数字识别. 首先介绍一下数据集.请首先解压:TF_Net\Asset\mnist_png. ...
- TensorFlow.NET机器学习入门【7】采用卷积神经网络(CNN)处理Fashion-MNIST
本文将介绍如何采用卷积神经网络(CNN)来处理Fashion-MNIST数据集. 程序流程如下: 1.准备样本数据 2.构建卷积神经网络模型 3.网络学习(训练) 4.消费.测试 除了网络模型的构建, ...
- TensorFlow.NET机器学习入门【8】采用GPU进行学习
随着网络越来约复杂,训练难度越来越大,有条件的可以采用GPU进行学习.本文介绍如何在GPU环境下使用TensorFlow.NET. TensorFlow.NET使用GPU非常的简单,代码不用做任何修改 ...
- TensorFlow.NET机器学习入门【0】前言与目录
曾经学习过一段时间ML.NET的知识,ML.NET是微软提供的一套机器学习框架,相对于其他的一些机器学习框架,ML.NET侧重于消费现有的网络模型,不太好自定义自己的网络模型,底层实现也做了高度封装. ...
- TensorFlow.NET机器学习入门【1】开发环境与类型简介
项目开发环境为Visual Studio 2019 + .Net 5 创建新项目后首先通过Nuget引入相关包: SciSharp.TensorFlow.Redist是Google提供的TensorF ...
- TensorFlow.NET机器学习入门【2】线性回归
回归分析用于分析输入变量和输出变量之间的一种关系,其中线性回归是最简单的一种. 设: Y=wX+b,现已知一组X(输入)和Y(输出)的值,要求出w和b的值. 举个例子:快年底了,销售部门要发年终奖了, ...
- 45、Docker 加 tensorflow的机器学习入门初步
[1]最近领导天天在群里发一些机器学习的链接,搞得好像我们真的要搞机器学习似的,吃瓜群众感觉好神奇呀. 第一步 其实也是最后一步,就是网上百度一下,Docker Toolbox,下载下来,下载,安装之 ...
随机推荐
- 谈一谈 DDD
一.前言 最近 10 年的互联网发展,从电子商务到移动互联,再到"互联网+"与传统行业的互联网转型,是一个非常痛苦的转型过程.在这个过程中,一方面会给我们带来诸多的挑战,另一方面又 ...
- Java项目发现==顺手改成equals之后,会发生什么?
最近发生一件很尴尬的事情,在维护一个 Java 项目的时候,发现有使用 == 来比较两个对象的属性, 于是顺手就把 == 改成了 equals.悲剧发生...... == 和 equals 的区别 = ...
- 试了下GoAsm
在VC里我们: #include <windows.h> DWORD dwNumberOfBytesWritten; int main() { HANDLE hStdOut = GetSt ...
- 源码分析-Consumer
消息消费概述 消息消费以组的模式开展,一个消费组内可以包含多个消费者,每一个消费者组可订阅多个主题,消费组之间有集群模式和广播模式两种消费模式. 集群模式,主题下的同一条消息只允许被其中一个消费者消费 ...
- 一起手写吧!promise.all
Promise.all 接收一个 promise 对象的数组作为参数,当这个数组里的所有 promise 对象全部变为resolve或 有 reject 状态出现的时候,它才会去调用 .then 方法 ...
- Vue相关,vue父子组件生命周期执行顺序。
一.实例代码 父组件: <template> <div id="parent"> <child></child> </div& ...
- 【bfs】洛谷 P1443 马的遍历
题目:P1443 马的遍历 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 记录一下第一道ac的bfs,原理是利用队列queue记录下一层的所有点,然后一层一层遍历: 其中: 1.p ...
- 用户创建firefox配置文件
1.打开cmd进放 firefox.exe所在的目录 如:D:\>cd D:\Mozilla Firefox 2.运行如命令:D:\Mozilla Firefox>firefox.exe ...
- shell脚本实现openss自建CA和证书申请
#!/bin/bash # #******************************************************************** #Author: Ma Xue ...
- ES6解构赋值的简单使用
相较于常规的赋值方式,解构赋值最主要的是'解构'两个字,在赋值的过程中要清晰的知道等号右边的结构. 先简单地看一下原来的赋值方式. var a=[1,2] 分析一下这句代码的几个点: (1)变量申明和 ...