一、概述

  在人工智能领域,数据的多样性促使研究人员不断探索新的模型与算法。传统的神经网络在处理像图像、文本这类具有固定结构的数据时表现出色,但面对具有不规则拓扑结构的图数据,如社交网络、化学分子结构、知识图谱等,却显得力不从心。

  图神经网络(Graph Neural Networks, GNN)是一种直接在图结构数据上运行的神经网络,用于处理节点、边或整个图的特征信息。其核心思想是通过聚合邻域节点的特征信息来更新当前节点的表示,从而捕捉图中节点间的依赖关系和拓扑结构特征。

二、模型原理

1. 图结构数据的特点

图由节点(vertices)和边(edges)组成,可表示为\(G=\left( V,E \right)\),其中:

  \(V=\left\{ v_1,v_2,...,v_N \right\}\)为节点集合,可能包含特征向量(如用户属性、原子特征等)。

  \(E=\left\{ (v_i,v_j) \right\}\)为边集合,描述节点间的关系,可能带有权重或类型(如社交关系、化学键)。

节点和边的特征表示:

  节点特征矩阵\(X\in R^{N\times F}\)(\(F\)为节点特征维度);

  边特征矩阵\(E\in R^{M\times D}\)(\(M\)为边数,\(D\)为边特征维度);

  邻接矩阵\(A\in R^{N\times N}\)(表示节点连接关系,无向图中矩阵对称)。

图具有以下特性:

  非欧几里得结构:节点间无序,邻居数量可变。

  异构性:图的规模、密度、节点类型可能差异极大。

2.核心机制:消息传递与节点更新

  图神经网络的核心目标之一是为图中的每个节点生成一个具有代表性的向量表示,也就是将节点的复杂特征和其在图中的拓扑结构信息编码到一个向量空间中,便于后续的节点分类、预测等任务。

节点表示的生成过程基于图的拓扑结构和节点自身的特征,利用神经网络的学习能力,自动提取出对任务有价值的信息。其基本思想是通过不断聚合邻居节点的信息,并结合自身的特征,逐步更新节点的表示,使得每个节点能够充分反映其在图中的角色和上下文信息。

(1)消息聚合(Message Aggregation)

  对于每个节点\(v\),收集其邻域节点\(N(v)\)的特征信息,生成聚合消息\(m\)。

常用聚合函数包括:

  求和(Sum):\(m_i=\sum_{v_j\in N(v_i)}{ReLU(W\cdot h_j+b)}\)

  均值(Mean):\(m_i=\frac{1}{\left| N(v_i) \right|}\sum_{v_j\in N(v_i)}{h_j}\)

  最大值(Max Pooling):\(m_i=\max_{v_j\in N(v_i)}\left\{ h_j \right\}\)

其中,\(h_j\)为邻域节点 \(v_j\)的隐藏状态,\(W\)和\(b\)为可学习参数。

(2)节点状态更新(Update)

利用聚合得到的消息\(m_i\)和当前节点的旧状态 \(h_{i}^{(l)}\),更新节点的隐藏状态:

\[h_{i}^{(l+1)}=\sigma\left( h_{i}^{(l)}\oplus m_i \right)
\]

其中 \(\sigma\)为激活函数(如 ReLU、Sigmoid),\(\oplus\)表示拼接或线性变换操作。

三、典型 GNN 模型架构

  不同 GNN 模型的差异主要体现在消息聚合方式和图结构处理策略上,几种典型模型为:

1. 图卷积网络(GCN, Graph Convolutional Network)

  简化了消息传递过程,通过对称归一化的邻接矩阵直接聚合邻居:

\[h_{i}^{(l+1)}=\sigma \left( \hat{D}^{-\frac{1}{2}}\hat A \hat D^{-\frac{1}{2}}h^{(l)}W^{(l)} \right)
\]

  其中,\(\hat A=A+I\)(\(I\)为单位矩阵,引入自环),\(\hat D\)为\(\hat A\)的度矩阵(对角矩阵,\(\hat D_{ii}=\sum_{j}{\hat A_{ij}}\))。

2. 图注意力网络(GAT, Graph Attention Network)

  引入注意力机制,动态学习邻居的重要性权重:

\[h_{v}^{(l+1)}=\sigma\left( \sum_{u\in N(v)}{\alpha_{uv}Wh_{u}^{(l)}} \right)
\]

  其中,\(\alpha_{uv}\)是通过注意力机制计算的归一化权重。

3. 图采样与聚合网络(GraphSAGE, Graph SAmple and aggreGatE)

  核心思想:对大规模图进行子图采样,避免全图计算的高复杂度。

  采样策略:随机采样固定数量的邻域节点(如固定采样 5 个邻居),再通过聚合函数(如均值、LSTM、池化)更新节点表示。

  适用场景:适用于归纳学习(Inductive Learning,处理训练中未出现的节点)。

四、优势与挑战

优势:

  结构感知:直接利用图的拓扑结构,捕捉节点间依赖关系;

  灵活性:适用于多种图类型(有向图、无向图、异质图);

  可扩展性:结合采样技术可处理大规模图数据。

挑战:

  过平滑(Over-smoothing):深层 GNN 中节点特征趋于同质化,丢失区分度;

  异质图处理:节点和边类型多样时,需设计更复杂的聚合方式;

  计算效率:全图计算的时间复杂度高,需优化采样或稀疏矩阵运算。

五、应用场景

  社交网络:用户行为预测、社区检测;

  生物医学:分子特性预测、药物研发(如 GNN 用于预测蛋白质相互作用);

  推荐系统:建模用户-物品交互图,提升推荐准确性;

  计算机视觉:点云数据处理、场景图生成;

  知识图谱:链接预测、实体分类;

  交通网络:流量预测、路径优化。

六、Python实现示例

(环境:Python 3.11,PyTorch 2.4.0)

import torch
import torch.nn as nn
import torch.nn.functional as F class GraphConvolution(nn.Module):
def __init__(self, input_dim, output_dim):
super(GraphConvolution, self).__init__()
self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim))
self.bias = nn.Parameter(torch.FloatTensor(output_dim))
self.reset_parameters() def reset_parameters(self):
nn.init.xavier_uniform_(self.weight)
nn.init.zeros_(self.bias) def forward(self, x, adj):
support = torch.mm(x, self.weight)
output = torch.spmm(adj, support)
return output + self.bias class GNN(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(GNN, self).__init__()
self.gc1 = GraphConvolution(input_dim, hidden_dim)
self.gc2 = GraphConvolution(hidden_dim, output_dim) def forward(self, x, adj):
x = F.relu(self.gc1(x, adj))
x = F.dropout(x, training=self.training)
x = self.gc2(x, adj)
return F.log_softmax(x, dim=1) # 示例用法
def test_gnn():
# 创建一个简单的3节点图
# 节点特征矩阵 (3节点,每个节点特征维度为4)
features = torch.FloatTensor([
[0.1, 0.2, 0.3, 0.4],
[0.5, 0.6, 0.7, 0.8],
[0.9, 1.0, 1.1, 1.2]
]) # 邻接矩阵 (3x3)
adj = torch.FloatTensor([
[1, 1, 0],
[1, 1, 1],
[0, 1, 1]
]) # 添加自环并归一化
adj = adj + torch.eye(adj.size(0))
d_inv_sqrt = torch.pow(adj.sum(1), -0.5).flatten()
d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0.
d_mat_inv_sqrt = torch.diag(d_inv_sqrt)
adj = torch.mm(torch.mm(d_mat_inv_sqrt, adj), d_mat_inv_sqrt) # 创建GNN模型
model = GNN(input_dim=4, hidden_dim=8, output_dim=2) # 前向传播
output = model(features, adj)
print("GNN输出:", output) # 随机生成标签并计算损失
labels = torch.LongTensor([0, 1, 0])
loss = F.nll_loss(output, labels)
print("损失值:", loss.item()) if __name__ == "__main__":
test_gnn()

示例实现了一个简单的两层图神经网络,包含

  1. GraphConvolution类实现了基本的图卷积操作,包括权重矩阵和偏置项;

  2. GNN类定义了一个两层GNN模型,使用ReLU激活函数和dropout;

  3. 代码展示了如何创建图数据(特征矩阵和邻接矩阵);

  4. 包含了邻接矩阵的预处理(添加自环和归一化)。

七、小结

  图神经网络通过消息传递机制聚合邻域信息,实现了图结构数据的高效建模。其核心在于设计合理的聚合函数和更新规则,以捕捉不同场景下的图特征。随着研究深入,GNN 在理论分析(如泛化能力、表达能力)和应用创新(如异质图、动态图)方面仍在不断发展,未来有望在更多复杂图任务中发挥关键作用。

End.

图神经网络(GNN)模型的基本原理的更多相关文章

  1. PGL图学习之图神经网络GNN模型GCN、GAT[系列六]

    PGL图学习之图神经网络GNN模型GCN.GAT[系列六] 项目链接:一键fork直接跑程序 https://aistudio.baidu.com/aistudio/projectdetail/505 ...

  2. zz【清华NLP】图神经网络GNN论文分门别类,16大应用200+篇论文最新推荐

    [清华NLP]图神经网络GNN论文分门别类,16大应用200+篇论文最新推荐 图神经网络研究成为当前深度学习领域的热点.最近,清华大学NLP课题组Jie Zhou, Ganqu Cui, Zhengy ...

  3. 图机器学习(GML)&图神经网络(GNN)原理和代码实现(前置学习系列二)

    项目链接:https://aistudio.baidu.com/aistudio/projectdetail/4990947?contributionType=1 欢迎fork欢迎三连!文章篇幅有限, ...

  4. Github项目推荐-图神经网络(GNN)相关资源大列表

    文章发布于公号[数智物语] (ID:decision_engine),关注公号不错过每一篇干货. 转自 | AI研习社 作者|Zonghan Wu 这是一个与图神经网络相关的资源集合.相关资源浏览下方 ...

  5. NLP知识图谱项目合集(信息抽取、文本分类、图神经网络、性能优化等)

    NLP知识图谱项目合集(信息抽取.文本分类.图神经网络.性能优化等) 这段时间完成了很多大大小小的小项目,现在做一个整体归纳方便学习和收藏,有利于持续学习. 1. 信息抽取项目合集 1.PaddleN ...

  6. 【GNN】图神经网络小结

    图神经网络小结 图神经网络小结 图神经网络分类 GCN: 由谱方法到空域方法 GCN概述 GCN的输出机制 GCN的不同方法 基于谱方法的GCN 初始 切比雪夫K阶截断: ChebNet 一阶Cheb ...

  7. PGL图学习之图神经网络ERNIESage、UniMP进阶模型[系列八]

    PGL图学习之图神经网络ERNIESage.UniMP进阶模型[系列八] 原项目链接:fork一下即可:https://aistudio.baidu.com/aistudio/projectdetai ...

  8. 图神经网络(GNN)--slide

    课件是学习小组汇报时用的,许多资料是从大佬哪里搬运的.Tex文档也在里面. GNN课件,下载不了,可以点击 带你入门图神经网络(GNN) 图神经网络(GNN)学习推荐网址 傅里叶分析之掐死教程(完整版 ...

  9. 知识图谱-生物信息学-医学顶刊论文(Bioinformatics-2021)-KG4SL:用于人类癌症综合致死率预测的知识图神经网络

    5.(2021.7.12)Bioinformatics-KG4SL:用于人类癌症综合致死率预测的知识图神经网络 论文标题:KG4SL: knowledge graph neural network f ...

  10. PGL图学习之图神经网络GraphSAGE、GIN图采样算法[系列七]

    0. PGL图学习之图神经网络GraphSAGE.GIN图采样算法[系列七] 本项目链接:https://aistudio.baidu.com/aistudio/projectdetail/50619 ...

随机推荐

  1. go mod 安装使用 beego

    go module基本使用 // 创建目录,初始化新项目 mkdir beemod cd beemod go mod init beemod 创建 server.go 文件 package main ...

  2. MySQL配置主从复制教程(MySQL8)

    原理: 数据库在进行DDL和DML语句操作时,会被记录到binlog的日志文件里,而读取这里面的日志就可以知道数据库进行过哪些DDL和DML操作,这是主数据库的日志,从数据库经过相关配置可以实时获取到 ...

  3. Netty源码—3.Reactor线程模型二

    大纲 5.NioEventLoop的执行总体框架 6.Reactor线程执行一次事件轮询 7.Reactor线程处理产生IO事件的Channel 8.Reactor线程处理任务队列之添加任务 9.Re ...

  4. 【硬件】认识和选购4K画质的显卡

    2.6 认识和选购4K画质的显卡 显卡一般是一块独立的电路板,插在主板上接收由主机发出的控制显示系统工作的指令和显示内容的数字信号,然后通过输出模拟(或数字)信号控制显示器显示各种字符和图形,它和显示 ...

  5. MySQL InnoDB 引擎中的聚簇索引和非聚簇索引有什么区别?

    MySQL InnoDB 引擎中的聚簇索引和非聚簇索引的区别 在 MySQL 的 InnoDB 存储引擎中,聚簇索引和非聚簇索引是两种常见的索引类型,它们在数据存储结构和使用场景上有显著区别. 1. ...

  6. 聊聊SpringAI流式输出的底层实现?

    在 Spring AI 中,流式输出(Streaming Output)是一种逐步返回 AI 模型生成结果的技术,允许服务器将响应内容分批次实时传输给客户端,而不是等待全部内容生成完毕后再一次性返回. ...

  7. Rust实战系列-Rust介绍

    " 学习资料:rust in action[1] 1. Rust 安装 curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | ...

  8. RabbitMq在win10上的安装、用户管理及控制台Demo

    思路: 安装elang--设置elang的环境变量--安装erlang版本对应的rabbitmq--设置rabbitmq的环境变量--安装rabbitmq的可视化管理插件 相关链接: RabbitMQ ...

  9. GitLab CI/CD 的配置文件 .gitlab-ci.yml 简介

    〇.前言 .gitlab-ci.yml 文件主要用于项目的自动化部署配置,自动化可以大大提升团队效率,但同时这个文件的内容也比较复杂,弄清楚也并非易事,本文将对此文件的内容进行简单介绍,供参考. 另外 ...

  10. 『Plotly实战指南』--交互功能进阶篇

    在数据可视化的世界中,交互性是提升用户体验和数据探索效率的关键.从简单的悬停提示到复杂的动态数据更新,交互功能让静态图表变得生动起来. 本文将介绍Plotly的高级交互功能,包括点击事件处理.动态数据 ...