JS做深度学习2——导入训练模型

改进项目

前段时间,我做了个RNN预测金融数据的毕业设计(华尔街),当时TensorFlow.js还没有发布,我不得已使用了keras对数据进行了训练,并且拟合好了不同期货的模型,因为当时毕设的网站是用node.js写的,为了可以在网站中预测,我采取的方案是:用python进行训练和预测,然后使用node.js运行python命令,最终在浏览器上可视化出来,这也算的上是黑科技了!

不过这样通过一个解释器调用另一个解释器,语言之间互相通信其实不是什么好的设计方式,且不说维护两门语言的困难,调用其他语言过程会产生的错误和性能问题较多,而且显得整个项目很混乱,强迫症受不了。

如今,Google开始官推TensorFlow的JS API又是前端福音。但根据官网的介绍,TensorFlow.js目前尚不成熟,JS方面尚未实现像Python那么丰富的学习API。所以各种基于TF的深度学习项目如果需要使用JS重构也需要慢慢过渡。

更多关于TensorFlow.js的目前支持状况请参阅:https://www.linpx.com/p/you-want-to-know-that-everything-about-tensorflowjs-is-here.html

如上所述,TensorFlow.js尚不能导出训练文件,但可以导入训练文件,今天根据官网提供的文档,将我毕设项目的预测功能重构为纯JS实现。方法是通过导入python训练好的文件,依靠TensorFlow.js调用进行预测。

官方文档

本文是对TensorFlow官网文档的学习和实用,记录了笔者用tfjs导入模型进行预测的过程,参考资料:https://js.tensorflow.org/tutorials/import-keras.html

接下来就开始翻译~~哦不,是上手coding。

文件格式

如tfjs官方所述,我们通常在keras中训练后导出的是H5格式的文件,tfjs不能直接理解h5文件,故需要先将h5转换格式。

准备插件

安装tensorflowjs:

pip install tensorflowjs

可以看到tensorflowjs版本还很年幼,看来发布不久。

手动转换测试

注意下面这些操作依然是基于python做的,我们先尝试手动转换文件格式。

pip安装完tensorflowjs后,进入cmd尝试转换H5文件

tensorflowjs_converter --input_format keras D:\pro\WallStreet\tf_modules\models\20.h5 D:\pro\WallStreet\tf_modules\models\20

对上面的命令进行一下解释:

input_format这个option后面跟着的是原始文件格式来源(keras),然后紧跟着h5文件的地址,然后是转化后保存的目标目录。

这里注意一下,h5最终会被转化为多个文件,所以目标是个目录而不是文件,目录里有:

其中model.json是js中需要调用的文件,另外两个是训练后的二进制文件。

The target TensorFlow.js Layers format is a directory containing a model.json file and a set of sharded weight files in binary format. The model.json file contains both the model topology (aka "architecture" or "graph": a description of the layers and how they are connected) and a manifest of the weight files.

这是Google的原文,翻译过来是tensorflowjs最终格式是个目录 ,包含了一个model.json,还有一些碎片的权重(神经网络中的名词:训练过程优化所得的w值)文件(二进制格式)。json文件则记录神经网络的拓扑结构,一种对神经网络不同层,不同神经元之间连接的状态的记录。大意就是保存了训练后的模型结构!

这段话比较玄学,以我多年对计算机各种理论的融汇贯通后的理解是,神经网络成型后的模型(也就是训练过后的文件),由两部分组成,一个是神经元本身的内部权重(一些数据),还有事神经元之间连接的桥梁(一些结构),综合起来还是——“数据结构”,这个数据结构类似图,有节点和连接,节点还是一些多维的权值。

由此可见,一个真正的程序员应当好好学习基本的专业课包括数学知识,而不是停留在语言层面,否则涉及到高深的技术时,终将也会茫然。

模型保存时转换

废话多了,在实际项目中,我们不可能手动转换,考虑直接在keras训练后生成tensorflowjs文件,这应是自动化的过程

如下:

import tensorflowjs as tfjs

...

# 创建RNN,并训练
model = Sequential()
model.add(LSTM(128, input_shape=(1, window)))
model.add(Dense(1))
model.compile(loss='mean_squared_error', optimizer='adam')
model.fit(trainX, trainY, epochs=100, batch_size=100, verbose=2)
#保存训练模型
#model.save('tf_modules/models/%s.h5'%g_id)
tfjs.converters.save_keras_model(model,'tf_modules/models/%s'%g_id)

这只是我改进项目中的python训练代码的片段,可以看到原来的model.save是保存为h5,被我注释后改成了保存为tensorflowjs文件集,使用的是tensorflowjs下的 converters.save_keras_model 方法。

导入模型

万事俱备,只欠东风,可以通过JavaScript导入模型预测了。

当到达这一步时,以为要大功告成了,too young!

你可能会说很简单啊,tfjs一定提供了读取模型的方法,没错,确实提供了,不过很可惜不支持node.js,笔者在写篇博客期间,不停翻阅各种国外文档,耗费了整整一下午,官方给出的例子太过简单了:

import * as tf from '@tensorflow/tfjs';
const model = await tf.loadModel('https://foo.bar/tfjs_artifacts/model.json'); const example = tf.fromPixels(webcamElement); // for example
const prediction = model.predict(example);

就这么短短几行,一个例子,意思是:“你用loadmodel+predict方法就行了!”

这对于tensorflow新手来说无疑又是巨大灾难。

最关键的是,笔者注意到loadModel()方法传入的是一个http地址,而研究了一下午,在尝试了node.js下各种文件读取,http访问后发现原来这货也只支持浏览器端的实现,一句话概括就是目前的tfjs导入模型不支持node.js!

以下就是结论的由来,如果有成功使用Node.js读取model的请留言帮助我,感激不尽。

那怎么办呢,硬着头皮在浏览器下运行咯。具体如下:

第一步:保存模型到静态目录下

为了browser端可以访问到model.json以及另外两个权重的二进制文件,在python代码训练完成后保存模型到静态可访问目录下,这对于node.js来说十分重要,因为node.js通过express中间件可以给定一个静态资源目录(就是存放css,js,img的目录)

例如:

//配置静态文件为assets目录
app.use(express.static(__dirname + "/assets"))

这是指定的静态文件的根目录,训练后的模型也放到这个目录中,浏览器才能访问到。

我的方法是在这个目录下新建一个models目录专门存放model,并且不同商品训练出来的模型都是一个独立目录(前面说了,tensorflowjs转化后形成的是一个目录),目录名字是商品在数据库的主键ID。

这样保存模型的代码就变成这样了

tfjs.converters.save_keras_model(model,'assets/tf_models/%s'%g_id

第二步:浏览器script引入tfjs

script(src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.11.4",type="text/javascript")

以上是笔者使用了jade引擎,可以自行转化为html的script标签

第三步:预测

以下的JavaScript代码都是在浏览器中运行的:

async function get_predict_data(g_id){
//预测
let model = await tf.loadModel(`/tf_models/${g_id}/model.json`)
let response = await fetch(`/futures/data/${g_id}/30`);
let data = (await response.json()).data;
let max_price = [];
for(let i in data){
max_price.push(data[i].max_price);
}
scaler = new min_max_scale(max_price);
max_price = scaler.fit(max_price);
let tf_price = tf.tensor3d(max_price,[1,1,30]);
let prediction = model.predict(tf_price);
let a = (await prediction.data())[0];
a = scaler.inverse(a);
} //数据归一化
function min_max_scale(data){
this.max = Math.max(...data);
this.min = Math.min(...data);
this.fit = function(){
for(i in data){
data[i] = (data[i]-this.min)/(this.max-this.min);
}
return data;
}
this.inverse = (to_inverse) => to_inverse*(this.max-this.min)+this.min
}

最后这个a即为浏览器下跑模型预测的值。

别看代码很简短,大致用了两个方法,一个是预测方法,另一个归一化数据方法,这里面坑坑可多了。

数据标准化,这里使用的是归一化,在python下使用的是sklearn的MinMaxScaler对象,而npm库中翻遍了也找不到相似的功能模块,只好我自己实现了。

什么是归一化,算法是什么?请参考:https://blog.csdn.net/Jiaach/article/details/79484990

可以看到,get_predict_data(预测函数)中除了官方给出的loadModel和predict函数外还有许多陌生的函数(均来自于tensorflow),这些函数的API对于新手来说十分陌生,官网给的也不是很明确,而且还是英文的,今天先不介绍,这些函数留给今后《JS做深度学习博客系列》一个个介绍。

tensorflow.js导入模型到此告一段落。

完整代码参考我的github项目:https://github.com/devilyouwei/WallStreet

JS做深度学习2——导入训练模型的更多相关文章

  1. JS做深度学习1——偶然发现与入门

    JS做深度学习1--偶然发现与入门 不久前,我初次涉猎了Node.js,并且使用它开发了毕业设计的WEB模块,然后通过在Node中调用系统命令执行Python文件方式实现了深度学习功能模块的对接,Py ...

  2. JS做深度学习3——数据结构

    最近在上海上班了,很久没有写博客了,闲下来继续关注和研究Tensorflow.js 关于深度学习的文章我也已经写了不少,部分早期作品可能包含了不少错误的认识,在后面的博文中会改进或重新审视. 今天聊聊 ...

  3. 【腾讯Bugly干货分享】人人都可以做深度学习应用:入门篇

    导语 2016年,继虚拟现实(VR)之后,人工智能(AI)的概念全面进入大众的视野.谷歌,微软,IBM等科技巨头纷纷重点布局,AI 貌似将成为互联网的下一个风口. 很多开发同学,对人工智能非常感兴趣, ...

  4. 腾讯QQ会员技术团队:人人都可以做深度学习应用:入门篇(下)

    四.经典入门demo:识别手写数字(MNIST) 常规的编程入门有"Hello world"程序,而深度学习的入门程序则是MNIST,一个识别28*28像素的图片中的手写数字的程序 ...

  5. 使用亚马逊云服务器EC2做深度学习(四)配置好的系统镜像

    这是<使用亚马逊云服务器EC2做深度学习>系列的第四篇文章. (一)申请竞价实例  (二)配置Jupyter Notebook服务器  (三)配置TensorFlow  (四)配置好的系统 ...

  6. 使用亚马逊云服务器EC2做深度学习(三)配置TensorFlow

    这是<使用亚马逊云服务器EC2做深度学习>系列的第三篇文章. (一)申请竞价实例  (二)配置Jupyter Notebook服务器  (三)配置TensorFlow  (四)配置好的系统 ...

  7. 使用亚马逊云服务器EC2做深度学习(二)配置Jupyter Notebook服务器

    这是<使用亚马逊云服务器EC2做深度学习>系列的第二篇文章. (一)申请竞价实例  (二)配置Jupyter Notebook服务器  (三)配置TensorFlow  (四)配置好的系统 ...

  8. 使用亚马逊云服务器EC2做深度学习(一)申请竞价实例

    这是<使用亚马逊云服务器EC2做深度学习>系列的第一篇文章. (一)申请竞价实例  (二)配置Jupyter Notebook服务器  (三)配置TensorFlow  (四)配置好的系统 ...

  9. 机器学习(Machine Learning)&深度学习(Deep Learning)资料【转】

    转自:机器学习(Machine Learning)&深度学习(Deep Learning)资料 <Brief History of Machine Learning> 介绍:这是一 ...

随机推荐

  1. 65)STL中string的知识

    1)代码展示: string是一个类,只不过封装了 char*  而且还封装了  很多的字符串操作函数 2)string类的初始化: string的构造函数 ²  默认构造函数: string();  ...

  2. 03 Mybatis:05.使用Mybatis完成CRUD

    mybatis框架:共四天 明确:我们在实际开发中,都是越简便越好,所以都是采用不写dao实现类的方式.不管使用XML还是注解配置. 第二天:mybatis基本使用 mybatis的单表crud操作 ...

  3. oracle的用户、权限、表空间的管理

    1.创建表空间 create tablespace test1_tablespace datafile 'test1file.dbf' size 10m; 2.创建临时表空间 create tempo ...

  4. ZJNU 1528 - War--高级

    类似于1213取水 可以把空投当作第0个城市 最后将0~n的所有城市跑最小生成树 /* Written By StelaYuri */ #include<iostream> #includ ...

  5. 2019牛客暑期多校训练营(第五场)B.generator 1

    传送门:https://ac.nowcoder.com/acm/contest/885/B 题意:给出,由公式 求出 思路:没学过矩阵快速幂.题解说是矩阵快速幂,之后就学了一遍.(可以先去学一下矩阵快 ...

  6. IMX6开发板虚拟机加载Ubuntu12.04.2镜像

    基于迅为IMX6开发板安装好虚拟机之后,用户就可以加载 Ubuntu12.04.2 镜像.用户可以在网盘中下载“编译好的镜像”,该镜像已经安装好了编译 Android4.4.2 所需要的大部分软件.用 ...

  7. springboot学习笔记:9.springboot+mybatis+通用mapper+多数据源

    本文承接上一篇文章:springboot学习笔记:8. springboot+druid+mysql+mybatis+通用mapper+pagehelper+mybatis-generator+fre ...

  8. react webpack配置

  9. LGOJ1861 星之器

    前置扯淡 我对这个题目的评价和网上各位大佬的一样:人类智慧题 (显然我不具有人类智慧--) Description link 现在有一个 \(n \times m\) 的矩阵\(A\),里面的每个元素 ...

  10. Opencv笔记(十二)——形态学转换

    学习目标: 学习不同的形态学操作,例如腐蚀,膨胀,开运算,闭运算等 我们要学习的函数有: cv2.erode(), cv2.dilate(), cv2.morphologyEx()等 原理简介: 形态 ...