©作者 | 范少华
单位 | 北邮 GAMMA Lab
近日,北京邮电大学 GAMMA Lab 与清华大学合作的论文“Generalizing Graph Neural Networks on Out-Of-Distribution Graph”被人工智能顶级期刊 IEEE TPAMI(影响因子 23.6)接收,图神经网络的分布外泛化能力决定了其在实际应用中的稳定性,是近年来的研究热点,该论文的初始版本于 2021 年 11 月放于 arXiv (https://arxiv.org/abs/2111.10657),是早期将因果方法与图神经网络结合解决图分布外泛化问题的文章之一。
本文将介绍图分布外泛化的关键问题、解决方法、以及未来研究工作。
论文标题:
图神经网络在分布外图上的泛化
论文链接:
http://www.shichuan.org/doc/157.pdf代码链接:https://github.com/googlebaba/StableGNN
目前提出的图神经网络 (GNN) 方法没有考虑训练图和测试图之间的不可知偏差,从而导致 GNN 在分布外(OOD)图上的泛化性能变差。导致 GNN 方法泛化性能下降的根本原因是这些方法都是基于 IID 假设。在此条件下,GNN 模型倾向于利用图数据中的虚假相关进行预测。但是,这样的虚假相关可能在未知的测试环境中改变,从而导致 GNN 的性能下降。因此,消除虚假相关的影响对于实现稳定的 GNN 模型至关重要。为了实现此目的,在本文中,我们强调对于图级别任务虚假相关存在于子图级别单元,并且用因果视角来分析 GNN 模型性能下降的原因。基于因果视角的分析,我们提出了一个统一的因果表示框架用于稳定 GNN 模型,称之为 StableGNN。这个框架的主要思想是首先利用一个可微分的图池化层提取图的高层语义特征,然后借助因果分析的区分能力来帮助模型摆脱虚假相关的影响。因此,GNN 模型可以更加专注于有区分性的子结构和标签之间的真实相关性。我们在具有不同偏差程度的仿真数据和 8 个真实的 OOD 图数据上验证了我们方法的有效性。此外,可解释性实验也验证了 StableGNN 可以利用因果结构做预测。本质上说,对于一般的机器学习方法,当遭受分布偏移问题时,准确率下降的根本原因是不相关特征和类别标签之间的虚假相关导致的。这个虚假相关根本上是由不相关特征和相关特征的意外的相关性导致的。而对于本文研究的图级别任务,由于图的性质通常由子图单元决定(比如,在分子图中,原子和化学键团表示其功能单元),所以我们定义一个子图单元可以是一个对于标签相关的或者不相关的特征单元。如图 1 所示,以’‘房子’‘模体分类任务为例,其中图的标签表示一个图中是否有“房子”模体。GCN 模型是在“房子”模体和“星星”模体高度相关的训练图上训练的。在这个数据上,“房子”模体和“星星”模体将会高度相关。这个意料之外的相关性将会导致“星星”模体的结构特征和“房子”标签的虚假相关。图 1 的第二列展示了用于 GCN 预测的最重要的子图可视化结果 (由 GNNExplainer 产生)。由结果可知,GNN 倾向于利用星星模体做预测。然而当遭遇没有“星星”模体的图,或者其他模体(比如,"钻石"模体)和星星模体在一起时,GCN 模型被证明容易产生错误的结果。▲ 图1. "房子"模体分类例子方法
所提出框架的基本想法是设计一个因果表示学习方法来抽取有意义的图高层语义变量然后估计他们对于图级别任务的真实因果效应。如图 2 所示,所提出的模型框架主要分为两个部分:图高层语义表示学习模块和因果变量区分模块。
▲ 图2. StableGNN的模型框架到目前为止变量学习部分学习的变量可能是虚假相关,在本节中,我们首先分析以因果视角分析导致 GNN 模型性能下降的原因,然后提出一个因果变量区分正则化器(CVD)。
以因果视角重视GNN
我们的目标是学习到一个分类器 基于相关的特征 Z。为了达到这个目的,我们需要区分学习到的表示 哪个是属于稳定特征 Z 哪个是属于不稳定特征 M。Z 和 M 的主要区别是对 Y 有没有因果效应。对于一个图级别的分类任务,在学习到节点表示之后,他将会被送到一个分类器里来预测他的标签。这个预测过程可以表示为图 3(a),其中 T 是处理变量,Y 是输出预测值,X 是混淆变量。路径 表示 GNN 的目标是估计一个学习到的表示变量 T 到 Y 的因果效应。同时其他变量将会被视为混淆变量 X。由于子图之间存在虚假相关,因此他们学习到的表示之间也存在虚假相关。因此,存在一个在 X 和 T 之间的路径。并且由于 GNN 同样也使用 confounder 做预测,所以存在一条路径 。因此,这两条路径形成一个 X 到 T 的后门路径 (i.e., ),从而导致 T 和 Y 之间的虚假相关。这个虚假相关将会改变处理变量和标签的之间真实相关性,并且在测试的时候会改变。在这种情景下,目前的 GNN 方法不能准确的评估子图的因果效应,因此 GNN 的性能可能会衰减。混淆平衡技术通常被用来评估变量的因果效应,但是他们通常针对某一变量是由单个维度的特征组成的数据,我们要处理的数据是是多个高维变量组成的,因此,我们提出一种多变量多维度的混淆变量平衡技术,如图 3b 所示:其中 是第 k 个混淆变量的 embedding 矩阵。▲ 图3. GNN的因果视角
重加权HSIC但是上述的混淆变量平衡技术主要针对的是二元处理变量,我们需要处理的高维处理变量。基于混淆平衡技术主要目的是去除处理变量和混淆变量之间的关联,我以我们考虑采用 HSIC 来度量高维变量之间的关联,同时提出采用样本加权的方式去除高维变量之间的关联,方法如下:对于两个变量 U 和 V,我们首先采用随机初始化的样本权重来重加权它们:然后我们可以得到加权的 HSIC:
对于去除所有变量之间的相关性,我们优化如下的全局高维变量去相关项:
实验
首先,相较于基模型我们都取得了比较大的提升效果,证明了我们是个有效的框架。其次,在偏差程度越大的数据上提升效果越明显,证明了我们的方法可以有效对抗数据偏移产生的分布外效果下降的问题。最后,我们的模型相较于 GCN/GraphSAGE 都有明显的提升,证明了我们的方法是一个灵活的框架可以提升现有模型的效果。图 4 是一些可解释性的例子,也能很好的说明我们的模型可以利用因果结构进行预测。
▲ 图4. GCN和StableGCN的可解释性例子图 4 是 MUTAG 数据集上的可解释性实验。蓝色,绿色,红色和黄色分别代表 N,H,O,C 原子。由GNNExplainer 产生的最重要的子图被标为黑色。StableGNN 正确的确定了功能团 NO2 和 NH2,这些功能团被认为是对 Mutagenic 有决定性作用的,而其他方法不能找到有解释性的子图做预测。
▲ 图4. MUTAG数据集上的可解释性实验此外,我们认为本文开启了一个在图数据上进行因果表示学习的方向。本文的最重要的贡献是提出了一个通用的因果表示框架:图高层变量表示学习和因果变量区分,这两个部分都可以为任务而特殊的设计。比如,多通道的滤波器可以被用来学习图上的不同的信号到子空间里。然后对于一些数据也许在高层变量之间存在这更复杂的因果结构,因此发现这些因果结构对于重构原始数据生成过程将会更有效。
引用
[1] Shaohua Fan, Xiao Wang, Chuan Shi, Peng Cui, Bai Wang. Generalizing Graph Neural Networks on Out-Of-Distribution Graphs. IEEE TPAMI 2023[2]R. Ying, D. Bourgeois, J. You, M. Zitnik, and J. Leskovec, “Gnnex-plainer: Generating explanations for graph neural networks,” NeurIPS, 2019.[3] X. Zhang, P. Cui, R. Xu, L. Zhou, Y. He, and Z. Shen, “Deep stablelearning for out-of-distribution generalization,” CVPR, 2021, pp.5372–5382[4]R. Ying, J. You, C. Morris, X. Ren, W. L. Hamilton, and J. Leskovec,“Hierarchical graph representation learning with differentiablepooling,” NeurIPS, 2018.[5] B. Schölkopf, F. Locatello, S. Bauer, N. R. Ke, N. Kalchbrenner,A. Goyal, and Y. Bengio, “Toward causal representation learning,”Proceedings of the IEEE, vol. 109, no. 5, pp. 612–634, 2021.更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
📝 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
📬 投稿通道:
• 投稿邮箱:hr@paperweekly.site
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧