Code Monkey home page Code Monkey logo

flash-linear-attention's Introduction

Flash Linear Attention

hf_model | Discord

This repo aims at providing a collection of efficient Triton-based implementations for state-of-the-art linear attention models. Any pull requests are welcome!

image

Models

Date Model Title Paper Code FLA impl
2023-07 RetNet (@MSRA@THU) Retentive network: a successor to transformer for large language models [arxiv] [official] [RetNet] code
2023-12 GLA (@MIT@IBM) Gated Linear Attention Transformers with Hardware-Efficient Training [arxiv] [official] code
2023-12 Based (@Stanford@Hazyresearch) An Educational and Effective Sequence Mixer [blog] [official] code
2024-01 Rebased Linear Transformers with Learnable Kernel Functions are Better In-Context Models [arxiv] [official] code
2021-02 Delta Net Linear Transformers Are Secretly Fast Weight Programmers [arxiv] [official] code
2023-09 Hedgehog (@HazyResearch) The Hedgehog & the Porcupine: Expressive Linear Attentions with Softmax Mimicry openreview code
2023-10 PolySketchFormer (@CMU@Google) Fast Transformers via Sketching Polynomial Kernels arxiv TODO
2023-07 TransnormerLLM A Faster and Better Large Language Model with Improved TransNormer (@Shanghai AI Lab) openreview arxiv [official] [Lightning2] TODO
2023-05 RWKV-v4 (@BlinkDL) Reinventing RNNs for the Transformer Era arxiv [official] TODO
2023-10 GateLoop Fully Data-Controlled Linear Recurrence for Sequence Modeling openreview arxiv [official] [jax] TODO
2021-10 ABC (@UW) Attention with Bounded-memory Control arxiv code
2023-09 VQ-transformer Linear-Time Transformers via Vector Quantization arxiv [official] TODO
2023-09 HGRN Hierarchically Gated Recurrent Neural Network for Sequence Modeling openreview [official] code
2024-04 HGRN2 HGRN2: Gated Linear RNNs with State Expansion arxiv [official] code
2024-04 RWKV6 Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrence arxiv [official] code
2024-06 Samba Samba: Simple Hybrid State Space Models for Efficient Unlimited Context Language Modeling arxiv [official] code
2024-05 Mamba2 Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality arxiv [official] code

Installation

The following requirements should be satisfied

As fla is actively developed now, no released packages are provided at this time. If you do need to use fla ops/modules and contemplate further explorations, an alternative way is to install the package from source

pip install -U git+https://github.com/sustcsonglin/flash-linear-attention

or manage fla with submodules

git submodule add https://github.com/sustcsonglin/flash-linear-attention.git 3rdparty/flash-linear-attention
ln -s 3rdparty/flash-linear-attention/fla fla

Caution

If you're not working with Triton v2.2 or its nightly release, it's important to be aware of potential issues with the FusedChunk implementation, detailed in this issue. You can run the test python tests/test_fused_chunk.py to check if your version is affected by similar compiler problems. While we offer some fixes for Triton<=2.1, be aware that these may result in reduced performance.

For both Triton 2.2 and earlier versions (up to 2.1), you can reliably use the Chunk version (with hidden states materialized into HBMs). After careful optimization, this version generally delivers high performance in most scenarios.

Usage

Token Mixing

We provide "token mixing" linear attention layers in fla.layers for you to use. You can replace the standard multihead attention layer in your model with other linear attention layers. Example usage is as follows:

>>> import torch
>>> from fla.layers import MultiScaleRetention
>>> batch_size, num_heads, seq_len, hidden_size,  = 32, 4, 2048, 1024
>>> device, dtype = 'cuda:0', torch.bfloat16
>>> retnet = MultiScaleRetention(hidden_size=hidden_size, num_heads=num_heads).to(device=device, dtype=dtype)
>>> x = torch.randn(batch_size, seq_len, hidden_size).to(device=device, dtype=dtype)
>>> y, *_ = retnet(x)
>>> y.shape
torch.Size([32, 2048, 1024])

We provide the implementations of models that are compatible with 🤗 Transformers library. Here's an example of how to initialize a GLA model from the default configs in fla:

>>> from fla.models import GLAConfig
>>> from transformers import AutoModel
>>> config = GLAConfig()
>>> config
GLAConfig {
  "attn_mode": "fused_chunk",
  "bos_token_id": 1,
  "clamp_min": null,
  "conv_size": 4,
  "eos_token_id": 2,
  "expand_k": 0.5,
  "expand_v": 1,
  "fuse_cross_entropy": true,
  "fuse_norm": true,
  "hidden_act": "swish",
  "hidden_ratio": 4,
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": null,
  "max_position_embeddings": 2048,
  "model_type": "gla",
  "num_heads": 4,
  "num_hidden_layers": 24,
  "rms_norm_eps": 1e-06,
  "share_conv_kernel": true,
  "tie_word_embeddings": false,
  "transformers_version": "4.39.1",
  "use_cache": true,
  "use_gk": true,
  "use_gv": false,
  "use_short_conv": false,
  "vocab_size": 32000
}

>>> AutoModel.from_config(config)
GLAModel(
  (embed_tokens): Embedding(32000, 2048)
  (layers): ModuleList(
    (0-23): 24 x GLABlock(
      (attn_norm): RMSNorm()
      (attn): GatedLinearAttention(
        (gate_fn): SiLU()
        (q_proj): Linear(in_features=2048, out_features=1024, bias=False)
        (k_proj): Linear(in_features=2048, out_features=1024, bias=False)
        (v_proj): Linear(in_features=2048, out_features=2048, bias=False)
        (g_proj): Linear(in_features=2048, out_features=2048, bias=False)
        (gk_proj): Sequential(
          (0): Linear(in_features=2048, out_features=16, bias=False)
          (1): Linear(in_features=16, out_features=1024, bias=True)
        )
        (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        (g_norm_swish_gate): FusedRMSNormSwishGate()
      )
      (mlp_norm): RMSNorm()
      (mlp): GLAMLP(
        (gate_proj): Linear(in_features=2048, out_features=11264, bias=False)
        (down_proj): Linear(in_features=5632, out_features=2048, bias=False)
        (act_fn): SiLU()
      )
    )
  )
  (norm): RMSNorm()
)

Generation

Upon successfully pretraining a model, it becomes accessible for generating text using the 🤗 text generation APIs. In the following, we give a generation example:

>>> import fla
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> name = 'fla-hub/gla-1.3B-100B'
>>> tokenizer = AutoTokenizer.from_pretrained(name)
>>> model = AutoModelForCausalLM.from_pretrained(name).cuda()
>>> input_prompt = "Power goes with permanence. Impermanence is impotence. And rotation is castration."
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.cuda()
>>> outputs = model.generate(input_ids, max_length=64)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

We also provide a simple script here for benchmarking the generation speed. Simply run it by:

$ python -m benchmarks.benchmark_generation \
  --path 'fla-hub/gla-1.3B-100B' \
  --repetition_penalty 2. \
  --prompt="Hello everyone, I'm Songlin Yang"

Prompt:
Hello everyone, I'm Songlin Yang
Generated:
Hello everyone, I'm Songlin Yang.
I am a 20 year old girl from China who is currently studying in the United States of America for my Master degree and also working as an English teacher at school here on campus since last summer (1st semester). My main goal to be able do well with this course so that we can have

Prompt length: 10, generation length: 64
Total prompt processing + decoding time: 4593ms

All of the pretrained models currently available can be found in fla-hub.

>>> from huggingface_hub import list_models
>>> for model in list_models(author='fla-hub'): print(model.id)

Evaluations

The lm-evaluation-harness library allows you to easily perform (zero-shot) model evaluations. Follow the steps below to use this library:

  1. Install lm_eval following their instructions.

  2. Run evaluation with:

$ PATH='fla-hub/gla-1.3B-100B'
$ python -m evals.harness --model hf \
    --model_args pretrained=$PATH,dtype=bfloat16 \
    --tasks wikitext,lambada_openai,piqa,hellaswag,winogrande,arc_easy,arc_challenge,boolq,sciq,copa,openbookqa \
    --batch_size 64 \
    --num_fewshot 0 \
    --device cuda \
    --show_config                  

We've made fla compatible with hf-style evaluations, you can call evals.harness to finish the evaluations. Running the command above will provide the task results reported in the GLA paper.

Tip

If you are using lm-evaluation-harness as an external library and can't find (almost) any tasks available, before calling lm_eval.evaluate() or lm_eval.simple_evaluate(), simply run the following to load the library's stock tasks!

>>> from lm_eval.tasks import TaskManager; TaskManager().initialize_tasks()

Benchmarks

We compared our Triton-based RetNet implementation with CUDA-based FlashAttention2, using a batch size of 8, 32 heads, and a head dimension of 128, across different sequence lengths. These tests were conducted on a single A100 80GB GPU, as illustrated in the following graph

# you might have to first install `fla` to enable its import via `pip install -e .`
$ python benchmark_retention.py
Performance:
   seq_len  fused_chunk_fwd  chunk_fwd  parallel_fwd  fused_chunk_fwdbwd  chunk_fwdbwd  parallel_fwdbwd  flash_fwd  flash_fwdbwd
0    128.0         0.093184   0.185344      0.067584            1.009664      1.591296         1.044480   0.041984      0.282624
1    256.0         0.165888   0.219136      0.126976            1.024000      1.596928         1.073152   0.074752      0.413696
2    512.0         0.308224   0.397312      0.265216            1.550336      1.603584         1.301504   0.156672      0.883712
3   1024.0         0.603136   0.747520      0.706560            3.044864      3.089408         3.529728   0.467968      2.342912
4   2048.0         1.191424   1.403904      2.141184            6.010880      6.059008        11.009024   1.612800      7.135232
5   4096.0         2.377728   2.755072      7.392256           11.932672     11.938816        37.792770   5.997568     24.435200
6   8192.0         4.750336   5.491712     26.402817           23.759359     23.952385       141.014023  22.682114     90.619904
7  16384.0         9.591296  10.870784    101.262337           47.666176     48.745472       539.853821  91.346947    346.318848

Performance

Different forms of linear attention

Please refer to Sectiton 2.3 of GLA paper for hardware considerations of different forms of linear attention.

  • Parallel: Self-attention-styled computation in $O(L^2)$ time with sequence parallelism.
  • FusedRecurrent: Recurrent computation in $O(L)$ time. Hidden states are computed on-the-fly in shared memory without any materialization to global memory (see Algorithm1 of this paper for more details!). This saves a lot of I/O cost and should be a strong baseline for speed comparison.
  • FusedChunk: Chunkwise computation in $O(LC)$ time where $C$ is the chunk size. Hidden states are computed on-the-fly without any materialization to global memory likewise FusedRecurrent. This version is usually better than FusedReuccurent because tensor cores can be used for sequence level "reduction", whilst FusedRecurrent cannot use tensor cores at all. Note that there is no sequence level parallelism in this implementation, so this impl is not suitable for the very small batch size setting. Should be more memory efficient than ParallelChunk.
  • ParallelChunk: Chunkwise computation with sequence parallelism. Need to materialize hidden states to global memory for each chunk. $C$ is needed to set properly to achieve good performance because when $C$ is small there are too many hidden states to load/store to global memory; and when $C$ is too large the FLOPs are high. Recommened $C$ is [64, 128, 256]

Citation

If you find this repo useful, please consider citing our works:

@article{yang2024delta,
  title   = {Parallelizing Linear Transformers with the Delta Rule over Sequence Length}, 
  author  = {Songlin Yang and Bailin Wang and Yu Zhang and Yikang Shen and Yoon Kim},
  journal = {arXiv preprint arXiv:2406.06484},
  year    = {2024},
}

@article{yang2023gated,
  title   = {Gated Linear Attention Transformers with Hardware-Efficient Training},
  author  = {Yang, Songlin and Wang, Bailin and Shen, Yikang and Panda, Rameswar and Kim, Yoon},
  journal = {arXiv preprint arXiv:2312.06635},
  year    = {2023}
}

@software{yang2024fla,
  title  = {FLA: A Triton-Based Library for Hardware-Efficient Implementations of Linear Attention Mechanism},
  author = {Yang, Songlin and Zhang, Yu},
  url    = {https://github.com/sustcsonglin/flash-linear-attention},
  month  = jan,
  year   = {2024}
}

flash-linear-attention's People

Contributors

chaoscodes avatar danfosing avatar donglixp avatar doraemonzzz avatar eltociear avatar hypnopump avatar learning-chip avatar mirceamironenco avatar ridgerchu avatar sustcsonglin avatar uniartisan avatar yzhangcs avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

flash-linear-attention's Issues

Transformer model not learning after adding a classification head

I added a classification head to the pretrained Transformer++ model from https://huggingface.co/fla-hub/transformer-1.3B-100B/tree/main and finetuned on SST-2 dataset. However, the validation loss remained constant since the begginning. Here's my code for the Sequence Classification I defined. Similar architecture works for the GLA model. Could you help me to take a look if there's anything wrong with my code or anything else.
`
class TransformerForSequenceClassification(TransformerPreTrainedModel):
def init(self, model_name, num_labels, config):
super().init(config)
self.num_labels = num_labels
self.model = AutoModelForCausalLM.from_pretrained(model_name).model
self.config = config
self.classifier = nn.Linear(self.config.hidden_size, self.num_labels, bias=False)
self.model.post_init()

def forward(
    self,
    input_ids: Optional[torch.LongTensor] = None,
    attention_mask: Optional[torch.FloatTensor] = None,
    token_type_ids: Optional[torch.LongTensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    head_mask: Optional[torch.FloatTensor] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    labels: Optional[torch.LongTensor] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
    use_cache: Optional[bool] = None
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:


    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    outputs = self.model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        inputs_embeds=inputs_embeds,
        past_key_values=past_key_values,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict
    )
    sequence_output = outputs[0]
    logits = self.classifier(sequence_output)
 

    if input_ids is not None:
        batch_size, sequence_length = input_ids.shape[:2]
    else:
        batch_size, sequence_length = inputs_embeds.shape[:2]

    assert (
        self.config.pad_token_id is not None or batch_size == 1
    ), "Cannot handle batch sizes > 1 if no padding token is defined."

    if self.config.pad_token_id is None:
        sequence_lengths = -1
    else:
        if input_ids is not None:
            # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
            sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
            sequence_lengths = sequence_lengths % input_ids.shape[-1]
            sequence_lengths = sequence_lengths.to(logits.device)
        else:
            sequence_lengths = -1
            logger.warning_once(
                f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
                "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
            )

    pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]

    loss = None
    if labels is not None:
        if self.config.problem_type is None:
            if self.num_labels == 1:
                self.config.problem_type = "regression"
            elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                self.config.problem_type = "single_label_classification"
            else:
                self.config.problem_type = "multi_label_classification"

        if self.config.problem_type == "regression":
            loss_fct = nn.MSELoss()
            if self.num_labels == 1:
                loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
            else:
                loss = loss_fct(pooled_logits, labels)
        elif self.config.problem_type == "single_label_classification":
            loss_fct = nn.CrossEntropyLoss()
            labels = labels.to(pooled_logits.device)
            loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
        elif self.config.problem_type == "multi_label_classification":
            loss_fct = nn.BCEWithLogitsLoss()
            loss = loss_fct(pooled_logits, labels)

    if not return_dict:
        output = (pooled_logits,) + outputs[1:]
        return ((loss,) + output) if loss is not None else output

    return SequenceClassifierOutputWithPast(
        loss=loss,
        logits=pooled_logits,
        past_key_values=outputs.past_key_values,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions
    )`

bugs in BasedLinearAttention/LinearAttention/HGRN2Attention Implementation

I've noticed a few bugs in the implementations of BasedLinearAttention, LinearAttention, and HGRN2.

For BasedLinearAttention self.hidden_size is not assigned

For LinearAttention k(v)'s rearrange was repeated, and feature_map_q(feature_map_k)'s position seems strange.
Also in Line152, we suppose to have something like o = o[0] if isinstance(o, tuple) or isinstance(o, list) as linear-att fwd will returns a list

For HGRN2, when all of three parameters provided, even with expand_ratio*num_heads=hidden_size, it raises error.

RWKV6 backward issue

Hi
I catched the 3B version of the model from the hugging face hub and then when I try to use loss.backward (after model.train()) using the transformer library, I got this error providing from your library.

  File "/home/ostix/.virtualenvs/AI-architectures/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1237, in ast_to_ttir
    raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 61:21:
    mask_bk = i_k * BK + tl.arange(0, BK) < DK
    mask_bv = i_v * BV + tl.arange(0, BV) < DV
    mask_kv = mask_bk[:, None] & mask_bv[None, :]
    _u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32)
    h = tl.zeros([BV, BK], dtype=tl.float32)
    if USE_INITIAL_STATE:
        p_init_s = initial_state + i_bh * DK * DV + \
            (i_k * BK + tl.arange(0, BK)[None, :]) * \
            DV + (i_v * BV + tl.arange(0, BV)[:, None])
        h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
                     ^
ValueError('Cannot broadcast, the expanded size of the tensor (64) must match the existing size (16) at non-singleton dimension 0: [16, 64], [64, 16]')

Thanks for considering this issue

Lack of speed advantage in GLA training

I have compared the speeds of GLA, Attention, and Flash Attention, as shown in the table below, and found that GLA has little to no advantage in terms of speed. What could be the reasons behind this result?

seq_len attention flash attn 2 fused_chunk_gla chunk_gla fused_recurrent_gla
1163 0.000822s 0.000177s 0.00200s 0.00138s 0.000860s
1172 0.000769s 0.000192s 0.00197s 0.00138s 0.000851s
1346 0.000782s 0.000185s 0.00186s 0.00143s 0.000870s
1366 0.000827s 0.000154s 0.00183s 0.00144s 0.000872s

Environment:

NVIDIA GeForce RTX 3090
Driver Version: 525.89.02      
CUDA Version: 11.8
torch                    2.0.1
accelerate               0.21.0
transformers             4.31.0
triton                   2.2.0

RuntimeError: Triton Error [CUDA]: device-side assert triggered for fla.modules.layernorm.py

Hello,

I had an issue with the RMS LayerNorm as implemented in the FLA library. It is a RuntimeError: Triton Error [CUDA]: device-side assert triggered. I run on A100 GPUs and on latest versions of torch and triton (I've tried downgrading to minimal requirements). Here is the stacktrace, any idea how to solve this?

[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/project/main.py", line 372, in <module>
[rank1]:     main()
[rank1]:   File "/home/project/main.py", line 356, in main
[rank1]:     loss = train_batch(opt, model, loss_func, features, labels)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/project/main.py", line 269, in train_batch
[rank1]:     output = model(features, labels=labels)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 863, in forward
[rank1]:     output = self._fsdp_wrapped_module(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/fla/models/gla/modeling_gla.py", line 367, in forward
[rank1]:     outputs = self.model(
[rank1]:               ^^^^^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 863, in forward
[rank1]:     output = self._fsdp_wrapped_module(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/fla/models/gla/modeling_gla.py", line 244, in forward
[rank1]:     hidden_states, attentions, past_key_values = layer(
[rank1]:                                                  ^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/fla/models/gla/modeling_gla.py", line 101, in forward
[rank1]:     hidden_states = self.attn_norm(hidden_states)
[rank1]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/fla/modules/layernorm.py", line 657, in forward
[rank1]:     return rms_norm_fn(
[rank1]:            ^^^^^^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/fla/modules/layernorm.py", line 524, in rms_norm_fn
[rank1]:     return LayerNormFn.apply(
[rank1]:            ^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/torch/autograd/function.py", line 574, in apply
[rank1]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/fla/utils.py", line 11, in wrapper
[rank1]:     return fn(ctx,
[rank1]:            ^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/fla/modules/layernorm.py", line 413, in forward
[rank1]:     y, mean, rstd, residual_out = _layer_norm_fwd(
[rank1]:                                   ^^^^^^^^^^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/fla/modules/layernorm.py", line 169, in _layer_norm_fwd
[rank1]:     _layer_norm_fwd_1pass_kernel[(M,)](
[rank1]:   File "/home/.local/lib/python3.11/site-packages/triton/runtime/jit.py", line 345, in <lambda>
[rank1]:     return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
[rank1]:                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 156, in run
[rank1]:     timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
[rank1]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 156, in <dictcomp>
[rank1]:     timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
[rank1]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 133, in _bench
[rank1]:     return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/triton/testing.py", line 103, in do_bench
[rank1]:     fn()
[rank1]:   File "/home/.local/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 114, in kernel_call
[rank1]:     self.fn.run(
[rank1]:   File "/home/.local/lib/python3.11/site-packages/triton/runtime/jit.py", line 691, in run
[rank1]:     kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
[rank1]:     ^^^^^^^^^^
[rank1]:   File "/home/.local/lib/python3.11/site-packages/triton/compiler/compiler.py", line 381, in __getattribute__
[rank1]:     self._init_handles()
[rank1]:   File "/home/.local/lib/python3.11/site-packages/triton/compiler/compiler.py", line 376, in _init_handles
[rank1]:     self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary(
[rank1]:                                                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: RuntimeError: Triton Error [CUDA]: device-side assert triggered```

inconsistent results when "masking" gating term between "fused_recurrent" and "fused_chunk" (fused_chunk presumably wrong)

Hi,

I tried to implement an option where I can pass a mask to a GLA layer that sets to zero certain hidden states so that I can pack sequences and avoid data contamination between them. For that I set corresponding gk to -infty so that gp.exp() = 0.
It boils down to these two lines :

if reset_mask is not None:
    gk = gk.masked_fill(reset_mask, -0.1*torch.finfo(gk.dtype).max)

where reset_mask is a boolean mask that indicates first token of packed sequences in each row.

Apparently it works when using using "fused_recurrent" mode. In this test I compare concat(gla(x1), gla(x2)) against gla(concat(x1, x2), reset_mask=reset_mask). It also works on real data (while contaminated version fails).


device = "cuda"
mode = "fused_recurrent"
b, n, d = 1, 8, 64
gla = GatedLinearAttention(hidden_size=d, mode=mode).to(device)

x = torch.randn(b, n, d).to(device)
x1, x2 = x.chunk(2, dim=1)

reset_mask = (torch.arange(n, device=device)%(n//2)) == 0 #False everywhere except in first position of x1 and x2
reset_mask = rearrange(reset_mask, "n -> 1 n 1")

y = gla(x, reset_mask=reset_mask)
y1, y2 = gla(x1), gla(x2)

assert torch.allclose(torch.cat((y1, y2), dim=1), y)

However "fused_chunk" curiously fails the test, and training on real data is unstable.

mode = "fused_recurrent"
...
>>> assert torch.allclose(torch.cat((y1, y2), dim=1), y)
AssertionError

Is that a bug ?

Thank you,
Théodor

AssertionError('All values in both first input shape ([constexpr[16], constexpr[8]]) and second input shape ([constexpr[8], constexpr[16]]) must be >= 16!')

I am trying to use GLABlock with batch size 1, but encounter this error. How can I resolve this?

My current config:

 config = GLAConfig(
    hidden_size=channels,
    num_hidden_layers=n_layer,
    num_attention_heads=num_heads,
    num_heads=num_heads,
    attn_mode='fused_chunk',
    expand_k=0.5,
    expand_v=1.0,
    hidden_act="swish",
    bid_mode='layer',
    use_dirpe=False,
    rms_norm_eps=1e-6,
    if_norm_qkv=False,
    if_scale_qkv=False,
    fuse_norm=True,
)

Triton Error in flash-linear-attention/fla/modules/rmsnorm.py

PyTorch=='2.3.0a0+ebedce2'
Triton=='2.2.0'
(from out-of-box NVIDIA PyTorch Container 24.02 nvcr.io/nvidia/pytorch:24.02-py3)

[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO] 
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/train.py", line 705, in <module>
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     logits, loss = model(X, Y)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return self._call_impl(*args, **kwargs)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return forward_call(*args, **kwargs)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return self._call_impl(*args, **kwargs)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return forward_call(*args, **kwargs)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/model_fla.py", line 26, in forward
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     ori_output = self.model(
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return self._call_impl(*args, **kwargs)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return forward_call(*args, **kwargs)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/models/mamba/modeling_mamba.py", line 579, in forward
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     mamba_outputs = self.backbone(
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return self._call_impl(*args, **kwargs)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return forward_call(*args, **kwargs)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/models/mamba/modeling_mamba.py", line 494, in forward
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     hidden_states = mixer_block(hidden_states, cache_params=cache_params)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return self._call_impl(*args, **kwargs)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return forward_call(*args, **kwargs)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/models/mamba/modeling_mamba.py", line 318, in forward
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     hidden_states = self.norm(hidden_states)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return self._call_impl(*args, **kwargs)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return forward_call(*args, **kwargs)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 504, in forward
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return rms_norm_fn(
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 488, in rms_norm_fn
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 553, in apply
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return super().apply(*args, **kwargs)  # type: ignore[misc]
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 421, in forward
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     y, mean, rstd, residual_out = _layer_norm_fwd(
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 155, in _layer_norm_fwd
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     MAX_FUSED_SIZE = 65536 // x.element_size()
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 162, in resume_in__layer_norm_fwd_at_155
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     _layer_norm_fwd_1pass_kernel[(M,)](
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 127, in run
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     self.nargs = dict(zip(self.arg_names, args))
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 127, in resume_in_run_at_127
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     self.nargs = dict(zip(self.arg_names, args))
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 127, in resume_in_run_at_127
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     self.nargs = dict(zip(self.arg_names, args))
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 128, in resume_in_run_at_127
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     if len(self.configs) > 1:
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 143, in resume_in_run_at_128
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 143, in <dictcomp>
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 122, in _bench
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 102, in do_bench
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     fn()
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 109, in kernel_call
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     self.pre_hook(args)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 110, in resume_in_kernel_call_at_109
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     self.fn.run(
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 413, in run
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     grid = get_special_arg("grid")
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 414, in resume_in_run_at_413
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     num_warps = get_special_arg("num_warps")
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 415, in resume_in_run_at_414
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     num_ctas = get_special_arg("num_ctas", 1)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 416, in resume_in_run_at_415
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     num_stages = get_special_arg("num_stages")
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 417, in resume_in_run_at_416
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     enable_warp_specialization = get_special_arg("enable_warp_specialization", False)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 426, in resume_in_run_at_417
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     bound_args = self.signature.bind(*args, **kwargs)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 427, in resume_in_run_at_426
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     bound_args.apply_defaults()
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 429, in resume_in_run_at_427
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     assert len(bound_args.arguments) == len(self.params)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 429, in resume_in_run_at_429
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     assert len(bound_args.arguments) == len(self.params)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 430, in resume_in_run_at_429
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     args = [KernelArg(arg_value, param) for (_, arg_value), param in zip(bound_args.arguments.items(), self.params)]
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 435, in resume_in_run_at_430
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     spec_key = tuple(arg.specialization_key() for arg in args if not arg.param.do_not_specialize)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO] 
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO] ==========
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO] Traceback (most recent call last):
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 741, in _convert_frame
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     result = inner_convert(frame, cache_entry, hooks, frame_state)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 332, in _convert_frame_assert
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     unimplemented("generator")
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py", line 193, in unimplemented
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     raise Unsupported(msg)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO] torch._dynamo.exc.Unsupported: generator
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [DEBUG] skipping because no torch.* specialization_key             /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py 170
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO] WON'T CONVERT <genexpr> /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py line 436
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO] ========== TorchDynamo Stack Trace ==========
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO] Traceback (most recent call last):
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 741, in _convert_frame
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     result = inner_convert(frame, cache_entry, hooks, frame_state)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 332, in _convert_frame_assert
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     unimplemented("generator")
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py", line 193, in unimplemented
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     raise Unsupported(msg)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO] torch._dynamo.exc.Unsupported: generator
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO] 
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO] ========== The above exception occurred while processing the following code ==========
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO] 
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/train.py", line 705, in <module>
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     logits, loss = model(X, Y)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return self._call_impl(*args, **kwargs)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return forward_call(*args, **kwargs)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return self._call_impl(*args, **kwargs)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return forward_call(*args, **kwargs)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/model_fla.py", line 26, in forward
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     ori_output = self.model(
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return self._call_impl(*args, **kwargs)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return forward_call(*args, **kwargs)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/models/mamba/modeling_mamba.py", line 579, in forward
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     mamba_outputs = self.backbone(
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return self._call_impl(*args, **kwargs)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return forward_call(*args, **kwargs)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/models/mamba/modeling_mamba.py", line 494, in forward
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     hidden_states = mixer_block(hidden_states, cache_params=cache_params)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return self._call_impl(*args, **kwargs)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return forward_call(*args, **kwargs)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/models/mamba/modeling_mamba.py", line 318, in forward
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     hidden_states = self.norm(hidden_states)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return self._call_impl(*args, **kwargs)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return forward_call(*args, **kwargs)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 504, in forward
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return rms_norm_fn(
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 488, in rms_norm_fn
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 553, in apply
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return super().apply(*args, **kwargs)  # type: ignore[misc]
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 421, in forward
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     y, mean, rstd, residual_out = _layer_norm_fwd(
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 155, in _layer_norm_fwd
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     MAX_FUSED_SIZE = 65536 // x.element_size()
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 162, in resume_in__layer_norm_fwd_at_155
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     _layer_norm_fwd_1pass_kernel[(M,)](
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 127, in run
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     self.nargs = dict(zip(self.arg_names, args))
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 127, in resume_in_run_at_127
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     self.nargs = dict(zip(self.arg_names, args))
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 127, in resume_in_run_at_127
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     self.nargs = dict(zip(self.arg_names, args))
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 128, in resume_in_run_at_127
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     if len(self.configs) > 1:
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 143, in resume_in_run_at_128
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 143, in <dictcomp>
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 122, in _bench
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 102, in do_bench
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     fn()
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 109, in kernel_call
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     self.pre_hook(args)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 110, in resume_in_kernel_call_at_109
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     self.fn.run(
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 413, in run
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     grid = get_special_arg("grid")
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 414, in resume_in_run_at_413
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     num_warps = get_special_arg("num_warps")
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 415, in resume_in_run_at_414
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     num_ctas = get_special_arg("num_ctas", 1)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 416, in resume_in_run_at_415
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     num_stages = get_special_arg("num_stages")
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 417, in resume_in_run_at_416
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     enable_warp_specialization = get_special_arg("enable_warp_specialization", False)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 426, in resume_in_run_at_417
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     bound_args = self.signature.bind(*args, **kwargs)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 427, in resume_in_run_at_426
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     bound_args.apply_defaults()
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 429, in resume_in_run_at_427
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     assert len(bound_args.arguments) == len(self.params)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 429, in resume_in_run_at_429
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     assert len(bound_args.arguments) == len(self.params)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 430, in resume_in_run_at_429
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     args = [KernelArg(arg_value, param) for (_, arg_value), param in zip(bound_args.arguments.items(), self.params)]
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 436, in resume_in_run_at_430
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     constexpr_key = tuple(arg.value for arg in args if arg.param.is_constexpr)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO] 
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO] ==========
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO] Traceback (most recent call last):
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 741, in _convert_frame
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     result = inner_convert(frame, cache_entry, hooks, frame_state)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 332, in _convert_frame_assert
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     unimplemented("generator")
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py", line 193, in unimplemented
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     raise Unsupported(msg)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO] torch._dynamo.exc.Unsupported: generator
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [DEBUG] skipping because no torch.* <listcomp>             /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py 450
[2024-04-05 19:06:32,022] [33/0] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing _device_of /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py:231
[2024-04-05 19:06:32,022] [33/0] torch.fx.experimental.symbolic_shapes: [INFO] create_env
[2024-04-05 19:06:32,023] [33/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py:231 in _device_of (JITFunction)
[2024-04-05 19:06:32,023] [33/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]         @staticmethod
[2024-04-05 19:06:32,023] [33/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py:233 in _device_of (JITFunction._device_of)
[2024-04-05 19:06:32,023] [33/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]             try:
[2024-04-05 19:06:32,023] [33/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE SETUP_FINALLY 12 []
[2024-04-05 19:06:32,023] [33/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py:234 in _device_of (JITFunction._device_of)
[2024-04-05 19:06:32,023] [33/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]                 return arg.device.type
[2024-04-05 19:06:32,023] [33/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST arg []
[2024-04-05 19:06:32,023] [33/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_ATTR device [LazyVariableTracker()]
[2024-04-05 19:06:32,023] [33/0] torch._dynamo.output_graph: [DEBUG] create_graph_input L_arg_ L['arg']
[2024-04-05 19:06:32,023] [33/0] torch._dynamo.variables.builder: [DEBUG] wrap_to_fake L['arg'] (16384, 1308) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>], constraint_sizes=[None, None], tensor_source=LocalSource(local_name='arg', cell_or_freevar=False), shape_env_to_source_to_symbol_cache={})
[2024-04-05 19:06:32,025] [33/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_ATTR type [ConstantVariable(device)]
[2024-04-05 19:06:32,025] [33/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE POP_BLOCK None [ConstantVariable(str)]
[2024-04-05 19:06:32,025] [33/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE RETURN_VALUE None [ConstantVariable(str)]
[2024-04-05 19:06:32,025] [33/0] torch._dynamo.convert_frame: [DEBUG] Skipping frame because no content in function call _device_of                     /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py 231
[2024-04-05 19:06:32,025] torch._dynamo.convert_frame: [DEBUG] skipping because no torch.* <listcomp>             /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py 451
[2024-04-05 19:06:32,025] torch._dynamo.convert_frame: [DEBUG] skipping because no torch.* <listcomp>             /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py 453
[2024-04-05 19:06:32,025] [34/0] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing _pinned_memory_of /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py:238
[2024-04-05 19:06:32,025] [34/0] torch.fx.experimental.symbolic_shapes: [INFO] create_env
[2024-04-05 19:06:32,026] [34/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py:238 in _pinned_memory_of (JITFunction)
[2024-04-05 19:06:32,026] [34/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]         @staticmethod
[2024-04-05 19:06:32,026] [34/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py:240 in _pinned_memory_of (JITFunction._pinned_memory_of)
[2024-04-05 19:06:32,026] [34/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]             try:
[2024-04-05 19:06:32,026] [34/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE SETUP_FINALLY 12 []
[2024-04-05 19:06:32,026] [34/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py:241 in _pinned_memory_of (JITFunction._pinned_memory_of)
[2024-04-05 19:06:32,026] [34/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]                 return arg.is_pinned()
[2024-04-05 19:06:32,026] [34/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST arg []
[2024-04-05 19:06:32,026] [34/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_ATTR is_pinned [LazyVariableTracker()]
[2024-04-05 19:06:32,026] [34/0] torch._dynamo.output_graph: [DEBUG] create_graph_input L_arg_ L['arg']
[2024-04-05 19:06:32,026] [34/0] torch._dynamo.variables.builder: [DEBUG] wrap_to_fake L['arg'] (16384, 1308) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>], constraint_sizes=[None, None], tensor_source=LocalSource(local_name='arg', cell_or_freevar=False), shape_env_to_source_to_symbol_cache={})
[2024-04-05 19:06:32,028] [34/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE CALL_FUNCTION 0 [GetAttrVariable(TensorVariable(), is_pinned)]
Traceback (most recent call last):
  File "/raid/xind/TLM/train.py", line 705, in <module>
    logits, loss = model(X, Y)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 417, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/raid/xind/TLM/new_models/model_fla.py", line 26, in forward
    ori_output = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/raid/xind/TLM/new_models/flash-linear-attention/fla/models/mamba/modeling_mamba.py", line 579, in forward
    mamba_outputs = self.backbone(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/raid/xind/TLM/new_models/flash-linear-attention/fla/models/mamba/modeling_mamba.py", line 494, in forward
    hidden_states = mixer_block(hidden_states, cache_params=cache_params)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/raid/xind/TLM/new_models/flash-linear-attention/fla/models/mamba/modeling_mamba.py", line 318, in forward
    hidden_states = self.norm(hidden_states)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 504, in forward
    return rms_norm_fn(
  File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 488, in rms_norm_fn
    return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 421, in forward
    y, mean, rstd, residual_out = _layer_norm_fwd(
  File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 155, in _layer_norm_fwd
    MAX_FUSED_SIZE = 65536 // x.element_size()
  File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 162, in resume_in__layer_norm_fwd_at_155
    _layer_norm_fwd_1pass_kernel[(M,)](
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 127, in run
    self.nargs = dict(zip(self.arg_names, args))
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 127, in resume_in_run_at_127
    self.nargs = dict(zip(self.arg_names, args))
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 127, in resume_in_run_at_127
    self.nargs = dict(zip(self.arg_names, args))
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 128, in resume_in_run_at_127
    if len(self.configs) > 1:
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 143, in resume_in_run_at_128
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 143, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 122, in _bench
    return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
  File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 102, in do_bench
    fn()
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 109, in kernel_call
    self.pre_hook(args)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 110, in resume_in_kernel_call_at_109
    self.fn.run(
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 413, in run
    grid = get_special_arg("grid")
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 414, in resume_in_run_at_413
    num_warps = get_special_arg("num_warps")
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 415, in resume_in_run_at_414
    num_ctas = get_special_arg("num_ctas", 1)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 416, in resume_in_run_at_415
    num_stages = get_special_arg("num_stages")
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 417, in resume_in_run_at_416
    enable_warp_specialization = get_special_arg("enable_warp_specialization", False)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 426, in resume_in_run_at_417
    bound_args = self.signature.bind(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 427, in resume_in_run_at_426
    bound_args.apply_defaults()
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 429, in resume_in_run_at_427
    assert len(bound_args.arguments) == len(self.params)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 429, in resume_in_run_at_429
    assert len(bound_args.arguments) == len(self.params)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 430, in resume_in_run_at_429
    args = [KernelArg(arg_value, param) for (_, arg_value), param in zip(bound_args.arguments.items(), self.params)]
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 453, in resume_in_run_at_430
    [self._pinned_memory_of(arg) for arg in non_constexpr_arg_values])
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 453, in <listcomp>
    [self._pinned_memory_of(arg) for arg in non_constexpr_arg_values])
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 580, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 741, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 384, in _convert_frame_assert
    return _compile(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 643, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 246, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 524, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 151, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 489, in transform
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2110, in run
    super().run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 780, in run
    and self.step()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 462, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1190, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 644, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/misc.py", line 645, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/tensor.py", line 770, in call_method
    return wrap_fx_proxy(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 1302, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 1387, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1590, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1545, in get_fake_value
    ret_val = wrap_fake_exception(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1086, in wrap_fake_exception
    return fn()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1546, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1657, in run_node
    raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1638, in run_node
    return getattr(args[0], node.target)(*args[1:], **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1480, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1711, in dispatch
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 442, in dispatch_to_op_implementations_dict
    return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 827, in nyi
    assert func not in _device_not_kwarg_ops, f"NYI: {func}"
torch._dynamo.exc.TorchRuntimeError: Failed running call_method is_pinned(*(FakeTensor(..., device='cuda:0', size=(16384, 1308), requires_grad=True),), **{}):
NYI: aten.is_pinned.default

from user code:
   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 241, in _pinned_memory_of
    return arg.is_pinned()


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

[2024-04-05 19:06:32,062] torch._dynamo.utils: [INFO] TorchDynamo compilation metrics:
[2024-04-05 19:06:32,062] torch._dynamo.utils: [INFO] Function                                  Runtimes (s)
[2024-04-05 19:06:32,062] torch._dynamo.utils: [INFO] --------------------------------------  --------------
[2024-04-05 19:06:32,062] torch._dynamo.utils: [INFO] _compile.<locals>.compile_inner                 2.5749
[2024-04-05 19:06:32,062] torch._dynamo.utils: [INFO] OutputGraph.call_user_compiler                  1.282
[2024-04-05 19:06:32,062] torch._dynamo.utils: [INFO] create_aot_dispatcher_function                  1.3513
[2024-04-05 19:06:32,062] torch._dynamo.utils: [INFO] compile_fx.<locals>.fw_compiler_base            1.2468
[2024-04-05 19:06:32,062] torch._dynamo.utils: [INFO] GraphLowering.run                               0.0301
[2024-04-05 19:06:32,062] torch._dynamo.utils: [INFO] GraphLowering.compile_to_module                 1.0085
[2024-04-05 19:06:32,062] torch._dynamo.utils: [INFO] Scheduler.__init__                              0.0083
[2024-04-05 19:06:32,062] torch._dynamo.utils: [INFO] Scheduler.codegen                               0.1939
[2024-04-05 19:06:32,062] torch._dynamo.utils: [INFO] WrapperCodeGen.generate                         0.0011
[2024-04-05 19:06:32,062] torch._dynamo.utils: [INFO] CachingAutotuner.benchmark_all_configs          0.1649

'RebasedFeatureMap' is missing?

Hi, thanks to all your efforts for such a good repo.

I'm trying to build a model basing on RebasedLinearAttention, but encounter error:

ImportError: cannot import name 'RebasedFeatureMap' from 'fla.modules.feature_map' (.*/fla/modules/feature_map.py)

Should this feature under developing or use hazyResearch's implementation located in https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/layers/rebased_fast.py?

Chunk wise linear attn kernel does not work with torch compile (returns incorrects values / NaNs)

I've been trying to use the linear attention kernels in a model which I am compiling, however the triton kernel does not seem to work with torch compile. Specifically, when comparing the output of the kernel to a reference implementation, they match when torch compile is not used, and they have huge discrepancies when the kernel is wrapped with torch compile)

Here is a script to reproduce the error:

import torch
import sys
import os
from fla.ops.linear_attn.chunk_fuse import FusedChunkLinearAttentionFunction
fused_chunk_linear_attn = FusedChunkLinearAttentionFunction.apply

def reference_implementation(q, k, v, scale, initial_state, output_final_state):
    # q,k,v: B, H, L, E
    q = q * scale
    attn_weights = torch.matmul(q, k.transpose(2, 3))
    causal_mask = torch.tril(torch.ones(q.shape[-2], q.shape[-2])).to(q.device).bool()
    attn_weights.masked_fill_(~causal_mask, 0.0)
    attn_output = torch.matmul(attn_weights, v)

    if initial_state is not None:
        state_contribution = torch.matmul(q, initial_state)
        attn_output += state_contribution

    hidden_state = None
    if output_final_state:
        hidden_state = torch.einsum("bhle,bhlf->bhef", k, v)
        if initial_state is not None:
            hidden_state += initial_state
    return attn_output, hidden_state


# SE (02/26): borrowing these tolerances from Mamba's test_selective_state_update
# https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/tests/ops/triton/test_selective_state_update.py#L24C1-L26C32
DTYPE_TO_ATOL = {
    torch.float32: 1e-3,
    torch.float16: 1e-2,
    torch.bfloat16: 5e-2,
}
DTYPE_TO_RTOL = {
    torch.float32: 3e-4,
    torch.float16: 5e-4,
    torch.bfloat16: 1e-2,
}


def compare_two_outputs(output1, output2, atol=1e-2, rtol=0):
    def compute_stats(a, b, atol=1e-2, rtol=0):
        abs_diff = torch.abs(a - b)
        mean_abs_diff = torch.mean(abs_diff).item()
        max_abs_diff = torch.max(abs_diff).item()
        not_close_mask = ~torch.isclose(a, b, atol=atol, rtol=rtol)
        mean_abs_diff_not_close = (
            torch.mean(abs_diff[not_close_mask]).item() if not_close_mask.any() else 0
        )
        mean_rel_diff_not_close = (
            torch.mean(abs_diff[not_close_mask] / (torch.abs(b[not_close_mask]) + 1e-7)).item()
            if not_close_mask.any()
            else 0
        )
        return {
            "mean_abs_diff": mean_abs_diff,
            "max_abs_diff": max_abs_diff,
            "mean_abs_diff_not_close": mean_abs_diff_not_close,
            "mean_rel_diff_not_close": mean_rel_diff_not_close,
            "percent_not_close": 100 * not_close_mask.float().mean().item(),
        }

    # if not torch.allclose(output1, output2, atol=atol, rtol=rtol):
    if not torch.allclose(output1, output2):
        output_stats = compute_stats(output1, output2, atol=atol, rtol=rtol)
        print("Different outputs")
        print(f"Mean absolute difference: {output_stats['mean_abs_diff']:.6e}")
        print(f"Max absolute difference: {output_stats['max_abs_diff']:.6e}")
        print(
            f"Mean absolute difference (not close): {output_stats['mean_abs_diff_not_close']:.6e}"
        )
        print(
            f"Relative absolute difference (not close): {output_stats['mean_rel_diff_not_close']:.6e}"
        )
        print(f"Percent not close: {output_stats['percent_not_close']:.2f}%")
    else:
        print("Outputs match")
    print("-" * 80)  # Separator for readability
    return torch.allclose(output1, output2, atol=atol, rtol=rtol)


def test_torch_compile(dtype):
    # Set up test parameters
    B, H, L, E = 1, 8, 2048, 64
    scale = 0.125
    output_final_state = True

    # Create input tensors
    q = (
        torch.empty(B, H, L, E, dtype=dtype, device="cuda")
        .normal_(mean=0.0, std=0.5)
        .requires_grad_()
        .contiguous()
    )
    k = (
        torch.empty(B, H, L, E, dtype=dtype, device="cuda")
        .normal_(mean=0.0, std=0.5)
        .requires_grad_()
        .contiguous()
    )
    v = (
        torch.empty(B, H, L, E, dtype=dtype, device="cuda")
        .normal_(mean=0.0, std=0.5)
        .requires_grad_()
        .contiguous()
    )
    initial_state = (
        torch.empty((B, H, E, E), dtype=dtype, device="cuda")
        .normal_(mean=0.0, std=0.5)
        .contiguous()
    )

    # Define a function that uses the kernel
    def attention_fn_kernel(q, k, v, initial_state):  # -> Any:
        return fused_chunk_linear_attn(q, k, v, scale, initial_state, output_final_state)

    def attention_fn_pytorch(q, k, v, initial_state):  # -> Any:
        return reference_implementation(q, k, v, scale, initial_state, output_final_state)

    # Compile/warmup the kernels
    compiled_fn_kernel = torch.compile(attention_fn_kernel)
    compiled_fn_pytorch = torch.compile(attention_fn_pytorch)
    with torch.autocast(device_type="cuda", dtype=dtype):
        output1, _ = compiled_fn_kernel(q, k, v, initial_state)
        output2, _ = compiled_fn_pytorch(q, k, v, initial_state)

    # Run the forward passes
    with torch.autocast(device_type="cuda", dtype=dtype):
        output_kernel_comp, _ = compiled_fn_kernel(q, k, v, initial_state)
        output_pytorch_comp, _ = compiled_fn_pytorch(q, k, v, initial_state)
        output_kernel, _ = attention_fn_kernel(q, k, v, initial_state)
        output_pytorch, _ = attention_fn_pytorch(q, k, v, initial_state)

    # Check if any of the compiled outputs match the original
    atol = DTYPE_TO_ATOL[dtype]
    rtol = DTYPE_TO_RTOL[dtype]

    print("Compiled vs non compiled Pytorch")
    compare_two_outputs(output_pytorch_comp, output_pytorch, atol=atol, rtol=rtol)
    print("Triton (non compiled) vs Pytorch (non compiled)")
    compare_two_outputs(output_kernel.float(), output_pytorch.float(), atol=atol, rtol=rtol)
    print("Triton (compiled) vs Pytorch (non compiled)")
    compare_two_outputs(output_kernel_comp.float(), output_pytorch.float(), atol=atol, rtol=rtol)
    print("Test compile completed")


if __name__ == "__main__":

    for dtype in [torch.float16, torch.bfloat16, torch.float32]:
        print(f"Testing torc compile for dtype {dtype}")
        test_torch_compile(dtype)
        print("=" * 80)

The output of the script (running on A100 40GB GPU):

Testing torc compile for dtype torch.float16                                                                                             [109/1908]Compiled vs non compiled Pytorch                                                                                                                   Outputs match                                                                                                                                      
--------------------------------------------------------------------------------
Triton (non compiled) vs Pytorch (non compiled)                                                                                                    Different outputs                                                                                                                                  
Mean absolute difference: 1.049560e-03                                                                                                            
Max absolute difference: 1.562500e-02                                                                                                              
Mean absolute difference (not close): 1.562500e-02                                                                                                 Relative absolute difference (not close): 1.805531e-03                                                                                             Percent not close: 0.00%                                                                                                                           
--------------------------------------------------------------------------------                                                                   Triton (compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 3.020972e+00
Max absolute difference: 2.664062e+01
Mean absolute difference (not close): 3.029570e+00
Relative absolute difference (not close): 9.999998e-01
Percent not close: 99.72%
--------------------------------------------------------------------------------
Test compile completed
==================================================
Testing torc compile for dtype torch.bfloat16                                                                                                      Compiled vs non compiled Pytorch                                                                                                                   
Outputs match                                                                                                                                      
--------------------------------------------------------------------------------                                                                   Triton (non compiled) vs Pytorch (non compiled)                                                                                                    Different outputs                                                                                                                                  
Mean absolute difference: 8.357576e-03                                                                                                             
Max absolute difference: 1.250000e-01                                                                                                              
Mean absolute difference (not close): 5.615234e-02                                                                                                 Relative absolute difference (not close): 1.953556e-01                                                                                             Percent not close: 0.00%
--------------------------------------------------------------------------------                                                                   Triton (compiled) vs Pytorch (non compiled)                                                                                                        Different outputs                                                                                                                                  
Mean absolute difference: nan                                                                                                                      
Max absolute difference: nan                                                                                                                       
Mean absolute difference (not close): nan                                                                                                          Relative absolute difference (not close): nan                                                                                                      
Percent not close: 100.00%                                                                                                                         
--------------------------------------------------------------------------------                                                                   Test compile completed
==============================================
Testing torc compile for dtype torch.float32
Compiled vs non compiled Pytorch
Outputs match
--------------------------------------------------------------------------------
Triton (non compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 2.531446e-06
Max absolute difference: 5.340576e-05
Mean absolute difference (not close): 0.000000e+00
Relative absolute difference (not close): 0.000000e+00
Percent not close: 0.00%
--------------------------------------------------------------------------------
Triton (compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: nan
Max absolute difference: nan
Mean absolute difference (not close): nan
Relative absolute difference (not close): nan
Percent not close: 100.00%
--------------------------------------------------------------------------------
Test compile completed
===============================================

Relevant libraries from environment:

Versions of relevant libraries:                                                                                                                    
[pip3] numpy==1.24.3
[pip3] torch==2.4.0.dev20240523                                                                                                                    
[pip3] torchaudio==2.2.0.dev20240523
[pip3] torchvision==0.19.0.dev20240523                                                                                                            
[pip3] triton==3.0.0                                                                                                                               
[pip3] triton-nightly==3.0.0.post20240522224832
[conda] blas                      1.0                         mkl    intel
[conda] brotlipy                  0.7.0           py311h9bf148f_1002    pytorch-nightly
[conda] cffi                      1.15.1          py311h9bf148f_3    pytorch-nightly                                                               [conda] cryptography              38.0.4          py311h46ebde7_0    pytorch-nightly
[conda] libjpeg-turbo             2.0.0                h9bf148f_0    pytorch-nightly
[conda] mkl                       2021.4.0              intel_640    intel
[conda] mkl-service               2.4.0           py311h9bf148f_0    pytorch-nightly
[conda] mkl_fft                   1.3.1           py311hc796f24_0    pytorch-nightly
[conda] mkl_random                1.2.2           py311hbba84a0_0    pytorch-nightly
[conda] mpmath                    1.2.1                   py311_0    pytorch-nightly
[conda] numpy                     1.26.4                   pypi_0    pypi                                                                         
[conda] numpy-base                1.24.3          py311hfd5febd_0
[conda] pysocks                   1.7.1                   py311_0    pytorch-nightly
[conda] pytorch                   2.4.0.dev20240523 py3.11_cuda12.4_cudnn8.9.2_0    pytorch-nightly
[conda] pytorch-cuda              12.4                 hc786d27_6    pytorch-nightly
[conda] pytorch-mutex             1.0                        cuda    pytorch-nightly
[conda] torchaudio                2.2.0.dev20240523     py311_cu124    pytorch-nightly
[conda] torchtriton               3.0.0+45fff310c8           py311    pytorch-nightly
[conda] torchvision               0.19.0.dev20240523     py311_cu124    pytorch-nightly                                                            [conda] triton-nightly            3.0.0.post20240522224832          pypi_0    pypi
[conda] urllib3                   1.26.14                 py311_0    pytorch-nightly 

Use Cache with GLA model raised error

Hi, I see that the RecurrentCache was renamed to Cache for gla model. However, it raised error as Cache does not have method “from_legacy_cache”.

benchmark_training_throughput and bugs

Hi, Thanks for your great work. I ran benchmarks to test all modes' throughput and memory usage with code flash-linear-attention/benchmarks/benchmark_training_throughput.py. Some of them, unfortunately, failed.

WORK:

  • GLA - GSA - HGRN - RetNet - transformer -MAMBA -Samba

FAILED

  • delta_net - HGRN2 - Linear_attn - RWKV6

(For mamba, i notice FLA did not asking for mamba_ssm and causal_conv1d, but also did not raise any warning that it runs on slow-forward mode)

Here are benchmark results, and the error info of the failed run is attached at the end.

Environment:

NVIDIA A100-SXM4-40GB
NVIDIA-SMI 550.54.15             
Driver Version: 550.54.15      
CUDA Version: 12.4 
torch                    2.3.1
accelerate               0.32.1
transformers             4.42.4
triton                   2.3.1
nvidia-cublas-cu12       12.1.3.1
nvidia-cuda-cupti-cu12   12.1.105
nvidia-cuda-nvrtc-cu12   12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12        8.9.2.26
nvidia-cufft-cu12        11.0.2.54
nvidia-curand-cu12       10.3.2.106
nvidia-cusolver-cu12     11.4.5.107
nvidia-cusparse-cu12     12.1.0.106
nvidia-nccl-cu12         2.20.5
nvidia-nvjitlink-cu12    12.5.82
nvidia-nvtx-cu12         12.1.105
Model Batch-Size Seq_len Max-memory (GB) Throughput (tokens/s)
GLA 8 512 14.77 14959.27
    1024 22.88 17467.75
    2048 OOM  
GSA 8 512 16.07 14960.99
    1024 24.35 17674.00
    2028 OOM  
HGRN 8 512 16.9 16382.00
    1024 26.15 19500.58
    2048 OOM  
retnet 8 512 15.13 13369.01
    1024 22.66 15437.14
    2048 37.75 16445.58
transformer 8 512 13.98 17851.42
    1024 20.30 20994.52
    2048 32.96 21807.02
Mamba 8 512 15.40 10385.55
    1024 22.72 11230.94
    2048 37.36 12151.64
Samba 8 512 13.77 16475.11
    1024 19.56 18470.86
    2048 31.18 19850.30

Delta-Net

Traceback (most recent call last):
  File "system/fla-bench/flash-linear-attention/benchmarks/benchmark_training_throughput.py", line 100, in <module>
    profile(args.name, args.batch_size, args.seq_len, args.warmup_steps, args.steps)
  File "system/fla-bench/flash-linear-attention/benchmarks/benchmark_training_throughput.py", line 66, in profile
    outputs = model(tokens, labels=tokens)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 819, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 807, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/models/delta_net/modeling_delta_net.py", line 385, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/models/delta_net/modeling_delta_net.py", line 263, in forward
    hidden_states, attentions, past_key_values = layer(
                                                 ^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/models/delta_net/modeling_delta_net.py", line 117, in forward
    hidden_states, attentions, past_key_values = self.attn(
                                                 ^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/layers/delta_net.py", line 228, in forward
    o, recurrent_state = chunk_delta_rule(q, k, v, beta, self.chunk_size, state, output_final_state=use_cache)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/ops/delta_rule/chunk.py", line 543, in chunk_delta_rule
    o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta, BT,  initial_state, output_final_state)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/autograd/function.py", line 598, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/cuda/amp/autocast_mode.py", line 115, in decorate_fwd
    return fwd(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/utils.py", line 11, in wrapper
    return fn(ctx,
           ^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/ops/delta_rule/chunk.py", line 503, in forward
    h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)        
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/ops/delta_rule/chunk.py", line 394, in chunk_fwd_h_fn
    assert BK <= 256, "current kernel does not support head dimension larger than 256."
           ^^^^^^^^^
AssertionError: current kernel does not support head dimension larger than 256.
----------------------------------------------

HGRN2

Traceback (most recent call last):
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1222, in ast_to_ttir
    generator.visit(fn.parse())
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
          ^^^^^^^^^^^^^^^^^^^
  File "/p/software/machine/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 303, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/p/software/machine/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/ast.py", line 426, in generic_visit
    self.visit(item)
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
          ^^^^^^^^^^^^^^^^^^^
  File "/p/software/machine/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 376, in visit_FunctionDef
    self.visit_compound_statement(node.body)
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 298, in visit_compound_statement
    ret_type = self.visit(stmt)
               ^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
          ^^^^^^^^^^^^^^^^^^^
  File "/p/software/machine/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 885, in visit_For
    self.visit_compound_statement(node.body)
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 298, in visit_compound_statement
    ret_type = self.visit(stmt)
               ^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
          ^^^^^^^^^^^^^^^^^^^
  File "/p/software/machine/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 448, in visit_AugAssign
    self.visit(assign)
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
          ^^^^^^^^^^^^^^^^^^^
  File "/p/software/machine/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 428, in visit_Assign
    values = self.visit(node.value)
             ^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
          ^^^^^^^^^^^^^^^^^^^
  File "/p/software/machine/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 477, in visit_BinOp
    rhs = self.visit(node.right)
          ^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
          ^^^^^^^^^^^^^^^^^^^
  File "/p/software/machine/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1027, in visit_Call
    return fn(*args, **extra_kwargs, **kws)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/language/core.py", line 27, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/language/core.py", line 1018, in dot
    return semantic.dot(input, other, acc, allow_tf32, max_num_imprecise_acc, out_dtype, _builder)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/language/semantic.py", line 1207, in dot
    assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.options)
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/language/semantic.py", line 1183, in assert_dtypes_valid
    assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!"
           ^^^^^^^^^^^^^^^^^^^^^^
AssertionError: First input (fp32) and second input (bf16) must have the same dtype!

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "system/fla-bench/flash-linear-attention/benchmarks/benchmark_training_throughput.py", line 100, in <module>
    profile(args.name, args.batch_size, args.seq_len, args.warmup_steps, args.steps)
  File "system/fla-bench/flash-linear-attention/benchmarks/benchmark_training_throughput.py", line 66, in profile
    outputs = model(tokens, labels=tokens)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 819, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 807, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/models/hgrn2/modeling_hgrn2.py", line 372, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/models/hgrn2/modeling_hgrn2.py", line 248, in forward
    hidden_states, attentions, past_key_values = layer(
                                                 ^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/models/hgrn2/modeling_hgrn2.py", line 97, in forward
    hidden_states, attentions, past_key_values = self.attn(
                                                 ^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/layers/hgrn2.py", line 151, in forward
    o, recurrent_state = chunk_gla(q, k, i, g, initial_state=recurrent_state, output_final_state=use_cache)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/ops/gla/chunk.py", line 733, in chunk_gla
    o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, checkpoint_level)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/autograd/function.py", line 598, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/utils.py", line 11, in wrapper
    return fn(ctx,
           ^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/ops/gla/chunk.py", line 541, in forward
    h = fwd_inner(
        ^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/ops/gla/chunk.py", line 514, in fwd_inner
    chunk_gla_fwd_kernel_h[grid](
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 167, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 416, in run
    self.cache[device][key] = compile(
                              ^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 191, in compile
    module = src.make_ir(options)
             ^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 117, in make_ir
    return ast_to_ttir(self.fn, self, options=options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1231, in ast_to_ttir
    raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 53:27:        # [BT, BV]
        b_v = tl.load(p_v, boundary_check=(0, 1))
        # [BK, BT]
        b_g = tl.load(p_g, boundary_check=(0, 1))
        if i_t < NT - 1:
            # [BK,]
            b_gn = tl.load(p_gn, boundary_check=(0,))
        else:
            b_gn = tl.min(b_g, axis=1)
        b_h *= tl.exp(b_gn)[:, None]
        b_k = (b_k * tl.exp(b_gn[:, None] - b_g)).to(b_k.dtype)
        b_h += tl.dot(b_k, b_v, allow_tf32=False)
                           ^
AssertionError('First input (fp32) and second input (bf16) must have the same dtype!')

Linear_attention

Traceback (most recent call last):
  File "system/fla-bench/flash-linear-attention/benchmarks/benchmark_training_throughput.py", line 100, in <module>
    profile(args.name, args.batch_size, args.seq_len, args.warmup_steps, args.steps)
  File "system/fla-bench/flash-linear-attention/benchmarks/benchmark_training_throughput.py", line 66, in profile
    outputs = model(tokens, labels=tokens)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 819, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 807, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/models/linear_attn/modeling_linear_attn.py", line 389, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/models/linear_attn/modeling_linear_attn.py", line 268, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/models/linear_attn/modeling_linear_attn.py", line 105, in forward
    hidden_states = self.attn(hidden_states)
                    ^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/layers/linear_attn.py", line 129, in forward
    q = self.feature_map_q(q)
        ^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/modules/feature_map.py", line 130, in forward
    return self.layer1(x) * self.layer2(x)
           ^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: mat1 and mat2 shapes cannot be multiplied (8192x2048 and 512x512)

Rwkv6

OOM for BS=8

 File "system/fla-bench/flash-linear-attention/fla/layers/rwkv6.py", line 272, in forward
    return self.linear(x + delta * mu)
                       ~~^~~~~~~~~~~~
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 MiB. GPU 

training efficiency of GLA

Hi, great work!! I am comparing the GLA and the attention (from pytorch) in diffusion model training. The GLA seems quite slower than the attention, even worse in the first epoch. Would love to know why and any solution we can take.
image

Hello from HF Diffusers

Thanks for the incredibly clean repository!

I am Sayak from the Diffusers team at Hugging Face. My question is probably very naive, so I apologize for that in advance.

I wanted to know if linear attention could applied in inference time only? More precisely, can I take a model trained with regular attention and turn it into a linear attention model during inference?

Current FLA RWKV6 implementation has significant precision issues in pure bf16 mode

The current FLA RWKV6 implementation has significant precision issues in pure bf16 mode.
Below are the results from my experiments:

CUDA bf16 (fp32 internal):

y: 0.0016603531206355376
gr: 0.0017877683404764239
gk: 0.0017853925508536652
gv: 0.0022316154634133964
gw: 0.0018482808625786967
gu: 0.0018472627187992381

FLA fp32:

y: 5.153093822969028e-07
gr: 5.860136550906496e-07
gk: 5.969486336398631e-07
gv: 5.833091583780125e-07
gw: 2.3036314788307143e-05
gu: 3.5015232226862115e-07

FLA bf16:

y: 0.0025760101921418134
gr: 0.0029575546041739134
gk: 0.002951189528185581
gv: 0.0031975613176225934
gw: 0.08319189127088046 (!!!)
gu: 0.0017254238302962922

As shown, the FLA bf16 results show significantly larger errors, particularly for gw.

Please look into this precision issue.

Thank you

Error loading pretrained checkpoints through `transformers` library

It rasies the following error when I run model = AutoModelForCausalLM.from_pretrained("fla-hub/gla-1.3B-200B").

ValueError: The checkpoint you are trying to load has model type gla but Transformers does not recognize this architecture. This could be because of an issue with the checkpoint, or because your version of Transformers is out of date.

illegal memory access error

Hi Songlin,

Thanks for your great work! I tried some comparative experiments on FLA recently and it presents great performance! But I faced an error when I increased the dimension to 1152 and set the head num to 16. The details of this error are shown below:

n136-180-028:1707397:1709473 [6] include/alloc.h:124 NCCL WARN Cuda failure 700 'an illegal memory access was encountered'
n136-180-028:1707397:1709473 [6] NCCL INFO include/alloc.h:245 -> 1

Could you give me some advice or guidance about this? Thanks a lot! By the way, could I get your WeChat for further communication? :D

Yours,
Lianghui

High precision and gradient discrepancy in RWKV Triton implementation between chunk and recurrent_fuse

image

I've noticed a significant discrepancy in precision and gradients in the chunk implementation of the RWKV operator in the Triton implementation. The file in question is flash-linear-attention/fla/ops/rwkv6/chunk_naive.py.

Important observation:
The standard Triton implementation and the PyTorch implementation show no significant precision differences. However, the chunk implementation exhibits extremely high discrepancies.

Details:

  • About 90% of the outputs show large differences in the chunk implementation.
  • about 10%-200% of the gradient differences in the chunk implementation.
  • The maximum differences observed are as follows:
tensor(145., device='cuda:0', dtype=torch.bfloat16, grad_fn=<MaxBackward1>)
tensor(145., device='cuda:0', dtype=torch.bfloat16, grad_fn=<MaxBackward1>)
tensor(202., device='cuda:0', dtype=torch.bfloat16)
tensor(198., device='cuda:0', dtype=torch.bfloat16)
tensor(57.2500, device='cuda:0', dtype=torch.bfloat16)
tensor(163., device='cuda:0', dtype=torch.bfloat16)
tensor(1392., device='cuda:0', dtype=torch.bfloat16)

These discrepancies are concerning as they could lead to significant differences in model performance and training stability, especially when using the chunk implementation.

Questions:

  1. Is this level of difference expected in the chunk implementation, given that the standard Triton and PyTorch implementations align well?
  2. Could there be a specific bug or issue in the chunk implementation causing these large discrepancies?
  3. Are there any known numerical stability issues in the chunking process that could explain these differences?
  4. How does the chunking process differ from the standard implementation, and could this difference account for the observed discrepancies?

Steps to reproduce:

python flash-linear-attention/fla/ops/rwkv6/chunk_naive.py

Environment:

  • PyTorch version: 2.3

  • GPU model: NVIDIA 4090

  • Pytorch 2.4

  • GPU intel arc a770

Any insights or suggestions for investigating this issue further would be greatly appreciated. In particular, any guidance on debugging the chunk implementation or understanding why it differs so significantly from the other implementations would be very helpful. Thank you for your time and attention to this matter.

bug in treatment of scale for fused_chunk_linear_attn

Thanks for the amazing library!

I discovered an error in https://github.com/sustcsonglin/flash-linear-attention/blob/main/fla/ops/linear_attn/chunk_fuse.py

It works great with scale=1.0, but with scale=-1 (therefore q.shape[-1] ** -0.5) it gives significantly incorrect results, mismatching the results from chunk_linear_attn. The chunk_linear_attn function seems to work fine with other scales such as -1.

Unfortunately I'm not sure where in the triton code the application of scale is done incorrectly.

Mistakes in the GLA paper

Thank you for all your great work on linear attention and I'm very excited about this repo!

I just wanted to bring up some errors in the GLA paper which might make the paper look less valuable, and maybe you could fix these in a new version if you want to.

Shouldn't the division by the temperature term be inside the brackets?
Screenshot 2024-01-10 at 09 35 53

Why is the Q_t transposed here?
Screenshot 2024-01-10 at 09 36 39

For the B calculation it should be the product of beta not B
Screenshot 2024-01-10 at 09 38 06

Again here why is the Q transposed?
Screenshot 2024-01-10 at 09 39 51

The V = V[...] is missing here
Screenshot 2024-01-10 at 09 40 44

Shouldn't a_normaliser and b_normaliser be a_q[...] not a[...]?
Screenshot 2024-01-10 at 09 41 32

Shouldn't it be k = K[:,:,iC + jc,....] not k*c?
Screenshot 2024-01-10 at 09 43 08

Shouldn't it be V_iC + k, not V_iCk?
Screenshot 2024-01-10 at 09 44 12

If any of these are not actually mistakes then feel free to point that out, hope this helps!

RuntimeError: Triton Error [CUDA]: device kernel image is invalid

Greetings!

First off, I want to express my gratitude for sharing your incredible work; it's truly impressive! However, when I attempted to execute the test code outlined in the README, I encountered the following error. Could you kindly offer any guidance or recommendations on how to resolve this issue?

Thank you for your assistance!

Traceback (most recent call last):
  File "/share/project/zzzr/jjlm/old_ckpts/VisionProjects/up-to-date-lib/flash-linear-attention/unit_test.py", line 19, in <module>
    y2 = gla(x)
  File "/home/z'z'z/.conda/envs/gla/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zzz/.conda/envs/gla/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/share/project/zzzr/jjlm/old_ckpts/VisionProjects/up-to-date-lib/flash-linear-attention/fla/layers/gla.py", line 95, in forward
    o = fused_chunk_gla(q, k, v, gk)
  File "/share/project/zzzr/jjlm/old_ckpts/VisionProjects/up-to-date-lib/flash-linear-attention/fla/ops/gla/chunk_fuse.py", line 516, in fused_chunk_gla
    o, final_state = FusedChunkGLAFunction.apply(
  File "/home/zzz/.conda/envs/gla/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/share/project/zzzr/jjlm/old_ckpts/VisionProjects/up-to-date-lib/flash-linear-attention/fla/ops/utils.py", line 11, in wrapper
    return fn(ctx,
  File "/home/zzz/.conda/envs/gla/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 115, in decorate_fwd
    return fwd(*args, **kwargs)
  File "/share/project/zzzr/jjlm/old_ckpts/VisionProjects/up-to-date-lib/flash-linear-attention/fla/ops/gla/chunk_fuse.py", line 320, in forward
    fwd_decay_cumsum[grid](
  File "/home/zzz/.conda/envs/gla/lib/python3.10/site-packages/triton/runtime/jit.py", line 550, in run
    bin.c_wrapper(
  File "/home/zzz/.conda/envs/gla/lib/python3.10/site-packages/triton/compiler/compiler.py", line 692, in __getattribute__
    self._init_handles()
  File "/home/zzz/.conda/envs/gla/lib/python3.10/site-packages/triton/compiler/compiler.py", line 683, in _init_handles
    mod, func, n_regs, n_spills = fn_load_binary(self.metadata["name"], self.asm[bin_path], self.shared, device)
RuntimeError: Triton Error [CUDA]: device kernel image is invalid

And my environment is shown below:

PyTorch version: 2.2.1+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.1 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: version 3.26.0
Libc version: glibc-2.31

Python version: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-113-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100-SXM4-40GB
Nvidia driver version: 470.129.06
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.1.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   43 bits physical, 48 bits virtual
CPU(s):                          256
On-line CPU(s) list:             0-255
Thread(s) per core:              2
Core(s) per socket:              64
Socket(s):                       2
NUMA node(s):                    8
Vendor ID:                       AuthenticAMD
CPU family:                      23
Model:                           49
Model name:                      AMD EPYC 7742 64-Core Processor
Stepping:                        0
Frequency boost:                 enabled
CPU MHz:                         3235.858
CPU max MHz:                     2250.0000
CPU min MHz:                     1500.0000
BogoMIPS:                        4491.50
Virtualization:                  AMD-V
L1d cache:                       4 MiB
L1i cache:                       4 MiB
L2 cache:                        64 MiB
L3 cache:                        512 MiB
NUMA node0 CPU(s):               0-15,128-143
NUMA node1 CPU(s):               16-31,144-159
NUMA node2 CPU(s):               32-47,160-175
NUMA node3 CPU(s):               48-63,176-191
NUMA node4 CPU(s):               64-79,192-207
NUMA node5 CPU(s):               80-95,208-223
NUMA node6 CPU(s):               96-111,224-239
NUMA node7 CPU(s):               112-127,240-255
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate sme ssbd mba sev ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif umip rdpid overflow_recov succor smca

Versions of relevant libraries:
[pip3] numpy==1.26.3
[pip3] torch==2.2.1+cu118
[pip3] torchaudio==2.2.1+cu118
[pip3] torchvision==0.17.1+cu118
[pip3] triton==2.2.0
[conda] numpy                     1.26.3                   pypi_0    pypi
[conda] torch                     2.2.1+cu118              pypi_0    pypi
[conda] torchaudio                2.2.1+cu118              pypi_0    pypi
[conda] torchvision               0.17.1+cu118             pypi_0    pypi
[conda] triton                    2.2.0                    pypi_0    pypi

Using operators without having `transformers` installed

I'm currently trying to use just the operators defined in fla.ops; however, because of the __init__.py script for the main package, it's not possible to do this without importing things from the HF transformers package, which makes the import slower (and broke it entirely until I upgraded the package).

It would be nice if there were a way to just import the operators without the layer modules or anything else.

RWKV6 backward gives nan gradients

U gradient is fine. all other grads grow uncontrolably

import torch as th
from fla.ops.rwkv_6.recurrent_fuse import fused_recurrent_rwkv6

B, H, L, K, V = 2, 4, 256, 64, 64

r, k, v, w = th.randn(4, B, H, L, K).cuda()
w = w.sigmoid()
u = th.randn(H, K).cuda()

r.requires_grad = True
k.requires_grad = True
v.requires_grad = True
w.requires_grad = True
u.requires_grad = True

o, state = fused_recurrent_rwkv6(r, k, v, w, u)

print(o.shape)
o.mean().backward()
print(u.grad.shape)
print(w.grad)

Finetune RWKV6 with fla implementations (使用fla中的rwkv6微调)

在bo的代码基础上将cuda算子替换为fla后loss初始很高,8.几起步。之前直接使用gla算子替换cuda,因为gla和rwkv的state计算顺序错开所以那里的r需要roll一下(微调正常)。所以麻烦您看一下我的代码是否少了一些必要操作

image

Variable-length sequence support

Hello, I noticed that retnet currently does not support variable-length sequences as input, which is quite important during pre-training. For example, features similar to seq_idx in mamba2 or cu_seqlens in flash attention. Do you have plans to support this feature?

Checkpoints for 340M models

Thanks for your great work.
I did not find the checkpoint of 340M baselines on huggingface reported in GLA paper, could you kindly share them?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.