Keras速查_CPU和GPU的mnist预测训练_模型导出_模型导入再预测_导出onnx并预测
需要做点什么
方便广大烟酒生研究生、人工智障炼丹师算法工程师快速使用keras,所以特写此文章,默认使用者已有基本的深度学习概念、数据集概念。
系统环境
python 3.7.4
tensorflow 2.6.0
keras 2.6.0
onnx 1.9.0
onnxruntime-gpu 1.9.0
tf2onnx 1.9.3
数据准备
MNIST数据集csv文件是一个42000x785的矩阵
42000表示有42000张图片
785中第一列是图片的类别(0,1,2,..,9),第二列到最后一列是图片数据向量 (28x28的图片张成784的向量), 数据集长这个样子:
1 0 0 0 0 0 0 0 0 0 ..
0 0 0 0 0 0 0 0 0 0
1 0 0 0 0 0 0 0 0 0
4 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0
7 0 0 0 0 0 0 0 0 0
3 0 0 0 0 0 0 0 0 0
5 0 0 0 0 0 0 0 0 0
3 0 0 0 0 0 0 0 0 0
8 0 0 0 0 0 0 0 0 0
9 0 0 0 0 0 0 0 0 0
1 0 0 0 0 0 0 0 0 0
3 0 0 0 0 0 0 0 0 0
3 0 0 0 0 0 0 0 0 0
1 0 0 0 0 0 0 0 0 0
2 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0
1. 导入需要的包
import os
import onnx
import keras
import logging
import subprocess
import numpy as np
import pandas as pd
import tensorflow as tf
import onnxruntime as ort
from sklearn.metrics import accuracy_score
from keras.models import Sequential, Model, load_model, save_model
from keras.layers import Dense, Activation, Dropout, Conv2D, Flatten, MaxPool2D, Input, Conv1D
from keras.utils.np_utils import to_categorical
tf.autograph.set_verbosity(0)
logging.getLogger("tensorflow").setLevel(logging.ERROR)
2. 参数准备
N_EPOCH = 1
N_BATCH = 64
N_BATCH_NUM = 500
S_DATA_PATH = r"mnist_train.csv"
S_KERAS_MODEL_DIR_PATH = r"cnn_keras"
S_KERAS_MODEL_PATH = r"cnn_keras.h5"
S_ONNX_MODEL_PATH = r"cnn_keras.onnx"
S_DEVICE, N_DEVICE_ID, S_DEVICE_FULL = "cuda", 0, "cuda:0" # 使用gpu
# S_DEVICE, N_DEVICE_ID, S_DEVICE_FULL = "cpu", 0, "cpu" # 没有gpu请反注释这行以使用CPU
if S_DEVICE == "cpu":
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
3. 读取数据
df = pd.read_csv(S_DATA_PATH, header=None)
np_mat = np.array(df)
print(df.shape)
print(np_mat.shape)
X = np_mat[:, 1:]
Y = np_mat[:, 0]
X = X.astype(np.float32) / 255
X_train = X[:N_BATCH * N_BATCH_NUM]
X_test = X[N_BATCH * N_BATCH_NUM:]
Y_train = Y[:N_BATCH * N_BATCH_NUM]
Y_test = Y[N_BATCH * N_BATCH_NUM:]
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)
X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)
Y_train = to_categorical(Y_train, num_classes=10)
Y_test = to_categorical(Y_test, num_classes=10)
print(X_train.shape)
print(Y_train.shape)
print(X_test.shape)
print(Y_test.shape)
运行输出
(42000, 785)
(42000, 785)
(32000, 28, 28, 1)
(32000, 10)
(10000, 28, 28, 1)
(10000, 10)
4. 模型构建
x_in = Input(shape=(28, 28, 1)) # 图像维度必须是 w h c
x = Conv2D(filters=32, kernel_size=(3, 3))(x_in)
x = MaxPool2D(pool_size=(2, 2))(x)
x = Dropout(0.2)(x)
x = Flatten()(x)
x = Dense(128)(x)
x = Activation('relu')(x)
x = Dense(10)(x)
y = Activation('softmax')(x)
model = Model(x_in, y)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
print(model.summary())
运行输出
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 28, 28, 1)] 0
_________________________________________________________________
conv2d (Conv2D) (None, 26, 26, 32) 320
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 13, 13, 32) 0
_________________________________________________________________
dropout (Dropout) (None, 13, 13, 32) 0
_________________________________________________________________
flatten (Flatten) (None, 5408) 0
_________________________________________________________________
dense (Dense) (None, 128) 692352
_________________________________________________________________
activation (Activation) (None, 128) 0
_________________________________________________________________
dense_1 (Dense) (None, 10) 1290
_________________________________________________________________
activation_1 (Activation) (None, 10) 0
=================================================================
Total params: 693,962
Trainable params: 693,962
Non-trainable params: 0
_________________________________________________________________
None
5. 模型训练和保存
model.fit(X_train,
Y_train,
epochs=N_EPOCH,
batch_size=N_BATCH,
verbose=1,
validation_data=(X_test, Y_test))
score = model.evaluate(X_test, Y_test, verbose=0)
print('Test score:', score[0])
print('Test accuracy:', score[1])
save_model(model, S_KERAS_MODEL_PATH)
运行输出
486/500 [============================>.] - ETA: 0s - loss: 0.2873 - accuracy: 0.9144
500/500 [==============================] - 4s 3ms/step - loss: 0.2837 - accuracy: 0.9155 - val_loss: 0.1352 - val_accuracy: 0.9616
Test score: 0.13516278564929962
Test accuracy: 0.9616000056266785
6.模型加载和加载模型使用
load_model = load_model(S_KERAS_MODEL_PATH)
print("load model ok")
score = load_model.evaluate(X_test, Y_test, verbose=0)
print('load model Test score:', score[0])
print('load model Test accuracy:', score[1])
运行输出
load model ok
load model Test score: 0.13516278564929962
load model Test accuracy: 0.9616000056266785
7.导出ONNX
s_cmd = 'python -m tf2onnx.convert --keras %s --output %s' % (S_KERAS_MODEL_PATH, S_ONNX_MODEL_PATH)
print(s_cmd)
print(os.system(s_cmd))
# proc = subprocess.run(s_cmd.split(), check=True)
# print(proc.returncode)
运行输出
python -m tf2onnx.convert --keras G:\Data\task_model_out\_tmp_out\cnn_keras.h5 --output G:\Data\task_model_out\_tmp_out\cnn_keras.onnx
0
8. 加载ONNX并运行
model = onnx.load(S_ONNX_MODEL_PATH)
print(onnx.checker.check_model(model)) # Check that the model is well formed
print(onnx.helper.printable_graph(model.graph)) # Print a human readable representation of the graph
ls_input_name, ls_output_name = [input.name for input in model.graph.input], [output.name for output in model.graph.output]
print("input name ", ls_input_name)
print("output name ", ls_output_name)
s_input_name = ls_input_name[0]
x_input = X_train[:N_BATCH*2, :, :, :].astype(np.float32)
ort_val = ort.OrtValue.ortvalue_from_numpy(x_input, S_DEVICE, N_DEVICE_ID)
print("val device ", ort_val.device_name())
print("val shape ", ort_val.shape())
print("val data type ", ort_val.data_type())
print("is_tensor ", ort_val.is_tensor())
print("array_equal ", np.array_equal(ort_val.numpy(), x_input))
providers = 'CUDAExecutionProvider' if S_DEVICE == "cuda" else 'CPUExecutionProvider'
print("providers ", providers)
ort_session = ort.InferenceSession(S_ONNX_MODEL_PATH, providers=[providers]) # gpu运行
ort_session.set_providers([providers])
outputs = ort_session.run(None, {s_input_name: ort_val})
print("sess env ", ort_session.get_providers())
print(type(outputs))
print(outputs[0])
运行输出
None
graph tf2onnx (
%input_1:0[FLOAT, unk__17x28x28x1]
) initializers (
%new_shape__15[INT64, 4]
%model/dense_1/MatMul/ReadVariableOp:0[FLOAT, 128x10]
%model/dense_1/BiasAdd/ReadVariableOp:0[FLOAT, 10]
%model/dense/MatMul/ReadVariableOp:0[FLOAT, 5408x128]
%model/dense/BiasAdd/ReadVariableOp:0[FLOAT, 128]
%model/conv2d/Conv2D/ReadVariableOp:0[FLOAT, 32x1x3x3]
%model/conv2d/BiasAdd/ReadVariableOp:0[FLOAT, 32]
%const_fold_opt__16[INT64, 2]
) {
%model/conv2d/BiasAdd__6:0 = Reshape(%input_1:0, %new_shape__15)
%model/conv2d/BiasAdd:0 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], strides = [1, 1]](%model/conv2d/BiasAdd__6:0, %model/conv2d/Conv2D/ReadVariableOp:0, %model/conv2d/BiasAdd/ReadVariableOp:0)
%model/max_pooling2d/MaxPool:0 = MaxPool[kernel_shape = [2, 2], strides = [2, 2]](%model/conv2d/BiasAdd:0)
%model/max_pooling2d/MaxPool__12:0 = Transpose[perm = [0, 2, 3, 1]](%model/max_pooling2d/MaxPool:0)
%model/flatten/Reshape:0 = Reshape(%model/max_pooling2d/MaxPool__12:0, %const_fold_opt__16)
%model/dense/MatMul:0 = MatMul(%model/flatten/Reshape:0, %model/dense/MatMul/ReadVariableOp:0)
%model/dense/BiasAdd:0 = Add(%model/dense/MatMul:0, %model/dense/BiasAdd/ReadVariableOp:0)
%model/activation/Relu:0 = Relu(%model/dense/BiasAdd:0)
%model/dense_1/MatMul:0 = MatMul(%model/activation/Relu:0, %model/dense_1/MatMul/ReadVariableOp:0)
%model/dense_1/BiasAdd:0 = Add(%model/dense_1/MatMul:0, %model/dense_1/BiasAdd/ReadVariableOp:0)
%Identity:0 = Softmax[axis = 1](%model/dense_1/BiasAdd:0)
return %Identity:0
}
input name ['input_1:0']
output name ['Identity:0']
val device cuda
val shape [128, 28, 28, 1]
val data type tensor(float)
is_tensor True
array_equal True
providers CUDAExecutionProvider
sess env ['CUDAExecutionProvider', 'CPUExecutionProvider']
<class 'list'>
[[1.0287621e-04 9.9524093e-01 5.0408958e-04 ... 6.5664819e-05
3.8182980e-03 1.2303158e-05]
[9.9932754e-01 2.7173186e-08 3.5315077e-04 ... 3.0959238e-06
8.5986117e-05 3.6047477e-06]
[1.1101285e-05 9.9719965e-01 3.8205151e-04 ... 1.2267688e-03
7.8595197e-04 4.0839368e-05]
...
[2.8337089e-02 1.5399084e-05 2.1733245e-01 ... 1.5945830e-05
2.1134425e-02 1.7111158e-03]
[1.7888090e-06 3.3868539e-06 5.2631256e-04 ... 9.9888057e-01
5.4794059e-06 5.5255485e-04]
[4.1398227e-05 1.0462944e-06 5.5901739e-03 ... 3.1221823e-09
6.6847453e-04 7.8918066e-07]]
你甚至不愿意Start的Github
Keras速查_CPU和GPU的mnist预测训练_模型导出_模型导入再预测_导出onnx并预测的更多相关文章
- 百度Paddle速查_CPU和GPU的mnist预测训练_模型导出_模型导入再预测_导出onnx并预测
需要做点什么 方便广大烟酒生研究生.人工智障炼丹师算法工程师快速使用百度PaddelPaddle,所以特写此文章,默认使用者已有基本的深度学习概念.数据集概念. 系统环境 python 3.7.4 p ...
- Mxnet速查_CPU和GPU的mnist预测训练_模型导出_模型导入再预测_导出onnx并预测
需要做点什么 方便广大烟酒生研究生.人工智障炼丹师算法工程师快速使用mxnet,所以特写此文章,默认使用者已有基本的深度学习概念.数据集概念. 系统环境 python 3.7.4 mxnet 1.9. ...
- [深度学习] Pytorch(三)—— 多/单GPU、CPU,训练保存、加载模型参数问题
[深度学习] Pytorch(三)-- 多/单GPU.CPU,训练保存.加载预测模型问题 上一篇实践学习中,遇到了在多/单个GPU.GPU与CPU的不同环境下训练保存.加载使用使用模型的问题,如果保存 ...
- 这可能是AI、机器学习和大数据领域覆盖最全的一份速查表
https://mp.weixin.qq.com/s?__biz=MjM5ODE1NDYyMA==&mid=2653390110&idx=1&sn=b3e5d6e946b719 ...
- numpy(ndarray)和tensor(GPU上的numpy)速查
类型(Types) Numpy PyTorch np.ndarray torch.Tensor np.float32 torch.float32; torch.float np.float64 tor ...
- CUDA 7.0 速查手册
Create by Jane/Santaizi 03:57:00 3/14/2016 All right reserved. 速查手册基于 CUDA 7.0 toolkit documentation ...
- 常用的14种HTTP状态码速查手册
分类 1xx \> Information(信息) // 接收的请求正在处理 2xx \> Success(成功) // 请求正常处理完毕 3xx \> Redirection(重定 ...
- jQuery 常用速查
jQuery 速查 基础 $("css 选择器") 选择元素,创建jquery对象 $("html字符串") 创建jquery对象 $(callback) $( ...
- 简明 Git 命令速查表(中文版)
原文引用地址:https://github.com/flyhigher139/Git-Cheat-Sheet/blob/master/Git%20Cheat%20Sheet-Zh.md在Github上 ...
随机推荐
- DNS解析域名过程
DNS解析域名过程 使用域名转换成IP地址,先读取本地HOST文件,本地文件没有从当前电信网管获取对应IP. 本地host文件 C:\Windows\System32\drivers\etc 画图演示 ...
- [GWCTF 2019]babyvm re
BABYVM 基于虚拟机操作的一个题 明面上的check函数和加密逻辑都是假的 操作码 重点分析这个vm 0xF5, 0xF1, 0xE1, 0x00, 0x00, 0x00, 0x00, 0xF2, ...
- Spring Boot AOP 扫盲,实现接口访问的统一日志记录
AOP 是 Spring 体系中非常重要的两个概念之一(另外一个是 IoC),今天这篇文章就来带大家通过实战的方式,在编程猫 SpringBoot 项目中使用 AOP 技术为 controller 层 ...
- Hyperledger Fabric 2.x Java区块链应用
一.说明 在上一篇文章中 <Hyperledger Fabric 2.x 自定义智能合约> 分享了智能合约的安装并使用 cli 客户端进行合约的调用:本文将使用 Java 代码基于 fab ...
- 添加删除系统右键菜单(就是上下文菜单,也就是Context Menu)中的一些选项
随着电脑安装的东西越来越多,右侧菜单也原来越长,很不方面.所以打算清理一下 我删除的大约以下几个,友好一点的都可以配置.当然也可以通过注册表直接删除. 特:注册表备份,即导入导出,避免一失足成千古恨. ...
- Spring 类名后缀理解
Aware 理解 实现Spring的Aware接口. 定义为感知.意识,核心意义在于通过Aware可以把spring底层组件注入到自定义的bean中. 对于bean与容器的关系,bean不应该知道自身 ...
- webshell安全教程防止服务器被破解
直接上传取得webshell 因过滤上传文件不严,导致用户能够直接上传webshell到网站恣意可写目录中,然后拿到网站的办理员操控权限. 2 增加修正上传类型 现在很多脚本程序上传模块不是只允许上传 ...
- 关于 ios 动画枚举翻译
例子 + (void)animateWithDuration:(NSTimeInterval)duration delay:(NSTimeInterval)delay options:(UIViewA ...
- 编写资源yaml文件、压力机配置hosts
资源文件 Deployment/StatefulSet/DaemonSet.Service.Ingress等 参考:https://www.cnblogs.com/uncleyong/p/155710 ...
- 来自牛逼哥的阴间MD5(web)
这个web题目是来自队里面牛逼哥的题目,审计源码, 看到这两个参数,前面的a和b就是直接输出数字,再看下面的,需要弱比较的输出一个c,要求应该是需要一个加密之前是一个0e开头的字符串,加密之后还是0e ...