huggingface vit训练CIFAR10数据集代码 ,可以改dataset训练自己的数据
上代码,使用hugging face fineturn vit模型
自己写的代码
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST,CIFAR10
from torchvision.transforms import ToTensor
from torchvision.models import resnet101
from tqdm import tqdm # 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("mps")
# torch.device("cpu") # 加载 MNIST 数据集
train_dataset = CIFAR10(root="/data/xinyuuliu/datas", train=True, transform=ToTensor(), download=True)
test_dataset = CIFAR10(root="/data/xinyuuliu/datas", train=False, transform=ToTensor()) def collate_fn(batch):
"""
对batch数据进行处理
:param batch: [一个getitem的结果,getitem的结果,getitem的结果]
:return: 元组
"""
reviews,labels = zip(*batch)
# print(reviews)
# print(labels)
# reviews = torch.Tensor(reviews)
labels = torch.Tensor(labels) return reviews,labels
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False,collate_fn=collate_fn) # url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# image = Image.open(requests.get(url, stream=True).raw) processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
model.config.classifier = 'mlp'
model.config.num_labels = 10
# print(model.get_output_embeddings)
# print(model.classifier)
model.classifier = nn.Linear(768,10)
print(model.classifier) parameters = list(model.parameters())
for x in parameters[:-1]:
x.requires_grad = False model.to(device) # 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001) def train(model, dataloader, optimizer, criterion):
model.train()
running_loss = 0.0
for inputs, labels in tqdm(dataloader, desc="Training"):
# print(inputs)
inputs = processor(images=inputs, return_tensors="pt")
inputs['pixel_values'] = inputs['pixel_values'].to(device)
labels = labels.to(device)
# print(inputs['pixel_values'].shape)
# print(labels.shape)
optimizer.zero_grad() outputs = model(**inputs)
logits = outputs.logits # print(logits,labels)
loss = criterion(logits, labels.long())
loss.backward()
optimizer.step()
# model predicts one of the 1000 ImageNet classes
# predicted_class_idx = logits.argmax(-1).item()
# print("Predicted class:", model.config.id2label[predicted_class_idx])
running_loss += loss.item() * inputs['pixel_values'].size(0) epoch_loss = running_loss / len(dataloader.dataset)
return epoch_loss def evaluate(model, dataloader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in tqdm(dataloader, desc="Evaluating"):
inputs = processor(images=inputs, return_tensors="pt")
inputs['pixel_values'] = inputs['pixel_values'].to(device)
labels = labels.to(device) outputs = model(**inputs)
logits = outputs.logits predicted= logits.argmax(-1) total += labels.size(0)
correct += (predicted == labels).sum().item() accuracy = correct / total * 100
return accuracy # 训练和评估
num_epochs = 10 for epoch in range(num_epochs):
print(f"Epoch {epoch+1}/{num_epochs}")
train_loss = train(model, train_loader, optimizer, criterion)
print(f"Training Loss: {train_loss:.4f}") test_acc = evaluate(model, test_loader)
print(f"Test Accuracy: {test_acc:.2f}%")
chatgpt生成的代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from transformers import ViTModel, ViTForImageClassification
from tqdm import tqdm # 设置随机种子
torch.manual_seed(42) # 定义超参数
batch_size = 32
num_epochs = 10
learning_rate = 1e-4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
]) # 加载CIFAR-10数据集
train_dataset = CIFAR10(root='/data/xinyuuliu/datas', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='/data/xinyuuliu/datas', train=False, download=True, transform=transform) # 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # 加载预训练的ViT模型
vit_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device) # 替换分类头
num_classes = 10
vit_model.config.classifier = 'mlp'
vit_model.config.num_labels = num_classes
vit_model.classifier = nn.Linear(vit_model.config.hidden_size, num_classes).to(device) # 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vit_model.parameters(), lr=learning_rate) # 微调ViT模型
for epoch in range(num_epochs):
print("epoch:",epoch)
vit_model.train()
train_loss = 0.0
train_correct = 0 bar = tqdm(train_loader,total=len(train_loader))
for images, labels in bar:
images = images.to(device)
labels = labels.to(device) # 前向传播
outputs = vit_model(images)
loss = criterion(outputs.logits, labels) # 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step() train_loss += loss.item()
_, predicted = torch.max(outputs.logits, 1)
train_correct += (predicted == labels).sum().item() # 在训练集上计算准确率
train_accuracy = 100.0 * train_correct / len(train_dataset) # 在测试集上进行评估
vit_model.eval()
test_loss = 0.0
test_correct = 0 with torch.no_grad():
bar = tqdm(test_loader,total=len(test_loader))
for images, labels in bar:
images = images.to(device)
labels = labels.to(device) outputs = vit_model(images)
loss = criterion(outputs.logits, labels) test_loss += loss.item()
_, predicted = torch.max(outputs.logits, 1)
test_correct += (predicted == labels).sum().item() # 在测试集上计算准确率
test_accuracy = 100.0 * test_correct / len(test_dataset) # 打印每个epoch的训练损失、训练准确率和测试准确率
print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Accuracy: {test_accuracy:.2f}%')
huggingface vit训练CIFAR10数据集代码 ,可以改dataset训练自己的数据的更多相关文章
- Ubuntu+caffe训练cifar-10数据集
1. 下载cifar-10数据库 ciffar-10数据集包含10种物体分类,50000张训练图片,10000张测试图片. 在终端执行指令下载cifar-10数据集(二进制文件): cd ~/caff ...
- Keras学习:试用卷积-训练CIFAR-10数据集
import numpy as np import cPickle import keras as ks from keras.layers import Dense, Activation, Fla ...
- MXNet学习:试用卷积-训练CIFAR-10数据集
第一次用卷积,看的别人的模型跑的CIFAR-10,不过吐槽一下...我觉着我的965m加速之后比我的cpu算起来没快多少..正确率64%的样子,没达到模型里说的75%,不知道问题出在哪里 import ...
- 使用caffe训练mnist数据集 - caffe教程实战(一)
个人认为学习一个陌生的框架,最好从例子开始,所以我们也从一个例子开始. 学习本教程之前,你需要首先对卷积神经网络算法原理有些了解,而且安装好了caffe 卷积神经网络原理参考:http://cs231 ...
- CaffeExample 在CIFAR-10数据集上训练与测试
本文主要来自Caffe作者Yangqing Jia网站给出的examples. @article{jia2014caffe, Author = {Jia, Yangqing and Shelhamer ...
- 仿照CIFAR-10数据集格式,制作自己的数据集
本系列文章由 @yhl_leo 出品,转载请注明出处. 文章链接: http://blog.csdn.net/yhl_leo/article/details/50801226 前一篇博客:C/C++ ...
- TensorFlow CNN 测试CIFAR-10数据集
本系列文章由 @yhl_leo 出品,转载请注明出处. 文章链接: http://blog.csdn.net/yhl_leo/article/details/50738311 1 CIFAR-10 数 ...
- caffe︱cifar-10数据集quick模型的官方案例
准备拿几个caffe官方案例用来练习,就看到了caffe中的官方案例有cifar-10数据集.于是练习了一下,在CPU情况下构建quick模型.主要参考博客:liumaolincycle的博客 配置: ...
- 实践详细篇-Windows下使用VS2015编译的Caffe训练mnist数据集
上一篇记录的是学习caffe前的环境准备以及如何创建好自己需要的caffe版本.这一篇记录的是如何使用编译好的caffe做训练mnist数据集,步骤编号延用上一篇 <实践详细篇-Windows下 ...
- CIFAR-10数据集图像分类【PCA+基于最小错误率的贝叶斯决策】
CIFAR-10和CIFAR-100均是带有标签的数据集,都出自于规模更大的一个数据集,他有八千万张小图片.而本次实验采用CIFAR-10数据集,该数据集共有60000张彩色图像,这些图像是32*32 ...
随机推荐
- SpringMVC简介 & 原理
特点 1.轻量级,简单易学 2.高效,基于请求响应的MVC框架 3.与Spring兼容性好,与之无缝接合(就是它的一部分) 4.约定优于配置(maven) 5.功能强大:支持RESTful 数据验证 ...
- 技能get-ps抠颜色一样的图
公司要插个小图片,从网上down下来的图片是不过是jpg的,背景不透明,这时候可以使用ps工具把这种同一颜色的内容扣下来. 操作步骤: 选择-色彩范围,然后用取样器取颜色,再调节拉条选取颜色范围,最后 ...
- docker安装kafka和zookeeper
参考,欢迎点击原文:https://www.cnblogs.com/360minitao/p/14665845.html(主要) https://blog.csdn.net/qq_22041375/a ...
- php处理序列化jQuery serializeArray数据
介绍jquery的几个常用处理表单的函数: 1.序列化表单内容元素为字符串,常用于ajax提交. $("form").serialize() 2. serializeArray() ...
- 记录--基于Vue2.0实现后台系统权限控制
这里给大家分享我在网上总结出来的一些知识,希望对大家有所帮助 基于Vue.js 2.x系列 + Element UI 的后台系统权限控制 前言:关于vue权限路由的那些事儿-- 项目背景:现有一个后台 ...
- C++关于栈对象返回的问题
本次实验环境 环境1:Win10, QT 5.12 环境2:Centos7,g++ 4.8.5 一. 主要结论 可以返回栈上的对象(各平台会有不同的优化),不可以返回栈对象的引用. 二.先看看函数传参 ...
- PDF的分割与合并
1.进行PDF切割 python代码如下: # 20220521 # 1.选择要分割的文件 # 2.选择要保存的位置,分割为多个文件时,可自动用页码命名 # 3.输入要分割的页码,可以是一个范围1-2 ...
- vivado的VIO调试工具的使用
vivado中的VIO调试工具的使用 1.实验原理 前面一篇介绍了ILA的独立测试,vivado中还有其他的FPGA测试工具.其中VIO就是个比较常用的工具.相对于ILA更多的关注波形,VIO则专注于 ...
- UE427-C++实现摄像机视角的移动,类似开镜效果
教程 方法 调整相机视野和弹簧臂的长度 //自带的tick函数内 需要使用DeltaTime if (bZoomIn) { ZoomFactor += DeltaTime / 0.5f; } else ...
- C++中std::function常见用法
C++标准库中的std::function是一个通用的函数封装,可以用来存储.复制.调用任何可调用对象(函数.函数指针.成员函数指针.lambda表达式等).以下是std::function的一些常见 ...