Code Monkey home page Code Monkey logo

dreamerv3's Introduction

Mastering Diverse Domains through World Models

A reimplementation of DreamerV3, a scalable and general reinforcement learning algorithm that masters a wide range of applications with fixed hyperparameters.

DreamerV3 Tasks

If you find this code useful, please reference in your paper:

@article{hafner2023dreamerv3,
  title={Mastering Diverse Domains through World Models},
  author={Hafner, Danijar and Pasukonis, Jurgis and Ba, Jimmy and Lillicrap, Timothy},
  journal={arXiv preprint arXiv:2301.04104},
  year={2023}
}

To learn more:

DreamerV3

DreamerV3 learns a world model from experiences and uses it to train an actor critic policy from imagined trajectories. The world model encodes sensory inputs into categorical representations and predicts future representations and rewards given actions.

DreamerV3 Method Diagram

DreamerV3 masters a wide range of domains with a fixed set of hyperparameters, outperforming specialized methods. Removing the need for tuning reduces the amount of expert knowledge and computational resources needed to apply reinforcement learning.

DreamerV3 Benchmark Scores

Due to its robustness, DreamerV3 shows favorable scaling properties. Notably, using larger models consistently increases not only its final performance but also its data-efficiency. Increasing the number of gradient steps further increases data efficiency.

DreamerV3 Scaling Behavior

Instructions

Package

If you just want to run DreamerV3 on a custom environment, you can pip install dreamerv3 and copy example.py from this repository as a starting point.

Docker

If you want to make modifications to the code, you can either use the provided Dockerfile that contains instructions or follow the manual instructions below.

Manual

Install JAX and then the other dependencies:

pip install -r requirements.txt

Simple training script:

python example.py

Flexible training script:

python dreamerv3/train.py \
  --logdir ~/logdir/$(date "+%Y%m%d-%H%M%S") \
  --configs crafter --batch_size 16 --run.train_ratio 32

Tips

  • All config options are listed in configs.yaml and you can override them from the command line.
  • The debug config block reduces the network size, batch size, duration between logs, and so on for fast debugging (but does not learn a good model).
  • By default, the code tries to run on GPU. You can switch to CPU or TPU using the --jax.platform cpu flag. Note that multi-GPU support is untested.
  • You can run with multiple config blocks that will override defaults in the order they are specified, for example --configs crafter large.
  • By default, metrics are printed to the terminal, appended to a JSON lines file, and written as TensorBoard summaries. Other outputs like WandB can be enabled in the training script.
  • If you get a Too many leaves for PyTreeDef error, it means you're reloading a checkpoint that is not compatible with the current config. This often happens when reusing an old logdir by accident.
  • If you are getting CUDA errors, scroll up because the cause is often just an error that happened earlier, such as out of memory or incompatible JAX and CUDA versions.
  • You can use the small, medium, large config blocks to reduce memory requirements. The default is xlarge. See the scaling graph above to see how this affects performance.
  • Many environments are included, some of which require installating additional packages. See the installation scripts in scripts and the Dockerfile for reference.
  • When running on custom environments, make sure to specify the observation keys the agent should be using via encoder.mlp_keys, encode.cnn_keys, decoder.mlp_keys and decoder.cnn_keys.
  • To log metrics from environments without showing them to the agent or storing them in the replay buffer, return them as observation keys with log_ prefix and enable logging via the run.log_keys_... options.
  • To continue stopped training runs, simply run the same command line again and make sure that the --logdir points to the same directory.

Disclaimer

This repository contains a reimplementation of DreamerV3 based on the open source DreamerV2 code base. It is unrelated to Google or DeepMind. The implementation has been tested to reproduce the official results on a range of environments.

dreamerv3's People

Contributors

anatlavitzkovitz avatar artemzholus avatar danijar avatar eltociear avatar hibikaze-git avatar isaacsst avatar signalprime avatar vint-1 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

dreamerv3's Issues

Request for access to the Crafter model

Hello! I'm interested in accessing the Crafter model for my master thesis. I've been searching online through this repository for the baseline but haven't been able to find it. Would it be possible for me to get access to it?

Thank you for considering my request, and I hope that you can help me.

MARL Support

I know the authors test multi-agent support on Hanabi. Are there future plans to implement similar MARL support in this repo for custom environments? Thank you for this great work!

[Question] Regarding training the reward head

Hello,

I had a quick question regarding the reward head that I was hoping you could help clarify. In the diagrams in the paper, you seem to predict the reward of the current transition from the successive state i.e. if you have the transition (s_1, a_1, r_1, s_2) you would predict r_1 using the posterior state encoded from s_2. This makes sense in my head as you need the action and the hidden state processed to get the reward. But in the code, it seems that reward is shifted back by 1. What I mean by this is it seems in the code that you process the sequence of observations to give you a sequence of posteriors and then you train the reward head on the reward sequence using these posteriors where the reward sequence starts from r_1 but the first posterior in the sequence is z_1. So I was wondering if you pad the reward sequence in the beginning or I am misunderstanding how the world model works. I've attached an image to illustrate my confusion:

dreamer

In my image, the _init variables are just the initial masked variables and/or the learned starting state. I assume the prev_action in this case is just zeros.

Thank you.

Plotting script

What indirs path does the plotting script expect? I am passing in "~/logsdirs/run_1/" but am getting an error that there is nothing to plot.

QUESTION: fast and slow critic vs. weight EMA.

The paper states "We compute λ-returns using the fast critic network and regularize the critic outputs towards those of its own weight EMA instead of computing returns using the slow critic. However, both approaches perform similarly in practice."

I am confused about the definition of fast and slow critics. Is the slow critic similar to a target network? Also, what was used in this repo? The regularizer term in the critic loss seems to involve a slow critic. Can you please explain?

How to use batch env

Hello.
Thank you for providing us with a great algorithm.
I have a question about batch env.
I have specified two envs in batchenv because of the slow progress of the game agent I am currently training.
However, when I checked the progress in sensorboard, there was no change and it was not doubled in speed.
Is it correct to use batchenv? Also, what does the parallel parameter represent?

Thank you in advance.

Unknown backend: 'gpu' requested, but no platforms that are instances of gpu are present. Platforms present are: interpreter,cpu

Steps to replicate:
1.0 docker run --gpus all -it --rm --shm-size=1g --ulimit memlock=-1 nvcr.io/nvidia/tensorflow:23.01-tf2-py3
1.1 git clone https://github.com/danijar/dreamerv3.git
1.2 cd dreamerv3 && vim requirements.txt
1.3.1 Comment out tensorflow-cpu #tensorflow-cpu
1.4 pip install -r requirements.txt
1.5 python3 -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))"
2023-02-25 19:15:14.739235: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-25 19:15:16.102908: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-25 19:15:16.103217: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-25 19:15:16.130538: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-25 19:15:16.130794: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-25 19:15:16.130965: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-25 19:15:16.131128: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')]```

1.6 TF_CPP_MIN_LOG_LEVEL=0 python3 example.py
2023-02-25 19:17:59.624066: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-25 19:18:01.559920: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-25 19:18:01.560153: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-25 19:18:01.581002: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-25 19:18:01.581257: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-25 19:18:01.581430: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-25 19:18:01.581594: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
Encoder CNN shapes: {'image': (64, 64, 3)}
Encoder MLP shapes: {}
Decoder CNN shapes: {'image': (64, 64, 3)}
Decoder MLP shapes: {}
2023-02-25 19:18:01.584024: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:169] XLA service 0x39d03b0 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
2023-02-25 19:18:01.584040: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:177] StreamExecutor device (0): Interpreter,
2023-02-25 19:18:01.602018: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc:215] TfrtCpuClient created.
2023-02-25 19:18:01.602537: I external/org_tensorflow/tensorflow/compiler/xla/stream_executor/tpu/tpu_initializer_helper.cc:266] Libtpu path is: libtpu.so
2023-02-25 19:18:01.602697: I external/org_tensorflow/tensorflow/compiler/xla/stream_executor/tpu/tpu_platform_interface.cc:73] No TPU platform found.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /workspace/dreamerv3/example.py:53 in │
│ │
│ 50 │
│ 51 │
│ 52 if name == 'main': │
│ ❱ 53 main() │
│ 54 │
│ │
│ /workspace/dreamerv3/example.py:42 in main │
│ │
│ 39 env = dreamerv3.wrap_env(env, config) │
│ 40 env = embodied.BatchEnv([env], parallel=False) │
│ 41 │
│ ❱ 42 agent = dreamerv3.Agent(env.obs_space, env.act_space, step, config) │
│ 43 replay = embodied.replay.Uniform( │
│ 44 │ config.batch_length, config.replay_size, logdir / 'replay') │
│ 45 args = embodied.Config( │
│ │
│ /workspace/dreamerv3/dreamerv3/jaxagent.py:20 in init
│ │
│ 17 │ configs = agent_cls.configs │
│ 18 │ inner = agent_cls │
│ 19 │ def init(self, *args, **kwargs): │
│ ❱ 20 │ super().init(agent_cls, *args, **kwargs) │
│ 21 return Agent │
│ 22 │
│ 23 │
│ │
│ /workspace/dreamerv3/dreamerv3/jaxagent.py:35 in init
│ │
│ 32 │ self.agent = agent_cls(obs_space, act_space, step, config, name='agent') │
│ 33 │ self.rng = np.random.default_rng(config.seed) │
│ 34 │ │
│ ❱ 35 │ available = jax.devices(self.config.platform) │
│ 36 │ self.policy_devices = [available[i] for i in self.config.policy_devices] │
│ 37 │ self.train_devices = [available[i] for i in self.config.train_devices] │
│ 38 │ self.single_device = (self.policy_devices == self.train_devices) and ( │
│ │
│ /usr/local/lib/python3.8/dist-packages/jax/_src/lib/xla_bridge.py:551 in devices │
│ │
│ 548 Returns: │
│ 549 │ List of Device subclasses. │
│ 550 """ │
│ ❱ 551 return get_backend(backend).devices() │
│ 552 │
│ 553 │
│ 554 def default_backend() -> str: │
│ │
│ /usr/local/lib/python3.8/dist-packages/jax/_src/lib/xla_bridge.py:493 in get_backend │
│ │
│ 490 │
│ 491 @lru_cache(maxsize=None) # don't use util.memoize because there is no X64 dependence. │
│ 492 def get_backend(platform=None): │
│ ❱ 493 return _get_backend_uncached(platform) │
│ 494 │
│ 495 │
│ 496 def get_device_backend(device=None): │
│ │
│ /usr/local/lib/python3.8/dist-packages/jax/_src/lib/xla_bridge.py:479 in _get_backend_uncached │
│ │
│ 476 │
│ 477 bs = backends() │
│ 478 if platform is not None: │
│ ❱ 479 │ platform = canonicalize_platform(platform) │
│ 480 │ backend = bs.get(platform, None) │
│ 481 │ if backend is None: │
│ 482 │ if platform in _backends_errors: │
│ │
│ /usr/local/lib/python3.8/dist-packages/jax/_src/lib/xla_bridge.py:359 in canonicalize_platform │
│ │
│ 356 for p in platforms: │
│ 357 │ if p in b.keys(): │
│ 358 │ return p │
│ ❱ 359 raise RuntimeError(f"Unknown backend: '{platform}' requested, but no " │
│ 360 │ │ │ │ │ f"platforms that are instances of {platform} are present. " │
│ 361 │ │ │ │ │ "Platforms present are: " + ",".join(b.keys())) │
│ 362 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Unknown backend: 'gpu' requested, but no platforms that are instances of gpu are present. Platforms present are: interpreter,cpu
2023-02-25 19:18:02.025428: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc:218] TfrtCpuClient destroyed.

Not sure if it's tied to google/jax#9859?

Performance on simple gym environments

Hey Danijar,

I am a big fan of the dreamer algorithms, especially this paper but I was wondering if you have any insight on why the world models seemingly can't be learnt for simple gym environments in my testing. I've been running this code base on Cartpole and Lunar lander (pixel observations) and whilst background information is learnt, the movement of the cart and spaceship are never included in the world models predictions. Additionally, the policy never learns to perform well. I am using the small model size. Is it possible that these environments observations are too simple or I am not training long enough (+2 hours but on a CPU only machine). Please let me know if you have any thoughts as to why this could be the case. Thanks so much.

Singularity and Atari

I converted the Docker container to a singularity container for running on cluster. Most environments work. However, when I run Atari Breakout experiments, I get a libc6 version error. However, the version of libc6 that is required is incompatible with the Ubuntu base image of the DockerFile. Have you any experience running these experiments on a SLURM cluster with singularity?

Shape error with custom Unity env

Hello,
first of all I would also like to thank you for publicly sharing your research's code.

I am currently trying to run DreamerV3 on my custom environment, which was build using Unity3D's ML-Agents and wrapped as a Gym. After having some issues with the shape of my action and observation space, which I think fixed now, I am still having some issues with the dimensions of checkpoints. The issue occurs right at the beginning of training, when the Agent prefills its train dataset and the first checkpoint is saved.
The full output is

here
python3 example.py 
[UnityMemory] Configuration Parameters - Can be set up in boot.config
  "memorysetup-bucket-allocator-granularity=16"
  "memorysetup-bucket-allocator-bucket-count=8"
  "memorysetup-bucket-allocator-block-size=4194304"
  "memorysetup-bucket-allocator-block-count=1"
  "memorysetup-main-allocator-block-size=16777216"
  "memorysetup-thread-allocator-block-size=16777216"
  "memorysetup-gfx-main-allocator-block-size=16777216"
  "memorysetup-gfx-thread-allocator-block-size=16777216"
  "memorysetup-cache-allocator-block-size=4194304"
  "memorysetup-typetree-allocator-block-size=2097152"
  "memorysetup-profiler-bucket-allocator-granularity=16"
  "memorysetup-profiler-bucket-allocator-bucket-count=8"
  "memorysetup-profiler-bucket-allocator-block-size=4194304"
  "memorysetup-profiler-bucket-allocator-block-count=1"
  "memorysetup-profiler-allocator-block-size=16777216"
  "memorysetup-profiler-editor-allocator-block-size=1048576"
  "memorysetup-temp-allocator-size-main=4194304"
  "memorysetup-job-temp-allocator-block-size=2097152"
  "memorysetup-job-temp-allocator-block-size-background=1048576"
  "memorysetup-job-temp-allocator-reduction-small-platforms=262144"
  "memorysetup-temp-allocator-size-background-worker=32768"
  "memorysetup-temp-allocator-size-job-worker=262144"
  "memorysetup-temp-allocator-size-preload-manager=262144"
  "memorysetup-temp-allocator-size-nav-mesh-worker=65536"
  "memorysetup-temp-allocator-size-audio-worker=65536"
  "memorysetup-temp-allocator-size-cloud-worker=32768"
  "memorysetup-temp-allocator-size-gfx=262144"
[WARNING] The environment contains multiple observations. You must define allow_multiple_obs=True to receive them all. Otherwise, only the first visual observation (or vector observation ifthere are no visual observations) will be provided in the observation.
/home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/gym/spaces/box.py:73: UserWarning: WARN: Box bound precision lowered by casting to float32
logger.warn(
Encoder CNN shapes: {}
Encoder MLP shapes: {'image': (16,)}
Decoder CNN shapes: {}
Decoder MLP shapes: {'image': (16,)}
JAX devices (1): [CpuDevice(id=0)]
Policy devices: TFRT_CPU_0
Train devices:  TFRT_CPU_0
Tracing train function.
Optimizer model_opt has 16,451,344 variables.
Optimizer actor_opt has 1,052,676 variables.
Optimizer critic_opt has 1,181,439 variables.
Logdir logdir/run1
Observation space:
image            Space(dtype=float32, shape=(16,), low=-inf, high=inf)
reward           Space(dtype=float32, shape=(), low=-inf, high=inf)
is_first         Space(dtype=bool, shape=(), low=False, high=True)
is_last          Space(dtype=bool, shape=(), low=False, high=True)
is_terminal      Space(dtype=bool, shape=(), low=False, high=True)
Action space:
action           Space(dtype=float32, shape=(2,), low=-1.0, high=1.0)
reset            Space(dtype=bool, shape=(), low=False, high=True)
Prefill train dataset.
Episode has 61 steps and return 0.1.
Episode has 55 steps and return 0.1.
Episode has 73 steps and return 0.1.
Episode has 55 steps and return 0.1.
Episode has 74 steps and return 0.1.
Episode has 96 steps and return 0.4.
Episode has 56 steps and return 0.1.
Episode has 50 steps and return 0.1.
Episode has 50 steps and return 0.0.
Episode has 45 steps and return 0.0.
Episode has 62 steps and return 0.1.
Episode has 76 steps and return 0.2.
Episode has 53 steps and return 0.1.
Episode has 57 steps and return 0.1.
Episode has 84 steps and return 0.2.
Saved chunk: 20230224T123638F338048-7fz2YQGpaWRhCvMc8sKBIS-4NtbeiuY5nlHsbS33AqTMd-1024.npz
Episode has 69 steps and return 0.1.
──────────────────────────────────────────────────────────────────────────────────────────────────── Step 1100 ────────────────────────────────────────────────────────────────────────────────────────────────────
episode/length 69 / episode/score 0.13 / episode/sum_abs_reward 0.13 / episode/reward_rate 0

Creating new TensorBoard event file writer.
Did not find any checkpoint.
Writing checkpoint: logdir/run1/checkpoint.ckpt
Start training loop.
Saved chunk: 20230224T123818F856929-4NtbeiuY5nlHsbS33AqTMd-0000000000000000000000-76.npz
Wrote checkpoint: logdir/run1/checkpoint.ckpt
Error writing summary: stats/policy_image
Tracing policy function.
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/fabian/Desktop/dreamerv3/example.py:53 in <module>                                         │
│                                                                                                  │
│   50                                                                                             │
│   51                                                                                             │
│   52 if __name__ == '__main__':                                                                  │
│ ❱ 53   main()                                                                                    │
│   54                                                                                             │
│                                                                                                  │
│ /home/fabian/Desktop/dreamerv3/example.py:48 in main                                             │
│                                                                                                  │
│   45   args = embodied.Config(                                                                   │
│   46 │     **config.run, logdir=config.logdir,                                                   │
│   47 │     batch_steps=config.batch_size * config.batch_length)                                  │
│ ❱ 48   embodied.run.train(agent, env, replay, logger, args)                                      │
│   49   # embodied.run.eval_only(agent, env, logger, args)                                        │
│   50                                                                                             │
│   51                                                                                             │
│                                                                                                  │
│ /home/fabian/Desktop/dreamerv3/dreamerv3/embodied/run/train.py:108 in train                      │
│                                                                                                  │
│   105   policy = lambda *args: agent.policy(                                                     │
│   106 │     *args, mode='explore' if should_expl(step) else 'train')                             │
│   107   while step < args.steps:                                                                 │
│ ❱ 108 │   driver(policy, steps=100)                                                              │
│   109 │   if should_save(step):                                                                  │
│   110 │     checkpoint.save()                                                                    │
│   111   logger.write()                                                                           │
│                                                                                                  │
│ /home/fabian/Desktop/dreamerv3/dreamerv3/embodied/core/driver.py:42 in __call__                  │
│                                                                                                  │
│   39   def __call__(self, policy, steps=0, episodes=0):                                          │
│   40 │   step, episode = 0, 0                                                                    │
│   41 │   while step < steps or episode < episodes:                                               │
│ ❱ 42 │     step, episode = self._step(policy, step, episode)                                     │
│   43                                                                                             │
│   44   def _step(self, policy, step, episode):                                                   │
│   45 │   assert all(len(x) == len(self._env) for x in self._acts.values())                       │
│                                                                                                  │
│ /home/fabian/Desktop/dreamerv3/dreamerv3/embodied/core/driver.py:50 in _step                     │
│                                                                                                  │
│   47 │   obs = self._env.step(acts)                                                              │
│   48 │   obs = {k: convert(v) for k, v in obs.items()}                                           │
│   49 │   assert all(len(x) == len(self._env) for x in obs.values()), obs                         │
│ ❱ 50 │   acts, self._state = policy(obs, self._state, **self._kwargs)                            │
│   51 │   acts = {k: convert(v) for k, v in acts.items()}                                         │
│   52 │   if obs['is_last'].any():                                                                │
│   53 │     mask = 1 - obs['is_last']                                                             │
│                                                                                                  │
│ /home/fabian/Desktop/dreamerv3/dreamerv3/embodied/run/train.py:105 in <lambda>                   │
│                                                                                                  │
│   102   should_save(step)  # Register that we jused saved.                                       │
│   103                                                                                            │
│   104   print('Start training loop.')                                                            │
│ ❱ 105   policy = lambda *args: agent.policy(                                                     │
│   106 │     *args, mode='explore' if should_expl(step) else 'train')                             │
│   107   while step < args.steps:                                                                 │
│   108 │   driver(policy, steps=100)                                                              │
│                                                                                                  │
│ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner                   │
│                                                                                                  │
│    72 │   │   @wraps(func)                                                                       │
│    73 │   │   def inner(*args, **kwds):                                                          │
│    74 │   │   │   with self._recreate_cm():                                                      │
│ ❱  75 │   │   │   │   return func(*args, **kwds)                                                 │
│    76 │   │   return inner                                                                       │
│    77                                                                                            │
│    78                                                                                            │
│                                                                                                  │
│ /home/fabian/Desktop/dreamerv3/dreamerv3/jaxagent.py:62 in policy                                │
│                                                                                                  │
│    59 │     state = tree_map(                                                                    │
│    60 │   │     np.asarray, state, is_leaf=lambda x: isinstance(x, list))                        │
│    61 │     state = self._convert_inps(state, self.policy_devices)                               │
│ ❱  62 │   (outs, state), _ = self._policy(varibs, rng, obs, state, mode=mode)                    │
│    63 │   outs = self._convert_outs(outs, self.policy_devices)                                   │
│    64 │   # TODO: Consider keeping policy states in accelerator memory.                          │
│    65 │   state = self._convert_outs(state, self.policy_devices)                                 │
│                                                                                                  │
│ /home/fabian/Desktop/dreamerv3/dreamerv3/ninjax.py:199 in wrapper                                │
│                                                                                                  │
│   196 │   statics = tuple(sorted([(k, v) for k, v in kw.items() if k in static]))                │
│   197 │   kw = {k: v for k, v in kw.items() if k not in static}                                  │
│   198 │   if not hasattr(wrapper, 'keys'):                                                       │
│ ❱ 199 │     created = init(statics, rng, *args, **kw)                                            │
│   200 │     wrapper.keys = set(created.keys())                                                   │
│   201 │     for key, value in created.items():                                                   │
│   202 │   │   if key not in state:                                                               │
│                                                                                                  │
│ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/traceback_util.py:16 │
│ 3 in reraise_with_filtered_traceback                                                             │
│                                                                                                  │
│   160   def reraise_with_filtered_traceback(*args, **kwargs):                                    │
│   161 │   __tracebackhide__ = True                                                               │
│   162 │   try:                                                                                   │
│ ❱ 163 │     return fun(*args, **kwargs)                                                          │
│   164 │   except Exception as e:                                                                 │
│   165 │     mode = filtering_mode()                                                              │
│   166 │     if is_under_reraiser(e) or mode == "off":                                            │
│                                                                                                  │
│ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/pjit.py:237 in       │
│ cache_miss                                                                                       │
│                                                                                                  │
│    234                                                                                           │
│    235   @api_boundary                                                                           │
│    236   def cache_miss(*args, **kwargs):                                                        │
│ ❱  237 │   outs, out_flat, out_tree, args_flat = _python_pjit_helper(                            │
│    238 │   │   fun, infer_params_fn, *args, **kwargs)                                            │
│    239 │                                                                                         │
│    240 │   executable = _read_most_recent_pjit_call_executable()                                 │
│                                                                                                  │
│ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/pjit.py:180 in       │
│ _python_pjit_helper                                                                              │
│                                                                                                  │
│    177                                                                                           │
│    178                                                                                           │
│    179 def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs):                           │
│ ❱  180   args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(                           │
│    181 │     *args, **kwargs)                                                                    │
│    182   for arg in args_flat:                                                                   │
│    183 │   dispatch.check_arg(arg)                                                               │
│                                                                                                  │
│ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/api.py:443 in        │
│ infer_params                                                                                     │
│                                                                                                  │
│    440 │   │     static_argnames=static_argnames, donate_argnums=donate_argnums,                 │
│    441 │   │     device=device, backend=backend, keep_unused=keep_unused,                        │
│    442 │   │     inline=inline, resource_env=None)                                               │
│ ❱  443 │     return pjit.common_infer_params(pjit_info_args, *args, **kwargs)                    │
│    444 │                                                                                         │
│    445 │   has_explicit_sharding = pjit._pjit_explicit_sharding(                                 │
│    446 │   │   in_shardings, out_shardings, device, backend)                                     │
│                                                                                                  │
│ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/pjit.py:520 in       │
│ common_infer_params                                                                              │
│                                                                                                  │
│    517 │     hashable_pytree(in_shardings), local_in_avals, in_tree, in_positional_semantics,    │
│    518 │     tuple(isinstance(a, GDA) for a in args_flat), resource_env)                         │
│    519                                                                                           │
│ ❱  520   jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(                          │
│    521 │     flat_fun, hashable_pytree(out_shardings), global_in_avals,                          │
│    522 │     HashableFunction(out_tree, closure=()),                                             │
│    523 │     ('jit' if resource_env is None else 'pjit'))                                        │
│                                                                                                  │
│ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/linear_util.py:301   │
│ in memoized_fun                                                                                  │
│                                                                                                  │
│   298 │     ans, stores = result                                                                 │
│   299 │     fun.populate_stores(stores)                                                          │
│   300 │   else:                                                                                  │
│ ❱ 301 │     ans = call(fun, *args)                                                               │
│   302 │     cache[key] = (ans, fun.stores)                                                       │
│   303 │                                                                                          │
│   304 │   return ans                                                                             │
│                                                                                                  │
│ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/pjit.py:932 in       │
│ _pjit_jaxpr                                                                                      │
│                                                                                                  │
│    929 │   with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "     │
│    930 │   │   │   │   │   │   │   │      "for pjit in {elapsed_time} sec",                      │
│    931 │   │   │   │   │   │   │   │   │   event=dispatch.JAXPR_TRACE_EVENT):                    │
│ ❱  932 │     jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(                        │
│    933 │   │     fun, global_in_avals, debug_info=pe.debug_info_final(fun, api_name))            │
│    934   finally:                                                                                │
│    935 │   pxla.positional_semantics.val = prev_positional_val                                   │
│                                                                                                  │
│ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/profiler.py:314 in   │
│ wrapper                                                                                          │
│                                                                                                  │
│   311   @wraps(func)                                                                             │
│   312   def wrapper(*args, **kwargs):                                                            │
│   313 │   with TraceAnnotation(name, **decorator_kwargs):                                        │
│ ❱ 314 │     return func(*args, **kwargs)                                                         │
│   315 │   return wrapper                                                                         │
│   316   return wrapper                                                                           │
│   317                                                                                            │
│                                                                                                  │
│ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/interpreters/partial_eval │
│ .py:1985 in trace_to_jaxpr_dynamic                                                               │
│                                                                                                  │
│   1982 │   │   │   │   │   │      keep_inputs: Optional[List[bool]] = None):                     │
│   1983   with core.new_main(DynamicJaxprTrace, dynamic=True) as main:  # type: ignore            │
│   1984 │   main.jaxpr_stack = ()  # type: ignore                                                 │
│ ❱ 1985 │   jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(                                 │
│   1986 │     fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info)                │
│   1987 │   del main, fun                                                                         │
│   1988   return jaxpr, out_avals, consts                                                         │
│                                                                                                  │
│ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/interpreters/partial_eval │
│ .py:2002 in trace_to_subjaxpr_dynamic                                                            │
│                                                                                                  │
│   1999 │   trace = DynamicJaxprTrace(main, core.cur_sublevel())                                  │
│   2000 │   in_tracers = _input_type_to_tracers(trace.new_arg, in_avals)                          │
│   2001 │   in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]                 │
│ ❱ 2002 │   ans = fun.call_wrapped(*in_tracers_)                                                  │
│   2003 │   out_tracers = map(trace.full_raise, ans)                                              │
│   2004 │   jaxpr, consts = frame.to_jaxpr(out_tracers)                                           │
│   2005 │   del fun, main, trace, frame, in_tracers, out_tracers, ans                             │
│                                                                                                  │
│ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/linear_util.py:165   │
│ in call_wrapped                                                                                  │
│                                                                                                  │
│   162 │   gen = gen_static_args = out_store = None                                               │
│   163 │                                                                                          │
│   164 │   try:                                                                                   │
│ ❱ 165 │     ans = self.f(*args, **dict(self.params, **kwargs))                                   │
│   166 │   except:                                                                                │
│   167 │     # Some transformations yield from inside context managers, so we have to             │
│   168 │     # interrupt them before reraising the exception. Otherwise they will only            │
│                                                                                                  │
│ /home/fabian/Desktop/dreamerv3/dreamerv3/ninjax.py:184 in init                                   │
│                                                                                                  │
│   181   @bind(jax.jit, static_argnums=[0], **kwargs)                                             │
│   182   def init(statics, rng, *args, **kw):                                                     │
│   183 │   # Return only state so JIT can remove dead code for fast initialization.               │
│ ❱ 184 │   s = fun({}, rng, *args, ignore=True, **dict(statics), **kw)[1]                         │
│   185 │   return s                                                                               │
│   186                                                                                            │
│   187   @bind(jax.jit, static_argnums=[0], **kwargs)                                             │
│                                                                                                  │
│ /home/fabian/Desktop/dreamerv3/dreamerv3/ninjax.py:95 in purified                                │
│                                                                                                  │
│    92 │   │   rng = jax.random.PRNGKey(rng)                                                      │
│    93 │     context = Context(state.copy(), rng, create, modify, ignore, [], name)               │
│    94 │     CONTEXT[threading.get_ident()] = context                                             │
│ ❱  95 │     out = fun(*args, **kwargs)                                                           │
│    96 │     state = dict(context)                                                                │
│    97 │     return out, state                                                                    │
│    98 │   finally:                                                                               │
│                                                                                                  │
│ /home/fabian/Desktop/dreamerv3/dreamerv3/ninjax.py:380 in wrapper                                │
│                                                                                                  │
│   377   def wrapper(self, *args, **kwargs):                                                      │
│   378 │   with scope(self._path, absolute=True):                                                 │
│   379 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 380 │   │   return method(self, *args, **kwargs)                                               │
│   381   return wrapper                                                                           │
│   382                                                                                            │
│   383                                                                                            │
│                                                                                                  │
│ /home/fabian/Desktop/dreamerv3/dreamerv3/agent.py:56 in policy                                   │
│                                                                                                  │
│    53 │   obs = self.preprocess(obs)                                                             │
│    54 │   (prev_latent, prev_action), task_state, expl_state = state                             │
│    55 │   embed = self.wm.encoder(obs)                                                           │
│ ❱  56 │   latent, _ = self.wm.rssm.obs_step(                                                     │
│    57 │   │   prev_latent, prev_action, embed, obs['is_first'])                                  │
│    58 │   self.expl_behavior.policy(latent, expl_state)                                          │
│    59 │   task_outs, task_state = self.task_behavior.policy(latent, task_state)                  │
│                                                                                                  │
│ /home/fabian/Desktop/dreamerv3/dreamerv3/ninjax.py:380 in wrapper                                │
│                                                                                                  │
│   377   def wrapper(self, *args, **kwargs):                                                      │
│   378 │   with scope(self._path, absolute=True):                                                 │
│   379 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 380 │   │   return method(self, *args, **kwargs)                                               │
│   381   return wrapper                                                                           │
│   382                                                                                            │
│   383                                                                                            │
│                                                                                                  │
│ /home/fabian/Desktop/dreamerv3/dreamerv3/nets.py:105 in obs_step                                 │
│                                                                                                  │
│   102 │   #   # prior['deter'] has shape (1, 512) but embed has shape (1, 1, 1024). Need to sq   │
│   103 │   #   embed = jnp.squeeze(embed, axis=1)                                                 │
│   104 │   # print('aft: prior deter', prior['deter'], embed, prior['deter'].ndim == embed.ndim   │
│ ❱ 105 │   x = jnp.concatenate([prior['deter'], embed], -1)                                       │
│   106 │   x = self.get('obs_out', Linear, **self._kw)(x)                                         │
│   107 │   stats = self._stats('obs_stats', x)                                                    │
│   108 │   dist = self.get_dist(stats)                                                            │
│                                                                                                  │
│ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:1 │
│ 845 in concatenate                                                                               │
│                                                                                                  │
│   1842   # (https://github.com/google/jax/issues/653).                                           │
│   1843   k = 16                                                                                  │
│   1844   while len(arrays_out) > 1:                                                              │
│ ❱ 1845 │   arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)                                │
│   1846 │   │   │   │     for i in range(0, len(arrays_out), k)]                                  │
│   1847   return arrays_out[0]                                                                    │
│   1848                                                                                           │
│                                                                                                  │
│ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:1 │
│ 845 in <listcomp>                                                                                │
│                                                                                                  │
│   1842   # (https://github.com/google/jax/issues/653).                                           │
│   1843   k = 16                                                                                  │
│   1844   while len(arrays_out) > 1:                                                              │
│ ❱ 1845 │   arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)                                │
│   1846 │   │   │   │     for i in range(0, len(arrays_out), k)]                                  │
│   1847   return arrays_out[0]                                                                    │
│   1848                                                                                           │
│                                                                                                  │
│ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/lax/lax.py:644 in    │
│ concatenate                                                                                      │
│                                                                                                  │
│    641 │   op, = operands                                                                        │
│    642 │   if isinstance(op, Array):                                                             │
│    643 │     return type_cast(Array, op)                                                         │
│ ❱  644   return concatenate_p.bind(*operands, dimension=dimension)                               │
│    645                                                                                           │
│    646                                                                                           │
│    647 class _enum_descriptor:                                                                   │
│                                                                                                  │
│ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/core.py:343 in bind  │
│                                                                                                  │
│    340   def bind(self, *args, **params):                                                        │
│    341 │   assert (not config.jax_enable_checks or                                               │
│    342 │   │   │   all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args     │
│ ❱  343 │   return self.bind_with_trace(find_top_trace(args), args, params)                       │
│    344                                                                                           │
│    345   def bind_with_trace(self, trace, args, params):                                         │
│    346 │   out = trace.process_primitive(self, map(trace.full_raise, args), params)              │
│                                                                                                  │
│ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/core.py:346 in       │
│ bind_with_trace                                                                                  │
│                                                                                                  │
│    343 │   return self.bind_with_trace(find_top_trace(args), args, params)                       │
│    344                                                                                           │
│    345   def bind_with_trace(self, trace, args, params):                                         │
│ ❱  346 │   out = trace.process_primitive(self, map(trace.full_raise, args), params)              │
│    347 │   return map(full_lower, out) if self.multiple_results else full_lower(out)             │
│    348                                                                                           │
│    349   def def_impl(self, impl):                                                               │
│                                                                                                  │
│ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/interpreters/partial_eval │
│ .py:1721 in process_primitive                                                                    │
│                                                                                                  │
│   1718   def process_primitive(self, primitive, tracers, params):                                │
│   1719 │   if primitive in custom_staging_rules:                                                 │
│   1720 │     return custom_staging_rules[primitive](self, *tracers, **params)                    │
│ ❱ 1721 │   return self.default_process_primitive(primitive, tracers, params)                     │
│   1722                                                                                           │
│   1723   def default_process_primitive(self, primitive, tracers, params):                        │
│   1724 │   avals = [t.aval for t in tracers]                                                     │
│                                                                                                  │
│ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/interpreters/partial_eval │
│ .py:1725 in default_process_primitive                                                            │
│                                                                                                  │
│   1722                                                                                           │
│   1723   def default_process_primitive(self, primitive, tracers, params):                        │
│   1724 │   avals = [t.aval for t in tracers]                                                     │
│ ❱ 1725 │   out_avals, effects = primitive.abstract_eval(*avals, **params)                        │
│   1726 │   out_avals = [out_avals] if not primitive.multiple_results else out_avals              │
│   1727 │   source_info = source_info_util.current()                                              │
│   1728 │   out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]           │
│                                                                                                  │
│ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/core.py:379 in       │
│ abstract_eval_                                                                                   │
│                                                                                                  │
│    376                                                                                           │
│    377 def _effect_free_abstract_eval(abstract_eval):                                            │
│    378   def abstract_eval_(*args, **kwargs):                                                    │
│ ❱  379 │   return abstract_eval(*args, **kwargs), no_effects                                     │
│    380   return abstract_eval_                                                                   │
│    381                                                                                           │
│    382 # -------------------- lifting --------------------                                       │
│                                                                                                  │
│ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/lax/utils.py:66 in   │
│ standard_abstract_eval                                                                           │
│                                                                                                  │
│    63 │   out = prim.impl(*[x.val for x in avals], **kwargs)                                     │
│    64 │   return core.ConcreteArray(out.dtype, out, weak_type=weak_type)                         │
│    65   elif least_specialized is core.ShapedArray:                                              │
│ ❱  66 │   return core.ShapedArray(shape_rule(*avals, **kwargs),                                  │
│    67 │   │   │   │   │   │   │   dtype_rule(*avals, **kwargs), weak_type=weak_type,             │
│    68 │   │   │   │   │   │   │   named_shape=named_shape_rule(*avals, **kwargs))                │
│    69   elif least_specialized is core.DShapedArray:                                             │
│                                                                                                  │
│ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/lax/lax.py:3014 in   │
│ _concatenate_shape_rule                                                                          │
│                                                                                                  │
│   3011 │   raise TypeError(msg.format(type(op)))                                                 │
│   3012   if len({operand.ndim for operand in operands}) != 1:                                    │
│   3013 │   msg = "Cannot concatenate arrays with different numbers of dimensions: got {}."       │
│ ❱ 3014 │   raise TypeError(msg.format(", ".join(str(o.shape) for o in operands)))                │
│   3015   if not 0 <= dimension < operands[0].ndim:                                               │
│   3016 │   msg = "concatenate dimension out of bounds: dimension {} for shapes {}."              │
│   3017 │   raise TypeError(msg.format(dimension, ", ".join([str(o.shape) for o in operands])))   │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
UnfilteredStackTrace: TypeError: Cannot concatenate arrays with different numbers of dimensions: got (1, 512), (1, 1, 1024).

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/fabian/Desktop/dreamerv3/example.py:53 in <module>                                         │
│                                                                                                  │
│   50                                                                                             │
│   51                                                                                             │
│   52 if __name__ == '__main__':                                                                  │
│ ❱ 53   main()                                                                                    │
│   54                                                                                             │
│                                                                                                  │
│ /home/fabian/Desktop/dreamerv3/example.py:48 in main                                             │
│                                                                                                  │
│   45   args = embodied.Config(                                                                   │
│   46 │     **config.run, logdir=config.logdir,                                                   │
│   47 │     batch_steps=config.batch_size * config.batch_length)                                  │
│ ❱ 48   embodied.run.train(agent, env, replay, logger, args)                                      │
│   49   # embodied.run.eval_only(agent, env, logger, args)                                        │
│   50                                                                                             │
│   51                                                                                             │
│                                                                                                  │
│ /home/fabian/Desktop/dreamerv3/dreamerv3/embodied/run/train.py:108 in train                      │
│                                                                                                  │
│   105   policy = lambda *args: agent.policy(                                                     │
│   106 │     *args, mode='explore' if should_expl(step) else 'train')                             │
│   107   while step < args.steps:                                                                 │
│ ❱ 108 │   driver(policy, steps=100)                                                              │
│   109 │   if should_save(step):                                                                  │
│   110 │     checkpoint.save()                                                                    │
│   111   logger.write()                                                                           │
│                                                                                                  │
│ /home/fabian/Desktop/dreamerv3/dreamerv3/embodied/core/driver.py:42 in __call__                  │
│                                                                                                  │
│   39   def __call__(self, policy, steps=0, episodes=0):                                          │
│   40 │   step, episode = 0, 0                                                                    │
│   41 │   while step < steps or episode < episodes:                                               │
│ ❱ 42 │     step, episode = self._step(policy, step, episode)                                     │
│   43                                                                                             │
│   44   def _step(self, policy, step, episode):                                                   │
│   45 │   assert all(len(x) == len(self._env) for x in self._acts.values())                       │
│                                                                                                  │
│ /home/fabian/Desktop/dreamerv3/dreamerv3/embodied/core/driver.py:50 in _step                     │
│                                                                                                  │
│   47 │   obs = self._env.step(acts)                                                              │
│   48 │   obs = {k: convert(v) for k, v in obs.items()}                                           │
│   49 │   assert all(len(x) == len(self._env) for x in obs.values()), obs                         │
│ ❱ 50 │   acts, self._state = policy(obs, self._state, **self._kwargs)                            │
│   51 │   acts = {k: convert(v) for k, v in acts.items()}                                         │
│   52 │   if obs['is_last'].any():                                                                │
│   53 │     mask = 1 - obs['is_last']                                                             │
│                                                                                                  │
│ /home/fabian/Desktop/dreamerv3/dreamerv3/embodied/run/train.py:105 in <lambda>                   │
│                                                                                                  │
│   102   should_save(step)  # Register that we jused saved.                                       │
│   103                                                                                            │
│   104   print('Start training loop.')                                                            │
│ ❱ 105   policy = lambda *args: agent.policy(                                                     │
│   106 │     *args, mode='explore' if should_expl(step) else 'train')                             │
│   107   while step < args.steps:                                                                 │
│   108 │   driver(policy, steps=100)                                                              │
│                                                                                                  │
│ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner                   │
│                                                                                                  │
│    72 │   │   @wraps(func)                                                                       │
│    73 │   │   def inner(*args, **kwds):                                                          │
│    74 │   │   │   with self._recreate_cm():                                                      │
│ ❱  75 │   │   │   │   return func(*args, **kwds)                                                 │
│    76 │   │   return inner                                                                       │
│    77                                                                                            │
│    78                                                                                            │
│                                                                                                  │
│ /home/fabian/Desktop/dreamerv3/dreamerv3/jaxagent.py:62 in policy                                │
│                                                                                                  │
│    59 │     state = tree_map(                                                                    │
│    60 │   │     np.asarray, state, is_leaf=lambda x: isinstance(x, list))                        │
│    61 │     state = self._convert_inps(state, self.policy_devices)                               │
│ ❱  62 │   (outs, state), _ = self._policy(varibs, rng, obs, state, mode=mode)                    │
│    63 │   outs = self._convert_outs(outs, self.policy_devices)                                   │
│    64 │   # TODO: Consider keeping policy states in accelerator memory.                          │
│    65 │   state = self._convert_outs(state, self.policy_devices)                                 │
│                                                                                                  │
│ /home/fabian/Desktop/dreamerv3/dreamerv3/ninjax.py:199 in wrapper                                │
│                                                                                                  │
│   196 │   statics = tuple(sorted([(k, v) for k, v in kw.items() if k in static]))                │
│   197 │   kw = {k: v for k, v in kw.items() if k not in static}                                  │
│   198 │   if not hasattr(wrapper, 'keys'):                                                       │
│ ❱ 199 │     created = init(statics, rng, *args, **kw)                                            │
│   200 │     wrapper.keys = set(created.keys())                                                   │
│   201 │     for key, value in created.items():                                                   │
│   202 │   │   if key not in state:                                                               │
│                                                                                                  │
│ /home/fabian/Desktop/dreamerv3/dreamerv3/ninjax.py:184 in init                                   │
│                                                                                                  │
│   181   @bind(jax.jit, static_argnums=[0], **kwargs)                                             │
│   182   def init(statics, rng, *args, **kw):                                                     │
│   183 │   # Return only state so JIT can remove dead code for fast initialization.               │
│ ❱ 184 │   s = fun({}, rng, *args, ignore=True, **dict(statics), **kw)[1]                         │
│   185 │   return s                                                                               │
│   186                                                                                            │
│   187   @bind(jax.jit, static_argnums=[0], **kwargs)                                             │
│                                                                                                  │
│ /home/fabian/Desktop/dreamerv3/dreamerv3/ninjax.py:95 in purified                                │
│                                                                                                  │
│    92 │   │   rng = jax.random.PRNGKey(rng)                                                      │
│    93 │     context = Context(state.copy(), rng, create, modify, ignore, [], name)               │
│    94 │     CONTEXT[threading.get_ident()] = context                                             │
│ ❱  95 │     out = fun(*args, **kwargs)                                                           │
│    96 │     state = dict(context)                                                                │
│    97 │     return out, state                                                                    │
│    98 │   finally:                                                                               │
│                                                                                                  │
│ /home/fabian/Desktop/dreamerv3/dreamerv3/ninjax.py:380 in wrapper                                │
│                                                                                                  │
│   377   def wrapper(self, *args, **kwargs):                                                      │
│   378 │   with scope(self._path, absolute=True):                                                 │
│   379 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 380 │   │   return method(self, *args, **kwargs)                                               │
│   381   return wrapper                                                                           │
│   382                                                                                            │
│   383                                                                                            │
│                                                                                                  │
│ /home/fabian/Desktop/dreamerv3/dreamerv3/agent.py:56 in policy                                   │
│                                                                                                  │
│    53 │   obs = self.preprocess(obs)                                                             │
│    54 │   (prev_latent, prev_action), task_state, expl_state = state                             │
│    55 │   embed = self.wm.encoder(obs)                                                           │
│ ❱  56 │   latent, _ = self.wm.rssm.obs_step(                                                     │
│    57 │   │   prev_latent, prev_action, embed, obs['is_first'])                                  │
│    58 │   self.expl_behavior.policy(latent, expl_state)                                          │
│    59 │   task_outs, task_state = self.task_behavior.policy(latent, task_state)                  │
│                                                                                                  │
│ /home/fabian/Desktop/dreamerv3/dreamerv3/ninjax.py:380 in wrapper                                │
│                                                                                                  │
│   377   def wrapper(self, *args, **kwargs):                                                      │
│   378 │   with scope(self._path, absolute=True):                                                 │
│   379 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 380 │   │   return method(self, *args, **kwargs)                                               │
│   381   return wrapper                                                                           │
│   382                                                                                            │
│   383                                                                                            │
│                                                                                                  │
│ /home/fabian/Desktop/dreamerv3/dreamerv3/nets.py:105 in obs_step                                 │
│                                                                                                  │
│   102 │   #   # prior['deter'] has shape (1, 512) but embed has shape (1, 1, 1024). Need to sq   │
│   103 │   #   embed = jnp.squeeze(embed, axis=1)                                                 │
│   104 │   # print('aft: prior deter', prior['deter'], embed, prior['deter'].ndim == embed.ndim   │
│ ❱ 105 │   x = jnp.concatenate([prior['deter'], embed], -1)                                       │
│   106 │   x = self.get('obs_out', Linear, **self._kw)(x)                                         │
│   107 │   stats = self._stats('obs_stats', x)                                                    │
│   108 │   dist = self.get_dist(stats)                                                            │
│                                                                                                  │
│ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:1 │
│ 845 in concatenate                                                                               │
│                                                                                                  │
│   1842   # (https://github.com/google/jax/issues/653).                                           │
│   1843   k = 16                                                                                  │
│   1844   while len(arrays_out) > 1:                                                              │
│ ❱ 1845 │   arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)                                │
│   1846 │   │   │   │     for i in range(0, len(arrays_out), k)]                                  │
│   1847   return arrays_out[0]                                                                    │
│   1848                                                                                           │
│                                                                                                  │
│ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:1 │
│ 845 in <listcomp>                                                                                │
│                                                                                                  │
│   1842   # (https://github.com/google/jax/issues/653).                                           │
│   1843   k = 16                                                                                  │
│   1844   while len(arrays_out) > 1:                                                              │
│ ❱ 1845 │   arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)                                │
│   1846 │   │   │   │     for i in range(0, len(arrays_out), k)]                                  │
│   1847   return arrays_out[0]                                                                    │
│   1848                                                                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
TypeError: Cannot concatenate arrays with different numbers of dimensions: got (1, 512), (1, 1, 1024).

Thanks

Normalizing advantages or returns

Hi Danijar,

Thanks so much for developing and sharing this algorithm. I've been following your work for some time and I think it's really great.

I'm attempting a re-implementation from scratch in pytorch and I have a question about the actor loss.

In the paper the loss is provided as...

image

However I see in the code we have something that looks more like an advantage estimate.

      rew, ret, base = critic.score(traj, self.actor)
      offset, invscale = self.retnorms[key](ret)
      normed_ret = (ret - offset) / invscale
      normed_base = (base - offset) / invscale
      advs.append((normed_ret - normed_base) * self.scales[key] / total)

If I'm interpreting the code correctly, normed_base seems to come from the value of the state the actor was in prior to transition, and normed_ret is the percentile scaled return as per the paper.

Also there is a little trick at the end..

    loss *= sg(traj['weight'])[:-1]

where 'weight' is computed as exponentially discounting the models future predictions

    cont = self.heads['cont'](traj).mode()
    traj['cont'] = jnp.concatenate([first_cont[None], cont[1:]], 0)
    discount = 1 - 1 / self.config.horizon
    traj['weight'] = jnp.cumprod(discount * traj['cont'], 0) / discount

This all makes sense, and is a good policy gradient, but in the paper it's mentioned that it's important to keep the scale of the policy gradient loss proportional to the entropy.

I noticed a big difference in policy gradient scale between using the returns as presented in the paper, and what we have here.

Would be great if you can clarify. Did I just misread the paper, or are these simply implementation details that don't matter a whole lot in practice?

UnfilteredStackTrace: AssertionError: ((1024, 64, 64, 3), (63, 64, 3))

Hi,

Congrats by excellent work.

I'm using the example.py file with a custom gym environment. My env have the shape (64, 64, 3) and I'm receiving this error.

UnfilteredStackTrace: AssertionError: ((1024, 64, 64, 3), (63, 64, 3))

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
def main():

  import warnings
  import dreamerv3
  from dreamerv3 import embodied
  warnings.filterwarnings('ignore', '.*truncated to dtype int32.*')

  # See configs.yaml for all options.
  config = embodied.Config(dreamerv3.configs['defaults'])
  # config = config.update(dreamerv3.configs['medium'])
  config = config.update({
      'logdir': '~/logdir/run1',
      'run.train_ratio': 64,
      'run.log_every': 30,  # Seconds
      'batch_size': 16,
      'jax.prealloc': False,
      'encoder.mlp_keys': '$^',
      'decoder.mlp_keys': '$^',
      'encoder.cnn_keys': 'image',
      'decoder.cnn_keys': 'image',
      'jax.platform': 'cpu',
  })
  config = embodied.Flags(config).parse()

  logdir = embodied.Path(config.logdir)
  step = embodied.Counter()
  logger = embodied.Logger(step, [
      embodied.logger.TerminalOutput(),
      #embodied.logger.JSONLOutput(logdir, 'metrics.jsonl'),
      embodied.logger.TensorBoardOutput(logdir),
      # embodied.logger.WandBOutput(logdir.name, config),
      # embodied.logger.MLFlowOutput(logdir.name),
  ])

  import gymnasium as gym
  import gym_custom
  from embodied.envs import from_gym
  env = gym.make('gym_custom:MyEnv-v0')
  env = from_gym.FromGym(env, obs_key='image')  # Or obs_key='vector'.
  # env = from_gym.FromGym(env, obs_key='state_vec')  # Or obs_key='vector'.
  # env = from_gym.FromGym(env, obs_key='vector')  # Or obs_key='vector'.
  env = dreamerv3.wrap_env(env, config)
  # env = embodied.BatchEnv([env], parallel=False)

  agent = dreamerv3.Agent(env.obs_space, env.act_space, step, config)
  replay = embodied.replay.Uniform(
      config.batch_length, config.replay_size, logdir / 'replay')
  args = embodied.Config(
      **config.run, logdir=config.logdir,
      batch_steps=config.batch_size * config.batch_length)
  embodied.run.train(agent, env, replay, logger, args)
  # embodied.run.eval_only(agent, env, logger, args)


if __name__ == '__main__':
  main()

My observation:

>>> obs
array([[[  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        ...,
        [ 75, 214,  16],
        [ 75, 214,  16],
        [ 75, 214,  16]],

       [[  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        ...,
        [ 75, 214,  16],
        [ 75, 214,  16],
        [ 75, 214,  16]],

       [[  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        ...,
        [ 75, 214,  16],
        [ 75, 214,  16],
        [ 75, 214,  16]],

       ...,

       [[  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        ...,
        [ 95, 211,  64],
        [ 95, 211,  64],
        [ 95, 211,  64]],

       [[  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        ...,
        [ 95, 211,  64],
        [ 95, 211,  64],
        [ 95, 211,  64]],

       [[  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        ...,
        [ 95, 211,  64],
        [ 95, 211,  64],
        [ 95, 211,  64]]], dtype=uint8)

Any idea how to fix this?

Expected FPS on V100?

Hey @danijar
thanks a lot for open sourcing DreamerV3!

I'm currently reproducing your crafter results and was wondering if 4.6 FPS sound right to you during training.
The training is executed on 32 cores and one V100 GPU. The log states that the GPU is used.
Give 4.6 FPS, it would take 25 days to achieve a training duration of 10M steps (I'm referring to Figure 6).

python dreamerv3/train.py --logdir logdir/train --configs crafter --run.script train

Encoder decoder loss plot

Hello, is it possible to visualise the loss plot of the encoder-decoder? I am new to jax so not sure how to do it.
Regards

Interactive rendering for SafetyGym

Hi there,

Thanks for the codebase, an awesome algorithm to play around with. I'm trying to render the env to sanity check the agent and am wondering how this is done. I've tried calling the render function in the from_gym.py script but this is not working. Simply running env.render in the step function also is not working.

Any help appreciated, thanks.

Feature Request: Gymnasium Compatibility

Hi, would it be possible to allow this repo to be used with Gymnasium environments? Certain libraries like MineRL are not compatible with Gymnasium yet but it would be great to allow environments to be of either type, so it could be used the wide range of new libraries that are based on Gymnasium.

For background, Gymnasium a maintained fork of openai gym and is designed as a drop-in replacement (import gym -> import gymnasium as gym). For context, many popular RL training libraries have switched (rllib, tianshou, CleanRL, stable-baselines3), along with many environment repositories (see third party environments). We are also considering compiling a list of repositories for popular training libraries and models implementations which can be used with Gymnasium, so this could potentially be listed on the website.

For information about upgrading and compatibility, see migration guide and gym compatibility. The main difference is the API has switched to returning truncated and terminated, rather than done, in order to give more information and mitigate edge case issues (for example, many popular tutorials/implementations of Q learning using gym were actually incorrect because of done, there will be an upcoming blog post explaining more details about this on the Farama site (https://farama.org/blog)).

Exporting the JAX policy as a TF model

Congratulations Danijar on this project and your paper!
Again, not really an issue per se. I understand having read and executed the example.py and most of the code that this project doesn't use Tensorflow in the way I'm familiar with and instead uses jax. I will endeavour to understand myself but I was wondering if it were "simple" to use jax2f to obtain a SavedModel? Ideally after training. Then if I'm feeling really brave I intend to use tfjs-converter to run inference in a web demo.

Update:
A) I realise I can go straight from jax to tfjs.
B) I also realise/think I understand that I'm actually going to have to get 3 nets converted, the world model, actor and critic. Then Implement dreamer in client side javascript.
I'm becoming ever doubtful of my ability to pull this off but the payoff has this occupying my full attention (calendar emptied for next 3 days).

Problem in example.py

Hi, thanks for your nice work! However, when I run

> python example.py

I didn't change any source code and met the following issue:

> TypeError: dot_general requires contracting dimensions to have the same shape, got (1041,) and (1029,).

It seems that there is a bug in somewhere?

Actions and states return nan

This is really a great job. Thanks for sharing this amazing work.

When I run the training script

python dreamerv3/train.py --logdir ~/logdir/$(date "+%Y%m%d-%H%M%S") --configs crafter small --batch_size 16 --run.train_ratio 32 `

The code runs stably at the beginning and reports an error after training for a while,
Exception: Traceback (most recent call last): File "/home/kevin/dreamerv3-latest/dreamerv3/embodied/core/worker.py", line 202, in _loop state, result = function(state, *args, **kwargs) File "/home/kevin/dreamerv3-latest/dreamerv3/embodied/core/parallel.py", line 40, in _respond result = getattr(state, name)(*args, **kwargs) File "/home/kevin/dreamerv3-latest/dreamerv3/embodied/core/wrappers.py", line 158, in step obs = self.env.step(action) File "/home/kevin/dreamerv3-latest/dreamerv3/embodied/core/wrappers.py", line 113, in step assert action[self._key].min() == 0.0, action AssertionError: {'action': array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], dtype=float32), 'reset': False}

Please tell me, how to avoid it, and how to modify the code

[QUESTION] Why is it multi-gpu results in lower fps?

I tested multi-gpus in various configurations, but every modification I made resulted in a lower fps. The fastest configuration I got was with a single GPU and no modifications.
I validated the Policy and Training devices were configured correctly in every instance. (e.g.

Policy devices: gpu:0
Train devices:  gpu:1

I added config = config.update(dreamerv3.configs['multicpu']) to example.py
Here are various configuration of configs.yaml

#No modifications including no multicpu in example.py
#Results in ~67fps with single GPU
# Results in ~47fps with multi-GPU
multicpu:

  jax:
    logical_cpus: 2
    policy_devices: [0]
    train_devices: [1]
  run:
    actor_batch: 4
  envs:
    amount: 8
  batch_size: 12
  batch_length: 10
#Results in ~ 26fps with multi-GPU
multicpu:

  jax:
    logical_cpus: 2
    policy_devices: [0]
    train_devices: [0, 1]
  run:
    actor_batch: 4
  envs:
    amount: 8
  batch_size: 12
  batch_length: 10
#Results in error
multicpu:

  jax:
    logical_cpus: 2
    policy_devices: [0, 1]
    train_devices: [0, 1]
  run:
    actor_batch: 4
  envs:
    amount: 8
  batch_size: 12
  batch_length: 10

ValueError: Batch must by divisible by 2 devices: {'image': (1, 64, 64, 3), 'is_first': (1,), 'is_last': (1,), 'is_terminal': (1,), 'reward': (1,)}

CUDA check failed with custom env

Issue Description

When training with a custom Gym env, there seems to be some dimension mismatch in the observation tensor batch OR some other issue (related to ninjax?) that causes a CUDNN crash with the following error:

F external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:632] Check failed: cudnnSetTensorNdDescriptor(handle_.get(), elem_type, nd, dims.data(), strides.data()) == CUDNN_STATUS_SUCCESS (9 vs. 0)batch_descriptor: {count: 1024 feature_map_count: 48 spatial: 259 259  value_min: 0.000000 value_max: 0.000000 layout: BatchYXDepth}

See attached dreamerv3-train-crash.txt for full log.
Could you suggest a resolution or point to the code line/module that could cause the above? Unfortunately, Python Debugger didn't produce any trace.

Additional Notes

  1. Agent's model, actor & critic network are initialized with the correct dimensions based on the observation space
  2. Replay buffer prefill steps to gather the train dataset executes fine

Steps to reproduce the error

Minimal changes to example.py to reproduce the above error:

  1. Replace the following lines:

dreamerv3/example.py

Lines 35 to 37 in 51556cd

import crafter
from embodied.envs import from_gym
env = crafter.Env() # Replace this with your Gym env.

With

    import gym
    import playwrightgym
    from embodied.envs import from_gym
    from gym.wrappers.resize_observation import ResizeObservation
    env = playwrightgym.playwright_env.PlaywrightEnv()
    env = ResizeObservation(env, (256, 256))
  1. Install additional python env requirements:
python -m pip install playwrightgym playwright opencv-contrib-python && \
python -m playwright install
  1. Run: python example.py

Thanks in advance for your help!

My rewards not crossing 0.

Hi Danijar,

Is more a doub then an issue.
My custom environment reach reward 10 in others models/algorithms (like muzero). But, when using dreamer (v2 or v3) I get this behavior (the reward going to zero). I would love to use dreamerv3.

Do you already had this result before? Maybe need I some hyperparameters changes?

image

I'm using this config and my observation is a image 64x64.

config = config.update(dreamerv3.configs['small'])
  config = config.update({
      'logdir': '~/logdir/run1',
      'run.train_ratio': 32,
      'run.log_every': 30,  # Seconds
      'batch_size': 8,
      'jax.prealloc': False,
      'encoder.mlp_keys': '$^',
      'decoder.mlp_keys': '$^',
      'encoder.cnn_keys': 'image',
      'decoder.cnn_keys': 'image',
      'jax.platform': 'cpu',
  })

This is a report when training:

────────────────────────────────────────────────────────────────────────────── Step 381545 ───────────────────────────────────────────────────────────────────────────────
train/action_mag 9 / train/action_max 9 / train/action_mean 7.97 / train/action_min 0 / train/action_std 0.42 / train/actor_opt_actor_opt_grad_overflow 0 /
train/actor_opt_actor_opt_grad_scale 1e4 / train/actor_opt_grad_norm 3.1e-6 / train/actor_opt_grad_steps 2.4e4 / train/actor_opt_loss -1356.35 / train/adv_mag 12.31 /
train/adv_max 0.88 / train/adv_mean -0.07 / train/adv_min -12.31 / train/adv_std 0.79 / train/cont_avg 0.99 / train/cont_loss_mean 1.3e-7 / train/cont_loss_std 2.9e-7 /
train/cont_neg_acc 1 / train/cont_neg_loss 8.6e-7 / train/cont_pos_acc 1 / train/cont_pos_loss 1.2e-7 / train/cont_pred 0.99 / train/cont_rate 0.99 / train/dyn_loss_mean
1.43 / train/dyn_loss_std 2.75 / train/extr_critic_critic_opt_critic_opt_grad_overflow 0 / train/extr_critic_critic_opt_critic_opt_grad_scale 1e4 /
train/extr_critic_critic_opt_grad_norm 0.09 / train/extr_critic_critic_opt_grad_steps 2.4e4 / train/extr_critic_critic_opt_loss 1.2e4 / train/extr_critic_mag 4.86 /
train/extr_critic_max -5.8e-3 / train/extr_critic_mean -0.26 / train/extr_critic_min -4.86 / train/extr_critic_std 0.27 / train/extr_return_normed_mag 12.03 /
train/extr_return_normed_max 0.52 / train/extr_return_normed_mean 0.19 / train/extr_return_normed_min -12.03 / train/extr_return_normed_std 0.85 / train/extr_return_rate
0.06 / train/extr_return_raw_mag 12.55 / train/extr_return_raw_max -1.6e-4 / train/extr_return_raw_mean -0.34 / train/extr_return_raw_min -12.55 /
train/extr_return_raw_std 0.85 / train/extr_reward_mag 5.37 / train/extr_reward_max -6.1e-6 / train/extr_reward_mean -0.02 / train/extr_reward_min -5.37 /
train/extr_reward_std 0.28 / train/image_loss_mean 0.06 / train/image_loss_std 0.17 / train/model_loss_mean 0.97 / train/model_loss_std 1.93 / train/model_opt_grad_norm
16.26 / train/model_opt_grad_steps 2.4e4 / train/model_opt_loss 9693.77 / train/model_opt_model_opt_grad_overflow 0 / train/model_opt_model_opt_grad_scale 1e4 /
train/policy_entropy_mag 0.08 / train/policy_entropy_max 0.08 / train/policy_entropy_mean 0.07 / train/policy_entropy_min 0.07 / train/policy_entropy_std 8.1e-5 /
train/policy_logprob_mag 6.91 / train/policy_logprob_max -9e-3 / train/policy_logprob_mean -0.07 / train/policy_logprob_min -6.91 / train/policy_logprob_std 0.65 /
train/policy_randomness_mag 0.03 / train/policy_randomness_max 0.03 / train/policy_randomness_mean 0.03 / train/policy_randomness_min 0.03 / train/policy_randomness_std
3.5e-5 / train/post_ent_mag 54.24 / train/post_ent_max 54.24 / train/post_ent_mean 11.24 / train/post_ent_min 6.25 / train/post_ent_std 7.23 / train/prior_ent_mag 61.4 /
train/prior_ent_max 61.4 / train/prior_ent_mean 13.41 / train/prior_ent_min 7.81 / train/prior_ent_std 8.98 / train/rep_loss_mean 1.43 / train/rep_loss_std 2.75 /
train/reward_avg -0.04 / train/reward_loss_mean 0.05 / train/reward_loss_std 0.46 / train/reward_max_data 5 / train/reward_max_pred 4.67 / train/reward_neg_acc 1 /
train/reward_neg_loss 0.05 / train/reward_pos_acc nan / train/reward_pos_loss nan / train/reward_pred -0.03 / train/reward_rate 0 / replay/size 3.8e5 / replay/inserts 32
/ replay/samples 16 / replay/insert_wait_avg 2.9e-6 / replay/insert_wait_frac 1 / replay/sample_wait_avg 1.5e-6 / replay/sample_wait_frac 0.75 / timer/duration 53.88 /
timer/env.step_count 32 / timer/env.step_total 0.02 / timer/env.step_frac 3.8e-4 / timer/env.step_avg 6.4e-4 / timer/env.step_min 3.9e-4 / timer/env.step_max 1.8e-3 /
timer/replay.add_count 32 / timer/replay.add_total 3.6e-3 / timer/replay.add_frac 6.6e-5 / timer/replay.add_avg 1.1e-4 / timer/replay.add_min 6.4e-5 /
timer/replay.add_max 4.4e-4 / timer/logger.write_count 1 / timer/logger.write_total 8e-3 / timer/logger.write_frac 1.5e-4 / timer/logger.write_avg 8e-3 /
timer/logger.write_min 8e-3 / timer/logger.write_max 8e-3 / timer/checkpoint.save_count 0 / timer/checkpoint.save_total 0 / timer/checkpoint.save_frac 0 /
timer/agent.save_count 0 / timer/agent.save_total 0 / timer/agent.save_frac 0 / timer/replay.save_count 0 / timer/replay.save_total 0 / timer/replay.save_frac 0 /
timer/agent.policy_count 32 / timer/agent.policy_total 0.35 / timer/agent.policy_frac 6.4e-3 / timer/agent.policy_avg 0.01 / timer/agent.policy_min 9e-3 /
timer/agent.policy_max 0.03 / timer/dataset_count 2 / timer/dataset_total 2.4e-4 / timer/dataset_frac 4.5e-6 / timer/dataset_avg 1.2e-4 / timer/dataset_min 9.2e-5 /
timer/dataset_max 1.5e-4 / timer/agent.train_count 2 / timer/agent.train_total 49.07 / timer/agent.train_frac 0.91 / timer/agent.train_avg 24.54 / timer/agent.train_min
24.4 / timer/agent.train_max 24.68 / timer/agent.report_count 1 / timer/agent.report_total 4.42 / timer/agent.report_frac 0.08 / timer/agent.report_avg 4.42 /
timer/agent.report_min 4.42 / timer/agent.report_max 4.42 / fps 0.59

Episode has 67 steps and return -3.6.

Meaning of each row in the `report/openl_image` video

Here is the GIF I am referring to. A similar one is logged when training on the Crafter environment. This is the BoxWorld environment from pycolab.

new

The code that creates it is in https://github.com/danijar/dreamerv3/blob/main/dreamerv3/agent.py#L200.

Based on the image, it seems like something isn't working properly on this environment: The model solves the first task shown in the top row but then fails to solve the next task because the second row fails to update to match the new task (see second column from the right). The reward that is printed out is usually 11 (the maximum) but is occasionally 0 when the agent exhausts the number of permitted steps

Any suggestions of what I'm doing wrong or what I should change? Thanks!

Training on Gym Pendulum

Thanks for sharing code in general; and even more thanks for providing the crafter example, which 'just works' on the first try!

That being said; I was trying to find the most minimal env, to permit me to run and step through the code. To that end I tried a bunch of things; replacing

  import crafter
  env = crafter.Env()  # Replace this with your Gym env.

With

  from gym.envs.classic_control import pendulum
  env = pendulum.PendulumEnv()

This however runs into an error; and not one that I have solved yet. (see below [^1])

What would be really nice if there was a gymnax wrapper; gymnax has a bunch of simple envs implemented in jax, so in terms of getting something converging fast with minimal dependencies, I think thatd be ideal.

  from gymnax.environments.classic_control import Pendulum
  env = Pendulum()

It has a gym-like api but I also did not get it to work as a drop in replacement. I hope to get it to work though, if successful ill put in the work to make a tidy pr out of it.

[^1] heres the error on the gym classic pendulum; I suspect the default config pertaining to image vs mlp inputs is incorrect, but I cant find any documentation pertaining to the config object.

│   214 │     self._mlp = MLP(None, mlp_layers, mlp_units, dist='none', **mlp_ │
│   215                                                                        │
│   216   def __call__(self, data):                                            │
│ ❱ 217 │   some_key, some_shape = list(self.shapes.items())[0]                │
│   218 │   batch_dims = data[some_key].shape[:-len(some_shape)]               │
│   219 │   data = {                                                           │
│   220 │   │   k: v.reshape((-1,) + v.shape[len(batch_dims):])                │
╰──────────────────────────────────────────────────────────────────────────────╯
IndexError: list index out of range

Using non-standard image sizes

Thank you and your team for sharing this with the scientific community. The significance of your work cannot be over-estimated: +9e9999 karma to you all, friend.

If we use an image of any size other than 32x32, we run into the assertion error. Could we ask for a suggestion about how we might update ImageDecoderResnet to handle alternative image resolutions? Should we focus on the Linear class?

Thanks for any suggestions, and more than anything, thanks for the brilliant contribution

Fully deterministic runs

Awesome repo. quick question,

I ran the DMC WalkerWalk experiment 3 different times with the same seeds and got 3 different learning curves. How can I get reproducible experiments?Awesome repo. quick question,

I ran the DMC WalkerWalk experiment 3 different times with the same seeds and got 3 different learning curves. How can I get reproducible experiments?
curves
curves
curves

Wow

No issue. just incredible work. Amazing code too, kinda messy and hard to interpret but that's because it's extremely general and efficient.

Great work!!!!!!!

XlaRuntimeError: INTERNAL: RET_CHECK failure

Hi,
I try to run the code in docker.
Unfortunately, I get a JAX-related error:

UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: INTERNAL: RET_CHECK failure 
(external/xla/xla/service/gpu/gemm_algorithm_picker.cc:380) 
stream->parent()->GetBlasGemmAlgorithms(stream, &algorithms)

Steps to reproduce:
Install NVidia Container Toolkit https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html
I changed line 33 in the Dockerfile to

COPY dreamerv3/embodied/scripts scripts

Create docker image and run container

docker build -f  dreamerv3/Dockerfile -t dreamer-v3:$USER . && \
 docker run -it --rm --gpus all -v ~/logdir:/logdir dreamer-v3:$USER \
   sh ../scripts/xvfb_run.sh python3 dreamerv3/train.py \
   --logdir "/logdir/$(date +%Y%m%d-%H%M%S)" \
   --configs atari small --task atari_pong

My local nvida-smi output:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.78.01    Driver Version: 525.78.01    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Quadro RTX 4000     Off  | 00000000:01:00.0  On |                  N/A |
| 30%   30C    P8    10W / 125W |    995MiB /  8192MiB |      2%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

The output of the provided nvidia docker test:

docker run -it --rm --gpus all nvidia/cuda:11.4.2-cudnn8-runtime-ubuntu20.04 nvidia-smi

==========
== CUDA ==
==========

CUDA Version 11.4.2

Container image Copyright (c) 2016-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

This container image and its contents are governed by the NVIDIA Deep Learning Container License.
By pulling and using the container, you accept the terms and conditions of this license:
https://developer.nvidia.com/ngc/nvidia-deep-learning-container-license

A copy of this license is made available in this container at /NGC-DL-CONTAINER-LICENSE for your convenience.

*************************
** DEPRECATION NOTICE! **
*************************
THIS IMAGE IS DEPRECATED and is scheduled for DELETION.
    https://gitlab.com/nvidia/container-images/cuda/blob/master/doc/support-policy.md

Tue May  2 12:15:16 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.78.01    Driver Version: 525.78.01    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Quadro RTX 4000     Off  | 00000000:01:00.0  On |                  N/A |
| 30%   30C    P8    11W / 125W |    995MiB /  8192MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+

Do you think I have to change my local Cuda Version in order to get Dreamer-v3 inside the container running correctly?

Questions about train_ratio and optimization procedure

I noticed that driver.on_step(train_step) registers train_step as a callback function triggered each time the agent interacts with the real environment. If i understood correctly then after collecting a frame from the real environment, the agent will do these things for train_ratio times(say 64 in the example):

  • Sample 16*64 obs,action,etc. from the replay buffer
  • Train the WorldModel with these data, and use the 16*64 obs,action as starting info for imagination
  • Train the Actor-Critic model by imagining(start from the context info)

My questions is:

  • Is it necessary to train the world model repeatedly with the same data? As far as I know this may result in converging to a local optimum.

The code is quite complicated, and I'm not sure if there are any mechanism leds to different results than my description above.

BrokenPipeError using Parallel class

Attempting to use the Parallel class to run gym-like environments in parallel.
Training runs successfully with <5 instances, however crashes with a BrokenPipeError mid-training when num_instances > ~10.

Do you have any ideas what would cause this type of error?

===========================ENV WRAPPING CODE (after gym.make)=======================
def make_batched_parallel_env(envs, manager:ComputeManager, config):

envs = map(lambda e: from_gym.FromGym(e, obs_key='vector'), envs)
envs = map(lambda e: dreamerv3.wrap_env(e, config), envs)

if config.envs.parallel != 'none':
    envs = map(lambda e: partial(embodied.Parallel, lambda: e, config.envs.parallel), envs)
else:
    raise("Invalid parallel type for config.envs.parallel")
if config.envs.restart:
    envs = map(lambda e: partial(wrappers.RestartOnException, e), envs)

envs = [e() for e in envs]

return embodied.BatchEnv(envs, parallel=config.envs.parallel != 'none')

def prepare_for_dreamer(envs, manager:ComputeManager, config:embodied.Config, parallel:bool):

env = make_batched_parallel_env(envs, manager, config)
step = embodied.Counter()
logger = embodied.Logger(step, [
    embodied.logger.TerminalOutput(),
    embodied.logger.JSONLOutput(config.logdir+'/metrics', 'metrics.jsonl'),
    embodied.logger.TensorBoardOutput(config.logdir+'tensorboard'),
])

agent = dreamerv3.Agent(env.obs_space, env.act_space, step, config)

replay = embodied.replay.Uniform(
    config.batch_length, config.replay_size, config.logdir+'/replay'
)

args = embodied.Config(
    **config.run,
    logdir=config.logdir,
    batch_steps=config.batch_size * config.batch_length,
)

return agent, env, replay, logger, args

=======================================ERROR MESSAGE================
Error inside process worker: Traceback (most recent call last):
File "/home/joeag/Documents/Group-Project/jax-venv/lib/python3.8/site-packages/dreamerv3/embodied/core/worker.py", line 195, in _loop
message, callid, payload = pipe.recv()
File "/usr/lib/python3.8/multiprocessing/connection.py", line 250, in recv
buf = self._recv_bytes()
File "/usr/lib/python3.8/multiprocessing/connection.py", line 414, in _recv_bytes
buf = self._recv(4)
File "/usr/lib/python3.8/multiprocessing/connection.py", line 379, in _recv
chunk = read(handle, remaining)
ConnectionResetError: [Errno 104] Connection reset by peer

The error occurred while tracing the function init at ninjax.py

run demo in docker:

sh dreamerv3/embodied/scripts/xvfb_run.sh python3 dreamerv3/train.py   --configs dmc_vision --task dmc_walker_walk 
│ /dreamerv3/dreamerv3/ninjax.py:245 in scan                                                       │
│                                                                                                  │
│   242 @jax.named_scope('scan')                                                                   │
│   243 def scan(fun, carry, xs, reverse=False, unroll=1, modify=False):                           │
│   244   fun = pure(fun, nested=True)                                                             │
│ ❱ 245   _prerun(fun, carry, jax.tree_util.tree_map(lambda x: x[0], xs))                          │
│   246   length = len(jax.tree_util.tree_leaves(xs)[0])                                           │
│   247   rngs = rng(length)                                                                       │
│   248   if modify:                                                                               │
│                                                                                                  │
│ /usr/lib/python3.8/contextlib.py:75 in inner                                                     │
│                                                                                                  │
│    72 │   │   @wraps(func)                                                                       │
│    73 │   │   def inner(*args, **kwds):                                                          │
│    74 │   │   │   with self._recreate_cm():                                                      │
│ ❱  75 │   │   │   │   return func(*args, **kwds)                                                 │
│    76 │   │   return inner                                                                       │
│    77                                                                                            │
│    78                                                                                            │
│                                                                                                  │
│ /dreamerv3/dreamerv3/ninjax.py:272 in _prerun                                                    │
│                                                                                                  │
│   269   if not context().create:                                                                 │
│   270 │   return                                                                                 │
│   271   discarded, state = fun(dict(context()), rng(), *args, ignore=True, **kwargs)             │
│ ❱ 272   jax.tree_util.tree_map(                                                                  │
│   273 │     lambda x: hasattr(x, 'delete') and x.delete(), discarded)                            │
│   274   context().update(state)                                                                  │
│   275                                                                                            │
│                                                                                                  │
│ /dreamerv3/dreamerv3/ninjax.py:273 in <lambda>                                                   │
│                                                                                                  │
│   270 │   return                                                                                 │
│   271   discarded, state = fun(dict(context()), rng(), *args, ignore=True, **kwargs)             │
│   272   jax.tree_util.tree_map(                                                                  │
│ ❱ 273 │     lambda x: hasattr(x, 'delete') and x.delete(), discarded)                            │
│   274   context().update(state)                                                                  │
│   275                                                                                            │
│   276                                                                                            │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(float16[16,4096])>with<DynamicJaxprTrace(level=1/0)>
The delete() method was called on the JAX Tracer object Traced<ShapedArray(float16[16,4096])>with<DynamicJaxprTrace(level=1/0)>
The error occurred while tracing the function init at /dreamerv3/dreamerv3/ninjax.py:163 for jit. This concrete value was not available in Python because it depends on the values of the arguments 'statics', 
'rng', and 'args'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

BSuite Implementation

Hi,

It seems like there is no implementation in the embodied package for running the bsuite experiments. The configuration exists in the configs.yaml but executing it fails since it bsuite is not specified in the env dict in the training script.

I suppose this only requires creating a class for BSuite in embodied/envs analog to the DMC environment specification. I will try to make a PR for this.

Difference between model_loss and model_opt_loss

Hey,

I was just going through the logs for one of the runs generated using DreamerV3 and noticed that the code logs two different model losses (1) model_loss and (2) model_opt_loss. I found that the value of model_opt_loss is generally way higher than the model_loss for the Crafter environment. I was wondering if you could explain me the differences between the two? If you could point me to the equations in the paper they refer to, that would be great. Thanks!

Flag parsing fails in Colab

Hi Danijar,

Excelent work! I 'm trying to run the example.py and I get stuck on the config file.

Read the cofig file great until the following line, then it break.
disag_models: 8 (int)

│ ❱ 56 main() │
....
│ ❱ 25 config = embodied.Flags(config).parse() │
│ 26 │
│ 27 logdir = embodied.Path(config.logdir) │
│ 28 step = embodied.Counter() │
│ │
│ /store/.local/lib/python3.9/site-packages/dreamerv3/embodied/core/flags.py:17 in parse │
│ │
│ 14 │ for flag in remaining: │
│ 15 │ if flag.startswith('--'): │
│ 16 │ │ raise ValueError(f"Flag '{flag}' did not match any config keys.") │
│ ❱ 17 │ assert not remaining, remaining
│ 18 │ return parsed │
│ 19 │
│ 20 def parse_known(self, argv=None, help_exists=False): │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AssertionError: ['-f', '/store/.local/share/jupyter/runtime/kernel-00a1b6b5-8b7b-4902-91bc-f2a9d86e76db.json']

Custom encoder and decoder, error with policy initialization.

Hi there, thank you for releasing the code. I'm trying to replace the encoder and decoder with custom written ones. However, there's one error I could not figure out why. This line seems to be the source but agent.policy can correctly return

out, updated = apply(statics, selected, rng, *args, **kw)

During the policy initialization call, I did not call or use any state variables but I got the error which looks like this:

│ /home/maksim/Project/dreamerv3/example.py:63 in <module>                                         │
│                                                                                                  │
│   60                                                                                             │
│   61                                                                                             │
│   62 if __name__ == '__main__':                                                                  │
│ ❱ 63   main()                                                                                    │
│   64                                                                                             │
│                                                                                                  │
│ /home/maksim/Project/dreamerv3/example.py:58 in main                                             │
│                                                                                                  │
│   55   args = embodied.Config(                                                                   │
│   56 │     **config.run, logdir=config.logdir,                                                   │
│   57 │     batch_steps=config.batch_size * config.batch_length)                                  │
│ ❱ 58   embodied.run.train(agent, env, replay, logger, args)                                      │
│   59   # embodied.run.eval_only(agent, env, logger, args)                                        │
│   60                                                                                             │
│   61                                                                                             │
│                                                                                                  │
│ /home/maksim/Project/dreamerv3/dreamerv3/embodied/run/train.py:110 in train                      │
│                                                                                                  │
│   107   policy = lambda *args: agent.policy(                                                     │
│   108 │     *args, mode='explore' if should_expl(step) else 'train')                             │
│   109   while step < args.steps:                                                                 │
│ ❱ 110 │   driver(policy, steps=100)                                                              │
│   111 │   if should_save(step):                                                                  │
│   112 │     checkpoint.save()                                                                    │
│   113   logger.write()                                                                           │
│                                                                                                  │
│ /home/maksim/Project/dreamerv3/dreamerv3/embodied/core/driver.py:42 in __call__                  │
│                                                                                                  │
│   39   def __call__(self, policy, steps=0, episodes=0):                                          │
│   40 │   step, episode = 0, 0                                                                    │
│   41 │   while step < steps or episode < episodes:                                               │
│ ❱ 42 │     step, episode = self._step(policy, step, episode)                                     │
│   43                                                                                             │
│   44   def _step(self, policy, step, episode):                                                   │
│   45 │   assert all(len(x) == len(self._env) for x in self._acts.values())                       │
│                                                                                                  │
│ /home/maksim/Project/dreamerv3/dreamerv3/embodied/core/driver.py:50 in _step                     │
│                                                                                                  │
│   47 │   obs = self._env.step(acts)                                                              │
│   48 │   obs = {k: convert(v) for k, v in obs.items()}                                           │
│   49 │   assert all(len(x) == len(self._env) for x in obs.values()), obs                         │
│ ❱ 50 │   acts, self._state = policy(obs, self._state, **self._kwargs)                            │
│   51 │   acts = {k: convert(v) for k, v in acts.items()}                                         │
│   52 │   if obs['is_last'].any():                                                                │
│   53 │     mask = 1 - obs['is_last']                                                             │
│                                                                                                  │
│ /home/maksim/Project/dreamerv3/dreamerv3/embodied/run/train.py:107 in <lambda>                   │
│                                                                                                  │
│   104                                                                                            │
│   105   print('Start training loop.')                                                            │
│   106   print(args)                                                                              │
│ ❱ 107   policy = lambda *args: agent.policy(                                                     │
│   108 │     *args, mode='explore' if should_expl(step) else 'train')                             │
│   109   while step < args.steps:                                                                 │
│   110 │   driver(policy, steps=100)                                                              │
│                                                                                                  │
│ /usr/lib/python3.8/contextlib.py:75 in inner                                                     │
│                                                                                                  │
│    72 │   │   @wraps(func)                                                                       │
│    73 │   │   def inner(*args, **kwds):                                                          │
│    74 │   │   │   with self._recreate_cm():                                                      │
│ ❱  75 │   │   │   │   return func(*args, **kwds)                                                 │
│    76 │   │   return inner                                                                       │
│    77                                                                                            │
│    78                                                                                            │
│                                                                                                  │
│ /home/maksim/Project/dreamerv3/dreamerv3/jaxagent.py:63 in policy                                │
│                                                                                                  │
│    60 │   │     np.asarray, state, is_leaf=lambda x: isinstance(x, list))                        │
│    61 │     state = self._convert_inps(state, self.policy_devices)                               │
│    62 │   print('call self.policy')                                                              │
│ ❱  63 │   (outs, state), _ = self._policy(varibs, rng, obs, state, mode=mode)                    │
│    64 │   print('end of calling self.policy')                                                    │
│    65 │   outs = self._convert_outs(outs, self.policy_devices)                                   │
│    66 │   # TODO: Consider keeping policy states in accelerator memory.                          │
│                                                                                                  │
│ /home/maksim/Project/dreamerv3/dreamerv3/ninjax.py:227 in wrapper                                │
│                                                                                                  │
│   224 │     print(selected.keys(), kw.keys())                                                    │
│   225 │     # apply(statics, selected, rng, *args, **kw)                                         │
│   226 │     # print('applied')                                                                   │
│ ❱ 227 │     out, updated = apply(statics, selected, rng, *args, **kw)                            │
│   228 │     return out, {**state, **updated}                                                     │
│   229   return wrapper                                                                           │
│   230                                                                                            │
│                                                                                                  │
│ /home/maksim/.local/lib/python3.8/site-packages/jax/_src/array.py:631 in                         │
│ _array_mlir_constant_handler                                                                     │
│                                                                                                  │
│   628                                                                                            │
│   629                                                                                            │
│   630 def _array_mlir_constant_handler(val, canonicalize_types=True):                            │
│ ❱ 631   return mlir.ir_constants(val._value,                                                     │
│   632 │   │   │   │   │   │      canonicalize_types=canonicalize_types)                          │
│   633 mlir.register_constant_handler(ArrayImpl, _array_mlir_constant_handler)                    │
│   634                                                                                            │
│                                                                                                  │
│ /home/maksim/.local/lib/python3.8/site-packages/jax/_src/array.py:496 in _value                  │
│                                                                                                  │
│   493 │                                                                                          │
│   494 │   if self._npy_value is None:                                                            │
│   495 │     if self.is_fully_replicated:                                                         │
│ ❱ 496 │   │   self._npy_value = np.asarray(self._arrays[0])  # type: ignore                      │
│   497 │   │   self._npy_value.flags.writeable = False                                            │
│   498 │   │   return cast(np.ndarray, self._npy_value)                                           │
│   499                                                                                            │

XlaRuntimeError: INVALID_ARGUMENT: Disallowed device-to-host transfer: shape=(3, 3), dtype=float32, device=gpu:0

I did not modify the code concerning RSSM in nets.py. I'm totally a beginner using jax and this error is really strange to me.

distributed training

Hi!

I am trying to run the training across 2 gpus using the multi_gpu distributed training. I was wondering if this codebase support it. If yes, how can I do it? Thank you so much!

Bests,
Cristian

[Question] Is there a reason that all linear layers except output layers do not use bias term?

Hi Danijar,

I was wondering if it's on purpose that all intermediate layers do not use bias terms except for output layers. Additionally, all output layers use uniform initialisation, is this also intended? The following is not so important but if you are interested, here is the config parameters that each layer is using:

{'_path': 'agent/wm/rssm/img_out', '_submodules': {}, '_units': (640,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/wm/rssm/img_stats', '_submodules': {}, '_units': (1024,), '_act': <function get_act.<locals>.<lambda> at 0x7f783823f820>, '_norm': 'none', '_bias': True, '_outscale': 1.0, '_outnorm': False, '_winit': 'uniform', '_fan': 'avg'}
Tracing train function.
{'_path': 'agent/wm/enc/mlp/h0', '_submodules': {}, '_units': (1024,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/wm/enc/mlp/h1', '_submodules': {}, '_units': (1024,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/wm/enc/mlp/h2', '_submodules': {}, '_units': (1024,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/wm/enc/mlp/h3', '_submodules': {}, '_units': (1024,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/wm/enc/mlp/h4', '_submodules': {}, '_units': (1024,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/wm/rssm/img_in', '_submodules': {}, '_units': (640,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/wm/rssm/gru', '_submodules': {}, '_units': (3072,), '_act': <function get_act.<locals>.<lambda> at 0x7f77f1fbc9d0>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/wm/rssm/obs_out', '_submodules': {}, '_units': (640,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/wm/rssm/obs_stats', '_submodules': {}, '_units': (1024,), '_act': <function get_act.<locals>.<lambda> at 0x7f77ec723280>, '_norm': 'none', '_bias': True, '_outscale': 1.0, '_outnorm': False, '_winit': 'uniform', '_fan': 'avg'}
{'_path': 'agent/wm/dec/mlp/h0', '_submodules': {}, '_units': (1024,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/wm/dec/mlp/h1', '_submodules': {}, '_units': (1024,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/wm/dec/mlp/h2', '_submodules': {}, '_units': (1024,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/wm/dec/mlp/h3', '_submodules': {}, '_units': (1024,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/wm/dec/mlp/h4', '_submodules': {}, '_units': (1024,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/wm/dec/mlp/dist_vector/out', '_submodules': {}, '_units': (4,), '_act': <function get_act.<locals>.<lambda> at 0x7f77ec723ca0>, '_norm': 'none', '_bias': True, '_outscale': 1.0, '_outnorm': False, '_winit': 'uniform', '_fan': 'avg'}
{'_path': 'agent/wm/rew/h0', '_submodules': {}, '_units': (640,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/wm/rew/h1', '_submodules': {}, '_units': (640,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/wm/rew/h2', '_submodules': {}, '_units': (640,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/wm/rew/dist_out/out', '_submodules': {}, '_units': (255,), '_act': <function get_act.<locals>.<lambda> at 0x7f77ec5e4790>, '_norm': 'none', '_bias': True, '_outscale': 0.0, '_outnorm': False, '_winit': 'uniform', '_fan': 'avg'}
{'_path': 'agent/wm/cont/h0', '_submodules': {}, '_units': (640,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/wm/cont/h1', '_submodules': {}, '_units': (640,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/wm/cont/h2', '_submodules': {}, '_units': (640,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/wm/cont/dist_out/out', '_submodules': {}, '_units': (1,), '_act': <function get_act.<locals>.<lambda> at 0x7f77ec6faee0>, '_norm': 'none', '_bias': True, '_outscale': 1.0, '_outnorm': False, '_winit': 'uniform', '_fan': 'avg'}
Optimizer model_opt has 24,004,356 variables.
{'_path': 'agent/task_behavior/ac/actor/h0', '_submodules': {}, '_units': (640,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/task_behavior/ac/actor/h1', '_submodules': {}, '_units': (640,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/task_behavior/ac/actor/h2', '_submodules': {}, '_units': (640,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/task_behavior/ac/actor/dist_out/out', '_submodules': {}, '_units': (2,), '_act': <function get_act.<locals>.<lambda> at 0x7f77e40c61f0>, '_norm': 'none', '_bias': True, '_outscale': 1.0, '_outnorm': False, '_winit': 'uniform', '_fan': 'avg'}
{'_path': 'agent/task_behavior/critic/net/h0', '_submodules': {}, '_units': (640,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/task_behavior/critic/net/h1', '_submodules': {}, '_units': (640,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/task_behavior/critic/net/h2', '_submodules': {}, '_units': (640,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/task_behavior/critic/net/dist_out/out', '_submodules': {}, '_units': (255,), '_act': <function get_act.<locals>.<lambda> at 0x7f77c8694550>, '_norm': 'none', '_bias': True, '_outscale': 0.0, '_outnorm': False, '_winit': 'uniform', '_fan': 'avg'}
Optimizer actor_opt has 2,135,042 variables.
{'_path': 'agent/task_behavior/critic/slow/h0', '_submodules': {}, '_units': (640,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/task_behavior/critic/slow/h1', '_submodules': {}, '_units': (640,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/task_behavior/critic/slow/h2', '_submodules': {}, '_units': (640,), '_act': <PjitFunction of <function silu at 0x7f822c97bc10>>, '_norm': 'layer', '_bias': False, '_outscale': 1.0, '_outnorm': False, '_winit': 'normal', '_fan': 'avg'}
{'_path': 'agent/task_behavior/critic/slow/dist_out/out', '_submodules': {}, '_units': (255,), '_act': <function get_act.<locals>.<lambda> at 0x7f77b85914c0>, '_norm': 'none', '_bias': True, '_outscale': 0.0, '_outnorm': False, '_winit': 'uniform', '_fan': 'avg'}

Running out of RAM

I am trying to run the task dmc_manip_reach_site but got early termination error:

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── Step 1372562 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
episode/length 250 / episode/score 2.82 / episode/sum_abs_reward 2.82 / episode/reward_rate 0 / train/action_mag 5.06 / train/action_max 4.85 / train/action_mean -0.15 / train/action_min -4.92 / train/action_std 1.17 / train/actor_opt_actor_opt_grad_overflow 0 / 
train/actor_opt_actor_opt_grad_scale 1e4 / train/actor_opt_grad_norm 0.07 / train/actor_opt_grad_steps 2e4 / train/actor_opt_loss -86.24 / train/adv_mag 1.21 / train/adv_max 1.2 / train/adv_mean 6e-3 / train/adv_min -0.58 / train/adv_std 0.1 / train/cont_avg 1 / 
train/cont_loss_mean 2.4e-9 / train/cont_loss_std 2e-9 / train/cont_neg_acc nan / train/cont_neg_loss nan / train/cont_pos_acc 1 / train/cont_pos_loss 2.4e-9 / train/cont_pred 1 / train/cont_rate 1 / train/dyn_loss_mean 14.81 / train/dyn_loss_std 6.95 / 
train/extr_critic_critic_opt_critic_opt_grad_overflow 0 / train/extr_critic_critic_opt_critic_opt_grad_scale 1e4 / train/extr_critic_critic_opt_grad_norm 1.02 / train/extr_critic_critic_opt_grad_steps 2e4 / train/extr_critic_critic_opt_loss 1.3e4 / train/extr_critic_mag 
5.71 / train/extr_critic_max 5.71 / train/extr_critic_mean 2.75 / train/extr_critic_min 1.16 / train/extr_critic_std 1.02 / train/extr_return_normed_mag 2.1 / train/extr_return_normed_max 2.1 / train/extr_return_normed_mean 0.4 / train/extr_return_normed_min -0.07 / 
train/extr_return_normed_std 0.32 / train/extr_return_rate 1 / train/extr_return_raw_mag 8.65 / train/extr_return_raw_max 8.65 / train/extr_return_raw_mean 2.77 / train/extr_return_raw_min 1.16 / train/extr_return_raw_std 1.11 / train/extr_reward_mag 0.66 / 
train/extr_reward_max 0.66 / train/extr_reward_mean 7.8e-3 / train/extr_reward_min 3.3e-7 / train/extr_reward_std 0.04 / train/image_loss_mean 22.79 / train/image_loss_std 15.72 / train/model_loss_mean 31.74 / train/model_loss_std 18.99 / train/model_opt_grad_norm 34.37 / 
train/model_opt_grad_steps 2e4 / train/model_opt_loss 8.6e4 / train/model_opt_model_opt_grad_overflow 0 / train/model_opt_model_opt_grad_scale 2713.41 / train/policy_entropy_mag 12.59 / train/policy_entropy_max 12.59 / train/policy_entropy_mean 9.17 / 
train/policy_entropy_min -1.27 / train/policy_entropy_std 2.6 / train/policy_logprob_mag 24.24 / train/policy_logprob_max 3.47 / train/policy_logprob_mean -9.17 / train/policy_logprob_min -24.24 / train/policy_logprob_std 3.36 / train/policy_randomness_mag 0.99 / 
train/policy_randomness_max 0.99 / train/policy_randomness_mean 0.83 / train/policy_randomness_min 0.32 / train/policy_randomness_std 0.13 / train/post_ent_mag 50.2 / train/post_ent_max 50.2 / train/post_ent_mean 37.2 / train/post_ent_min 18.75 / train/post_ent_std 6.65 / 
train/prior_ent_mag 54.55 / train/prior_ent_max 54.55 / train/prior_ent_mean 52.08 / train/prior_ent_min 49.31 / train/prior_ent_std 0.76 / train/rep_loss_mean 14.81 / train/rep_loss_std 6.95 / train/reward_avg 9.2e-3 / train/reward_loss_mean 0.06 / train/reward_loss_std 
0.3 / train/reward_max_data 0.55 / train/reward_max_pred 0.43 / train/reward_neg_acc 0.99 / train/reward_neg_loss 0.02 / train/reward_pos_acc 0.55 / train/reward_pos_loss 2.45 / train/reward_pred 8.6e-3 / train/reward_rate 0.02 / stats/mean_log_entropy 8.56 / replay/size 
6.9e5 / replay/inserts 7.9e4 / replay/samples 3.9e4 / replay/insert_wait_avg 6.6e-7 / replay/insert_wait_frac 1 / replay/sample_wait_avg 6.8e-7 / replay/sample_wait_frac 1 / timer/duration 343.21 / timer/env.step_count 2e4 / timer/env.step_total 73.98 / timer/env.step_frac 
0.22 / timer/env.step_avg 3.8e-3 / timer/env.step_min 2.7e-3 / timer/env.step_max 1.22 / timer/replay.add_count 7.9e4 / timer/replay.add_total 3.77 / timer/replay.add_frac 0.01 / timer/replay.add_avg 4.8e-5 / timer/replay.add_min 1.4e-5 / timer/replay.add_max 0.03 / 
timer/logger.write_count 1 / timer/logger.write_total 7.8e-3 / timer/logger.write_frac 2.3e-5 / timer/logger.write_avg 7.8e-3 / timer/logger.write_min 7.8e-3 / timer/logger.write_max 7.8e-3 / timer/checkpoint.save_count 0 / timer/checkpoint.save_total 0 / 
timer/checkpoint.save_frac 0 / timer/agent.save_count 0 / timer/agent.save_total 0 / timer/agent.save_frac 0 / timer/replay.save_count 0 / timer/replay.save_total 0 / timer/replay.save_frac 0 / timer/agent.policy_count 2e4 / timer/agent.policy_total 38.96 / 
timer/agent.policy_frac 0.11 / timer/agent.policy_avg 2e-3 / timer/agent.policy_min 9.1e-4 / timer/agent.policy_max 6.22 / timer/dataset_count 2464 / timer/dataset_total 0.14 / timer/dataset_frac 4e-4 / timer/dataset_avg 5.5e-5 / timer/dataset_min 4.6e-5 / timer/dataset_max
2.3e-3 / timer/agent.train_count 2464 / timer/agent.train_total 185.67 / timer/agent.train_frac 0.54 / timer/agent.train_avg 0.08 / timer/agent.train_min 0.07 / timer/agent.train_max 1.6 / timer/agent.report_count 1 / timer/agent.report_total 35.25 / timer/agent.report_frac
0.1 / timer/agent.report_avg 35.25 / timer/agent.report_min 35.25 / timer/agent.report_max 35.25 / fps 452.33

fish: Job 1, 'python dreamerv3/train.py \…' terminated by signal -… (SIGKILL)
fish: Job Forced quit, '' terminated by signal  ()

I tried it second time but it always stopped at step around 1.3e6 where max steps is set as 1e10

Error when training on Pong in Colab

I am trying this in Google Colab on a CPU instance so please let me know if this is not possible?

I installed DreamerV3 using:

!pip install dreamerv3

Then copied the example.py file with a couple of changes:

import warnings
import dreamerv3
from dreamerv3 import embodied
warnings.filterwarnings('ignore', '.*truncated to dtype int32.*')

# See configs.yaml for all options.
config = embodied.Config(dreamerv3.configs['defaults'])
config = config.update(dreamerv3.configs['small'])
config = config.update({
    'logdir': '~/logdir/run3',
    'run.train_ratio': 64,
    'run.eval_eps': 10,
    'run.log_every': 30,  # Seconds
    'batch_size': 16,
    'jax.prealloc': False,
    'encoder.mlp_keys': '$^',
    'decoder.mlp_keys': '$^',
    'encoder.cnn_keys': 'image',
    'decoder.cnn_keys': 'image',
    'jax.platform': 'cpu',
})
config = embodied.Flags(config).parse()

logdir = embodied.Path(config.logdir)
step = embodied.Counter()
logger = embodied.Logger(step, [
    embodied.logger.TerminalOutput(),
    embodied.logger.JSONLOutput(logdir, 'metrics.jsonl'),
    embodied.logger.TensorBoardOutput(logdir),
    # embodied.logger.WandBOutput(logdir.name, config),
    # embodied.logger.MLFlowOutput(logdir.name),
])

import gym
from embodied.envs import from_gym
from gym.wrappers import ResizeObservation
env = gym.make("PongNoFrameskip-v4")  # Replace this with your Gym env.
env = ResizeObservation(env, 128)
env = from_gym.FromGym(env)
env = dreamerv3.wrap_env(env, config)
env = embodied.BatchEnv([env], parallel=False)

agent = dreamerv3.Agent(env.obs_space, env.act_space, step, config)
replay = embodied.replay.Uniform(
    config.batch_length, config.replay_size, logdir / 'replay')
args = embodied.Config(
    **config.run, logdir=config.logdir,
    batch_steps=config.batch_size * config.batch_length)
embodied.run.train(agent, env, replay, logger, args)

The error I receive seems to be linked to the line:

config = embodied.Flags(config).parse()

The error is:

Traceback (most recent call last)
in /usr/local/lib/python3.8/dist-packages/dreamerv3/embodied/core/flags.py:17 in parse
for flag in remaining:
if flag.startswith('--'):
raise ValueError(f"Flag '{flag}' did not match any config keys.")
assert not remaining, remaining
return parsed
def parse_known(self, argv=None, help_exists=False):

AssertionError: ['-f', '/root/.local/share/jupyter/runtime/kernel-0d0baf68-d3f6-463e-b3c8-d0320085b92e.json']

I did try commenting out the config line but then received the following error:

Traceback (most recent call last)
in /usr/local/lib/python3.8/dist-packages/dreamerv3/embodied/run/train.py:106 in train

policy = lambda *args: agent.policy(*args, mode='explore' if should_expl(step) else 'train')
while step < args.steps:
driver(policy, steps=100)
if should_save(step):
checkpoint.save()
logger.write()

/usr/local/lib/python3.8/dist-packages/dreamerv3/embodied/core/driver.py:42 in call

def call(self, policy, steps=0, episodes=0):
step, episode = 0, 0
while step < steps or episode < episodes:
step, episode = self._step(policy, step, episode)
def _step(self, policy, step, episode):
assert all(len(x) == len(self._env) for x in self._acts.values())

/usr/local/lib/python3.8/dist-packages/dreamerv3/embodied/core/driver.py:65 in _step

for i in range(len(self._env)):
trn = {k: v[i] for k, v in trns.items()}
[self._eps[i][k].append(v) for k, v in trn.items()]
[fn(trn, i, **self._kwargs) for fn in self._on_steps]
step += 1
if obs['is_last'].any():
for i, done in enumerate(obs['is_last']):

/usr/local/lib/python3.8/dist-packages/dreamerv3/embodied/core/driver.py:65 in

for i in range(len(self._env)):
trn = {k: v[i] for k, v in trns.items()}
[self._eps[i][k].append(v) for k, v in trn.items()]
[fn(trn, i, **self._kwargs) for fn in self._on_steps]
step += 1
if obs['is_last'].any():
for i, done in enumerate(obs['is_last']):

/usr/local/lib/python3.8/dist-packages/dreamerv3/embodied/run/train.py:75 in train_step
for _ in range(should_train(step)):
with timer.scope('dataset'):
batch[0] = next(dataset)
outs, state[0], mets = agent.train(batch[0], state[0])
metrics.add(mets, prefix='train')
if 'priority' in outs:
replay.prioritize(outs['key'], outs['priority'])

/usr/lib/python3.8/contextlib.py:75 in inner

@wraps(func)
def inner(*args, **kwds):
with self._recreate_cm():
return func(*args, **kwds)
return inner

/usr/local/lib/python3.8/dist-packages/dreamerv3/jaxagent.py:76 in train

self.once = False
assert jaxutils.Optimizer.PARAM_COUNTS
for name, count in jaxutils.Optimizer.PARAM_COUNTS.items():
mets[f'params
{name}'] = float(count)
return outs, state, mets

def report(self, data):

TypeError: float() argument must be a string or a number, not 'NoneType'

Hopefully I have provided enough information. I'll be happy to provide more if needed.

Thanks,

Load trained weights into agent and get predicted actions

Hello, thanks for sharing this amazing piece of work!

Is there an easy way to load the trained weights from the checkpoint.pkl into an agent and get the predicted action from it (agent.policy(obs, state, mode='eval'))['action']). The idea would be to visualize online in a standard pygame loop for instsance?

Looking at the code, I guess the easiest would be to use the dremerv3.Agent class, but I don't understand how to load the weights from the pickle file 😅

unsupported dtype: object

Hi Danijar,

First congrats by excellent work.

I'm trying to run dreamerv3 using a custom gym environment which have a observation image:

>>> env.observation_space
Box(0, 255, (64, 64, 3), uint8)

But I'm getting this error:

Encoder CNN shapes: {'image': (64, 64, 3)}
Encoder MLP shapes: {}
Decoder CNN shapes: {'image': (64, 64, 3)}
Decoder MLP shapes: {}
JAX devices (1): [CpuDevice(id=0)]
Policy devices: TFRT_CPU_0
Train devices:  TFRT_CPU_0
Tracing train function.
Optimizer model_opt has 181,562,755 variables.
Optimizer actor_opt has 9,457,674 variables.
Optimizer critic_opt has 9,708,799 variables.
Logdir /Users/fernando/logdir/run1
Observation space:
  image            Space(dtype=uint8, shape=(64, 64, 3), low=0, high=255)
  reward           Space(dtype=float32, shape=(), low=-inf, high=inf)
  is_first         Space(dtype=bool, shape=(), low=False, high=True)
  is_last          Space(dtype=bool, shape=(), low=False, high=True)
  is_terminal      Space(dtype=bool, shape=(), low=False, high=True)
Action space:
  action           Space(dtype=float32, shape=(10,), low=0, high=1)
  reset            Space(dtype=bool, shape=(), low=False, high=True)
Prefill train dataset.
/Users/fernando/Documents/dev/projects/dreamerv3/dreamerv3/embodied/envs/from_gym.py:72: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
  obs = {k: np.asarray(v) for k, v in obs.items()}
{'image': array([[array([[[  0,   0,   0],
                [  0,   0,   0],
                [  0,   0,   0],
                ...,
                [ 75, 214,  16],
                [ 75, 214,  16],
                [ 75, 214,  16]],

               [[  0,   0,   0],
                [  0,   0,   0],
                [  0,   0,   0],
                ...,
                [ 75, 214,  16],
                [ 75, 214,  16],
                [ 75, 214,  16]],

               [[  0,   0,   0],
                [  0,   0,   0],
                [  0,   0,   0],
                ...,
                [ 75, 214,  16],
                [ 75, 214,  16],
                [ 75, 214,  16]],

               ...,

               [[  0,   0,   0],
                [  0,   0,   0],
                [  0,   0,   0],
                ...,
                [ 95, 211,  64],
                [ 95, 211,  64],
                [ 95, 211,  64]],

               [[  0,   0,   0],
                [  0,   0,   0],
                [  0,   0,   0],
                ...,
                [ 95, 211,  64],
                [ 95, 211,  64],
                [ 95, 211,  64]],

               [[  0,   0,   0],
                [  0,   0,   0],
                [  0,   0,   0],
                ...,
                [ 95, 211,  64],
                [ 95, 211,  64],
                [ 95, 211,  64]]], dtype=uint8), {}]], dtype=object), 'reward': array([0.], dtype=float32), 'is_first': array([ True]), 'is_last': array([False]), 'is_terminal': array([False])}
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /Users/fernando/Documents/dev/projects/dreamerv3/myenv.py:56 in <module>                        │
│                                                                                                  │
│   53                                                                                             │
│   54                                                                                             │
│   55 if __name__ == '__main__':                                                                  │
│ ❱ 56   main()                                                                                    │
│   57                                                                                             │
│                                                                                                  │
│ /Users/fernando/Documents/dev/projects/dreamerv3/myenv.py:51 in main                            │
│                                                                                                  │
│   48   args = embodied.Config(                                                                   │
│   49 │     **config.run, logdir=config.logdir,                                                   │
│   50 │     batch_steps=config.batch_size * config.batch_length)                                  │
│ ❱ 51   embodied.run.train(agent, env, replay, logger, args)                                      │
│   52   # embodied.run.eval_only(agent, env, logger, args)                                        │
│   53                                                                                             │
│   54                                                                                             │
│                                                                                                  │
│ /Users/fernando/Documents/dev/projects/dreamerv3/dreamerv3/embodied/run/train.py:65 in train     │
│                                                                                                  │
│    62   print('Prefill train dataset.')                                                          │
│    63   random_agent = embodied.RandomAgent(env.act_space)                                       │
│    64   while len(replay) < max(args.batch_steps, args.train_fill):                              │
│ ❱  65 │   driver(random_agent.policy, steps=100)                                                 │
│    66   logger.add(metrics.result())                                                             │
│    67   logger.write()                                                                           │
│    68                                                                                            │
│                                                                                                  │
│ /Users/fernando/Documents/dev/projects/dreamerv3/dreamerv3/embodied/core/driver.py:42 in         │
│ __call__                                                                                         │
│                                                                                                  │
│   39   def __call__(self, policy, steps=0, episodes=0):                                          │
│   40 │   step, episode = 0, 0                                                                    │
│   41 │   while step < steps or episode < episodes:                                               │
│ ❱ 42 │     step, episode = self._step(policy, step, episode)                                     │
│   43                                                                                             │
│   44   def _step(self, policy, step, episode):                                                   │
│   45 │   assert all(len(x) == len(self._env) for x in self._acts.values())                       │
│                                                                                                  │
│ /Users/fernando/Documents/dev/projects/dreamerv3/dreamerv3/embodied/core/driver.py:49 in _step   │
│                                                                                                  │
│   46 │   acts = {k: v for k, v in self._acts.items() if not k.startswith('log_')}                │
│   47 │   obs = self._env.step(acts)                                                              │
│   48 │   print(obs)                                                                              │
│ ❱ 49 │   obs = {k: convert(v) for k, v in obs.items()}                                           │
│   50 │   assert all(len(x) == len(self._env) for x in obs.values()), obs                         │
│   51 │   acts, self._state = policy(obs, self._state, **self._kwargs)                            │
│   52 │   acts = {k: convert(v) for k, v in acts.items()}                                         │
│                                                                                                  │
│ /Users/fernando/Documents/dev/projects/dreamerv3/dreamerv3/embodied/core/driver.py:49 in         │
│ <dictcomp>                                                                                       │
│                                                                                                  │
│   46 │   acts = {k: v for k, v in self._acts.items() if not k.startswith('log_')}                │
│   47 │   obs = self._env.step(acts)                                                              │
│   48 │   print(obs)                                                                              │
│ ❱ 49 │   obs = {k: convert(v) for k, v in obs.items()}                                           │
│   50 │   assert all(len(x) == len(self._env) for x in obs.values()), obs                         │
│   51 │   acts, self._state = policy(obs, self._state, **self._kwargs)                            │
│   52 │   acts = {k: convert(v) for k, v in acts.items()}                                         │
│                                                                                                  │
│ /Users/fernando/Documents/dev/projects/dreamerv3/dreamerv3/embodied/core/basics.py:32 in convert │
│                                                                                                  │
│    29 │   │     value = value.astype(dst)                                                        │
│    30 │   │   break                                                                              │
│    31 │   else:                                                                                  │
│ ❱  32 │     raise TypeError(f"Object '{value}' has unsupported dtype: {value.dtype}")            │
│    33   return value                                                                             │
│    34                                                                                            │
│    35                                                                                            │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
TypeError: Object '[[array([[[  0,   0,   0],
          [  0,   0,   0],
          [  0,   0,   0],
          ...,
          [ 75, 214,  16],
          [ 75, 214,  16],
          [ 75, 214,  16]],

         [[  0,   0,   0],
          [  0,   0,   0],
          [  0,   0,   0],
          ...,
          [ 75, 214,  16],
          [ 75, 214,  16],
          [ 75, 214,  16]],

         [[  0,   0,   0],
          [  0,   0,   0],
          [  0,   0,   0],
          ...,
          [ 75, 214,  16],
          [ 75, 214,  16],
          [ 75, 214,  16]],

         ...,

         [[  0,   0,   0],
          [  0,   0,   0],
          [  0,   0,   0],
          ...,
          [ 95, 211,  64],
          [ 95, 211,  64],
          [ 95, 211,  64]],

         [[  0,   0,   0],
          [  0,   0,   0],
          [  0,   0,   0],
          ...,
          [ 95, 211,  64],
          [ 95, 211,  64],
          [ 95, 211,  64]],

         [[  0,   0,   0],
          [  0,   0,   0],
          [  0,   0,   0],
          ...,
          [ 95, 211,  64],
          [ 95, 211,  64],
          [ 95, 211,  64]]], dtype=uint8) {}]]' has unsupported dtype: object

I suspect this is a gym version problem (my environment is using gym==0.26.2) your version gym==0.19.0 is not available to Mac M1.

Have you any idea how to fix this?

Potential to Release (Some) Pretrained Models?

Hi Danijar,

Great research and implementation. Do you have future plans to release pre-trained model weights for some environments? This could aid in research that aims to study the transfer of world models, mechanistically interpreting components of models, for the purpose of creating offline datasets, and other ideas relating to https://reincarnating-rl.github.io/.

Daniel

Out of memory error on cluster

Screenshot 2023-04-18 at 2 29 32 PM
I have been running DMLab Goals small & Crafter experiments on a cluster and after ~10 hours I get an out of memory error. I am allocating 32 GB of memory for the job. The error doesn't happen at a fixed time, so I don't believe that a forward pass through any of the network is the issue. Could there be some memory leak? Do you have any idea what could be causing this?

Thx

Hi, first of all - BIG thank you for your grate work.

[BUG] some bug for jax gpu version

My device is ubuntu20.04 NVIDIA-SMI 515.86.01 Driver Version: 515.86.01 CUDA Version: 11.7, CUDNN 870, Python3.8
pip list:

when I try to execute 'python example.py', I got the follwing bug:

(dreamerv3) weidong@user-NULL:~/dreamerv3$ python example.py
2023-02-19 16:52:05.740320: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/weidong/.mujoco/mujoco210/bin:/usr/lib/nvidia:/home/weidong/.mujoco/mujoco200/bin:/usr/lib/nvidia-000:/home/weidong/.mujoco/mujoco210/bin:/usr/lib/nvidia:/home/weidong/.mujoco/mujoco200/bin:/usr/lib/nvidia-000
2023-02-19 16:52:05.740423: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/weidong/.mujoco/mujoco210/bin:/usr/lib/nvidia:/home/weidong/.mujoco/mujoco200/bin:/usr/lib/nvidia-000:/home/weidong/.mujoco/mujoco210/bin:/usr/lib/nvidia:/home/weidong/.mujoco/mujoco200/bin:/usr/lib/nvidia-000
2023-02-19 16:52:05.740432: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/weidong/dreamerv3/example.py:48 in <module>                                                │
│                                                                                                  │
│   45                                                                                             │
│   46                                                                                             │
│   47 if __name__ == '__main__':                                                                  │
│ ❱ 48   main()                                                                                    │
│   49                                                                                             │
│                                                                                                  │
│ /home/weidong/dreamerv3/example.py:40 in main                                                    │
│                                                                                                  │
│   37   env = dreamerv3.wrap_env(env, config.wrapper)                                             │
│   38   env = embodied.BatchEnv([env], parallel=False)                                            │
│   39                                                                                             │
│ ❱ 40   agent = dreamerv3.Agent(env.obs_space, env.act_space, step, config)                       │
│   41   replay = embodied.replay.Uniform(                                                         │
│   42 │     config.batch_length, config.replay_size, logdir / 'replay')                           │
│   43   args = config.run.update(batch_steps=config.batch_size * config.batch_length)             │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/jaxagent.py:20 in __init__                                     │
│                                                                                                  │
│    17 │   configs = agent_cls.configs                                                            │
│    18 │   inner = agent_cls                                                                      │
│    19 │   def __init__(self, obs_space, act_space, step, config):                                │
│ ❱  20 │     super().__init__(agent_cls, obs_space, act_space, step, config)                      │
│    21   return Agent                                                                             │
│    22                                                                                            │
│    23                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/jaxagent.py:28 in __init__                                     │
│                                                                                                  │
│    25                                                                                            │
│    26   def __init__(self, agent_cls, obs_space, act_space, step, config):                       │
│    27 │   self.config = config.jax                                                               │
│ ❱  28 │   self.setup()                                                                           │
│    29 │   self.agent = agent_cls(obs_space, act_space, step, config, name='agent')               │
│    30 │   self.rng = jaxutils.RNG(config.seed)                                                   │
│    31 │   self.varibs = {}                                                                       │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/jaxagent.py:73 in setup                                        │
│                                                                                                  │
│    70 │   if self.config.platform == 'cpu':                                                      │
│    71 │     jax.config.update('jax_disable_most_optimizations', self.config.debug)               │
│    72 │   jaxutils.COMPUTE_DTYPE = getattr(jnp, self.config.precision)                           │
│ ❱  73 │   print(f'JAX DEVICES ({jax.local_device_count()}):', jax.devices())                     │
│    74                                                                                            │
│    75   def train(self, data, state=None):                                                       │
│    76 │   data = self._convert_inps(data)                                                        │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py:45 │
│ 7 in local_device_count                                                                          │
│                                                                                                  │
│   454                                                                                            │
│   455 def local_device_count(backend: Optional[Union[str, XlaBackend]] = None) -> int:           │
│   456   """Returns the number of devices addressable by this process."""                         │
│ ❱ 457   return int(get_backend(backend).local_device_count())                                    │
│   458                                                                                            │
│   459                                                                                            │
│   460 def devices(backend: Optional[Union[str, XlaBackend]] = None) -> List[xla_client.Device]   │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py:42 │
│ 5 in get_backend                                                                                 │
│                                                                                                  │
│   422                                                                                            │
│   423 @lru_cache(maxsize=None)  # don't use util.memoize because there is no X64 dependence.     │
│   424 def get_backend(platform=None):                                                            │
│ ❱ 425   return _get_backend_uncached(platform)                                                   │
│   426                                                                                            │
│   427                                                                                            │
│   428 def get_device_backend(device=None):                                                       │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py:41 │
│ 1 in _get_backend_uncached                                                                       │
│                                                                                                  │
│   408                                                                                            │
│   409   bs = backends()                                                                          │
│   410   if platform is not None:                                                                 │
│ ❱ 411 │   platform = canonicalize_platform(platform)                                             │
│   412 │   backend = bs.get(platform, None)                                                       │
│   413 │   if backend is None:                                                                    │
│   414 │     if platform in _backends_errors:                                                     │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py:29 │
│ 4 in canonicalize_platform                                                                       │
│                                                                                                  │
│   291   for p in platforms:                                                                      │
│   292 │   if p in b.keys():                                                                      │
│   293 │     return p                                                                             │
│ ❱ 294   raise RuntimeError(f"Unknown backend: '{platform}' requested, but no "                   │
│   295 │   │   │   │   │    f"platforms that are instances of {platform} are present. "           │
│   296 │   │   │   │   │    "Platforms present are: " + ",".join(b.keys()))                       │
│   297                                                                                            │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Unknown backend: 'gpu' requested, but no platforms that are instances of gpu are present. 
Platforms present are: interpreter,cpu

After read the jax GPU gudide install the GPU jax and GPU tensorflow :

pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install tensorflow

But I got an another bug follow:

(dreamerv3) weidong@user-NULL:~/dreamerv3$ python example.py
2023-02-19 17:07:51.659700: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/weidong/.mujoco/mujoco210/bin:/usr/lib/nvidia:/home/weidong/.mujoco/mujoco200/bin:/usr/lib/nvidia-000:/home/weidong/.mujoco/mujoco210/bin:/usr/lib/nvidia:/home/weidong/.mujoco/mujoco200/bin:/usr/lib/nvidia-000
2023-02-19 17:07:51.659805: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/weidong/.mujoco/mujoco210/bin:/usr/lib/nvidia:/home/weidong/.mujoco/mujoco200/bin:/usr/lib/nvidia-000:/home/weidong/.mujoco/mujoco210/bin:/usr/lib/nvidia:/home/weidong/.mujoco/mujoco200/bin:/usr/lib/nvidia-000
2023-02-19 17:07:51.659814: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
JAX DEVICES (8): [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=2, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=3, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=4, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=5, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=6, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=7, process_index=0, slice_index=0)]
Encoder CNN shapes: {'image': (64, 64, 3)}
Encoder MLP shapes: {}
Decoder CNN shapes: {'image': (64, 64, 3)}
Decoder MLP shapes: {}
Logdir /home/weidong/logdir/run1
Observation space:
  image            Space(dtype=uint8, shape=(64, 64, 3), low=0, high=255)
  reward           Space(dtype=float32, shape=(), low=-inf, high=inf)
  is_first         Space(dtype=bool, shape=(), low=False, high=True)
  is_last          Space(dtype=bool, shape=(), low=False, high=True)
  is_terminal      Space(dtype=bool, shape=(), low=False, high=True)
Action space:
  action           Space(dtype=float32, shape=(17,), low=0, high=1)
  reset            Space(dtype=bool, shape=(), low=False, high=True)
Fill train dataset (1024 steps).
Episode has 147 steps and return 2.1.
Episode has 305 steps and return 2.1.
Episode has 110 steps and return 1.1.
Episode has 176 steps and return 0.1.
Episode has 140 steps and return 0.1.
───────────────────────────────────────────────── Step 1024 ─────────────────────────────────────────────────
episode/length 140 / episode/score 0.1 / episode/sum_abs_reward 2.1 / episode/reward_rate 0.01

Creating new TensorBoard event file writer.
Saved chunk: 20230219T170756F065708-18SVIAerO9mVKu8c3SI3e2-1Oa8uUlL3lQQnPBO7aaVU7-1024.npz
Tracing train function.
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/weidong/dreamerv3/example.py:48 in <module>                                                │
│                                                                                                  │
│   45                                                                                             │
│   46                                                                                             │
│   47 if __name__ == '__main__':                                                                  │
│ ❱ 48   main()                                                                                    │
│   49                                                                                             │
│                                                                                                  │
│ /home/weidong/dreamerv3/example.py:44 in main                                                    │
│                                                                                                  │
│   41   replay = embodied.replay.Uniform(                                                         │
│   42 │     config.batch_length, config.replay_size, logdir / 'replay')                           │
│   43   args = config.run.update(batch_steps=config.batch_size * config.batch_length)             │
│ ❱ 44   embodied.run.train(agent, env, replay, logger, args)                                      │
│   45                                                                                             │
│   46                                                                                             │
│   47 if __name__ == '__main__':                                                                  │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/embodied/run/train.py:79 in train                              │
│                                                                                                  │
│    76   for _ in range(args.pretrain):                                                           │
│    77 │   with timer.scope('dataset'):                                                           │
│    78 │     batch = next(dataset)                                                                │
│ ❱  79 │   _, state[0], _ = agent.train(batch, state[0])                                          │
│    80                                                                                            │
│    81   batch = [None]                                                                           │
│    82   def train_step(tran, worker):                                                            │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner                   │
│                                                                                                  │
│    72 │   │   @wraps(func)                                                                       │
│    73 │   │   def inner(*args, **kwds):                                                          │
│    74 │   │   │   with self._recreate_cm():                                                      │
│ ❱  75 │   │   │   │   return func(*args, **kwds)                                                 │
│    76 │   │   return inner                                                                       │
│    77                                                                                            │
│    78                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/jaxagent.py:80 in train                                        │
│                                                                                                  │
│    77 │   rng = self._next_rngs(mirror=not self.varibs)                                          │
│    78 │   if state is None:                                                                      │
│    79 │     state, self.varibs = self._init_train(self.varibs, rng, data['is_first'])            │
│ ❱  80 │   (outs, state, mets), self.varibs = self._train(                                        │
│    81 │   │   self.varibs, rng, data, state)                                                     │
│    82 │   outs = self._convert_outs(outs)                                                        │
│    83 │   mets = self._convert_mets(mets)                                                        │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:181 in wrapper                                       │
│                                                                                                  │
│   178 │   statics = tuple(sorted([(k, v) for k, v in kwargs.items() if k in static]))            │
│   179 │   kwargs = {k: v for k, v in kwargs.items() if k not in static}                          │
│   180 │   if not hasattr(wrapper, 'keys'):                                                       │
│ ❱ 181 │     created = init(statics, rng, *args, **kwargs)                                        │
│   182 │     wrapper.keys = set(created.keys())                                                   │
│   183 │     for key, value in created.items():                                                   │
│   184 │   │   if key not in state:                                                               │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/traceback_util.py:16 │
│ 3 in reraise_with_filtered_traceback                                                             │
│                                                                                                  │
│   160   def reraise_with_filtered_traceback(*args, **kwargs):                                    │
│   161 │   __tracebackhide__ = True                                                               │
│   162 │   try:                                                                                   │
│ ❱ 163 │     return fun(*args, **kwargs)                                                          │
│   164 │   except Exception as e:                                                                 │
│   165 │     mode = filtering_mode()                                                              │
│   166 │     if is_under_reraiser(e) or mode == "off":                                            │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/pjit.py:237 in       │
│ cache_miss                                                                                       │
│                                                                                                  │
│    234                                                                                           │
│    235   @api_boundary                                                                           │
│    236   def cache_miss(*args, **kwargs):                                                        │
│ ❱  237 │   outs, out_flat, out_tree, args_flat = _python_pjit_helper(                            │
│    238 │   │   fun, infer_params_fn, *args, **kwargs)                                            │
│    239 │                                                                                         │
│    240 │   executable = _read_most_recent_pjit_call_executable()                                 │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/pjit.py:180 in       │
│ _python_pjit_helper                                                                              │
│                                                                                                  │
│    177                                                                                           │
│    178                                                                                           │
│    179 def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs):                           │
│ ❱  180   args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(                           │
│    181 │     *args, **kwargs)                                                                    │
│    182   for arg in args_flat:                                                                   │
│    183 │   dispatch.check_arg(arg)                                                               │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/api.py:443 in        │
│ infer_params                                                                                     │
│                                                                                                  │
│    440 │   │     static_argnames=static_argnames, donate_argnums=donate_argnums,                 │
│    441 │   │     device=device, backend=backend, keep_unused=keep_unused,                        │
│    442 │   │     inline=inline, resource_env=None)                                               │
│ ❱  443 │     return pjit.common_infer_params(pjit_info_args, *args, **kwargs)                    │
│    444 │                                                                                         │
│    445 │   has_explicit_sharding = pjit._pjit_explicit_sharding(                                 │
│    446 │   │   in_shardings, out_shardings, device, backend)                                     │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/pjit.py:520 in       │
│ common_infer_params                                                                              │
│                                                                                                  │
│    517 │     hashable_pytree(in_shardings), local_in_avals, in_tree, in_positional_semantics,    │
│    518 │     tuple(isinstance(a, GDA) for a in args_flat), resource_env)                         │
│    519                                                                                           │
│ ❱  520   jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(                          │
│    521 │     flat_fun, hashable_pytree(out_shardings), global_in_avals,                          │
│    522 │     HashableFunction(out_tree, closure=()),                                             │
│    523 │     ('jit' if resource_env is None else 'pjit'))                                        │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/linear_util.py:301   │
│ in memoized_fun                                                                                  │
│                                                                                                  │
│   298 │     ans, stores = result                                                                 │
│   299 │     fun.populate_stores(stores)                                                          │
│   300 │   else:                                                                                  │
│ ❱ 301 │     ans = call(fun, *args)                                                               │
│   302 │     cache[key] = (ans, fun.stores)                                                       │
│   303 │                                                                                          │
│   304 │   return ans                                                                             │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/pjit.py:932 in       │
│ _pjit_jaxpr                                                                                      │
│                                                                                                  │
│    929 │   with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "     │
│    930 │   │   │   │   │   │   │   │      "for pjit in {elapsed_time} sec",                      │
│    931 │   │   │   │   │   │   │   │   │   event=dispatch.JAXPR_TRACE_EVENT):                    │
│ ❱  932 │     jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(                        │
│    933 │   │     fun, global_in_avals, debug_info=pe.debug_info_final(fun, api_name))            │
│    934   finally:                                                                                │
│    935 │   pxla.positional_semantics.val = prev_positional_val                                   │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/profiler.py:314 in   │
│ wrapper                                                                                          │
│                                                                                                  │
│   311   @wraps(func)                                                                             │
│   312   def wrapper(*args, **kwargs):                                                            │
│   313 │   with TraceAnnotation(name, **decorator_kwargs):                                        │
│ ❱ 314 │     return func(*args, **kwargs)                                                         │
│   315 │   return wrapper                                                                         │
│   316   return wrapper                                                                           │
│   317                                                                                            │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/interpreters/partial_eval │
│ .py:1985 in trace_to_jaxpr_dynamic                                                               │
│                                                                                                  │
│   1982 │   │   │   │   │   │      keep_inputs: Optional[List[bool]] = None):                     │
│   1983   with core.new_main(DynamicJaxprTrace, dynamic=True) as main:  # type: ignore            │
│   1984 │   main.jaxpr_stack = ()  # type: ignore                                                 │
│ ❱ 1985 │   jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(                                 │
│   1986 │     fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info)                │
│   1987 │   del main, fun                                                                         │
│   1988   return jaxpr, out_avals, consts                                                         │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/interpreters/partial_eval │
│ .py:2002 in trace_to_subjaxpr_dynamic                                                            │
│                                                                                                  │
│   1999 │   trace = DynamicJaxprTrace(main, core.cur_sublevel())                                  │
│   2000 │   in_tracers = _input_type_to_tracers(trace.new_arg, in_avals)                          │
│   2001 │   in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]                 │
│ ❱ 2002 │   ans = fun.call_wrapped(*in_tracers_)                                                  │
│   2003 │   out_tracers = map(trace.full_raise, ans)                                              │
│   2004 │   jaxpr, consts = frame.to_jaxpr(out_tracers)                                           │
│   2005 │   del fun, main, trace, frame, in_tracers, out_tracers, ans                             │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/linear_util.py:165   │
│ in call_wrapped                                                                                  │
│                                                                                                  │
│   162 │   gen = gen_static_args = out_store = None                                               │
│   163 │                                                                                          │
│   164 │   try:                                                                                   │
│ ❱ 165 │     ans = self.f(*args, **dict(self.params, **kwargs))                                   │
│   166 │   except:                                                                                │
│   167 │     # Some transformations yield from inside context managers, so we have to             │
│   168 │     # interrupt them before reraising the exception. Otherwise they will only            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:166 in init                                          │
│                                                                                                  │
│   163   @bind(jax.jit, static_argnums=[0], **kwargs)                                             │
│   164   def init(statics, rng, *args, **kwargs):                                                 │
│   165 │   # Return only state so JIT can remove dead code for fast initialization.               │
│ ❱ 166 │   s = fun({}, rng, *args, ignore=True, **dict(statics), **kwargs)[1]                     │
│   167 │   return s                                                                               │
│   168                                                                                            │
│   169   @bind(jax.jit, static_argnums=[0], **kwargs)                                             │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:77 in purified                                       │
│                                                                                                  │
│    74 │   before = CONTEXT                                                                       │
│    75 │   try:                                                                                   │
│    76 │     CONTEXT = Context(state.copy(), rng, create, modify, ignore, [])                     │
│ ❱  77 │     out = fun(*args, **kwargs)                                                           │
│    78 │     state = dict(CONTEXT)                                                                │
│    79 │     return out, state                                                                    │
│    80 │   finally:                                                                               │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper                                       │
│                                                                                                  │
│   350   def wrapper(self, *args, **kwargs):                                                      │
│   351 │   with scope(self._path, absolute=True):                                                 │
│   352 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 353 │   │   return method(self, *args, **kwargs)                                               │
│   354   return wrapper                                                                           │
│   355                                                                                            │
│   356                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/agent.py:80 in train                                           │
│                                                                                                  │
│    77 │   self.config.jax.jit and print('Tracing train function.')                               │
│    78 │   metrics = {}                                                                           │
│    79 │   data = self.preprocess(data)                                                           │
│ ❱  80 │   state, wm_outs, mets = self.wm.train(data, state)                                      │
│    81 │   metrics.update(mets)                                                                   │
│    82 │   context = {**data, **wm_outs['post']}                                                  │
│    83 │   start = tree_map(lambda x: x.reshape([-1] + list(x.shape[2:])), context)               │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper                                       │
│                                                                                                  │
│   350   def wrapper(self, *args, **kwargs):                                                      │
│   351 │   with scope(self._path, absolute=True):                                                 │
│   352 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 353 │   │   return method(self, *args, **kwargs)                                               │
│   354   return wrapper                                                                           │
│   355                                                                                            │
│   356                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/agent.py:151 in train                                          │
│                                                                                                  │
│   148                                                                                            │
│   149   def train(self, data, state):                                                            │
│   150 │   modules = [self.encoder, self.rssm, *self.heads.values()]                              │
│ ❱ 151 │   mets, (state, outs, metrics) = self.opt(                                               │
│   152 │   │   modules, self.loss, data, state, has_aux=True)                                     │
│   153 │   metrics.update(mets)                                                                   │
│   154 │   return state, outs, metrics                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper                                       │
│                                                                                                  │
│   350   def wrapper(self, *args, **kwargs):                                                      │
│   351 │   with scope(self._path, absolute=True):                                                 │
│   352 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 353 │   │   return method(self, *args, **kwargs)                                               │
│   354   return wrapper                                                                           │
│   355                                                                                            │
│   356                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/jaxutils.py:410 in __call__                                    │
│                                                                                                  │
│   407 │   │   loss *= sg(self.grad_scale.read())                                                 │
│   408 │     return loss, aux                                                                     │
│   409 │   metrics = {}                                                                           │
│ ❱ 410 │   loss, params, grads, aux = nj.grad(                                                    │
│   411 │   │   wrapped, modules, has_aux=True)(*args, **kwargs)                                   │
│   412 │   if not self.PARAM_COUNTS[self.path]:                                                   │
│   413 │     count = sum([np.prod(x.shape) for x in params.values()])                             │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:142 in wrapper                                       │
│                                                                                                  │
│   139   backward = jax.value_and_grad(forward, has_aux=True)                                     │
│   140   @functools.wraps(backward)                                                               │
│   141   def wrapper(*args, **kwargs):                                                            │
│ ❱ 142 │   _prerun(fun, *args, **kwargs)                                                          │
│   143 │   assert all(isinstance(x, (str, Module)) for x in keys)                                 │
│   144 │   strs = [x for x in keys if isinstance(x, str)]                                         │
│   145 │   mods = [x for x in keys if isinstance(x, Module)]                                      │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner                   │
│                                                                                                  │
│    72 │   │   @wraps(func)                                                                       │
│    73 │   │   def inner(*args, **kwds):                                                          │
│    74 │   │   │   with self._recreate_cm():                                                      │
│ ❱  75 │   │   │   │   return func(*args, **kwds)                                                 │
│    76 │   │   return inner                                                                       │
│    77                                                                                            │
│    78                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:271 in _prerun                                       │
│                                                                                                  │
│   268 def _prerun(fun, *args, **kwargs):                                                         │
│   269   if not context().create:                                                                 │
│   270 │   return                                                                                 │
│ ❱ 271   discarded, state = fun(dict(context()), rng(), *args, ignore=True, **kwargs)             │
│   272   jax.tree_util.tree_map(                                                                  │
│   273 │     lambda x: hasattr(x, 'delete') and x.delete(), discarded)                            │
│   274   context().update(state)                                                                  │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:77 in purified                                       │
│                                                                                                  │
│    74 │   before = CONTEXT                                                                       │
│    75 │   try:                                                                                   │
│    76 │     CONTEXT = Context(state.copy(), rng, create, modify, ignore, [])                     │
│ ❱  77 │     out = fun(*args, **kwargs)                                                           │
│    78 │     state = dict(CONTEXT)                                                                │
│    79 │     return out, state                                                                    │
│    80 │   finally:                                                                               │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/jaxutils.py:402 in wrapped                                     │
│                                                                                                  │
│   399                                                                                            │
│   400   def __call__(self, modules, lossfn, *args, has_aux=False, **kwargs):                     │
│   401 │   def wrapped(*args, **kwargs):                                                          │
│ ❱ 402 │     outs = lossfn(*args, **kwargs)                                                       │
│   403 │     loss, aux = outs if has_aux else (outs, None)                                        │
│   404 │     assert loss.dtype == jnp.float32, (self.name, loss.dtype)                            │
│   405 │     assert loss.shape == (), (self.name, loss.shape)                                     │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper                                       │
│                                                                                                  │
│   350   def wrapper(self, *args, **kwargs):                                                      │
│   351 │   with scope(self._path, absolute=True):                                                 │
│   352 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 353 │   │   return method(self, *args, **kwargs)                                               │
│   354   return wrapper                                                                           │
│   355                                                                                            │
│   356                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/agent.py:161 in loss                                           │
│                                                                                                  │
│   158 │   prev_latent, prev_action = state                                                       │
│   159 │   prev_actions = jnp.concatenate([                                                       │
│   160 │   │   prev_action[:, None], data['action'][:, :-1]], 1)                                  │
│ ❱ 161 │   post, prior = self.rssm.observe(                                                       │
│   162 │   │   embed, prev_actions, data['is_first'], prev_latent)                                │
│   163 │   dists = {}                                                                             │
│   164 │   feats = {**post, 'embed': embed}                                                       │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper                                       │
│                                                                                                  │
│   350   def wrapper(self, *args, **kwargs):                                                      │
│   351 │   with scope(self._path, absolute=True):                                                 │
│   352 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 353 │   │   return method(self, *args, **kwargs)                                               │
│   354   return wrapper                                                                           │
│   355                                                                                            │
│   356                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/nets.py:60 in observe                                          │
│                                                                                                  │
│    57 │   step = lambda prev, inputs: self.obs_step(prev[0], *inputs)                            │
│    58 │   inputs = swap(action), swap(embed), swap(is_first)                                     │
│    59 │   start = state, state                                                                   │
│ ❱  60 │   post, prior = jaxutils.scan(step, inputs, start, self._unroll)                         │
│    61 │   post = {k: swap(v) for k, v in post.items()}                                           │
│    62 │   prior = {k: swap(v) for k, v in prior.items()}                                         │
│    63 │   return post, prior                                                                     │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/jaxutils.py:73 in scan                                         │
│                                                                                                  │
│    70 def scan(fn, inputs, start, unroll=True, modify=False):                                    │
│    71   fn2 = lambda carry, inp: (fn(carry, inp),) * 2                                           │
│    72   if not unroll:                                                                           │
│ ❱  73 │   return nj.scan(fn2, start, inputs, modify=modify)[1]                                   │
│    74   length = len(jax.tree_util.tree_leaves(inputs)[0])                                       │
│    75   carrydef = jax.tree_util.tree_structure(start)                                           │
│    76   carry = start                                                                            │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner                   │
│                                                                                                  │
│    72 │   │   @wraps(func)                                                                       │
│    73 │   │   def inner(*args, **kwds):                                                          │
│    74 │   │   │   with self._recreate_cm():                                                      │
│ ❱  75 │   │   │   │   return func(*args, **kwds)                                                 │
│    76 │   │   return inner                                                                       │
│    77                                                                                            │
│    78                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:245 in scan                                          │
│                                                                                                  │
│   242 @jax.named_scope('scan')                                                                   │
│   243 def scan(fun, carry, xs, reverse=False, unroll=1, modify=False):                           │
│   244   fun = pure(fun, nested=True)                                                             │
│ ❱ 245   _prerun(fun, carry, jax.tree_util.tree_map(lambda x: x[0], xs))                          │
│   246   length = len(jax.tree_util.tree_leaves(xs)[0])                                           │
│   247   rngs = rng(length)                                                                       │
│   248   if modify:                                                                               │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner                   │
│                                                                                                  │
│    72 │   │   @wraps(func)                                                                       │
│    73 │   │   def inner(*args, **kwds):                                                          │
│    74 │   │   │   with self._recreate_cm():                                                      │
│ ❱  75 │   │   │   │   return func(*args, **kwds)                                                 │
│    76 │   │   return inner                                                                       │
│    77                                                                                            │
│    78                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:272 in _prerun                                       │
│                                                                                                  │
│   269   if not context().create:                                                                 │
│   270 │   return                                                                                 │
│   271   discarded, state = fun(dict(context()), rng(), *args, ignore=True, **kwargs)             │
│ ❱ 272   jax.tree_util.tree_map(                                                                  │
│   273 │     lambda x: hasattr(x, 'delete') and x.delete(), discarded)                            │
│   274   context().update(state)                                                                  │
│   275                                                                                            │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/tree_util.py:207 in  │
│ tree_map                                                                                         │
│                                                                                                  │
│   204   """                                                                                      │
│   205   leaves, treedef = tree_flatten(tree, is_leaf)                                            │
│   206   all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]                         │
│ ❱ 207   return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))                              │
│   208                                                                                            │
│   209 def build_tree(treedef: PyTreeDef, xs: Any) -> Any:                                        │
│   210   return treedef.from_iterable_tree(xs)                                                    │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/tree_util.py:207 in  │
│ <genexpr>                                                                                        │
│                                                                                                  │
│   204   """                                                                                      │
│   205   leaves, treedef = tree_flatten(tree, is_leaf)                                            │
│   206   all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]                         │
│ ❱ 207   return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))                              │
│   208                                                                                            │
│   209 def build_tree(treedef: PyTreeDef, xs: Any) -> Any:                                        │
│   210   return treedef.from_iterable_tree(xs)                                                    │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:273 in <lambda>                                      │
│                                                                                                  │
│   270 │   return                                                                                 │
│   271   discarded, state = fun(dict(context()), rng(), *args, ignore=True, **kwargs)             │
│   272   jax.tree_util.tree_map(                                                                  │
│ ❱ 273 │     lambda x: hasattr(x, 'delete') and x.delete(), discarded)                            │
│   274   context().update(state)                                                                  │
│   275                                                                                            │
│   276                                                                                            │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/core.py:734 in       │
│ delete                                                                                           │
│                                                                                                  │
│    731 │     f"The 'copy_to_host_async' method is not available on the JAX Tracer object {self}  │
│    732                                                                                           │
│    733   def delete(self):                                                                       │
│ ❱  734 │   raise ConcretizationTypeError(self,                                                   │
│    735 │     f"The delete() method was called on the JAX Tracer object {self}")                  │
│    736                                                                                           │
│    737   def device(self):                                                                       │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where 
concrete value is expected: Traced<ShapedArray(float16[16,1024])>with<DynamicJaxprTrace(level=1/0)>
The delete() method was called on the JAX Tracer object 
Traced<ShapedArray(float16[16,1024])>with<DynamicJaxprTrace(level=1/0)>
The error occurred while tracing the function init at /home/weidong/dreamerv3/dreamerv3/ninjax.py:163 for 
jit. This concrete value was not available in Python because it depends on the values of the arguments 
'statics', 'rng', and 'args'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/weidong/dreamerv3/example.py:48 in <module>                                                │
│                                                                                                  │
│   45                                                                                             │
│   46                                                                                             │
│   47 if __name__ == '__main__':                                                                  │
│ ❱ 48   main()                                                                                    │
│   49                                                                                             │
│                                                                                                  │
│ /home/weidong/dreamerv3/example.py:44 in main                                                    │
│                                                                                                  │
│   41   replay = embodied.replay.Uniform(                                                         │
│   42 │     config.batch_length, config.replay_size, logdir / 'replay')                           │
│   43   args = config.run.update(batch_steps=config.batch_size * config.batch_length)             │
│ ❱ 44   embodied.run.train(agent, env, replay, logger, args)                                      │
│   45                                                                                             │
│   46                                                                                             │
│   47 if __name__ == '__main__':                                                                  │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/embodied/run/train.py:79 in train                              │
│                                                                                                  │
│    76   for _ in range(args.pretrain):                                                           │
│    77 │   with timer.scope('dataset'):                                                           │
│    78 │     batch = next(dataset)                                                                │
│ ❱  79 │   _, state[0], _ = agent.train(batch, state[0])                                          │
│    80                                                                                            │
│    81   batch = [None]                                                                           │
│    82   def train_step(tran, worker):                                                            │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner                   │
│                                                                                                  │
│    72 │   │   @wraps(func)                                                                       │
│    73 │   │   def inner(*args, **kwds):                                                          │
│    74 │   │   │   with self._recreate_cm():                                                      │
│ ❱  75 │   │   │   │   return func(*args, **kwds)                                                 │
│    76 │   │   return inner                                                                       │
│    77                                                                                            │
│    78                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/jaxagent.py:80 in train                                        │
│                                                                                                  │
│    77 │   rng = self._next_rngs(mirror=not self.varibs)                                          │
│    78 │   if state is None:                                                                      │
│    79 │     state, self.varibs = self._init_train(self.varibs, rng, data['is_first'])            │
│ ❱  80 │   (outs, state, mets), self.varibs = self._train(                                        │
│    81 │   │   self.varibs, rng, data, state)                                                     │
│    82 │   outs = self._convert_outs(outs)                                                        │
│    83 │   mets = self._convert_mets(mets)                                                        │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:181 in wrapper                                       │
│                                                                                                  │
│   178 │   statics = tuple(sorted([(k, v) for k, v in kwargs.items() if k in static]))            │
│   179 │   kwargs = {k: v for k, v in kwargs.items() if k not in static}                          │
│   180 │   if not hasattr(wrapper, 'keys'):                                                       │
│ ❱ 181 │     created = init(statics, rng, *args, **kwargs)                                        │
│   182 │     wrapper.keys = set(created.keys())                                                   │
│   183 │     for key, value in created.items():                                                   │
│   184 │   │   if key not in state:                                                               │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:166 in init                                          │
│                                                                                                  │
│   163   @bind(jax.jit, static_argnums=[0], **kwargs)                                             │
│   164   def init(statics, rng, *args, **kwargs):                                                 │
│   165 │   # Return only state so JIT can remove dead code for fast initialization.               │
│ ❱ 166 │   s = fun({}, rng, *args, ignore=True, **dict(statics), **kwargs)[1]                     │
│   167 │   return s                                                                               │
│   168                                                                                            │
│   169   @bind(jax.jit, static_argnums=[0], **kwargs)                                             │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:77 in purified                                       │
│                                                                                                  │
│    74 │   before = CONTEXT                                                                       │
│    75 │   try:                                                                                   │
│    76 │     CONTEXT = Context(state.copy(), rng, create, modify, ignore, [])                     │
│ ❱  77 │     out = fun(*args, **kwargs)                                                           │
│    78 │     state = dict(CONTEXT)                                                                │
│    79 │     return out, state                                                                    │
│    80 │   finally:                                                                               │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper                                       │
│                                                                                                  │
│   350   def wrapper(self, *args, **kwargs):                                                      │
│   351 │   with scope(self._path, absolute=True):                                                 │
│   352 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 353 │   │   return method(self, *args, **kwargs)                                               │
│   354   return wrapper                                                                           │
│   355                                                                                            │
│   356                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/agent.py:80 in train                                           │
│                                                                                                  │
│    77 │   self.config.jax.jit and print('Tracing train function.')                               │
│    78 │   metrics = {}                                                                           │
│    79 │   data = self.preprocess(data)                                                           │
│ ❱  80 │   state, wm_outs, mets = self.wm.train(data, state)                                      │
│    81 │   metrics.update(mets)                                                                   │
│    82 │   context = {**data, **wm_outs['post']}                                                  │
│    83 │   start = tree_map(lambda x: x.reshape([-1] + list(x.shape[2:])), context)               │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper                                       │
│                                                                                                  │
│   350   def wrapper(self, *args, **kwargs):                                                      │
│   351 │   with scope(self._path, absolute=True):                                                 │
│   352 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 353 │   │   return method(self, *args, **kwargs)                                               │
│   354   return wrapper                                                                           │
│   355                                                                                            │
│   356                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/agent.py:151 in train                                          │
│                                                                                                  │
│   148                                                                                            │
│   149   def train(self, data, state):                                                            │
│   150 │   modules = [self.encoder, self.rssm, *self.heads.values()]                              │
│ ❱ 151 │   mets, (state, outs, metrics) = self.opt(                                               │
│   152 │   │   modules, self.loss, data, state, has_aux=True)                                     │
│   153 │   metrics.update(mets)                                                                   │
│   154 │   return state, outs, metrics                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper                                       │
│                                                                                                  │
│   350   def wrapper(self, *args, **kwargs):                                                      │
│   351 │   with scope(self._path, absolute=True):                                                 │
│   352 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 353 │   │   return method(self, *args, **kwargs)                                               │
│   354   return wrapper                                                                           │
│   355                                                                                            │
│   356                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/jaxutils.py:410 in __call__                                    │
│                                                                                                  │
│   407 │   │   loss *= sg(self.grad_scale.read())                                                 │
│   408 │     return loss, aux                                                                     │
│   409 │   metrics = {}                                                                           │
│ ❱ 410 │   loss, params, grads, aux = nj.grad(                                                    │
│   411 │   │   wrapped, modules, has_aux=True)(*args, **kwargs)                                   │
│   412 │   if not self.PARAM_COUNTS[self.path]:                                                   │
│   413 │     count = sum([np.prod(x.shape) for x in params.values()])                             │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:142 in wrapper                                       │
│                                                                                                  │
│   139   backward = jax.value_and_grad(forward, has_aux=True)                                     │
│   140   @functools.wraps(backward)                                                               │
│   141   def wrapper(*args, **kwargs):                                                            │
│ ❱ 142 │   _prerun(fun, *args, **kwargs)                                                          │
│   143 │   assert all(isinstance(x, (str, Module)) for x in keys)                                 │
│   144 │   strs = [x for x in keys if isinstance(x, str)]                                         │
│   145 │   mods = [x for x in keys if isinstance(x, Module)]                                      │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner                   │
│                                                                                                  │
│    72 │   │   @wraps(func)                                                                       │
│    73 │   │   def inner(*args, **kwds):                                                          │
│    74 │   │   │   with self._recreate_cm():                                                      │
│ ❱  75 │   │   │   │   return func(*args, **kwds)                                                 │
│    76 │   │   return inner                                                                       │
│    77                                                                                            │
│    78                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:271 in _prerun                                       │
│                                                                                                  │
│   268 def _prerun(fun, *args, **kwargs):                                                         │
│   269   if not context().create:                                                                 │
│   270 │   return                                                                                 │
│ ❱ 271   discarded, state = fun(dict(context()), rng(), *args, ignore=True, **kwargs)             │
│   272   jax.tree_util.tree_map(                                                                  │
│   273 │     lambda x: hasattr(x, 'delete') and x.delete(), discarded)                            │
│   274   context().update(state)                                                                  │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:77 in purified                                       │
│                                                                                                  │
│    74 │   before = CONTEXT                                                                       │
│    75 │   try:                                                                                   │
│    76 │     CONTEXT = Context(state.copy(), rng, create, modify, ignore, [])                     │
│ ❱  77 │     out = fun(*args, **kwargs)                                                           │
│    78 │     state = dict(CONTEXT)                                                                │
│    79 │     return out, state                                                                    │
│    80 │   finally:                                                                               │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/jaxutils.py:402 in wrapped                                     │
│                                                                                                  │
│   399                                                                                            │
│   400   def __call__(self, modules, lossfn, *args, has_aux=False, **kwargs):                     │
│   401 │   def wrapped(*args, **kwargs):                                                          │
│ ❱ 402 │     outs = lossfn(*args, **kwargs)                                                       │
│   403 │     loss, aux = outs if has_aux else (outs, None)                                        │
│   404 │     assert loss.dtype == jnp.float32, (self.name, loss.dtype)                            │
│   405 │     assert loss.shape == (), (self.name, loss.shape)                                     │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper                                       │
│                                                                                                  │
│   350   def wrapper(self, *args, **kwargs):                                                      │
│   351 │   with scope(self._path, absolute=True):                                                 │
│   352 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 353 │   │   return method(self, *args, **kwargs)                                               │
│   354   return wrapper                                                                           │
│   355                                                                                            │
│   356                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/agent.py:161 in loss                                           │
│                                                                                                  │
│   158 │   prev_latent, prev_action = state                                                       │
│   159 │   prev_actions = jnp.concatenate([                                                       │
│   160 │   │   prev_action[:, None], data['action'][:, :-1]], 1)                                  │
│ ❱ 161 │   post, prior = self.rssm.observe(                                                       │
│   162 │   │   embed, prev_actions, data['is_first'], prev_latent)                                │
│   163 │   dists = {}                                                                             │
│   164 │   feats = {**post, 'embed': embed}                                                       │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper                                       │
│                                                                                                  │
│   350   def wrapper(self, *args, **kwargs):                                                      │
│   351 │   with scope(self._path, absolute=True):                                                 │
│   352 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 353 │   │   return method(self, *args, **kwargs)                                               │
│   354   return wrapper                                                                           │
│   355                                                                                            │
│   356                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/nets.py:60 in observe                                          │
│                                                                                                  │
│    57 │   step = lambda prev, inputs: self.obs_step(prev[0], *inputs)                            │
│    58 │   inputs = swap(action), swap(embed), swap(is_first)                                     │
│    59 │   start = state, state                                                                   │
│ ❱  60 │   post, prior = jaxutils.scan(step, inputs, start, self._unroll)                         │
│    61 │   post = {k: swap(v) for k, v in post.items()}                                           │
│    62 │   prior = {k: swap(v) for k, v in prior.items()}                                         │
│    63 │   return post, prior                                                                     │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/jaxutils.py:73 in scan                                         │
│                                                                                                  │
│    70 def scan(fn, inputs, start, unroll=True, modify=False):                                    │
│    71   fn2 = lambda carry, inp: (fn(carry, inp),) * 2                                           │
│    72   if not unroll:                                                                           │
│ ❱  73 │   return nj.scan(fn2, start, inputs, modify=modify)[1]                                   │
│    74   length = len(jax.tree_util.tree_leaves(inputs)[0])                                       │
│    75   carrydef = jax.tree_util.tree_structure(start)                                           │
│    76   carry = start                                                                            │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner                   │
│                                                                                                  │
│    72 │   │   @wraps(func)                                                                       │
│    73 │   │   def inner(*args, **kwds):                                                          │
│    74 │   │   │   with self._recreate_cm():                                                      │
│ ❱  75 │   │   │   │   return func(*args, **kwds)                                                 │
│    76 │   │   return inner                                                                       │
│    77                                                                                            │
│    78                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:245 in scan                                          │
│                                                                                                  │
│   242 @jax.named_scope('scan')                                                                   │
│   243 def scan(fun, carry, xs, reverse=False, unroll=1, modify=False):                           │
│   244   fun = pure(fun, nested=True)                                                             │
│ ❱ 245   _prerun(fun, carry, jax.tree_util.tree_map(lambda x: x[0], xs))                          │
│   246   length = len(jax.tree_util.tree_leaves(xs)[0])                                           │
│   247   rngs = rng(length)                                                                       │
│   248   if modify:                                                                               │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner                   │
│                                                                                                  │
│    72 │   │   @wraps(func)                                                                       │
│    73 │   │   def inner(*args, **kwds):                                                          │
│    74 │   │   │   with self._recreate_cm():                                                      │
│ ❱  75 │   │   │   │   return func(*args, **kwds)                                                 │
│    76 │   │   return inner                                                                       │
│    77                                                                                            │
│    78                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:272 in _prerun                                       │
│                                                                                                  │
│   269   if not context().create:                                                                 │
│   270 │   return                                                                                 │
│   271   discarded, state = fun(dict(context()), rng(), *args, ignore=True, **kwargs)             │
│ ❱ 272   jax.tree_util.tree_map(                                                                  │
│   273 │     lambda x: hasattr(x, 'delete') and x.delete(), discarded)                            │
│   274   context().update(state)                                                                  │
│   275                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:273 in <lambda>                                      │
│                                                                                                  │
│   270 │   return                                                                                 │
│   271   discarded, state = fun(dict(context()), rng(), *args, ignore=True, **kwargs)             │
│   272   jax.tree_util.tree_map(                                                                  │
│ ❱ 273 │     lambda x: hasattr(x, 'delete') and x.delete(), discarded)                            │
│   274   context().update(state)                                                                  │
│   275                                                                                            │
│   276                                                                                            │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: 
Traced<ShapedArray(float16[16,1024])>with<DynamicJaxprTrace(level=1/0)>
The delete() method was called on the JAX Tracer object 
Traced<ShapedArray(float16[16,1024])>with<DynamicJaxprTrace(level=1/0)>
The error occurred while tracing the function init at /home/weidong/dreamerv3/dreamerv3/ninjax.py:163 for 
jit. This concrete value was not available in Python because it depends on the values of the arguments 
'statics', 'rng', and 'args'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

I think some jax support guide should be added in readme.

My pip list now is as follow:

(dreamerv3) weidong@user-NULL:~/dreamerv3$ pip list
Package                      Version
---------------------------- --------------------
absl-py                      1.4.0
astunparse                   1.6.3
cachetools                   5.3.0
certifi                      2022.12.7
charset-normalizer           3.0.1
chex                         0.1.6
cloudpickle                  1.6.0
crafter                      1.8.0
decorator                    5.1.1
dm-tree                      0.1.8
flatbuffers                  23.1.21
gast                         0.4.0
google-auth                  2.16.1
google-auth-oauthlib         0.4.6
google-pasta                 0.2.0
grpcio                       1.51.1
gym                          0.19.0
h5py                         3.8.0
idna                         3.4
imageio                      2.25.1
importlib-metadata           6.0.0
jax                          0.4.4
jaxlib                       0.4.4+cuda11.cudnn86
keras                        2.11.0
libclang                     15.0.6.1
llvmlite                     0.39.1
Markdown                     3.4.1
markdown-it-py               2.1.0
MarkupSafe                   2.1.2
mdurl                        0.1.2
numba                        0.56.4
numpy                        1.23.5
oauthlib                     3.2.2
opensimplex                  0.4.4
opt-einsum                   3.3.0
optax                        0.1.4
packaging                    23.0
Pillow                       9.4.0
pip                          23.0.1
protobuf                     3.19.6
pyasn1                       0.4.8
pyasn1-modules               0.2.8
Pygments                     2.14.0
python-version               0.0.2
requests                     2.28.2
requests-oauthlib            1.3.1
rich                         13.3.1
rsa                          4.9
ruamel.yaml                  0.17.21
ruamel.yaml.clib             0.2.7
scipy                        1.10.0
setuptools                   65.6.3
six                          1.16.0
tensorboard                  2.11.2
tensorboard-data-server      0.6.1
tensorboard-plugin-wit       1.8.1
tensorflow                   2.11.0
tensorflow-cpu               2.11.0
tensorflow-estimator         2.11.0
tensorflow-io-gcs-filesystem 0.30.0
tensorflow-probability       0.19.0
termcolor                    2.2.0
toolz                        0.12.0
typing_extensions            4.5.0
urllib3                      1.26.14
Werkzeug                     2.2.3
wheel                        0.38.4
wrapt                        1.14.1
zipp                         3.14.0

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.