图神经网络中的图分类任务与图级表示学习详解
一、问题描述
图分类是图神经网络(GNN)的核心任务之一,目标是对整个图结构进行类别预测(如分子性质分类、社交网络类型识别)。与节点分类不同,图分类需生成图级表示(Graph-Level Representation),即如何将整个图的拓扑信息和节点特征聚合为一个固定维度的向量。这一过程面临两大挑战:
- 图结构不规则:不同图的节点数和连接模式各异,需置换不变性(Permutation Invariance)操作。
- 层次化特征融合:需同时捕获局部子结构(如化学官能团)和全局图模式(如分子环状结构)。
二、图级表示学习的基本流程
图分类任务通常分为三步:
- 节点嵌入生成:通过多层GNN迭代聚合邻域信息,得到每个节点的表示。
- 图级池化:将节点嵌入聚合为图级表示。
- 分类器预测:基于图级表示输出分类结果。
三、节点嵌入生成:消息传递机制
以图卷积网络(GCN)为例,单层GNN的节点更新公式为:
\[H^{(l+1)} = \sigma\left(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} H^{(l)} W^{(l)}\right) \]
- \(\tilde{A} = A + I\):带自环的邻接矩阵,避免忽略自身特征。
- \(\tilde{D}\):度矩阵,用于归一化邻接矩阵,防止梯度爆炸。
- \(H^{(l)}\):第\(l\)层的节点嵌入矩阵,\(H^{(0)}\)为初始节点特征。
- 多层堆叠后,每个节点的嵌入包含多跳邻域信息(如2层GCN覆盖二阶邻居)。
四、图级池化方法详解
池化操作需满足置换不变性(节点顺序不影响结果),主流方法包括:
1. 全局池化(Global Pooling)
- 均值/求和池化:直接对节点嵌入取均值或求和:
\[ h_G = \frac{1}{n}\sum_{i=1}^n h_i \quad \text{或} \quad h_G = \sum_{i=1}^n h_i \]
-
优点:简单高效,适用于节点重要性均匀的图。
-
缺点:忽略节点差异(如关键功能节点可能被噪声节点稀释)。
-
最大池化:对每个特征维度取最大值:
\[ h_G[d] = \max_{i=1}^n h_i[d] \]
- 优点:突出显著特征,适合关键节点主导的分类任务。
2. 层次化池化(Hierarchical Pooling)
全局池化仅操作一次,可能丢失局部结构。层次化池化通过逐步压缩图结构保留多尺度信息:
- DiffPool原理:
- 学习软分配矩阵 \(S^{(l)} \in \mathbb{R}^{n_l \times n_{l+1}}\),将第\(l\)层的\(n_l\)个节点聚类为\(n_{l+1}\)个超节点。
- 超节点的嵌入和邻接矩阵更新为:
\[ H^{(l+1)} = S^{(l)^T} H^{(l)}, \quad A^{(l+1)} = S^{(l)^T} A^{(l)} S^{(l)} \]
- 最终对最顶层的超节点嵌入进行全局池化。
- 挑战:分配矩阵需通过额外GNN学习,计算复杂度高。
3. 注意力池化(Attention-Based Pooling)
- SortPooling:对节点嵌入按某一维度排序后截取前\(k\)个节点,输入1D卷积层提取特征。
- SAGPool:通过自注意力分数选择保留重要节点:
\[ \text{得分} = \text{GNN}(A, H), \quad \text{保留top-k节点} \]
保留节点间的边重构子图,迭代池化以捕获层次结构。
五、图级表示学习的优化策略
- 跳跃连接(Skip-Connection):
- 将不同GNN层的节点嵌入拼接后再池化(如Jumping Knowledge Networks),融合多尺度特征。
- 图同构约束:
- 通过图同构网络(GIN)等表达能力强的GNN生成节点嵌入,确保池化后能区分不同拓扑的图。
- 辅助任务:
- 联合训练图重构、节点聚类等自监督任务,提升表示质量。
六、实例分析:分子性质预测
以预测分子毒性为例:
- 输入:分子图(原子为节点,化学键为边)。
- 节点嵌入:3层GCN生成原子特征(考虑原子类型、价键信息)。
- 池化:全局求和池化(因分子性质常与原子数量相关)。
- 分类器:池化结果输入MLP输出毒性概率。
七、关键挑战与前沿方向
- 可解释性:通过注意力权重或子结构识别解释分类依据(如识别毒性官能团)。
- 动态图池化:处理节点/边随时间变化的图。
- 无监督图级表示:通过对比学习(如GraphCL)减少对标签的依赖。
通过上述步骤,图级表示学习将复杂图结构转化为判别性向量,为图分类任务提供坚实基础。实际应用中需根据图特点(如规模、密度)选择池化方法,并结合领域知识优化特征融合策略。