Code Monkey home page Code Monkey logo

yet-another-retnet's Introduction

trophies

If you find my projects useful, please consider becoming a sponsor. Everything here comes from my free time, and is released under permissive licenses (e.g. MIT). Your contribution helps fund open-source AI.

buymeacoffee

yet-another-retnet's People

Contributors

amshaker avatar dongyeongkim avatar draguve avatar fkodom avatar leor-c avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

yet-another-retnet's Issues

How do I make a PR?

I have committed some minor changes to a new branch and attempted to push these, but I am getting access denied.

How's this RetNet useful when throughput is actually lower?

Thank you for your work. I did some testing with your implementation and it is robust and works pretty well !

However, for non-auto-regressive applications, the throughput is pretty much worse than regular transformers. In essence, the same parallel formulation can be used to generate a token (or a representation) by feeding the entire tokens without having to worry about keeping the states and looping through them.

Then in this case, what makes RetNet a successor to transformer ?

No [tool.poetry] section in pyproject.toml

To help fix my error (see previous issue), I decided to run you pyproject.toml file. However, I get the error that there is no [tool.poetry] section. How do you run your toml file? Here is your pyproject.toml as of 2023-11-05_10:13:

[build-system]
requires = ["setuptools", "setuptools-scm"]

[project]
authors = [
  {name = "Frank Odom", email = "[email protected]"},
]
classifiers = ["Programming Language :: Python :: 3"]
dependencies = [
  # TODO: Check version ranges and pin dependencies
  "einops",
  "torch>=1.8",
]
description = "yet-another-retnet"
dynamic = ["version", "readme"] # NOTE: Must be in sync with [tool.setuptools.dynamic] below
license = {text = "MIT"}
name = "yet-another-retnet"
requires-python = ">=3.8"

[tool.setuptools.dynamic]
# NOTE: Must be in sync with 'project.dynamic' above
readme = {file = ["README.md"], content-type = "text/markdown"}
version = {attr = "yet_another_retnet.VERSION"}

[tool.setuptools.packages.find]
exclude = ["tests"]

# extra packages (e.g. pip install .[test])
[project.optional-dependencies]
test = [
  "black",
  "kaleido",
  "mypy",
  "pre-commit",
  "plotly",
  "pytest",
  "pytest-cov",
  "ruff",
  "types-requests",
]
train = [
  "lightning~=2.0.0",
  "tensorboard~=2.14.0",
  "tiktoken~=0.4.0",
  "torchdata>=0.6.0",
  "tqdm",
]

# ----- Linting, Formatting, and Typing -----

[tool.black]
line-length = 88

[tool.mypy]
check_untyped_defs = "true"
files = "yet_another_retnet/"
ignore_missing_imports = "true"

[tool.pytest.ini_options]
addopts = "--cov --cov-report term-missing --cov-fail-under 80"
filterwarnings = "ignore:.*.:DeprecationWarning"
testpaths = ["tests"]

[tool.ruff]
ignore = ["B905", "E501"]
line-length = 88
select = [
  "B",
  "C",
  "E",
  "F",
  "I",
  "W",
]
# Exclude a variety of commonly ignored directories.
exclude = [
  ".bzr",
  ".direnv",
  ".eggs",
  ".git",
  ".hg",
  ".mypy_cache",
  ".nox",
  ".pants.d",
  ".ruff_cache",
  ".svn",
  ".tox",
  ".venv",
  "__pypackages__",
  "_build",
  "buck-out",
  "build",
  "dist",
  "node_modules",
  "venv",
]

[tool.ruff.mccabe]
max-complexity = 18

Throughput measurements of parallel and recurrence methods

Hi @fkodom ,

Thank you so much for sharing this work with the research community.

I have one question please, I measure the throughput in the inference and it seems that the parallel method has more throughput compared to the recurrence method, which is inconsistent with the paper. They claim that recurrent inference is O(1) and has higher throughput. Have you tested that or know what is the reason?

Best regards,
Abdelrahman.

Benchmark_inference

I don't quite understand the line:

            transformer_result = benchmark(transformer, x[:, -1:], start_pos=seq_length - 1)

I timed it and also timed the following:

            transformer_result = benchmark(transformer, x[:, 0:1], start_pos=0)

The longer the seq_length, I was expecting the timings to become more and more disparate. But that isn't the case. The second call gave a throughput of approximately 50% of the throughput of the first call. Why would that be? I believe I have some kind of misunderstanding. Thanks.

Some issues regarding _build_decay_mask.

Thank you for your implementation, but I have encountered a bug when using the code. There is a major problem in the function _build_decay_mask where the last element of decay_gammas is set to 1. This causes all elements of (decay_gammas**distance)[-1] to be 1 (since 1 ** float("inf") = 1), which can lead to information leakage. Here is a suggested modification:

    # Set the upper-triangular distances to infinity, so that only *past* keys
    # can affect the current query.  (Setting distance to infinity ensures that
    # the decay matrix is 0 for those positions, since x^(inf) = 0 when -1 < x < 1.
    distance_mask = torch.ones_like(distance, dtype=torch.bool).triu_(diagonal=1)
    # distance = distance.masked_fill_(distance_mask, "inf")

    distance = rearrange(distance, "n s -> () n s")
    decay_gammas = rearrange(decay_gammas, "h -> h () ()")
    
    decay_mask = decay_gammas**distance
    decay_mask = decay_mask.masked_fill_(distance_mask, 0)
    return decay_mask

Invalid precision when running train_project_gutenberg

I am running on the mac, and like the clarity of your code. My device is 'cpu', since I don't have Cuda on a mac notebook. Running retnet.py works fine. However, when running train_project_gutenburg using the command:

python -m scripts.train_project_gutenberg

I get the following trace:

Traceback (most recent call last):
  File "/Users/erlebach/opt/miniconda3/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/Users/erlebach/opt/miniconda3/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/Users/erlebach/src/2023/retentive_networks/retentive_code/yet-another-retnet/scripts/train_project_gutenberg.py", line 388, in <module>
    main(**vars(args))
  File "/Users/erlebach/src/2023/retentive_networks/retentive_code/yet-another-retnet/scripts/train_project_gutenberg.py", line 349, in main
    train(
  File "/Users/erlebach/src/2023/retentive_networks/retentive_code/yet-another-retnet/scripts/train_project_gutenberg.py", line 205, in train
    fabric = Fabric(
  File "/Users/erlebach/src/2023/retentive_networks/retentive_code/yet-another-retnet/.venv/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 127, in __init__
    self._connector = _Connector(
  File "/Users/erlebach/src/2023/retentive_networks/retentive_code/yet-another-retnet/.venv/lib/python3.10/site-packages/lightning/fabric/connector.py", line 135, in __init__
    self._check_config_and_set_final_flags(
  File "/Users/erlebach/src/2023/retentive_networks/retentive_code/yet-another-retnet/.venv/lib/python3.10/site-packages/lightning/fabric/connector.py", line 229, in _check_config_and_set_final_flags
    precision_input = _convert_precision_to_unified_args(precision)
  File "/Users/erlebach/src/2023/retentive_networks/retentive_code/yet-another-retnet/.venv/lib/python3.10/site-packages/lightning/fabric/connector.py", line 559, in _convert_precision_to_unified_args
    raise ValueError(f"Precision {repr(precision)} is invalid. Allowed precision values: {supported_precision}")
ValueError: Precision 'float32' is invalid. Allowed precision values: ('transformer-engine', 'transformer-engine-float16', '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', '32-true', '64-true', 64, 32, 16, '64', '32', '16', 'bf16')

What is non-standard in the current code that is preventing this from happening? Thanks.

Running benchmark_inference on the CPU

I am running scripts/benchmark_inference on the CPU (on a Mac M2 with Ventura OS). There are several issues with the code:
Could you please run the code on the CPU with a version of Torch which does not have CUDA? Or amend the README file stating the constraints on the code?

Thanks.

  • I added the following section after DEVICE definition:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if DEVICE == 'cuda':
    DTYPE = torch.float16
else:
    DTYPE = torch.float32
print(DEVICE)

I disabled the call to benchmark_inference_memory. The function calls Profile, which depends on the a version of PyTorch with GPU enabled. Perhaps you could fix the code?

  • For some reason, the lines:
        # Benchmark *recurrent* RetNet formulation for inference
        retnet_result = benchmark(
            retnet.forward_recurrent, x[:, 0], seq_idx=0, prev_states=[]
        )

generate the message:

  File "/Users/erlebach/src/2023/retentive_networks/retentive_code/yet-another-retnet/.venv/lib/python3.10/site-packages/torch/cuda/__init__.py", line 239, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")

Here is the full trace (the line numbers in benchmark_inference might be off by 3-4 lines because of my additions at the top of the code.

Traceback (most recent call last):
  File "/Users/erlebach/opt/miniconda3/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/Users/erlebach/opt/miniconda3/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/Users/erlebach/src/2023/retentive_networks/retentive_code/yet-another-retnet/scripts/benchmark_inference.py", line 165, in <module>
    retnet_throughputs, transformer_throughputs = benchmark_inference_throughput(
  File "/Users/erlebach/src/2023/retentive_networks/retentive_code/yet-another-retnet/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/Users/erlebach/src/2023/retentive_networks/retentive_code/yet-another-retnet/scripts/benchmark_inference.py", line 109, in benchmark_inference_throughput
    retnet_result = benchmark(
  File "/Users/erlebach/src/2023/retentive_networks/retentive_code/yet-another-retnet/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/Users/erlebach/src/2023/retentive_networks/retentive_code/yet-another-retnet/yet_another_retnet/utils/benchmark.py", line 43, in benchmark
    _ = timer.repeat(number=1, repeat=5)
  File "/Users/erlebach/opt/miniconda3/lib/python3.10/timeit.py", line 206, in repeat
    t = self.timeit(number)
  File "/Users/erlebach/opt/miniconda3/lib/python3.10/timeit.py", line 178, in timeit
    timing = self.inner(it, self.timer)
  File "<timeit-src>", line 6, in inner
  File "/Users/erlebach/src/2023/retentive_networks/retentive_code/yet-another-retnet/.venv/lib/python3.10/site-packages/torch/cuda/__init__.py", line 686, in synchronize
    _lazy_init()
  File "/Users/erlebach/src/2023/retentive_networks/retentive_code/yet-another-retnet/.venv/lib/python3.10/site-packages/torch/cuda/__init__.py", line 239, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")

Have you ever tried Retnet for vision tasks?

Hi, Thank you for your great work. The Retnet version you provided is the easiest to understand and clear to understand version I have ever seen.
Have you ever tried using retnet module for visual tasks (instead of transformer)? If you've tried it, I'm curious how it turned out.
Thank you for your work again!

An initiallization issue

Hi again (:
I've found a small problem in the current implementation of the initiallization of the RetNetDecoder class. Specifically, to build a multi-layered model, this class uses deepcopy to copy the single RetNetDecoderLayer object it recieves as input. This copy leads to the following problems:

  1. The parameters of the layers are not I.I.D.
  2. Consequently, the "lottery ticket hypothesis" does not apply (at least there is no established evidence for this phenomena in the non I.I.D. case).

It's not a very serious issue, but I think it's worth fixing. I would be happy to implement a solution. I wanted to discuss which design would be preferred here:
One possible solution could be to change RetNetDecoder.__init__ to get a list of layer objects (initiallized externally).
Alternatively, it is also possible to store the arguments of the layer as properties and initiallize the new layers based on the properties of the given layer.
Another possible solution could be to define a configuration object with which a RetNetDecoderLayer object is initiallized, and pass an instance of it to RetNetDecoder.__init__ instead of an actual layer object.

There may be other solutions as well. Which one do you think would be ideal here? Do you have other solution ideas?
Thanks!

How to train with long sequences using chunkwise feature of RetNet?

Hello,

I am interested in training a model using the chunkwise feature of RetNet to handle long sequences.However, I couldn't find detailed instructions on how to do this in the documentation.
Could you please guide me on the best practices or steps to train a model with long sequences using the chunkwise in RetNet?

Thank you.

ModelCheckpoint() argument after ** must be a mapping, not ModelCheckpoint

I tried to do inference but got the above error.
I'm using python3.10 on Ubuntu 22.04. i9-13900K + 4090
I fixed this by changing return cls(**checkpoint_dict) to return cls(**vars(checkpoint_dict))

  File "/home/dwood/LLM/yar/./venv/lib/python3.10/site-packages/scripts/train_project_gutenberg.py", line 95, in load
    return cls(**checkpoint_dict)
TypeError: __main__.ModelCheckpoint() argument after ** must be a mapping, not ModelCheckpoint

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.