CS224W 图机器学习16 PART1:图上下文学习框架

由于本小节为论文研讨课,因此本文将以论文阅读笔记的形式展开

《PRODIGY: Enabling In-context Learning Over Graphs》

摘要:

  • 本文提出了一种名为 PRODIGY(Pretraining Over Diverse In-Context Graph System)的预训练框架,该框架借鉴了大模型的训练思路,先通过图提示(Graph Prompt)来表示图的上下文学习任务,再使用名为邻居匹配(neighbor matching)进行自监督预训练,使得图神经网络学习节点/边的表示
  • 最终的实验也表示了该框架的有效性,在未训练过的数据上也具备较好的预测表现; 预测性能相比于普通的分类模型平均提高 18%,在有限数据微调的情况下预测性能平均提高 32.6%

原始论文

1 图的上下文学习

大语言模型(LLMs)的上下文学习:经过预训练的 LLMs 能根据文本提示或任务示例来直接对下游任务进行预测,而无需更新模型权重,这种能力也被称为上下文学习(in-context learning,ICL)或语境学习

不同于 LLMs 的上下文学习,过往图模型的预训练过程一般都满足图编码器的一般范式,即先使用有监督任务(比如节点预测或图分类)进行预训练,再根据具体的下游任务微调分类头(微调的过程需要较多的下游任务数据)

本节课提出的 PRODIGY 框架,是首个支持图上下文学习的框架,以 PRODIGY 框架预训练的模型可以在不同图提示(Graph Prompt)的作用下,无需微调即可适应不同的下游任务,并且性能比微调模型更好

图上下文学习的难点:

  1. 如何在图机器学习中模拟 LLMs 中的自然语言提示?图提示(Graph Prompt)
  2. 如何设计模型架构和预训练目标,让模型实现跨任务上下文学习能力?邻居匹配自监督

2 图提示的表示

上下文学习的 LLMs 在处理下游任务时需要在文本提示中给出一定的示例,来方便 LLMs 理解下游任务的性质,并更好地给出预测结果。这一过程也被称为少样本提示(Few-shot Prompting)

  • 图 (A):上下文学习的图模型也需要相应的少样本提示过程,只是其提示示例的内容是以图数据的形式,这种提示示例也被称为图提示(Graph Prompt);具体来说,图提示会包含数据图(Data Graphs)和任务图(Task Graphs)两部分
  • 图 (B):数据图(Data Graphs)包含了输入节点及其上下文信息,其具体形式一般为节点邻域检索抽样得到的子图(显式)或基于图编码器得到的嵌入表示(隐式);其目的是尽可能收集与输入节点有关的信息
  • 图 (C):任务图(Task Graphs)用于捕获输入和标签之间的联系,实现多种下游任务类型的适配
  • 在数据节点和标签节点之间的边有三种类型,输入的查询数据节点 $Q$ 与所有标签相连(黄色边);示例数据节点 $S$ 与真实标签节点的边为绿色(标记为 T),与其他标签节点的边为红色(标记为 F)
  • 不同的标签节点代表不同的预测任务,还可以在标签节点特征中附件任务信息和指令以提高通用性

3 基于图提示的模型架构

数据图(Data Graphs)的消息传递:

  • 应用 GNN 模型 $M_D$ (比如 GCN 或 GAT)来学习每个节点的嵌入表示 $E$,其嵌入维度为 $d$
  • 对于每个数据图,针对其包含的所有节点嵌入进行聚合池化操作,得到每个数据图的嵌入 $G$
  • 对于节点分类问题,直接使用更新后的节点嵌入作为最终的数据图嵌入:$G_i=E_{v_i}$
  • 对于链接预测问题,先聚合链接 $i$ 上所有节点的嵌入,再拼接起止节点的嵌入 $E_{v_1}$ 和 $E_{v_2}$,最后添加额外的线性投影层保持最终的嵌入维度为 $d$:$G_i=W^T(E_{v_1}||E_{v_2}||max(E_i))+b$
  • 其中 $W\in R^{3d\times d}$ 是可学习的权重矩阵,$b$ 是可学习的偏差项

注意,数据图的消息传递中不存在查询数据 $Q$ 与示例数据 $S$ 之间的通信

任务图(Task Graphs)的消息传递:

  • 假设输入的示例数据图嵌入为 $G$,应用 GNN 模型 $M_T$ 来更新任务图上的节点表示 $H$
  • 任务图上的数据节点嵌入 $H_i$ 用 $M_D$ 模型输出的节点嵌入进行初始化,标签节点嵌入 $H_ji$ 则进行随机初始化(可能附带额外的标签信息),数据节点和标签节点之间的每条边用二值化特征 $e_{ij}$ 来表示边的类型
  • 任务图中节点和边嵌入的更新则依赖基于 Attention 机制的 GNN,公式如下:

$$ \begin{array}{c}{{\beta_{i j}=M L P\left(W_{q}^{T}H_{i}^{l}||W_{k}^{T}H_{j}^{l}||e_{i j}\right)}} \\ {{\alpha_{i j}=\frac{\exp(\beta_{i j})}{\sum_{k\in{\cal N}(i)\cup{i}}\exp(\beta_{i k})}}} \\ {{H_{i}^{l+1}=R e L U\left(B N\left(H_{i}^{l}+W_{o}^{T}\sum_{j\in{\cal N}(i)\cup{i}}\alpha_{i j}W_{v}^{T}H_{j}^{l}\right)\right)}}\end{array} $$

  • 最终任务分类的判定则依赖于每组查询中数据表示和标签表示间的余弦相似度:

$$ O_{i}=[\mathrm{cosine\similarity}(H\{x_{i}},H_{y}),\forall,y\in{\mathcal{Y}}] $$

4 图的上下文预训练目标

为了方便模型直接应用于下游任务,需要以图提示的形式显式构建上下文预训练任务

本文提出了两种上下文预训练任务方法:邻居匹配和多任务

  1. 基于邻居匹配的自监督预训练任务:

  • 邻居匹配的任务目标:输入包括查询节点、采样一组节点及其 2-hop 邻节点,任务目标/期望输出是对查询节点进行分类,即判断该查询节点是否为采样节点的 2-hop 邻节点
  • 基于邻居匹配的预训练的模型非常适合节点分类的下游任务;对于下游任务为链路预测的情况,可考虑将预训练目标中的邻居匹配问题调整为边匹配问题;
  1. 基于多任务的有监督预训练任务
  • 当预训练图中的节点或边具备标签时,可以考虑利用这些信息来执行有监督的预训练
  • 有监督预训练能直接应用于格式类似的下游任务,但可能不兼容这预训练中不存在的标签

在预训练时还通过节点删除(DropNode)或特征屏蔽(MaskNode)等方式进行图数据增强

无论是自监督的邻居匹配预训练,还是有监督的多任务预训练,都采用交叉熵作为损失函数

5 PRODIGY 框架的实验分析

数据集说明:

  • 预训练数据集 1: MAG240M,一个具有 1.22 亿个节点和 13 亿条边的大规模引用网络
  • 预训练数据集 2: Wiki ,由 Wikipedia 构建的知识图谱(KG),有 480 万个节点和 590 万条边
  • 下游任务数据集(对应节点或边的分类任务): arXiv 、 ConceptNet、 FB15K-237 、 NELL

模型初始化:

  • 数据图:主要使用预训练的 RoBERTa 来进行引文网络中的节点特征初始化
  • 任务图:主要使用预训练的 MPNet 来进行节点和边特征的初始化

不同预训练模型的下游任务(节点分类)表现:

Classes NoPretrain Contrastive PG-NM PG-MT PRODIGY Finetune
3 33.16 ± 0.30 65.08 ± 0.34 72.50 ± 0.35 65.64 ± 0.33 73.09 ± 0.36 65.42 ± 5.53
5 18.33 ± 0.21 51.63 ± 0.29 61.21 ± 0.28 51.97 ± 0.27 61.52 ± 0.28 53.49 ± 4.61
10 9.19 ± 0.11 36.78 ± 0.19 46.12 ± 0.19 37.23 ± 0.20 46.74 ± 0.20 30.22 ± 3.77
20 4.72 ± 0.06 25.18 ± 0.11 33.71 ± 0.12 25.91 ± 0.12 34.41 ± 0.12 17.68 ± 1.15
40 2.62 ± 0.02 17.02 ± 0.07 23.69 ± 0.06 17.19 ± 0.08 25.13 ± 0.07 8.04 ± 3.00
  • PG-NM 表示仅依赖邻居匹配的预训练方法,PG-MT 表示多任务的预训练方法
  • PRODIGY 框架的提升表现在纯自监督的 PG-NM,也能更好地适用于下游任务
  • PRODIGY 框架的准确率明显高于其他方法,比作为最佳基线的对比学习(Contrastive)高 28.6%~48%,在所有方法平均改进了 77%,也优于目前的 SOTA 微调方法

不同预训练模型的下游任务(边分类)表现:

Classes NoPretrain Contrastive PG-NM PG-MT PRODIGY Finetune
4 30.4 ± 0.63 44.01 ± 0.61 46.94 ± 0.61 51.78 ± 0.63 53.97 ± 0.63 53.85 ± 9.29
5 33.54 ± 0.61 81.35 ± 0.58 80.35 ± 0.57 89.15 ± 0.46 88.02 ± 0.48 82.01 ± 12.83
10 20.0 ± 0.35 70.88 ± 0.48 71.68 ± 0.45 82.26 ± 0.40 81.1 ± 0.39 71.97 ± 6.16
20 9.2 ± 0.18 59.8 ± 0.35 59.9 ± 0.35 73.47 ± 0.32 72.04 ± 0.33 64.01 ± 4.66
40 2.5 ± 0.08 49.39 ± 0.23 46.82 ± 0.21 58.34 ± 0.22 59.58 ± 0.22 57.27 ± 3.33
5 33.44 ± 0.57 84.08 ± 0.54 80.53 ± 0.58 84.79 ± 0.51 87.02 ± 0.44 87.22 ± 12.75
10 18.82 ± 0.31 76.54 ± 0.45 72.77 ± 0.48 78.5 ± 0.44 81.06 ± 0.41 71.90 ± 5.90
20 7.42 ± 0.16 66.56 ± 0.35 62.82 ± 0.36 69.82 ± 0.34 72.66 ± 0.32 66.19 ± 8.46
40 3.04 ± 0.07 57.44 ± 0.24 49.59 ± 0.22 53.55 ± 0.23 60.02 ± 0.22 55.06 ± 4.19

其他结论:

  • 消融实验表明,属性预测对 PG-NM 的影响最大,删除会降低 7%的性能
  • 随着提示样本和训练集的增加,PRODIGY 的性能与稳定性均优于普通微调方法
  • 由于 PRODIGY 预训练目标的复杂性,PRODIGY 能随着训练集的增多而持续提升

往年同期文章