论文标题:On the Robustness of Open-World Test-Time Training: Self-Training with Dynamic Prototype Expansion
论文链接:https://openaccess.thecvf.com/content/ICCV2023/html/Li_On_the_Robustness_of_Open-World_Test-Time_Training_Self-Training_with_Dynamic_ICCV_2023_paper.html
代码:https://github.com/Yushu-Li/OWTTT
引用:Li Y, Xu X, Su Y, et al. On the Robustness of Open-World Test-Time Training: Self-Training with Dynamic Prototype Expansion[C]//Proceedings of the IEEE/CVF International Conference on Computer Vision. 2023: 11836-11846.
将深度学习模型推广到低延迟的未知目标域分布,已经激发了对测试时间训练/自适应(test-time training/adaptation,TTT/TTA)的研究。现有的方法往往侧重于在精心管理的目标域数据下提高测试时间的训练性能。
然而,本研究指出,在目标域受到强大的OOD数据(out-of-distribution,OOD,即超出正常分布范围的数据)污染时,许多最先进的方法无法保持其性能水平。这种情况被称为开放世界测试时训练( open-world test-time training,OWTTT)。这一失败主要是因为这些方法无法有效地区分强OOD数据样本和常规的弱OOD数据样本。
为提高OWTTT的稳健性,研究人员首先开发了一种自适应的强OOD数据剪枝方法,以提高自训练TTT方法的效力。此外,他们提出了一种方法,可以动态扩展原型以更好地表示强OOD数据样本,从而实现更好的弱/强OOD数据分离。最后,他们采用分布对齐来规范自训练,在5个OWTTT基准测试中实现了最先进的性能水平。
本文认为,在现有研究中被忽略的测试时训练问题,即开放世界测试时训练(OWTTT),可能受到强OOD测试数据的影响。我们证明,如果没有特殊处理,最先进的TTT方法在开放世界协议下难以进行良好的泛化。
本文引入了一种基线方法,采用原型聚类和分布对齐正则化。此外,我们进一步开发了一个强OOD检测器和原型扩展方法,以提高在OWTTT协议下的基线鲁棒性。
本文建立了一个基准用于评估OWTTT协议,覆盖了多种域偏移类型,包括常见的数据污染和风格迁移。我们的方法在建议的基准测试上取得了最先进的性能。
无监督域适应:无监督域适应(Unsupervised Domain Adaptation, UDA)旨在提高模型在目标域数据上的泛化能力,即在目标域没有标记数据的情况下。UDA通常通过学习源域和目标域之间的不变特征(invariant features)来实现,这包括发现目标域中的聚类结构、自监督训练、基于距离的对齐等等。尽管UDA在提高模型对目标域的泛化性能方面取得了相当大的进展,但在适应过程中需要同时访问源域和目标域数据并不总是现实的,例如由于数据隐私问题。
测试时训练:在某些情况下,我们希望已经部署到目标领域的模型能够自动适应新环境,而无需访问源领域数据。考虑到这些需求,为了实现对任意未知目标领域的自适应并减少推断延迟,测试时训练/自适应(Test-Time Training/Adaptation,TTT/TTA)出现了。TTT通常通过三种范式实现:(1)测试数据上的自监督学习,使模型能够适应目标领域,而不考虑任何语义信息 [40, 27]。(2)自训练,强化模型对未标记数据的预测,已被证明对TTT非常有效 [43, 6, 25, 15]。(3)分布对齐,通过调整模型权重以生成与源域具有相同分布的特征 [39, 27]。
开放集域适应:开放集识别的概念首次由[37]引入,指的是在这种设置中,训练的模型需要拒绝来自未知语义类别的测试样本。在领域适应的背景下,ATI [5]提出了开放集域适应(Open-Set Domain Adaptation, OSDA),并通过定义和最大化开放集与闭集之间的距离来实现开放集识别。
测试时训练旨在将源域预训练模型调整为可能受到与源域分布不同的目标域。首先,我们对基于自训练的TTT范式进行概述,按照[39]中的定义进行。具体来说,我们将源域和目标域数据集分别表示为
和,它们各自的标签为和。在封闭世界TTT中,两个标签空间是相同的,而在开放世界为Cs⊆Ct。在测试阶段,在时间戳 t 的一小批测试样本记为
。我们进一步将表示学习网络表示为,将分类器头表示为TTT 是通过更新目标域数据集 Dt 上的表示网络和/或分类器参数来实现的。为了避免 TTT 定义之间的混淆,我们采用了[39]中提出的顺序测试时间训练(sTTT)来进行评估。在 sTTT 协议下,对测试样本进行顺序测试,并在观察到一小批测试样本后进行模型更新。对到达时间戳 t 的任何测试样本的预测不会受到到达 t+k 的任何测试样本的影响。
受到在领域适应任务中发现聚类结构的成功启发,我们将测试时训练形式化为在目标域数据中发现聚类结构。聚类中心即为原型,通过在特征空间中测量测试样本与原型的相似度来实现推断。形式上,我们将源域中的原型表示为
,原型聚类目标定义为最小化以下负对数似然损失:最小化上述目标将允许测试样本嵌套到它们的预测原型附近,远离其他原型。在封闭世界测试时训练情境下[39],原型聚类已经展示出强大的性能[25]。然而,在开放世界测试时训练情境下,原型聚类受到严重挑战,因为可能存在强OOD样本。如果强OOD样本被强制分类到任何源类别中,对于带有噪声标签的样本进行自训练会混淆网络对弱OOD样本的判别能力。
强OOD样本修剪:我们开发了一种无需超参数的方法来剔除强OOD样本,以避免模型权重的负面影响。具体来说,我们为每个测试样本定义一个强OOD分数
:我们观察到强OOD分数受到双峰分布的影响,如图3所示。因此,我们不指定一个固定的阈值,而是将最佳阈值定义为分隔两个分布模态的阈值。
具体来说,该问题可以被形式化为将强OOD分数分成两个聚类,最佳阈值将最小化公式 3 中的簇内变异度:
通过最佳阈值τ*,我们可以识别出强OOD样本,这有两方面的好处。首先,在模型权重更新期间,我们可以将检测到的OOD样本排除在基于源领域原型的自训练之外。其次,在推断阶段,它为我们提供了一种区分弱OOD样本和强OOD样本的方法。
识别强OOD样本并将它们排除在模型权重更新之外并不能保证将弱OOD测试样本与强OOD样本有效分离。受到新颖性检测的成功启发[32],我们提出动态扩展原型池,以包含代表强OOD样本的原型。然后,我们应用自训练,同时使用源域原型和强OOD原型,以在特征空间中的弱OOD样本和强OOD样本之间创造更大的差距。
具体地,我们将强OOD原型集表示为
。当没有关于目标域的先验信息可用时,我们将新原型初始化为空集。随着TTT的进行,预计Pn将扩展以适应目标域中的未知分布。原型扩展:扩展强OOD原型池需要评估测试样本与源域原型和强OOD原型的距离。为了估计数据中聚类的数量,本文借鉴了DP-means的思想。DP-means在发现数据中的聚类时通过测量数据点到已知聚类中心的距离来确定新的聚类。在这种方法中,如果距离超过一个阈值,就会初始化一个新的聚类。类似地,为了动态扩展原型池,作者首先将带有扩展强OOD分数
的测试样本定义为最接近已知源域原型和强OOD原型的距离,如Eq. 4所示:使用强OOD原型进行原型聚类:通过识别出额外的强OOD原型,我们首先为测试样本定义了原型聚类损失。考虑到两个方面。首先,分类为已知类别的测试样本应该嵌套到原型附近,远离其他原型,这定义了一个K路分类任务。其次,分类为强OOD原型的测试样本应该嵌套在离任何源域原型更远的地方,这定义了一个K + 1路分类任务。基于这些目标,我们定义了原型聚类损失,如Eq. 5所示:
利用源域和目标域分布之间的 KL-散度损失LKLD在目标域上规范原型聚类:
整体算法如下:
在目标域数据中存在噪声或污染的情况下执行开放世界测试时训练(Open-World TTT):
在风格迁移目标域中执行 Open-World TTT :
数据可视化:
TTT 已被广泛研究,以使适应未知的目标分布与低推理延迟。这项工作的研究目标是探索测试时训练在测试数据中存在强OOD样本的情况下的鲁棒性,也就是所谓的开放世界测试时训练(OWTTT)。为此,研究提出了一种无需超参数的强OOD检测器,该检测器有助于OWTTT中的自训练和分布对齐。此外,研究允许原型池在动态扩展的情况下,自训练能够更好地将弱OOD样本与强OOD样本分开。通过在五个OWTTT基准测试上进行广泛的评估,研究证明了所提出的方法的有效性。