我用numpy实现了VIT,手写vision transformer, 可在树莓派上运行,在hugging face上训练模型保存参数成numpy格式,纯numpy实现
先复制一点知乎上的内容

按照上面的流程图,一个ViT block可以分为以下几个步骤
(1) patch embedding:例如输入图片大小为224x224,将图片分为固定大小的patch,patch大小为16x16,则每张图像会生成224x224/16x16=196个patch,即输入序列长度为196,每个patch维度16x16x3=768,线性投射层的维度为768xN (N=768),因此输入通过线性投射层之后的维度依然为196x768,即一共有196个token,每个token的维度是768。这里还需要加上一个特殊字符cls,因此最终的维度是197x768。到目前为止,已经通过patch embedding将一个视觉问题转化为了一个seq2seq问题
(2) positional encoding(standard learnable 1D position embeddings):ViT同样需要加入位置编码,位置编码可以理解为一张表,表一共有N行,N的大小和输入序列长度相同,每一行代表一个向量,向量的维度和输入序列embedding的维度相同(768)。注意位置编码的操作是sum,而不是concat。加入位置编码信息之后,维度依然是197x768
(3) LN/multi-head attention/LN:LN输出维度依然是197x768。多头自注意力时,先将输入映射到q,k,v,如果只有一个头,qkv的维度都是197x768,如果有12个头(768/12=64),则qkv的维度是197x64,一共有12组qkv,最后再将12组qkv的输出拼接起来,输出维度是197x768,然后在过一层LN,维度依然是197x768
(4) MLP:将维度放大再缩小回去,197x768放大为197x3072,再缩小变为197x768
一个block之后维度依然和输入相同,都是197x768,因此可以堆叠多个block。最后会将特殊字符cls对应的输出 Z0 作为encoder的最终输出 ,代表最终的image presentation(另一种做法是不加cls字符,对所有的tokens的输出做一个平均),如下图公式(4),后面接一个MLP进行图片分类

vit 的 numpy 实现代码,可以直接看懂各个部分的细节实现 ,和bert有一些不一样,除了embedding层不一样之外,还有模型结构有有些不同,主要是layer_normalization放在了attention层和feed_forword层之前,bert都是放在之后
import numpy as np
import os
from PIL import Image # 加载保存的模型数据
model_data = np.load('vit_model_params.npz')
for i in model_data:
# print(i)
print(i,model_data[i].shape) patch_embedding_weight = model_data["vit.embeddings.patch_embeddings.projection.weight"]
patch_embedding_bias = model_data["vit.embeddings.patch_embeddings.projection.bias"]
position_embeddings = model_data["vit.embeddings.position_embeddings"]
cls_token_embeddings = model_data["vit.embeddings.cls_token"] def patch_embedding(images):
# 卷积核大小
kernel_size = 16
return conv2d(images, patch_embedding_weight,patch_embedding_bias,stride=kernel_size) def position_embedding():
return position_embeddings def model_input(images): patch_embedded = np.transpose(patch_embedding(images).reshape([1,768,-1]), (0, 2, 1)) patch_embedded = np.concatenate([cls_token_embeddings,patch_embedded],axis=1) # position_ids = np.array(range(patch_embedded.shape[1])) # 位置id
# 位置嵌入矩阵,形状为 (max_position, embedding_size)
position_embedded = position_embedding() embedding_output = patch_embedded + position_embedded return embedding_output def softmax(x, axis=None):
# e_x = np.exp(x).astype(np.float32) #
e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
sum_ex = np.sum(e_x, axis=axis,keepdims=True).astype(np.float32)
return e_x / sum_ex def conv2d(images,weight,bias,stride=1,padding=0):
# 卷积操作
N, C, H, W = images.shape
F, _, HH, WW = weight.shape
# 计算卷积后的输出尺寸
H_out = (H - HH + 2 * padding) // stride + 1
W_out = (W - WW + 2 * padding) // stride + 1
# 初始化卷积层输出
out = np.zeros((N, F, H_out, W_out))
# 执行卷积运算
for i in range(H_out):
for j in range(W_out):
# 提取当前卷积窗口
window = images[:, :, i * stride:i * stride + HH, j * stride:j * stride + WW]
# 执行卷积运算
out[:, :, i, j] = np.sum(window * weight, axis=(1, 2, 3)) + bias
# 输出结果
# print("卷积层输出尺寸:", out.shape)
return out def scaled_dot_product_attention(Q, K, V, mask=None):
d_k = Q.shape[-1]
scores = np.matmul(Q, K.transpose(0, 2, 1)) / np.sqrt(d_k)
if mask is not None:
scores = np.where(mask, scores, np.full_like(scores, -np.inf))
attention_weights = softmax(scores, axis=-1)
# print(attention_weights)
# print(np.sum(attention_weights,axis=-1))
output = np.matmul(attention_weights, V)
return output, attention_weights def multihead_attention(input, num_heads,W_Q,B_Q,W_K,B_K,W_V,B_V,W_O,B_O): q = np.matmul(input, W_Q.T)+B_Q
k = np.matmul(input, W_K.T)+B_K
v = np.matmul(input, W_V.T)+B_V # 分割输入为多个头
q = np.split(q, num_heads, axis=-1)
k = np.split(k, num_heads, axis=-1)
v = np.split(v, num_heads, axis=-1) outputs = []
for q_,k_,v_ in zip(q,k,v):
output, attention_weights = scaled_dot_product_attention(q_, k_, v_)
outputs.append(output)
outputs = np.concatenate(outputs, axis=-1)
outputs = np.matmul(outputs, W_O.T)+B_O
return outputs def layer_normalization(x, weight, bias, eps=1e-12):
mean = np.mean(x, axis=-1, keepdims=True)
variance = np.var(x, axis=-1, keepdims=True)
std = np.sqrt(variance + eps)
normalized_x = (x - mean) / std
output = weight * normalized_x + bias
return output def feed_forward_layer(inputs, weight, bias, activation='relu'):
linear_output = np.matmul(inputs,weight) + bias if activation == 'relu':
activated_output = np.maximum(0, linear_output) # ReLU激活函数
elif activation == 'gelu':
activated_output = 0.5 * linear_output * (1 + np.tanh(np.sqrt(2 / np.pi) * (linear_output + 0.044715 * np.power(linear_output, 3)))) # GELU激活函数 elif activation == "tanh" :
activated_output = np.tanh(linear_output)
else:
activated_output = linear_output # 无激活函数 return activated_output def residual_connection(inputs, residual):
# 残差连接
residual_output = inputs + residual
return residual_output def vit(input,num_heads=12): for i in range(12):
# 调用多头自注意力函数
W_Q = model_data['vit.encoder.layer.{}.attention.attention.query.weight'.format(i)]
B_Q = model_data['vit.encoder.layer.{}.attention.attention.query.bias'.format(i)]
W_K = model_data['vit.encoder.layer.{}.attention.attention.key.weight'.format(i)]
B_K = model_data['vit.encoder.layer.{}.attention.attention.key.bias'.format(i)]
W_V = model_data['vit.encoder.layer.{}.attention.attention.value.weight'.format(i)]
B_V = model_data['vit.encoder.layer.{}.attention.attention.value.bias'.format(i)]
W_O = model_data['vit.encoder.layer.{}.attention.output.dense.weight'.format(i)]
B_O = model_data['vit.encoder.layer.{}.attention.output.dense.bias'.format(i)]
intermediate_weight = model_data['vit.encoder.layer.{}.intermediate.dense.weight'.format(i)]
intermediate_bias = model_data['vit.encoder.layer.{}.intermediate.dense.bias'.format(i)]
dense_weight = model_data['vit.encoder.layer.{}.output.dense.weight'.format(i)]
dense_bias = model_data['vit.encoder.layer.{}.output.dense.bias'.format(i)]
LayerNorm_before_weight = model_data['vit.encoder.layer.{}.layernorm_before.weight'.format(i)]
LayerNorm_before_bias = model_data['vit.encoder.layer.{}.layernorm_before.bias'.format(i)]
LayerNorm_after_weight = model_data['vit.encoder.layer.{}.layernorm_after.weight'.format(i)]
LayerNorm_after_bias = model_data['vit.encoder.layer.{}.layernorm_after.bias'.format(i)] output = layer_normalization(input,LayerNorm_before_weight,LayerNorm_before_bias)
output = multihead_attention(output, num_heads,W_Q,B_Q,W_K,B_K,W_V,B_V,W_O,B_O)
output1 = residual_connection(input,output)
#这里和模型输出一致 output = layer_normalization(output1,LayerNorm_after_weight,LayerNorm_after_bias) #一致
output = feed_forward_layer(output, intermediate_weight.T, intermediate_bias, activation='gelu')
output = feed_forward_layer(output, dense_weight.T, dense_bias, activation='')
output2 = residual_connection(output1,output) input = output2 bert_pooler_dense_weight = model_data['vit.layernorm.weight']
bert_pooler_dense_bias = model_data['vit.layernorm.bias']
output = layer_normalization(output2[:,0],bert_pooler_dense_weight,bert_pooler_dense_bias ) #一致
classifier_weight = model_data['classifier.weight']
classifier_bias = model_data['classifier.bias']
output = feed_forward_layer(output,classifier_weight.T,classifier_bias,activation="" ) #一致
output = np.argmax(output,axis=-1)
return output folder_path = './cifar10' # 替换为图片所在的文件夹路径
def infer_images_in_folder(folder_path):
for file_name in os.listdir(folder_path):
file_path = os.path.join(folder_path, file_name)
if os.path.isfile(file_path) and file_name.endswith(('.jpg', '.jpeg', '.png')):
image = Image.open(file_path)
image = image.resize((224, 224))
label = file_name.split(".")[0].split("_")[1]
image = np.array(image)/255.0
image = np.transpose(image, (2, 0, 1))
image = np.expand_dims(image,axis=0)
print("file_path:",file_path,"img size:",image.shape,"label:",label)
input = model_input(image)
predicted_class = vit(input)
print('Predicted class:', predicted_class) if __name__ == "__main__": infer_images_in_folder(folder_path)
结果:
file_path: ./cifar10/8619_5.jpg img size: (1, 3, 224, 224) label: 5
Predicted class: [5]
file_path: ./cifar10/6042_6.jpg img size: (1, 3, 224, 224) label: 6
Predicted class: [6]
file_path: ./cifar10/6801_6.jpg img size: (1, 3, 224, 224) label: 6
Predicted class: [6]
file_path: ./cifar10/7946_1.jpg img size: (1, 3, 224, 224) label: 1
Predicted class: [1]
file_path: ./cifar10/6925_2.jpg img size: (1, 3, 224, 224) label: 2
Predicted class: [2]
file_path: ./cifar10/6007_6.jpg img size: (1, 3, 224, 224) label: 6
Predicted class: [6]
file_path: ./cifar10/7903_1.jpg img size: (1, 3, 224, 224) label: 1
Predicted class: [1]
file_path: ./cifar10/7064_5.jpg img size: (1, 3, 224, 224) label: 5
Predicted class: [5]
file_path: ./cifar10/2713_8.jpg img size: (1, 3, 224, 224) label: 8
Predicted class: [8]
file_path: ./cifar10/8575_9.jpg img size: (1, 3, 224, 224) label: 9
Predicted class: [9]
file_path: ./cifar10/1985_6.jpg img size: (1, 3, 224, 224) label: 6
Predicted class: [6]
file_path: ./cifar10/5312_5.jpg img size: (1, 3, 224, 224) label: 5
Predicted class: [5]
file_path: ./cifar10/593_6.jpg img size: (1, 3, 224, 224) label: 6
Predicted class: [6]
file_path: ./cifar10/8093_7.jpg img size: (1, 3, 224, 224) label: 7
Predicted class: [7]
file_path: ./cifar10/6862_5.jpg img size: (1, 3, 224, 224) label: 5
模型参数:
vit.embeddings.cls_token (1, 1, 768)
vit.embeddings.position_embeddings (1, 197, 768)
vit.embeddings.patch_embeddings.projection.weight (768, 3, 16, 16)
vit.embeddings.patch_embeddings.projection.bias (768,)
vit.encoder.layer.0.attention.attention.query.weight (768, 768)
vit.encoder.layer.0.attention.attention.query.bias (768,)
vit.encoder.layer.0.attention.attention.key.weight (768, 768)
vit.encoder.layer.0.attention.attention.key.bias (768,)
vit.encoder.layer.0.attention.attention.value.weight (768, 768)
vit.encoder.layer.0.attention.attention.value.bias (768,)
vit.encoder.layer.0.attention.output.dense.weight (768, 768)
vit.encoder.layer.0.attention.output.dense.bias (768,)
vit.encoder.layer.0.intermediate.dense.weight (3072, 768)
vit.encoder.layer.0.intermediate.dense.bias (3072,)
vit.encoder.layer.0.output.dense.weight (768, 3072)
vit.encoder.layer.0.output.dense.bias (768,)
vit.encoder.layer.0.layernorm_before.weight (768,)
vit.encoder.layer.0.layernorm_before.bias (768,)
vit.encoder.layer.0.layernorm_after.weight (768,)
vit.encoder.layer.0.layernorm_after.bias (768,)
vit.encoder.layer.1.attention.attention.query.weight (768, 768)
vit.encoder.layer.1.attention.attention.query.bias (768,)
vit.encoder.layer.1.attention.attention.key.weight (768, 768)
vit.encoder.layer.1.attention.attention.key.bias (768,)
vit.encoder.layer.1.attention.attention.value.weight (768, 768)
vit.encoder.layer.1.attention.attention.value.bias (768,)
vit.encoder.layer.1.attention.output.dense.weight (768, 768)
vit.encoder.layer.1.attention.output.dense.bias (768,)
vit.encoder.layer.1.intermediate.dense.weight (3072, 768)
vit.encoder.layer.1.intermediate.dense.bias (3072,)
vit.encoder.layer.1.output.dense.weight (768, 3072)
vit.encoder.layer.1.output.dense.bias (768,)
vit.encoder.layer.1.layernorm_before.weight (768,)
vit.encoder.layer.1.layernorm_before.bias (768,)
vit.encoder.layer.1.layernorm_after.weight (768,)
vit.encoder.layer.1.layernorm_after.bias (768,)
vit.encoder.layer.2.attention.attention.query.weight (768, 768)
vit.encoder.layer.2.attention.attention.query.bias (768,)
vit.encoder.layer.2.attention.attention.key.weight (768, 768)
vit.encoder.layer.2.attention.attention.key.bias (768,)
vit.encoder.layer.2.attention.attention.value.weight (768, 768)
vit.encoder.layer.2.attention.attention.value.bias (768,)
vit.encoder.layer.2.attention.output.dense.weight (768, 768)
vit.encoder.layer.2.attention.output.dense.bias (768,)
vit.encoder.layer.2.intermediate.dense.weight (3072, 768)
vit.encoder.layer.2.intermediate.dense.bias (3072,)
vit.encoder.layer.2.output.dense.weight (768, 3072)
vit.encoder.layer.2.output.dense.bias (768,)
vit.encoder.layer.2.layernorm_before.weight (768,)
vit.encoder.layer.2.layernorm_before.bias (768,)
vit.encoder.layer.2.layernorm_after.weight (768,)
vit.encoder.layer.2.layernorm_after.bias (768,)
vit.encoder.layer.3.attention.attention.query.weight (768, 768)
vit.encoder.layer.3.attention.attention.query.bias (768,)
vit.encoder.layer.3.attention.attention.key.weight (768, 768)
vit.encoder.layer.3.attention.attention.key.bias (768,)
vit.encoder.layer.3.attention.attention.value.weight (768, 768)
vit.encoder.layer.3.attention.attention.value.bias (768,)
vit.encoder.layer.3.attention.output.dense.weight (768, 768)
vit.encoder.layer.3.attention.output.dense.bias (768,)
vit.encoder.layer.3.intermediate.dense.weight (3072, 768)
vit.encoder.layer.3.intermediate.dense.bias (3072,)
vit.encoder.layer.3.output.dense.weight (768, 3072)
vit.encoder.layer.3.output.dense.bias (768,)
vit.encoder.layer.3.layernorm_before.weight (768,)
vit.encoder.layer.3.layernorm_before.bias (768,)
vit.encoder.layer.3.layernorm_after.weight (768,)
vit.encoder.layer.3.layernorm_after.bias (768,)
vit.encoder.layer.4.attention.attention.query.weight (768, 768)
vit.encoder.layer.4.attention.attention.query.bias (768,)
vit.encoder.layer.4.attention.attention.key.weight (768, 768)
vit.encoder.layer.4.attention.attention.key.bias (768,)
vit.encoder.layer.4.attention.attention.value.weight (768, 768)
vit.encoder.layer.4.attention.attention.value.bias (768,)
vit.encoder.layer.4.attention.output.dense.weight (768, 768)
vit.encoder.layer.4.attention.output.dense.bias (768,)
vit.encoder.layer.4.intermediate.dense.weight (3072, 768)
vit.encoder.layer.4.intermediate.dense.bias (3072,)
vit.encoder.layer.4.output.dense.weight (768, 3072)
vit.encoder.layer.4.output.dense.bias (768,)
vit.encoder.layer.4.layernorm_before.weight (768,)
vit.encoder.layer.4.layernorm_before.bias (768,)
vit.encoder.layer.4.layernorm_after.weight (768,)
vit.encoder.layer.4.layernorm_after.bias (768,)
vit.encoder.layer.5.attention.attention.query.weight (768, 768)
vit.encoder.layer.5.attention.attention.query.bias (768,)
vit.encoder.layer.5.attention.attention.key.weight (768, 768)
vit.encoder.layer.5.attention.attention.key.bias (768,)
vit.encoder.layer.5.attention.attention.value.weight (768, 768)
vit.encoder.layer.5.attention.attention.value.bias (768,)
vit.encoder.layer.5.attention.output.dense.weight (768, 768)
vit.encoder.layer.5.attention.output.dense.bias (768,)
vit.encoder.layer.5.intermediate.dense.weight (3072, 768)
vit.encoder.layer.5.intermediate.dense.bias (3072,)
vit.encoder.layer.5.output.dense.weight (768, 3072)
vit.encoder.layer.5.output.dense.bias (768,)
vit.encoder.layer.5.layernorm_before.weight (768,)
vit.encoder.layer.5.layernorm_before.bias (768,)
vit.encoder.layer.5.layernorm_after.weight (768,)
vit.encoder.layer.5.layernorm_after.bias (768,)
vit.encoder.layer.6.attention.attention.query.weight (768, 768)
vit.encoder.layer.6.attention.attention.query.bias (768,)
vit.encoder.layer.6.attention.attention.key.weight (768, 768)
vit.encoder.layer.6.attention.attention.key.bias (768,)
vit.encoder.layer.6.attention.attention.value.weight (768, 768)
vit.encoder.layer.6.attention.attention.value.bias (768,)
vit.encoder.layer.6.attention.output.dense.weight (768, 768)
vit.encoder.layer.6.attention.output.dense.bias (768,)
vit.encoder.layer.6.intermediate.dense.weight (3072, 768)
vit.encoder.layer.6.intermediate.dense.bias (3072,)
vit.encoder.layer.6.output.dense.weight (768, 3072)
vit.encoder.layer.6.output.dense.bias (768,)
vit.encoder.layer.6.layernorm_before.weight (768,)
vit.encoder.layer.6.layernorm_before.bias (768,)
vit.encoder.layer.6.layernorm_after.weight (768,)
vit.encoder.layer.6.layernorm_after.bias (768,)
vit.encoder.layer.7.attention.attention.query.weight (768, 768)
vit.encoder.layer.7.attention.attention.query.bias (768,)
vit.encoder.layer.7.attention.attention.key.weight (768, 768)
vit.encoder.layer.7.attention.attention.key.bias (768,)
vit.encoder.layer.7.attention.attention.value.weight (768, 768)
vit.encoder.layer.7.attention.attention.value.bias (768,)
vit.encoder.layer.7.attention.output.dense.weight (768, 768)
vit.encoder.layer.7.attention.output.dense.bias (768,)
vit.encoder.layer.7.intermediate.dense.weight (3072, 768)
vit.encoder.layer.7.intermediate.dense.bias (3072,)
vit.encoder.layer.7.output.dense.weight (768, 3072)
vit.encoder.layer.7.output.dense.bias (768,)
vit.encoder.layer.7.layernorm_before.weight (768,)
vit.encoder.layer.7.layernorm_before.bias (768,)
vit.encoder.layer.7.layernorm_after.weight (768,)
vit.encoder.layer.7.layernorm_after.bias (768,)
vit.encoder.layer.8.attention.attention.query.weight (768, 768)
vit.encoder.layer.8.attention.attention.query.bias (768,)
vit.encoder.layer.8.attention.attention.key.weight (768, 768)
vit.encoder.layer.8.attention.attention.key.bias (768,)
vit.encoder.layer.8.attention.attention.value.weight (768, 768)
vit.encoder.layer.8.attention.attention.value.bias (768,)
vit.encoder.layer.8.attention.output.dense.weight (768, 768)
vit.encoder.layer.8.attention.output.dense.bias (768,)
vit.encoder.layer.8.intermediate.dense.weight (3072, 768)
vit.encoder.layer.8.intermediate.dense.bias (3072,)
vit.encoder.layer.8.output.dense.weight (768, 3072)
vit.encoder.layer.8.output.dense.bias (768,)
vit.encoder.layer.8.layernorm_before.weight (768,)
vit.encoder.layer.8.layernorm_before.bias (768,)
vit.encoder.layer.8.layernorm_after.weight (768,)
vit.encoder.layer.8.layernorm_after.bias (768,)
vit.encoder.layer.9.attention.attention.query.weight (768, 768)
vit.encoder.layer.9.attention.attention.query.bias (768,)
vit.encoder.layer.9.attention.attention.key.weight (768, 768)
vit.encoder.layer.9.attention.attention.key.bias (768,)
vit.encoder.layer.9.attention.attention.value.weight (768, 768)
vit.encoder.layer.9.attention.attention.value.bias (768,)
vit.encoder.layer.9.attention.output.dense.weight (768, 768)
vit.encoder.layer.9.attention.output.dense.bias (768,)
vit.encoder.layer.9.intermediate.dense.weight (3072, 768)
vit.encoder.layer.9.intermediate.dense.bias (3072,)
vit.encoder.layer.9.output.dense.weight (768, 3072)
vit.encoder.layer.9.output.dense.bias (768,)
vit.encoder.layer.9.layernorm_before.weight (768,)
vit.encoder.layer.9.layernorm_before.bias (768,)
vit.encoder.layer.9.layernorm_after.weight (768,)
vit.encoder.layer.9.layernorm_after.bias (768,)
vit.encoder.layer.10.attention.attention.query.weight (768, 768)
vit.encoder.layer.10.attention.attention.query.bias (768,)
vit.encoder.layer.10.attention.attention.key.weight (768, 768)
vit.encoder.layer.10.attention.attention.key.bias (768,)
vit.encoder.layer.10.attention.attention.value.weight (768, 768)
vit.encoder.layer.10.attention.attention.value.bias (768,)
vit.encoder.layer.10.attention.output.dense.weight (768, 768)
vit.encoder.layer.10.attention.output.dense.bias (768,)
vit.encoder.layer.10.intermediate.dense.weight (3072, 768)
vit.encoder.layer.10.intermediate.dense.bias (3072,)
vit.encoder.layer.10.output.dense.weight (768, 3072)
vit.encoder.layer.10.output.dense.bias (768,)
vit.encoder.layer.10.layernorm_before.weight (768,)
vit.encoder.layer.10.layernorm_before.bias (768,)
vit.encoder.layer.10.layernorm_after.weight (768,)
vit.encoder.layer.10.layernorm_after.bias (768,)
vit.encoder.layer.11.attention.attention.query.weight (768, 768)
vit.encoder.layer.11.attention.attention.query.bias (768,)
vit.encoder.layer.11.attention.attention.key.weight (768, 768)
vit.encoder.layer.11.attention.attention.key.bias (768,)
vit.encoder.layer.11.attention.attention.value.weight (768, 768)
vit.encoder.layer.11.attention.attention.value.bias (768,)
vit.encoder.layer.11.attention.output.dense.weight (768, 768)
vit.encoder.layer.11.attention.output.dense.bias (768,)
vit.encoder.layer.11.intermediate.dense.weight (3072, 768)
vit.encoder.layer.11.intermediate.dense.bias (3072,)
vit.encoder.layer.11.output.dense.weight (768, 3072)
vit.encoder.layer.11.output.dense.bias (768,)
vit.encoder.layer.11.layernorm_before.weight (768,)
vit.encoder.layer.11.layernorm_before.bias (768,)
vit.encoder.layer.11.layernorm_after.weight (768,)
vit.encoder.layer.11.layernorm_after.bias (768,)
vit.layernorm.weight (768,)
vit.layernorm.bias (768,)
classifier.weight (10, 768)
classifier.bias (10,)
hungging face模型训练代码 对cifar10训练,保存模型参数为numpy格式,方便numpy的模型加载:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from transformers import ViTModel, ViTForImageClassification
from tqdm import tqdm
import numpy as np # 设置随机种子
torch.manual_seed(42) # 定义超参数
batch_size = 64
num_epochs = 1
learning_rate = 1e-4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
]) # 加载CIFAR-10数据集
train_dataset = CIFAR10(root='/data/xinyuuliu/datas', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='/data/xinyuuliu/datas', train=False, download=True, transform=transform) # 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # 加载预训练的ViT模型
vit_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device) # 替换分类头
num_classes = 10
# vit_model.config.classifier = 'mlp'
# vit_model.config.num_labels = num_classes
vit_model.classifier = nn.Linear(vit_model.config.hidden_size, num_classes).to(device) # parameters = list(vit_model.parameters())
# for x in parameters[:-1]:
# x.requires_grad = False # 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vit_model.parameters(), lr=learning_rate) # 微调ViT模型
for epoch in range(num_epochs):
print("epoch:",epoch)
vit_model.train()
train_loss = 0.0
train_correct = 0 bar = tqdm(train_loader,total=len(train_loader))
for images, labels in bar:
images = images.to(device)
labels = labels.to(device) # 前向传播
outputs = vit_model(images)
loss = criterion(outputs.logits, labels) # 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step() train_loss += loss.item()
_, predicted = torch.max(outputs.logits, 1)
train_correct += (predicted == labels).sum().item() # 在训练集上计算准确率
train_accuracy = 100.0 * train_correct / len(train_dataset) # 在测试集上进行评估
vit_model.eval()
test_loss = 0.0
test_correct = 0 with torch.no_grad():
bar = tqdm(test_loader,total=len(test_loader))
for images, labels in bar:
images = images.to(device)
labels = labels.to(device) outputs = vit_model(images)
loss = criterion(outputs.logits, labels) test_loss += loss.item()
_, predicted = torch.max(outputs.logits, 1)
test_correct += (predicted == labels).sum().item() # 在测试集上计算准确率
test_accuracy = 100.0 * test_correct / len(test_dataset) # 打印每个epoch的训练损失、训练准确率和测试准确率
print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Accuracy: {test_accuracy:.2f}%') torch.save(vit_model.state_dict(), 'vit_model_parameters.pth') # 打印BERT模型的权重维度
for name, param in vit_model.named_parameters():
print(name, param.data.shape) # # # 保存模型参数为NumPy格式
model_params = {name: param.data.cpu().numpy() for name, param in vit_model.named_parameters()}
np.savez('vit_model_params.npz', **model_params)
# model_params
Epoch [1/1], Train Loss: 97.7498, Train Accuracy: 96.21%, Test Accuracy: 96.86%
我用numpy实现了VIT,手写vision transformer, 可在树莓派上运行,在hugging face上训练模型保存参数成numpy格式,纯numpy实现的更多相关文章
- 利用sklearn对MNIST手写数据集开始一个简单的二分类判别器项目(在这个过程中学习关于模型性能的评价指标,如accuracy,precision,recall,混淆矩阵)
.caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...
- Tensorflow之基于MNIST手写识别的入门介绍
Tensorflow是当下AI热潮下,最为受欢迎的开源框架.无论是从Github上的fork数量还是star数量,还是从支持的语音,开发资料,社区活跃度等多方面,他当之为superstar. 在前面介 ...
- GAN实战笔记——第三章第一个GAN模型:生成手写数字
第一个GAN模型-生成手写数字 一.GAN的基础:对抗训练 形式上,生成器和判别器由可微函数表示如神经网络,他们都有自己的代价函数.这两个网络是利用判别器的损失记性反向传播训练.判别器努力使真实样本输 ...
- 【TensorFlow篇】--Tensorflow框架实现SoftMax模型识别手写数字集
一.前述 本文讲述用Tensorflow框架实现SoftMax模型识别手写数字集,来实现多分类. 同时对模型的保存和恢复做下示例. 二.具体原理 代码一:实现代码 #!/usr/bin/python ...
- OpenCV+TensorFlow实现自定义手写图像识别
完整版请点击链接:https://mp.weixin.qq.com/s/5gHXGmLbtO7m3dOFrDUiHQ 或微信关注“大数据技术宅” 继用TensorFlow教你做手写字识别(准确率 ...
- 手写AVL 树(下)
上一篇 手写AVL树上实现了AVL树的插入和查询 上代码: 头文件:AVL.h #include <iostream> template<typename T1,typename T ...
- mnist 手写数字识别
mnist 手写数字识别三大步骤 1.定义分类模型2.训练模型3.评价模型 import tensorflow as tfimport input_datamnist = input_data.rea ...
- 手写Json转换
在做项目的时候总是要手动将集合转换成json每次都很麻烦,于是就尝试着写了一个公用的方法,用于转换List to json: using System; using System.Collection ...
- 全命令行手写MapReduce并且打包运行
主要要讲的有3个 java中的package是干啥的? 工作了好几年的都一定真正理解java里面的package关键字,这里在写MapReduce需要进行打包的时候突然发现命令行下打包运行居然不会了, ...
- 手写迷你SpringMVC框架
前言 学习如何使用Spring,SpringMVC是很快的,但是在往后使用的过程中难免会想探究一下框架背后的原理是什么,本文将通过讲解如何手写一个简单版的springMVC框架,直接从代码上看框架中请 ...
随机推荐
- Android开发_记事本(1)
一些知识 Textview TextView中有下述几个属性: id:为TextView设置一个组件id,根据id,我们可以在Java代码中通过findViewById()的方法获取到该对象,然后进行 ...
- super 与 this 关键字
super与this用法相似: 1.普通的直接引用 2.形参与成员名字重名,用 this 来指代类本身,super指代父类 public class Students extends Person { ...
- IIS 部署.NET CORE 项目 出现 HTTP 错误 500.19 - Internal Server Error
当出现这个错误时是因为服务器上没有.NET CORE对应的SDK以及运行时文件,我的.NET CORE版本是2.2,下载的就是2.2对应的文件. 附上.NET CORE2.2版本的下载链接 下载 .N ...
- Python实现网络工具
使用python编写网络工具 基础内容 介绍基本的网络编程 Socket编程 Socket又称"套接字",应用程序通常通过"套接字"向网络发出请求或者应答网络请 ...
- 为什么 APISIX Ingress 是比 Emissary-ingress 更好的选择?
本文从可扩展性和服务发现集成等多个维度对比了 APISIX Ingress 与 Emissary-ingress 的性能. 作者:容鑫,API7.ai 云原生技术工程师,Apache APISIX C ...
- vite项目优化----- 解决终端optimized dependencies changed. reloading问题
写在前面网上都说vite要比webpack快,但个人感受,默认情况下, vite项目的启动确实比webpack快,但如果某个界面是首次进入,且依赖比较多/比较复杂的话,那就会比较慢了. 这篇文章就是用 ...
- Windows安装系统
0x01下载PE 微PE 0x02安装PE 0x021方式一:安装到系统 此方法开机有选择系统的选项,强迫症使用方法二 0x022方式二:安装到U盘 此方法需要一个U盘 确认无误后点击 立即安装到U盘 ...
- [Pytorch框架] 1.6 训练一个分类器
文章目录 训练一个分类器 关于数据? 训练一个图像分类器 在GPU上训练 多GPU训练 下一步? 训练一个分类器 上一讲中已经看到如何去定义一个神经网络,计算损失值和更新网络的权重. 你现在可能在想下 ...
- Websocket 60秒断开,连接不稳定
本地测试都是正常的,线上测试总是过一会就断开... 线上新增了https协议,导致页面中的链接必须也是ssl Websocket链接地址从ws://ws.xxx.com改成了wss://ws.xxx. ...
- mysql 结合python一些日常写法
python sql语句in写法 sql = "SELECT * FROM user WHERE name in ({})".format(','.join(["'%s' ...