聊聊 从源码来看ChatGLM-6B的模型结构
基于ChatGLM-6B第一版,要注意还有ChatGLM2-6B以及ChatGLM3-6B
概述
ChatGLM是transformer架构的神经网络模型,因此从transformer结构入手,分析其源码结构。
transformer结构:

位置编码
ChatGLM-6B的位置编码采用的旋转位置编码(RoPB)实现。其源码:
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
super().__init__()
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
inv_freq = inv_freq.half()
self.learnable = learnable
if learnable:
self.inv_freq = torch.nn.Parameter(inv_freq)
self.max_seq_len_cached = None
else:
self.register_buffer('inv_freq', inv_freq)
self.max_seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
self.precision = precision
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
pass
def forward(self, x, seq_dim=1, seq_len=None):
if seq_len is None:
seq_len = x.shape[seq_dim]
if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
self.max_seq_len_cached = None if self.learnable else seq_len
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
if self.precision == torch.bfloat16:
emb = emb.float()
# [sx, 1 (b * np), hn]
cos_cached = emb.cos()[:, None, :]
sin_cached = emb.sin()[:, None, :]
if self.precision == torch.bfloat16:
cos_cached = cos_cached.bfloat16()
sin_cached = sin_cached.bfloat16()
if self.learnable:
return cos_cached, sin_cached
self.cos_cached, self.sin_cached = cos_cached, sin_cached
return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
def _apply(self, fn):
if self.cos_cached is not None:
self.cos_cached = fn(self.cos_cached)
if self.sin_cached is not None:
self.sin_cached = fn(self.sin_cached)
return super()._apply(fn)
## 转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/
激活函数
ChatGLM-6B采用的激活函数是GeLU(高斯误差线性单元),其源码:
@torch.jit.script
def gelu_impl(x):
"""OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
(1.0 + 0.044715 * x * x)))
def gelu(x):
return gelu_impl(x)
编码器-解码器(encoder-decoder)
接下来就是编码器解码器结构,如何抓住模型源头来分析?可以从transformers的API入手:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().to("cuda:1").eval()
print(mode)
## 转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/
输出:
ChatGLMForConditionalGeneration(
(transformer): ChatGLMModel(
(word_embeddings): Embedding(130528, 4096)
(layers): ModuleList(
(0-27): 28 x GLMBlock(
(input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
(attention): SelfAttention(
(rotary_emb): RotaryEmbedding()
(query_key_value): Linear(in_features=4096, out_features=12288, bias=True)
(dense): Linear(in_features=4096, out_features=4096, bias=True)
)
(post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
(mlp): GLU(
(dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)
(dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)
)
)
)
(final_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
)
(lm_head): Linear(in_features=4096, out_features=130528, bias=False)
)
从脑图的角度来梳理下其结构

其结构图表示如下:

将结构图与最开始的transformer结构图对比来看,两者还是比较符合的。
官方源码中标注了编码器与解码器是一体的,只需要配置参数即可切换为解码器。如下:

聊聊 从源码来看ChatGLM-6B的模型结构的更多相关文章
- 死磕Java之聊聊HashSet源码(基于JDK1.8)
HashSet的UML图 HashSet的成员变量及其含义 public class HashSet<E> extends AbstractSet<E> implements ...
- 从源码来看ReentrantLock和ReentrantReadWriteLock
上一篇花了点时间将同步器看了一下,心中对锁的概念更加明确了一点,知道我们所使用到的锁是怎么样获取同步状态的,我们也写了一个自定义同步组件Mutex,讲到了它其实就是一个简版的ReentrantLock ...
- 死磕Java之聊聊ThreadLocal源码(基于JDK1.8)
记得在一次面试中被问到ThreadLocal,答得马马虎虎,所以打算研究一下ThreadLocal的源码 面试官 : 用过ThreadLocal吗? 楼主答 : 用过,当时使用ThreadLocal的 ...
- 聊聊ThreadLocal源码(基于JDK1.8)
原文:https://cloud.tencent.com/developer/article/1333298 聊聊JDK源码中ThreadLocal的实现 主要方法: ThreadLocal的get方 ...
- 死磕Java之聊聊HashMap源码(基于JDK1.8)
死磕Java之聊聊HashMap源码(基于JDK1.8) http://cmsblogs.com/?p=4731 为什么面试要问hashmap 的原理
- Spring5源码分析(1)设计思想与结构
1 源码地址(带有中文注解)git@github.com:yakax/spring-framework-5.0.2.RELEASE--.git Spring 的设计初衷其实就是为了简化我们的开发 基于 ...
- Redis 源码简洁剖析 09 - Reactor 模型
Reactor 模型 事件驱动框架 Redis 如何实现 Reactor 模型 事件的数据结构:aeFileEvent 主循环:aeMain 函数 事件捕获与分发:aeProcessEvents 函数 ...
- 死磕Java之聊聊ArrayList源码(基于JDK1.8)
工作快一年了,近期打算研究一下JDK的源码,也就因此有了死磕java系列 ArrayList 是一个数组队列,相当于动态数组.与Java中的数组相比,它的容量能动态增长.它继承于AbstractLis ...
- 死磕Java之聊聊LinkedList源码(基于JDK1.8)
工作快一年了,近期打算研究一下JDK的源码,也就因此有了死磕java系列 LinkedList 是一个继承于AbstractSequentialList的双向链表,链表不需要capacity的设定,它 ...
- 从OkHttp的源码来看 HTTP
先来了解一下OkHttp的历史,最早是square公司觉得Android给的HttpClient这块的库不太好用,于是乎做了一层包装,再后来他们包装的这个库被Android官方给收回去了,而Andro ...
随机推荐
- centos7安装Python3.7,执行./configure时报错,configure: error: no acceptable C compiler found in $PATH
执行./configure时报错,configure: error: no acceptable C compiler found in $PATH 在安装python3.7,配置编译路径时会遇到以下 ...
- MySQL运维2-主从复制
一.主从复制概念 主从复制是指将主数据库的DDL和DML操作通过二进制日志传到从服务器中,然后在从服务器上对这些日志重新执行也叫重做,从而使得从数据库和主库的数据保持同步. MySQL支持一台主库同时 ...
- Linux系列教程——Linux文件查找、Linux压缩打包、Linux软件管理
@ 目录 1 Linux文件查找 1.find查找概述 2.find查找示例 1.find名称查找 2.find大小查找 3.find类型查找 4.find时间查找 5.find用户查找 6.find ...
- 如何查询4GL程序中创建的临时表中的数据
前提:将dba_segments这个表的select权限授权给各个营运中心(即数据库用户) ①.用sys账号以dba的权限登录数据库 <topprod:/u1/topprod/tiptop> ...
- 【Cucumber】关于BDD自然语言自动化测试的语法总结
1.关键字 - Feature 每一个.feature文件必须以关键字Feature开始,Feature关键字之后可以添加该feature的描述,其作用类似于注释,仅仅为了便于理解沟通交流,描述内容中 ...
- 通过unittest加载测试用例的不同方法
使用python+unitest做自动化测试执行时, 执行用例时就涉及测试用例的加载. 即如何把测试cases加载到测试suite,然后进行运行. 一般把用例加载方法分为两大类:通过unittest. ...
- 【不限框架】超好用的3d开源图片预览插件推荐
今天给大家推荐一款超好用的图片预览插件-image-preview 简单说明 image-preview是一款主要面向移动端web应用,同时兼容pc,基于原生js,不限框架,react,vue,ang ...
- 概率期望 DP 题解合集
期望这东西学了一次忘了,再学一次过了两天又不会了.我是鱼. 故写此博客以便加深记忆及日后复习. NOIP 前恶补期望(? 希望有用,RP++() 经典问题 1 某事件发生概率为 \(p\),则该事件首 ...
- 实战攻防演练-Linux写入ssh密钥,利用密钥登录
前言 密钥形式登录的原理是利用密钥生成器制作一对密钥,一只公钥和一只私钥.将公钥添加到服务器的某个账户上,然后在客户端利用私钥即可完成认证并登录.这样一来,没有私钥,任何人都无法通过 SSH 暴力破解 ...
- 火山引擎 ByteHouse:只需 2 个方法,增强 ClickHouse 数据导入能力
更多技术交流.求职机会,欢迎关注字节跳动数据平台微信公众号,回复[1]进入官方交流群 作为企业数字化建设的必备要素,易用的数据引擎能帮助企业提升数据使用效率,更好提升数据应用价值,夯实数字化建设基础. ...