用 Tensorflow.js 做了一个动漫分类的功能(二)
前言:
前面已经通过采集拿到了图片,并且也手动对图片做了标注。接下来就要通过 Tensorflow.js 基于 mobileNet 训练模型,最后就可以实现在采集中对图片进行自动分类了。
这种功能在应用场景里就比较多了,比如图标素材站点,用户通过上传一个图标,系统会自动匹配出相似的图标,还有二手平台,用户通过上传闲置物品图片,平台自动给出分类等,这些也都是前期对海量图片进行了标注训练而得到一个损失率极低的模型。下面就通过简答的代码实现一个小的动漫分类。

环境:
Node
Http-Server
Parcel
Tensorflow

编码:
1. 训练模型
1.1. 创建项目,安装依赖包
npm install @tensorflow/tfjs --legacy-peer-deps
npm install @tensorflow/tfjs-node-gpu --legacy-peer-deps
1.2. 全局安装 Http-Server
npm install i http-server
1.3. 下载 mobileNet 模型文件 (网上有下载)
1.4. 根目录下启动 Http 服务 (开启跨域),用于 mobileNet 和训练结果的模型可访问
http-server --cors -p 8080

1.5. 创建训练执行脚本 run.js
const tf = require('@tensorflow/tfjs-node-gpu');
const getData = require('./data');
const TRAIN_PATH = './动漫分类/train';
const OUT_PUT = 'output';
const MOBILENET_URL = 'http://127.0.0.1:8080/data/mobilenet/web_model/model.json';
(async () => {
const { ds, classes } = await getData(TRAIN_PATH, OUT_PUT);
console.log(ds, classes);
//引入别人训练好的模型
const mobilenet = await tf.loadLayersModel(MOBILENET_URL);
//查看模型结构
mobilenet.summary();
const model = tf.sequential();
//截断模型,复用了86个层
for (let i = 0; i < 86; ++i) {
const layer = mobilenet.layers[i];
layer.trainable = false;
model.add(layer);
}
//降维,摊平数据
model.add(tf.layers.flatten());
//设置全连接层
model.add(tf.layers.dense({
units: 10,
activation: 'relu'//设置激活函数,用于处理非线性问题
}));
model.add(tf.layers.dense({
units: classes.length,
activation: 'softmax'//用于多分类问题
}));
//设置损失函数,优化器
model.compile({
loss: 'sparseCategoricalCrossentropy',
optimizer: tf.train.adam(),
metrics:['acc']
});
//训练模型
await model.fitDataset(ds, { epochs: 20 });
//保存模型
await model.save(`file://${process.cwd()}/${OUT_PUT}`);
})();
1.6. 创建图片与 Tensor 转换库 data.js
const fs = require('fs');
const tf = require("@tensorflow/tfjs-node-gpu");
const img2x = (imgPath) => {
const buffer = fs.readFileSync(imgPath);
//清除数据
return tf.tidy(() => {
//把图片转成tensor
const imgt = tf.node.decodeImage(new Uint8Array(buffer), 3);
//调整图片大小
const imgResize = tf.image.resizeBilinear(imgt, [224, 224]);
//归一化
return imgResize.toFloat().sub(255 / 2).div(255 / 2).reshape([1, 224, 224, 3]);
});
}
const getData = async (traindir, output) => {
let classes = fs.readdirSync(traindir, 'utf-8');
fs.writeFileSync(`./${output}/classes.json`, JSON.stringify(classes));
const data = [];
classes.forEach((dir, dirIndex) => {
fs.readdirSync(`${traindir}/${dir}`)
.filter(n => n.match(/jpg$/))
.slice(0, 1000)
.forEach(filename => {
const imgPath = `${traindir}/${dir}/${filename}`;
data.push({ imgPath, dirIndex });
});
});
console.log(data);
//打乱训练顺序,提高准确度
tf.util.shuffle(data);
const ds = tf.data.generator(function* () {
const count = data.length;
const batchSize = 32;
for (let start = 0; start < count; start += batchSize) {
const end = Math.min(start + batchSize, count);
console.log('当前批次', start);
yield tf.tidy(() => {
const inputs = [];
const labels = [];
for (let j = start; j < end; ++j) {
const { imgPath, dirIndex } = data[j];
const x = img2x(imgPath);
inputs.push(x);
labels.push(dirIndex);
}
const xs = tf.concat(inputs);
const ys = tf.tensor(labels);
return { xs, ys };
});
}
});
return { ds, classes };
}
module.exports = getData;
1.7. 运行执行文件
node run.js

2. 调用模型
2.1. 全局安装 parcel
npm install i parcel
2.2. 创建页面 index.html
<script src="script.js"></script>
<input type="file" onchange="predict(this.files[0])">
<br>
2.3. 创建模型调用预测脚本 script.js
import * as tf from '@tensorflow/tfjs';
import { img2x, file2img } from './utils';
const MODEL_PATH = 'http://127.0.0.1:8080/t7';
const CLASSES = ["假面骑士","奥特曼","海贼王","火影忍者","龙珠"];
window.onload = async () => {
const model = await tf.loadLayersModel(MODEL_PATH + '/output/model.json');
window.predict = async (file) => {
const img = await file2img(file);
document.body.appendChild(img);
const pred = tf.tidy(() => {
const x = img2x(img);
return model.predict(x);
});
const index = pred.argMax(1).dataSync()[0];
console.log(pred.argMax(1).dataSync());
let predictStr = "";
if (typeof CLASSES[index] == 'undefined') {
predictStr = BRAND_CLASSES[index];
} else {
predictStr = CLASSES[index];
}
setTimeout(() => {
alert(`预测结果:${predictStr}`);
}, 0);
};
};
2.4. 创建图片 tensor 格式转换库 utils.js
import * as tf from '@tensorflow/tfjs';
export function img2x(imgEl){
return tf.tidy(() => {
const input = tf.browser.fromPixels(imgEl)
.toFloat()
.sub(255 / 2)
.div(255 / 2)
.reshape([1, 224, 224, 3]);
return input;
});
}
export function file2img(f) {
return new Promise(resolve => {
const reader = new FileReader();
reader.readAsDataURL(f);
reader.onload = (e) => {
const img = document.createElement('img');
img.src = e.target.result;
img.width = 224;
img.height = 224;
img.onload = () => resolve(img);
};
});
}
2.5. 打包项目并运行
parcel index.html

2.6. 运行效果



注意:
1. 模型训练过程报错
Input to reshape is a tensor with 50176 values, but the requested shape has 150528
1.1. 原因
张量 reshape 不对,实际输入元素个数与所需矩阵元素个数不一致,就是采集过来的图片有多种图片格式,而不同格式的通道不同 (jpg3 通道,png4 通道,灰色图片 1 通道),在将图片转换 tensor 时与代码里的张量形状不匹配。
1.2. 解决方法
一种方法是删除灰色或 png 图片,其二是修改代码 tf.node.decodeImage (new Uint8Array (buffer), 3)


用 Tensorflow.js 做了一个动漫分类的功能(二)的更多相关文章
- TensorFlow.js之根据数据拟合曲线
这篇文章中,我们将使用TensorFlow.js来根据数据拟合曲线.即使用多项式产生数据然后再改变其中某些数据(点),然后我们会训练模型来找到用于产生这些数据的多项式的系数.简单的说,就是给一些在二维 ...
- 关于最近在做的一个js全屏轮播插件
最近去面试了,对方要求我在一个星期内用原生的js代码写一个全屏轮播的插件,第一想法就是跟照片轮播很相似,只是照片轮播是有定义一个宽高度大小已经确定了的容器用来存储所有照片,然后将照片全部左浮动,利用m ...
- 做了一个图片等比缩放的js
做了一个图片等比缩放的js 芋头 发布在view:8447 今天改了一下博客的主题,发现博客主题在ie6下变样了,后来发现是因为某篇文章里的某个图片太大了撑开了容器,导致样式错位,前几天公司需求里 ...
- 4-13 Webpacker-React.js; 用React做一个下拉表格的功能: <详解>
Rails5.1增加了Webpacker: Webpacker essentially is the decisions made by the Rails team and bundled up i ...
- 用 JS 做一个数独游戏(二)
用 JS 做一个数独游戏(二) 在 上一篇博客 中,我们通过 Node 运行了我们的 JavaScript 代码,在控制台中打印出来生成好的数独终盘.为了让我们的数独游戏能有良好的体验,这篇博客将会为 ...
- 用 JS 做一个数独游戏(一)
用 JS 做一个数独游戏(一) 数独的棋盘由 9x9 的方格组成,每一行的数字包含 1 ~ 9 九个数字,并且每一列包含 1 ~ 9 这 9 个不重复的数字,另外,整个棋盘分为 9 个 3x3 的块, ...
- 用js给闺女做了一个加减乘除的html
下班回家用二十分钟给闺女做了一个加减乘除的页面,顺便记录下代码,时间仓促,后期再来修改吧 目录结构 -yq --menu.html --yq.html --yq50.html --yq70.html ...
- 用js,css3 做的一个球
用css3属性很容易做一个立方体,但是要做一个球体,会相对复杂些 原理是:球可以看做是由无数个圆圈构成,然后就可以用圆圈来做球, 下面的例子是我做的一个小球,由72个圆圈组成.如果把每个圆圈的背景颜色 ...
- JS 做时钟
今天,给大家分享一个用JS做的时钟. <!DOCTYPE html><html> <head> <meta charset="utf-8" ...
- JS一般般的网页重构可以使用Node.js做些什么(转)
一.非计算机背景前端如何快速了解Node.js? 做前端的应该都听过Node.js,偏开发背景的童鞋应该都玩过. 对于一些没有计算机背景的,工作内容以静态页面呈现为主的前端,可能并未把玩过Node.j ...
随机推荐
- 【python爬虫】对于微博用户发表文章内容和评论的爬取
此博客仅作为交流学习 对于喜爱的微博用户文章内容进行爬取 (此部分在于app页面进行爬取,比较方便) 分析页面 在这里进行json方法进行,点击Network进行抓包 发现数据加载是由这个页面发出的, ...
- 2022-04-25:给定两个长度为N的数组,a[]和b[] 也就是对于每个位置i来说,有a[i]和b[i]两个属性 i a[i] b[i] j a[j] b[j] 现在想为了i,选一个最
2022-04-25:给定两个长度为N的数组,a[]和b[] 也就是对于每个位置i来说,有a[i]和b[i]两个属性 i a[i] b[i] j a[j] b[j] 现在想为了i,选一个最好的j位置, ...
- rust语言写的贪吃蛇游戏
首先新建工程,然后用vscode打开,命令如下: cargo new snake --bin 文件结构如下: Cargo.Toml文件内容如下: [package] name = "snak ...
- 2022-09-01:字符串的 波动 定义为子字符串中出现次数 最多 的字符次数与出现次数 最少 的字符次数之差。 给你一个字符串 s ,它只包含小写英文字母。请你返回 s 里所有 子字符串的 最大波
2022-09-01:字符串的 波动 定义为子字符串中出现次数 最多 的字符次数与出现次数 最少 的字符次数之差. 给你一个字符串 s ,它只包含小写英文字母.请你返回 s 里所有 子字符串的 最大波 ...
- 2022-08-26:用一个大小为 m x n 的二维网格 grid 表示一个箱子 你有 n 颗球。箱子的顶部和底部都是开着的。 箱子中的每个单元格都有一个对角线挡板,跨过单元格的两个角, 可以将球导
2022-08-26:用一个大小为 m x n 的二维网格 grid 表示一个箱子 你有 n 颗球.箱子的顶部和底部都是开着的. 箱子中的每个单元格都有一个对角线挡板,跨过单元格的两个角, 可以将球导 ...
- 2022-06-12:在N*N的正方形棋盘中,有N*N个棋子,那么每个格子正好可以拥有一个棋子。 但是现在有些棋子聚集到一个格子上了,比如: 2 0 3 0 1 0 3 0 0 如上的二维数组代表,一
2022-06-12:在NN的正方形棋盘中,有NN个棋子,那么每个格子正好可以拥有一个棋子. 但是现在有些棋子聚集到一个格子上了,比如: 2 0 3 0 1 0 3 0 0 如上的二维数组代表,一共3 ...
- 2021-12-16:给定两个数a和b, 第1轮,把1选择给a或者b, 第2轮,把2选择给a或者b, ... 第i轮,把i选择给a或者b。 想让a和b的值一样大,请问至少需要多少轮? 来自字节跳动。
2021-12-16:给定两个数a和b, 第1轮,把1选择给a或者b, 第2轮,把2选择给a或者b, - 第i轮,把i选择给a或者b. 想让a和b的值一样大,请问至少需要多少轮? 来自字节跳动. 答案 ...
- Ado.Net 数据库访问技术(.Net 6版本)
1. ADO.NET的前世今生 ADO.NET的名称起源于ADO(ActiveX Data Objects),是一个COM组件库,用于在以往的Microsoft技术中访问数据.之所以使用ADO.NET ...
- 一次 HPC 病毒感染与解决经历
周一的时候,有同事反馈说,HPC 的项目报告路径正在不断产生 *.exe 和 *.pif 文件,怀疑是不是被病毒感染! 收到信息,第一时间进去目录,的确发现该目录每个几秒钟就自动生成一个 *.exe ...
- 自然语言处理(NLP)
"自然语言处理(Natural Language Processing, NLP)是计算机科学领域与人工智能领域中的一个重要方向.它研究能实现人与计算机之间用自然语言进行有效通信的各种理论和 ...