CS224W 图机器学习15:GNN 拓展到大型图

1 大型图应用与难点

当前的大型图应用场景:

  1. 推荐系统(亚马逊、Youtube、Pinterest 等):用户规模在 100M~1B,产品/视频规模在 10M~1B,任务包括商品推荐或用户分类
  2. 社交网络(Facebook、X、Instagram 等):用户规模在 300M~3B,任务包括好友推荐或用户属性预测
  3. 学术网络(微软学术图谱):作者或论文规模在 120M,任务包括论文分类、合作作者推荐、论文引用推荐
  4. 知识图谱(Wikidata、Freebase 等):实体规模在 80~90M,任务包括图谱补全、知识推理

GNN 应用到大型图的难点:

  1. 随机梯度下降法(SGD)失效:在小批次 SGD 训练时,大规模图采样得到的节点往往彼此隔离,进而影响 GNN 的训练(GNN 通过聚合邻近节点特征来生成节点嵌入)
  2. 全批量梯度法失效: GPU 计算快但内存小(10GB~20GB);整个图和特征无法直接加载到 GPU 中

GNN 扩展到大规模图的方法:

  1. 在子图上执行消息的传递,这样子图可以完整加载到 GPU 内并用于训练;比如 Neighbor Sampling 或 Cluster-GCN
  2. 简化 GNN 过程,转化为一种可以在 CPU 中高效执行的特征预处理操作;比如 Simplified GCN

2 邻域采样的 GraphSAGE

对于单个节点的嵌入表示计算,只需要 K 邻域节点来定义计算图

基于 SGD 策略训练 K 邻域的 GNN:

  • 随机采样 $M$ 个节点,获取每个节点的 $K$ 邻域并构建计算图
  • 基于计算图生成节点的嵌入表示,最后计算 $M$ 个节点的平均损失 $l_{sub}(\theta)$
  • 基于平均损失,使用 SGD 策略更新模型的参数:$\theta <- \theta-\nabla l_{sub}(\theta)$

当前训练策略的问题:

  1. 每个节点嵌入需要聚合大量的信息,并且存在大量的计算冗余
  2. 随着邻域 $K$ 的增加,计算图的复杂度呈指数级增强(尤其是度高的节点)

调整训练策略:构建计算图时,每个节点最多只采样 $H$ 个邻节点

  • 假设第 $k$ 层的节点,最多采样 $H_k$ 个邻节点,则 K 层 GNN 计算图最多涉及 $\Pi_{k=1}^KH_k$ 个节点
  • 较小的 $H$ 能实现更高效的邻居信息聚合,也但也会降低训练的稳定性(聚合信息的方差大)
  • 计算时间依然随着邻域 $K$ 的增加呈指数级增加;每添加一个 GNN 层,计算成本增加 $H$ 倍

邻节点的抽样策略:

  • 随机抽样虽然效率高,但可能采到"不重要"的节点(因此效果一般不是最佳的)
  • 带重启的随机游走,每次游走都有一定概率返回到初始节点(实际实践效果更好)

邻域采样策略,本质是对计算图进行修剪/子采样来提高计算效率

3 子图逐层更新的 Cluster-GCN

在 full-batch GNN 中,所有节点的嵌入使用上一层的嵌入进行同时更新: $$ h_v^{(\ell)}=COMBINE\left(h_v^{(\ell-1)},AGGR\left( \left\{\boldsymbol{h}_u^{(\ell-1)}\right}_{u\in N(v)}\right)\right) $$

  • 假设边的数量为 $E$,GNN 的每一层需要计算 $2E$ 次信息
  • 对于 K 层的 GNN 只需要计算 $2K\times E$ 次信息,计算效率高

逐层的节点嵌入更新能借助历史计算的嵌入,减少了大量的计算冗余

由于 GPU 的内存有限,逐层的节点嵌入更新不适用于大规模图的情况

Cluster-GCN 的训练过程:

  • 首先,Cluster-GCN 使用图社区检测方法(例如 Louvain、METIS)将大图 $G$ 划分为多个节点分组 $V_1,...,V_C$,然后针对每个节点分组构建诱导子图 $G_,...,G_C$
  • 然后,Cluster-GCN 会随机选择一个子图并应用 GNN 的节点逐层更新,获得每个节点的嵌入表示;最后根据该子图的节点嵌入表示计算平均损失,指导参数的更新

子图的构建过程中,应尽可能保留原始图的边连通性结构

Cluster-GCN 的问题:

  1. 诱导子图忽略了组间的链接,导致其他组(子图)的信息丢失,影响 GNN 性能
  2. 图社区检测方法将相似的节点归纳为一组,但每组节点仅覆盖了数据的集中局部
  3. 采样的子图不够多元化,无法表示完整的图结构(不同节点组之间的波动较大,导致计算的梯度不可靠/方差高,影响 SGD 的收敛速度)

Cluster-GCN 的改进:每个小批次考虑并聚合多个节点组

  • 减少每个节点组的规模,同时对多个节点组的采样和聚合
  • 基于多个节点组诱导构建子图,使得模型考虑到组间链接
  • 该方式能改善子图的多元化,但仍存在梯度估计的系统性偏差

假设节点数为 $M$,节点的平均度为 $D_{avg}$,则 Cluster-GCN 的计算复杂度为 $O(K\cdot M\cdot D_{avg})$ ,计算效率远高于邻域采样的 GraphSAGE,尤其是在 $K$ 很大的时候(线性<<指数级)

4 简化模型结构的 SimplGCN

前置知识:LightGCN 通过消除非线性来简化 GCN

SimplGCN 的处理方式类似于 LightGCN,但也有不同之处:

  • SimplGCN 也消除了 GCN 中的非线性激活部分;但与 LightGCN 不同的是,SimplGCN 使用的邻接矩阵 $\tilde{A}$ 会考虑自环的情况,同时遵循原始的 GCN 计算过程
  • SimplGCN 假设输入的节点嵌入表示矩阵 $E$ 是固定的(不需要学习),因此 K 层的 GCN 计算可以一次性完成(该过程可以看作一次预处理步骤):$\tilde{E}=\tilde{A}^KE$

SimplGCN 的性能分析:

  • 简化后的 GCN 的具备极强的可拓展性,其输出 $\tilde{E}$ 配合线性模型的表现接近 GCN
  • 与邻域采样或 Cluster-GCN 相比,SimplGCN 的计算效率更高(常数级的计算复杂度),其预处理过程属于稀疏矩阵的向量积,可以在 CPU 上高效执行
  • SimplGCN 的缺点就是模型表现力不太够,在生成节点嵌入表示时缺少非线性

SimplGCN 在半监督节点分类任务中,表现与原始 GNN 相当,这可能是因为许多节点分类任务表现出同质结构(相互引用的论文更可能属于同一主题,相互关注的网友更可能喜欢同一部电影)

往年同期文章