TensorFlow.NET机器学习入门【6】采用神经网络处理Fashion-MNIST
"如果一个算法在MNIST上不work,那么它就根本没法用;而如果它在MNIST上work,它在其他数据上也可能不work"。
—— 马克吐温
上一篇文章我们实现了一个MNIST手写数字识别的程序,通过一个简单的两层神经网络,就轻松获得了98%的识别成功率。这个成功率不代表你的网络是有效的,因为MNIST实在是太简单了,我们需要更复杂的数据集来检验网络的有效性!这就有了Fashion-MNIST数据集,它采用10种服装的图片来取代数字0~9,除此之外,其图片大小、数量均和MNIST一致。
上篇文章的代码几乎不用改动,只要改个获取原始图片文件的文件夹名称即可。
程序运行结果识别成功率大约为82%左右。
我们可以对网络进行调整,看能否提高识别率,具体可用的方法:
1、增加网络层
2、增加神经元个数
3、改用其它激活函数
试验结果表明,不管如何调整,识别率始终上不去多少。可见该网络方案已经碰到了瓶颈,如果要大幅度提高识别率必须要采取新的方案了。
下篇文章我们将介绍卷积神经网络(CNN)的应用,通过CNN来处理图像数据将是一个更好、更科学的解决方案。
由于本文代码和上一篇文章的代码高度一致,这里就不再详细说明了。全部代码如下:

/// <summary>
/// 采用神经网络处理Fashion-MNIST数据集
/// </summary>
public class NN_MultipleClassification_Fashion_MNIST
{
private readonly string TrainImagePath = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\train";
private readonly string TestImagePath = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\test";
private readonly string train_date_path = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\train_data.bin";
private readonly string train_label_path = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\train_label.bin"; private readonly int img_rows = 28;
private readonly int img_cols = 28;
private readonly int num_classes = 10; // total classes public void Run()
{
var model = BuildModel();
model.summary(); model.compile(optimizer: keras.optimizers.Adam(0.001f),
loss: keras.losses.SparseCategoricalCrossentropy(),
metrics: new[] { "accuracy" }); (NDArray train_x, NDArray train_y) = LoadTrainingData();
model.fit(train_x, train_y, batch_size: 1024, epochs: 20); test(model);
} /// <summary>
/// 构建网络模型
/// </summary>
private Model BuildModel()
{
// 网络参数
int n_hidden_1 = 128; // 1st layer number of neurons.
int n_hidden_2 = 128; // 2nd layer number of neurons.
float scale = 1.0f / 255; var model = keras.Sequential(new List<ILayer>
{
keras.layers.InputLayer((img_rows,img_cols)),
keras.layers.Flatten(),
keras.layers.Rescaling(scale),
keras.layers.Dense(n_hidden_1, activation:keras.activations.Relu),
keras.layers.Dense(n_hidden_2, activation:keras.activations.Relu),
keras.layers.Dense(num_classes, activation:keras.activations.Softmax)
}); return model;
} /// <summary>
/// 加载训练数据
/// </summary>
/// <param name="total_size"></param>
private (NDArray, NDArray) LoadTrainingData()
{
try
{
Console.WriteLine("Load data");
IFormatter serializer = new BinaryFormatter();
FileStream loadFile = new FileStream(train_date_path, FileMode.Open, FileAccess.Read);
float[,,] arrx = serializer.Deserialize(loadFile) as float[,,]; loadFile = new FileStream(train_label_path, FileMode.Open, FileAccess.Read);
int[] arry = serializer.Deserialize(loadFile) as int[];
Console.WriteLine("Load data success");
return (np.array(arrx), np.array(arry));
}
catch (Exception ex)
{
Console.WriteLine($"Load data Exception:{ex.Message}");
return LoadRawData();
}
} private (NDArray, NDArray) LoadRawData()
{
Console.WriteLine("LoadRawData"); int total_size = 60000;
float[,,] arrx = new float[total_size, img_rows, img_cols];
int[] arry = new int[total_size]; int count = 0; DirectoryInfo RootDir = new DirectoryInfo(TrainImagePath);
foreach (var Dir in RootDir.GetDirectories())
{
foreach (var file in Dir.GetFiles("*.png"))
{
Bitmap bmp = (Bitmap)Image.FromFile(file.FullName);
if (bmp.Width != img_cols || bmp.Height != img_rows)
{
continue;
} for (int row = 0; row < img_rows; row++)
for (int col = 0; col < img_cols; col++)
{
var pixel = bmp.GetPixel(col, row);
int val = (pixel.R + pixel.G + pixel.B) / 3; arrx[count, row, col] = val;
arry[count] = int.Parse(Dir.Name);
} count++;
} Console.WriteLine($"Load image data count={count}");
} Console.WriteLine("LoadRawData finished");
//Save Data
Console.WriteLine("Save data");
IFormatter serializer = new BinaryFormatter(); //开始序列化
FileStream saveFile = new FileStream(train_date_path, FileMode.Create, FileAccess.Write);
serializer.Serialize(saveFile, arrx);
saveFile.Close(); saveFile = new FileStream(train_label_path, FileMode.Create, FileAccess.Write);
serializer.Serialize(saveFile, arry);
saveFile.Close();
Console.WriteLine("Save data finished"); return (np.array(arrx), np.array(arry));
} /// <summary>
/// 消费模型
/// </summary>
private void test(Model model)
{
Random rand = new Random(1); DirectoryInfo TestDir = new DirectoryInfo(TestImagePath);
foreach (var ChildDir in TestDir.GetDirectories())
{
Console.WriteLine($"Folder:【{ChildDir.Name}】");
var Files = ChildDir.GetFiles("*.png");
for (int i = 0; i < 10; i++)
{
int index = rand.Next(1000);
var image = Files[index]; var x = LoadImage(image.FullName);
var pred_y = model.Apply(x);
var result = argmax(pred_y[0].numpy()); Console.WriteLine($"FileName:{image.Name}\tPred:{result}");
}
}
} private NDArray LoadImage(string filename)
{
float[,,] arrx = new float[1, img_rows, img_cols];
Bitmap bmp = (Bitmap)Image.FromFile(filename); for (int row = 0; row < img_rows; row++)
for (int col = 0; col < img_cols; col++)
{
var pixel = bmp.GetPixel(col, row);
int val = (pixel.R + pixel.G + pixel.B) / 3;
arrx[0, row, col] = val;
} return np.array(arrx);
} private int argmax(NDArray array)
{
var arr = array.reshape(-1); float max = 0;
for (int i = 0; i < 10; i++)
{
if (arr[i] > max)
{
max = arr[i];
}
} for (int i = 0; i < 10; i++)
{
if (arr[i] == max)
{
return i;
}
} return 0;
}
}
【相关资源】
源码:Git: https://gitee.com/seabluescn/tf_not.git
项目名称:NN_MultipleClassification_Fashion_MNIST
TensorFlow.NET机器学习入门【6】采用神经网络处理Fashion-MNIST的更多相关文章
- TensorFlow.NET机器学习入门【3】采用神经网络实现非线性回归
上一篇文章我们介绍的线性模型的求解,但有很多模型是非线性的,比如: 这里表示有两个输入,一个输出. 现在我们已经不能采用y=ax+b的形式去定义一个函数了,我们只能知道输入变量的数量,但不知道某个变量 ...
- TensorFlow.NET机器学习入门【4】采用神经网络处理分类问题
上一篇文章我们介绍了通过神经网络来处理一个非线性回归的问题,这次我们将采用神经网络来处理一个多元分类的问题. 这次我们解决这样一个问题:输入一个人的身高和体重的数据,程序判断出这个人的身材状况,一共三 ...
- TensorFlow.NET机器学习入门【5】采用神经网络实现手写数字识别(MNIST)
从这篇文章开始,终于要干点正儿八经的工作了,前面都是准备工作.这次我们要解决机器学习的经典问题,MNIST手写数字识别. 首先介绍一下数据集.请首先解压:TF_Net\Asset\mnist_png. ...
- TensorFlow.NET机器学习入门【7】采用卷积神经网络(CNN)处理Fashion-MNIST
本文将介绍如何采用卷积神经网络(CNN)来处理Fashion-MNIST数据集. 程序流程如下: 1.准备样本数据 2.构建卷积神经网络模型 3.网络学习(训练) 4.消费.测试 除了网络模型的构建, ...
- TensorFlow.NET机器学习入门【8】采用GPU进行学习
随着网络越来约复杂,训练难度越来越大,有条件的可以采用GPU进行学习.本文介绍如何在GPU环境下使用TensorFlow.NET. TensorFlow.NET使用GPU非常的简单,代码不用做任何修改 ...
- TensorFlow.NET机器学习入门【0】前言与目录
曾经学习过一段时间ML.NET的知识,ML.NET是微软提供的一套机器学习框架,相对于其他的一些机器学习框架,ML.NET侧重于消费现有的网络模型,不太好自定义自己的网络模型,底层实现也做了高度封装. ...
- TensorFlow.NET机器学习入门【1】开发环境与类型简介
项目开发环境为Visual Studio 2019 + .Net 5 创建新项目后首先通过Nuget引入相关包: SciSharp.TensorFlow.Redist是Google提供的TensorF ...
- TensorFlow.NET机器学习入门【2】线性回归
回归分析用于分析输入变量和输出变量之间的一种关系,其中线性回归是最简单的一种. 设: Y=wX+b,现已知一组X(输入)和Y(输出)的值,要求出w和b的值. 举个例子:快年底了,销售部门要发年终奖了, ...
- 45、Docker 加 tensorflow的机器学习入门初步
[1]最近领导天天在群里发一些机器学习的链接,搞得好像我们真的要搞机器学习似的,吃瓜群众感觉好神奇呀. 第一步 其实也是最后一步,就是网上百度一下,Docker Toolbox,下载下来,下载,安装之 ...
随机推荐
- dart系列之:HTML的专属领域,除了javascript之外,dart也可以
目录 简介 DOM操作 CSS操作 处理事件 总结 简介 虽然dart可以同时用作客户端和服务器端,但是基本上dart还是用做flutter开发的基本语言而使用的.除了andorid和ios之外,we ...
- 数组的高阶方法map filter reduce的使用
数组中常用的高阶方法: foreach map filter reduce some every 在这些方法中都是对数组中每一个元素进行遍历操作,只有foreach是没有 ...
- linux 常用清空文件方法
1.vim 编辑器 vim /tmp/file :1,$d 或 :%d 2.cat 命令 cat /dev/null > /tmp/file
- vim使用配置(转)
在终端下使用vim进行编辑时,默认情况下,编辑的界面上是没有行号的.语法高亮度显示.智能缩进等功能的. 为了更好的在vim下进行工作,需要手动配置一个配置文件: .vimrc 在启动vim时,当前用户 ...
- 3.3 GO字符串处理
strings方法 index 判断子字符串或字符在父字符串中出现的位置(索引)Index 返回字符串 str 在字符串 s 中的索引( str 的第一个字符的索引),-1 表示字符串 s 不包含字符 ...
- 【科研工具】CAJViewer的一些操作
逐渐发现CAJViewer没有想象中的难用. 添加书签:Ctrl+M 使用按类分类,可以筛选出书签位置,和注释区分. 搜索:Ctrl+F 可以定义多种搜索.
- Web系统与自控系统数据通讯架构 之 OPC DA DataChangeEventHandler 非热点数据更新策略 ,
在使用OPC 采集 工控数据时,在DA模式下.采集数据通常用到 DataChangeEventHandler这个事件.但有时会遇到一些问题,就是当数据不变化时时不会触发 DataChange 这个事件 ...
- 赋能开发:捷码携手达内教育打造IT职业教育新生态
近日,达内教育与远眺科技签约联合培养的第一批低代码开发方向的高职学生,在杭州未来科技城捷码总部顺利毕业,首期合格学员总数超过30名.随着这些接受了"捷码"低代码平台全程" ...
- $(document).ready()与window.onload的区别,站在三个维度回答问题
1.执行时机 window.onload必须等到页面内包括图片的所有元素加载完毕后才能执行. $(document).ready()是DOM结构绘制完毕后就执行,不必等到加载完毕. 2 ...
- AD小白如何发板厂制板--导出gerber文件和钻孔文件+嘉立创下单教程
AD如何发工程制板子? 方式1,发PCB源文件给板厂 方式2,发一些工艺文件给板厂,这样就无须泄漏你的PCB源文件了,一个硬件工程师必须要掌握方式2. 方式2要做的就是导出gerber文件和钻孔文件, ...