본문 바로가기

Test-time Adaptation

LearnPruner: Rething Attention-based Token Pruning In Vision Language Models

논문주소: https://openreview.net/pdf?id=Dxb6gBJHby

 

Vision-Language Models (VLMs)는 최근 시각적 이해와 추론 영역에서 놀라운 능력을 보여주었으나, 긴 시각적 시퀀스 입력으로 인해 상당한 계산 부담을 초래합니다. 최근 연구들은 중요하지 않은 시각적 토큰을 pruning하여 모델 성능을 유지하면서 상당한 계산량 감소를 달성함으로써 이 문제를 해결하고 있습니다. token pruning의 핵심은 토큰의 중요도를 결정하는 것이며, 현재의 접근 방식은 주로 vision encoders나 Large Language Models (LLMs)의 attentio score에 의존합니다.

 

 본 논문에서는 vision encodersLLMs 양쪽 모두에서 어텐션 메커니즘의 효과를 분석합니다. vision encodersattention sink 현상을 겪으며 이로 인해 정보가 많은 foreground 영역에 대한 집중력이 떨어진다는 것을 발견했습니다. 반면 LLMs의 경우, 이전 연구들이 토큰 위치에 대한 어텐션 편향을 식별했음에도 불구하고, text-to-vision attention은 이러한 편향에 저항력을 보이며 중간 레이어에서 효과적인 pruning 가이드를 가능하게 함을 확인했습니다.

 

 이러한 관찰을 바탕으로, 본 연구에서는 LearnPruner를 제안합니다. 이는 vision encoder 직후에 학습 가능한 pruning module을 통해 중복된 시각 토큰을 먼저 제거한 다음, LLM의 중간 레이어에서 태스크 관련 토큰만을 유지하는 two-stage token pruning 프레임워크입니다. 실험 결과, LearnPruner는 시각 토큰의 5.5%만을 사용하면서도 원래 성능의 약 95%를 보존할 수 있으며, 3.2배의 추론 가속화를 달성하여 우수한 정확도-효율성 trade-off를 입증했습니다.

 

1 INTRODUCTION

 

최근 몇 년 동안 LLMs의 급격한 발전과 함께, VLMs 또한 괄목할 만한 진전을 이루었습니다. LLM의 이해 및 추론 능력을 시각적 모달리티로 확장함으로써, VLMsVisual Question Answering (VQA), Image Captioning, Visual Reasoning 등 다양한 Multimodal Tasks에서 유례없는 성능을 보여주었습니다.

 

 기존의 VLMs는 일반적으로 Vision Encoders를 통해 이미지를 이산적인 Token Sequences로 변환한 다음, Modal Alignment Modules를 거쳐 LLMs에 입력하는 방식을 취합니다. 그러나 방대한 수의 Visual Tokens는 특히 High-Resolution ImageLong-Video Input 시나리오에서 VLM Inference에 심각한 계산적 과제를 안겨줍니다. 이러한 Computational Burden은 자원이 제한된 환경이나 실시간 애플리케이션에서 VLMs의 실제 배포를 심각하게 제한합니다.

 

 이러한 계산적 문제를 완화하기 위해 Token Pruning이 유망한 해결책으로 떠올랐습니다. Token Pruning의 핵심 아이디어는 특정 기준을 통해 개별 Visual TokensImportance Scores를 할당한 후, Inference Phase에서 가장 중요한 상위 k개의 토큰만 유지하고 나머지는 버리는 것입니다. 시각적 콘텐츠는 일반적으로 텍스트보다 정보 밀도가 훨씬 낮기 때문에, Pruning 기술을 통해 Visual Token Count를 실질적으로 줄여도 Performance Degradation은 최소화됩니다.

 

 당연하게도 Token Pruning의 효과는 정확한 Visual Token Importance Assessment에 크게 의존합니다. 최근 연구들은 주로 Vision EncodersLLMAttention Scores를 지표로 활용하는데, 예를 들어 [CLS] Attention이나 Average LLM Attention 등이 있습니다. Attention Mechanism이 토큰의 중요도를 정확하게 반영하는지 조사하기 위해, 우리는 널리 사용되는 VLMLLaVA-1.5Vision EncoderLLM 양쪽에서 Attention Heatmaps를 시각화했습니다.

 

그림 1에 나타난 바와 같이, [CLS] TokenForeground Objects에 부분적으로 집중할 수 있지만, 정보가 적은 Background Regions에 과도한 Attention을 할당하는 경우가 많습니다. 이러한 관찰 결과는 Vision Transformers가 공간적 세부 사항을 버리면서 전역 이미지 정보를 집계하는 ArtifactsHigh-Norm Outlier Tokens를 생성하는 경향이 있다는 이전 연구와 일치합니다. 반면, LLM에서 Positional Encoding의 국소성과 Causal Masking Mechanisms로 인해 이미지의 하단 절반에 편향을 보이는 Attention Shift 현상이 식별되었지만, 이 편향은 주로 Vision-to-Vision 또는 All-Token Attention Patterns에서 관찰됩니다. 이와 대조적으로 Text-to-Vision AttentionQuery-Relevant Regions에 효과적으로 집중하는 것으로 보입니다.

 

위의 관찰 결과는 현재의 Attention-based Token Pruning 방법의 효과에 대한 심층 분석을 수행하게 된 동기가 되었으며, 두 가지 결정적인 통찰을 얻었습니다:

  1. Vision Encoder에서 [CLS] Token은 핵심적인 Salient Foreground Objects에 적절히 주의를 기울이지 못하며, 이는 특히 제한된 Token Budgets 하에서 최적이 아닌 Pruning Results를 초래합니다.
  2. LLM에서 Text-to-Vision AttentionAttention Shift에 저항하는 강건함을 보여주며, Middle Layer에서 Query-Relevant Visual Tokens를 선택하기 위한 신뢰할 수 있는 가이드를 제공할 수 있습니다. 그러나 PruningMiddle Layers까지 지연시키는 것은 여전히 상당한 Redundant Computations를 포함하므로 가속화 이득이 미미합니다.

이러한 한계를 해결하기 위해 본 연구에서는 Vision Encoder 이후와 LLM 내부에서 순차적으로 불필요한 Visual Tokens를 제거하여 VLM의 효율성을 높이는 Two-Stage Token Pruning 프레임워크인 LearnPruner를 제안합니다. 구체적으로, 먼저 기존의 [CLS] Attention Scores를 대신하여 Visual TokensImportance Scores를 직접 예측하는 가벼운 Learnable Module을 채택함으로써 내재적인 Visual Redundancy를 제거합니다. 또한, 보완적인 시각 정보를 제공하기 위해 작은 규모의 Diversity Tokens 세트가 유지됩니다. 남은 토큰들은 Query InstructionsCross-Modal Interaction을 위해 LLM으로 전달됩니다. 그 후, LLMMiddle Layer에서 Query-Aware Token Selection을 수행하여 주어진 Query와 무관한 토큰을 추가로 제거합니다.

 

 정제된 Pruning StrategyImportance Measures의 이점 덕분에 LearnPruner는 유리한 Accuracy-Efficiency Trade-off를 달성합니다. 다양한 VLM Benchmarks에 대한 광범위한 실험 결과, LearnPruner가 기존의 최신 방법들을 능가함을 보여줍니다. 단 5.6%의 토큰만 유지하면서도 LearnPruner는 원래 성능의 94.8%를 보존하며, Prefill TimeTotal Time에서 각각 2.3배 및 1.5배의 가속화를 달성합니다.

 

2. RELATED WORK

2.1 VISION-LANGUAGE MODELS.

Vision-Language Models의 진화는 초기 Joint Embedding 접근 방식에서 시작하여 Large Language Model의 역량을 활용하는 정교한 아키텍처로 발전해 왔습니다. LLaVA-1.5, MiniGPT-4, Qwen-VL과 같은 현대적 VLMsModal Alignment Modules를 통해 Visual Encoders를 강력한 언어 모델과 통합함으로써 Multimodal Understanding에서 주목할 만한 성능을 입증했습니다.

 

 더욱이 최근 VLM 입력은 고해상도 이미지와 비디오 시퀀스로 확장되어 Visual Token Sequences를 비약적으로 증가시켰습니다. 예를 들어, LLaVA-1.5는 이미지당 576개의 토큰을 생성하는 반면, LLaVA-NeXT는 고해상도 이미지를 하위 이미지 그리드로 나누어 최대 $5 \times 576 = 2,880$개의 토큰을 생성하고, LLaVA-OneVision은 비디오 프레임에 Pooling 연산을 적용하여 최대 $32 \times 196 = 6,272$개의 토큰을 생성합니다. 시각적 시퀀스가 길어지는 이러한 추세는 계산 부담을 가중시켰을 뿐만 아니라, Visual Tokens가 입력 구성을 지배하는 Multimodal Inputs의 심각한 불균형을 초래했습니다. 그러나 시각적 모달리티는 언어에 비해 실질적으로 높은 Redundancy를 나타내며, 이는 토큰 수량과 정보 밀도 사이의 불일치를 유발하여 VLMs에서 Visual Token Reduction 연구의 필요성을 자극합니다.

 

2.2 VISUAL TOKEN REDUCTION FOR VLMS

앞서 언급했듯이 Visual Token Reduction은 이미지 내의 중복 정보를 제거하여 Inference Efficiency를 높이는 것을 목표로 합니다. FastVLLM 내에서 한 토큰이 다른 모든 토큰으로부터 받는 Average Attention Scores를 계산하여 토큰 중요도를 결정하고, LLMShallow Layer에서 중요하지 않은 토큰을 Pruning하는 이 분야의 선구적인 작업입니다. PyramidDropLLM 내에서 Redundancy가 점진적으로 증가한다는 것을 관찰하고 Hierarchical Pruning Strategy를 제안했습니다. SparseVLMText-Aware Guidance를 달성하기 위해 Text-to-Vision Attention Scores를 활용하고, Rank of Attention Matrices를 사용하여 Pruning Ratio를 적응적으로 조정합니다.

 

한편, 일부 연구들은 Vision Encoder 직후에 Pruning 전략을 직접 적용합니다. VisPrunerLLM 내에 Attention ShiftAttention Dispersion 문제가 존재한다고 주장하며 중요도 기준을 [CLS] Token Attention Scores로 대체합니다. VisionZip은 작지만 잠재적으로 중요한 정보의 손실을 피하기 위해 Token Merging 기술을 추가로 통합합니다. 대안적으로 Diversity 관점에서 DivPruneDART는 포괄적인 Visual Context를 포착하는 다양한 세트를 유지하기 위해 Feature Similarity를 기반으로 토큰을 선택합니다

 

이러한 Training-free 방법들 외에도, 최근 연구들은 허용 가능한 수준의 추가적인 Training Overhead를 대가로 정확도를 더욱 향상시키기 위해 Training-based Pruning 방법들을 도입했습니다. ATP-LLaVA는 이미지의 Global Attention Distribution을 활용하여 Instance-specific Pruning Thresholds를 예측합니다. TwigVLMLLMShallow Layers에 추가적인 Decoder Blocks를 삽입하여 더 신뢰할 수 있는 Pruning을 수행할 뿐만 아니라, Self-speculative Decoding을 통해 Decoding Stage Acceleration을 가능하게 합니다. 중복 토큰을 제거하기 위해 어텐션 결과에 의존하는 기존 연구들에도 불구하고, Attention Mechanism은 본질적인 한계를 가지고 있으며 그 효과에 대해서는 추가적인 탐구가 필요합니다.

 

3 METHOD

3.1 PRELIMINARY

 기존 연구들은 일반적으로 Transformer Block에서 파생된 Attention Map에 의존하여 Token Pruning을 수행합니다. Attention MechanismVLMs의 두 핵심 구성 요소인 Vision EncoderLLM 모두에서 토큰 간의 상호작용을 촉진하기 위해 널리 적용됩니다

 

수식적으로 토큰 시퀀스 $X=[x_{1},x_{2},\cdot\cdot\cdot,x_{N}]\in\mathbb{R}^{N\times d}$가 주어지면, Attention 계산은 먼저 입력을 각각 Query ($Q$), Key ($K$), Value ($V$)로 변환합니다

 

여기서 $W_{q}, W_{k}, W_{v}$는 투영 행렬이며, $N$은 시퀀스 길이, $d$는 은닉 차원입니다. 이후 Attention Map ($A$)과 출력 ($O$)은 다음과 같이 계산됩니다:

 

여기서 $M \in \mathbb{R}^{N\times N}$은 선택적인 Mask Matrix입니다. Global Attention을 채택하는 Vision Encoders의 경우 $M$은 제로 행렬이며, Causal Attention을 사용하는 LLMs의 경우 각 토큰이 이전 토큰에만 주의를 기울일 수 있도록 상삼각 행렬(upper triangular matrix)의 형태를 가집니다.

 

3.2 STUDY OF ATTENTION IN VLMS

Attentionin Vision Encoder

 

Vision Encoders는 일반적으로 시퀀스 시작 부분에 [CLS] Token을 포함하여 패치 토큰들과 상호작용하고 전역 정보를 집계합니다. [CLS] Token은 시각적으로 중요한(Salient) 영역에 집중할 것으로 기대되므로, 기존 연구들은 자연스럽게 [CLS] TokenAttention Score를 패치 토큰의 중요도 추정치로 활용해 왔습니다 . 그러나 최근 연구들은 대부분의 Vision Encoders가 균일한 배경 영역에서 Artifacts를 생성하는 경향이 있음을 밝혀냈습니다 . 이러한 Artifacts는 의미론적 정보가 제한적임에도 불구하고 다른 토큰들로부터 비정상적으로 높은 Attention을 받는 Attention Sink 현상을 유발합니다. 실험 결과, [CLS] Token의 어텐션에만 의존하는 방식($[CLS]_{all}$)보다 전경 영역으로 토큰 선택을 제한하는 방식($[CLS]_{fg}$)이 더 우수한 성능을 보였으며, 이는 [CLS] Token이 전경 객체에 충분히 집중하지 못함을 시사합니다.

 

Attentionin LLM

LLMs에서 Visual Tokens는 모달리티 내부뿐만 아니라 Text Tokens와도 상호작용합니다. 기존 연구들은 보통 다른 토큰들로부터 받는 Average Attention이나 마지막 지시어 토큰을 사용하여 중요도 점수를 추정하며, 계산 비용을 줄이기 위해 Shallow Layers에서 Pruning을 수행합니다. 그러나 LLMs에서는 인덱스가 높은(이미지 하단부) Visual Tokens가 더 높은 어텐션 점수를 받는 Attention Shift 현상이 식별되었습니다. 분석 결과, Text-to-Vision Attention은 이러한 편향에 훨씬 더 점진적인 트렌드를 보이며 개별 인스턴스에 따라 다양하게 나타나 Pruning 결정에 더 유리한 것으로 확인되었습니다 . 특히 어텐션의 신뢰도는 Shallow Layers에서 점진적으로 증가하여 Middle Layers에서 안정적인 성능을 보이다가 레이어가 깊어질수록 다시 감소하는 경향을 보입니다.

 

3.3 LEARNPRUNER

 위 연구를 바탕으로, 본 논문은 Vision Encoder 이후의 어텐션 기반 기준을 Learnable Pruning Criteria로 대체하고, Vision Encoder 직후와 LLM 내부에서 각각 Pruning을 수행하는 Progressive Pruning 전략을 채택한 LearnPruner를 제안합니다.

Remove Visual Redundancy.

 이 단계는 Vision Encoder 직후에 보다 콤팩트한 토큰 표현을 통해 원본 이미지 정보를 보존하는 것을 목표로 합니다. 본 연구에서는 [CLS] Attention 대신 각 시각 토큰의 중요도 점수를 직접 예측하는 Learnable Pruning Module (LPM)을 채택합니다. 구체적으로, Vision Encoder의 토큰 특징을 가벼운 MLP에 입력하여 각 토큰의 보존 여부를 결정하는 이진 분류를 수행합니다. 불연속적인 이진 결정의 미분 불가능성 문제를 해결하기 위해 학습 시 Straight-Through Estimator (STE)를 적용합니다:

 

 

Remove Text-Irrelevant Content

이미지에는 방대한 정보가 포함되어 있지만, 모든 시각적 내용이 주어진 쿼리에 필요한 것은 아니므로, LLM 내부에서 두 번째 Pruning을 수행합니다. 분석 결과에 따라 Attention Shift에 영향을 덜 받은 Text Attention을 직접 활용하는 Token Pruning을 가이드 합니다. Text Attention은 모든 헤드에 대해 쿼리 토큰들이 받는 average attention으로 계산 됩니다:

본 연구에서는 가장 높은 어텐션 점수를 가진 상위 $k$개의 토큰만을 유지하여 LLM내에서의 추가 상호작용에 참여 시킵니다.

 

4 EXPERIMENTS

4.1 EXPERIMENTAL SETUP

 

  • Dataset & Training: LLaVA-665K 데이터셋의 10%를 사용하여 학습을 진행했습니다.
  • Frozen Weights: 베이스 모델인 VLM의 가중치는 Frozen 상태로 유지하고, 오직 LPM 모듈만 학습시켰습니다.
  • Two-Stage Pruning: 1단계는 Vision Encoder 직후에 수행하며, Diversity Tokens를 10% 포함합니다. 2단계는 LLM의 12번째 레이어에서 수행됩니다.

4.2 MAIN RESULTS

 

 

  • LLaVA-v1.5-7B: 토큰 수를 576개에서 32개로 대폭 줄였음에도 원래 성능의 94.8%를 유지하며 기존 Training-freeTraining-based 방법들을 압도했습니다.
  • High-Resolution & Video: 고해상도 모델인 LLaVA-NeXT에서는 88.9%의 토큰을 제거하고도 97.5%의 성능을 보존했습니다. Video-LLaVA에서도 FastV보다 높은 정확도와 GPT Score를 기록하며 높은 Generalization 능력을 입증했습니다.
  • Architecture Agnostic: LLaMA 계열이 아닌 Qwen2.5-VL에서도 FastV 대비 월등한 성능 유지 능력을 보여주었습니다.

 

4.3 ABLATION STUDIES

 

  • Importance Criteria: 기존의 [CLS] AttentionLPM 점수로 대체했을 때 성능이 1.7% 향상되었습니다. 이는 LPMSalient Foreground Regions를 훨씬 더 효과적으로 포착함을 의미합니다.
  • Two-Stage Strategy: LPM만 사용하는 1단계 방식보다 LLM 내부의 2단계 Pruning을 결합했을 때 더 효율적인 토큰 예산 배분이 가능하여 성능이 개선되었습니다.
  • Attention in LLM: LLM 내부에서는 추가적인 LPM 설치보다 기존의 Attention Signal을 직접 활용하는 것이 충분히 신뢰할 수 있으며 효율적이었습니다.

 

4.4 EFFICIENCY ANALYSIS

  • Inference Acceleration: LLaVA-v1.5-7B 기준으로 Prefill Time은 5.4배, 전체 시간은 2.3배 가속화되었습니다.
  • Memory Efficiency: KV Cache 사용량을 6.8배 절감하여 메모리 부담을 크게 줄였습니다.
  • LPM Overhead: LPM은 매우 가벼운 설계 덕분에 추가적인 계산 비용이나 메모리 사용량이 거의 무시할 수 있는 수준입니다.

 

5 CONCLUSION

본 논문에서는 Vision Encoder와 LLMs 양쪽 모두의 Attention Mechanisms에 대한 심층적인 분석을 수행했습니다. 이러한 관찰을 바탕으로, Vision Encoder 이후에 학습 가능한 Pruning Module을 통해 중복된 Vision Tokens을 먼저 제거한 다음, LLMMiddle Layer에서 텍스트와 무관한 토큰을 추가로 폐기하는 2단계 Pruning 프레임워크인 LearnPruner를 제안합니다. 광범위한 실험 결과는 우리의 LearnPruner가 기존의 최신 방법들을 능가하며 더 나은 Accuracy-Efficiency Trade-off를 달성함을 보여줍니다.