上代码,使用hugging face fineturn vit模型

自己写的代码

from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST,CIFAR10
from torchvision.transforms import ToTensor
from torchvision.models import resnet101
from tqdm import tqdm # 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("mps")
# torch.device("cpu") # 加载 MNIST 数据集
train_dataset = CIFAR10(root="/data/xinyuuliu/datas", train=True, transform=ToTensor(), download=True)
test_dataset = CIFAR10(root="/data/xinyuuliu/datas", train=False, transform=ToTensor()) def collate_fn(batch):
"""
对batch数据进行处理
:param batch: [一个getitem的结果,getitem的结果,getitem的结果]
:return: 元组
"""
reviews,labels = zip(*batch)
# print(reviews)
# print(labels)
# reviews = torch.Tensor(reviews)
labels = torch.Tensor(labels) return reviews,labels
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False,collate_fn=collate_fn) # url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# image = Image.open(requests.get(url, stream=True).raw) processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
model.config.classifier = 'mlp'
model.config.num_labels = 10
# print(model.get_output_embeddings)
# print(model.classifier)
model.classifier = nn.Linear(768,10)
print(model.classifier) parameters = list(model.parameters())
for x in parameters[:-1]:
x.requires_grad = False model.to(device) # 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001) def train(model, dataloader, optimizer, criterion):
model.train()
running_loss = 0.0
for inputs, labels in tqdm(dataloader, desc="Training"):
# print(inputs)
inputs = processor(images=inputs, return_tensors="pt")
inputs['pixel_values'] = inputs['pixel_values'].to(device)
labels = labels.to(device)
# print(inputs['pixel_values'].shape)
# print(labels.shape)
optimizer.zero_grad() outputs = model(**inputs)
logits = outputs.logits # print(logits,labels)
loss = criterion(logits, labels.long())
loss.backward()
optimizer.step()
# model predicts one of the 1000 ImageNet classes
# predicted_class_idx = logits.argmax(-1).item()
# print("Predicted class:", model.config.id2label[predicted_class_idx])
running_loss += loss.item() * inputs['pixel_values'].size(0) epoch_loss = running_loss / len(dataloader.dataset)
return epoch_loss def evaluate(model, dataloader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in tqdm(dataloader, desc="Evaluating"):
inputs = processor(images=inputs, return_tensors="pt")
inputs['pixel_values'] = inputs['pixel_values'].to(device)
labels = labels.to(device) outputs = model(**inputs)
logits = outputs.logits predicted= logits.argmax(-1) total += labels.size(0)
correct += (predicted == labels).sum().item() accuracy = correct / total * 100
return accuracy # 训练和评估
num_epochs = 10 for epoch in range(num_epochs):
print(f"Epoch {epoch+1}/{num_epochs}")
train_loss = train(model, train_loader, optimizer, criterion)
print(f"Training Loss: {train_loss:.4f}") test_acc = evaluate(model, test_loader)
print(f"Test Accuracy: {test_acc:.2f}%")

  

chatgpt生成的代码

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from transformers import ViTModel, ViTForImageClassification
from tqdm import tqdm # 设置随机种子
torch.manual_seed(42) # 定义超参数
batch_size = 32
num_epochs = 10
learning_rate = 1e-4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
]) # 加载CIFAR-10数据集
train_dataset = CIFAR10(root='/data/xinyuuliu/datas', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='/data/xinyuuliu/datas', train=False, download=True, transform=transform) # 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # 加载预训练的ViT模型
vit_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device) # 替换分类头
num_classes = 10
vit_model.config.classifier = 'mlp'
vit_model.config.num_labels = num_classes
vit_model.classifier = nn.Linear(vit_model.config.hidden_size, num_classes).to(device) # 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vit_model.parameters(), lr=learning_rate) # 微调ViT模型
for epoch in range(num_epochs):
print("epoch:",epoch)
vit_model.train()
train_loss = 0.0
train_correct = 0 bar = tqdm(train_loader,total=len(train_loader))
for images, labels in bar:
images = images.to(device)
labels = labels.to(device) # 前向传播
outputs = vit_model(images)
loss = criterion(outputs.logits, labels) # 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step() train_loss += loss.item()
_, predicted = torch.max(outputs.logits, 1)
train_correct += (predicted == labels).sum().item() # 在训练集上计算准确率
train_accuracy = 100.0 * train_correct / len(train_dataset) # 在测试集上进行评估
vit_model.eval()
test_loss = 0.0
test_correct = 0 with torch.no_grad():
bar = tqdm(test_loader,total=len(test_loader))
for images, labels in bar:
images = images.to(device)
labels = labels.to(device) outputs = vit_model(images)
loss = criterion(outputs.logits, labels) test_loss += loss.item()
_, predicted = torch.max(outputs.logits, 1)
test_correct += (predicted == labels).sum().item() # 在测试集上计算准确率
test_accuracy = 100.0 * test_correct / len(test_dataset) # 打印每个epoch的训练损失、训练准确率和测试准确率
print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Accuracy: {test_accuracy:.2f}%')

  

huggingface vit训练CIFAR10数据集代码 ,可以改dataset训练自己的数据的更多相关文章

  1. Ubuntu+caffe训练cifar-10数据集

    1. 下载cifar-10数据库 ciffar-10数据集包含10种物体分类,50000张训练图片,10000张测试图片. 在终端执行指令下载cifar-10数据集(二进制文件): cd ~/caff ...

  2. Keras学习:试用卷积-训练CIFAR-10数据集

    import numpy as np import cPickle import keras as ks from keras.layers import Dense, Activation, Fla ...

  3. MXNet学习:试用卷积-训练CIFAR-10数据集

    第一次用卷积,看的别人的模型跑的CIFAR-10,不过吐槽一下...我觉着我的965m加速之后比我的cpu算起来没快多少..正确率64%的样子,没达到模型里说的75%,不知道问题出在哪里 import ...

  4. 使用caffe训练mnist数据集 - caffe教程实战(一)

    个人认为学习一个陌生的框架,最好从例子开始,所以我们也从一个例子开始. 学习本教程之前,你需要首先对卷积神经网络算法原理有些了解,而且安装好了caffe 卷积神经网络原理参考:http://cs231 ...

  5. CaffeExample 在CIFAR-10数据集上训练与测试

    本文主要来自Caffe作者Yangqing Jia网站给出的examples. @article{jia2014caffe, Author = {Jia, Yangqing and Shelhamer ...

  6. 仿照CIFAR-10数据集格式,制作自己的数据集

    本系列文章由 @yhl_leo 出品,转载请注明出处. 文章链接: http://blog.csdn.net/yhl_leo/article/details/50801226 前一篇博客:C/C++ ...

  7. TensorFlow CNN 测试CIFAR-10数据集

    本系列文章由 @yhl_leo 出品,转载请注明出处. 文章链接: http://blog.csdn.net/yhl_leo/article/details/50738311 1 CIFAR-10 数 ...

  8. caffe︱cifar-10数据集quick模型的官方案例

    准备拿几个caffe官方案例用来练习,就看到了caffe中的官方案例有cifar-10数据集.于是练习了一下,在CPU情况下构建quick模型.主要参考博客:liumaolincycle的博客 配置: ...

  9. 实践详细篇-Windows下使用VS2015编译的Caffe训练mnist数据集

    上一篇记录的是学习caffe前的环境准备以及如何创建好自己需要的caffe版本.这一篇记录的是如何使用编译好的caffe做训练mnist数据集,步骤编号延用上一篇 <实践详细篇-Windows下 ...

  10. CIFAR-10数据集图像分类【PCA+基于最小错误率的贝叶斯决策】

    CIFAR-10和CIFAR-100均是带有标签的数据集,都出自于规模更大的一个数据集,他有八千万张小图片.而本次实验采用CIFAR-10数据集,该数据集共有60000张彩色图像,这些图像是32*32 ...

随机推荐

  1. vscode 切换页签快捷键 自定义 Ctrl+H Ctrl+L 左右切换

    今天需要整理写资料,需要在多个页签之间切换,发现自定义了快捷. 好久不用这个快捷键,都快忘了. vscode 切换页签快捷键 自定义 Ctrl+H Ctrl+L 左右切换

  2. 关于百分百浏览器(cent browser)无法使用QQ快捷登录问题

    个人比较喜欢用百分百浏览器,但是QQ似乎不允许此浏览器进行登录,参考了下网上提供的思路,研究解决了QQ无法登录的问题 主要就设置了下证书,详情步骤见下面图片

  3. .Net Core 你必须知道的source-generators

    源生成器是 C# 9 中引入的一项功能,允许在编译过程中动态生成代码. 它们直接与 C# 编译器集成(Roslyn)并在编译时运行,分析源代码并根据分析结果生成附加代码. 源生成器提供了一种简化的自动 ...

  4. 译:使用 Bun 执行 Shell 脚本

    原文地址(Bun Blog): https://bun.sh/blog/the-bun-shell 作者: jarredsumner 发布时间:2024-01-20 前言 JavaScript 是世界 ...

  5. 爬虫实战:从HTTP请求获取数据解析社区

    在过去的实践中,我们通常通过爬取HTML网页来解析并提取所需数据,然而这只是一种方法.另一种更为直接的方式是通过发送HTTP请求来获取数据.考虑到大多数常见服务商的数据都是通过HTTP接口封装的,因此 ...

  6. python opencv DNN 人脸检测

    import cv2 modelFile = "res10_300x300_ssd_iter_140000_fp16.caffemodel" configFile = " ...

  7. archlinux virtualbox使用文件共享 主机arch,客机windows8.1 windows10

    参照 https://www.cnblogs.com/cuitang/p/11263008.html 1.安装virtualbox增强功能VBoxGuestAdditions.iso (1)从virt ...

  8. Windows下mDNS查询API—DnsStartMulticastQuery/DnsStopMulticastQuery的使用

    背景及问题: 目前很多局域网设备通过mNDS协议实现互联,IP地址为自动IP段-169.254.x.x,有时候设备厂家提供的API需要通过知晓局域网中的IP地址/设备名,才能连接该设备.这样要求每个软 ...

  9. 第一个hello驱动

    Linux驱动程序的分类 字符设备驱动.块设备驱动和网络设备驱动. Linux驱动程序运行方式 把驱动程序编译进内核里面,这样内核启动后就会自动运行驱动程序了: 把驱动程序编译成以.ko为后缀的模块文 ...

  10. #容斥,搜索,线性筛#CF83D Numbers

    洛谷 CF83D 分析 题意就是\(\sum_{i=l}^r[k|i]*[mn[\frac{i}{k}]\geq k]\) 首先线性筛每个数的最小质因数,如果\(\frac{r}{k}\)较小直接暴力 ...