My device is ubuntu20.04 NVIDIA-SMI 515.86.01 Driver Version: 515.86.01 CUDA Version: 11.7, CUDNN 870, Python3.8
pip list:
(dreamerv3) weidong@user-NULL:~/dreamerv3$ python example.py
2023-02-19 17:07:51.659700: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/weidong/.mujoco/mujoco210/bin:/usr/lib/nvidia:/home/weidong/.mujoco/mujoco200/bin:/usr/lib/nvidia-000:/home/weidong/.mujoco/mujoco210/bin:/usr/lib/nvidia:/home/weidong/.mujoco/mujoco200/bin:/usr/lib/nvidia-000
2023-02-19 17:07:51.659805: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/weidong/.mujoco/mujoco210/bin:/usr/lib/nvidia:/home/weidong/.mujoco/mujoco200/bin:/usr/lib/nvidia-000:/home/weidong/.mujoco/mujoco210/bin:/usr/lib/nvidia:/home/weidong/.mujoco/mujoco200/bin:/usr/lib/nvidia-000
2023-02-19 17:07:51.659814: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
JAX DEVICES (8): [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=2, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=3, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=4, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=5, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=6, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=7, process_index=0, slice_index=0)]
Encoder CNN shapes: {'image': (64, 64, 3)}
Encoder MLP shapes: {}
Decoder CNN shapes: {'image': (64, 64, 3)}
Decoder MLP shapes: {}
Logdir /home/weidong/logdir/run1
Observation space:
image Space(dtype=uint8, shape=(64, 64, 3), low=0, high=255)
reward Space(dtype=float32, shape=(), low=-inf, high=inf)
is_first Space(dtype=bool, shape=(), low=False, high=True)
is_last Space(dtype=bool, shape=(), low=False, high=True)
is_terminal Space(dtype=bool, shape=(), low=False, high=True)
Action space:
action Space(dtype=float32, shape=(17,), low=0, high=1)
reset Space(dtype=bool, shape=(), low=False, high=True)
Fill train dataset (1024 steps).
Episode has 147 steps and return 2.1.
Episode has 305 steps and return 2.1.
Episode has 110 steps and return 1.1.
Episode has 176 steps and return 0.1.
Episode has 140 steps and return 0.1.
───────────────────────────────────────────────── Step 1024 ─────────────────────────────────────────────────
episode/length 140 / episode/score 0.1 / episode/sum_abs_reward 2.1 / episode/reward_rate 0.01
Creating new TensorBoard event file writer.
Saved chunk: 20230219T170756F065708-18SVIAerO9mVKu8c3SI3e2-1Oa8uUlL3lQQnPBO7aaVU7-1024.npz
Tracing train function.
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/weidong/dreamerv3/example.py:48 in <module> │
│ │
│ 45 │
│ 46 │
│ 47 if __name__ == '__main__': │
│ ❱ 48 main() │
│ 49 │
│ │
│ /home/weidong/dreamerv3/example.py:44 in main │
│ │
│ 41 replay = embodied.replay.Uniform( │
│ 42 │ config.batch_length, config.replay_size, logdir / 'replay') │
│ 43 args = config.run.update(batch_steps=config.batch_size * config.batch_length) │
│ ❱ 44 embodied.run.train(agent, env, replay, logger, args) │
│ 45 │
│ 46 │
│ 47 if __name__ == '__main__': │
│ │
│ /home/weidong/dreamerv3/dreamerv3/embodied/run/train.py:79 in train │
│ │
│ 76 for _ in range(args.pretrain): │
│ 77 │ with timer.scope('dataset'): │
│ 78 │ batch = next(dataset) │
│ ❱ 79 │ _, state[0], _ = agent.train(batch, state[0]) │
│ 80 │
│ 81 batch = [None] │
│ 82 def train_step(tran, worker): │
│ │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner │
│ │
│ 72 │ │ @wraps(func) │
│ 73 │ │ def inner(*args, **kwds): │
│ 74 │ │ │ with self._recreate_cm(): │
│ ❱ 75 │ │ │ │ return func(*args, **kwds) │
│ 76 │ │ return inner │
│ 77 │
│ 78 │
│ │
│ /home/weidong/dreamerv3/dreamerv3/jaxagent.py:80 in train │
│ │
│ 77 │ rng = self._next_rngs(mirror=not self.varibs) │
│ 78 │ if state is None: │
│ 79 │ state, self.varibs = self._init_train(self.varibs, rng, data['is_first']) │
│ ❱ 80 │ (outs, state, mets), self.varibs = self._train( │
│ 81 │ │ self.varibs, rng, data, state) │
│ 82 │ outs = self._convert_outs(outs) │
│ 83 │ mets = self._convert_mets(mets) │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:181 in wrapper │
│ │
│ 178 │ statics = tuple(sorted([(k, v) for k, v in kwargs.items() if k in static])) │
│ 179 │ kwargs = {k: v for k, v in kwargs.items() if k not in static} │
│ 180 │ if not hasattr(wrapper, 'keys'): │
│ ❱ 181 │ created = init(statics, rng, *args, **kwargs) │
│ 182 │ wrapper.keys = set(created.keys()) │
│ 183 │ for key, value in created.items(): │
│ 184 │ │ if key not in state: │
│ │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/traceback_util.py:16 │
│ 3 in reraise_with_filtered_traceback │
│ │
│ 160 def reraise_with_filtered_traceback(*args, **kwargs): │
│ 161 │ __tracebackhide__ = True │
│ 162 │ try: │
│ ❱ 163 │ return fun(*args, **kwargs) │
│ 164 │ except Exception as e: │
│ 165 │ mode = filtering_mode() │
│ 166 │ if is_under_reraiser(e) or mode == "off": │
│ │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/pjit.py:237 in │
│ cache_miss │
│ │
│ 234 │
│ 235 @api_boundary │
│ 236 def cache_miss(*args, **kwargs): │
│ ❱ 237 │ outs, out_flat, out_tree, args_flat = _python_pjit_helper( │
│ 238 │ │ fun, infer_params_fn, *args, **kwargs) │
│ 239 │ │
│ 240 │ executable = _read_most_recent_pjit_call_executable() │
│ │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/pjit.py:180 in │
│ _python_pjit_helper │
│ │
│ 177 │
│ 178 │
│ 179 def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs): │
│ ❱ 180 args_flat, _, params, in_tree, out_tree, _ = infer_params_fn( │
│ 181 │ *args, **kwargs) │
│ 182 for arg in args_flat: │
│ 183 │ dispatch.check_arg(arg) │
│ │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/api.py:443 in │
│ infer_params │
│ │
│ 440 │ │ static_argnames=static_argnames, donate_argnums=donate_argnums, │
│ 441 │ │ device=device, backend=backend, keep_unused=keep_unused, │
│ 442 │ │ inline=inline, resource_env=None) │
│ ❱ 443 │ return pjit.common_infer_params(pjit_info_args, *args, **kwargs) │
│ 444 │ │
│ 445 │ has_explicit_sharding = pjit._pjit_explicit_sharding( │
│ 446 │ │ in_shardings, out_shardings, device, backend) │
│ │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/pjit.py:520 in │
│ common_infer_params │
│ │
│ 517 │ hashable_pytree(in_shardings), local_in_avals, in_tree, in_positional_semantics, │
│ 518 │ tuple(isinstance(a, GDA) for a in args_flat), resource_env) │
│ 519 │
│ ❱ 520 jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr( │
│ 521 │ flat_fun, hashable_pytree(out_shardings), global_in_avals, │
│ 522 │ HashableFunction(out_tree, closure=()), │
│ 523 │ ('jit' if resource_env is None else 'pjit')) │
│ │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/linear_util.py:301 │
│ in memoized_fun │
│ │
│ 298 │ ans, stores = result │
│ 299 │ fun.populate_stores(stores) │
│ 300 │ else: │
│ ❱ 301 │ ans = call(fun, *args) │
│ 302 │ cache[key] = (ans, fun.stores) │
│ 303 │ │
│ 304 │ return ans │
│ │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/pjit.py:932 in │
│ _pjit_jaxpr │
│ │
│ 929 │ with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} " │
│ 930 │ │ │ │ │ │ │ │ "for pjit in {elapsed_time} sec", │
│ 931 │ │ │ │ │ │ │ │ │ event=dispatch.JAXPR_TRACE_EVENT): │
│ ❱ 932 │ jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic( │
│ 933 │ │ fun, global_in_avals, debug_info=pe.debug_info_final(fun, api_name)) │
│ 934 finally: │
│ 935 │ pxla.positional_semantics.val = prev_positional_val │
│ │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/profiler.py:314 in │
│ wrapper │
│ │
│ 311 @wraps(func) │
│ 312 def wrapper(*args, **kwargs): │
│ 313 │ with TraceAnnotation(name, **decorator_kwargs): │
│ ❱ 314 │ return func(*args, **kwargs) │
│ 315 │ return wrapper │
│ 316 return wrapper │
│ 317 │
│ │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/interpreters/partial_eval │
│ .py:1985 in trace_to_jaxpr_dynamic │
│ │
│ 1982 │ │ │ │ │ │ keep_inputs: Optional[List[bool]] = None): │
│ 1983 with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore │
│ 1984 │ main.jaxpr_stack = () # type: ignore │
│ ❱ 1985 │ jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( │
│ 1986 │ fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info) │
│ 1987 │ del main, fun │
│ 1988 return jaxpr, out_avals, consts │
│ │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/interpreters/partial_eval │
│ .py:2002 in trace_to_subjaxpr_dynamic │
│ │
│ 1999 │ trace = DynamicJaxprTrace(main, core.cur_sublevel()) │
│ 2000 │ in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) │
│ 2001 │ in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep] │
│ ❱ 2002 │ ans = fun.call_wrapped(*in_tracers_) │
│ 2003 │ out_tracers = map(trace.full_raise, ans) │
│ 2004 │ jaxpr, consts = frame.to_jaxpr(out_tracers) │
│ 2005 │ del fun, main, trace, frame, in_tracers, out_tracers, ans │
│ │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/linear_util.py:165 │
│ in call_wrapped │
│ │
│ 162 │ gen = gen_static_args = out_store = None │
│ 163 │ │
│ 164 │ try: │
│ ❱ 165 │ ans = self.f(*args, **dict(self.params, **kwargs)) │
│ 166 │ except: │
│ 167 │ # Some transformations yield from inside context managers, so we have to │
│ 168 │ # interrupt them before reraising the exception. Otherwise they will only │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:166 in init │
│ │
│ 163 @bind(jax.jit, static_argnums=[0], **kwargs) │
│ 164 def init(statics, rng, *args, **kwargs): │
│ 165 │ # Return only state so JIT can remove dead code for fast initialization. │
│ ❱ 166 │ s = fun({}, rng, *args, ignore=True, **dict(statics), **kwargs)[1] │
│ 167 │ return s │
│ 168 │
│ 169 @bind(jax.jit, static_argnums=[0], **kwargs) │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:77 in purified │
│ │
│ 74 │ before = CONTEXT │
│ 75 │ try: │
│ 76 │ CONTEXT = Context(state.copy(), rng, create, modify, ignore, []) │
│ ❱ 77 │ out = fun(*args, **kwargs) │
│ 78 │ state = dict(CONTEXT) │
│ 79 │ return out, state │
│ 80 │ finally: │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper │
│ │
│ 350 def wrapper(self, *args, **kwargs): │
│ 351 │ with scope(self._path, absolute=True): │
│ 352 │ with jax.named_scope(self._path.split('/')[-1]): │
│ ❱ 353 │ │ return method(self, *args, **kwargs) │
│ 354 return wrapper │
│ 355 │
│ 356 │
│ │
│ /home/weidong/dreamerv3/dreamerv3/agent.py:80 in train │
│ │
│ 77 │ self.config.jax.jit and print('Tracing train function.') │
│ 78 │ metrics = {} │
│ 79 │ data = self.preprocess(data) │
│ ❱ 80 │ state, wm_outs, mets = self.wm.train(data, state) │
│ 81 │ metrics.update(mets) │
│ 82 │ context = {**data, **wm_outs['post']} │
│ 83 │ start = tree_map(lambda x: x.reshape([-1] + list(x.shape[2:])), context) │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper │
│ │
│ 350 def wrapper(self, *args, **kwargs): │
│ 351 │ with scope(self._path, absolute=True): │
│ 352 │ with jax.named_scope(self._path.split('/')[-1]): │
│ ❱ 353 │ │ return method(self, *args, **kwargs) │
│ 354 return wrapper │
│ 355 │
│ 356 │
│ │
│ /home/weidong/dreamerv3/dreamerv3/agent.py:151 in train │
│ │
│ 148 │
│ 149 def train(self, data, state): │
│ 150 │ modules = [self.encoder, self.rssm, *self.heads.values()] │
│ ❱ 151 │ mets, (state, outs, metrics) = self.opt( │
│ 152 │ │ modules, self.loss, data, state, has_aux=True) │
│ 153 │ metrics.update(mets) │
│ 154 │ return state, outs, metrics │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper │
│ │
│ 350 def wrapper(self, *args, **kwargs): │
│ 351 │ with scope(self._path, absolute=True): │
│ 352 │ with jax.named_scope(self._path.split('/')[-1]): │
│ ❱ 353 │ │ return method(self, *args, **kwargs) │
│ 354 return wrapper │
│ 355 │
│ 356 │
│ │
│ /home/weidong/dreamerv3/dreamerv3/jaxutils.py:410 in __call__ │
│ │
│ 407 │ │ loss *= sg(self.grad_scale.read()) │
│ 408 │ return loss, aux │
│ 409 │ metrics = {} │
│ ❱ 410 │ loss, params, grads, aux = nj.grad( │
│ 411 │ │ wrapped, modules, has_aux=True)(*args, **kwargs) │
│ 412 │ if not self.PARAM_COUNTS[self.path]: │
│ 413 │ count = sum([np.prod(x.shape) for x in params.values()]) │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:142 in wrapper │
│ │
│ 139 backward = jax.value_and_grad(forward, has_aux=True) │
│ 140 @functools.wraps(backward) │
│ 141 def wrapper(*args, **kwargs): │
│ ❱ 142 │ _prerun(fun, *args, **kwargs) │
│ 143 │ assert all(isinstance(x, (str, Module)) for x in keys) │
│ 144 │ strs = [x for x in keys if isinstance(x, str)] │
│ 145 │ mods = [x for x in keys if isinstance(x, Module)] │
│ │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner │
│ │
│ 72 │ │ @wraps(func) │
│ 73 │ │ def inner(*args, **kwds): │
│ 74 │ │ │ with self._recreate_cm(): │
│ ❱ 75 │ │ │ │ return func(*args, **kwds) │
│ 76 │ │ return inner │
│ 77 │
│ 78 │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:271 in _prerun │
│ │
│ 268 def _prerun(fun, *args, **kwargs): │
│ 269 if not context().create: │
│ 270 │ return │
│ ❱ 271 discarded, state = fun(dict(context()), rng(), *args, ignore=True, **kwargs) │
│ 272 jax.tree_util.tree_map( │
│ 273 │ lambda x: hasattr(x, 'delete') and x.delete(), discarded) │
│ 274 context().update(state) │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:77 in purified │
│ │
│ 74 │ before = CONTEXT │
│ 75 │ try: │
│ 76 │ CONTEXT = Context(state.copy(), rng, create, modify, ignore, []) │
│ ❱ 77 │ out = fun(*args, **kwargs) │
│ 78 │ state = dict(CONTEXT) │
│ 79 │ return out, state │
│ 80 │ finally: │
│ │
│ /home/weidong/dreamerv3/dreamerv3/jaxutils.py:402 in wrapped │
│ │
│ 399 │
│ 400 def __call__(self, modules, lossfn, *args, has_aux=False, **kwargs): │
│ 401 │ def wrapped(*args, **kwargs): │
│ ❱ 402 │ outs = lossfn(*args, **kwargs) │
│ 403 │ loss, aux = outs if has_aux else (outs, None) │
│ 404 │ assert loss.dtype == jnp.float32, (self.name, loss.dtype) │
│ 405 │ assert loss.shape == (), (self.name, loss.shape) │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper │
│ │
│ 350 def wrapper(self, *args, **kwargs): │
│ 351 │ with scope(self._path, absolute=True): │
│ 352 │ with jax.named_scope(self._path.split('/')[-1]): │
│ ❱ 353 │ │ return method(self, *args, **kwargs) │
│ 354 return wrapper │
│ 355 │
│ 356 │
│ │
│ /home/weidong/dreamerv3/dreamerv3/agent.py:161 in loss │
│ │
│ 158 │ prev_latent, prev_action = state │
│ 159 │ prev_actions = jnp.concatenate([ │
│ 160 │ │ prev_action[:, None], data['action'][:, :-1]], 1) │
│ ❱ 161 │ post, prior = self.rssm.observe( │
│ 162 │ │ embed, prev_actions, data['is_first'], prev_latent) │
│ 163 │ dists = {} │
│ 164 │ feats = {**post, 'embed': embed} │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper │
│ │
│ 350 def wrapper(self, *args, **kwargs): │
│ 351 │ with scope(self._path, absolute=True): │
│ 352 │ with jax.named_scope(self._path.split('/')[-1]): │
│ ❱ 353 │ │ return method(self, *args, **kwargs) │
│ 354 return wrapper │
│ 355 │
│ 356 │
│ │
│ /home/weidong/dreamerv3/dreamerv3/nets.py:60 in observe │
│ │
│ 57 │ step = lambda prev, inputs: self.obs_step(prev[0], *inputs) │
│ 58 │ inputs = swap(action), swap(embed), swap(is_first) │
│ 59 │ start = state, state │
│ ❱ 60 │ post, prior = jaxutils.scan(step, inputs, start, self._unroll) │
│ 61 │ post = {k: swap(v) for k, v in post.items()} │
│ 62 │ prior = {k: swap(v) for k, v in prior.items()} │
│ 63 │ return post, prior │
│ │
│ /home/weidong/dreamerv3/dreamerv3/jaxutils.py:73 in scan │
│ │
│ 70 def scan(fn, inputs, start, unroll=True, modify=False): │
│ 71 fn2 = lambda carry, inp: (fn(carry, inp),) * 2 │
│ 72 if not unroll: │
│ ❱ 73 │ return nj.scan(fn2, start, inputs, modify=modify)[1] │
│ 74 length = len(jax.tree_util.tree_leaves(inputs)[0]) │
│ 75 carrydef = jax.tree_util.tree_structure(start) │
│ 76 carry = start │
│ │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner │
│ │
│ 72 │ │ @wraps(func) │
│ 73 │ │ def inner(*args, **kwds): │
│ 74 │ │ │ with self._recreate_cm(): │
│ ❱ 75 │ │ │ │ return func(*args, **kwds) │
│ 76 │ │ return inner │
│ 77 │
│ 78 │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:245 in scan │
│ │
│ 242 @jax.named_scope('scan') │
│ 243 def scan(fun, carry, xs, reverse=False, unroll=1, modify=False): │
│ 244 fun = pure(fun, nested=True) │
│ ❱ 245 _prerun(fun, carry, jax.tree_util.tree_map(lambda x: x[0], xs)) │
│ 246 length = len(jax.tree_util.tree_leaves(xs)[0]) │
│ 247 rngs = rng(length) │
│ 248 if modify: │
│ │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner │
│ │
│ 72 │ │ @wraps(func) │
│ 73 │ │ def inner(*args, **kwds): │
│ 74 │ │ │ with self._recreate_cm(): │
│ ❱ 75 │ │ │ │ return func(*args, **kwds) │
│ 76 │ │ return inner │
│ 77 │
│ 78 │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:272 in _prerun │
│ │
│ 269 if not context().create: │
│ 270 │ return │
│ 271 discarded, state = fun(dict(context()), rng(), *args, ignore=True, **kwargs) │
│ ❱ 272 jax.tree_util.tree_map( │
│ 273 │ lambda x: hasattr(x, 'delete') and x.delete(), discarded) │
│ 274 context().update(state) │
│ 275 │
│ │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/tree_util.py:207 in │
│ tree_map │
│ │
│ 204 """ │
│ 205 leaves, treedef = tree_flatten(tree, is_leaf) │
│ 206 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest] │
│ ❱ 207 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) │
│ 208 │
│ 209 def build_tree(treedef: PyTreeDef, xs: Any) -> Any: │
│ 210 return treedef.from_iterable_tree(xs) │
│ │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/tree_util.py:207 in │
│ <genexpr> │
│ │
│ 204 """ │
│ 205 leaves, treedef = tree_flatten(tree, is_leaf) │
│ 206 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest] │
│ ❱ 207 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) │
│ 208 │
│ 209 def build_tree(treedef: PyTreeDef, xs: Any) -> Any: │
│ 210 return treedef.from_iterable_tree(xs) │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:273 in <lambda> │
│ │
│ 270 │ return │
│ 271 discarded, state = fun(dict(context()), rng(), *args, ignore=True, **kwargs) │
│ 272 jax.tree_util.tree_map( │
│ ❱ 273 │ lambda x: hasattr(x, 'delete') and x.delete(), discarded) │
│ 274 context().update(state) │
│ 275 │
│ 276 │
│ │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/core.py:734 in │
│ delete │
│ │
│ 731 │ f"The 'copy_to_host_async' method is not available on the JAX Tracer object {self} │
│ 732 │
│ 733 def delete(self): │
│ ❱ 734 │ raise ConcretizationTypeError(self, │
│ 735 │ f"The delete() method was called on the JAX Tracer object {self}") │
│ 736 │
│ 737 def device(self): │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where
concrete value is expected: Traced<ShapedArray(float16[16,1024])>with<DynamicJaxprTrace(level=1/0)>
The delete() method was called on the JAX Tracer object
Traced<ShapedArray(float16[16,1024])>with<DynamicJaxprTrace(level=1/0)>
The error occurred while tracing the function init at /home/weidong/dreamerv3/dreamerv3/ninjax.py:163 for
jit. This concrete value was not available in Python because it depends on the values of the arguments
'statics', 'rng', and 'args'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/weidong/dreamerv3/example.py:48 in <module> │
│ │
│ 45 │
│ 46 │
│ 47 if __name__ == '__main__': │
│ ❱ 48 main() │
│ 49 │
│ │
│ /home/weidong/dreamerv3/example.py:44 in main │
│ │
│ 41 replay = embodied.replay.Uniform( │
│ 42 │ config.batch_length, config.replay_size, logdir / 'replay') │
│ 43 args = config.run.update(batch_steps=config.batch_size * config.batch_length) │
│ ❱ 44 embodied.run.train(agent, env, replay, logger, args) │
│ 45 │
│ 46 │
│ 47 if __name__ == '__main__': │
│ │
│ /home/weidong/dreamerv3/dreamerv3/embodied/run/train.py:79 in train │
│ │
│ 76 for _ in range(args.pretrain): │
│ 77 │ with timer.scope('dataset'): │
│ 78 │ batch = next(dataset) │
│ ❱ 79 │ _, state[0], _ = agent.train(batch, state[0]) │
│ 80 │
│ 81 batch = [None] │
│ 82 def train_step(tran, worker): │
│ │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner │
│ │
│ 72 │ │ @wraps(func) │
│ 73 │ │ def inner(*args, **kwds): │
│ 74 │ │ │ with self._recreate_cm(): │
│ ❱ 75 │ │ │ │ return func(*args, **kwds) │
│ 76 │ │ return inner │
│ 77 │
│ 78 │
│ │
│ /home/weidong/dreamerv3/dreamerv3/jaxagent.py:80 in train │
│ │
│ 77 │ rng = self._next_rngs(mirror=not self.varibs) │
│ 78 │ if state is None: │
│ 79 │ state, self.varibs = self._init_train(self.varibs, rng, data['is_first']) │
│ ❱ 80 │ (outs, state, mets), self.varibs = self._train( │
│ 81 │ │ self.varibs, rng, data, state) │
│ 82 │ outs = self._convert_outs(outs) │
│ 83 │ mets = self._convert_mets(mets) │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:181 in wrapper │
│ │
│ 178 │ statics = tuple(sorted([(k, v) for k, v in kwargs.items() if k in static])) │
│ 179 │ kwargs = {k: v for k, v in kwargs.items() if k not in static} │
│ 180 │ if not hasattr(wrapper, 'keys'): │
│ ❱ 181 │ created = init(statics, rng, *args, **kwargs) │
│ 182 │ wrapper.keys = set(created.keys()) │
│ 183 │ for key, value in created.items(): │
│ 184 │ │ if key not in state: │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:166 in init │
│ │
│ 163 @bind(jax.jit, static_argnums=[0], **kwargs) │
│ 164 def init(statics, rng, *args, **kwargs): │
│ 165 │ # Return only state so JIT can remove dead code for fast initialization. │
│ ❱ 166 │ s = fun({}, rng, *args, ignore=True, **dict(statics), **kwargs)[1] │
│ 167 │ return s │
│ 168 │
│ 169 @bind(jax.jit, static_argnums=[0], **kwargs) │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:77 in purified │
│ │
│ 74 │ before = CONTEXT │
│ 75 │ try: │
│ 76 │ CONTEXT = Context(state.copy(), rng, create, modify, ignore, []) │
│ ❱ 77 │ out = fun(*args, **kwargs) │
│ 78 │ state = dict(CONTEXT) │
│ 79 │ return out, state │
│ 80 │ finally: │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper │
│ │
│ 350 def wrapper(self, *args, **kwargs): │
│ 351 │ with scope(self._path, absolute=True): │
│ 352 │ with jax.named_scope(self._path.split('/')[-1]): │
│ ❱ 353 │ │ return method(self, *args, **kwargs) │
│ 354 return wrapper │
│ 355 │
│ 356 │
│ │
│ /home/weidong/dreamerv3/dreamerv3/agent.py:80 in train │
│ │
│ 77 │ self.config.jax.jit and print('Tracing train function.') │
│ 78 │ metrics = {} │
│ 79 │ data = self.preprocess(data) │
│ ❱ 80 │ state, wm_outs, mets = self.wm.train(data, state) │
│ 81 │ metrics.update(mets) │
│ 82 │ context = {**data, **wm_outs['post']} │
│ 83 │ start = tree_map(lambda x: x.reshape([-1] + list(x.shape[2:])), context) │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper │
│ │
│ 350 def wrapper(self, *args, **kwargs): │
│ 351 │ with scope(self._path, absolute=True): │
│ 352 │ with jax.named_scope(self._path.split('/')[-1]): │
│ ❱ 353 │ │ return method(self, *args, **kwargs) │
│ 354 return wrapper │
│ 355 │
│ 356 │
│ │
│ /home/weidong/dreamerv3/dreamerv3/agent.py:151 in train │
│ │
│ 148 │
│ 149 def train(self, data, state): │
│ 150 │ modules = [self.encoder, self.rssm, *self.heads.values()] │
│ ❱ 151 │ mets, (state, outs, metrics) = self.opt( │
│ 152 │ │ modules, self.loss, data, state, has_aux=True) │
│ 153 │ metrics.update(mets) │
│ 154 │ return state, outs, metrics │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper │
│ │
│ 350 def wrapper(self, *args, **kwargs): │
│ 351 │ with scope(self._path, absolute=True): │
│ 352 │ with jax.named_scope(self._path.split('/')[-1]): │
│ ❱ 353 │ │ return method(self, *args, **kwargs) │
│ 354 return wrapper │
│ 355 │
│ 356 │
│ │
│ /home/weidong/dreamerv3/dreamerv3/jaxutils.py:410 in __call__ │
│ │
│ 407 │ │ loss *= sg(self.grad_scale.read()) │
│ 408 │ return loss, aux │
│ 409 │ metrics = {} │
│ ❱ 410 │ loss, params, grads, aux = nj.grad( │
│ 411 │ │ wrapped, modules, has_aux=True)(*args, **kwargs) │
│ 412 │ if not self.PARAM_COUNTS[self.path]: │
│ 413 │ count = sum([np.prod(x.shape) for x in params.values()]) │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:142 in wrapper │
│ │
│ 139 backward = jax.value_and_grad(forward, has_aux=True) │
│ 140 @functools.wraps(backward) │
│ 141 def wrapper(*args, **kwargs): │
│ ❱ 142 │ _prerun(fun, *args, **kwargs) │
│ 143 │ assert all(isinstance(x, (str, Module)) for x in keys) │
│ 144 │ strs = [x for x in keys if isinstance(x, str)] │
│ 145 │ mods = [x for x in keys if isinstance(x, Module)] │
│ │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner │
│ │
│ 72 │ │ @wraps(func) │
│ 73 │ │ def inner(*args, **kwds): │
│ 74 │ │ │ with self._recreate_cm(): │
│ ❱ 75 │ │ │ │ return func(*args, **kwds) │
│ 76 │ │ return inner │
│ 77 │
│ 78 │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:271 in _prerun │
│ │
│ 268 def _prerun(fun, *args, **kwargs): │
│ 269 if not context().create: │
│ 270 │ return │
│ ❱ 271 discarded, state = fun(dict(context()), rng(), *args, ignore=True, **kwargs) │
│ 272 jax.tree_util.tree_map( │
│ 273 │ lambda x: hasattr(x, 'delete') and x.delete(), discarded) │
│ 274 context().update(state) │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:77 in purified │
│ │
│ 74 │ before = CONTEXT │
│ 75 │ try: │
│ 76 │ CONTEXT = Context(state.copy(), rng, create, modify, ignore, []) │
│ ❱ 77 │ out = fun(*args, **kwargs) │
│ 78 │ state = dict(CONTEXT) │
│ 79 │ return out, state │
│ 80 │ finally: │
│ │
│ /home/weidong/dreamerv3/dreamerv3/jaxutils.py:402 in wrapped │
│ │
│ 399 │
│ 400 def __call__(self, modules, lossfn, *args, has_aux=False, **kwargs): │
│ 401 │ def wrapped(*args, **kwargs): │
│ ❱ 402 │ outs = lossfn(*args, **kwargs) │
│ 403 │ loss, aux = outs if has_aux else (outs, None) │
│ 404 │ assert loss.dtype == jnp.float32, (self.name, loss.dtype) │
│ 405 │ assert loss.shape == (), (self.name, loss.shape) │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper │
│ │
│ 350 def wrapper(self, *args, **kwargs): │
│ 351 │ with scope(self._path, absolute=True): │
│ 352 │ with jax.named_scope(self._path.split('/')[-1]): │
│ ❱ 353 │ │ return method(self, *args, **kwargs) │
│ 354 return wrapper │
│ 355 │
│ 356 │
│ │
│ /home/weidong/dreamerv3/dreamerv3/agent.py:161 in loss │
│ │
│ 158 │ prev_latent, prev_action = state │
│ 159 │ prev_actions = jnp.concatenate([ │
│ 160 │ │ prev_action[:, None], data['action'][:, :-1]], 1) │
│ ❱ 161 │ post, prior = self.rssm.observe( │
│ 162 │ │ embed, prev_actions, data['is_first'], prev_latent) │
│ 163 │ dists = {} │
│ 164 │ feats = {**post, 'embed': embed} │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper │
│ │
│ 350 def wrapper(self, *args, **kwargs): │
│ 351 │ with scope(self._path, absolute=True): │
│ 352 │ with jax.named_scope(self._path.split('/')[-1]): │
│ ❱ 353 │ │ return method(self, *args, **kwargs) │
│ 354 return wrapper │
│ 355 │
│ 356 │
│ │
│ /home/weidong/dreamerv3/dreamerv3/nets.py:60 in observe │
│ │
│ 57 │ step = lambda prev, inputs: self.obs_step(prev[0], *inputs) │
│ 58 │ inputs = swap(action), swap(embed), swap(is_first) │
│ 59 │ start = state, state │
│ ❱ 60 │ post, prior = jaxutils.scan(step, inputs, start, self._unroll) │
│ 61 │ post = {k: swap(v) for k, v in post.items()} │
│ 62 │ prior = {k: swap(v) for k, v in prior.items()} │
│ 63 │ return post, prior │
│ │
│ /home/weidong/dreamerv3/dreamerv3/jaxutils.py:73 in scan │
│ │
│ 70 def scan(fn, inputs, start, unroll=True, modify=False): │
│ 71 fn2 = lambda carry, inp: (fn(carry, inp),) * 2 │
│ 72 if not unroll: │
│ ❱ 73 │ return nj.scan(fn2, start, inputs, modify=modify)[1] │
│ 74 length = len(jax.tree_util.tree_leaves(inputs)[0]) │
│ 75 carrydef = jax.tree_util.tree_structure(start) │
│ 76 carry = start │
│ │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner │
│ │
│ 72 │ │ @wraps(func) │
│ 73 │ │ def inner(*args, **kwds): │
│ 74 │ │ │ with self._recreate_cm(): │
│ ❱ 75 │ │ │ │ return func(*args, **kwds) │
│ 76 │ │ return inner │
│ 77 │
│ 78 │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:245 in scan │
│ │
│ 242 @jax.named_scope('scan') │
│ 243 def scan(fun, carry, xs, reverse=False, unroll=1, modify=False): │
│ 244 fun = pure(fun, nested=True) │
│ ❱ 245 _prerun(fun, carry, jax.tree_util.tree_map(lambda x: x[0], xs)) │
│ 246 length = len(jax.tree_util.tree_leaves(xs)[0]) │
│ 247 rngs = rng(length) │
│ 248 if modify: │
│ │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner │
│ │
│ 72 │ │ @wraps(func) │
│ 73 │ │ def inner(*args, **kwds): │
│ 74 │ │ │ with self._recreate_cm(): │
│ ❱ 75 │ │ │ │ return func(*args, **kwds) │
│ 76 │ │ return inner │
│ 77 │
│ 78 │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:272 in _prerun │
│ │
│ 269 if not context().create: │
│ 270 │ return │
│ 271 discarded, state = fun(dict(context()), rng(), *args, ignore=True, **kwargs) │
│ ❱ 272 jax.tree_util.tree_map( │
│ 273 │ lambda x: hasattr(x, 'delete') and x.delete(), discarded) │
│ 274 context().update(state) │
│ 275 │
│ │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:273 in <lambda> │
│ │
│ 270 │ return │
│ 271 discarded, state = fun(dict(context()), rng(), *args, ignore=True, **kwargs) │
│ 272 jax.tree_util.tree_map( │
│ ❱ 273 │ lambda x: hasattr(x, 'delete') and x.delete(), discarded) │
│ 274 context().update(state) │
│ 275 │
│ 276 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:
Traced<ShapedArray(float16[16,1024])>with<DynamicJaxprTrace(level=1/0)>
The delete() method was called on the JAX Tracer object
Traced<ShapedArray(float16[16,1024])>with<DynamicJaxprTrace(level=1/0)>
The error occurred while tracing the function init at /home/weidong/dreamerv3/dreamerv3/ninjax.py:163 for
jit. This concrete value was not available in Python because it depends on the values of the arguments
'statics', 'rng', and 'args'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
I think some jax support guide should be added in readme.