先复制一点知乎上的内容

按照上面的流程图,一个ViT block可以分为以下几个步骤

(1) patch embedding:例如输入图片大小为224x224,将图片分为固定大小的patch,patch大小为16x16,则每张图像会生成224x224/16x16=196个patch,即输入序列长度为196,每个patch维度16x16x3=768,线性投射层的维度为768xN (N=768),因此输入通过线性投射层之后的维度依然为196x768,即一共有196个token,每个token的维度是768。这里还需要加上一个特殊字符cls,因此最终的维度是197x768。到目前为止,已经通过patch embedding将一个视觉问题转化为了一个seq2seq问题

(2) positional encoding(standard learnable 1D position embeddings):ViT同样需要加入位置编码,位置编码可以理解为一张表,表一共有N行,N的大小和输入序列长度相同,每一行代表一个向量,向量的维度和输入序列embedding的维度相同(768)。注意位置编码的操作是sum,而不是concat。加入位置编码信息之后,维度依然是197x768

(3) LN/multi-head attention/LN:LN输出维度依然是197x768。多头自注意力时,先将输入映射到q,k,v,如果只有一个头,qkv的维度都是197x768,如果有12个头(768/12=64),则qkv的维度是197x64,一共有12组qkv,最后再将12组qkv的输出拼接起来,输出维度是197x768,然后在过一层LN,维度依然是197x768

(4) MLP:将维度放大再缩小回去,197x768放大为197x3072,再缩小变为197x768

一个block之后维度依然和输入相同,都是197x768,因此可以堆叠多个block。最后会将特殊字符cls对应的输出 Z0 作为encoder的最终输出 ,代表最终的image presentation(另一种做法是不加cls字符,对所有的tokens的输出做一个平均),如下图公式(4),后面接一个MLP进行图片分类

vit 的 numpy 实现代码,可以直接看懂各个部分的细节实现 ,和bert有一些不一样,除了embedding层不一样之外,还有模型结构有有些不同,主要是layer_normalization放在了attention层和feed_forword层之前,bert都是放在之后

import numpy as np
import os
from PIL import Image # 加载保存的模型数据
model_data = np.load('vit_model_params.npz')
for i in model_data:
# print(i)
print(i,model_data[i].shape) patch_embedding_weight = model_data["vit.embeddings.patch_embeddings.projection.weight"]
patch_embedding_bias = model_data["vit.embeddings.patch_embeddings.projection.bias"]
position_embeddings = model_data["vit.embeddings.position_embeddings"]
cls_token_embeddings = model_data["vit.embeddings.cls_token"] def patch_embedding(images):
# 卷积核大小
kernel_size = 16
return conv2d(images, patch_embedding_weight,patch_embedding_bias,stride=kernel_size) def position_embedding():
return position_embeddings def model_input(images): patch_embedded = np.transpose(patch_embedding(images).reshape([1,768,-1]), (0, 2, 1)) patch_embedded = np.concatenate([cls_token_embeddings,patch_embedded],axis=1) # position_ids = np.array(range(patch_embedded.shape[1])) # 位置id
# 位置嵌入矩阵,形状为 (max_position, embedding_size)
position_embedded = position_embedding() embedding_output = patch_embedded + position_embedded return embedding_output def softmax(x, axis=None):
# e_x = np.exp(x).astype(np.float32) #
e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
sum_ex = np.sum(e_x, axis=axis,keepdims=True).astype(np.float32)
return e_x / sum_ex def conv2d(images,weight,bias,stride=1,padding=0):
# 卷积操作
N, C, H, W = images.shape
F, _, HH, WW = weight.shape
# 计算卷积后的输出尺寸
H_out = (H - HH + 2 * padding) // stride + 1
W_out = (W - WW + 2 * padding) // stride + 1
# 初始化卷积层输出
out = np.zeros((N, F, H_out, W_out))
# 执行卷积运算
for i in range(H_out):
for j in range(W_out):
# 提取当前卷积窗口
window = images[:, :, i * stride:i * stride + HH, j * stride:j * stride + WW]
# 执行卷积运算
out[:, :, i, j] = np.sum(window * weight, axis=(1, 2, 3)) + bias
# 输出结果
# print("卷积层输出尺寸:", out.shape)
return out def scaled_dot_product_attention(Q, K, V, mask=None):
d_k = Q.shape[-1]
scores = np.matmul(Q, K.transpose(0, 2, 1)) / np.sqrt(d_k)
if mask is not None:
scores = np.where(mask, scores, np.full_like(scores, -np.inf))
attention_weights = softmax(scores, axis=-1)
# print(attention_weights)
# print(np.sum(attention_weights,axis=-1))
output = np.matmul(attention_weights, V)
return output, attention_weights def multihead_attention(input, num_heads,W_Q,B_Q,W_K,B_K,W_V,B_V,W_O,B_O): q = np.matmul(input, W_Q.T)+B_Q
k = np.matmul(input, W_K.T)+B_K
v = np.matmul(input, W_V.T)+B_V # 分割输入为多个头
q = np.split(q, num_heads, axis=-1)
k = np.split(k, num_heads, axis=-1)
v = np.split(v, num_heads, axis=-1) outputs = []
for q_,k_,v_ in zip(q,k,v):
output, attention_weights = scaled_dot_product_attention(q_, k_, v_)
outputs.append(output)
outputs = np.concatenate(outputs, axis=-1)
outputs = np.matmul(outputs, W_O.T)+B_O
return outputs def layer_normalization(x, weight, bias, eps=1e-12):
mean = np.mean(x, axis=-1, keepdims=True)
variance = np.var(x, axis=-1, keepdims=True)
std = np.sqrt(variance + eps)
normalized_x = (x - mean) / std
output = weight * normalized_x + bias
return output def feed_forward_layer(inputs, weight, bias, activation='relu'):
linear_output = np.matmul(inputs,weight) + bias if activation == 'relu':
activated_output = np.maximum(0, linear_output) # ReLU激活函数
elif activation == 'gelu':
activated_output = 0.5 * linear_output * (1 + np.tanh(np.sqrt(2 / np.pi) * (linear_output + 0.044715 * np.power(linear_output, 3)))) # GELU激活函数 elif activation == "tanh" :
activated_output = np.tanh(linear_output)
else:
activated_output = linear_output # 无激活函数 return activated_output def residual_connection(inputs, residual):
# 残差连接
residual_output = inputs + residual
return residual_output def vit(input,num_heads=12): for i in range(12):
# 调用多头自注意力函数
W_Q = model_data['vit.encoder.layer.{}.attention.attention.query.weight'.format(i)]
B_Q = model_data['vit.encoder.layer.{}.attention.attention.query.bias'.format(i)]
W_K = model_data['vit.encoder.layer.{}.attention.attention.key.weight'.format(i)]
B_K = model_data['vit.encoder.layer.{}.attention.attention.key.bias'.format(i)]
W_V = model_data['vit.encoder.layer.{}.attention.attention.value.weight'.format(i)]
B_V = model_data['vit.encoder.layer.{}.attention.attention.value.bias'.format(i)]
W_O = model_data['vit.encoder.layer.{}.attention.output.dense.weight'.format(i)]
B_O = model_data['vit.encoder.layer.{}.attention.output.dense.bias'.format(i)]
intermediate_weight = model_data['vit.encoder.layer.{}.intermediate.dense.weight'.format(i)]
intermediate_bias = model_data['vit.encoder.layer.{}.intermediate.dense.bias'.format(i)]
dense_weight = model_data['vit.encoder.layer.{}.output.dense.weight'.format(i)]
dense_bias = model_data['vit.encoder.layer.{}.output.dense.bias'.format(i)]
LayerNorm_before_weight = model_data['vit.encoder.layer.{}.layernorm_before.weight'.format(i)]
LayerNorm_before_bias = model_data['vit.encoder.layer.{}.layernorm_before.bias'.format(i)]
LayerNorm_after_weight = model_data['vit.encoder.layer.{}.layernorm_after.weight'.format(i)]
LayerNorm_after_bias = model_data['vit.encoder.layer.{}.layernorm_after.bias'.format(i)] output = layer_normalization(input,LayerNorm_before_weight,LayerNorm_before_bias)
output = multihead_attention(output, num_heads,W_Q,B_Q,W_K,B_K,W_V,B_V,W_O,B_O)
output1 = residual_connection(input,output)
#这里和模型输出一致 output = layer_normalization(output1,LayerNorm_after_weight,LayerNorm_after_bias) #一致
output = feed_forward_layer(output, intermediate_weight.T, intermediate_bias, activation='gelu')
output = feed_forward_layer(output, dense_weight.T, dense_bias, activation='')
output2 = residual_connection(output1,output) input = output2 bert_pooler_dense_weight = model_data['vit.layernorm.weight']
bert_pooler_dense_bias = model_data['vit.layernorm.bias']
output = layer_normalization(output2[:,0],bert_pooler_dense_weight,bert_pooler_dense_bias ) #一致
classifier_weight = model_data['classifier.weight']
classifier_bias = model_data['classifier.bias']
output = feed_forward_layer(output,classifier_weight.T,classifier_bias,activation="" ) #一致
output = np.argmax(output,axis=-1)
return output folder_path = './cifar10' # 替换为图片所在的文件夹路径
def infer_images_in_folder(folder_path):
for file_name in os.listdir(folder_path):
file_path = os.path.join(folder_path, file_name)
if os.path.isfile(file_path) and file_name.endswith(('.jpg', '.jpeg', '.png')):
image = Image.open(file_path)
image = image.resize((224, 224))
label = file_name.split(".")[0].split("_")[1]
image = np.array(image)/255.0
image = np.transpose(image, (2, 0, 1))
image = np.expand_dims(image,axis=0)
print("file_path:",file_path,"img size:",image.shape,"label:",label)
input = model_input(image)
predicted_class = vit(input)
print('Predicted class:', predicted_class) if __name__ == "__main__": infer_images_in_folder(folder_path)

结果:

file_path: ./cifar10/8619_5.jpg img size: (1, 3, 224, 224) label: 5
Predicted class: [5]
file_path: ./cifar10/6042_6.jpg img size: (1, 3, 224, 224) label: 6
Predicted class: [6]
file_path: ./cifar10/6801_6.jpg img size: (1, 3, 224, 224) label: 6
Predicted class: [6]
file_path: ./cifar10/7946_1.jpg img size: (1, 3, 224, 224) label: 1
Predicted class: [1]
file_path: ./cifar10/6925_2.jpg img size: (1, 3, 224, 224) label: 2
Predicted class: [2]
file_path: ./cifar10/6007_6.jpg img size: (1, 3, 224, 224) label: 6
Predicted class: [6]
file_path: ./cifar10/7903_1.jpg img size: (1, 3, 224, 224) label: 1
Predicted class: [1]
file_path: ./cifar10/7064_5.jpg img size: (1, 3, 224, 224) label: 5
Predicted class: [5]
file_path: ./cifar10/2713_8.jpg img size: (1, 3, 224, 224) label: 8
Predicted class: [8]
file_path: ./cifar10/8575_9.jpg img size: (1, 3, 224, 224) label: 9
Predicted class: [9]
file_path: ./cifar10/1985_6.jpg img size: (1, 3, 224, 224) label: 6
Predicted class: [6]
file_path: ./cifar10/5312_5.jpg img size: (1, 3, 224, 224) label: 5
Predicted class: [5]
file_path: ./cifar10/593_6.jpg img size: (1, 3, 224, 224) label: 6
Predicted class: [6]
file_path: ./cifar10/8093_7.jpg img size: (1, 3, 224, 224) label: 7
Predicted class: [7]
file_path: ./cifar10/6862_5.jpg img size: (1, 3, 224, 224) label: 5

模型参数:

vit.embeddings.cls_token (1, 1, 768)
vit.embeddings.position_embeddings (1, 197, 768)
vit.embeddings.patch_embeddings.projection.weight (768, 3, 16, 16)
vit.embeddings.patch_embeddings.projection.bias (768,)
vit.encoder.layer.0.attention.attention.query.weight (768, 768)
vit.encoder.layer.0.attention.attention.query.bias (768,)
vit.encoder.layer.0.attention.attention.key.weight (768, 768)
vit.encoder.layer.0.attention.attention.key.bias (768,)
vit.encoder.layer.0.attention.attention.value.weight (768, 768)
vit.encoder.layer.0.attention.attention.value.bias (768,)
vit.encoder.layer.0.attention.output.dense.weight (768, 768)
vit.encoder.layer.0.attention.output.dense.bias (768,)
vit.encoder.layer.0.intermediate.dense.weight (3072, 768)
vit.encoder.layer.0.intermediate.dense.bias (3072,)
vit.encoder.layer.0.output.dense.weight (768, 3072)
vit.encoder.layer.0.output.dense.bias (768,)
vit.encoder.layer.0.layernorm_before.weight (768,)
vit.encoder.layer.0.layernorm_before.bias (768,)
vit.encoder.layer.0.layernorm_after.weight (768,)
vit.encoder.layer.0.layernorm_after.bias (768,)
vit.encoder.layer.1.attention.attention.query.weight (768, 768)
vit.encoder.layer.1.attention.attention.query.bias (768,)
vit.encoder.layer.1.attention.attention.key.weight (768, 768)
vit.encoder.layer.1.attention.attention.key.bias (768,)
vit.encoder.layer.1.attention.attention.value.weight (768, 768)
vit.encoder.layer.1.attention.attention.value.bias (768,)
vit.encoder.layer.1.attention.output.dense.weight (768, 768)
vit.encoder.layer.1.attention.output.dense.bias (768,)
vit.encoder.layer.1.intermediate.dense.weight (3072, 768)
vit.encoder.layer.1.intermediate.dense.bias (3072,)
vit.encoder.layer.1.output.dense.weight (768, 3072)
vit.encoder.layer.1.output.dense.bias (768,)
vit.encoder.layer.1.layernorm_before.weight (768,)
vit.encoder.layer.1.layernorm_before.bias (768,)
vit.encoder.layer.1.layernorm_after.weight (768,)
vit.encoder.layer.1.layernorm_after.bias (768,)
vit.encoder.layer.2.attention.attention.query.weight (768, 768)
vit.encoder.layer.2.attention.attention.query.bias (768,)
vit.encoder.layer.2.attention.attention.key.weight (768, 768)
vit.encoder.layer.2.attention.attention.key.bias (768,)
vit.encoder.layer.2.attention.attention.value.weight (768, 768)
vit.encoder.layer.2.attention.attention.value.bias (768,)
vit.encoder.layer.2.attention.output.dense.weight (768, 768)
vit.encoder.layer.2.attention.output.dense.bias (768,)
vit.encoder.layer.2.intermediate.dense.weight (3072, 768)
vit.encoder.layer.2.intermediate.dense.bias (3072,)
vit.encoder.layer.2.output.dense.weight (768, 3072)
vit.encoder.layer.2.output.dense.bias (768,)
vit.encoder.layer.2.layernorm_before.weight (768,)
vit.encoder.layer.2.layernorm_before.bias (768,)
vit.encoder.layer.2.layernorm_after.weight (768,)
vit.encoder.layer.2.layernorm_after.bias (768,)
vit.encoder.layer.3.attention.attention.query.weight (768, 768)
vit.encoder.layer.3.attention.attention.query.bias (768,)
vit.encoder.layer.3.attention.attention.key.weight (768, 768)
vit.encoder.layer.3.attention.attention.key.bias (768,)
vit.encoder.layer.3.attention.attention.value.weight (768, 768)
vit.encoder.layer.3.attention.attention.value.bias (768,)
vit.encoder.layer.3.attention.output.dense.weight (768, 768)
vit.encoder.layer.3.attention.output.dense.bias (768,)
vit.encoder.layer.3.intermediate.dense.weight (3072, 768)
vit.encoder.layer.3.intermediate.dense.bias (3072,)
vit.encoder.layer.3.output.dense.weight (768, 3072)
vit.encoder.layer.3.output.dense.bias (768,)
vit.encoder.layer.3.layernorm_before.weight (768,)
vit.encoder.layer.3.layernorm_before.bias (768,)
vit.encoder.layer.3.layernorm_after.weight (768,)
vit.encoder.layer.3.layernorm_after.bias (768,)
vit.encoder.layer.4.attention.attention.query.weight (768, 768)
vit.encoder.layer.4.attention.attention.query.bias (768,)
vit.encoder.layer.4.attention.attention.key.weight (768, 768)
vit.encoder.layer.4.attention.attention.key.bias (768,)
vit.encoder.layer.4.attention.attention.value.weight (768, 768)
vit.encoder.layer.4.attention.attention.value.bias (768,)
vit.encoder.layer.4.attention.output.dense.weight (768, 768)
vit.encoder.layer.4.attention.output.dense.bias (768,)
vit.encoder.layer.4.intermediate.dense.weight (3072, 768)
vit.encoder.layer.4.intermediate.dense.bias (3072,)
vit.encoder.layer.4.output.dense.weight (768, 3072)
vit.encoder.layer.4.output.dense.bias (768,)
vit.encoder.layer.4.layernorm_before.weight (768,)
vit.encoder.layer.4.layernorm_before.bias (768,)
vit.encoder.layer.4.layernorm_after.weight (768,)
vit.encoder.layer.4.layernorm_after.bias (768,)
vit.encoder.layer.5.attention.attention.query.weight (768, 768)
vit.encoder.layer.5.attention.attention.query.bias (768,)
vit.encoder.layer.5.attention.attention.key.weight (768, 768)
vit.encoder.layer.5.attention.attention.key.bias (768,)
vit.encoder.layer.5.attention.attention.value.weight (768, 768)
vit.encoder.layer.5.attention.attention.value.bias (768,)
vit.encoder.layer.5.attention.output.dense.weight (768, 768)
vit.encoder.layer.5.attention.output.dense.bias (768,)
vit.encoder.layer.5.intermediate.dense.weight (3072, 768)
vit.encoder.layer.5.intermediate.dense.bias (3072,)
vit.encoder.layer.5.output.dense.weight (768, 3072)
vit.encoder.layer.5.output.dense.bias (768,)
vit.encoder.layer.5.layernorm_before.weight (768,)
vit.encoder.layer.5.layernorm_before.bias (768,)
vit.encoder.layer.5.layernorm_after.weight (768,)
vit.encoder.layer.5.layernorm_after.bias (768,)
vit.encoder.layer.6.attention.attention.query.weight (768, 768)
vit.encoder.layer.6.attention.attention.query.bias (768,)
vit.encoder.layer.6.attention.attention.key.weight (768, 768)
vit.encoder.layer.6.attention.attention.key.bias (768,)
vit.encoder.layer.6.attention.attention.value.weight (768, 768)
vit.encoder.layer.6.attention.attention.value.bias (768,)
vit.encoder.layer.6.attention.output.dense.weight (768, 768)
vit.encoder.layer.6.attention.output.dense.bias (768,)
vit.encoder.layer.6.intermediate.dense.weight (3072, 768)
vit.encoder.layer.6.intermediate.dense.bias (3072,)
vit.encoder.layer.6.output.dense.weight (768, 3072)
vit.encoder.layer.6.output.dense.bias (768,)
vit.encoder.layer.6.layernorm_before.weight (768,)
vit.encoder.layer.6.layernorm_before.bias (768,)
vit.encoder.layer.6.layernorm_after.weight (768,)
vit.encoder.layer.6.layernorm_after.bias (768,)
vit.encoder.layer.7.attention.attention.query.weight (768, 768)
vit.encoder.layer.7.attention.attention.query.bias (768,)
vit.encoder.layer.7.attention.attention.key.weight (768, 768)
vit.encoder.layer.7.attention.attention.key.bias (768,)
vit.encoder.layer.7.attention.attention.value.weight (768, 768)
vit.encoder.layer.7.attention.attention.value.bias (768,)
vit.encoder.layer.7.attention.output.dense.weight (768, 768)
vit.encoder.layer.7.attention.output.dense.bias (768,)
vit.encoder.layer.7.intermediate.dense.weight (3072, 768)
vit.encoder.layer.7.intermediate.dense.bias (3072,)
vit.encoder.layer.7.output.dense.weight (768, 3072)
vit.encoder.layer.7.output.dense.bias (768,)
vit.encoder.layer.7.layernorm_before.weight (768,)
vit.encoder.layer.7.layernorm_before.bias (768,)
vit.encoder.layer.7.layernorm_after.weight (768,)
vit.encoder.layer.7.layernorm_after.bias (768,)
vit.encoder.layer.8.attention.attention.query.weight (768, 768)
vit.encoder.layer.8.attention.attention.query.bias (768,)
vit.encoder.layer.8.attention.attention.key.weight (768, 768)
vit.encoder.layer.8.attention.attention.key.bias (768,)
vit.encoder.layer.8.attention.attention.value.weight (768, 768)
vit.encoder.layer.8.attention.attention.value.bias (768,)
vit.encoder.layer.8.attention.output.dense.weight (768, 768)
vit.encoder.layer.8.attention.output.dense.bias (768,)
vit.encoder.layer.8.intermediate.dense.weight (3072, 768)
vit.encoder.layer.8.intermediate.dense.bias (3072,)
vit.encoder.layer.8.output.dense.weight (768, 3072)
vit.encoder.layer.8.output.dense.bias (768,)
vit.encoder.layer.8.layernorm_before.weight (768,)
vit.encoder.layer.8.layernorm_before.bias (768,)
vit.encoder.layer.8.layernorm_after.weight (768,)
vit.encoder.layer.8.layernorm_after.bias (768,)
vit.encoder.layer.9.attention.attention.query.weight (768, 768)
vit.encoder.layer.9.attention.attention.query.bias (768,)
vit.encoder.layer.9.attention.attention.key.weight (768, 768)
vit.encoder.layer.9.attention.attention.key.bias (768,)
vit.encoder.layer.9.attention.attention.value.weight (768, 768)
vit.encoder.layer.9.attention.attention.value.bias (768,)
vit.encoder.layer.9.attention.output.dense.weight (768, 768)
vit.encoder.layer.9.attention.output.dense.bias (768,)
vit.encoder.layer.9.intermediate.dense.weight (3072, 768)
vit.encoder.layer.9.intermediate.dense.bias (3072,)
vit.encoder.layer.9.output.dense.weight (768, 3072)
vit.encoder.layer.9.output.dense.bias (768,)
vit.encoder.layer.9.layernorm_before.weight (768,)
vit.encoder.layer.9.layernorm_before.bias (768,)
vit.encoder.layer.9.layernorm_after.weight (768,)
vit.encoder.layer.9.layernorm_after.bias (768,)
vit.encoder.layer.10.attention.attention.query.weight (768, 768)
vit.encoder.layer.10.attention.attention.query.bias (768,)
vit.encoder.layer.10.attention.attention.key.weight (768, 768)
vit.encoder.layer.10.attention.attention.key.bias (768,)
vit.encoder.layer.10.attention.attention.value.weight (768, 768)
vit.encoder.layer.10.attention.attention.value.bias (768,)
vit.encoder.layer.10.attention.output.dense.weight (768, 768)
vit.encoder.layer.10.attention.output.dense.bias (768,)
vit.encoder.layer.10.intermediate.dense.weight (3072, 768)
vit.encoder.layer.10.intermediate.dense.bias (3072,)
vit.encoder.layer.10.output.dense.weight (768, 3072)
vit.encoder.layer.10.output.dense.bias (768,)
vit.encoder.layer.10.layernorm_before.weight (768,)
vit.encoder.layer.10.layernorm_before.bias (768,)
vit.encoder.layer.10.layernorm_after.weight (768,)
vit.encoder.layer.10.layernorm_after.bias (768,)
vit.encoder.layer.11.attention.attention.query.weight (768, 768)
vit.encoder.layer.11.attention.attention.query.bias (768,)
vit.encoder.layer.11.attention.attention.key.weight (768, 768)
vit.encoder.layer.11.attention.attention.key.bias (768,)
vit.encoder.layer.11.attention.attention.value.weight (768, 768)
vit.encoder.layer.11.attention.attention.value.bias (768,)
vit.encoder.layer.11.attention.output.dense.weight (768, 768)
vit.encoder.layer.11.attention.output.dense.bias (768,)
vit.encoder.layer.11.intermediate.dense.weight (3072, 768)
vit.encoder.layer.11.intermediate.dense.bias (3072,)
vit.encoder.layer.11.output.dense.weight (768, 3072)
vit.encoder.layer.11.output.dense.bias (768,)
vit.encoder.layer.11.layernorm_before.weight (768,)
vit.encoder.layer.11.layernorm_before.bias (768,)
vit.encoder.layer.11.layernorm_after.weight (768,)
vit.encoder.layer.11.layernorm_after.bias (768,)
vit.layernorm.weight (768,)
vit.layernorm.bias (768,)
classifier.weight (10, 768)
classifier.bias (10,)

hungging face模型训练代码 对cifar10训练,保存模型参数为numpy格式,方便numpy的模型加载:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from transformers import ViTModel, ViTForImageClassification
from tqdm import tqdm
import numpy as np # 设置随机种子
torch.manual_seed(42) # 定义超参数
batch_size = 64
num_epochs = 1
learning_rate = 1e-4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
]) # 加载CIFAR-10数据集
train_dataset = CIFAR10(root='/data/xinyuuliu/datas', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='/data/xinyuuliu/datas', train=False, download=True, transform=transform) # 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # 加载预训练的ViT模型
vit_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device) # 替换分类头
num_classes = 10
# vit_model.config.classifier = 'mlp'
# vit_model.config.num_labels = num_classes
vit_model.classifier = nn.Linear(vit_model.config.hidden_size, num_classes).to(device) # parameters = list(vit_model.parameters())
# for x in parameters[:-1]:
# x.requires_grad = False # 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vit_model.parameters(), lr=learning_rate) # 微调ViT模型
for epoch in range(num_epochs):
print("epoch:",epoch)
vit_model.train()
train_loss = 0.0
train_correct = 0 bar = tqdm(train_loader,total=len(train_loader))
for images, labels in bar:
images = images.to(device)
labels = labels.to(device) # 前向传播
outputs = vit_model(images)
loss = criterion(outputs.logits, labels) # 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step() train_loss += loss.item()
_, predicted = torch.max(outputs.logits, 1)
train_correct += (predicted == labels).sum().item() # 在训练集上计算准确率
train_accuracy = 100.0 * train_correct / len(train_dataset) # 在测试集上进行评估
vit_model.eval()
test_loss = 0.0
test_correct = 0 with torch.no_grad():
bar = tqdm(test_loader,total=len(test_loader))
for images, labels in bar:
images = images.to(device)
labels = labels.to(device) outputs = vit_model(images)
loss = criterion(outputs.logits, labels) test_loss += loss.item()
_, predicted = torch.max(outputs.logits, 1)
test_correct += (predicted == labels).sum().item() # 在测试集上计算准确率
test_accuracy = 100.0 * test_correct / len(test_dataset) # 打印每个epoch的训练损失、训练准确率和测试准确率
print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Accuracy: {test_accuracy:.2f}%') torch.save(vit_model.state_dict(), 'vit_model_parameters.pth') # 打印BERT模型的权重维度
for name, param in vit_model.named_parameters():
print(name, param.data.shape) # # # 保存模型参数为NumPy格式
model_params = {name: param.data.cpu().numpy() for name, param in vit_model.named_parameters()}
np.savez('vit_model_params.npz', **model_params)
# model_params

Epoch [1/1], Train Loss: 97.7498, Train Accuracy: 96.21%, Test Accuracy: 96.86%

我用numpy实现了VIT,手写vision transformer, 可在树莓派上运行,在hugging face上训练模型保存参数成numpy格式,纯numpy实现的更多相关文章

  1. 利用sklearn对MNIST手写数据集开始一个简单的二分类判别器项目(在这个过程中学习关于模型性能的评价指标,如accuracy,precision,recall,混淆矩阵)

    .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...

  2. Tensorflow之基于MNIST手写识别的入门介绍

    Tensorflow是当下AI热潮下,最为受欢迎的开源框架.无论是从Github上的fork数量还是star数量,还是从支持的语音,开发资料,社区活跃度等多方面,他当之为superstar. 在前面介 ...

  3. GAN实战笔记——第三章第一个GAN模型:生成手写数字

    第一个GAN模型-生成手写数字 一.GAN的基础:对抗训练 形式上,生成器和判别器由可微函数表示如神经网络,他们都有自己的代价函数.这两个网络是利用判别器的损失记性反向传播训练.判别器努力使真实样本输 ...

  4. 【TensorFlow篇】--Tensorflow框架实现SoftMax模型识别手写数字集

    一.前述 本文讲述用Tensorflow框架实现SoftMax模型识别手写数字集,来实现多分类. 同时对模型的保存和恢复做下示例. 二.具体原理 代码一:实现代码 #!/usr/bin/python ...

  5. OpenCV+TensorFlow实现自定义手写图像识别

    完整版请点击链接:https://mp.weixin.qq.com/s/5gHXGmLbtO7m3dOFrDUiHQ    或微信关注“大数据技术宅” 继用TensorFlow教你做手写字识别(准确率 ...

  6. 手写AVL 树(下)

    上一篇 手写AVL树上实现了AVL树的插入和查询 上代码: 头文件:AVL.h #include <iostream> template<typename T1,typename T ...

  7. mnist 手写数字识别

    mnist 手写数字识别三大步骤 1.定义分类模型2.训练模型3.评价模型 import tensorflow as tfimport input_datamnist = input_data.rea ...

  8. 手写Json转换

    在做项目的时候总是要手动将集合转换成json每次都很麻烦,于是就尝试着写了一个公用的方法,用于转换List to json: using System; using System.Collection ...

  9. 全命令行手写MapReduce并且打包运行

    主要要讲的有3个 java中的package是干啥的? 工作了好几年的都一定真正理解java里面的package关键字,这里在写MapReduce需要进行打包的时候突然发现命令行下打包运行居然不会了, ...

  10. 手写迷你SpringMVC框架

    前言 学习如何使用Spring,SpringMVC是很快的,但是在往后使用的过程中难免会想探究一下框架背后的原理是什么,本文将通过讲解如何手写一个简单版的springMVC框架,直接从代码上看框架中请 ...

随机推荐

  1. 随手记:redis 开发注意事项

    Redis开发建议 1.冷热数据分离,不要将所有数据全部都放到Redis中 虽然Redis支持持久化,但是Redis的数据存储全部都是在内存中的,成本昂贵.建议根据业务只将高频热数据存储到Redis中 ...

  2. MyBatisPlus解决逻辑删除与唯一索引的兼容问题

    需求背景 比如有张用户表,在插入或者更新数据的时候,我们需要 用户名称(username),不能重复. 我们首先考虑的是给该字段创建唯一索引 create unique index uni_usern ...

  3. 你真的懂synchronized锁?

    1. 前言 synchronized在我们的程序中非常的常见,主要是为了解决多个线程抢占同一个资源.那么我们知道synchronized有多种用法,以下从实践出发,题目由简入深,看你能答对几道题目? ...

  4. Vue修改单页面背景颜色

  5. Fatal error: Uncaught SoapFault exception:解决办法

    SOAP-ERROR: Parsing WSDL: failed to load external entity Fatal error: Uncaught SoapFault exception: ...

  6. 机器学习(七):梯度下降解决分类问题——perceptron感知机算法与SVM支持向量机算法进行二维点分类

    实验2 感知机算法与支持向量机算法 一.预备知识 1.感知机算法 二.实验目的 掌握感知机算法的原理及设计: 掌握利用感知机算法解决分类问题. 三.实验内容 设计感知机算法求解, 设计SVM算法求解( ...

  7. 逍遥自在学C语言 | 条件控制的正确使用姿势

    前言 在C语言中,有三种条件判断结构:if语句.if-else语句和switch语句. 一.人物简介 第一位闪亮登场,有请今后会一直教我们C语言的老师 -- 自在. 第二位上场的是和我们一起学习的小白 ...

  8. 2022-06-23:给定一个非负数组,任意选择数字,使累加和最大且为7的倍数,返回最大累加和。 n比较大,10的5次方。 来自美团。3.26笔试。

    2022-06-23:给定一个非负数组,任意选择数字,使累加和最大且为7的倍数,返回最大累加和. n比较大,10的5次方. 来自美团.3.26笔试. 答案2022-06-23: 要i还是不要i,递归. ...

  9. 2021-02-26:一个数组arr是二叉树的中序遍历结果,每条边的开销是父节点和子节点的乘积,总开销是所有边的开销之和。请问最小总开销是多少?

    2021-02-26:一个数组arr是二叉树的中序遍历结果,每条边的开销是父节点和子节点的乘积,总开销是所有边的开销之和.请问最小总开销是多少? 链接:https://www.nowcoder.com ...

  10. 爆肝一周,我开源了ChatGPT 中文版接口,官方1:1镜像支持全部 官方接口

    这里实现我之前文章承诺承接上文 人人实现ChatGPT自由,手把手教你零撸部署自己聊天私服 现在ChatGPT 提供了api接口 可以让我自己对接去实现我们自己想要gpt应用,但是由于一些原因,国内也 ...