PromptKD: Unsupervised Prompt Distillation for Vision-Language Models
PromptKD: 视觉-语言模型的无监督提示蒸馏
Abstract
摘要
Prompt learning has emerged as a valuable technique in enhancing vision-language models (VLMs) such as CLIP for downstream tasks in specific domains. Existing work mainly focuses on designing various learning forms of prompts, neglecting the potential of prompts as effective distillers for learning from larger teacher models. In this paper, we introduce an unsupervised domain prompt distillation framework, which aims to transfer the knowledge of a larger teacher model to a lightweight target model through prompt-driven imitation using unlabeled domain images. Specifically, our framework consists of two distinct stages. In the initial stage, we pre-train a large CLIP teacher model using domain (few-shot) labels. After pretraining, we leverage the unique decoupled-modality charact eris tics of CLIP by pre-computing and storing the text features as class vectors only once through the teacher text encoder. In the subsequent stage, the stored class vectors are shared across teacher and student image encoders for calculating the predicted logits. Further, we align the logits of both the teacher and student models via KL divergence, encouraging the student image encoder to generate similar probability distributions to the teacher through the learnable prompts. The proposed prompt distillation process eliminates the reliance on labeled data, enabling the algorithm to leverage a vast amount of unlabeled images within the domain. Finally, the well-trained student image encoders and pre-stored text features (class vectors) are utilized for inference. To our best knowledge, we are the first to (1) perform unsupervised domain-specific prompt-driven knowledge distillation for CLIP, and (2) establish a practical pre-storing mechanism of text features as shared class vectors between teacher and student. Extensive experiments on 11 datasets demonstrate the effectiveness of our method. Code is publicly available at https://github.com/
提示学习已成为增强视觉语言模型(VLM)如CLIP在特定领域下游任务中的一项重要技术。现有工作主要集中于设计各种提示学习形式,却忽视了提示作为从更大教师模型中学习的有效蒸馏器的潜力。本文提出了一种无监督领域提示蒸馏框架,旨在通过未标记领域图像的提示驱动模仿,将大型教师模型的知识迁移到轻量级目标模型中。
具体而言,我们的框架包含两个不同阶段。在初始阶段,我们使用领域(少样本)标签预训练一个大型CLIP教师模型。预训练完成后,我们利用CLIP独特的解耦模态特性,通过教师文本编码器预计算并存储文本特征作为类别向量,仅需一次操作。在后续阶段,存储的类别向量在教师和学生图像编码器之间共享,用于计算预测逻辑值。此外,我们通过KL散度对齐教师和学生模型的逻辑值,促使学生图像编码器通过可学习提示生成与教师相似的概率分布。所提出的提示蒸馏过程消除了对标记数据的依赖,使算法能够利用领域内大量未标记图像。最终,训练好的学生图像编码器和预存储的文本特征(类别向量)被用于推理。
据我们所知,我们是首次(1)为CLIP实现无监督领域特定的提示驱动知识蒸馏,以及(2)建立一种实用的文本特征预存储机制,作为教师和学生之间共享的类别向量。在11个数据集上的大量实验证明了我们方法的有效性。代码公开于https://github.com/
Figure 1. Harmonic mean (HM) comparison on base-to-novel genera liz ation. All methods adopt the ViT-B/16 image encoder from the pre-trained CLIP model. PromptKD achieves state-of-the-art performance on 11 diverse recognition datasets.
图 1: 基类到新类泛化的调和平均数 (HM) 对比。所有方法均采用预训练 CLIP 模型中的 ViT-B/16 图像编码器。PromptKD 在 11 个不同识别数据集上实现了最先进的性能。
zhengli97/PromptKD.
zhengli97/PromptKD。
1. Introduction
1. 引言
Recently large pretrained vision-language models (VLMs), such as CLIP [41, 68] and ALIGN [17], have demonstrated superior generalization ability for domain-specific downstream tasks. Unlike conventional visual frameworks, the vision-language model, like CLIP, usually employs a twotower architecture that includes an image encoder and a text encoder. These models are trained using a contrastive loss to learn a unified embedding space that aligns the representations of multi-modal signals.
近期,大型预训练视觉语言模型(VLM),如 CLIP [41, 68] 和 ALIGN [17],在特定领域下游任务中展现出卓越的泛化能力。与传统视觉框架不同,CLIP 等视觉语言模型通常采用双塔架构,包含图像编码器和文本编码器。这些模型通过对比损失训练,学习统一的多模态信号表征对齐嵌入空间。
To better optimize the models for domain-specific downstream tasks, various methods [10, 21, 65, 71, 72] have been proposed to adapt the representation while keeping the original CLIP model fixed. Inspired by the success of Nature Language Processing (NLP) [26, 28] area, prompt learning [18, 71, 72] has been proposed to acquire continuous prompt representations as a replacement for meticulously designed hard prompts. Based on the type of information learned by prompt, existing methods can be roughly divided into three types: text-based, visual-based, and both. Textbased methods [71, 72] propose to adaptively learn appropriate text prompts for downstream tasks, rather than fixed forms. Visual-based methods [5, 18] follow similar principles and further apply them to visual modalities. Textvisual-based prompt methods [21, 22, 25, 52] suggest a simultaneous learning strategy for prompts in both image and text branches, instead of treating them separately.
为了更好地针对特定领域下游任务优化模型,研究者们提出了多种方法 [10, 21, 65, 71, 72] 来调整表征,同时保持原始 CLIP 模型固定不变。受自然语言处理 (NLP) [26, 28] 领域成功的启发,提示学习 (prompt learning) [18, 71, 72] 被提出用于获取连续的提示表征,以替代精心设计的硬提示。根据提示学习的信息类型,现有方法大致可分为三类:基于文本、基于视觉以及两者兼顾的方法。基于文本的方法 [71, 72] 提出为下游任务自适应学习合适的文本提示,而非固定形式。基于视觉的方法 [5, 18] 遵循类似原则,并进一步将其应用于视觉模态。文本-视觉联合提示方法 [21, 22, 25, 52] 则建议同时对图像和文本分支的提示进行联合学习,而非单独处理。
Figure 2. Architecture comparison between classic KD paradigm for CLIP (likewise CLIP-KD [63]) and our prompt distillation framework. (a) Classic KD methods perform distillation between independent teacher and student models. Students are typically fully fine-tuned by teachers’ soft labels. (b) PromptKD breaks the rules of teacher-student independence. We propose to reuse the previously well-trained text features from the teacher pre-training stage and incorporate them into the student image encoder for both distillation and inference.
图 2: CLIP经典知识蒸馏(KD)范式(类似CLIP-KD [63])与我们的提示蒸馏框架架构对比。(a) 传统KD方法在独立的教师模型和学生模型之间进行蒸馏。学生模型通常通过教师的软标签进行全参数微调。(b) PromptKD打破了师生独立的规则。我们提出复用教师预训练阶段已训练良好的文本特征,将其整合到学生图像编码器中,同时用于蒸馏和推理。
Prior research has primarily concentrated on acquiring effective formats of prompts using scarce labeled data while preserving the outstanding generalization capabilities. In this paper, we introduce a novel unsupervised framework (termed “PromptKD”) where the prompt acts as a domain knowledge distiller, allowing the CLIP student model to absorb knowledge from a vast CLIP teacher model on extensive unlabeled domain data. Specifically, our framework consists of two distinct stages: the teacher pre-training stage and the student distillation stage.
先前的研究主要集中于利用稀缺的标注数据获取有效的提示(prompt)格式,同时保持出色的泛化能力。本文提出了一种新颖的无监督框架(称为"PromptKD"),其中提示作为领域知识蒸馏器,使CLIP学生模型能够从大规模无标注领域数据中的CLIP教师模型吸收知识。具体而言,我们的框架包含两个不同阶段:教师预训练阶段和学生蒸馏阶段。
In the initial stage, we first pre-train a large CLIP teacher model using existing advanced approaches [21, 22] on domain few-shot labeled data. After pre-training, we propose to leverage the unique decoupled-modality characteristics of CLIP by pre-computing and storing the text features as class vectors only once through the teacher text encoder.
在初始阶段,我们首先使用现有先进方法[21, 22]在领域少样本标注数据上预训练一个大型CLIP教师模型。预训练完成后,我们提出利用CLIP特有的解耦模态特性,通过教师文本编码器预先计算并存储文本特征作为类别向量,该过程仅需执行一次。
In the subsequent stage, the stored class vectors are shared across the teacher and student image encoder to calculate the predicted logits without any extra computation costs from text branches. Different from the traditional knowledge distillation scheme where the weights of a student are usually fully tuned to mimic the teachers’ statistical behavior as shown in Fig. 2(a), we propose to utilize the student’s learnable visual prompts to align the logits of both teacher and student models via KL divergence, encouraging the student image encoder to generate similar probability distributions to the teacher through prompt distillation. Due to the dimensional differences between the features of teacher and student, an extra projector is implemented to adjust the features to account for the dimension disparity.
在后续阶段,存储的类别向量在教师和学生图像编码器之间共享,用于计算预测逻辑值(logits),无需文本分支产生额外计算开销。与传统知识蒸馏方案(如图2(a)所示)通常通过完全调整学生模型权重来模仿教师统计行为不同,我们提出利用学生模型的可学习视觉提示(visual prompts),通过KL散度对齐师生模型的逻辑值,促使学生图像编码器通过提示蒸馏(prompt distillation)生成与教师模型相似的概率分布。由于师生模型特征存在维度差异,我们额外引入投影器(projector)来调整特征维度。
With the benefits of the teacher-student paradigm, we can leverage the pre-trained teacher to generate soft labels for unlabeled images from the target domain, thus enabling the training of students without the need for labeled images. Finally, the well-trained student image encoder, along with the pre-stored teacher text features (class vectors), are employed for inference purposes. An architectural comparison of the classic distillation paradigm for CLIP and our proposed prompt distillation framework is illustrated in Fig. 2.
借助师生范式的优势,我们可以利用预训练的教师模型为目标域的无标注图像生成软标签,从而无需标注图像即可训练学生模型。最终,训练完善的学生图像编码器将与预存的教师文本特征(类别向量)共同用于推理任务。图2展示了经典CLIP蒸馏范式与我们提出的提示蒸馏框架的架构对比。
Experimental results in Fig. 1 show that our PromptKD outperforms previous methods and achieves state-ofthe-art performance on 11 diverse recognition datasets with the ViT-B/16 image encoder CLIP model. Specifically, our method achieves average improvements of $2.70%$ and $4.63%$ on the base and new classes on 11 diverse datasets.
图1中的实验结果表明,我们的PromptKD方法优于以往的方法,并在使用ViT-B/16图像编码器CLIP模型的11个多样化识别数据集上实现了最先进的性能。具体而言,我们的方法在11个多样化数据集的基础类和新类上分别实现了2.70%和4.63%的平均提升。
Our contributions can be summarized as follows:
我们的贡献可总结如下:
2. Related Work
2. 相关工作
Prompt Learning in Vision-Language Models. Prompt learning is a technique that can transfer the large pre-trained model, like CLIP [41], towards downstream tasks [11, 42, 66] without the need for completely re-training the original model. It proposes to adapt the representations for specific tasks through learnable text or visual soft prompts instead of manually crafted hard prompts (e.g., “a photo of a ${\mathrm{classname}}^{\cdot\cdot})$ . Soft prompts [18, 25, 44, 71, 72] can be optimized by back-propagating through the frozen pretrained model, resulting in better performance. Existing works mainly focus on designing various efficient forms of prompts using scarce labeled domain data. MaPLe [21] proposes to learn prompts for the image and text branches simultan e ou sly, rather than a separate side. PromptSRC [22] utilizes its original features to regularize the learning of prompts for each branch. Previous works necessitated forward and backward computations for each input in both image [8, 56] and text branches. In our work, we leverage the unique decoupled-modality characteristic of CLIP, saving well-trained teacher text features as class vectors for student distillation. In this way, the training of student CLIP is simplified to solely include forward and backward calculations of the image branch, without requiring the text branch.
视觉语言模型中的提示学习
提示学习是一种技术,能够将大型预训练模型(如CLIP [41])迁移到下游任务[11, 42, 66]中,而无需完全重新训练原始模型。该方法提出通过可学习的文本或视觉软提示(soft prompts)来适配特定任务的表示,而非手动设计的硬提示(hard prompts)(例如“一张 ${\mathrm{classname}}^{\cdot\cdot}$ 的照片”)。软提示[18, 25, 44, 71, 72]可以通过在冻结的预训练模型上进行反向传播优化,从而获得更好的性能。现有工作主要集中于利用稀缺的标注领域数据设计各种高效的提示形式。MaPLe [21]提出同时学习图像和文本分支的提示,而非单独处理。PromptSRC [22]利用其原始特征来规范化每个分支的提示学习。先前的工作需要在图像[8, 56]和文本分支中对每个输入进行前向和反向计算。在我们的工作中,我们利用了CLIP独特的解耦模态特性,将训练好的教师文本特征保存为类别向量用于学生蒸馏。这样,学生CLIP的训练被简化为仅包含图像分支的前向和反向计算,而无需文本分支。
Zero-shot Learning. Given the labeled training set of the seen classes, zero-shot learning (ZSL) [32, 55, 58] aims to learn a classifier that can classify testing samples of unseen classes. Existing methods can be roughly divided into two types based on whether test images are available: Inductive [59, 67] and Trans duct ive [49, 51] ZSL. Previous works on prompt learning, such as MaPLe and PromptSRC, have mainly focused on the instance inductive settings where only labeled training instances are available. In our paper, we explore the trans duct ive ZSL setting where both seen and unseen class images are all utilized in model learning. Specifically, our teacher model follows the same training scheme as PromptSRC, which is trained on samples from seen classes with ground truth labels. The difference is that the target student model is trained on the full unlabeled dataset, which contains all samples of both seen and unseen classes, without using any ground truth labels.
零样本学习 (Zero-shot Learning)。给定已见类别的带标签训练集,零样本学习 (ZSL) [32, 55, 58] 的目标是学习一个能够对未见类别测试样本进行分类的分类器。现有方法根据测试图像是否可用大致分为两类:归纳式 (Inductive) [59, 67] 和直推式 (Transductive) [49, 51] ZSL。先前关于提示学习 (prompt learning) 的工作,如 MaPLe 和 PromptSRC,主要关注仅含带标签训练实例的归纳式设置。本文中,我们探索了直推式 ZSL 设置,即在模型学习中同时利用已见和未见类别的图像。具体而言,我们的教师模型遵循与 PromptSRC 相同的训练方案,即在带有真实标签的已见类别样本上进行训练。不同之处在于,目标学生模型是在完整的无标签数据集上训练的,该数据集包含已见和未见类别的所有样本,且不使用任何真实标签。
Knowledge Distillation. Knowledge distillation [15] aims to train a lightweight student model under the supervision of a large pretrained teacher model. In recent years, various distillation forms have emerged for effective knowledge transfer from teachers to students, such as logits alignment [29, 31, 69, 70], feature imitation [4, 27, 64] and sample relationship matching [38, 61]. In addition to traditional image classification topics, knowledge distillation has achieved great success in many vision tasks, including object detection [2, 19, 54], image segmentation [33, 62], and pose estimation [30]. Recently, many works [24, 40, 57, 63] have turned their attention to the CLIP model. These works propose leveraging the CLIP model’s exceptional generalization capabilities to enhance the learning of existing models. CLIP-KD [63] find that in distilling pre-trained CLIP models, the simplest feature mimicry with the MSE loss approach yields the best results. TinyCLIP [57] performs cross-modal feature alignment in affinity space between teacher and student. Our approach differs from previous distillation methods that train the entire student model by leveraging a pre-trained large CLIP teacher. In our work, we employ a more efficient approach by utilizing student prompts for distillation while keeping the student’s original CLIP weights frozen. This allows us to achieve the desired knowledge transfer without the need for extensive re-training of the student model.
知识蒸馏 (Knowledge Distillation)。知识蒸馏 [15] 旨在通过大型预训练教师模型 (teacher model) 的监督来训练轻量级学生模型 (student model)。近年来,为了实现从教师到学生的高效知识迁移,涌现出多种蒸馏形式,例如逻辑对齐 (logits alignment) [29, 31, 69, 70]、特征模仿 (feature imitation) [4, 27, 64] 和样本关系匹配 (sample relationship matching) [38, 61]。除传统的图像分类任务外,知识蒸馏在目标检测 [2, 19, 54]、图像分割 [33, 62] 和姿态估计 [30] 等视觉任务中也取得了巨大成功。近期,许多工作 [24, 40, 57, 63] 开始关注 CLIP 模型,提出利用其卓越的泛化能力来增强现有模型的学习效果。CLIP-KD [63] 发现,在蒸馏预训练 CLIP 模型时,采用 MSE 损失函数的简单特征模仿方法效果最佳。TinyCLIP [57] 则在教师与学生模型的亲和力空间中进行跨模态特征对齐。与以往通过预训练大型 CLIP 教师模型来训练完整学生模型的蒸馏方法不同,本研究采用更高效的策略:在冻结学生原有 CLIP 权重的前提下,通过学生提示 (student prompts) 实现蒸馏,从而无需对学生模型进行大规模重训练即可完成知识迁移。
3. Method
3. 方法
Prompt learning [18, 72] aims to enhance the performance of existing VLMs like CLIP to downstream tasks by incorporating learnable prompts. Existing works mainly focus on devising effective learning formats of prompts using scarce labeled domain data while ensuring strong generalization capabilities to unseen images. In this paper, we first explore prompts as an effective knowledge distiller, allowing the CLIP student model to learn from the large CLIP teacher model by aligning their predictions on extensive unlabeled domain images. An overview of our proposed prompt distillation method is illustrated in Fig. 3. Specifically, our method comprises two main stages: the teacher pre-training stage and the student prompt distillation stage. In the initial stage, we first pre-train a large CLIP teacher model using existing advanced approaches on few-shot labeled data, as depicted in Fig. 3(a). After pre-training, we extract and preserve the highly proficient text features obtained from the teacher text encoder as class vectors. In the subsequent stage, the pre-stored class vectors are effectively reused by multiplying them with the outputs of both the teacher and student image encoders, resulting in predictions for each model. Then we initiate the distillation process by promoting prompt imitation, encouraging the student model to generate similar predictions to the teacher model, as illustrated in Fig. 3(b). An additional projector is introduced to align the dimensions of teacher text features and student image features. Finally, the well-trained student image encoder branch and pre-stored teacher text features (class vectors) are utilized for inference (see Fig. 3(c)).
提示学习 [18, 72] 旨在通过引入可学习提示 (prompt) 来提升 CLIP 等现有视觉语言模型 (VLM) 在下游任务中的表现。现有工作主要集中于利用稀缺的标注领域数据设计有效的提示学习格式,同时确保对未见图像的强泛化能力。本文首次探索将提示作为高效的知识蒸馏器,通过让 CLIP 学生模型与大型 CLIP 教师模型在大量无标注领域图像上的预测对齐来实现知识迁移。图 3 展示了我们提出的提示蒸馏方法概览。
具体而言,我们的方法包含两个主要阶段:教师模型预训练阶段和学生提示蒸馏阶段。在初始阶段(如图 3(a) 所示),我们首先使用现有先进方法在少样本标注数据上预训练大型 CLIP 教师模型。预训练完成后,从教师文本编码器提取并保存高水平的文本特征作为类别向量。在后续阶段(如图 3(b) 所示),通过将预存的类别向量分别与教师和学生图像编码器的输出相乘,获得各模型的预测结果,进而启动以提示模仿为核心的知识蒸馏过程——促使学生模型生成与教师模型相似的预测。我们额外引入投影器来对齐教师文本特征和学生图像特征的维度。最终(如图 3(c) 所示),训练完成的学生图像编码器分支与预存的教师文本特征(类别向量)将共同用于推理。
Below we first introduce the background knowledge of VLMs and the knowledge distillation method in Sec. 3.1. Then we introduce our method in detail in Sec. 3.2.
我们在第3.1节首先介绍视觉语言模型(VLM)的背景知识和知识蒸馏方法,随后在第3.2节详细阐述我们的方法。
3.1. Background
3.1. 背景
Vision-Language Models. Existing VLMs like CLIP [41] and ALIGN [17] are designed to align images and texts in order to learn a joint embedding space. Following [21, 22,
视觉-语言模型 (Vision-Language Models)。现有的VLM(如CLIP [41]和ALIGN [17])旨在对齐图像和文本以学习联合嵌入空间。基于[21, 22]的研究...
Figure 3. An overview of our proposed prompt distillation (PromptKD) framework. (a) We first pre-train a large CLIP teacher model using existing state-of-the-art prompt learning methods with labeled training images. Then we save the well-trained text features of all possible classes for the next stages. (b) During the distillation stage, the training is focused on student image prompts and the project layer, and there are no extra computational expenses associated with the text encoding process when utilizing the pre-saved text features as class vectors. (c) Finally, the well-trained student and pre-stored class vectors are utilized for inference.
图 3: 我们提出的提示蒸馏 (PromptKD) 框架概述。(a) 首先使用现有最先进的提示学习方法,用带标注的训练图像预训练大型 CLIP 教师模型,然后保存所有可能类别的训练良好的文本特征以供后续阶段使用。(b) 在蒸馏阶段,训练重点放在学生图像提示和投影层上,当使用预先保存的文本特征作为类别向量时,不会产生与文本编码过程相关的额外计算开销。(c) 最后,训练良好的学生模型和预先存储的类别向量被用于推理。
$$
p(y|x)=\frac{\exp(u w_{y}^{\mathsf{T}}/\tau)}{\sum_{i=1}^{N}\exp(u w_{i}^{\mathsf{T}}/\tau)},
$$
$$
p(y|x)=\frac{\exp(u w_{y}^{\mathsf{T}}/\tau)}{\sum_{i=1}^{N}\exp(u w_{i}^{\mathsf{T}}/\tau)},
$$
where $u w^{\mathsf{T}}$ represent the output logit and $\tau$ is the temperature parameter.
其中 $u w^{\mathsf{T}}$ 表示输出逻辑值,$\tau$ 是温度参数。
Instead of manually crafted hard prompts, recent works like CoOp [72] propose to adaptively learn appropriate soft textual prompts for downstream tasks. Concretely, $M$ learnable textual vectors ${v_{1},v_{2},...,v_{M}}$ , i.e., prefix, are added before the CLASS token to create a contextual i zed representation. Then the prompt $t_{i}$ for class $c_{i}$ becomes $t_{i}=$ ${v_{1},v_{2},...,v_{M},c_{i}}$ , where each vector $v_{i}$ $(i\in{1,2,...,M})$ have the same dimension with the word embeddings and $M$ is a hyper parameter that determines the length of the prefix. In addition to text prompt tuning methods, visual prompts have also been extensively explored. Some works [18, 21, 22] follow the same idea as the text prompt method, adding multiple learnable visual prefixes to the image patch as input to the image encoder. These visual prompts aim to guide the image encoder to extract more meaningful and task-relevant visual features. By incorporating these learnable visual prefixes, the model can leverage additional context and prior knowledge to improve its performance on image understanding tasks.
与人工设计的硬提示不同,CoOp [72] 等近期研究提出为下游任务自适应学习合适的软文本提示。具体而言,在 CLASS token 前添加 $M$ 个可学习的文本向量 ${v_{1},v_{2},...,v_{M}}$(即前缀)以构建上下文感知的表征。此时类别 $c_{i}$ 的提示 $t_{i}$ 变为 $t_{i}=$ ${v_{1},v_{2},...,v_{M},c_{i}}$,其中每个向量 $v_{i}$ $(i\in{1,2,...,M})$ 与词嵌入维度相同,$M$ 是控制前缀长度的超参数。除文本提示调优方法外,视觉提示也得到广泛探索。部分研究 [18, 21, 22] 沿用文本提示的思路,向图像块添加多个可学习的视觉前缀作为图像编码器的输入。这些视觉提示旨在引导图像编码器提取更具意义且与任务相关的视觉特征。通过融入可学习的视觉前缀,模型可利用额外上下文和先验知识提升图像理解任务的性能。
Knowledge Distillation. Originally proposed by Hinton et al. [15], knowledge distillation aims to transfer the knowledge of a pretrained heavy teacher model to a lightweight student model. After the distillation, the student can master the expertise of the teacher and be used for final deployment. Specifically, the Kullback-Leibler (KL) divergence loss is utilized to match the output distribution of two models, which can be formulated as follows:
知识蒸馏 (Knowledge Distillation)。最初由 Hinton 等人 [15] 提出,知识蒸馏旨在将预训练的大型教师模型 (teacher model) 的知识迁移到轻量级学生模型 (student model) 中。经过蒸馏后,学生模型能够掌握教师模型的专长,并可用于最终部署。具体而言,该方法利用 Kullback-Leibler (KL) 散度损失来匹配两个模型的输出分布,其公式如下:
$$
L_{k d}(q^{t},q^{s},\tau)=\tau^{2}K L(\sigma(q^{t}/\tau),\sigma(q^{s}/\tau)).
$$
$$
L_{k d}(q^{t},q^{s},\tau)=\tau^{2}K L(\sigma(q^{t}/\tau),\sigma(q^{s}/\tau)).
$$
where $q^{t}$ and $q^{s}$ denote the logits predicted by the teacher and student. $\sigma(\cdot)$ is the softmax function and $\tau$ is the temperature [15, 31] which controls the softness of distribution.
其中 $q^{t}$ 和 $q^{s}$ 分别表示教师模型和学生模型预测的 logits。$\sigma(\cdot)$ 是 softmax 函数,$\tau$ 是控制分布平滑度的温度参数 [15, 31]。
3.2. PromptKD: Prompt Distillation for VLMs
3.2. PromptKD: 视觉语言模型的提示蒸馏
Our proposed prompt distillation framework comprises two stages: teacher pre-training and student prompt distillation, as illustrated in Fig. 3. In this section, we provide a comprehensive explanation of each stage.
我们提出的提示词蒸馏框架包含两个阶段:教师预训练和学生提示词蒸馏,如图 3 所示。本节将详细阐述每个阶段。
Stage I: Teacher Pre training. In the initial stage, we begin by pre-training a large CLIP teacher model using labeled domain data, as illustrated in Fig. 3(a). To accomplish this, we can employ existing prompt learning methods such as MaPLe [21] and PromptSRC [22], or alternatively, utilize a publicly available pretrained CLIP model for simplicity. Given a labeled domain dataset $D_{l a b e l e d}={x_{i},y_{i}}_{i=1}^{M}$ with a set class name, the teacher CLIP model takes training images and text descriptions with category names as input, and passes through the image encoder $f_{I}^{t}$ and text encoder $f_{T}^{t}$ to obtain the corresponding normalized image features $u\in\mathbb{R}^{d}$ and text features $w\in\mathbb{R}^{d}$ . The final output result $p^{t}$ is calculated by Eqn. (1). Typically, the parameters of teacher soft prompts are updated by minimizing the cross-entropy loss between predicted probabilities $p$ and ground truth labels $y$ .
阶段一:教师模型预训练。在初始阶段,我们首先使用带标注的领域数据预训练一个大型CLIP教师模型,如图3(a)所示。为此,可以采用现有提示学习方法如MaPLe [21]和PromptSRC [22],或直接使用公开可用的预训练CLIP模型以简化流程。给定带标注的领域数据集$D_{labeled}={x_{i},y_{i}}_{i=1}^{M}$及其预设类别名称,教师CLIP模型将训练图像和包含类别名称的文本描述作为输入,分别通过图像编码器$f_{I}^{t}$和文本编码器$f_{T}^{t}$获得归一化的图像特征$u\in\mathbb{R}^{d}$与文本特征$w\in\mathbb{R}^{d}$。最终输出结果$p^{t}$由公式(1)计算得出。通常,教师软提示参数通过最小化预测概率$p$与真实标签$y$之间的交叉熵损失来更新。
Once the training of the text encoder is completed, the output features remain fixed and do not require further updates. In this case, we save the well-trained teacher text features of all $N$ classes $W=[w_{1},w_{2},...,w_{N}]\in\mathbb{R}^{N\times d}$ as shared class vectors that will be utilized in the subsequent stages of the process. This operation eliminates the necessity of having the student CLIP text branch, resulting in substantial computational cost savings during the training process. In addition, through our PromptKD method, we can replace the large teacher’s heavy image encoder with a student lightweight image encoder, reducing the computational cost during deployment while maintaining competitive performance.
文本编码器训练完成后,输出特征将保持固定且无需进一步更新。此时,我们将所有$N$个类别的训练有素的教师文本特征$W=[w_{1},w_{2},...,w_{N}]\in\mathbb{R}^{N\times d}$保存为共享类别向量,供后续流程使用。该操作无需学生CLIP文本分支参与,从而在训练过程中显著节省计算成本。此外,通过我们的PromptKD方法,可以用学生轻量级图像编码器替代教师笨重的图像编码器,在保持竞争力的性能同时降低部署时的计算开销。
Stage II: Student Prompt Distillation. At this stage, we aim to train a student model by encouraging the student to align with the teacher’s output results through prompt imitation, as shown in Fig. 3(b). Thanks to the strategy of reusing teacher text features, we only need to train the student image encoder branch $f_{I}^{s}$ with learnable visual prompts and the feature projector. In the context of an unlabeled domain dataset $\boldsymbol{D_{u n l a b e l e d}}$ , by inputting the image $x$ into both the pre-trained teacher’s and the untrained student’s image branches, we can acquire the normalized teacher image features $u^{t}=f_{I}^{t}(x)/||f_{I}^{t}(x)||{2}\in\mathbb{R}^{d}$ and student image features $u^{s}=P(f_{I}^{s}(x))/||P(f_{I}^{s}(x))||_{2}\in\mathbb{R}^{d}$ . The learnable projector $P(\cdot)$ in the student image encoder branch is introduced to match the feature dimensions at a relatively small cost while being effective enough to ensure accurate alignment. Then we multiply the pre-stored teacher text features $W\in\mathbb{R}^{N\times d}$ with the teacher and student image features to obtain the output logits $\boldsymbol{q}^{t}=\boldsymbol{u}^{t}\boldsymbol{W}^{\mathsf{T}}\in\mathbb{R}^{N}$ and $\boldsymbol{q}^{s}=\boldsymbol{u}^{s}\boldsymbol{W}^{\intercal}\in\mathbb{R}^{N}$ , respectively. We optimize the student model to produce similar output to the teacher model on the
阶段二:学生提示蒸馏。在此阶段,我们旨在通过提示模仿使学生模型与教师模型的输出结果对齐,如图 3(b) 所示。得益于重用教师文本特征的策略,我们只需训练学生图像编码器分支 $f_{I}^{s}$ 及其可学习视觉提示和特征投影器。在未标注领域数据集 $\boldsymbol{D_{unlabeled}}$ 中,将图像 $x$ 同时输入预训练教师和未训练学生的图像分支后,可获得归一化的教师图像特征 $u^{t}=f_{I}^{t}(x)/||f_{I}^{t}(x)||{2}\in\mathbb{R}^{d}$ 和学生图像特征 $u^{s}=P(f_{I}^{s}(x))/||P(f_{I}^{s}(x))||_{2}\in\mathbb{R}^{d}$。学生图像编码器分支引入的可学习投影器 $P(\cdot)$ 能以较小代价匹配特征维度,同时确保对齐精度。随后将预存的教师文本特征 $W\in\mathbb{R}^{N\times d}$ 分别与师生图像特征相乘,得到输出逻辑值 $\boldsymbol{q}^{t}=\boldsymbol{u}^{t}\boldsymbol{W}^{\mathsf{T}}\in\mathbb{R}^{N}$ 和 $\boldsymbol{q}^{s}=\boldsymbol{u}^{s}\boldsymbol{W}^{\intercal}\in\mathbb{R}^{N}$。我们通过优化使学生模型在...
Algorithm 1 Pseudocode of PromptKD in PyTorch.
算法 1: PromptKD 的 PyTorch 伪代码
| # | | tea-t: 教师CLIP的文本编码器 |
| # | | tea-i: 教师CLIP的图像编码器 |
| # | | stu-i: 学生CLIP的图像编码器 |
| # | | l-tea: 教师输出logits |
| # | | l-stu: 学生输出logits |
| | Proj: 特征投影器 | |
| # init | | |
| | f_txt_t = tea_t(txt_of_all_classes) | |
| | # forward | |
| | for img in unlabeled_dataset: | |
| | f_img-t = tea_i(img) | |
| | f_img-s = stu_i(img) | |
| | f_img_s = Proj(f_img_s) | |
| | # 获取输出预测 | |
| | l-tea = f-img-t * f-txt-t.t() l-stu = f-img-s * f-txt-t.t() | |
| | # 计算蒸馏损失 | |
| | loss = KLDivergence(l-stu, l-tea) | |
| | loss.backward() | |
unlabeled domain dataset $\cal{D}_{u n l a b e l e d}$ , which can be formulated as follows:
未标注领域数据集 $\cal{D}_{u n l a b e l e d}$ ,其可表述如下:
$$
L_{s t u}=L_{k d}(q^{t},q^{s},\tau).
$$
$$
L_{s t u}=L_{k d}(q^{t},q^{s},\tau).
$$
Algorithm 1 provides PromptKD’s PyTorch-style pseudocode.
算法 1 提供了 PromptKD 的 PyTorch 风格伪代码。
Inference. Finally, the well-trained student image encoder $f_{I}^{s}$ , along with the pre-stored teacher text features $W$ (class vectors), are employed for inference purposes.
推理。最终,训练有素的学生图像编码器 $f_{I}^{s}$ 将与预先存储的教师文本特征 $W$(类别向量)一起用于推理。
4. Experiments
4. 实验
4.1. Settings
4.1. 设置
Base-to-novel Generalization. Following [21, 22, 71], we split the training and testing datasets into base and novel classes. The teacher is pre-trained using the PrompSRC [22] method, following the same training setting as PromptSRC. During distillation, we use the entire unlabeled training set to train our students. After distillation, the student’s performance on the base and the novel class is evaluated on the testing set.
基类到新类的泛化。遵循[21, 22, 71]的方法,我们将训练和测试数据集划分为基类和新类。教师模型使用PrompSRC[22]方法进行预训练,训练设置与PromptSRC保持一致。在蒸馏过程中,我们使用整个未标注的训练集来训练学生模型。蒸馏完成后,在测试集上评估学生模型在基类和新类上的性能。
Cross-dataset Evaluation. Same as PromptSRC [22], our teacher model is pre-trained on the source dataset (i.e., ImageNet) with a 16-shot training data configuration. Then we use the training set of unlabeled target datasets to train students and evaluate their performance on the test set after training. In PromptKD, we use unlabeled images of unseen classes for student training which belongs to the transductive zero-shot learning method. For previous methods such as CoOp, MaPLe, and PromptSRC, their training is based on seen class data and belongs to the inductive paradigm.
跨数据集评估。与PromptSRC [22]相同,我们的教师模型在源数据集(即ImageNet)上以16样本训练数据配置进行预训练。然后使用未标注目标数据集的训练集训练学生模型,并在训练后评估其在测试集上的性能。在PromptKD中,我们使用未见类别的未标注图像进行学生训练,这属于转导式零样本学习方法。而对于CoOp、MaPLe和PromptSRC等先前方法,它们的训练基于已见类别数据,属于归纳式范式。
ViT-B/16 | Base | Novel | HM | ViT-B/16 | Base | Novel | HM |
---|---|---|---|---|---|---|---|
CLIP | 72.43 | 68.14 | 70.22 | CLIP | 96.84 | 94.00 | 95.40 |
CoOp | 76.47 | 67.88 | 71.92 | CoOp | 98.00 | 89.81 | 93.73 |
CoCoOp | 75.98 | 70.43 | 73.10 | CoCoOp | 97.96 | 93.81 | 95.84 |
MaPLe | 76.66 | 70.54 | 73.47 | MaPLe | 97.74 | 94.36 | 96.02 |
PromptSRC | 77.60 | 70.73 | 74.01 | PromptSRC | 98.10 | 94.03 | 96.02 |
PromptKD | 80.83 | 74.66 | 77.62 | PromptKD | 98.91 | 96.65 | 97.77 |
△ |