STM32深度学习实战

1. 前言

​ 本文主要记录基于 tensorflow 的简单模型在 stm32 上运行测试的调试记录,开发人员应对深度学习基础理论和 tensorflow 框架基础操作有一定了解,对深度学习在微控制器上的实现评估提供一定的参考方向。

​ 本文实战基于温控主控板硬件及其基础工程进行测试。

​ 总体思想:

  1. 在PC的 tensorflow 虚拟环境中进行模型的装配、训练、保存及测试;
  2. 通过 X-CUBE-AI 拓展软件包从预先训练的神经网络模型生成 STM32 优化库;
  3. 编写测试程序验证神经网络模型。

2. 开发环境准备

  • tensorflow 开发环境已按照《Win10环境安装Anaconda(3-2021.05)+Tensorflow(2.5)》进行部署完成;

  • STM32CubeMX 6.0.1(安装 X-CUBE-AI 拓展软件包);

  • 温控主控板精简工程(其他项目操作类似);

  • 模型开发过程中需要的 python 插件(autopep8、flake8、jupyter、notebook、Keras、matplotlib、numpy、pandas),反正就是缺什么安装什么;

X-CUBE-AI 简介

​ X-CUBE-AI 是 STM32Cube 扩展包的一部分 STM32Cube.AI,通过自动转换预先训练的神经网络并将生成的优化库集成到用户的项目中,扩展 STM32CubeMX 功能。

  • 从预先训练的神经网络模型生成 STM32 优化库;
  • 本地支持各种深度学习框架,如 Keras 和 TensorFlow 精简版,以及支持所有可以导出到 ONNX 标准格式的框架,如 PyTorch、微软认知工具包、MATLAB 等;
  • 支持 Keras 网络和 TensorFlow 的 8 位量化精简量化网络;
  • 允许使用较大的网络,将参数(权重矩阵)存储在外部闪存中,激活缓冲运行在外部 RAM 中;
  • 通过 STM32Cube 集成,可轻松跨不同 STM32 微控制器系列;
  • 使用 TensorFlow 精简神经网络,使用 STM32Cube.AI 运行环境 或 用于微控制器的 TensorFlow Lite 原生运行环境;
  • 免费、用户友好的许可证条款

3. 模型创建

本文以最简单的电压等级模型为例。

3.1 电压等级模型

新建模型训练文件 level_check.py 训练 epochs 20000次

# level_check.py
'''
电源等级检测测试
训练模型阈值
一级 -> v>=8.0
二级 -> 7.8<=v<8.0
三级 -> v<7.8 输入层 -> 隐藏层 -> 输出层
''' # 导入工具包
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import (Sequential, datasets, layers, losses, metrics, optimizers) os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
print(tf.__version__) # %% 读取数据
data = pd.read_csv('data/voltage.csv', sep=',', header=None)
voltage = data.iloc[:, 0] # 取第一列所有数据
level = data.iloc[:, 1:]
level.astype(int) # print(voltage)
# print(level) # %% 建立模型
'''
model = Sequential([layers.Dense(20, activation='relu'),
layers.Dense(10, activation='relu'),
layers.Dense(3, activation='softmax')])
model.build(input_shape=(4, 1))
model.summary()
''' # '''
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(units=20, activation='relu', input_shape=(1,)))
model.add(tf.keras.layers.Dense(units=10, activation='relu'))
model.add(tf.keras.layers.Dense(units=3, activation='softmax'))
model.summary()
# ''' # %%
model.compile(optimizer=optimizers.Adam(learning_rate=0.001),
loss=losses.categorical_crossentropy,
metrics=[metrics.mse])
history = model.fit(x=voltage, y=level, epochs=20000) print(model.evaluate(voltage, level)) # 保存模型
model.save('level_check.h5')

运行如下图:

训练数据 voltage.csv

7.61,0,0,1
7.62,0,0,1
7.63,0,0,1
7.64,0,0,1
7.65,0,0,1
7.66,0,0,1
7.75,0,0,1
7.78,0,0,1
7.71,0,0,1
7.72,0,0,1
7.8,0,1,0
7.83,0,1,0
7.92,0,1,0
7.85,0,1,0
7.81,0,1,0
7.81,0,1,0
7.84,0,1,0
7.89,0,1,0
7.98,0,1,0
7.88,0,1,0
8.02,1,0,0
8.12,1,0,0
8.05,1,0,0
8.15,1,0,0
8.11,1,0,0
8.01,1,0,0
8.22,1,0,0
8.12,1,0,0
8.14,1,0,0
8.07,1,0,0

​ 模型的载入与测试,并且将模型转换为TF lite格式(ps:如果直接使用.h5文件也是可以的,在CubeMX里输入模型类别选Keras,不知什么原因模型转换不成功,使用 .h5模型即可),代码如下:

# 导入工具包
import sys, os
import datetime
import time
import numpy as np
import tensorflow as tf # 输出函数 输出更加直观
def level_output(level):
for i in range(level.shape[1]):
if level[0, i] == 1.0:
return i+1 # 测试电压
test_v = 7.78
t1 = time.time() # %%导入模型计算
load_model = tf.keras.models.load_model('level_check.h5')
out = load_model.predict([test_v])
print(out) cal_level = np.around(out).astype(int)
print(cal_level) t2 = time.time() # %%输出能源等级
level = level_output(cal_level)
print(level)
print((int(t2*1000)-int(t1*1000))) # %% 转换模型为tf lite格式 不量化
# converter = tf.lite.TFLiteConverter.from_keras_model(load_model)
# tflite_model = converter.convert()
# open("level_check.tflite", "wb").write(tflite_model) # 保存到磁盘

运行如下图:

4. stm32深度学习

​ 基于温控主控板精简工程,打开 CubeMX_Config.ioc 工程,添加 X-CUBE-AI 支持;

  1. 添加 X-CUBE-AI 拓展包

​ 2. 添加模型,点击 Analyze 进行分析,会计算出模型的复杂度及其所需资源 FLASH + RAM。

  1. 生成所需代码,并移植到工程中

  2. 移植测试

    • 复制AI库 board\CubeMX_Config\Middlewares\ST\AI 到 board\CubeAI,并添加如下 SConscript

      import os, sys
      import rtconfig
      from building import * cwd = GetCurrentDir()
      CPPPATH = [cwd + '/Inc']
      LIBPATH = [cwd + '/Lib']
      lib_file_path = os.listdir(LIBPATH[0])
      lib_files = list() for i in range(len(lib_file_path)):
      stem = os.path.splitext(lib_file_path[i])[0]
      lib_files.append(stem) # EXEC_PATH is the compiler execute path, for example, CodeSourcery, Keil MDK, IAR
      for item in lib_files:
      if "GCC" in item and rtconfig.CROSS_TOOL == 'gcc':
      LIBS = [item[3:]]
      elif "Keil" in item and rtconfig.CROSS_TOOL == 'keil':
      LIBS = [item]
      elif "IAR" in item and rtconfig.CROSS_TOOL == 'iar':
      LIBS = [item] src = []
      group = DefineGroup('CubeAI', src, depend = [''], CPPPATH = CPPPATH, LIBS = LIBS, LIBPATH = LIBPATH)
      Return('group')
    • 提取生成的 board\CubeMX_Config\Src 和 Inc 目录下的 模型C代码到应用

      network_config.h

      network.h

      network.c

      network_data.h

      network_data.c

      添加 app_ai.c

      #include <rtthread.h>
      #include <stdlib.h>
      #include <string.h>
      #include <math.h>
      #include "network.h"
      #include "network_data.h" #define DBG_SECTION_NAME "AI"
      #define DBG_LEVEL DBG_LOG
      #include <rtdbg.h> static ai_handle network = AI_HANDLE_NULL;
      static ai_network_report network_info; AI_ALIGNED(4)
      static ai_u8 activations[AI_NETWORK_DATA_ACTIVATIONS_SIZE]; #if !defined(AI_NETWORK_INPUTS_IN_ACTIVATIONS)
      AI_ALIGNED(4)
      static ai_u8 in_data_s[AI_NETWORK_IN_1_SIZE_BYTES];
      #endif #if !defined(AI_NETWORK_OUTPUTS_IN_ACTIVATIONS)
      AI_ALIGNED(4)
      static ai_u8 out_data_s[AI_NETWORK_OUT_1_SIZE_BYTES];
      #endif static void ai_log_err(const ai_error err, const char *fct)
      {
      if (fct)
      LOG_RAW("TEMPLATE - Error (%s) - type=0x%02x code=0x%02x\r\n", fct, err.type, err.code);
      else
      LOG_RAW("TEMPLATE - Error - type=0x%02x code=0x%02x\r\n", err.type, err.code);
      } static int ai_boostrap(ai_handle w_addr, ai_handle act_addr)
      {
      ai_error err; /* 1 - Create an instance of the model */
      err = ai_network_create(&network, AI_NETWORK_DATA_CONFIG);
      if (err.type != AI_ERROR_NONE)
      {
      ai_log_err(err, "ai_network_create");
      return -1;
      } /* 2 - Initialize the instance */
      const ai_network_params params = AI_NETWORK_PARAMS_INIT(
      AI_NETWORK_DATA_WEIGHTS(w_addr),
      AI_NETWORK_DATA_ACTIVATIONS(act_addr)); if (!ai_network_init(network, &params))
      {
      err = ai_network_get_error(network);
      ai_log_err(err, "ai_network_init");
      return -1;
      } /* 3 - Retrieve the network info of the created instance */
      if (!ai_network_get_info(network, &network_info))
      {
      err = ai_network_get_error(network);
      ai_log_err(err, "ai_network_get_error");
      ai_network_destroy(network);
      network = AI_HANDLE_NULL;
      return -3;
      } return 0;
      } static int ai_run(void *data_in, void *data_out)
      {
      ai_i32 batch; ai_buffer *ai_input = network_info.inputs;
      ai_buffer *ai_output = network_info.outputs; ai_input[0].data = AI_HANDLE_PTR(data_in);
      ai_output[0].data = AI_HANDLE_PTR(data_out); batch = ai_network_run(network, ai_input, ai_output);
      if (batch != 1)
      {
      ai_log_err(ai_network_get_error(network), "ai_network_run");
      return -1;
      } return 0;
      } #include "stm32f4xx_hal.h"
      void MX_X_CUBE_AI_Init(void)
      {
      LOG_RAW("\r\nTEMPLATE - initialization\r\n"); CRC_HandleTypeDef hcrc;
      hcrc.Instance = CRC;
      HAL_CRC_Init(&hcrc); ai_boostrap(ai_network_data_weights_get(), activations);
      }
      INIT_APP_EXPORT(MX_X_CUBE_AI_Init); /* 自定义网络查询及配置命令行接口 */
      static void ai_test(int argc, char *argv[])
      {
      if (argc != 2)
      {
      LOG_RAW("use ai <float>\r\n");
      return;
      } float test_data = atof(argv[1]); if (network == AI_HANDLE_NULL)
      {
      ai_error err = {AI_ERROR_INVALID_HANDLE, AI_ERROR_CODE_NETWORK};
      ai_log_err(err, "network not init ok");
      return;
      } if ((network_info.n_inputs != 1) || (network_info.n_outputs != 1))
      {
      ai_error err = {AI_ERROR_INVALID_PARAM, AI_ERROR_CODE_OUT_OF_RANGE};
      ai_log_err(err, "template code should be updated\r\n to support a model with multiple IO");
      return;
      } int ret = ai_run(&test_data, out_data_s);
      LOG_RAW("input data %.2f\r\n", test_data);
      LOG_RAW("output data [%.2f %.2f %.2f]\r\n", *((float *)&out_data_s[0]),
      *((float *)&out_data_s[4]),
      *((float *)&out_data_s[8]));
      LOG_RAW("output data [%d %d %d]\r\n", (uint8_t)round(*((float *)&out_data_s[0])),
      (uint8_t)round(*((float *)&out_data_s[4])),
      (uint8_t)round(*((float *)&out_data_s[8])));
      }
      MSH_CMD_EXPORT_ALIAS(ai_test, ai, ai_test sample);
    • 编译烧录,输入指令 ai 测试

  3. 实战工程 https://10.199.101.2/svn/FILNK/Code/ac_product/ac_control/branches/ai_test

STM32深度学习实战的更多相关文章

  1. 深度学习实战篇-基于RNN的中文分词探索

    深度学习实战篇-基于RNN的中文分词探索 近年来,深度学习在人工智能的多个领域取得了显著成绩.微软使用的152层深度神经网络在ImageNet的比赛上斩获多项第一,同时在图像识别中超过了人类的识别水平 ...

  2. 学习Keras:《Keras快速上手基于Python的深度学习实战》PDF代码+mobi

    有一定Python和TensorFlow基础的人看应该很容易,各领域的应用,但比较广泛,不深刻,讲硬件的部分可以作为入门人的参考. <Keras快速上手基于Python的深度学习实战>系统 ...

  3. 对比学习:《深度学习之Pytorch》《PyTorch深度学习实战》+代码

    PyTorch是一个基于Python的深度学习平台,该平台简单易用上手快,从计算机视觉.自然语言处理再到强化学习,PyTorch的功能强大,支持PyTorch的工具包有用于自然语言处理的Allen N ...

  4. 『深度应用』NLP机器翻译深度学习实战课程·零(基础概念)

    0.前言 深度学习用的有一年多了,最近开始NLP自然处理方面的研发.刚好趁着这个机会写一系列NLP机器翻译深度学习实战课程. 本系列课程将从原理讲解与数据处理深入到如何动手实践与应用部署,将包括以下内 ...

  5. 『深度应用』NLP机器翻译深度学习实战课程·壹(RNN base)

    深度学习用的有一年多了,最近开始NLP自然处理方面的研发.刚好趁着这个机会写一系列NLP机器翻译深度学习实战课程. 本系列课程将从原理讲解与数据处理深入到如何动手实践与应用部署,将包括以下内容:(更新 ...

  6. TensorFlow 2.0 深度学习实战 —— 浅谈卷积神经网络 CNN

    前言 上一章为大家介绍过深度学习的基础和多层感知机 MLP 的应用,本章开始将深入讲解卷积神经网络的实用场景.卷积神经网络 CNN(Convolutional Neural Networks,Conv ...

  7. 深度学习--实战 LeNet5

    深度学习--实战 LeNet5 数据集 数据集选用CIFAR-10的数据集,Cifar-10 是由 Hinton 的学生 Alex Krizhevsky.Ilya Sutskever 收集的一个用于普 ...

  8. 【神经网络与深度学习】深度学习实战——caffe windows 下训练自己的网络模型

    1.相关准备 1.1 手写数字数据集 这篇博客上有.jpg格式的图片下载,附带标签信息,有需要的自行下载,博客附带百度云盘下载地址(手写数字.jpg 格式):http://blog.csdn.net/ ...

  9. Tensorflow 2.0 深度学习实战 —— 详细介绍损失函数、优化器、激活函数、多层感知机的实现原理

    前言 AI 人工智能包含了机器学习与深度学习,在前几篇文章曾经介绍过机器学习的基础知识,包括了监督学习和无监督学习,有兴趣的朋友可以阅读< Python 机器学习实战 >.而深度学习开始只 ...

  10. 一箭N雕:多任务深度学习实战

    1.多任务学习导引 多任务学习是机器学习中的一个分支,按1997年综述论文Multi-task Learning一文的定义:Multitask Learning (MTL) is an inducti ...

随机推荐

  1. 堆排序(topk 问题)(NB)

    博客地址:https://www.cnblogs.com/zylyehuo/ # _*_coding:utf-8_*_ # 比较排序 import random def sift(li, low, h ...

  2. Oracle12c 数据库 警告日志

    目录 一:查看警告日志文件的位置 二:警告日志内容 三:告警日志监控: 方案1: 方案2: 方案3: 正文 回到顶部 一:查看警告日志文件的位置 Oracle 12c环境下查询,alert日志并不在b ...

  3. Oracle配置和性能优化方法

          性能是衡量软件系统的一个重要部分,可能引起性能低下的原因很多,如CPU/内存/网络资源不足,硬盘读写速度慢,数据库配置不合理,数据库对象规划或存储方式不合理,模块设计对性能考虑不足等. 1 ...

  4. git库移植

    记一次个人项目移植到组织项目的git应用,留爪. 1. 首先保证你本地有一份完整的库 2. 在 gitee 组织里新建一份裸库 3. 本地库移除所有远程库 git remote //查看所有远程库 g ...

  5. DRAM的读写操作、刷新、恢复的原理

    这一节湖科大教书匠讲得特别好,原理梳理的很清晰,建议去b站看一看 写这个只为了自己复习方便一点 对读操作会破坏数据的理解 预充电利用列线上的寄生电容,使得每列的电压保持在\(Vcc/2\) 进行读操作 ...

  6. MySQL 查询树结构、循环查询、查看函数、视图、存储过程

    MySQL经常会用到查询树结构数据,这里专门收集整了一篇. 构建函数 构建树查询函数:查询父级节点函数 -- 在mysql中完成节点下的所有节点或节点上的所有父节点的查询 -- 根据传入id查询所有父 ...

  7. Ruby+Selenium+testunit web自动化demo

    1.安装对应库 使用RubyMine新建项目打开终端安装对应库 gem install selenium-webdriver gem install test-unit 如果安装不成功,请切换到国内源 ...

  8. elemengui分页

    <!-- 分页模块 --> <template> <div class="block" style="margin-top:20px&quo ...

  9. fidder抓包微信小程序的方法

    想获取小程序的请求和返回数据,要么通过抓包工具抓包,要么使用小程序调试工具直接查看 总结下怎样使用fidder抓包 第一步,各种配置,把下面一系列图片里该勾的都勾上,够好了重启fidder 第二步,打 ...

  10. python,获取当前日期且以当前日期为名称创建文件名

    爬虫爬取信息时,需要把爬取的内容存到txt文档中,且爬虫是每天执行,以日期命名能避免出现名称重复等问题,解决方法如下 import time import os import sys path = o ...