huggingface vit训练CIFAR10数据集代码 ,可以改dataset训练自己的数据
上代码,使用hugging face fineturn vit模型
自己写的代码
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST,CIFAR10
from torchvision.transforms import ToTensor
from torchvision.models import resnet101
from tqdm import tqdm # 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("mps")
# torch.device("cpu") # 加载 MNIST 数据集
train_dataset = CIFAR10(root="/data/xinyuuliu/datas", train=True, transform=ToTensor(), download=True)
test_dataset = CIFAR10(root="/data/xinyuuliu/datas", train=False, transform=ToTensor()) def collate_fn(batch):
"""
对batch数据进行处理
:param batch: [一个getitem的结果,getitem的结果,getitem的结果]
:return: 元组
"""
reviews,labels = zip(*batch)
# print(reviews)
# print(labels)
# reviews = torch.Tensor(reviews)
labels = torch.Tensor(labels) return reviews,labels
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False,collate_fn=collate_fn) # url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# image = Image.open(requests.get(url, stream=True).raw) processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
model.config.classifier = 'mlp'
model.config.num_labels = 10
# print(model.get_output_embeddings)
# print(model.classifier)
model.classifier = nn.Linear(768,10)
print(model.classifier) parameters = list(model.parameters())
for x in parameters[:-1]:
x.requires_grad = False model.to(device) # 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001) def train(model, dataloader, optimizer, criterion):
model.train()
running_loss = 0.0
for inputs, labels in tqdm(dataloader, desc="Training"):
# print(inputs)
inputs = processor(images=inputs, return_tensors="pt")
inputs['pixel_values'] = inputs['pixel_values'].to(device)
labels = labels.to(device)
# print(inputs['pixel_values'].shape)
# print(labels.shape)
optimizer.zero_grad() outputs = model(**inputs)
logits = outputs.logits # print(logits,labels)
loss = criterion(logits, labels.long())
loss.backward()
optimizer.step()
# model predicts one of the 1000 ImageNet classes
# predicted_class_idx = logits.argmax(-1).item()
# print("Predicted class:", model.config.id2label[predicted_class_idx])
running_loss += loss.item() * inputs['pixel_values'].size(0) epoch_loss = running_loss / len(dataloader.dataset)
return epoch_loss def evaluate(model, dataloader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in tqdm(dataloader, desc="Evaluating"):
inputs = processor(images=inputs, return_tensors="pt")
inputs['pixel_values'] = inputs['pixel_values'].to(device)
labels = labels.to(device) outputs = model(**inputs)
logits = outputs.logits predicted= logits.argmax(-1) total += labels.size(0)
correct += (predicted == labels).sum().item() accuracy = correct / total * 100
return accuracy # 训练和评估
num_epochs = 10 for epoch in range(num_epochs):
print(f"Epoch {epoch+1}/{num_epochs}")
train_loss = train(model, train_loader, optimizer, criterion)
print(f"Training Loss: {train_loss:.4f}") test_acc = evaluate(model, test_loader)
print(f"Test Accuracy: {test_acc:.2f}%")
chatgpt生成的代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from transformers import ViTModel, ViTForImageClassification
from tqdm import tqdm # 设置随机种子
torch.manual_seed(42) # 定义超参数
batch_size = 32
num_epochs = 10
learning_rate = 1e-4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
]) # 加载CIFAR-10数据集
train_dataset = CIFAR10(root='/data/xinyuuliu/datas', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='/data/xinyuuliu/datas', train=False, download=True, transform=transform) # 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # 加载预训练的ViT模型
vit_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device) # 替换分类头
num_classes = 10
vit_model.config.classifier = 'mlp'
vit_model.config.num_labels = num_classes
vit_model.classifier = nn.Linear(vit_model.config.hidden_size, num_classes).to(device) # 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vit_model.parameters(), lr=learning_rate) # 微调ViT模型
for epoch in range(num_epochs):
print("epoch:",epoch)
vit_model.train()
train_loss = 0.0
train_correct = 0 bar = tqdm(train_loader,total=len(train_loader))
for images, labels in bar:
images = images.to(device)
labels = labels.to(device) # 前向传播
outputs = vit_model(images)
loss = criterion(outputs.logits, labels) # 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step() train_loss += loss.item()
_, predicted = torch.max(outputs.logits, 1)
train_correct += (predicted == labels).sum().item() # 在训练集上计算准确率
train_accuracy = 100.0 * train_correct / len(train_dataset) # 在测试集上进行评估
vit_model.eval()
test_loss = 0.0
test_correct = 0 with torch.no_grad():
bar = tqdm(test_loader,total=len(test_loader))
for images, labels in bar:
images = images.to(device)
labels = labels.to(device) outputs = vit_model(images)
loss = criterion(outputs.logits, labels) test_loss += loss.item()
_, predicted = torch.max(outputs.logits, 1)
test_correct += (predicted == labels).sum().item() # 在测试集上计算准确率
test_accuracy = 100.0 * test_correct / len(test_dataset) # 打印每个epoch的训练损失、训练准确率和测试准确率
print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Accuracy: {test_accuracy:.2f}%')
huggingface vit训练CIFAR10数据集代码 ,可以改dataset训练自己的数据的更多相关文章
- Ubuntu+caffe训练cifar-10数据集
1. 下载cifar-10数据库 ciffar-10数据集包含10种物体分类,50000张训练图片,10000张测试图片. 在终端执行指令下载cifar-10数据集(二进制文件): cd ~/caff ...
- Keras学习:试用卷积-训练CIFAR-10数据集
import numpy as np import cPickle import keras as ks from keras.layers import Dense, Activation, Fla ...
- MXNet学习:试用卷积-训练CIFAR-10数据集
第一次用卷积,看的别人的模型跑的CIFAR-10,不过吐槽一下...我觉着我的965m加速之后比我的cpu算起来没快多少..正确率64%的样子,没达到模型里说的75%,不知道问题出在哪里 import ...
- 使用caffe训练mnist数据集 - caffe教程实战(一)
个人认为学习一个陌生的框架,最好从例子开始,所以我们也从一个例子开始. 学习本教程之前,你需要首先对卷积神经网络算法原理有些了解,而且安装好了caffe 卷积神经网络原理参考:http://cs231 ...
- CaffeExample 在CIFAR-10数据集上训练与测试
本文主要来自Caffe作者Yangqing Jia网站给出的examples. @article{jia2014caffe, Author = {Jia, Yangqing and Shelhamer ...
- 仿照CIFAR-10数据集格式,制作自己的数据集
本系列文章由 @yhl_leo 出品,转载请注明出处. 文章链接: http://blog.csdn.net/yhl_leo/article/details/50801226 前一篇博客:C/C++ ...
- TensorFlow CNN 测试CIFAR-10数据集
本系列文章由 @yhl_leo 出品,转载请注明出处. 文章链接: http://blog.csdn.net/yhl_leo/article/details/50738311 1 CIFAR-10 数 ...
- caffe︱cifar-10数据集quick模型的官方案例
准备拿几个caffe官方案例用来练习,就看到了caffe中的官方案例有cifar-10数据集.于是练习了一下,在CPU情况下构建quick模型.主要参考博客:liumaolincycle的博客 配置: ...
- 实践详细篇-Windows下使用VS2015编译的Caffe训练mnist数据集
上一篇记录的是学习caffe前的环境准备以及如何创建好自己需要的caffe版本.这一篇记录的是如何使用编译好的caffe做训练mnist数据集,步骤编号延用上一篇 <实践详细篇-Windows下 ...
- CIFAR-10数据集图像分类【PCA+基于最小错误率的贝叶斯决策】
CIFAR-10和CIFAR-100均是带有标签的数据集,都出自于规模更大的一个数据集,他有八千万张小图片.而本次实验采用CIFAR-10数据集,该数据集共有60000张彩色图像,这些图像是32*32 ...
随机推荐
- spirmmvc框架整合手抄版示例,供基础搭建代码对照
注明所有文档和图片完整对照,辟免笔记出错,不能复习 package com.ithm.config; import com.alibaba.druid.pool.DruidDataSource; ...
- 实现一个 SEO 友好的响应式多语言官网 (Vite-SSG + Vuetify3) 我的踩坑之旅
在 2023 年的年底,我终于有时间下定决心把我的 UtilMeta 项目官网 进行翻新,主要的原因是之前的官网是用 Vue2 实现的一个 SPA 应用,对搜索引擎 SEO 很不友好,这对于介绍项目的 ...
- Zabbix“专家坐诊”第190期问答汇总
问题一 Q:请问为啥用拓扑图监控交换机接口流量,获取不到数据,显示未知,键值也没错 ,最新数据也能看到,是什么原因呢? A:把第一个值改成主机名. 问题二 Q:请问下zabbix server 有什么 ...
- [VueJsDev] 快速入门 - 开发前小知识
[VueJsDev] 目录列表 https://www.cnblogs.com/pengchenggang/p/17037320.html 开发前小知识 ::: details 目录 目录 开发前小知 ...
- 摆脱鼠标操作 - vscode - vim - 官方说明文档 github上的,防止打不开,这里发一个
Key - command done - command done with VS Code specific customization ️ - some variations of the com ...
- 从一线方案商的角度来看高通QCC3020芯片
写在前面的话 QCC3020的推出已经有一段时间了.在蓝牙音频的圈子里,属于家喻户晓的芯片了.再加上高通的大力宣传和一些顶尖级产品的使用,可以说,它是高通在吸收CSR的技术之后,着力推出的最具竞争 ...
- window.showModalDialog与opener及returnValue
首先来看看 window.showModalDialog 的参数 vReturnValue = window.showModalDialog(sURL [, vArguments] [, sFeatu ...
- 21_显示YUV图片&视频
一.显示YUV图片 显示 YUV 图片和显示 BMP 图片的大致流程是一样的.显示 BMP 图片我们可以直接获取到 BMP 图片的 surface,然后直接从 surface 创建纹理.显示 YUV ...
- JAVA 获取文件 MD5
private static final char[] hexCode = "0123456789ABCDEF".toCharArray(); // 文件类取MD5 public ...
- OWOD:开放世界目标检测,更贴近现实的检测场景 | CVPR 2021 Oral
不同于以往在固定数据集上测试性能,论文提出了一个更符合实际的全新检测场景Open World Object Detection,需要同时识别出未知类别和已知类别,并不断地进行增量学习.论文还给出了OR ...