图神经网络(GNN)中的图注意力网络(GAT)层间注意力机制与多头注意力融合策略详解
1. 问题描述
图注意力网络(GAT)通过注意力机制为邻居节点分配不同的权重,克服了GCN中权重固定的限制。但在深层GAT中,层间注意力机制如何协作?多头注意力如何融合?这是理解GAT表达能力的关键。
2. 单层GAT的注意力机制回顾
- 输入:节点特征集合 \(\mathbf{h} = \{\vec{h}_1, \vec{h}_2, ..., \vec{h}_N\}\),其中 \(\vec{h}_i \in \mathbb{R}^F\)
- 注意力系数计算:
\[ e_{ij} = a(\mathbf{W}\vec{h}_i, \mathbf{W}\vec{h}_j) \]
其中 \(\mathbf{W} \in \mathbb{R}^{F' \times F}\) 是共享权重矩阵,\(a\) 是单层前馈网络。
- 归一化注意力权重:
\[ \alpha_{ij} = \frac{\exp(\text{LeakyReLU}(e_{ij}))}{\sum_{k \in \mathcal{N}_i} \exp(\text{LeakyReLU}(e_{ik}))} \]
- 输出特征:
\[ \vec{h}_i' = \sigma\left(\sum_{j \in \mathcal{N}_i} \alpha_{ij} \mathbf{W} \vec{h}_j\right) \]
3. 多层GAT的层间注意力协作机制
- 逐层权重更新:每一层的注意力权重 \(\alpha_{ij}^{(l)}\) 基于上一层输出特征 \(\mathbf{H}^{(l-1)}\) 动态计算
\[ \alpha_{ij}^{(l)} = f_{\text{att}}(\mathbf{W}^{(l)} \vec{h}_i^{(l-1)}, \mathbf{W}^{(l)} \vec{h}_j^{(l-1)}) \]
- 感受野扩展:浅层关注局部邻居,深层通过堆叠聚合高阶邻居信息
- 第1层:聚合1-hop邻居
- 第2层:聚合2-hop邻居(通过邻居的邻居)
- 注意力权重的演化:深层网络可学习到更复杂的依赖关系(如社区结构、节点重要性)
4. 多头注意力的三种融合策略
- 拼接(Concat)策略(常用于中间层):
\[ \vec{h}_i' = \|_{k=1}^K \sigma\left(\sum_{j \in \mathcal{N}_i} \alpha_{ij}^k \mathbf{W}^k \vec{h}_j\right) \]
- 输出维度:\(K \times F'\)(特征拼接)
- 优点:保留各头独立性,增强表达能力
- 平均(Average)策略(常用于输出层):
\[ \vec{h}_i' = \sigma\left(\frac{1}{K} \sum_{k=1}^K \sum_{j \in \mathcal{N}_i} \alpha_{ij}^k \mathbf{W}^k \vec{h}_j\right) \]
- 输出维度:\(F'\)(特征取平均)
- 优点:稳定训练,减少方差
- 加权融合策略(进阶方法):
\[ \vec{h}_i' = \sum_{k=1}^K \beta_k \left( \sum_{j \in \mathcal{N}_i} \alpha_{ij}^k \mathbf{W}^k \vec{h}_j \right) \]
- 其中 \(\beta_k\) 是可学习的头重要性权重
- 优点:自适应调整多头贡献
5. 层间与多头协作的数学表达
设GAT有 \(L\) 层,每层 \(K\) 个头:
- 第 \(l\) 层第 \(k\) 头的输出:
\[ \vec{h}_{i}^{(l,k)} = \sum_{j \in \mathcal{N}_i} \alpha_{ij}^{(l,k)} \mathbf{W}^{(l,k)} \vec{h}_j^{(l-1)} \]
- 层内多头融合(以拼接为例):
\[ \vec{h}_i^{(l)} = \|_{k=1}^K \text{ELU}\left(\vec{h}_{i}^{(l,k)}\right) \]
- 层间传递:\(\mathbf{H}^{(l)}\) 作为 \(\mathbf{H}^{(l+1)}\) 的输入
6. 实际应用中的设计选择
- 残差连接:深层GAT添加跨层连接 \(\vec{h}_i^{(l)} + \vec{h}_i^{(l-1)}\) 缓解梯度消失
- 注意力dropout:对归一化前的 \(e_{ij}\) 使用dropout增加鲁棒性
- 批量归一化:对 \(\mathbf{W}\vec{h}_i\) 进行归一化稳定训练
7. 实例说明
假设一个3层GAT(每层4头)处理社交网络:
- 第1层:学习直接朋友的注意力(如亲密程度)
- 第2层:聚合朋友的朋友,捕捉社区倾向
- 第3层:融合多阶信息,识别用户影响力
- 最终输出层使用平均融合生成节点嵌入
通过层间协作与多头融合,GAT能够自适应学习复杂图结构中的异构关系,这是其优于传统GCN的核心原因。