Code Monkey home page Code Monkey logo

easycontext's Introduction

EasyContext

🤗 Hugging Face

Memory optimization and training recipes to extrapolate language models' context length to 1 million tokens, with minimal hardware.

Updates

  • [06/25] EasyContext is part of our effort to develop LongVA, a long context vision language model. Check it out if it interests you!
  • [05/11] Add Ulysses.
  • [05/06] Add distractors (multi-needle) in the NIAH evaluation script. You can set the number of distractors using --num_distractor.
  • [05/06] IMPORTANT! If you want to use eval_needle.py to evaluate the llama3 model, you need to add one extra space (" ") behind the QUESTION_STR. I believe this has something to do with the tokenizer.

What is this?

Many companies have been promoting their models' capability to handle long context. For those outside the companies, a context of 1 million tokens still seems somewhat magical or requires enormous compute. This repo aims to demystify long context scaling and show that it is actually quite straightforward.

This repo does not propose new ideas. Instead, we showcase how to combine existing techniques to train language models with a context length of:

  • 700K with 8 A100 (Llama2-7B).
  • 1M with 16 A100 (Llama2-13B).

No approximations are used. The models can be trained with full finetuning, full attention, and full sequence length. Our training script (train.py) has less than 200 lines of code.

The techniques used are:

We support different sequence parallel methods:

We then proceed to train Llama-2-7B on 8 A100 by gradually increasing its rope base frequency to 1B. Notably, our model is only trained with 512K sequence length while generalizing to nearly 1M context.

Usage

from easy_context import prepare_seq_parallel_inputs, apply_seq_parallel_monkey_patch, prepare_dataloader
from transformers import LlamaForCausalLM
# Swap attention implementation from flash attn to either dist_ring_attn or zigzag_ring_attn
apply_seq_parallel_monkey_patch("dist_flash_attn", "llama")
# Make sure you toggle on flash_attention_2
model = LlamaForCausalLM.from_pretrained(model_name, _attn_implementation="flash_attention_2")
accelerator = ...
train_dataloader = ...
prepare_dataloader("dist_flash_attn", train_dataloader, accelerator)

# In your training loop...
for step, batch in enumerate(train_dataloader):
  # Shard the sequences
  prepared = prepare_seq_parallel_inputs("dist_flash_attn", batch["input_ids"], batch["position_ids"], batch["target_ids"], accelerator.process_index, accelerator.num_processes, accelerator.device)
  local_input_ids = prepared["local_input_ids"]  
  local_position_ids = prepared["local_position_ids"]
  local_target_ids = prepared["local_target_ids"]
  # Then do model forward as usual
  logits = model(local_input_ids,position_ids=local_position_ids,).logits

Results

Needle-in-a-haystack

There are still some red bricks. I am not sure if instruction tuning or heavier long-context training would help.

Perplexity

It is tested on 2 documents with 500K to 600K length in the proofpile test set(the longest I can find).

Installation

The code is tested on Python 3.10.0, PyTorch 2.4.0 (nightly), and CUDA 11.8.

conda create -n easycontext python=3.10 -y && conda activate easycontext
pip install --pre torch==2.4.0.dev20240324  --index-url https://download.pytorch.org/whl/nightly/cu118
pip install packaging &&  pip install ninja && pip install flash-attn --no-build-isolation --no-cache-dir
pip install -r requirements.txt

Note that PyTorch nightly is a must because I noticed that PyTorch 2.2.0 will OOM for 700K context length on 8 A100s.

Evaluation

Needle-in-a-haystack

accelerate launch --num_processes 8 --config_file  accelerate_configs/deepspeed_inference.yaml  --main_process_port 6000 eval_needle.py \
    --model PY007/EasyContext-1M-Llama-2-7B  \
    --max_context_length 1000000 \
    --min_context_length 50000 \
    --context_interval   50000 \
    --depth_interval 0.1 \
    --num_samples 2 \
    --rnd_number_digits 7 \
    --haystack_dir PaulGrahamEssays 

The above command takes around 6 hours. To reduce time, consider increasing the context_interval and depth_interval.

Perplexity

There are only two documents in proofpile test with length longer than 500K.

accelerate launch --config_file  accelerate_configs/deepspeed_inference.yaml --num_processes 8 --main_process_port 6000 eval_ppl.py \
    --tokenized emozilla/proofpile-test-tokenized  \
    --dataset-min-tokens 500000 \
    --samples 2 \
    --output-file data/debug.csv \
    --min-tokens 50000 \
    --max-tokens 500000 \
    --tokens-step 50000 \
    --truncate \
    --aggressive-memory \
    -m PY007/EasyContext-1M-Llama-2-7B
python plot.py data/debug.csv --xmax 550000 --ymax 2 --ymin 1.5

Training

See train_scripts/

Speed

Switching from data parallel to ring attention results in a minor, but not significant, drop in throughput. However, the throughput drops significantly when we increase the sequence length due to the quadratic complexity of self-attention. I do not think it is due to increased communication cost in ring attention, as the volatile GPU util is almost always 100%. The throughput is measured on 8 A100s with Llama-7B for the first 5 training steps, so expect some variance.

Setup Throughput on 8 A100
64K, data parallel 10240 tokens/s
64K, ring attention 7816 tokens/s
128K, ring attention 4266 tokens/s
512K, ring attention 2133 tokens/s
700K, ring attention 1603 tokens/s

I still remember there were a lot of discussions 2 years ago about whether sparse attention is relevant and one big counterargument is that the quadratic complexity of self-attention is not dominant. I think it is time to revisit this in the long context era.

TODOs

  • Switching to monkey patch implementation.
  • Add dist flash attn.
  • Set up a pip package.
  • EasyContext-Llama-2-13B-1M, if I have spare compute.
  • Instruction tuning.
  • EasyContext-Mistral-7B-1M, if I have spare compute.
  • Add PoSE.

We do not have a clear timeline for the TODOs. Community contributions & collaborations are more than welcome. Please feel free to open an issue or a pull request.

Some Random Thoughts

Until now, handling long sequences in video generation models was thought to be a big challenge. I believe the fact that 8 A100 can contain 700K context for a 7B transformer during training isn't just cool for language models; it's huge for video generation too. 700K context length would mean we can now finetune/generate 1500 frames, assuming each frame contains 512 tokens. This means if one day Meta or someone else open source, at least we can finetune it. Also the nice thing about the encoder-only transformer is we do not need to store the KV cache, which is a huge memory saver.

Acknowledgements

This work is built on top of the following papers/repositories:

Citation

If you find this work useful, a citation will be appreciated via:

@article{zhang2024longva,
  title={Long Context Transfer from Language to Vision},
  author={Peiyuan Zhang and Kaichen Zhang and Bo Li and Guangtao Zeng and Jingkang Yang and Yuanhan Zhang and Ziyue Wang and Haoran Tan and Chunyuan Li and Ziwei Liu},
  journal={arXiv preprint arXiv:2406.16852},
  year={2024},
  url = {https://arxiv.org/abs/2406.16852}
}

and also consider giving a star >_<

Star History Chart

easycontext's People

Contributors

jzhang38 avatar kwen-chen avatar

Stargazers

Zhang Yuwei avatar Junying Chen avatar Eric Wong avatar Xiaofeng Sun avatar  avatar Ziyang Wang avatar Liang Wang avatar  avatar Jue WANG avatar yangjing avatar Orr Zohar avatar Cunxiao Du avatar Hyunseok Kil avatar chenjinguangsheng avatar Togo avatar zzh avatar Xiaosen Zheng avatar Ziqiang Liu avatar Aoi avatar Ajin avatar Nan Wang avatar Jiangjie Chen avatar Chenghao (Alan) Yang avatar Zihan Zhang avatar Jiqian Yang avatar Yu Wang avatar SmallShark avatar Miracle avatar  avatar Trolle Karlsson avatar Kirouane Ayoub avatar Leo Pekelis avatar KING JAMES avatar LeeHX avatar  avatar Almaz Dautov avatar kaneziki avatar Zhiwei He avatar Todd Wildey avatar aurae avatar Amanda Bertsch avatar ZiLiang_Qi avatar Samantha Johnson avatar Felipe avatar ChenxinAn avatar Greta avatar SunForlight avatar Yingfei(Jeremy) Xiang avatar  avatar Liangkai Hang avatar Hai Duong avatar Zhongkai Zhao avatar  avatar Justin Yang avatar Drew Das avatar  avatar wzy avatar  avatar JingyangDeng avatar Yin Song avatar Marvis avatar Tao Peng avatar  avatar Jockey Yan avatar Shitty Girl avatar 千古兴亡知衡权 avatar xucan avatar  avatar Lucas Lingle avatar  avatar Carles avatar andy wong avatar Zhixuan Lin avatar Suyuchen Wang avatar Teven Feng avatar Ledzy avatar Yige avatar  avatar hoshi-hiyouga avatar Zuxin avatar Weiran Yao avatar Indraneil Paul avatar Han Xiao avatar  avatar felix-wang avatar Dongryeol Lee avatar Hanlin Zhang avatar  avatar Oliver Pfaffel avatar  avatar  avatar James avatar ffgcc avatar Seunghyun SEO avatar NCJ avatar sulki.kim avatar Doit avatar Aisyah Razak avatar marzena avatar William Fouvy avatar

Watchers

Kenn avatar Ziwei Liu avatar Kasper Piskorski avatar Melisa Russak avatar Yanshuang avatar  avatar  avatar  avatar  avatar

easycontext's Issues

about seq parallel global batch size

Hello, thank you for your good work
I use the following bash script

--batch-size 1 \
--gradient-accumulate-every 48  \

and this single_node.yaml

num_machines: 1
num_processes: 2

I want to know whether the global training step is 48 or 96 with seq parallel in your dist_flash_attn

Confused by the train scripts

In train_scripts/EasyContext-1M-Llama-2-7B.sh, line 53 specifies --model PY007/Llama2-7B-64K. Why isn't it --model ./output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5, which is the output model of the previous training process?

Also, would you upload training scripts for llama-2 13B in the future? I really appreciate this work and I am looking forward to it. Thanks!

can training codellama?

Thank you very much for your code.

I performed train.py with the codellama 34b base.

This training went well and I confirmed that a checkpoint output of 76G, which is the same as codellama 34b, was generated.
Afterwards, when trying to load the generated model through LlamaForCausalLM, the following error occurred.

ValueError: Trying to set a tensor of shape torch.size([0]) in "weight" (which has shape torch.Size([32000, 8192])), this look incorrect.

Is there anything I missed or need to fix?

attention_mask

Hello, is it possible to add attention_mask to prepare_seq_parallel_inputs, I did notice that there is an assertion in the monkey_path.py file that restricts attention_mask to None
image

Dataset length question

Hello
im testing our learning using your code.
Thank you always.

Currently, I have created a dataset with a 1:1 ratio of 8k and 64k datasets.

Afterwards, learning was conducted using code, but

    q_embed = (q * cos) + (rotate_half(q) * sin)
RuntimeError: The size of tensor a (1024) must match the size of tensor b (8192) at non-singleton dimension 2
  0%|          | 0/301 [00:00<?, ?it/s]

An error has occurred.

My prediction is that there will be no problem with the 64k dataset, but a problem appears during the process of learning the 8k dataset.

Should I set the length of the dataset the same when learning?

For datasets shorter than seq-length, I am wondering whether I should pad it.

Thanks for your help.

--seq-length 65535 \

拓展长上下文的技术是?

实际上本项目的源码比yarn的源码更简洁明了,但我没有在其中看到有关yarn算法的代码,而直观感觉上本项目更像是使用了线性插值,请问本项目使用的是YaRN还是线性插值呢?

期待您的回复,感谢您的解惑。

Rotary embedding size missmatch

Hi authors,

Thank you for great repo! I am testing with Llama3-72b and got error

[rank0]: File "/home/ubuntu/miniconda3/envs/llm2/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 447, in forward
[rank0]: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ubuntu/miniconda3/envs/llm2/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 206, in apply_rotary_pos_emb
[rank0]: q_embed = (q * cos) + (rotate_half(q) * sin)
[rank0]: ~~^~~~~
[rank0]: RuntimeError: The size of tensor a (3999) must match the size of tensor b (4000) at non-singleton dimension 2

Do you have any idea about that error?

Not the real auto-regressive decoding mode ?

Dear author,

In below eval_foreard function, it seems not the real autoregressive decoding. since you concate the input and answer_ids together to form the new input_ids, it performs decoding in the teacher-force mode, not the real auto-regressive decoding.

am I correct?


def eval_forward(accelerator, model, input_ids, pad_id, answer_ids):
# first append labels to input_ids
prompt_length = input_ids.shape[1]
labels_length = answer_ids.shape[1]
input_ids = torch.cat([input_ids, answer_ids], dim=1)
# second pad input_ids to the multiple of accelerator.num_processes
pad_tensor = torch.tensor(
[pad_id]
* (
(accelerator.num_processes * 2)
- input_ids.shape[1] % (accelerator.num_processes * 2)
)
).unsqueeze(0)
input_ids = torch.cat([input_ids, pad_tensor], dim=1)
position_ids = (
torch.arange(input_ids.shape[1]).unsqueeze(0).expand(input_ids.shape[0], -1)
)
prepared = prepare_seq_parallel_inputs(
"zigzag_ring_attn",
input_ids,
position_ids,
None,
accelerator.process_index,
accelerator.num_processes,
accelerator.device,
)
local_input_ids = prepared["local_input_ids"]
local_position_ids = prepared["local_position_ids"]
with torch.inference_mode():
logits = model(
local_input_ids,
position_ids=local_position_ids,
use_cache=False,
).logits
pred = logits.argmax(dim=-1)

Need a running script for ‘dist_flash_attn’

Can you provide a script to run dist_flash_attn? I tried setting parallel_mode to dist_flash_attn but it didn't work successfully.

When trying to use 'dist_flash_attn' with 2*A100, process 0 is stuck in torch.cuda.synchronize() of _lightseq_forward of a certain decoderlayer, while process 1 runs to this step of the next decoderlayer. Strangely, the model gets stuck on the second sample. What might be causing this bug? Is there any way to solve this problem?

It seems that communication of process 0 in maybe_send_recv_fwd_qkvo is not completed.

Model stopped updating after 300-400 steps.

I am training with the following script:

export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024'

accelerate launch \
--config_file  accelerate_configs/single_node.yaml \
train.py \
--batch-size 1 \
--gradient-accumulate-every 4 \
--output-dir ./output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5 \
--wandb EasyContext \
--max-train-steps 2500  \
--learning-rate 2e-5  \
--dataset yaofu/slimpajama-per-source-length-upsample \
--model meta-llama/Llama-2-7b-hf  \
--seq-length 32768 \
--rope-theta 500000 \
--parallel_mode data_parallel

The difference between the above command and train_scripts/EasyContext-1M-Llama-2-7B.sh is that I changed the --max-train-steps and --rope-theta. Additionally, I modified the if block in Line 161-165 in train.py to enable model saving every 100 steps (I set --save_interval=100):

if accelerator.sync_gradients:
    progress_bar.update(1)
    if loss_log is not None:
        progress_bar.set_postfix(loss_log)
    completed_steps += 1

    if completed_steps % args.save_interval == 0:
        ckpt_save_dir = f"{args.output_dir}/step{completed_steps}"
        os.makedirs(ckpt_save_dir, exist_ok=True)
        accelerator.wait_for_everyone()

        state_dict = accelerator.get_state_dict(model)

        accelerator.unwrap_model(model).save_pretrained(
            f"{ckpt_save_dir}",
            is_main_process=accelerator.is_main_process,
            save_function=accelerator.save,
            state_dict=state_dict,
        )

        accelerator.print(f"Saved model to {ckpt_save_dir}")

        accelerator.wait_for_everyone()

All saved models are evaluated on the latest version of lm_eval. I found that all models saved after step 400 (step 400 included) are identical. That is, when checked with cmp <model_at_step_400-0000X-of-0000X.safetensors> <model_at_step_T(T>400)-0000X-of-0000X.safetensors>, no errors are given. Besides, when evaluated on lm_eval, these models give identical results on all datasets tested (including MMLU, TQA, Hellaswag, Winogrande, etc.).

The models are all trained on 8 A800 GPUs (80G) and this issue can be reproduced on different model structures (YARN, which is LLaMA-2 with a different positional embedding). I wonder if you have any insights towards this issue. Thanks!

Which image is used for this job?

I want to ask which image is used for this job, I can't run train.sh after I complete the Installation using pytorch:23.06 following the steps prompted by installation

error when finetuning yi-34b

Hi, thank you for this great project. I am finetuning yi-34b, and when loading the model, it occurs cuda oom error. So i just change the zero3_init_flag to true to avoid oom when loading the model. But when training, there are some other errors, I paste the errors here, could you please help me? Thank you!

/opt/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [257,0,0], thread: [31,0,0] Assertion srcIndex < srcSelectDimSize failed.
[rank6]:[E410 09:24:27.054138428 ProcessGroupNCCL.cpp:1430] [PG 0 Rank 6] Process group watchdog thread terminated with exception: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at /opt/pytorch/c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits, std::allocator >) + 0xae (0x7fd1a42fb67e in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&) + 0xf3 (0x7fd1a42a5375 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x3f2 (0x7fd1a43b0612 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x5e (0x7fd182ac63de in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0x78 (0x7fd182aca678 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x8ad (0x7fd182ad2fbd in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x128 (0x7fd182ad3c08 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #7: + 0xdc253 (0x7fd1a3eb0253 in /lib/x86_64-linux-gnu/libstdc++.so.6)
frame #8: + 0x94ac3 (0x7fd1a4e6bac3 in /lib/x86_64-linux-gnu/libc.so.6)
frame #9: clone + 0x44 (0x7fd1a4efca04 in /lib/x86_64-linux-gnu/libc.so.6)

terminate called after throwing an instance of 'c10::DistBackendError'
what(): [PG 0 Rank 6] Process group watchdog thread terminated with exception: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at /opt/pytorch/c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits, std::allocator >) + 0xae (0x7fd1a42fb67e in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&) + 0xf3 (0x7fd1a42a5375 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x3f2 (0x7fd1a43b0612 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x5e (0x7fd182ac63de in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0x78 (0x7fd182aca678 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x8ad (0x7fd182ad2fbd in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x128 (0x7fd182ad3c08 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #7: + 0xdc253 (0x7fd1a3eb0253 in /lib/x86_64-linux-gnu/libstdc++.so.6)
frame #8: + 0x94ac3 (0x7fd1a4e6bac3 in /lib/x86_64-linux-gnu/libc.so.6)
frame #9: clone + 0x44 (0x7fd1a4efca04 in /lib/x86_64-linux-gnu/libc.so.6)

Exception raised from ncclCommWatchdog at /opt/pytorch/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1434 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits, std::allocator >) + 0xae (0x7fd1a42fb67e in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: + 0xfded22 (0x7fd182afbd22 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: + 0xd342da (0x7fd1828512da in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: + 0xdc253 (0x7fd1a3eb0253 in /lib/x86_64-linux-gnu/libstdc++.so.6)
frame #4: + 0x94ac3 (0x7fd1a4e6bac3 in /lib/x86_64-linux-gnu/libc.so.6)
frame #5: clone + 0x44 (0x7fd1a4efca04 in /lib/x86_64-linux-gnu/libc.so.6)

/opt/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [68,0,0], thread: [64,0,0] Assertion srcIndex < srcSelectDimSize failed.
/opt/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [68,0,0], thread: [65,0,0] Assertion srcIndex < srcSelectDimSize failed.
/opt/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [68,0,0], thread: [66,0,0] Assertion srcIndex < srcSelectDimSize failed.
/opt/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [68,0,0], thread: [67,0,0] Assertion srcIndex < srcSelectDimSize failed.
/opt/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [68,0,0], thread: [68,0,0] Assertion srcIndex < srcSelectDimSize failed.
/opt/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [68,0,0], thread: [69,0,0] Assertion srcIndex < srcSelectDimSize failed.
/opt/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [68,0,0], thread: [70,0,0] Assertion srcIndex < srcSelectDimSize failed.
/opt/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [68,0,0], thread: [71,0,0] Assertion srcIndex < srcSelectDimSize failed.
/opt/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [68,0,0], thread: [72,0,0] Assertion srcIndex < srcSelectDimSize failed.
/opt/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [68,0,0], thread: [73,0,0] Assertion srcIndex < srcSelectDimSize failed.
/opt/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [68,0,0], thread: [74,0,0] Assertion srcIndex < srcSelectDimSize failed.
/opt/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [68,0,0], thread: [75,0,0] Assertion srcIndex < srcSelectDimSize failed.
/opt/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [68,0,0], thread: [76,0,0] Assertion srcIndex < srcSelectDimSize failed.
/opt/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [68,0,0], thread: [77,0,0] Assertion srcIndex < srcSelectDimSize failed.
/opt/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [68,0,0], thread: [78,0,0] Assertion srcIndex < srcSelectDimSize failed.
/opt/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [68,0,0], thread: [79,0,0] Assertion srcIndex < srcSelectDimSize failed.
/opt/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [68,0,0], thread: [80,0,0] Assertion srcIndex < srcSelectDimSize failed.
/opt/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [68,0,0], thread: [81,0,0] Assertion srcIndex < srcSelectDimSize failed.
/opt/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [68,0,0], thread: [82,0,0] Assertion srcIndex < srcSelectDimSize failed.
/opt/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [68,0,0], thread: [83,0,0] Assertion srcIndex < srcSelectDimSize failed.
/opt/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [68,0,0], thread: [84,0,0] Assertion srcIndex < srcSelectDimSize failed.

Llama-2 models do not support `sliding_window` parameter

In train.py, line 63 specified sliding_window=None. Actually, if Llama-2 models are used, the initialization function does not support this parameter. I guess this is modified from Mistral training scripts.

Simply getting rid of this parameter works fine for me.

Appending answer_ids to prompt in `eval_needle.py`

Hi,

In eval_needle.py, I see that the answer_ids are being appended to the input prompt.

input_ids = torch.cat([input_ids, answer_ids], dim=1)

Could you please help me understand why this was implemented this way?

Wouldn't that make the model generate output in teacher-forcing mode instead of doing autoregressive decoding?

OOM when seq-length=700k

Hi, author. When I set seq-length=700k, OOM occured. My torch version is 2.4.0.dev20240324. Do I need to set gradient-accumulate-every to 1?

Max train steps: 90
  0%|                                                                                                                                                                            | 0/90 [00:00<?, ?it/s]`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
[2024-04-08 12:44:38,013] [WARNING] [stage3.py:2069:step] 9 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
  1%|█▊                                                                                                                                                             | 1/90 [16:59<25:12  1%|█▌                                                                                                                                     | 1/90 [16:59<25:12:38, 1019.75s/it, loss=7.13, ppl=1.25e+3][2024-04-08 13:02:16,577] [WARNING] [stage3.py:2069:step] 27 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
  2%|███                                                                                                                                    | 2/90 [34:38<25:29:07, 1042.58s/it, loss=7  2%|███                                                                                                                                        | 2/90 [34:38<25:29:07, 1042.58s/it, loss=5.97, ppl=390][2024-04-08 13:19:59,481] [WARNING] [stage3.py:2069:step] 28 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
  3%|████▋                                                                                                                                      | 3/90 [52:21<25:25:11, 1051.86s/it, lo  3%|████▋                                                                                                                                      | 3/90 [52:21<25:25:11, 1051.86s/it, loss=5.88, ppl=359][2024-04-08 13:37:46,252] [WARNING] [stage3.py:2069:step] 28 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
  4%|██████                                                                                                                                   | 4/90 [1:10:07<25:16:06, 1057.75s/it, lo  4%|██████                                                                                                                                   | 4/90 [1:10:07<25:16:06, 1057.75s/it, loss=5.88, ppl=359][2024-04-08 13:55:34,933] [WARNING] [stage3.py:2069:step] 28 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
  6%|███████▌                                                                                                                                 | 5/90 [1:27:56<25:04:03, 1061.69s/it, lo  6%|███████▌                                                                                                                                 | 5/90 [1:27:56<25:04:03, 1061.69s/it, loss=5.71, ppl=301][2024-04-08 14:13:19,565] [WARNING] [stage3.py:2069:step] 29 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
  7%|█████████▏                                                                                                                               | 6/90 [1:45:41<24:47:45, 1062.69s/it, lo  7%|█████████▏                                                                                                                               | 6/90 [1:45:41<24:47:45, 1062.69s/it, loss=5.48, ppl=240][2024-04-08 14:30:59,959] [WARNING] [stage3.py:2069:step] 29 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
  8%|██████████▋                                                                                                                              | 7/90 [2:03:21<24:29:01, 1061.94s/it, lo  8%|██████████▋                                                                                                                              | 7/90 [2:03:21<24:29:01, 1061.94s/it, loss=5.54, ppl=256][2024-04-08 14:48:45,778] [WARNING] [stage3.py:2069:step] 30 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
  9%|████████████▏                                                                                                                            | 8/90 [2:21:07<24:13:00, 1063.17s/it, lo  9%|████████████▏                                                                                                                            | 8/90 [2:21:07<24:13:00, 1063.17s/it, loss=5.24, ppl=189][2024-04-08 15:06:28,424] [WARNING] [stage3.py:2069:step] 28 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
 10%|█████████████▋                                                                                                                           | 9/90 [2:38:50<23:55:03, 1063.01s/it, lo 10%|█████████████▋                                                                                                                           | 9/90 [2:38:50<23:55:03, 1063.01s/it, loss=5.18, ppl=177][2024-04-08 15:24:17,016] [WARNING] [stage3.py:2069:step] 30 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
 11%|███████████████                                                                                                                         | 10/90 [2:56:38<23:39:38, 1064.73s/it, lo 11%|███████████████                                                                                                                         | 10/90 [2:56:38<23:39:38, 1064.73s/it, loss=5.08, ppl=161][2024-04-08 15:41:58,421] [WARNING] [stage3.py:2069:step] 28 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
 12%|████████████████▌                                                                                                                       | 11/90 [3:14:20<23:20:33, 1063.71s/it, lo 12%|████████████████▌                                                                                                                       | 11/90 [3:14:20<23:20:33, 1063.71s/it, loss=5.01, ppl=150][2024-04-08 15:59:46,293] [WARNING] [stage3.py:2069:step] 31 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
 13%|██████████████████▏                                                                                                                     | 12/90 [3:32:08<23:04:28, 1064.98s/it, lo 13%|██████████████████▏                                                                                                                     | 12/90 [3:32:08<23:04:28, 1064.98s/it, loss=4.93, ppl=138][2024-04-08 16:17:37,821] [WARNING] [stage3.py:2069:step] 29 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
 14%|███████████████████▋                                                                                                                    | 13/90 [3:49:59<22:49:16, 1066.96s/it, lo 14%|███████████████████▋                                                                                                                    | 13/90 [3:49:59<22:49:16, 1066.96s/it, loss=4.96, ppl=142][2024-04-08 16:35:20,389] [WARNING] [stage3.py:2069:step] 30 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
 16%|█████████████████████▏                                                                                                                  | 14/90 [4:07:42<22:29:48, 1065.64s/it, lo 16%|█████████████████████▏                                                                                                                  | 14/90 [4:07:42<22:29:48, 1065.64s/it, loss=5.03, ppl=152][2024-04-08 16:53:01,507] [WARNING] [stage3.py:2069:step] 30 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
 17%|██████████████████████▋                                                                                                                 | 15/90 [4:25:23<22:10:20, 1064.27s/it, lo 17%|██████████████████████▋                                                                                                                 | 15/90 [4:25:23<22:10:20, 1064.27s/it, loss=4.85, ppl=127][2024-04-08 17:10:52,587] [WARNING] [stage3.py:2069:step] 28 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
 18%|████████████████████████▏                                                                                                               | 16/90 [4:43:14<21:55:07, 1066.32s/it, lo 18%|████████████████████████▏                                                                                                               | 16/90 [4:43:14<21:55:07, 1066.32s/it, loss=4.76, ppl=117][2024-04-08 17:28:38,735] [WARNING] [stage3.py:2069:step] 29 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
 19%|█████████████████████████▋                                                                                                              | 17/90 [5:01:00<21:37:17, 1066.27s/it, lo 19%|█████████████████████████▋                                                                                                              | 17/90 [5:01:00<21:37:17, 1066.27s/it, loss=4.91, ppl=135][2024-04-08 17:46:16,972] [WARNING] [stage3.py:2069:step] 28 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
 20%|███████████████████████████▏                                                                                                            | 18/90 [5:18:38<21:16:37, 1063.86s/it, lo 20%|███████████████████████████▏                                                                                                            | 18/90 [5:18:38<21:16:37, 1063.86s/it, loss=4.93, ppl=138][2024-04-08 18:04:03,970] [WARNING] [stage3.py:2069:step] 30 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
 21%|████████████████████████████▋                                                                                                           | 19/90 [5:36:25<21:00:00, 1064.80s/it, lo 21%|████████████████████████████▉                                                                                                            | 19/90 [5:36:25<21:00:00, 1064.80s/it, loss=4.9, ppl=134][rank1]: Traceback (most recent call last):
[rank1]:   File "/data/jkl/proj/EasyContext/train.py", line 219, in <module>
[rank1]:     main(args.parse_args())
[rank1]:   File "/data/jkl/proj/EasyContext/train.py", line 138, in main
[rank1]:     accelerator.backward(loss)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/accelerate/accelerator.py", line 1995, in backward
[rank1]:     self.deepspeed_engine_wrapped.backward(loss, **kwargs)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/accelerate/utils/deepspeed.py", line 166, in backward
[rank1]:     self.engine.backward(loss, **kwargs)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
[rank1]:     ret_val = func(*args, **kwargs)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1976, in backward
[rank1]:     self.optimizer.backward(loss, retain_graph=retain_graph)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
[rank1]:     ret_val = func(*args, **kwargs)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 2213, in backward
[rank1]:     self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
[rank1]:     scaled_loss.backward(retain_graph=retain_graph)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/_tensor.py", line 525, in backward
[rank1]:     torch.autograd.backward(
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank1]:     _engine_run_backward(
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank1]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/autograd/function.py", line 301, in apply
[rank1]:     return user_fn(self, *args)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 304, in backward
[rank1]:     outputs = ctx.run_function(*detached_inputs)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1577, in _call_impl
[rank1]:     result = forward_call(*args, **kwargs)
[rank1]:   File "/data/jkl/proj/EasyContext/easy_context/zigzag_ring_attn/monkey_patch.py", line 84, in new_decoder_forward
[rank1]:     hidden_states = self.mlp(hidden_states)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1577, in _call_impl
[rank1]:     result = forward_call(*args, **kwargs)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 240, in forward
[rank1]:     down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1577, in _call_impl
[rank1]:     result = forward_call(*args, **kwargs)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
[rank1]:     return F.linear(input, self.weight, self.bias)
[rank1]: torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.79 GiB. GPU  has a total capacity of 79.15 GiB of which 801.25 MiB is free. Including non-PyTorch memory, this process has 78.33 GiB memory in use. Of the allocated memory 45.74 GiB is allocated by PyTorch, and 31.66 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
W0408 18:07:00.108000 140664182719680 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 735592 closing signal SIGTERM
W0408 18:07:00.109000 140664182719680 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 735594 closing signal SIGTERM
W0408 18:07:00.109000 140664182719680 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 735595 closing signal SIGTERM
W0408 18:07:00.109000 140664182719680 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 735596 closing signal SIGTERM
W0408 18:07:00.110000 140664182719680 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 735597 closing signal SIGTERM
W0408 18:07:00.110000 140664182719680 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 735598 closing signal SIGTERM
W0408 18:07:00.110000 140664182719680 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 735599 closing signal SIGTERM
E0408 18:07:21.018000 140664182719680 torch/distributed/elastic/multiprocessing/api.py:826] failed (exitcode: 1) local_rank: 1 (pid: 735593) of binary: /data/jkl/miniconda3/envs/easycontext/bin/python
Traceback (most recent call last):
  File "/data/jkl/miniconda3/envs/easycontext/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 46, in main
    args.func(args)
  File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1042, in launch_command
    deepspeed_launcher(args)
  File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/accelerate/commands/launch.py", line 754, in deepspeed_launcher
    distrib_run.run(args)
  File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/distributed/run.py", line 870, in run
    elastic_launch(
  File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 132, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 263, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
train.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-04-08_18:07:00
  host      : ubuntu
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 735593)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

Error when the model vocabulary is larger than 120k

Dear author,

when using the easycontext to evaluate a model with more than 120K vocabuary, I come across an strange problem:

the predict result is correct when a single gpu is used, but error when 8 gpus are used. the 'pred' result in low code are all zeros, which is quite stange.

I wonder is there any limit of the vocabulary size in ring-attention implementation and what is the possible reason?

BTW, below is what i have tried:

  1. i tried the model your provided, it is correct on both single and multi-gpu mode.
  2. To eliminate the issue with the tokenizer, I am using input_ids as the input in eval_forward func. (consistent with the input_ids for multi-GPU inference and single-GPU inference)

with torch.inference_mode():
logits = self.model(
local_input_ids,
position_ids=local_position_ids,
use_cache=False,
).logits
pred = logits.argmax(dim=-1)

the value of pred ids: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:6')Traceback (most recent call last):

Requirements for input length

Thanks for your great work!
I noticed that the length of input_ids needs to be divisible by the world size, otherwise the forward will be stuck. What's the reason for it?

train speed is too slow

I found that when the context length is 512k, the training speed is too slow, which is different from your experimental results. It takes 585 seconds for training a batch of 512k , which is to 512000/585.85=873.94 tokens/s
And I used A100-80G*8 with NVLINK.

accelerate launch \
--config_file accelerate_configs/single_node.yaml \
train.py \
--batch-size 1 \
--gradient-accumulate-every 2  \
--output-dir  ./output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5 \
--seed 2027 \
--max-train-steps 90  \
--learning-rate 1e-5  \
--dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \
--model meta-llama/Llama-2-7b-hf  \
--seq-length 512000 \
--rope-theta 250000000 \
--parallel_mode zigzag_ring_attn

image

Does the input sharding match exact optimization of long sequence?

Thanks for your exciting work!

I found the extract_local function seems to split the input sequence length L into L/world_size. Are parameters optimized (backward) for each chunk rather than the whole long sequence? So have you tried if there are any approximation errors or the optimization is length-agnostic?

Does this repo work with FSDP or Zero?

Thanks for sharing the fantastic repo! I am wondering if it works with FSDP or Zero? It seems that this is not mentioned in the document. Thank you so much!

LongBench/InfiniteBench

Hi,

Great work! I'm wondering if there is any benchmark tests conducted for long context bench such as LongBench or InfiniteBench to incorporate multiple types of tasks?

How to auto-regression generate?

In eval_needle.py, it gather( ) and undo_extract_local( ) the preds to get the whole preds, then get pred token by prompt_length. In auto-regression mode, I just need next token, can I just get the pred token without gather( ) and undo_extract_local( )?

dataset description

Great work! Would it be possible to add some descriptions to clarify how the training dataset is generated? For example, the two datasets used in the script: PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K and PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M. Thanks!

Does it support SFT training?

I noticed that the code does not support the passing of attention_mask, making it impossible to perform padding operations for SFT data?

assert attention_mask is None

In addition, will there be any issues with the loss calculation method in the code for SFT data where the labels contain values of -100 (the prompt and padding parts)?

EasyContext/train.py

Lines 117 to 138 in fe49492

prepared = prepare_seq_parallel_inputs(
args.parallel_mode,
input_ids,
position_ids,
target_ids,
accelerator.process_index,
accelerator.num_processes,
accelerator.device,
)
local_input_ids = prepared["local_input_ids"]
local_position_ids = prepared["local_position_ids"]
local_target_ids = prepared["local_target_ids"]
loss_log = None
with accelerator.accumulate(model):
logits = model(
local_input_ids,
position_ids=local_position_ids,
).logits
loss = loss_func(
logits.reshape(-1, logits.shape[-1]), local_target_ids.reshape(-1)
)

Look forward to your response. Thank you.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.