본문 바로가기

Test-time Adaptation

LVPruning: An Effective yet Simple Language-Guided Vision Token Pruning Approach for Multi-modal Large Language Models

논문 주소: https://arxiv.org/pdf/2501.13652

 

Abstract

멀티모달 대규모 언어 모델은 시각 및 텍스트 modality를 통합함으로 써 놀라운 성공을 거두었습니다. 하지만 처리되는 시각 토큰의 수가 많아 상당한 계산 오버헤드가 발생하며, 이는 자원이 제한된 환경에서의 실용성을 제한합니다. 본 연구에서는 모델의 성능을 유지하면서 계산 부담을 대폭 줄이는 효과적이면서도 간단한 방법인 Language-Guided Vision Token Pruning (LVPruning)을 위해 제안합니다. LVPruning은 cross-attention module을 활용하여 언어 토큰과 상호작용을 기반으로 visual token의 중요도를 계산하고, 어떤 토큰을 제거할 지를 결정합니다. 중요한 점은, LVPruning이 기존 MLLM 파라미터를 수정하지 않고도 통합될 수 있어 적용하거나 제거하기가 매우 간단하다는 것입니다. 본 실험에 따르면 LVPruning은 LLaVA-1.5의 중간 레이어까지 시각 토큰을 최대 90%까지 효과적으로 줄일 수 있으며, 9개의 멀티모달 벤치마크에서 평균 성능 손실이 단 0.45%에 불과한 상태로 추론 TFLOPs를 62.1% 감소시켰습니다.

 

1 Introduction

Multi-modal Large Language Models (MLLMs)는 시각 및 텍스트 정보를 결합하여 두 modality를 모두 이해해야 하는 복잡한 태스크를 수행함으로써 인상적인 결과를 달성했습니다. 하지만 이러한 모델은 계산 집약적일 수 있어 resource-constrained environments에서의 실용성을 제한합니다 이러한 상당한 computational overhead를 초래하는 중요한 사실 중 하나는 이 모델들이 입력 이미지 패치를 나타내는 대량의 vision tokens를 종종 처리하지만, 모든 시각 정보가 이해에 똑같이 중요한 것은 아니라는 점입니다. 인간의 뇌는 관련 없는 세부 사항을 무시하면서 핵심적인 특징에 집중할 수 있어 매우 효율적인 시각적 인지가 가능합니다. 이에 영감을 받아 성능을 크게 희생하지 않으면서도 계산 비용을 줄이고 중요한 vision tokens의 우선순위를 정할 수 있는 MLLMs를 개발할 필요성이 커지고 있습니다.

 

 MLLMs의 계산 효율성을 높이기 위한 이전 접근 방식들은 다양한 전략을 탐구해 왔습니다. Q-formervision encoder로 사용하는 모델들은 시각 정볼르 더 작은 token 세트로 응축합니다. 비록 계산 부하를 효과적으로 줄여주지만, 이러한 응축은 필수적인 시각 정보의 소닛ㄹ로 이어져 성능을 저하시킬 가능성이 있습니다. 반면, LLaVA와 같은 모델들은 모든 vision tokens를 간단한 MLP connector를 통해 language model로 전달하여 높은 성능을 달성하지만, 그 대가로 계산 요구량이 증가합니다. 또한, 시각적 특징만을 기반으로 중요한 vision tokens를 감지하는 token compression 기술들은 single-modal 태스크에서는 유망한 성능을 보여주었으나, MLLMs에서의 시각 정보와 언어 정보 간의 상호작용을 충분히 활용하지 못합니다. 이는 계산 효율성과 모델 성능 사이의 trade-off를 부각시키며, 두 측면을 효과적이고 효율적으로 균형 있게 맞출 수 있는 해결책의 필요성을 시사합니다.

 

 이러한 과제를 해결하기 위해, 본 연구에서는 언어적 context와의 관련성을 기반하여 MLLMs의 Vision tokens 수를 동적으로 줄이는 간단하면서도 효과적인 방법인 Language-Guided Vision Token Pruning (LVPruning)을 제안합니다. 이는 vision tokens language tokens attention을 수행하여 중요도 점수를 계산하는 경량화된 cross-attention decision modules를 도입합니다. 이 relevance scoring을 통해 모델은 각 vision token을 유지할 지 아니면 pruning할 지를 결정할 수 있으며, 정보량이 적은 시각 데이터를 효과적으로 필터링합니다. 이러한 decision modules를 MLLM의 다양한 레이어에 통합함으로써, LVPruning은 모델이 더 깊은 레이어를 처리함에 따라 점진적인 token pruning을 가능하게 합니다. 학습 중에는 모든 기존 모델 파라미터를 freeze하고 삽입된 decision modules만 학습하므로, base model은 변경되지 않은 상태로 유지되며 pruning 메커니즘을 쉽게 적용하거나 제거할 수 있습니다.

 

본 연구의 기여는 세 가지 입니다. 첫 째, 위 그림 1에서 보여주듯이, LVPruning은 모델 성능에 실질적인 영향을 주지않으면서도 vision token의 90%까지를 pruning 함으로써 inference TFLOPs를 최대 62.1%까지 크게 줄일 수 있음을 입증했습니다. 둘째, 기존 MLLMs에 통합하기 쉽고 효과적이며, 원래 구조의 변경을 최소화하는 새로운 language-guided token pruning 메커니즘을 도입했습니다. 셋째, 본 방법은 재학습 없이도 inference 중에 token pruning 비율을 조절할 수 있어, 특정 요구 사항에 따라 효율성과 성능의 균형을 유연하게 맞출 수 있게 해줍니다. 다양한 multi-modal benchmarks에 대한 광범위한 실험을 통해, LVPruning이 정확한 multi-modal 컨텐츠를 이해하고 생성하는 능력을 유지하면서도 MLLMs의 효율성을 향상시키는 실용적인 해결책임을 보여줍니다.

 

2 Related Work

Multi-modal Large Language Models

MLLMs의 최근 발전은 visual 및 textual modalities의 통합을 크게 향상시켰습니다. BLIP-2는 Q-former를 vision encoder로 사용하여 사전 학습된 vision models와 language models를 연결하는 2단계 학습 프레임워크를 도입했으며, 효율적인 처리를 위해 응축된 vision tokens 세트를 효과적으로 생성했습니다. 이를 기반으로 InstructBLIP과 MiniGPT-4는 instruction tuning을 결합하여 복잡한 프롬프트를 따르고 다양한 태스크를 수행하는 모델의 능력을 개선했습니다. 대안적으로, LLaVA-1.5와 같은 모델들은 사전 학습된 vision encoders의 모든 vision tokens를 language model에 직접 입력합니다. 이 접근 방식은 더 풍부한 시각 정보 덕분에 높은 성능을 달성하지만, 상당한 computational overhead를 초래합니다. 이러한 모델들은 계산 효율성과 성능 사이의 trade-off를 보여주며, 정확도를 떨어뜨리지 않으면서 두 측면의 균형을 맞출 수 있는 접근 방식의 필요성을 강조합니다.

Efficient Transformers

Transformer 모델의 계산 효율성을 개선하기 위해 knowledge distillation, token merging/pruning, 그리고 quantization과 같은 많은 기술이 제안되었습니다. NLP 태스크의 경우, DistillBERT, MiniLM과 같은 방법들은 knowledge distillation을 사용하여 더 효율적인 추론을 위한 작은 모델들을 만듭니다. computer vision 태스크의 경우, 연구들은 이미지 분류에서의 중요도에 따라 tokens를 pruning하거나 merging 하는 데 집중합니다. 이러한 방식들은 추론 중에 정보량이 적은 패치를 식별하거나 유사한 tokens를 병합하여 vision tokens의 수를 줄입니다. 우리 연구와 가장 유사한 작업은 MLP layers를 사용하여 token pruning 결정을 예측하는 DynamicVit입니다. 그들은 이미지 분류 태스크를 위해 Vision-Transformer 기반 모델에 여러 pruning layers를 계층적으로 삽입했습니다. 그러나 이러한 방법들은 single-modal 목표를 위해 특별히 설계되었으며, multi-modal 설정에서의 과제를 다루지 않습니다. 본 연구는 이미지 이해 태스크 맥락에서 MLLMs를 위한 token pruning에 집중함으로써 차별화됩니다.

 

3 Methodology 

이 섹션에서는 LVPruning 프레임워크를 상세히 설명합니다. LVPruning은 모든 vision tokens MLP connector를 통해 language model로 전달하는 MLLMs를 위해 설계되었습니다. 이 구조는 Transformer 기반의 사전 학습된 CLIP vision encoder, MLP vision-language connector, 그리고 LLM backbone으로 구성됩니다. 먼저, 이미지 입력은 패치로 나뉘어 CLIP 모델에 의해 처리되며, 각 이미지 패치는 대표적인 vision token이 됩니다. 다음으로, vision-language connectorvision tokens를 LLM의 텍스트 공간 차원으로 투영하면, 결합된 vision 및 text tokens이 casual text generation을 위해 LLM에 입력됩니다. LLM의 특정 레이어에서 cross-attention decition modules는 추론을 가이드하기 위해 가장 중요한 토큰 (어텐션 점수가 가장 높은 토큰)을 동적으로 선택하고 불필요한 토큰을 제거합니다. 학습 중에는 attention masks를 적용하여 pruning된 vision tokens를 마스킹합니다. 중요한 점은, token pruning 후 positional embeddings를 업데이트하는 대신, 표준 LLMs에서 사용되는 것처럼 남은 vision tokens에 대해 기존의 positional embeddings를 유지한다는 것입니다.

 

3.1 Cross-Attention Decision Module

 이제 Token pruning을 위해 설계된 cross-attention decision module의 세부 아키텍처를 설명합니다. vision tokens를 선택하고 폐기하는 역할을 담당하는 decision modulecross-attention layersMLP layer로 구성됩니다. 점진적인 pruning을 위해 이러한 모듈의 여러 인스턴스가 LLM backbone의 서로 다른 레이어에 삽입됩니다. $H \in \mathbb{R}^{N \times d}$를 LLM hidden layer의 출력이라고 가정하며, 여기서 $N$sequence length이고 $d$hidden representations의 차원입니다. $H$vision tokenstext tokens의 하위 집합을 포함합니다.  vision tokens의 인덱스 집합을 $I_V = {n_{v_i} | v_i \in \mathbb{N}, 0 \leq v_i < N}$로 정의하고, text tokens의 인덱스 집합을 $I_T = {n_{t_i} | t_i \in \mathbb{N}, (0 < t_i \leq N) \wedge (t_i \notin I_V)}$로 정의합니다.

 

본 연구에서는 vision tokensquery tokens

 

로 사용하고, text tokensKeyValue tokens

로 사용합니다.

 

여기서 $W_q, W_k, W_v$linear projection layers입니다. 그런 다음 attention matrix를 계산하고 그 출력을 FFN에 공급합니다.

본 연구에서는 FFN 출력을 linear layer $\mathbf{W}_O$에 공급하여 vision token을 유지하거나 제거하기 위한 점수를 예측합니다.

여기서 $\gamma_{i,0}$은 vision token $\mathbf{H}{I_V,i}$를 유지하기 위한 점수를 나타내고, $\gamma{i,1}$은 해당 토큰을 제거하기 위한 점수를 나타냅니다. vision token pruning 결정은 $\boldsymbol{\gamma}$를 기반으로 생성됩니다. 결정을 생성하고 적용하는 메커니즘은 학습과 추론 사이에 차이가 있습니다. 여러 decision modulesLLM의 서로 다른 레이어에 삽입되어, vision tokensLLM 전체에 걸쳐 점진적으로 pruning됩니다.

 

3.2 End-to-End Training

Cross-attention decision module의 출력인 $\boldsymbol{\gamma}$를 기반으로 토큰 pruning 결정을 생성하고 적용하는 과정이 differentiable하도록 보장하기 위해, 본 연구에서는 Gumbel-softmax 분포를 $\boldsymbol{\gamma}$에 적용하여 이를 one-hot vectors $\mathbf{D}_{GS}$로 재분배합니다. 그런 다음 각 벡터의 첫 번째 차원을 결정 $\mathbf{D}$로 사용하여 특정 토큰을 유지할지 여부를 결정합니다.

 

여기서 $D_i = 1$vision token $\mathbf{H}_{I_V,i}$를 유지함을 의미하며, 그 반대도 마찬가지입니다. 결정에 의해 주어지는 유지된 토큰의 수는 학습 중에 고정되지 않으며, 원치 않는 vision tokens를 직접 제거하는 것은 batch processing을 방해합니다. 이 우려를 해결하기 위해, 본 연구에서는 $\mathbf{D}$를 기반으로 visionlanguage tokens 모두에 대해 attention masks $\mathbf{M}$을 생성합니다. 구체적으로는 다음과 같습니다.

따라서 $\mathbf{M}$은 모든 language tokens에는 마스크 값 1이 할당되는 반면, vision tokens는 토큰 pruning 결정 $\mathbf{D}$에 따라 마스킹되도록 구성됩니다. 또한, 수치적 안정성을 높이기 위해 $\mathbf{M}$의 모든 대각 요소는 1로 설정됩니다.

 

그러나 $\mathbf{M}$은 원래의 causalpadding attention masks를 무시합니다. 이를 해결하기 위해, 먼저 원래의 attention mask $\mathbf{\bar{M}}$을 가공되지 않은 attention scores에 적용하여 causal attention matrix $\mathbf{\bar{A}}$를 얻습니다. 그런 다음 $\mathbf{\bar{A}}$에 Softmax 연산과 함께 $\mathbf{M}$을 적용하여 최종 attention matrix $\mathbf{\hat{A}}$를 얻습니다. 구체적으로, $l$ 레이어의 decision module에서 생성된 attention mask$\mathbf{M}_l$이라 정의합니다. $l+x$ 레이어에서의 attention scores는 다음과 같이 계산됩니다.

 여기서 $(l + x) \in \mathbb{N} < l'$이며, $l'$은 다음 decision module의 레이어 위치입니다. 만약 $\mathbf{M}_{i,j} = 0$이면, 최종 attention matrix에서 토큰 $\mathbf{H}_j$에 대한 attention score는 0이 되며, 결과적으로 $\mathbf{H}_j$는 다른 어떤 토큰에도 기여하지 않게 됩니다. 또한, $l$ 레이어와 $l'$ 레이어($l' > l$)에서 얻은 토큰 pruning 결정을 각각 $\mathbf{D}^l, \mathbf{D}^{l'}$라고 정의할 때, $\mathbf{D}^{l'}$를 다음과 같이 업데이트합니다.

 

여기서 $\odot$element-wise production이며, 이는 이전에 제거된 vision token이 다시는 사용되지 않음을 의미합니다.

 

결론적으로, 식 (7) - (9)는 전체 토큰 수를 변경하지 않으면서 원치 않는 vision tokens가 다른 토큰에 미치는 영향을 제거합니다. 식 (5), (6), (8), (9)를 통해 토큰 pruning 결정을 생성하고 적용하는 과정은 완전히 differentiable해집니다. 이 두 가지 요소는 LVPruningend-to-end training 능력을 달성하게 합니다.

 

Training Objectives

 LVPruning의 학습 목표는 decision modules가 서로 다른 레이어에서 미리 정해진 비율로 vision tokens를 제거하도록 가르치는 동시에, Token pruning에도 불구하고 MLLMsvision instruction-following 능력을 유지하도록 fine-tuning하는 것입니다. 주요 학습 목표는 instruction tuning을 위한 causal language modeling이며, 이를 $L_{causal}$이라 합니다.

 

 또한, 각 decision module에서 유지되는 vision tokens의 비율이 사전 정의된 값과 일치하도록 하기 위해, $S$개의 decision modules를 특정 레이어 인덱스 $L_{idx} = [l_1, ..., l_S]$에 삽입하고, 목표 토큰 유지 비율 $P = [\rho_1, ..., \rho_S]$를 설정합니다. 이를 강제하기 위해, 우리는 토큰 pruning 결정을 제약하는 Mean Squared Error (MSE) loss인 $L_{ratio}$를 적용합니다.

여기서 모든 실험에서 $\beta = 0.5$로 설정했습니다. 최종 학습 목표는 $L_{causal}$과 $L_{ratio}$의 가중 합입니다.

 

3.3 Inference

학습 단계에서는 무관한 vision tokens의 영향을 배제하기 위해 attention masks가 사용됩니다. 그러나 inference 단계에서는 계산 비용을 줄이기 위해 이러한 토큰들을 실제로 제거해야 하며, 이는 상당한 실무적 어려움을 야기합니다. 첫째, D에 의해 결정되는 유지된 vision tokens의 수량이 가변적이어서 batch inference 프로세스가 복잡해집니다. 둘째, 현대의 LLMs는 일반적으로 각 레이어의 토큰에 대해 positional embeddings를 사용합니다. 학습 중에 확인된 분포와의 정렬을 보장하기 위해 유지된 토큰에 대해 원래의 positional embeddings를 유지하는 것이 필수적입니다.

 

첫 번째 문제를 극복하기 위해, 추론 중에 사용할 토큰 유지 비율 세트 $\hat{\mathbf{P}} = [\hat{\rho}_1, \dots, \hat{\rho}_S]$를 정의합니다. 추론 비율은 학습 시의 비율과 반드시 같을 필요는 없다는 점에 유의하십시오. $s$번째 token pruning 레이어에서, 먼저 다음과 같이 decision scores를 정렬합니다.

 

 

그런 다음 점수가 가장 높은 상위 $k_s = \hat{\rho}_s \times |I_V|$개의 vision tokens를 유지합니다. 전체 vision tokens 중 유지된 인덱스는 $\hat{I}s = {Q{s, 1:k_s}}$이며, 전체 토큰 중 유지된 vision token 인덱스는 $\hat{I}v^s = I{v, \hat{I}_s}$입니다. 우리는 첫 번째 레이어의 positional embeddings$\mathbf{PE}_1$으로 정의합니다. 각 pruning 이후에도 visiontext tokens 모두에 대한 positional embeddings가 변경되지 않도록 보장하기 위해, $s$번째 token pruning 레이어의 positional embeddings는 다음과 같습니다.

4 Experimental Setup

본 실험의 목적은 token pruning 기술을 사용하여 MLLMs의 효율성을 높일 수 있는지 확인하는 것입니다. 구체적으로 (i) LVPruning이 성능을 유지하며 computational costs를 얼마나 절감하는지, (ii) 최신 MLLMs와 비교해 효율성과 성능 사이의 균형을 얼마나 잘 잡는지 조사합니다.

 

4.1 Implementation Details

 실험은 LLaVA-1.5-7B를 베이스로 하며, LLaMA LLM의 1, 8, 16번째 레이어 뒤에 총 3개의 decision modules를 삽입했습니다. 학습 시 token kept ratio$\rho = 0.5$로 설정했습니다. 기존 LLaVA 파라미터는 모두 freeze하고 삽입된 모듈만 학습시켰으며, 8개의 A100 GPU를 사용했습니다. 추론 시에는 별도의 튜닝 없이 세 가지 비율($\rho = 0.6, 0.5, 0.45$)을 평가했습니다.

 

4.2 Dataset and Benchmarks

LLaVA-1.5 데이터셋 중 이미지가 포함된 약 620k 샘플을 사용하여 1 epoch 동안 학습했습니다. 성능 평가는 VQAv2, GQA, MMBench, POPE 등 총 9개의 멀티모달 벤치마크를 통해 이루어졌으며, 이를 통해 inference FLOPs와 정확도를 측정했습니다

 

5 Experimental Results

5.1 Performance Preservation and Computation Cost Reduction

 LVPruning은 9개의 multi-modal benchmarks에서 성능을 거의 유지하면서도 inference cost를 획기적으로 낮췄습니다. token kept ratio($\rho$)가 0.6일 때, TFLOPs를 약 52.6% 절감(8.38에서 3.97로 감소)했음에도 VQAv2, GQA, VizWiz 등에서 성능 하락이 거의 없거나 오히려 향상되었습니다. 심지어 $\rho = 0.45$로 설정하여 TFLOPs를 66.7%나 줄였을 때도 성능 저하가 매우 적어, 모델의 efficiency와 태스크 성능 사이의 균형이 매우 뛰어남을 입증했습니다.

 

5.2 Comparisons with state-of-the-art MLLMs

 LVPruning($\rho = 0.5$)을 최신 Q-former 기반 모델들과 비교한 결과, 경쟁력 있는 inference FLOPs로 더 우수한 성능을 달성했습니다. VQAv2 벤치마크에서 BLIP2-14BIDEFICS-9B보다 높은 정확도를 기록했으며, 특히 VizWiz 태스크에서는 기존 모델들을 큰 차이로 앞섰습니다. 또한 MMBenchLLaVA-Wild 등 다양한 instruction following 벤치마크에서도 일관되게 경쟁력 있는 결과를 보여주며, computational efficiency와 성능 사이의 최적의 trade-off를 증명했습니다.

 

6 Conclusion

본 연구에서는 기존 MLLMs에 최소한의 아키텍처 변경으로 통합할 수 있는 새로운 language-guided vision token pruning 방식인 LVPruning을 소개합니다. LVPruning은 언어 토큰을 기반으로 각 vision token에 대한 relevance scores를 계산하며, LLM 전체에 걸쳐 중복된 토큰을 점진적으로 제거합니다. 중간 레이어까지 최대 90%의 vision tokens를 제거하여, 9개의 multi-modal benchmarks에서 평균 성능 손실이 약 0.45%에 불과한 상태로 FLOPs를 62.1% 감소시켰습니다. 이는 LVPruningmulti-modal tasks에서 성능을 유지하면서 MLLM 효율성을 향상시키는 실용적인 해결책임을 보여줍니다.

 

7 Limitations

LVPruning은 계산 부하를 줄이는 데 큰 가능성을 보여주었으나, 본 연구의 몇 가지 한계점도 인정해야 합니다. 본 연구의 평가는 특정 benchmarks 세트에서 수행되었습니다. 다른 데이터셋이나 실제 애플리케이션에서의 LVPruning 성능은 아직 탐구되지 않은 상태로 남아 있습니다. 따라서 LVPruningcomputational overhead를 줄이는 데 효과적일지라도, 필수적인 시각 정보가 손상되지 않도록 보장하기 위해 이 방법을 적용할 때는 각 태스크의 구체적인 요구 사항을 고려하는 것이 중요합니다. 향후 연구에는 LVPruning의 실질적인 영향을 더 잘 이해하기 위해 인간의 피드백을 통한 성능 평가가 포함될 수 있습니다.