图神经网络(GNN)中的图分类任务与图级表示学习详解
一、问题描述
图分类是图神经网络(GNN)中的一项核心任务,其目标是对整个图结构(Graph-Level)进行预测或分类。与节点分类(预测图中单个节点的标签)或边预测(预测节点间是否存在连接)不同,图分类需要将整个图映射到一个类别标签。例如,在分子性质预测中,每个分子可表示为一个图(原子为节点,化学键为边),需判断该分子是否具有特定毒性或活性。
关键挑战:如何将图中所有节点的信息有效聚合,生成一个能够代表全图结构的固定维度的嵌入向量(即图级表示),并基于此向量进行分类。
二、图级表示学习的核心步骤
图分类流程通常分为三步:
- 节点嵌入生成:通过GNN层迭代聚合邻居信息,学习每个节点的低维向量表示。
- 图级池化:将所有节点嵌入聚合为单一图嵌入向量。
- 分类器预测:将图嵌入输入全连接层和Softmax函数,输出分类结果。
以下逐步解析关键环节:
步骤1:节点嵌入生成(消息传递机制)
- 设图 \(G = (V, E)\),其中 \(V\) 为节点集合,\(E\) 为边集合。
- 每个节点 \(v\) 初始特征为 \(h_v^{(0)}\)。
- 在第 \(k\) 层GNN中,节点通过聚合邻居信息更新自身表示:
\[ h_v^{(k)} = \text{UPDATE}\left( h_v^{(k-1)}, \text{AGGREGATE}\left( \{ h_u^{(k-1)} \mid u \in \mathcal{N}(v) \} \right) \right) \]
其中 \(\mathcal{N}(v)\) 是节点 \(v\) 的邻居集合,AGGREGATE 常用求和、均值或最大值池化,UPDATE 可通过神经网络(如MLP)实现。
- 经过 \(K\) 层迭代后,每个节点获得包含 \(K\)-跳邻居信息的嵌入 \(h_v^{(K)}\)。
步骤2:图级池化(关键难点)
直接拼接所有节点嵌入会导致向量维度随节点数变化,无法输入固定维度的分类器。常用池化方法包括:
2.1 全局池化
- 求和池化:图嵌入 \(h_G = \sum_{v \in V} h_v^{(K)}\)。
- 均值池化:图嵌入 \(h_G = \frac{1}{|V|} \sum_{v \in V} h_v^{(K)}\)。
- 最大值池化:图嵌入 \(h_G = \max_{v \in V} h_v^{(K)}\)(按维度取最大值)。
- 优缺点:计算简单,但可能丢失图中局部结构信息(如所有节点嵌入求和后无法区分环状与链状结构)。
2.2 层次化池化
为保留局部结构,引入分层池化操作:
- 图粗化:将相似节点聚类为超节点,形成规模更小的新图。
- 示例(DiffPool):
- 学习一个软分配矩阵 \(S \in \mathbb{R}^{n \times m}\),将 \(n\) 个节点映射到 \(m\) 个簇(\(m < n\))。
- 新图的节点特征为 \(X' = S^T H\),邻接矩阵为 \(A' = S^T A S\)。
- 重复多轮后,最终用全局池化生成图嵌入。
- 优点:可捕获图中层次化社区结构。
- 缺点:分配矩阵的学习增加了计算复杂度。
2.3 注意力池化
- 如Self-Attention Pooling(SAGPool):
- 计算每个节点的注意力分数,保留得分最高的部分节点,剔除低分节点。
- 通过注意力权重加权聚合,保留重要子结构。
步骤3:分类与损失函数
- 将图嵌入 \(h_G\) 输入全连接层: \(z = W h_G + b\)。
- 使用Softmax输出概率分布: \(\hat{y} = \text{softmax}(z)\)。
- 损失函数常采用交叉熵损失: \(L = -\sum_{i} y_i \log \hat{y}_i\),其中 \(y\) 为真实标签。
三、代表性模型简析
- GraphSAGE(归纳式学习):
- 通过采样邻居生成节点嵌入,图分类时直接使用全局最大池化。
- GIN(图同构网络):
- 证明当GNN的聚合函数为单射时,其区分图结构的能力最强。
- 图分类公式: \(h_G = \text{MLP} \left( \sum_{v \in V} h_v^{(K)} \right)\),强调求和池化的重要性。
四、关键设计原则
- 过平滑问题:GNN层数过多时,节点嵌入趋于相似,导致图级表示失真。需平衡层数与感受野。
- 图结构敏感性:池化操作需保留图的拓扑特征(如连通性、子图模式)。
- 计算效率:层次化池化虽效果好,但可能成为训练瓶颈,需权衡性能与资源。
总结:图分类的核心在于通过消息传递和池化操作将不规则图结构转化为有意义的固定维向量。设计需结合任务特性选择池化策略,确保图级表示同时包含节点属性与拓扑信息。