主要为了测试模型增加Lora模块后,参数量和训练速度的变化情况。
结论:正常情况下,增加Lora模块是会增加参数量的,因此前向传播和反向传播的时间也会增加。
但是,在大语言模型训练的情况下,因为基础模型本身参数量非常大,Lora模块增加的参数量相对非常小。并且,基础模型不参与梯度更新,可以做模型量化,实际上是能减少模型训练时间和显存使用量的。
以下是实验脚本和运行结果:
#部分参考https://zhuanlan.zhihu.com/p/666000885
import time
import torch
from torch import nn
from peft import LoraConfig, get_peft_model, PeftModel
from torchsummary import summary x_train = torch.randn((1000, 10))
y_train = torch.randn((1000, 1)) net = nn.Sequential(
nn.Linear(10,20),
nn.Sigmoid(),
nn.Linear(20,30),
nn.Sigmoid(),
nn.Linear(30,1)
)
summary(net, (1,10)) config = LoraConfig(target_modules=["0"], r=2)
model = get_peft_model(net, config)
criterion = torch.nn.MSELoss(reduction='mean') # 定义损失函数,采用均方误差
optimizer = torch.optim.Adam(model.parameters(), lr=0.3) # 定义优化器,采用Adam
summary(model, (1,10)) # base 前向计算时间
start = time.time()
for i in range(100000):
y_pre = net(x_train) # 前向传播
print("base 前向计算时间: ", time.time() - start) # lora 前向计算时间
start = time.time()
for i in range(100000):
y_pre = model(x_train) # 前向传播
print("lora 前向计算时间", time.time() - start) # base 反向传播时间
start = time.time()
for i in range(1000):
y_pre = net(x_train) # 前向传播
loss = criterion(y_pre, y_train) # 计算损失
optimizer.zero_grad() # 梯度清零
loss.backward() # 反向传播
optimizer.step() # 使用优化器更新梯度
print("base loss after training: ", loss.item())
print("base 反向计算时间", time.time() - start) # lora 反向传播时间
start = time.time()
for i in range(1000):
y_pre = model(x_train) # 前向传播
loss = criterion(y_pre, y_train) # 计算损失
optimizer.zero_grad() # 梯度清零
loss.backward() # 反向传播
optimizer.step() # 使用优化器更新梯度
print("lora loss after training: ", loss.item())
print("lora 反向计算时间", time.time() - start)

  运行代码输出:

----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Linear-1 [-1, 1, 20] 220
Sigmoid-2 [-1, 1, 20] 0
Linear-3 [-1, 1, 30] 630
Sigmoid-4 [-1, 1, 30] 0
Linear-5 [-1, 1, 1] 31
================================================================
Total params: 881
Trainable params: 881
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Linear-1 [-1, 1, 20] 220
Identity-2 [-1, 1, 10] 0
Linear-3 [-1, 1, 2] 20
Linear-4 [-1, 1, 20] 40
Linear-5 [-1, 1, 20] 220
Sigmoid-6 [-1, 1, 20] 0
Linear-7 [-1, 1, 30] 630
Sigmoid-8 [-1, 1, 30] 0
Linear-9 [-1, 1, 1] 31
================================================================
Total params: 1,161
Trainable params: 60
Non-trainable params: 1,101
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.01
----------------------------------------------------------------
base loss after training: 1.0724023580551147
base 反向计算时间 2.9570980072021484
lora loss after training: 1.0643658638000488
lora 反向计算时间 3.053032159805298

Lora训练的参数和性能的更多相关文章

  1. 训练超参数, 出现 Cannot use GPU in CPU-only Caffe 错误?

    当我们用MNIST手写体数字数据库和LeNet CNN 模型训练超参数,运行 examples/mnist/train_lenet.sh是出现Cannot use GPU in CPU-only Ca ...

  2. 05:Sysbench压测-innodb_deadlock_detect参数对性能的影响

    目录 sysbench压测-innodb_deadlock_detect参数对性能的影响 一.OLTP测试前准备 二.进行OLTP测试 三.测试结果解读: 四.关于测试后的结论: 五.关于测试后的性能 ...

  3. 04:Sysbench压测-innodb_flush_log_at_trx_commit,sync_binlog参数对性能的影响

    目录 sysbench压测-innodb_flush_log_at_trx_commit,sync_binlog参数对性能的影响 一.OLTP测试前准备 二.MySQL 数据落盘的过程 三.参数说明 ...

  4. [转帖]PostgreSQL 参数调整(性能优化)

    PostgreSQL 参数调整(性能优化) https://www.cnblogs.com/VicLiu/p/11854730.html 知道一个 shared_pool 文章写的挺好的 还没仔细看 ...

  5. Tomcat学习四步走:内核、集群、参数及性能

    主题简介: 内核实现原理 分布式集群 生产部署关键参数 性能监控和分析 一.内核实现原理 HTTP Web服务器与浏览器之间以HTTP协议通信,浏览器要访问服务器即向服务器发送HTTP请求报文. 如图 ...

  6. pytorch和tensorflow的爱恨情仇之定义可训练的参数

    pytorch和tensorflow的爱恨情仇之基本数据类型 pytorch和tensorflow的爱恨情仇之张量 pytorch版本:1.6.0 tensorflow版本:1.15.0 之前我们就已 ...

  7. C# 中的 in 参数和性能分析

    in 修饰符也是从 C# 7.2 开始引入的,它与我们上一篇中讨论的 <C# 中的只读结构体(readonly struct)>[1] 是紧密相关的. in 修饰符 in 修饰符通过引用传 ...

  8. IO模式设置网络编程常见问题总结—IO模式设置,阻塞与非阻塞的比较,recv参数对性能的影响—O_NONBLOCK(open使用)、IPC_NOWAIT(msgrcv)、MSG_DONTWAIT(re

    非阻塞IO 和阻塞IO: 在网络编程中对于一个网络句柄会遇到阻塞IO 和非阻塞IO 的概念, 这里对于这两种socket 先做一下说明:       基本概念: 阻塞IO:: socket 的阻塞模式 ...

  9. Linux内存管理-内核的shmall和shmmax参数(性能调优)(转)

    内核的shmall和shmmax参数 SHMMAX=配置了最大的内存segment的大小:这个设置的比SGA_MAX_SIZE大比较好. SHMMIN=最小的内存segment的大小 SHMMNI=整 ...

  10. PostgreSQL 参数调整(性能优化)

    昨天分别在外网和无外网环境下安装PostgreSQL,有外网环境下安装的相当顺利.但是在无外网环境下就是两个不同的概念了,可谓十有八折.感兴趣的同学可以搭建一下. PostgreSQL安装完成后第一件 ...

随机推荐

  1. KingbaseESV8R6识别IO使用率过高

    前言 数据库正常运行离不开I/O的使用,在操作系统上,I/O又离不开存储的性能及使用方式,我们可以在存储层利用raid条带化技术使IOPS达到最佳性能. 本篇文章有助于确认数据库I/O使用率过高的原因 ...

  2. Fast多维数组

    #include<iostream> #include<chrono> struct Timer { std::chrono::time_point<std::chron ...

  3. HTML实现发送接收串口和TCP数据

    前提 请安装通讯调试工具,所有的网页必须运行在本工具上,在其他浏览器直接打开是不行的. 效果显示 在网页上右键打开,选择其他应用 2.在其他应用中找到通讯调试工具 如果没有这一项,点更多,在计算机中查 ...

  4. 实用 Linux 命令 Windos 命令 实例演示 持续更新中

    实用 Linux 命令 Windos 命令 实例演示 持续更新中 目录 实用 Linux 命令 Windos 命令 实例演示 持续更新中 Linux 命令 [Command [options] [lo ...

  5. Python 潮流周刊第 45 期(摘要)+ 赠书 5 本《Python语言及其应用(第2版)》

    本周刊由 Python猫 出品,精心筛选国内外的 250+ 信息源,为你挑选最值得分享的文章.教程.开源项目.软件工具.播客和视频.热门话题等内容.愿景:帮助所有读者精进 Python 技术,并增长职 ...

  6. Python 内置数据类型详解

    内置数据类型 在编程中,数据类型是一个重要的概念. 变量可以存储不同类型的数据,不同类型可以执行不同的操作. Python默认内置了以下这些数据类型,分为以下几类: 文本类型:str 数值类型:int ...

  7. AI云增强升级!还原生动人像,拍出质感照片

    近期不少细心用户发现,在用HUAWEI Mate 60 Pro手机拍照后,使用相册中的AI云增强功能,照片变得更加细腻有质感.这是因为AI云增强升级并更新支持了人像模式拍摄的照片,高清自然的人像细节还 ...

  8. Python调用动态库,获取BSTR字符串

    今天客户在用Python调用我们的动态库的时候,遇到一个问题,调用动态库中的函数,函数返回的是BSTR字符串,但是客户接收到的是一个8位长度的数字. 动态库函数原型:EXTERN_C BSTR ELO ...

  9. Go语言的100个错误使用场景(61-68)|并发实践

    目录 前言 9. 并发实践 9.1 context 的不恰当传播(#61) 9.2 开启一个协程但不知道何时关闭(#62) 9.3 在循环中没有谨慎使用协程(#63) 9.4 使用 select 和 ...

  10. c# seo 百度sitemap书写

    前言 我们知道对页面百度收录,有两种方式: 1.百度自己抓取. 2.自己提交sitemap让百度来抓取. sitemap 一般的一个业务逻辑是: 生成sitemap xml,然后每隔一段时间更新即可, ...