Comments (15)
@ramkumarkoppu just curious Ram, are you interested in this repo for academia or for some company you work at?
this is for personal research project.
from q-transformer.
@ramkumarkoppu could you try changing your
Agent
initialization toagent = Agent( model, environment = env, num_episodes = 100, max_num_steps_per_episode = 10, )
@lucidrains just tried locally and reproduced the error. Your fix works for me
from q-transformer.
Hi @lucidrains, thanks for your support. but I do have this info to share on this issue:
Mine is 64-bit,
and after training completes, I get this:
training completed
Traceback (most recent call last):
File "/home/ram/github/q-transformer/usage.py", line 78, in <module>
actions = model.get_optimal_actions(video, instructions)
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/ram/github/q-transformer/q_transformer/q_robotic_transformer.py", line 1107, in get_optimal_actions
encoded_state = self.encode_state(*args, **kwargs)
File "/home/ram/github/q-transformer/q_transformer/q_robotic_transformer.py", line 1169, in encode_state
tokens = self.vit(
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "<@beartype(q_transformer.q_robotic_transformer.MaxViT.forward) at 0x7f9f8e77faf0>", line 78, in forward
File "/home/ram/github/q-transformer/q_transformer/q_robotic_transformer.py", line 447, in forward
x = self.conv_stem(img)
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/container.py", line 215, in forward
input = module(input)
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 460, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
from q-transformer.
@ramkumarkoppu hi Ram, and thanks for your interest in this repo
i couldn't reproduce your bug, but found another one when i ran the example (should be fixed in the latest version)
could you try the latest version and see if it incidentally resolves your issue?
from q-transformer.
@ramkumarkoppu just curious Ram, are you interested in this repo for academia or for some company you work at?
from q-transformer.
@ramkumarkoppu do you have enough disk space for the memmap files? I had to decrease the demo hyperparams because my disk wasn't big enough to hold them. I think the error was similar to this.
from q-transformer.
Hi @lucidrains,
Tried new version, but still have the same problem:
pip show q-transformer
Name: q-transformer
Version: 0.1.9
Summary: Q-Transformer
Home-page: https://github.com/lucidrains/q-transformer
Author: Phil Wang
Author-email: [email protected]
License: MIT
Location: /home/ram/github/q-transformer
Editable project location: /home/ram/github/q-transformer
Requires: accelerate, beartype, classifier-free-guidance-pytorch, einops, ema-pytorch, numpy, torch, torchtyping
Required-by:
python example1.py
using memory efficient attention
/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
return self.fget.__get__(instance, owner)()
Traceback (most recent call last):
File "/home/ram/github/q-transformer/example1.py", line 47, in <module>
agent = Agent(
File "<@beartype(q_transformer.agent.Agent.__init__) at 0x7f5dbc1953a0>", line 145, in __init__
File "/home/ram/github/q-transformer/q_transformer/agent.py", line 208, in __init__
self.states = open_memmap(str(states_path), dtype = 'float32', mode = 'w+', shape = (*prec_shape, *state_shape))
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/numpy/lib/format.py", line 945, in open_memmap
marray = numpy.memmap(filename, dtype=dtype, shape=shape, order=order,
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/numpy/core/memmap.py", line 254, in __new__
fid.seek(bytes - 1, 0)
OSError: [Errno 22] Invalid argument
my python environment is:
conda list
# packages in environment at /home/ram/anaconda3/envs/q-transformer:
#
# Name Version Build Channel
_libgcc_mutex 0.1 main
_openmp_mutex 5.1 1_gnu
accelerate 0.26.1 pypi_0 pypi
anyio 4.2.0 py39h06a4308_0
argon2-cffi 23.1.0 pypi_0 pypi
argon2-cffi-bindings 21.2.0 py39h7f8727e_0
arrow 1.3.0 pypi_0 pypi
asttokens 2.4.1 pypi_0 pypi
async-lru 2.0.4 py39h06a4308_0
attrs 23.2.0 pypi_0 pypi
babel 2.14.0 pypi_0 pypi
backcall 0.2.0 pyhd3eb1b0_0
beartype 0.17.0 pypi_0 pypi
beautifulsoup4 4.12.3 pypi_0 pypi
bleach 6.1.0 pypi_0 pypi
brotli-python 1.0.9 py39h6a678d5_7
ca-certificates 2023.12.12 h06a4308_0
certifi 2023.11.17 py39h06a4308_0
cffi 1.16.0 py39h5eee18b_0
chardet 5.2.0 pypi_0 pypi
charset-normalizer 3.3.2 pypi_0 pypi
classifier-free-guidance-pytorch 0.5.2 pypi_0 pypi
comm 0.2.1 pypi_0 pypi
cryptography 41.0.7 py39hdda0065_0
cyrus-sasl 2.1.28 h52b45da_1
dbus 1.13.18 hb2f20db_0
debugpy 1.8.0 pypi_0 pypi
decorator 5.1.1 pyhd3eb1b0_0
defusedxml 0.7.1 pyhd3eb1b0_0
einops 0.7.0 pypi_0 pypi
ema-pytorch 0.3.2 pypi_0 pypi
exceptiongroup 1.2.0 py39h06a4308_0
executing 2.0.1 pypi_0 pypi
expat 2.5.0 h6a678d5_0
fastjsonschema 2.19.1 pypi_0 pypi
filelock 3.13.1 pypi_0 pypi
fontconfig 2.14.1 h4c34cd2_2
fqdn 1.5.1 pypi_0 pypi
freetype 2.12.1 h4a9f257_0
fsspec 2023.12.2 pypi_0 pypi
ftfy 6.1.3 pypi_0 pypi
glib 2.69.1 he621ea3_2
gst-plugins-base 1.14.1 h6a678d5_1
gstreamer 1.14.1 h5eee18b_1
huggingface-hub 0.20.3 pypi_0 pypi
icu 73.1 h6a678d5_0
idna 3.6 pypi_0 pypi
importlib-metadata 7.0.1 py39h06a4308_0
importlib_metadata 7.0.1 hd3eb1b0_0
ipykernel 6.29.0 pypi_0 pypi
ipython 8.18.1 pypi_0 pypi
ipywidgets 8.1.1 pypi_0 pypi
isoduration 20.11.0 pypi_0 pypi
jedi 0.19.1 pypi_0 pypi
jinja2 3.1.3 pypi_0 pypi
jpeg 9e h5eee18b_1
json5 0.9.14 pypi_0 pypi
jsonpointer 2.4 pypi_0 pypi
jsonschema 4.21.1 pypi_0 pypi
jsonschema-specifications 2023.12.1 pypi_0 pypi
jupyter 1.0.0 py39h06a4308_8
jupyter-core 5.7.1 pypi_0 pypi
jupyter-events 0.9.0 pypi_0 pypi
jupyter-lsp 2.2.2 pypi_0 pypi
jupyter-server 2.12.5 pypi_0 pypi
jupyter-server-terminals 0.5.2 pypi_0 pypi
jupyter_client 8.6.0 py39h06a4308_0
jupyter_console 6.6.3 py39h06a4308_0
jupyter_core 5.5.0 py39h06a4308_0
jupyter_events 0.8.0 py39h06a4308_0
jupyter_server 2.10.0 py39h06a4308_0
jupyter_server_terminals 0.4.4 py39h06a4308_1
jupyterlab 4.0.11 pypi_0 pypi
jupyterlab-pygments 0.3.0 pypi_0 pypi
jupyterlab-server 2.25.2 pypi_0 pypi
jupyterlab_pygments 0.1.2 py_0
jupyterlab_server 2.25.1 py39h06a4308_0
jupyterlab_widgets 3.0.9 py39h06a4308_0
krb5 1.20.1 h143b758_1
ld_impl_linux-64 2.38 h1181459_1
libclang 14.0.6 default_hc6dbbc7_1
libclang13 14.0.6 default_he11475f_1
libcups 2.4.2 h2d74bed_1
libedit 3.1.20230828 h5eee18b_0
libffi 3.4.4 h6a678d5_0
libgcc-ng 11.2.0 h1234567_1
libgomp 11.2.0 h1234567_1
libllvm14 14.0.6 hdb19cb5_3
libpng 1.6.39 h5eee18b_0
libpq 12.17 hdbd6064_0
libsodium 1.0.18 h7b6447c_0
libstdcxx-ng 11.2.0 h1234567_1
libuuid 1.41.5 h5eee18b_0
libxcb 1.15 h7f8727e_0
libxkbcommon 1.0.1 h5eee18b_1
libxml2 2.10.4 hf1b16e4_1
lz4-c 1.9.4 h6a678d5_0
markupsafe 2.1.4 pypi_0 pypi
matplotlib-inline 0.1.6 py39h06a4308_0
mistune 3.0.2 pypi_0 pypi
mpmath 1.3.0 pypi_0 pypi
mysql 5.7.24 h721c034_2
nbclient 0.9.0 pypi_0 pypi
nbconvert 7.14.2 pypi_0 pypi
nbformat 5.9.2 py39h06a4308_0
ncurses 6.4 h6a678d5_0
nest-asyncio 1.6.0 pypi_0 pypi
networkx 3.2.1 pypi_0 pypi
notebook 7.0.7 pypi_0 pypi
notebook-shim 0.2.3 py39h06a4308_0
numpy 1.26.3 pypi_0 pypi
nvidia-cublas-cu12 12.1.3.1 pypi_0 pypi
nvidia-cuda-cupti-cu12 12.1.105 pypi_0 pypi
nvidia-cuda-nvrtc-cu12 12.1.105 pypi_0 pypi
nvidia-cuda-runtime-cu12 12.1.105 pypi_0 pypi
nvidia-cudnn-cu12 8.9.2.26 pypi_0 pypi
nvidia-cufft-cu12 11.0.2.54 pypi_0 pypi
nvidia-curand-cu12 10.3.2.106 pypi_0 pypi
nvidia-cusolver-cu12 11.4.5.107 pypi_0 pypi
nvidia-cusparse-cu12 12.1.0.106 pypi_0 pypi
nvidia-nccl-cu12 2.18.1 pypi_0 pypi
nvidia-nvjitlink-cu12 12.3.101 pypi_0 pypi
nvidia-nvtx-cu12 12.1.105 pypi_0 pypi
open-clip-torch 2.24.0 pypi_0 pypi
openssl 3.0.12 h7f8727e_0
overrides 7.6.0 pypi_0 pypi
packaging 23.2 pypi_0 pypi
pandocfilters 1.5.1 pypi_0 pypi
parso 0.8.3 pyhd3eb1b0_0
pcre 8.45 h295c915_0
pexpect 4.9.0 pypi_0 pypi
pickleshare 0.7.5 pyhd3eb1b0_1003
pillow 10.2.0 pypi_0 pypi
pip 23.3.1 py39h06a4308_0
platformdirs 4.1.0 pypi_0 pypi
ply 3.11 py39h06a4308_0
prometheus-client 0.19.0 pypi_0 pypi
prometheus_client 0.14.1 py39h06a4308_0
prompt-toolkit 3.0.43 py39h06a4308_0
prompt_toolkit 3.0.43 hd3eb1b0_0
protobuf 4.25.2 pypi_0 pypi
psutil 5.9.8 pypi_0 pypi
ptyprocess 0.7.0 pyhd3eb1b0_2
pure_eval 0.2.2 pyhd3eb1b0_0
pycparser 2.21 pyhd3eb1b0_0
pygments 2.17.2 pypi_0 pypi
pyopenssl 23.2.0 py39h06a4308_0
pyqt 5.15.10 py39h6a678d5_0
pyqt5-sip 12.13.0 py39h5eee18b_0
pysocks 1.7.1 py39h06a4308_0
python 3.9.18 h955ad1f_0
python-dateutil 2.8.2 pyhd3eb1b0_0
python-fastjsonschema 2.16.2 py39h06a4308_0
python-json-logger 2.0.7 py39h06a4308_0
pytz 2023.3.post1 py39h06a4308_0
pyyaml 6.0.1 py39h5eee18b_0
pyzmq 25.1.2 py39h6a678d5_0
q-transformer 0.1.9 dev_0 <develop>
qt-main 5.15.2 h53bd1ea_10
qtconsole 5.5.1 pypi_0 pypi
qtpy 2.4.1 py39h06a4308_0
readline 8.2 h5eee18b_0
referencing 0.32.1 pypi_0 pypi
regex 2023.12.25 pypi_0 pypi
requests 2.31.0 py39h06a4308_0
rfc3339-validator 0.1.4 py39h06a4308_0
rfc3986-validator 0.1.1 py39h06a4308_0
rpds-py 0.17.1 pypi_0 pypi
safetensors 0.4.2 pypi_0 pypi
send2trash 1.8.2 py39h06a4308_0
sentencepiece 0.1.99 pypi_0 pypi
setuptools 68.2.2 py39h06a4308_0
sip 6.7.12 py39h6a678d5_0
six 1.16.0 pyhd3eb1b0_1
sniffio 1.3.0 py39h06a4308_0
soupsieve 2.5 py39h06a4308_0
sqlite 3.41.2 h5eee18b_0
stack-data 0.6.3 pypi_0 pypi
stack_data 0.2.0 pyhd3eb1b0_0
sympy 1.12 pypi_0 pypi
terminado 0.18.0 pypi_0 pypi
timm 0.9.12 pypi_0 pypi
tinycss2 1.2.1 py39h06a4308_0
tk 8.6.12 h1ccaba5_0
tokenizers 0.15.1 pypi_0 pypi
tomli 2.0.1 py39h06a4308_0
torch 2.1.2 pypi_0 pypi
torchtyping 0.1.4 pypi_0 pypi
torchvision 0.16.2 pypi_0 pypi
tornado 6.4 pypi_0 pypi
tqdm 4.66.1 pypi_0 pypi
traitlets 5.14.1 pypi_0 pypi
transformers 4.37.1 pypi_0 pypi
triton 2.1.0 pypi_0 pypi
typeguard 4.1.5 pypi_0 pypi
types-python-dateutil 2.8.19.20240106 pypi_0 pypi
typing-extensions 4.9.0 py39h06a4308_1
typing_extensions 4.9.0 py39h06a4308_1
tzdata 2023d h04d1e81_0
uri-template 1.3.0 pypi_0 pypi
urllib3 2.1.0 pypi_0 pypi
wcwidth 0.2.13 pypi_0 pypi
webcolors 1.13 pypi_0 pypi
webencodings 0.5.1 pypi_0 pypi
websocket-client 1.7.0 pypi_0 pypi
wheel 0.41.2 py39h06a4308_0
widgetsnbextension 4.0.9 pypi_0 pypi
xz 5.4.5 h5eee18b_0
yaml 0.2.5 h7b6447c_0
zeromq 4.3.5 h6a678d5_0
zipp 3.17.0 py39h06a4308_0
zlib 1.2.13 h5eee18b_0
zstd 1.5.5 hc292b87_0
python --version
Python 3.9.18
the code I am using is copied to example1.py from Usage from this repo:
import torch
from q_transformer import (
QRoboticTransformer,
QLearner,
Agent,
ReplayMemoryDataset
)
# the attention model
model = QRoboticTransformer(
vit = dict(
num_classes = 1000,
dim_conv_stem = 64,
dim = 64,
dim_head = 64,
depth = (2, 2, 5, 2),
window_size = 7,
mbconv_expansion_rate = 4,
mbconv_shrinkage_rate = 0.25,
dropout = 0.1
),
num_actions = 8,
action_bins = 256,
depth = 1,
heads = 8,
dim_head = 64,
cond_drop_prob = 0.2,
dueling = True
)
# you need to supply your own environment, by overriding BaseEnvironment
from q_transformer.mocks import MockEnvironment
env = MockEnvironment(
state_shape = (3, 6, 224, 224),
text_embed_shape = (768,)
)
# env.init() should return instructions and initial state: Tuple[str, Tensor[*state_shape]]
# env(actions) should return rewards, next state, and done flag: Tuple[Tensor[()], Tensor[*state_shape], Tensor[()]]
# agent is a class that allows the q-model to interact with the environment to generate a replay memory dataset for learning
agent = Agent(
model,
environment = env,
num_episodes = 10000,
max_num_steps_per_episode = 1000,
)
agent()
# Q learning on the replay memory dataset on the model
q_learner = QLearner(
model,
dataset = ReplayMemoryDataset(),
num_train_steps = 10000,
learning_rate = 3e-4,
batch_size = 32
)
q_learner()
# after much learning
# your robot should be better at selecting optimal actions
video = torch.randn(2, 3, 6, 224, 224)
instructions = [
'bring me that apple sitting on the table',
'please pass the butter'
]
actions = model.get_optimal_actions(video, instructions)
from q-transformer.
@ramkumarkoppu do you have enough disk space for the memmap files? I had to decrease the demo hyperparams because my disk wasn't big enough to hold them. I think the error was similar to this.
It seems it has enough disk space:
df -h .
Filesystem Size Used Avail Use% Mounted on
/dev/nvme0n1p5 288G 163G 111G 60% /
from q-transformer.
@ramkumarkoppu could you try changing your Agent
initialization to
agent = Agent(
model,
environment = env,
num_episodes = 100,
max_num_steps_per_episode = 10,
)
from q-transformer.
@justindujardin thanks Justin, think you are right https://stackoverflow.com/a/54205996 . i'll probably need to have some wrapper around memmap and auto manage the 2GB limitation
from q-transformer.
@justindujardin @ramkumarkoppu it works fine on my m1 mac book, just as an aside
from q-transformer.
running modified code as per your suggestion on Ubuntu with Nvidia GPU with 24GB, after running 100 episodes, it throws this:
episode 99
90%|████████████████████████████████████████████████████████████████████▍ | 9/10 [00:09<00:01, 1.02s/it]
completed, memories stored to /home/ram/github/q-transformer/replay_memories_data
/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3526.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
Traceback (most recent call last):
File "/home/ram/github/q-transformer/example1.py", line 66, in <module>
q_learner()
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ram/github/q-transformer/q_transformer/q_learner.py", line 598, in forward
loss, (td_loss, conservative_reg_loss) = self.learn(
File "/home/ram/github/q-transformer/q_transformer/q_learner.py", line 525, in learn
td_loss, q_intermediates = self.autoregressive_q_learn_handle_single_timestep(*args, **q_learn_kwargs)
File "/home/ram/github/q-transformer/q_transformer/q_learner.py", line 403, in autoregressive_q_learn_handle_single_timestep
return self.autoregressive_q_learn(text_embeds, states, actions, next_states, rewards, dones, monte_carlo_return = monte_carlo_return)
File "/home/ram/github/q-transformer/q_transformer/q_learner.py", line 462, in autoregressive_q_learn
q_pred_all_actions = self.model(states, text_embeds = repeated_text_embeds, actions = actions)
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py", line 138, in inner
return fn_maybe_with_text(self, *args, **kwargs)
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py", line 131, in fn_maybe_with_text
return fn(self, *args, **kwargs)
File "/home/ram/github/q-transformer/q_transformer/q_robotic_transformer.py", line 1210, in forward
encoded_state = self.encode_state(
File "/home/ram/github/q-transformer/q_transformer/q_robotic_transformer.py", line 1169, in encode_state
tokens = self.vit(
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "<@beartype(q_transformer.q_robotic_transformer.MaxViT.forward) at 0x7f4cad184c10>", line 78, in forward
File "/home/ram/github/q-transformer/q_transformer/q_robotic_transformer.py", line 471, in forward
x = windowed_attn(x, rotary_emb = rotary_emb)
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ram/github/q-transformer/q_transformer/q_robotic_transformer.py", line 136, in forward
return self.fn(x, **kwargs) + x
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ram/github/q-transformer/q_transformer/q_robotic_transformer.py", line 319, in forward
q = apply_rotary_pos_emb(rotary_emb, q)
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
return func(*args, **kwargs)
File "/home/ram/github/q-transformer/q_transformer/q_robotic_transformer.py", line 84, in apply_rotary_pos_emb
return t * pos.cos() + rotate_half(t) * pos.sin()
**torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.15 GiB. GPU 0 has a total capacty of 23.64 GiB of which 1.12 GiB is free. Including non-PyTorch memory, this process has 22.52 GiB memory in use. Of the allocated memory 22.22 GiB is allocated by PyTorch, and 94.04 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF**
How much GPU memory it is needed?
from q-transformer.
@ramkumarkoppu instantiate the QLearner
as
q_learner = QLearner(
model,
dataset = ReplayMemoryDataset(),
num_train_steps = 10000,
learning_rate = 3e-4,
batch_size = 1,
grad_accum_every = 16,
)
from q-transformer.
@ramkumarkoppu just going to only allow for training on 64 bit systems for now, sorry!
from q-transformer.
@ramkumarkoppu yea i see, ok i'll invest some time for a chunked memmap solution
as for the latter error, retry on the latest version
from q-transformer.
Related Issues (8)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from q-transformer.