这几天又在玩树莓派,先是搞了个物联网,又在尝试在树莓派上搞一些简单的神经网络,这次搞得是mlp识别mnist手写数字识别

训练代码在电脑上,cpu就能训练,很快的:

 1 import torch
2 import torch.nn as nn
3 import torch.optim as optim
4 from torchvision import datasets, transforms
5
6 # 设置随机种子
7 torch.manual_seed(42)
8
9 # 定义MLP模型
10 class MLP(nn.Module):
11 def __init__(self):
12 super(MLP, self).__init__()
13 self.fc1 = nn.Linear(784, 256)
14 self.fc2 = nn.Linear(256, 128)
15 self.fc3 = nn.Linear(128, 10)
16
17 def forward(self, x):
18 x = x.view(-1, 784)
19 x = torch.relu(self.fc1(x))
20 x = torch.relu(self.fc2(x))
21 x = self.fc3(x)
22 return x
23
24 # 加载MNIST数据集
25 transform = transforms.Compose([
26 transforms.ToTensor(),
27 # transforms.Normalize((0.1307,), (0.3081,))
28 ])
29
30 train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
31 test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
32
33 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
34 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
35
36 # 创建模型实例
37 model = MLP()
38
39 # 定义损失函数和优化器
40 criterion = nn.CrossEntropyLoss()
41 optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
42
43 # 训练模型
44 def train(model, train_loader, optimizer, criterion, epochs):
45 model.train()
46 for epoch in range(1, epochs + 1):
47 for batch_idx, (data, target) in enumerate(train_loader):
48 optimizer.zero_grad()
49 output = model(data)
50 loss = criterion(output, target)
51 loss.backward()
52 optimizer.step()
53
54 if batch_idx % 100 == 0:
55 print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
56 epoch, batch_idx * len(data), len(train_loader.dataset),
57 100. * batch_idx / len(train_loader), loss.item()))
58
59 # 训练模型
60 train(model, train_loader, optimizer, criterion, epochs=5)
61
62 # 保存模型为NumPy格式
63 numpy_model = {}
64 numpy_model['fc1.weight'] = model.fc1.weight.detach().numpy()
65 numpy_model['fc1.bias'] = model.fc1.bias.detach().numpy()
66 numpy_model['fc2.weight'] = model.fc2.weight.detach().numpy()
67 numpy_model['fc2.bias'] = model.fc2.bias.detach().numpy()
68 numpy_model['fc3.weight'] = model.fc3.weight.detach().numpy()
69 numpy_model['fc3.bias'] = model.fc3.bias.detach().numpy()
70
71 # 保存为NumPy格式的数据
72 import numpy as np
73 np.savez('mnist_model.npz', **numpy_model)

然后需要自己倒出一些图片在dataset里:我保存在了mnist_pi文件夹下,“_”后面的是标签,主要是在pc端导出保存到树莓派下

树莓派推理端的代码,需要numpy手动重新搭建网络,然后加载那些保存的矩阵参数,做矩阵乘法和加法

 1 import numpy as np
2 import os
3 from PIL import Image
4
5 # 加载模型
6 model_data = np.load('mnist_model.npz')
7 weights1 = model_data['fc1.weight']
8 biases1 = model_data['fc1.bias']
9 weights2 = model_data['fc2.weight']
10 biases2 = model_data['fc2.bias']
11 weights3 = model_data['fc3.weight']
12 biases3 = model_data['fc3.bias']
13
14 # 进行推理
15 def predict(image, weights1, biases1,weights2, biases2,weights3, biases3):
16 image = image.flatten()/255 # 将输入图像展平并进行归一化
17 output = np.dot(weights1, image) + biases1
18 output = np.dot(weights2, output) + biases2
19 output = np.dot(weights3, output) + biases3
20 predicted_class = np.argmax(output)
21 return predicted_class
22
23
24
25
26 folder_path = './mnist_pi' # 替换为图片所在的文件夹路径
27 def infer_images_in_folder(folder_path):
28 for file_name in os.listdir(folder_path):
29 file_path = os.path.join(folder_path, file_name)
30 if os.path.isfile(file_path) and file_name.endswith(('.jpg', '.jpeg', '.png')):
31 image = Image.open(file_path)
32 label = file_name.split(".")[0].split("_")[1]
33 image = np.array(image)
34 print("file_path:",file_path,"img size:",image.shape,"label:",label)
35 predicted_class = predict(image, weights1, biases1,weights2, biases2,weights3, biases3)
36 print('Predicted class:', predicted_class)
37
38 infer_images_in_folder(folder_path)

结果:

效果还不错:

这次内容就到这里了,下次争取做一个卷积的神经网络在树莓派上推理,然后争取做一个目标检测的模型在树莓派上

在树莓派上使用numpy实现简单的神经网络推理,pytorch在服务器或PC上训练好模型保存成numpy格式的数据,推理在树莓派上加载模型的更多相关文章

  1. Numpy实现简单BP神经网络识别手写数字

    本文将用Numpy实现简单BP神经网络完成对手写数字图片的识别,数据集为42000张带标签的28x28像素手写数字图像.在计算机完成对手写数字图片的识别过程中,代表图片的28x28=764个像素的特征 ...

  2. C#开发BIMFACE系列53 WinForm程序中使用CefSharp加载模型图纸1 简单应用

    BIMFACE二次开发系列目录     [已更新最新开发文章,点击查看详细] 在我的博客<C#开发BIMFACE系列52 CS客户端集成BIMFACE应用的技术方案>中介绍了多种集成BIM ...

  3. pytorch在CPU和GPU上加载模型

    pytorch允许把在GPU上训练的模型加载到CPU上,也允许把在CPU上训练的模型加载到GPU上.CPU->CPU,GPU->GPU torch.load('gen_500000.pkl ...

  4. 【神经网络与深度学习】如何将别人训练好的model用到自己的数据上

    caffe团队用imagenet图片进行训练,迭代30多万次,训练出来一个model.这个model将图片分为1000类,应该是目前为止最好的图片分类model了. 假设我现在有一些自己的图片想进行分 ...

  5. 将mnist数据集保存成numpy格式

    import numpy as np from urllib import request import gzip import pickle filename = [ ["training ...

  6. 【模块化编程】理解requireJS-实现一个简单的模块加载器

    在前文中我们不止一次强调过模块化编程的重要性,以及其可以解决的问题: ① 解决单文件变量命名冲突问题 ② 解决前端多人协作问题 ③ 解决文件依赖问题 ④ 按需加载(这个说法其实很假了) ⑤ ..... ...

  7. Tensorflow模型加载与保存、Tensorboard简单使用

    先上代码: from __future__ import absolute_import from __future__ import division from __future__ import ...

  8. 手把手教你实现Android RecyclerView上拉加载功能

    摘要 一直在用到RecyclerView时都会微微一颤,因为一直都没去了解怎么实现上拉加载,受够了每次去Github找开源引入,因为感觉就为了一个上拉加载功能而去引入一大堆你不知道有多少BUG的代码, ...

  9. springboot+layui实现PC端用户的增删改查 & 整合mui实现app端的自动登录和用户的上拉加载 & HBuilder打包app并在手机端下载安装

    springboot整合web开发的各个组件在前面已经有详细的介绍,下面是用springboot整合layui实现了基本的增删改查. 同时在学习mui开发app,也就用mui实现了一个简单的自动登录和 ...

  10. SwipeRefreshLayout详解和自定义上拉加载更多

    个人主页 演示Demo下载 本文重点介绍了SwipeRefreshLayout的使用和自定View继承SwipeRefreshLayout添加上拉加载更多的功能. 介绍之前,先来看一下SwipeRef ...

随机推荐

  1. 同步协程的必备工具: WaitGroup

    1. 简介 本文将介绍 Go 语言中的 WaitGroup 并发原语,包括 WaitGroup 的基本使用方法.实现原理.使用注意事项以及常见的使用方式.能够更好地理解和应用 WaitGroup 来协 ...

  2. 2023最新ELK日志平台(elasticsearch+logstash+kibana)搭建

    前言 去年公司由于不断发展,内部自研系统越来越多,所以后来搭建了一个日志收集平台,并将日志收集功能以二方包形式引入自研系统,避免每个自研系统都要建立一套自己的日志模块,节约了开发时间,管理起来也更加容 ...

  3. java选择结构-switch

    java选择结构-switch java的另一个多选择结构switch-case case中的value为常数值. 不加break,会一直执行到最后,包括default(case穿透) switch( ...

  4. Laf v1.0 发布:函数计算只有两种,30s 放弃的和 30s 上线的

    一般情况下,开发一个系统都需要前端和后端,仅靠一个人几乎无法胜任,需要考虑的特性和功能非常多,比如: 需要一个数据库来存放数据: 需要一个文件存储来存放各种文件,比如图片文件: 后端需要提供接口供前端 ...

  5. Windows11快捷键大集合+手动给程序添加快捷键

    本文收集了170多个windows11上的快捷键,其中有少部分是windows11新添加的.大部分的win10快捷键也适用于win11.这些快捷键涵盖了系统设置.命令行程序执行.Snap布局切换.对话 ...

  6. 可视化—AntV G6 紧凑树实现节点与边动态样式、超过X条展示更多等实用小功能

    通过一段时间的使用和学习,对G6有了更一步的经验,这篇博文主要从以下几个小功能着手介绍,文章最后会给出完整的demo代码. 目录 1. 树图的基本布局和使用 2. 根据返回数据的属性不同,定制不一样的 ...

  7. Service Mesh之Istio部署bookinfo

    前文我们了解了service mesh.分布式服务治理和istio部署相关话题,回顾请参考https://www.cnblogs.com/qiuhom-1874/p/17281541.html:今天我 ...

  8. SpringCloud源码学习笔记3——Nacos服务注册源码分析

    系列文章目录和关于我 一丶基本概念&Nacos架构 1.为什么需要注册中心 实现服务治理.服务动态扩容,以及调用时能有负载均衡的效果. 如果我们将服务提供方的ip地址配置在服务消费方的配置文件 ...

  9. TypeScript必知三部曲(一)TypeScript编译方案以及IDE对TS的类型检查

    TypeScript代码的编译过程一直以来会给很多小伙伴造成困扰,typescript官方提供tsc对ts代码进行编译,babel也表示能够编译ts代码,它们二者的区别是什么?我们应该选择哪种方案?为 ...

  10. 自己动手从零写桌面操作系统GrapeOS系列教程——4.1 在VirtualBox中安装CentOS

    学习操作系统原理最好的方法是自己写一个简单的操作系统. 之前讲解开发环境时并没有介绍具体的安装过程,有网友反应CentOS的安装配置有问题,尤其是共享文件夹.本讲我们就来补充介绍一下在VirtualB ...