Code Monkey home page Code Monkey logo

Comments (15)

ramkumarkoppu avatar ramkumarkoppu commented on June 3, 2024 1

@ramkumarkoppu just curious Ram, are you interested in this repo for academia or for some company you work at?

this is for personal research project.

from q-transformer.

justindujardin avatar justindujardin commented on June 3, 2024 1

@ramkumarkoppu could you try changing your Agent initialization to

agent = Agent(
    model,
    environment = env,
    num_episodes = 100,
    max_num_steps_per_episode = 10,
)

@lucidrains just tried locally and reproduced the error. Your fix works for me

from q-transformer.

ramkumarkoppu avatar ramkumarkoppu commented on June 3, 2024 1

Hi @lucidrains, thanks for your support. but I do have this info to share on this issue:
Mine is 64-bit,
image
and after training completes, I get this:

training completed
Traceback (most recent call last):
  File "/home/ram/github/q-transformer/usage.py", line 78, in <module>
    actions = model.get_optimal_actions(video, instructions)
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ram/github/q-transformer/q_transformer/q_robotic_transformer.py", line 1107, in get_optimal_actions
    encoded_state = self.encode_state(*args, **kwargs)
  File "/home/ram/github/q-transformer/q_transformer/q_robotic_transformer.py", line 1169, in encode_state
    tokens = self.vit(
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "<@beartype(q_transformer.q_robotic_transformer.MaxViT.forward) at 0x7f9f8e77faf0>", line 78, in forward
  File "/home/ram/github/q-transformer/q_transformer/q_robotic_transformer.py", line 447, in forward
    x = self.conv_stem(img)
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/container.py", line 215, in forward
    input = module(input)
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

from q-transformer.

lucidrains avatar lucidrains commented on June 3, 2024

@ramkumarkoppu hi Ram, and thanks for your interest in this repo

i couldn't reproduce your bug, but found another one when i ran the example (should be fixed in the latest version)

could you try the latest version and see if it incidentally resolves your issue?

from q-transformer.

lucidrains avatar lucidrains commented on June 3, 2024

@ramkumarkoppu just curious Ram, are you interested in this repo for academia or for some company you work at?

from q-transformer.

justindujardin avatar justindujardin commented on June 3, 2024

@ramkumarkoppu do you have enough disk space for the memmap files? I had to decrease the demo hyperparams because my disk wasn't big enough to hold them. I think the error was similar to this.

from q-transformer.

ramkumarkoppu avatar ramkumarkoppu commented on June 3, 2024

Hi @lucidrains,
Tried new version, but still have the same problem:

pip show q-transformer
Name: q-transformer
Version: 0.1.9
Summary: Q-Transformer
Home-page: https://github.com/lucidrains/q-transformer
Author: Phil Wang
Author-email: [email protected]
License: MIT
Location: /home/ram/github/q-transformer
Editable project location: /home/ram/github/q-transformer
Requires: accelerate, beartype, classifier-free-guidance-pytorch, einops, ema-pytorch, numpy, torch, torchtyping
Required-by: 
python example1.py 
using memory efficient attention
/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  return self.fget.__get__(instance, owner)()
Traceback (most recent call last):
  File "/home/ram/github/q-transformer/example1.py", line 47, in <module>
    agent = Agent(
  File "<@beartype(q_transformer.agent.Agent.__init__) at 0x7f5dbc1953a0>", line 145, in __init__
  File "/home/ram/github/q-transformer/q_transformer/agent.py", line 208, in __init__
    self.states      = open_memmap(str(states_path), dtype = 'float32', mode = 'w+', shape = (*prec_shape, *state_shape))
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/numpy/lib/format.py", line 945, in open_memmap
    marray = numpy.memmap(filename, dtype=dtype, shape=shape, order=order,
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/numpy/core/memmap.py", line 254, in __new__
    fid.seek(bytes - 1, 0)
OSError: [Errno 22] Invalid argument

my python environment is:

conda list
# packages in environment at /home/ram/anaconda3/envs/q-transformer:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main  
_openmp_mutex             5.1                       1_gnu  
accelerate                0.26.1                   pypi_0    pypi
anyio                     4.2.0            py39h06a4308_0  
argon2-cffi               23.1.0                   pypi_0    pypi
argon2-cffi-bindings      21.2.0           py39h7f8727e_0  
arrow                     1.3.0                    pypi_0    pypi
asttokens                 2.4.1                    pypi_0    pypi
async-lru                 2.0.4            py39h06a4308_0  
attrs                     23.2.0                   pypi_0    pypi
babel                     2.14.0                   pypi_0    pypi
backcall                  0.2.0              pyhd3eb1b0_0  
beartype                  0.17.0                   pypi_0    pypi
beautifulsoup4            4.12.3                   pypi_0    pypi
bleach                    6.1.0                    pypi_0    pypi
brotli-python             1.0.9            py39h6a678d5_7  
ca-certificates           2023.12.12           h06a4308_0  
certifi                   2023.11.17       py39h06a4308_0  
cffi                      1.16.0           py39h5eee18b_0  
chardet                   5.2.0                    pypi_0    pypi
charset-normalizer        3.3.2                    pypi_0    pypi
classifier-free-guidance-pytorch 0.5.2                    pypi_0    pypi
comm                      0.2.1                    pypi_0    pypi
cryptography              41.0.7           py39hdda0065_0  
cyrus-sasl                2.1.28               h52b45da_1  
dbus                      1.13.18              hb2f20db_0  
debugpy                   1.8.0                    pypi_0    pypi
decorator                 5.1.1              pyhd3eb1b0_0  
defusedxml                0.7.1              pyhd3eb1b0_0  
einops                    0.7.0                    pypi_0    pypi
ema-pytorch               0.3.2                    pypi_0    pypi
exceptiongroup            1.2.0            py39h06a4308_0  
executing                 2.0.1                    pypi_0    pypi
expat                     2.5.0                h6a678d5_0  
fastjsonschema            2.19.1                   pypi_0    pypi
filelock                  3.13.1                   pypi_0    pypi
fontconfig                2.14.1               h4c34cd2_2  
fqdn                      1.5.1                    pypi_0    pypi
freetype                  2.12.1               h4a9f257_0  
fsspec                    2023.12.2                pypi_0    pypi
ftfy                      6.1.3                    pypi_0    pypi
glib                      2.69.1               he621ea3_2  
gst-plugins-base          1.14.1               h6a678d5_1  
gstreamer                 1.14.1               h5eee18b_1  
huggingface-hub           0.20.3                   pypi_0    pypi
icu                       73.1                 h6a678d5_0  
idna                      3.6                      pypi_0    pypi
importlib-metadata        7.0.1            py39h06a4308_0  
importlib_metadata        7.0.1                hd3eb1b0_0  
ipykernel                 6.29.0                   pypi_0    pypi
ipython                   8.18.1                   pypi_0    pypi
ipywidgets                8.1.1                    pypi_0    pypi
isoduration               20.11.0                  pypi_0    pypi
jedi                      0.19.1                   pypi_0    pypi
jinja2                    3.1.3                    pypi_0    pypi
jpeg                      9e                   h5eee18b_1  
json5                     0.9.14                   pypi_0    pypi
jsonpointer               2.4                      pypi_0    pypi
jsonschema                4.21.1                   pypi_0    pypi
jsonschema-specifications 2023.12.1                pypi_0    pypi
jupyter                   1.0.0            py39h06a4308_8  
jupyter-core              5.7.1                    pypi_0    pypi
jupyter-events            0.9.0                    pypi_0    pypi
jupyter-lsp               2.2.2                    pypi_0    pypi
jupyter-server            2.12.5                   pypi_0    pypi
jupyter-server-terminals  0.5.2                    pypi_0    pypi
jupyter_client            8.6.0            py39h06a4308_0  
jupyter_console           6.6.3            py39h06a4308_0  
jupyter_core              5.5.0            py39h06a4308_0  
jupyter_events            0.8.0            py39h06a4308_0  
jupyter_server            2.10.0           py39h06a4308_0  
jupyter_server_terminals  0.4.4            py39h06a4308_1  
jupyterlab                4.0.11                   pypi_0    pypi
jupyterlab-pygments       0.3.0                    pypi_0    pypi
jupyterlab-server         2.25.2                   pypi_0    pypi
jupyterlab_pygments       0.1.2                      py_0  
jupyterlab_server         2.25.1           py39h06a4308_0  
jupyterlab_widgets        3.0.9            py39h06a4308_0  
krb5                      1.20.1               h143b758_1  
ld_impl_linux-64          2.38                 h1181459_1  
libclang                  14.0.6          default_hc6dbbc7_1  
libclang13                14.0.6          default_he11475f_1  
libcups                   2.4.2                h2d74bed_1  
libedit                   3.1.20230828         h5eee18b_0  
libffi                    3.4.4                h6a678d5_0  
libgcc-ng                 11.2.0               h1234567_1  
libgomp                   11.2.0               h1234567_1  
libllvm14                 14.0.6               hdb19cb5_3  
libpng                    1.6.39               h5eee18b_0  
libpq                     12.17                hdbd6064_0  
libsodium                 1.0.18               h7b6447c_0  
libstdcxx-ng              11.2.0               h1234567_1  
libuuid                   1.41.5               h5eee18b_0  
libxcb                    1.15                 h7f8727e_0  
libxkbcommon              1.0.1                h5eee18b_1  
libxml2                   2.10.4               hf1b16e4_1  
lz4-c                     1.9.4                h6a678d5_0  
markupsafe                2.1.4                    pypi_0    pypi
matplotlib-inline         0.1.6            py39h06a4308_0  
mistune                   3.0.2                    pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
mysql                     5.7.24               h721c034_2  
nbclient                  0.9.0                    pypi_0    pypi
nbconvert                 7.14.2                   pypi_0    pypi
nbformat                  5.9.2            py39h06a4308_0  
ncurses                   6.4                  h6a678d5_0  
nest-asyncio              1.6.0                    pypi_0    pypi
networkx                  3.2.1                    pypi_0    pypi
notebook                  7.0.7                    pypi_0    pypi
notebook-shim             0.2.3            py39h06a4308_0  
numpy                     1.26.3                   pypi_0    pypi
nvidia-cublas-cu12        12.1.3.1                 pypi_0    pypi
nvidia-cuda-cupti-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-runtime-cu12  12.1.105                 pypi_0    pypi
nvidia-cudnn-cu12         8.9.2.26                 pypi_0    pypi
nvidia-cufft-cu12         11.0.2.54                pypi_0    pypi
nvidia-curand-cu12        10.3.2.106               pypi_0    pypi
nvidia-cusolver-cu12      11.4.5.107               pypi_0    pypi
nvidia-cusparse-cu12      12.1.0.106               pypi_0    pypi
nvidia-nccl-cu12          2.18.1                   pypi_0    pypi
nvidia-nvjitlink-cu12     12.3.101                 pypi_0    pypi
nvidia-nvtx-cu12          12.1.105                 pypi_0    pypi
open-clip-torch           2.24.0                   pypi_0    pypi
openssl                   3.0.12               h7f8727e_0  
overrides                 7.6.0                    pypi_0    pypi
packaging                 23.2                     pypi_0    pypi
pandocfilters             1.5.1                    pypi_0    pypi
parso                     0.8.3              pyhd3eb1b0_0  
pcre                      8.45                 h295c915_0  
pexpect                   4.9.0                    pypi_0    pypi
pickleshare               0.7.5           pyhd3eb1b0_1003  
pillow                    10.2.0                   pypi_0    pypi
pip                       23.3.1           py39h06a4308_0  
platformdirs              4.1.0                    pypi_0    pypi
ply                       3.11             py39h06a4308_0  
prometheus-client         0.19.0                   pypi_0    pypi
prometheus_client         0.14.1           py39h06a4308_0  
prompt-toolkit            3.0.43           py39h06a4308_0  
prompt_toolkit            3.0.43               hd3eb1b0_0  
protobuf                  4.25.2                   pypi_0    pypi
psutil                    5.9.8                    pypi_0    pypi
ptyprocess                0.7.0              pyhd3eb1b0_2  
pure_eval                 0.2.2              pyhd3eb1b0_0  
pycparser                 2.21               pyhd3eb1b0_0  
pygments                  2.17.2                   pypi_0    pypi
pyopenssl                 23.2.0           py39h06a4308_0  
pyqt                      5.15.10          py39h6a678d5_0  
pyqt5-sip                 12.13.0          py39h5eee18b_0  
pysocks                   1.7.1            py39h06a4308_0  
python                    3.9.18               h955ad1f_0  
python-dateutil           2.8.2              pyhd3eb1b0_0  
python-fastjsonschema     2.16.2           py39h06a4308_0  
python-json-logger        2.0.7            py39h06a4308_0  
pytz                      2023.3.post1     py39h06a4308_0  
pyyaml                    6.0.1            py39h5eee18b_0  
pyzmq                     25.1.2           py39h6a678d5_0  
q-transformer             0.1.9                     dev_0    <develop>
qt-main                   5.15.2              h53bd1ea_10  
qtconsole                 5.5.1                    pypi_0    pypi
qtpy                      2.4.1            py39h06a4308_0  
readline                  8.2                  h5eee18b_0  
referencing               0.32.1                   pypi_0    pypi
regex                     2023.12.25               pypi_0    pypi
requests                  2.31.0           py39h06a4308_0  
rfc3339-validator         0.1.4            py39h06a4308_0  
rfc3986-validator         0.1.1            py39h06a4308_0  
rpds-py                   0.17.1                   pypi_0    pypi
safetensors               0.4.2                    pypi_0    pypi
send2trash                1.8.2            py39h06a4308_0  
sentencepiece             0.1.99                   pypi_0    pypi
setuptools                68.2.2           py39h06a4308_0  
sip                       6.7.12           py39h6a678d5_0  
six                       1.16.0             pyhd3eb1b0_1  
sniffio                   1.3.0            py39h06a4308_0  
soupsieve                 2.5              py39h06a4308_0  
sqlite                    3.41.2               h5eee18b_0  
stack-data                0.6.3                    pypi_0    pypi
stack_data                0.2.0              pyhd3eb1b0_0  
sympy                     1.12                     pypi_0    pypi
terminado                 0.18.0                   pypi_0    pypi
timm                      0.9.12                   pypi_0    pypi
tinycss2                  1.2.1            py39h06a4308_0  
tk                        8.6.12               h1ccaba5_0  
tokenizers                0.15.1                   pypi_0    pypi
tomli                     2.0.1            py39h06a4308_0  
torch                     2.1.2                    pypi_0    pypi
torchtyping               0.1.4                    pypi_0    pypi
torchvision               0.16.2                   pypi_0    pypi
tornado                   6.4                      pypi_0    pypi
tqdm                      4.66.1                   pypi_0    pypi
traitlets                 5.14.1                   pypi_0    pypi
transformers              4.37.1                   pypi_0    pypi
triton                    2.1.0                    pypi_0    pypi
typeguard                 4.1.5                    pypi_0    pypi
types-python-dateutil     2.8.19.20240106          pypi_0    pypi
typing-extensions         4.9.0            py39h06a4308_1  
typing_extensions         4.9.0            py39h06a4308_1  
tzdata                    2023d                h04d1e81_0  
uri-template              1.3.0                    pypi_0    pypi
urllib3                   2.1.0                    pypi_0    pypi
wcwidth                   0.2.13                   pypi_0    pypi
webcolors                 1.13                     pypi_0    pypi
webencodings              0.5.1                    pypi_0    pypi
websocket-client          1.7.0                    pypi_0    pypi
wheel                     0.41.2           py39h06a4308_0  
widgetsnbextension        4.0.9                    pypi_0    pypi
xz                        5.4.5                h5eee18b_0  
yaml                      0.2.5                h7b6447c_0  
zeromq                    4.3.5                h6a678d5_0  
zipp                      3.17.0           py39h06a4308_0  
zlib                      1.2.13               h5eee18b_0  
zstd                      1.5.5                hc292b87_0  
python --version
Python 3.9.18

the code I am using is copied to example1.py from Usage from this repo:

import torch

from q_transformer import (
    QRoboticTransformer,
    QLearner,
    Agent,
    ReplayMemoryDataset
)

# the attention model

model = QRoboticTransformer(
    vit = dict(
        num_classes = 1000,
        dim_conv_stem = 64,
        dim = 64,
        dim_head = 64,
        depth = (2, 2, 5, 2),
        window_size = 7,
        mbconv_expansion_rate = 4,
        mbconv_shrinkage_rate = 0.25,
        dropout = 0.1
    ),
    num_actions = 8,
    action_bins = 256,
    depth = 1,
    heads = 8,
    dim_head = 64,
    cond_drop_prob = 0.2,
    dueling = True
)

# you need to supply your own environment, by overriding BaseEnvironment

from q_transformer.mocks import MockEnvironment

env = MockEnvironment(
    state_shape = (3, 6, 224, 224),
    text_embed_shape = (768,)
)

# env.init()     should return instructions and initial state: Tuple[str, Tensor[*state_shape]]
# env(actions)   should return rewards, next state, and done flag: Tuple[Tensor[()], Tensor[*state_shape], Tensor[()]]

# agent is a class that allows the q-model to interact with the environment to generate a replay memory dataset for learning

agent = Agent(
    model,
    environment = env,
    num_episodes = 10000,
    max_num_steps_per_episode = 1000,
)

agent()

# Q learning on the replay memory dataset on the model

q_learner = QLearner(
    model,
    dataset = ReplayMemoryDataset(),
    num_train_steps = 10000,
    learning_rate = 3e-4,
    batch_size = 32
)

q_learner()

# after much learning
# your robot should be better at selecting optimal actions

video = torch.randn(2, 3, 6, 224, 224)

instructions = [
    'bring me that apple sitting on the table',
    'please pass the butter'
]

actions = model.get_optimal_actions(video, instructions)

from q-transformer.

ramkumarkoppu avatar ramkumarkoppu commented on June 3, 2024

@ramkumarkoppu do you have enough disk space for the memmap files? I had to decrease the demo hyperparams because my disk wasn't big enough to hold them. I think the error was similar to this.

It seems it has enough disk space:

df -h .
Filesystem      Size  Used Avail Use% Mounted on
/dev/nvme0n1p5  288G  163G  111G  60% /

from q-transformer.

lucidrains avatar lucidrains commented on June 3, 2024

@ramkumarkoppu could you try changing your Agent initialization to

agent = Agent(
    model,
    environment = env,
    num_episodes = 100,
    max_num_steps_per_episode = 10,
)

from q-transformer.

lucidrains avatar lucidrains commented on June 3, 2024

@justindujardin thanks Justin, think you are right https://stackoverflow.com/a/54205996 . i'll probably need to have some wrapper around memmap and auto manage the 2GB limitation

from q-transformer.

lucidrains avatar lucidrains commented on June 3, 2024

@justindujardin @ramkumarkoppu it works fine on my m1 mac book, just as an aside

from q-transformer.

ramkumarkoppu avatar ramkumarkoppu commented on June 3, 2024

running modified code as per your suggestion on Ubuntu with Nvidia GPU with 24GB, after running 100 episodes, it throws this:

episode 99
 90%|████████████████████████████████████████████████████████████████████▍       | 9/10 [00:09<00:01,  1.02s/it]
completed, memories stored to /home/ram/github/q-transformer/replay_memories_data
/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3526.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Traceback (most recent call last):
  File "/home/ram/github/q-transformer/example1.py", line 66, in <module>
    q_learner()
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ram/github/q-transformer/q_transformer/q_learner.py", line 598, in forward
    loss, (td_loss, conservative_reg_loss) = self.learn(
  File "/home/ram/github/q-transformer/q_transformer/q_learner.py", line 525, in learn
    td_loss, q_intermediates = self.autoregressive_q_learn_handle_single_timestep(*args, **q_learn_kwargs)
  File "/home/ram/github/q-transformer/q_transformer/q_learner.py", line 403, in autoregressive_q_learn_handle_single_timestep
    return self.autoregressive_q_learn(text_embeds, states, actions, next_states, rewards, dones, monte_carlo_return = monte_carlo_return)
  File "/home/ram/github/q-transformer/q_transformer/q_learner.py", line 462, in autoregressive_q_learn
    q_pred_all_actions = self.model(states, text_embeds = repeated_text_embeds, actions = actions)
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py", line 138, in inner
    return fn_maybe_with_text(self, *args, **kwargs)
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py", line 131, in fn_maybe_with_text
    return fn(self, *args, **kwargs)
  File "/home/ram/github/q-transformer/q_transformer/q_robotic_transformer.py", line 1210, in forward
    encoded_state = self.encode_state(
  File "/home/ram/github/q-transformer/q_transformer/q_robotic_transformer.py", line 1169, in encode_state
    tokens = self.vit(
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "<@beartype(q_transformer.q_robotic_transformer.MaxViT.forward) at 0x7f4cad184c10>", line 78, in forward
  File "/home/ram/github/q-transformer/q_transformer/q_robotic_transformer.py", line 471, in forward
    x = windowed_attn(x, rotary_emb = rotary_emb)
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ram/github/q-transformer/q_transformer/q_robotic_transformer.py", line 136, in forward
    return self.fn(x, **kwargs) + x
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ram/github/q-transformer/q_transformer/q_robotic_transformer.py", line 319, in forward
    q = apply_rotary_pos_emb(rotary_emb, q)
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "/home/ram/github/q-transformer/q_transformer/q_robotic_transformer.py", line 84, in apply_rotary_pos_emb
    return t * pos.cos() + rotate_half(t) * pos.sin()
**torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.15 GiB. GPU 0 has a total capacty of 23.64 GiB of which 1.12 GiB is free. Including non-PyTorch memory, this process has 22.52 GiB memory in use. Of the allocated memory 22.22 GiB is allocated by PyTorch, and 94.04 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF**

How much GPU memory it is needed?

from q-transformer.

lucidrains avatar lucidrains commented on June 3, 2024

@ramkumarkoppu instantiate the QLearner as

q_learner = QLearner(
    model,
    dataset = ReplayMemoryDataset(),
    num_train_steps = 10000,
    learning_rate = 3e-4,
    batch_size = 1,
    grad_accum_every = 16,
)

from q-transformer.

lucidrains avatar lucidrains commented on June 3, 2024

@ramkumarkoppu just going to only allow for training on 64 bit systems for now, sorry!

from q-transformer.

lucidrains avatar lucidrains commented on June 3, 2024

@ramkumarkoppu yea i see, ok i'll invest some time for a chunked memmap solution

as for the latter error, retry on the latest version

from q-transformer.

Related Issues (8)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.