Code Monkey home page Code Monkey logo

snapkv's Introduction

SnapKV 📷

We introduce an innovative and out-of-box KV cache compression method, SnapKV.

Requirements

Currently tested with transformers==4.37.0, need to check if it is compatible with higher version.

transformers>=4.36
flash-attn==2.4.0

Installation

git clone [email protected]:FasterDecoding/SnapKV.git
cd SnapKV
pip install -e .

Quick Start

Use SnapKV-optimized Models

For example:

from snapkv.monkeypatch.monkeypatch import replace_mistral
replace_mistral() # Use monkey patches enable SnapKV

Check the example notebook.

Customize Your SnapKV-optimized Models

SnapKV can be easily integrated with other models.

You can follow the comment marked with [SnapKV] in existing models to construct your own models. (Currently we support Llama family/ Mistral/ Mixtral)

The detailed algorithm of SnapKV is in snapkv_utils.py

Partial Results

Comprehensive Experiment Results on LongBench Pressure Test Result on Needle-in-a-Haystack

TODO

  • Add observation experiments for reduplication.
  • Add LongBench for reduplication.
  • Explore the prompt phase compression.

Citation

If you feel this project is helpful, please consider cite our report 😊

@article{li2024snapkv,
  title={SnapKV: LLM Knows What You are Looking for Before Generation},
  author={Li, Yuhong and Huang, Yingbing and Yang, Bowen and Venkitesh, Bharat and Locatelli, Acyr and Ye, Hanchen and Cai, Tianle and Lewis, Patrick and Chen, Deming},
  journal={arXiv preprint arXiv:2404.14469},
  year={2024}
}

snapkv's People

Contributors

leeyeehoo avatar wendyh1108 avatar ctlllll avatar

Stargazers

 avatar  avatar  avatar SeshurajuP avatar  avatar Kaizhao Liang avatar fredchen avatar Jiabao Ji avatar Taehoon Kim avatar Ziang Wu avatar  avatar Yuzhen Mao avatar Di Liu avatar Yusuf Syaifudin avatar zxy avatar  avatar  avatar Emma Thompson avatar  avatar felix.do.1030 avatar Francesco Baldassarri avatar Parth Mehta avatar Yingxin Li avatar  avatar RuiqiLi avatar 靳西 avatar  avatar feitianxue avatar  avatar  avatar Pramit Choudhary avatar Nordlicht avatar Joshua David avatar  avatar Jose Cohenca avatar  avatar YeWenting avatar Hassan Hayat avatar  avatar 电线杆 avatar  avatar Lynn avatar Junyoung Park avatar Yi Lu avatar yhq avatar Bohr avatar Mike Bybee avatar kyle avatar 任思宇 avatar wdc avatar LeeHX avatar  avatar Minh Tran avatar JFDuan avatar  avatar Etienne Balit avatar Jack_Lee avatar Liang Qiao avatar Andrew PH avatar Jindong Li avatar Peyton avatar snoop2head avatar Hongwu Peng avatar Rui Pan 潘瑞 avatar Shreyansh Singh avatar  avatar Alexey Gorodilov avatar CuiBo avatar ChaoPeng avatar  avatar Daxiong avatar  avatar Nur Arifin Akbar avatar Zhang Cao avatar cosmicrealm avatar Prashant Nigam avatar Jinghan Yao avatar  avatar Jeff Carpenter avatar lxw avatar Xijie Huang avatar Huiqiang Jiang avatar Xiangming (Brian) Gu avatar wangru avatar  avatar Yilong Zhao avatar liujingcs avatar  avatar 睡觉型学渣 avatar Acyr Locatelli avatar Zhenyu (Allen) Zhang avatar Yao Fu avatar Xiaotian Han avatar zw avatar  avatar cnxup avatar Minsoo Kim avatar Stoney Kang avatar Markus Rauhalahti avatar 唐国梁Tommy avatar

Watchers

Mike avatar  avatar 靳西 avatar  avatar

snapkv's Issues

What prompt was used in Needle in a Haystack test?

I try to reproduce needle test with LWM-Text-Chat-1M but the model just refuse to answer. I have tried following prompts in Needle test and the model just generate </s>

<s>[INST] <<SYS>>
You are a helpful AI bot that answers questions for a user. Keep your response short and direct
<</SYS>>
{ context }

{retrieval_question} Don't give information outside the document or repeat your findings
[/INST]

and

<s>[INST] <<SYS>>
You are a helpful AI bot that answers questions for a user. Keep your response short and direct
<</SYS>>
{ context }

{retrieval_question} Don't give information outside the document or repeat your findings
[/INST]</s>

Questions on paper and code [prompting for mistral, positional index, minor errors & questions in paper]

Hello :)
Thank you for the excellent work and for sharing your code. I've learned a lot and have a few questions about the paper and settings:

  • In Figures 2 and 3, what specifically do "prompt" and "context" represent? My guess is that "prompt" refers to the entire input sequence length, and "context" includes specific instructions. Should their labels be switched?

  • Could you share the specific prompt details applied in the Mistral experiment for measuring LongBench performance? Using the default LongBench settings, I observed lower performance overall, particularly in Qasper:

    • For Mistral-v2: Full: 28.92, SnapKV 2048: 26.43, 4096: 28.42 (reported: 33.06/32.47/33.36 respectively).
    • Intuitively, I think that sending the task-specific prompt (ex. You are given a scientific article and a question. Answer the question....) from LongBench to the end of the input sequence, so it falls within the window range, might improve performance. Was there any such modification?
  • Following the SnapKV methodology, I expect the KV cache size to always be bounded by the max_capacity_prompt. Yet, why does an OOM error occur when exceeding a certain length? (131K at Sec 5.1) Could it be due to recalculating the attention weights in Line 9 of Listing 1?

Additionally, there seems to be a minor error in Figure 7 where both the top and bottom plots are labeled as "without Pooling." It might be less confusing to label the bottom plot as "with Pooling."

Thank you for any insights you can provide. I really appreciate the motivation and methodology behind your work!

Grouped query attention implementation

Thank you for your nice work and sharing code. Grouped query attention is used in Mistral and Mixtral models. However, I found the implementation in snapkv_utils.py is for multi-head attention, it may not be correct for grouped query attention.

Can't not run longbench!

Here is my env. The version of transfomers is meet the requirements in monkeypatch.py

torch==2.2.0
transfomers==4.37.0

The traceback are as follows:

traceback

>> python pred_snap.py --model llama2-7b-chat-4k --compress_args_path ablation_c1024_w32_k7_maxpool.json

Traceback (most recent call last):
File "experiments/LongBench/pred_snap.py", line 321, in
File "/data1/ss/anaconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "experiments/LongBench/pred_snap.py", line 132, in get_pred_single_gpu
File "/data1/ss/anaconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/data1/ss/anaconda3/lib/python3.10/site-packages/transformers/generation/utils.py", line 1474, in generate
return self.greedy_search(
File "/data1/ss/anaconda3/lib/python3.10/site-packages/transformers/generation/utils.py", line 2335, in greedy_search
outputs = self(
File "/data1/ss/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data1/ss/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/data1/ss/anaconda3/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1183, in forward
outputs = self.model(
File "/data1/ss/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data1/ss/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/data1/ss/anaconda3/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1035, in forward
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
File "/data1/ss/anaconda3/lib/python3.10/site-packages/transformers/modeling_attn_mask_utils.py", line 398, in _prepare_4d_causal_attention_mask_for_sdpa
expanded_4d_mask = attn_mask_converter.to_4d(
File "/data1/ss/anaconda3/lib/python3.10/site-packages/transformers/modeling_attn_mask_utils.py", line 137, in to_4d
expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
RuntimeError: The size of tensor a (3509) must match the size of tensor b (7017) at non-singleton dimension 3

I think the reason would be DynamicCache.get_usable_length conflict with the getting-casual-mask function _prepare_4d_causal_attention_mask_for_sdpa.

I would like to know how can I quick fix this. Thx :)

It seems that snapkv need to be able to do "prefill" at least once before the prompt can be compressed.

snapkv need a full len q, k matmul before its first self-attention, which is a $O(n^2)$ space complexity. So is snapkv need to be able to do "prefill" at least once before the prompt can be compressed?

after that it can save memory footprint during decoding phase.

   def update_kv(self, key_states, query_states, value_states, attention_mask, num_key_value_groups):
        # check if prefix phase
        assert key_states.shape[-2] == query_states.shape[-2]
        bsz, num_heads, q_len, head_dim = query_states.shape
        if q_len < self.max_capacity_prompt:
            return key_states, value_states
        else:
            attn_weights = torch.matmul(query_states[..., -self.window_size:, :], key_states.transpose(2, 3)) / math.sqrt(head_dim)
            

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.