apple的m1芯片比以往cpu芯片在机器学习加速上听说有15倍的提升,也就是可以使用apple mac训练深度学习pytorch模型!!!惊呆了

安装apple m1芯片版本的pytorch

然后使用chatGPT生成一个resnet101的训练代码,这里注意,如果网络特别轻的话是没有加速效果的,还没有cpu的计算来的快

这里要选择好设备不是"cuda"了,cuda是nvidia深度学习加速的配置

# 设置设备
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps") #torch.device("cpu")

  

resnet101的训练代码,训练mnist手写数字识别,之前我还尝试了两层linear的训练代码,低估了apple 的 torch.device("mps"),这两层linear的简单神经网络完全加速不起来,还不如torch.device("cpu")快,换成了resnet101加速效果就很明显了,目测速度在mps上比cpu快了5倍左右
 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torchvision.models import resnet101
from tqdm import tqdm # 设置设备
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps") #torch.device("cpu") # 加载 MNIST 数据集
train_dataset = MNIST(root="/Users/xinyuuliu/Desktop/test_python/", train=True, transform=ToTensor(), download=True)
test_dataset = MNIST(root="/Users/xinyuuliu/Desktop/test_python/", train=False, transform=ToTensor()) # 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) # 定义 ResNet-101 模型
model = resnet101(pretrained=False)
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = nn.Linear(2048, 10) # 替换最后一层全连接层
model.to(device) # 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001) # 训练和评估函数
def train(model, dataloader, optimizer, criterion):
model.train()
running_loss = 0.0
for inputs, labels in tqdm(dataloader, desc="Training"):
inputs = inputs.to(device)
labels = labels.to(device) optimizer.zero_grad() outputs = model(inputs)
loss = criterion(outputs, labels) loss.backward()
optimizer.step() running_loss += loss.item() * inputs.size(0) epoch_loss = running_loss / len(dataloader.dataset)
return epoch_loss def evaluate(model, dataloader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in tqdm(dataloader, desc="Evaluating"):
inputs = inputs.to(device)
labels = labels.to(device) outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1) total += labels.size(0)
correct += (predicted == labels).sum().item() accuracy = correct / total * 100
return accuracy # 训练和评估
num_epochs = 10 for epoch in range(num_epochs):
print(f"Epoch {epoch+1}/{num_epochs}")
train_loss = train(model, train_loader, optimizer, criterion)
print(f"Training Loss: {train_loss:.4f}") test_acc = evaluate(model, test_loader)
print(f"Test Accuracy: {test_acc:.2f}%")

  

结果:

在mps device上,训练时间在10分钟左右

在cpu device上,训练时间在50分钟左右,明显在mps device上速度快了5倍

macbook苹果m1芯片训练机器学习、深度学习模型,resnet101在mnist手写数字识别上做加速,torch.device("mps")的更多相关文章

  1. 【深度学习系列】PaddlePaddle之手写数字识别

    上周在搜索关于深度学习分布式运行方式的资料时,无意间搜到了paddlepaddle,发现这个框架的分布式训练方案做的还挺不错的,想跟大家分享一下.不过呢,这块内容太复杂了,所以就简单的介绍一下padd ...

  2. 深度学习(一):Python神经网络——手写数字识别

    声明:本文章为阅读书籍<Python神经网络编程>而来,代码与书中略有差异,书籍封面: 源码 若要本地运行,请更改源码中图片与数据集的位置,环境为 Python3.6x. 1 import ...

  3. mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)

    前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...

  4. 深度学习之 mnist 手写数字识别

    深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...

  5. 用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别

    用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别 http://phunter.farbox.com/post/mxnet-tutorial1 用MXnet实战深度学 ...

  6. SVM学习笔记(二)----手写数字识别

    引言 上一篇博客整理了一下SVM分类算法的基本理论问题,它分类的基本思想是利用最大间隔进行分类,处理非线性问题是通过核函数将特征向量映射到高维空间,从而变成线性可分的,但是运算却是在低维空间运行的.考 ...

  7. 【机器学习】k-近邻算法应用之手写数字识别

    上篇文章简要介绍了k-近邻算法的算法原理以及一个简单的例子,今天再向大家介绍一个简单的应用,因为使用的原理大体差不多,就没有没有过多的解释. 为了具有说明性,把手写数字的图像转换为txt文件,如下图所 ...

  8. 深度学习之PyTorch实战(3)——实战手写数字识别

    上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...

  9. 学习笔记CB009:人工神经网络模型、手写数字识别、多层卷积网络、词向量、word2vec

    人工神经网络,借鉴生物神经网络工作原理数学模型. 由n个输入特征得出与输入特征几乎相同的n个结果,训练隐藏层得到意想不到信息.信息检索领域,模型训练合理排序模型,输入特征,文档质量.文档点击历史.文档 ...

  10. TensorFlow.NET机器学习入门【5】采用神经网络实现手写数字识别(MNIST)

    从这篇文章开始,终于要干点正儿八经的工作了,前面都是准备工作.这次我们要解决机器学习的经典问题,MNIST手写数字识别. 首先介绍一下数据集.请首先解压:TF_Net\Asset\mnist_png. ...

随机推荐

  1. Javaweb学习第十二弹--Request和Response

    XML配置方式编写Servlet 3.0版本之前,仅仅支持XML配置文件的配置方式 1.编写Servlet类 2.在web.xml中配置该Servlet Request和Response Reques ...

  2. golang中关于死锁的思考与学习

    1.Golang中死锁的触发条件 1.1 书上关于死锁的四个必要条件的讲解 发生死锁时,线程永远不能完成,系统资源被阻碍使用,以致于阻止了其他作业开始执行.在讨论处理死锁问题的各种方法之前,我们首先深 ...

  3. 一文带你搞懂java中的变量的定义是什么意思

    前言 在之前的文章中,壹哥给大家讲解了Java的第一个案例HelloWorld,并详细给大家介绍了Java的标识符,而且现在我们也已经知道该使用什么样的工具进行Java开发.那么接下来,壹哥会集中精力 ...

  4. Django笔记五之字段类型

    这篇笔记介绍字段的类型 Field Type. Django 的model 下的 field 对应的是 MySQL 中的表字段,而我们定义的 field 的类型则对应 MySQL 中的字段类型. 本次 ...

  5. ApplicationRunner 类说明

    在开发中可能会有这样的情景.需要在容器启动的时候执行一些内容.比如读取配置文件,数据库连接之类的.SpringBoot给我们提供了两个接口来帮助我们实现这种需求.这两个接口分别为 CommandLin ...

  6. [数据分析与可视化] Python绘制数据地图2-GeoPandas地图可视化

    本文主要介绍GeoPandas结合matplotlib实现地图的基础可视化.GeoPandas是一个Python开源项目,旨在提供丰富而简单的地理空间数据处理接口.GeoPandas扩展了Pandas ...

  7. day29:计算机网络概念

    目录 1.网络开发的两大架构 2.网络概念 3.OSI七层模型 4.ARP协议 5.TCP三次握手和四次挥手 1.网络开发的两大架构 1.没有网络的时候,文件是如何传输的? 早期没有网络 a.py - ...

  8. git撤销某一次commit提交

    一.使用git rebase命令 如果您想彻底删除 Git 中的某次提交的内容,可以使用 git rebase 命令并将该提交删除. 以下是删除 Git 提交内容的步骤: 找到要删除的提交的哈希值.可 ...

  9. C# 从0到实战--程序入门:基本程序结构·hello,world

    为什么要写博客 某人是一名大学生,到了大二,学院开始教授.Net,从这里我接触到了C#和ASP.Net,这些技术让我感到了想不到的快速开发之震撼.于是突发奇想,写此博客来记录我的学习路程.博客不仅仅是 ...

  10. Puppeteer+RabbitMQ:Node.js 批量加工pdf服务架构设计与落地

    全文约8500字,阅读时长约10分钟. 智慧作业最近上线「个性化手册」(简称个册)功能,一份完整的个性化手册分为三部分: •学情分析:根据学生阶段性的学习和考试情况进行学情分析.归纳.总结,汇总学情数 ...