基于ChatGLM-6B第一版,要注意还有ChatGLM2-6B以及ChatGLM3-6B

概述

ChatGLM是transformer架构的神经网络模型,因此从transformer结构入手,分析其源码结构。

transformer结构:

转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

位置编码

ChatGLM-6B的位置编码采用的旋转位置编码(RoPB)实现。其源码:

class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
super().__init__()
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
inv_freq = inv_freq.half()
self.learnable = learnable
if learnable:
self.inv_freq = torch.nn.Parameter(inv_freq)
self.max_seq_len_cached = None
else:
self.register_buffer('inv_freq', inv_freq)
self.max_seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
self.precision = precision def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
pass def forward(self, x, seq_dim=1, seq_len=None):
if seq_len is None:
seq_len = x.shape[seq_dim]
if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
self.max_seq_len_cached = None if self.learnable else seq_len
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
if self.precision == torch.bfloat16:
emb = emb.float() # [sx, 1 (b * np), hn]
cos_cached = emb.cos()[:, None, :]
sin_cached = emb.sin()[:, None, :]
if self.precision == torch.bfloat16:
cos_cached = cos_cached.bfloat16()
sin_cached = sin_cached.bfloat16()
if self.learnable:
return cos_cached, sin_cached
self.cos_cached, self.sin_cached = cos_cached, sin_cached
return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] def _apply(self, fn):
if self.cos_cached is not None:
self.cos_cached = fn(self.cos_cached)
if self.sin_cached is not None:
self.sin_cached = fn(self.sin_cached)
return super()._apply(fn) ## 转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

激活函数

ChatGLM-6B采用的激活函数是GeLU(高斯误差线性单元),其源码:

@torch.jit.script
def gelu_impl(x):
"""OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
(1.0 + 0.044715 * x * x))) def gelu(x):
return gelu_impl(x)

编码器-解码器(encoder-decoder)

接下来就是编码器解码器结构,如何抓住模型源头来分析?可以从transformers的API入手:

from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().to("cuda:1").eval() print(mode) ## 转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

输出:

ChatGLMForConditionalGeneration(
(transformer): ChatGLMModel(
(word_embeddings): Embedding(130528, 4096)
(layers): ModuleList(
(0-27): 28 x GLMBlock(
(input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
(attention): SelfAttention(
(rotary_emb): RotaryEmbedding()
(query_key_value): Linear(in_features=4096, out_features=12288, bias=True)
(dense): Linear(in_features=4096, out_features=4096, bias=True)
)
(post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
(mlp): GLU(
(dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)
(dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)
)
)
)
(final_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
)
(lm_head): Linear(in_features=4096, out_features=130528, bias=False)
)

从脑图的角度来梳理下其结构



其结构图表示如下:



将结构图与最开始的transformer结构图对比来看,两者还是比较符合的。

官方源码中标注了编码器与解码器是一体的,只需要配置参数即可切换为解码器。如下:

转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

聊聊 从源码来看ChatGLM-6B的模型结构的更多相关文章

  1. 死磕Java之聊聊HashSet源码(基于JDK1.8)

    HashSet的UML图 HashSet的成员变量及其含义 public class HashSet<E> extends AbstractSet<E> implements ...

  2. 从源码来看ReentrantLock和ReentrantReadWriteLock

    上一篇花了点时间将同步器看了一下,心中对锁的概念更加明确了一点,知道我们所使用到的锁是怎么样获取同步状态的,我们也写了一个自定义同步组件Mutex,讲到了它其实就是一个简版的ReentrantLock ...

  3. 死磕Java之聊聊ThreadLocal源码(基于JDK1.8)

    记得在一次面试中被问到ThreadLocal,答得马马虎虎,所以打算研究一下ThreadLocal的源码 面试官 : 用过ThreadLocal吗? 楼主答 : 用过,当时使用ThreadLocal的 ...

  4. 聊聊ThreadLocal源码(基于JDK1.8)

    原文:https://cloud.tencent.com/developer/article/1333298 聊聊JDK源码中ThreadLocal的实现 主要方法: ThreadLocal的get方 ...

  5. 死磕Java之聊聊HashMap源码(基于JDK1.8)

    死磕Java之聊聊HashMap源码(基于JDK1.8) http://cmsblogs.com/?p=4731 为什么面试要问hashmap 的原理

  6. Spring5源码分析(1)设计思想与结构

    1 源码地址(带有中文注解)git@github.com:yakax/spring-framework-5.0.2.RELEASE--.git Spring 的设计初衷其实就是为了简化我们的开发 基于 ...

  7. Redis 源码简洁剖析 09 - Reactor 模型

    Reactor 模型 事件驱动框架 Redis 如何实现 Reactor 模型 事件的数据结构:aeFileEvent 主循环:aeMain 函数 事件捕获与分发:aeProcessEvents 函数 ...

  8. 死磕Java之聊聊ArrayList源码(基于JDK1.8)

    工作快一年了,近期打算研究一下JDK的源码,也就因此有了死磕java系列 ArrayList 是一个数组队列,相当于动态数组.与Java中的数组相比,它的容量能动态增长.它继承于AbstractLis ...

  9. 死磕Java之聊聊LinkedList源码(基于JDK1.8)

    工作快一年了,近期打算研究一下JDK的源码,也就因此有了死磕java系列 LinkedList 是一个继承于AbstractSequentialList的双向链表,链表不需要capacity的设定,它 ...

  10. 从OkHttp的源码来看 HTTP

    先来了解一下OkHttp的历史,最早是square公司觉得Android给的HttpClient这块的库不太好用,于是乎做了一层包装,再后来他们包装的这个库被Android官方给收回去了,而Andro ...

随机推荐

  1. JavaAgent寄生在目标进程中引起的ClassNotFoundException

    今天有解决方案部的小伙伴反映,我公司XWind产品在分析客户应用程序的潜在性能问题时,总是显现诊断任务异常,为了定位问题的根因,我们马上要求解决方案部的小伙伴提供XWind相关的日志,从日志中找到了如 ...

  2. 深入理解 python 虚拟机:生成器停止背后的魔法

    深入理解 python 虚拟机:生成器停止背后的魔法 在本篇文章当中主要给大家介绍 Python 当中生成器的实现原理,尤其是生成器是如何能够被停止执行,而且还能够被恢复的,这是一个非常让人疑惑的地方 ...

  3. 分布式与微服务——Iaas,Paas和Saas、单体应用和缺点、微服务概念、传统 分布式 SOA 架构与微服务架构的区别、微服务实战、什么是RPC、CAP定理和BASE理论、唯一ID生成、实现分布式

    文章目录 1-什么是Iaas,Paas和Saas 一 IaaS基础设施服务 二 paas平台即服务 三saas软件即服务 四 总结 2-单体应用和缺点 一 单体应用 二 单体应用的缺陷 3-微服务概念 ...

  4. WPF 笔迹算法 从点集转笔迹轮廓

    本文将告诉大家一些笔迹算法,从用户输入的点集,即鼠标轨迹点或触摸轨迹点等,转换为一个可在界面绘制显示笔迹画面的基础数学算法.尽管本文标记的是 WPF 的笔迹算法,然而实际上本文更侧重基础数学计算,理论 ...

  5. Noi-Linux 2.0 装机+使用整合

    写在前面 网上的东西比较多,也比较杂乱,不是很方便,所以我整合了一些关于 Noi-Linux2.0 虚拟机装机方法+代码编辑环境+实地编程的介绍,看完至少能用起来打代码了. NOI 官网公告(JS 开 ...

  6. mac os 升级到13后,系统免密失败

    # sudo vim /etc/ssh/ssh_config # 添加以下内容 PubkeyAcceptedKeyTypes +ssh-rsa

  7. 深入解析css-笔记

    前言 本文章是根据<深入解析CSS>一书所作的学习笔记,书中的知识点基本都概括在这.希望对您有帮助,另外本博客是通过word笔记文档导入,虽然后续对内容和代码相关进行了一些格式处理,但还是 ...

  8. Java比赛常用API总结

    1.栈和队列 1.1 栈的常用方法 //1.栈顶插入元素 push(element) //2.返回栈顶元素并弹出栈顶元素 pop() //3.返回栈顶元素但不弹出 peek() //4.清空栈 cle ...

  9. NLP文本生成全解析:从传统方法到预训练完整介绍

    本文深入探讨了文本生成的多种方法,从传统的基于统计和模板的技术到现代的神经网络模型,尤其是LSTM和Transformer架构.文章还详细介绍了大型预训练模型如GPT在文本生成中的应用,并提供了Pyt ...

  10. [动态树] Link-Cut Tree

    Link-Cut Tree 0x00 绪言 学长们讲 LCT 的时候,我在另一个机房摸鱼,所以没有听到,就回家看 yxc 的补了补. 0x01 什么是动态树 动态树问题, 即要求我们维护一个由若干棵子 ...