上代码,使用hugging face fineturn vit模型

自己写的代码

from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST,CIFAR10
from torchvision.transforms import ToTensor
from torchvision.models import resnet101
from tqdm import tqdm # 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("mps")
# torch.device("cpu") # 加载 MNIST 数据集
train_dataset = CIFAR10(root="/data/xinyuuliu/datas", train=True, transform=ToTensor(), download=True)
test_dataset = CIFAR10(root="/data/xinyuuliu/datas", train=False, transform=ToTensor()) def collate_fn(batch):
"""
对batch数据进行处理
:param batch: [一个getitem的结果,getitem的结果,getitem的结果]
:return: 元组
"""
reviews,labels = zip(*batch)
# print(reviews)
# print(labels)
# reviews = torch.Tensor(reviews)
labels = torch.Tensor(labels) return reviews,labels
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False,collate_fn=collate_fn) # url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# image = Image.open(requests.get(url, stream=True).raw) processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
model.config.classifier = 'mlp'
model.config.num_labels = 10
# print(model.get_output_embeddings)
# print(model.classifier)
model.classifier = nn.Linear(768,10)
print(model.classifier) parameters = list(model.parameters())
for x in parameters[:-1]:
x.requires_grad = False model.to(device) # 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001) def train(model, dataloader, optimizer, criterion):
model.train()
running_loss = 0.0
for inputs, labels in tqdm(dataloader, desc="Training"):
# print(inputs)
inputs = processor(images=inputs, return_tensors="pt")
inputs['pixel_values'] = inputs['pixel_values'].to(device)
labels = labels.to(device)
# print(inputs['pixel_values'].shape)
# print(labels.shape)
optimizer.zero_grad() outputs = model(**inputs)
logits = outputs.logits # print(logits,labels)
loss = criterion(logits, labels.long())
loss.backward()
optimizer.step()
# model predicts one of the 1000 ImageNet classes
# predicted_class_idx = logits.argmax(-1).item()
# print("Predicted class:", model.config.id2label[predicted_class_idx])
running_loss += loss.item() * inputs['pixel_values'].size(0) epoch_loss = running_loss / len(dataloader.dataset)
return epoch_loss def evaluate(model, dataloader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in tqdm(dataloader, desc="Evaluating"):
inputs = processor(images=inputs, return_tensors="pt")
inputs['pixel_values'] = inputs['pixel_values'].to(device)
labels = labels.to(device) outputs = model(**inputs)
logits = outputs.logits predicted= logits.argmax(-1) total += labels.size(0)
correct += (predicted == labels).sum().item() accuracy = correct / total * 100
return accuracy # 训练和评估
num_epochs = 10 for epoch in range(num_epochs):
print(f"Epoch {epoch+1}/{num_epochs}")
train_loss = train(model, train_loader, optimizer, criterion)
print(f"Training Loss: {train_loss:.4f}") test_acc = evaluate(model, test_loader)
print(f"Test Accuracy: {test_acc:.2f}%")

  

chatgpt生成的代码

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from transformers import ViTModel, ViTForImageClassification
from tqdm import tqdm # 设置随机种子
torch.manual_seed(42) # 定义超参数
batch_size = 32
num_epochs = 10
learning_rate = 1e-4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
]) # 加载CIFAR-10数据集
train_dataset = CIFAR10(root='/data/xinyuuliu/datas', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='/data/xinyuuliu/datas', train=False, download=True, transform=transform) # 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # 加载预训练的ViT模型
vit_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device) # 替换分类头
num_classes = 10
vit_model.config.classifier = 'mlp'
vit_model.config.num_labels = num_classes
vit_model.classifier = nn.Linear(vit_model.config.hidden_size, num_classes).to(device) # 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vit_model.parameters(), lr=learning_rate) # 微调ViT模型
for epoch in range(num_epochs):
print("epoch:",epoch)
vit_model.train()
train_loss = 0.0
train_correct = 0 bar = tqdm(train_loader,total=len(train_loader))
for images, labels in bar:
images = images.to(device)
labels = labels.to(device) # 前向传播
outputs = vit_model(images)
loss = criterion(outputs.logits, labels) # 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step() train_loss += loss.item()
_, predicted = torch.max(outputs.logits, 1)
train_correct += (predicted == labels).sum().item() # 在训练集上计算准确率
train_accuracy = 100.0 * train_correct / len(train_dataset) # 在测试集上进行评估
vit_model.eval()
test_loss = 0.0
test_correct = 0 with torch.no_grad():
bar = tqdm(test_loader,total=len(test_loader))
for images, labels in bar:
images = images.to(device)
labels = labels.to(device) outputs = vit_model(images)
loss = criterion(outputs.logits, labels) test_loss += loss.item()
_, predicted = torch.max(outputs.logits, 1)
test_correct += (predicted == labels).sum().item() # 在测试集上计算准确率
test_accuracy = 100.0 * test_correct / len(test_dataset) # 打印每个epoch的训练损失、训练准确率和测试准确率
print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Accuracy: {test_accuracy:.2f}%')

  

huggingface vit训练CIFAR10数据集代码 ,可以改dataset训练自己的数据的更多相关文章

  1. Ubuntu+caffe训练cifar-10数据集

    1. 下载cifar-10数据库 ciffar-10数据集包含10种物体分类,50000张训练图片,10000张测试图片. 在终端执行指令下载cifar-10数据集(二进制文件): cd ~/caff ...

  2. Keras学习:试用卷积-训练CIFAR-10数据集

    import numpy as np import cPickle import keras as ks from keras.layers import Dense, Activation, Fla ...

  3. MXNet学习:试用卷积-训练CIFAR-10数据集

    第一次用卷积,看的别人的模型跑的CIFAR-10,不过吐槽一下...我觉着我的965m加速之后比我的cpu算起来没快多少..正确率64%的样子,没达到模型里说的75%,不知道问题出在哪里 import ...

  4. 使用caffe训练mnist数据集 - caffe教程实战(一)

    个人认为学习一个陌生的框架,最好从例子开始,所以我们也从一个例子开始. 学习本教程之前,你需要首先对卷积神经网络算法原理有些了解,而且安装好了caffe 卷积神经网络原理参考:http://cs231 ...

  5. CaffeExample 在CIFAR-10数据集上训练与测试

    本文主要来自Caffe作者Yangqing Jia网站给出的examples. @article{jia2014caffe, Author = {Jia, Yangqing and Shelhamer ...

  6. 仿照CIFAR-10数据集格式,制作自己的数据集

    本系列文章由 @yhl_leo 出品,转载请注明出处. 文章链接: http://blog.csdn.net/yhl_leo/article/details/50801226 前一篇博客:C/C++ ...

  7. TensorFlow CNN 测试CIFAR-10数据集

    本系列文章由 @yhl_leo 出品,转载请注明出处. 文章链接: http://blog.csdn.net/yhl_leo/article/details/50738311 1 CIFAR-10 数 ...

  8. caffe︱cifar-10数据集quick模型的官方案例

    准备拿几个caffe官方案例用来练习,就看到了caffe中的官方案例有cifar-10数据集.于是练习了一下,在CPU情况下构建quick模型.主要参考博客:liumaolincycle的博客 配置: ...

  9. 实践详细篇-Windows下使用VS2015编译的Caffe训练mnist数据集

    上一篇记录的是学习caffe前的环境准备以及如何创建好自己需要的caffe版本.这一篇记录的是如何使用编译好的caffe做训练mnist数据集,步骤编号延用上一篇 <实践详细篇-Windows下 ...

  10. CIFAR-10数据集图像分类【PCA+基于最小错误率的贝叶斯决策】

    CIFAR-10和CIFAR-100均是带有标签的数据集,都出自于规模更大的一个数据集,他有八千万张小图片.而本次实验采用CIFAR-10数据集,该数据集共有60000张彩色图像,这些图像是32*32 ...

随机推荐

  1. 快速带你入门css

    css复习笔记 1. css样式值 1.1 文字样式 1 p{ 2 font-size: 30px;/*设置文字大小*/ 3 font-weight: bold;/*文字加粗*/ 4 font-sty ...

  2. 开源好用的所见即所得(WYSIWYG)编辑器:Editor.js

    @ 目录 特点 基于区块 干净的数据 界面与交互 插件 标题和文本 图片 列表 Todo 表格 使用 安装 创建编辑器实例 配置工具 本地化 自定义样式 今天介绍一个开源好用的Web所见即所得(WYS ...

  3. Vue 长文本组件(有展开更多按钮)实现 附源码及使用

    原文地址:Vue 长文本组件(有展开更多按钮) | Stars-One的杂货小窝 最近项目需要优化长文本的显示,如果长文本过长,固定显示几行并显示一个展开更多的按钮,点击按钮即可把隐藏的文本显示出来 ...

  4. 修改easyui日期控件只显示年月,并且只能选择年月

    <!DOCTYPE html> <html> <head> <meta charset="UTF-8"> <title> ...

  5. 大年学习linux(第五节---目录结构)

    五.目录结构 可以用ls / 查看linux的目录结构 bin data etc lib media opt root sbin sys usr boot dev home lib64 mnt pro ...

  6. openApi generator总是生成类名为 defaultApi

    生成器可以开启 useTags 设置,开启之后会根据 api 文档中的 tags 生成前缀类名,因此,要不生成 defaultApi 需要以下操作: 1.openApi 文档中每个 url 必须要有 ...

  7. iOS模拟器 Unable to boot the Simulator —— Ficow笔记

    本文首发于 Ficow Shen's Blog,原文地址: iOS模拟器 Unable to boot the Simulator -- Ficow笔记. 内容概览 前言 终结模拟器进程 命令行改权限 ...

  8. iptables-save 命令使用总结

    转载请注明出处: iptables-save 命令在 Linux 系统中用于将当前运行的 iptables 防火墙规则导出到一个文件中.这对于备份规则.迁移规则或在不同系统间共享规则配置非常有用. 基 ...

  9. [ERROR] “不支持使用 SOAP 编码。SOAP 扩展元素包含 use=“encoded“ “ 无法解析 WSDL。

    下载axis-1_4,地址https://archive.apache.org/dist/ws/axis/1_4/ 解压,进入D:\axis-1_4\lib 执行命令 java -cp mail.ja ...

  10. FPGA中的时钟域问题

    FPGA中的时钟域问题 一.时钟域的定义 所谓时钟域,就是同一个时钟驱动的区域.这里的驱动,是指时钟刷新D触发器的事件,体现在verilog中就是always的边沿触发信号.单一时钟域是FPGA的基本 ...