Code Monkey home page Code Monkey logo

griffin's Introduction

Multi-Modality

Griffin

Implementation of Griffin from the paper: "Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models". PAPER LINK

install

$ pip install griffin-torch

usage

import torch
from griffin_torch.main import Griffin

# Forward pass
x = torch.randint(0, 100, (1, 10))

# Model
model = Griffin(
    dim=512,  # Dimension of the model
    num_tokens=100,  # Number of tokens in the input
    seq_len=10,  # Length of the input sequence
    depth=8,  # Number of transformer blocks
    mlp_mult=4,  # Multiplier for the hidden dimension in the MLPs
    dropout=0.1,  # Dropout rate
)

# Forward pass
y = model(x)

print(y)

License

MIT

Citation

@misc{de2024griffin,
    title={Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models}, 
    author={Soham De and Samuel L. Smith and Anushan Fernando and Aleksandar Botev and George Cristian-Muraru and Albert Gu and Ruba Haroun and Leonard Berrada and Yutian Chen and Srivatsan Srinivasan and Guillaume Desjardins and Arnaud Doucet and David Budden and Yee Whye Teh and Razvan Pascanu and Nando De Freitas and Caglar Gulcehre},
    year={2024},
    eprint={2402.19427},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

griffin's People

Contributors

dependabot[bot] avatar kyegomez avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

griffin's Issues

[BUG] [RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)]

2024-03-05 17:07:41,378 - numexpr.utils - INFO - Note: NumExpr detected 96 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2024-03-05 17:07:41,378 - numexpr.utils - INFO - NumExpr defaulting to 8 threads.
2024-03-05 17:07:41.611780: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-05 17:07:41.611844: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-05 17:07:41.613235: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-03-05 17:07:41.618957: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-03-05 17:07:42.159194: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/home/kye/miniconda3/lib/python3.11/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/kye/miniconda3/lib/python3.11/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
Using 4 GPUs!
Using StableAdamWUnfused-v1
training:   0%|                                      | 0/100000 [00:00<?, ?it/s]torch.Size([1, 1024, 512])
torch.Size([1, 1024, 512])
torch.Size([1, 1024, 512])
torch.Size([1, 1024, 512])
training:   0%|                                      | 0/100000 [00:01<?, ?it/s]
Traceback (most recent call last):
  File "/home/kye/Griffin/train.py", line 107, in <module>
    loss = model(next(train_loader))
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kye/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kye/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kye/miniconda3/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 185, in forward
    outputs = self.parallel_apply(replicas, inputs, module_kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kye/miniconda3/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 200, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kye/miniconda3/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 108, in parallel_apply
    output.reraise()
  File "/home/kye/miniconda3/lib/python3.11/site-packages/torch/_utils.py", line 722, in reraise
    raise exception
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/kye/miniconda3/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker
    output = module(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kye/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kye/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kye/miniconda3/lib/python3.11/site-packages/zeta/structs/auto_regressive_wrapper.py", line 320, in forward
    logits = self.net(inp, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kye/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kye/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kye/Griffin/griffin_torch/main.py", line 327, in forward
    x = layer(x) + x
        ^^^^^^^^
  File "/home/kye/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kye/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kye/Griffin/griffin_torch/main.py", line 204, in forward
    linear_1, linear_2 = nn.Linear(d, d)(x), nn.Linear(d, d)(x)
                         ^^^^^^^^^^^^^^^^^^
  File "/home/kye/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kye/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kye/miniconda3/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

(base) kye@api:~/Griffin$ 

[BUG] implementation of Griffin

Dear authors,
It seems that the local attention described in the paper is not implemented. Furthermore, I think the initialization of linear layers should be put in the init function, not in the forward function like code below.

    linear_1, linear_2 = nn.Linear(d, d)(x), nn.Linear(d, d)(x)

    linear_1 = nn.Conv1d(
        in_channels=s,
        out_channels=s,
        kernel_size=3,
        padding=1,
    )(linear_1)

[BUG] Calculation of alpha and initialization of lambda incorrect

Hello,

I noticed two deviations from the Griffin paper in your code.

Lambda

Here, Lambda is initialized as:

self.Lambda.data.uniform_(
torch.logit(torch.tensor(0.9)),
torch.logit(torch.tensor(0.999)),
)

However, the Griffin paper states in the second part of chapter 2.4:

We initialize Λ such that a^c is uniformly distributed between 0.9 and 0.999 at the start of training,

and a = sigmoid(Λ).

So actually, the initialization for Lambda should be calculated as Λ = -log((1 / a^(1/c)) - 1) with a uniformly between 0.9 and 0.999.

Alpha

And second a_t is defined in the paper (equation 3) as:

a_t = a^(c*r_t)

which is in nowhere near the formula in the code:

a = torch.sigmoid(self.Lambda)
at = a / self.c**rt

And also, the implementation should follow the recommendation from Appendix A (Implementation) in the paper (equation 6), to implement this operation in log-space:
a_t = exp(-c*softplus(-Λ) ⊙ r_t)

Note, that the formula in the paper is missing the minus before the Lambda, but you can easily check, that it should be there yourself:
https://www.wolframalpha.com/input?i=exp%28-8*log%281%2Bexp%28-l%29%29%29+%3D+sigmoid%28l%29%5E8
or for general c:
https://www.wolframalpha.com/input?i=exp%28-c*log%281%2Bexp%28-l%29%29%29+%3D+sigmoid%28l%29%5Ec

[BUG] Dimension errors in RGLRU class

[Description]

In forward() of class RGLRU, the dimension of x is [Batch_size, dim],
but the dimension of self.Wa is [hidden_dim,dim],
where hidden_dim=mul*dim, how could RGLRU work?

"linear_1 = self.lru(linear_1)"
This line is commented in the code.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.