Code Monkey home page Code Monkey logo

retrieval_head's Introduction

Retrieval Head

This is the open-source code for paper: Retrieval Head Mechanistically Explains Long-Context Factuality.

This code is implemented based on Needle In a HayStack.

Retrieval Head Detection

An algorithm that statistically calculate the retrieval score of attention heads in a transformer model. Because FlashAttention can not return attention matrix, this algorithm is implemented by first caching with FlashAttention and apply normal attention for decoding.

Environment

Core: pytorch=2.0.1, transformers=4.37.2, flash-attn=2.5.6 (my environment)

Other: rouge_score

A Single 80G GPU is enough to detect up to 50K length.

Usage :

python retrieval_head_detection.py  --model_path $path_to_model --s 0 --e 50000

We find that only few samples can stablely detect some of the strongest retrieval heads. I if you are in a hurry or no fancy large GPUs avalible, you can just set '--e' to a lower value, e.g.

python retrieval_head_detection.py  --model_path $path_to_model --s 0 --e 5000

Results of retrieval score will be write in './head_score/$model_name.json' Currently Implemented Model Families: LLama(Llama-2-7B-80K), Yi, Qwen, Mistrial

Results:

All detection results are saved in "./head_score/*.json", where each head is saved in the format of

{layer-head_id: [list of retrieval scores across detections]}

Directly load a results for Analysis

## load head score file, llama-2-7b-80k for example
import json
import numpy as np
with open('./head_score/llama-2-7b-80k.json') as file:
    head_list = json.loads(file.readline())
## use the average retrieval score and ranking
head_score_list = [([int(ll) for ll in l[0].split("-")],np.mean(l[1])) for l in head_list.items()]
head_score_list = sorted(head_score_list, key=lambda x: x[1], reverse=True) 
top_retrieval_heads = [[l[0],  round(np.mean(l[1]), 2)] for l in head_score_list][:10]
print(top_retrieval_heads)
'''
Head:[16, 19],   Retrieval Score: 0.94      Head:[11, 15],   Retrieval Score: 0.92      
Head:[8, 26],    Retrieval Score: 0.8       Head:[6, 9],     Retrieval Score: 0.62        
Head:[7, 12],    Retrieval Score: 0.61      Head:[17, 22],   Retrieval Score: 0.56
Head:[11, 2],    Retrieval Score: 0.46      Head:[6, 16],    Retrieval Score: 0.44
Head:[19, 15],   Retrieval Score: 0.42      Head:[21, 30],   Retrieval Score: 0.4
'''

Influence on Needle-in-a-Haystack

This code is implemented by masking the given head in the attention matrix or masking the query in FalshAttention.

Usage:

Setting --mask_top to K > 0 to mask out top K retrieval heads, K < 0 to mask out K random heads, K = 0 for no masking.

A Single 80G GPU can test up to ~70K length, 2*80G GPU can test up to 100K length

Masking top 30 retrieval heads vs 30 random heads:

python needle_in_haystack_with_mask.py --mask_top 30 --s 1000 --e 100000  --model_path $path_to_model  #Results of  will be written in './results/graph/llama-2-7b-80k_block_top30'
python needle_in_haystack_with_mask.py --mask_top -30 --s 1000 --e 100000  --model_path $path_to_model  #Results of  will be written in './results/graph/llama-2-7b-80k_block_random30'

Reulsts and Visualization:

Replace 'model_name' in './viz/CreateVizFromLLMTesting.ipynb' by the folder name of Needle-in-a-Haystack results.

Mask top 30 Retrieval Head for Llama-2-7b-80K: alt text Mask random 30 non-Retrieval Head for Llama-2-7b-80K: alt text

retrieval_head's People

Contributors

nightdessert avatar

Stargazers

Lau Van Kiet avatar Roman Solomatin avatar Jiaxin Zhang avatar Zhuoran Zhang avatar Aicy_Xxgzbd avatar Longze Chen avatar Yuhao Dong avatar w568w avatar zxy avatar Hao Chen avatar Jayce132 avatar Usatyuk Vasiliy avatar Jeff Carpenter avatar ssl avatar Ruslan Khalitov avatar song avatar Yilong Zhao avatar Ameya Godbole avatar Zekun Li avatar mengru wang avatar Ying Sheng avatar di avatar WΞNDΞL avatar Yunze Man avatar Yangyu Zhang avatar CooperLeong avatar elucida avatar  avatar Jiaming Tang avatar Yunxuan Xiao avatar Han avatar Yue Zhang avatar CuiBo avatar  avatar Sofian Mejjoute avatar Songlin Yang avatar QinLuo avatar Eric Alcaide avatar Xinnian Liang avatar HatMatrix avatar 唐国梁Tommy avatar Chris Kerwell Gresla avatar Rohan Paul avatar  avatar Tej Shah avatar rico avatar Aman Bhandula avatar init avatar Renat Zayashnikov avatar XLXW avatar Mark avatar Tony Lin avatar Atlantis avatar Mark Anthony Llego avatar  avatar  avatar  avatar Junyan Xu avatar Yi Lu avatar Zheng Yuan avatar  avatar Moyo avatar James avatar Tianwei Yin avatar Abdulrahman Tabaza avatar  avatar MCX avatar  avatar  avatar Paul Yao avatar Satyam Tiwary avatar Jie avatar Rohit Saxena avatar  avatar  avatar  avatar hanyang avatar Guangxuan Xiao avatar Huanxuan Liao avatar Mo Li avatar Ramsey avatar  avatar Haden Wasserbaech avatar Linghui Meng avatar Yu Zhang avatar Emmanuel Kahembwe avatar  avatar Qian avatar Xiang LIU avatar Kaishuai Xu avatar  avatar Yao Fu avatar Markus Rauhalahti avatar Dawei Zhu avatar  avatar ZikaiXiao avatar Huiqiang Jiang avatar Zecheng Tang avatar Hongwu Peng avatar LeeHX avatar

Watchers

 avatar  avatar

retrieval_head's Issues

some typing errors

with open('./head_score/llama-2-7b-80k.json') as file:
    head_list = json.loads(file.readline())
head_score_list = [([int(ll) for ll in l[0].split("-")],np.mean(l[1])) for l in head_list.items()]
top_retrieval_heads = [[l[0],  round(np.mean(l[1]), 2)] for l in head_score_list][:10]

Difference between real_needle and needle in needle.jsonl?

Will the choice of needle cause different observational phenomena? What is the difference between these two needles in

{"needle": "A new report from the WMO shows that records were once again broken, and in some cases smashed, for greenhouse gas levels, surface temperatures, ocean heat and acidification.", "question": "What does a new report from WMO shows ?", "real_needle": "records were once again broken, and in some cases smashed, for greenhouse gas levels, surface temperatures, ocean heat and acidification."}
?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.