Comments (11)
I didn't maange to make the above snipets work. I kept having issues with the parameters of the agent initial_policy calls.
Here is what worked for me
Basically adding a callback for each step in the embodied/run/eval_only.py
script. It is a very minimal change to the original codebase and it also decouple quite nicelly the inner working of Dreamer from the pygame rendering code.
Hope this helps.
from dreamerv3.
After 2 long days, found the answer based on this issue! Leaving here for anyone who wants to render their DRL AIs playing:
In from_gym.py
, add the 4 lines that start with plt
:
import matplotlib.pyplot as plt
def step(self, action):
if action['reset'] or self._done:
self._done = False
obs = self._env.reset()
return self._obs(obs, 0.0, is_first=True)
if self._act_dict:
action = self._unflatten(action)
else:
action = action[self._act_key]
obs, reward, self._done, self._info = self._env.step(action)
plt.imshow(obs)
plt.show(block=False)
plt.pause(0.001) # Pause to ensure the plot updates
plt.clf() # Clear the plot so that the next image replaces this one
return self._obs(
obs, reward,
is_last=bool(self._done),
is_terminal=bool(self._info.get('is_terminal', self._done)))
from dreamerv3.
Hey, I updated the checkpoint code and run scripts to make this easy. You can now train an agent as normal:
python dreamerv3/train.py --run.logdir ~/logdir/train --configs crafter --run.script train
And then load the agent to evaluate it in an environment without further training:
python dreamerv3/train.py --run.logdir ~/logdir/eval --configs crafter \
--run.script eval_only --run.from_checkpoint ~/logdir/train/checkpoint.pkl
You also asked for a minimal example to load the agent yourself. The relevant code is in dreamerv3/train.py
and run/eval_only.py
and boils down to:
env = ...
config = ...
step = embodied.Counter()
agent = Agent(env.obs_space, env.act_space, step, config)
checkpoint = embodied.Checkpoint()
checkpoint.agent = agent
checkpoint.load('path/to/checkpoint.pkl', keys=['agent'])
state = None
act, state = agent.policy(obs, state, mode='eval')
from dreamerv3.
Great, thank you so much.
from dreamerv3.
Hello ! Any idea how the initial state should be formatted? I am trying to run from the minimal code you provided above with a gym environment. However:
env = # Replace this with your Gym env.
env = from_gym.FromGym(env)
obs=env._env.reset()
obs=env._obs(obs, 0.0, is_first=True)
obs = {k: embodied.convert(v) for k, v in obs.items()}
act, state = agent.policy(obs, state=[None], mode='eval')
returns an error.
From the policy() function I can see that it is expecting:
def policy(self, obs, state, mode='train'): │
│ 52 │ self.config.jax.jit and print('Tracing policy function.') │
│ 53 │ obs = self.preprocess(obs) │
│ ❱ 54 │ (prev_latent, prev_action), task_state, expl_state = state
Is there a way to initialize the state?
from dreamerv3.
You can just pass in None
as the first state and from then on pass back the state that it returns.
This is done in jaxagent.py
.
from dreamerv3.
@ThomasRochefortB did you manage to run the minimal snippet successfully?
On my side, I run into an error that seems to come from the observation data being not formatted as expected when passed to the agent policy.
Here is what I did:
Training, everything work well:
python dreamerv3/train.py --logdir ~/logdir/test_1 --configs crafter
And then, when I try to run the minimal snippet inference like that:
LOGDIR = Path('~/logdir/test_1')
config = embodied.Config.load(str(LOGDIR / 'config.yaml'))
env = crafter.Env() # Replace this with your Gym env.
env = from_gym.FromGym(env)
env = dreamerv3.wrap_env(env, config)
# env = embodied.BatchEnv([env], parallel=False)
step = embodied.Counter()
agent = dreamerv3.Agent(env.obs_space, env.act_space, step, config)
checkpoint = embodied.Checkpoint()
checkpoint.agent = agent
checkpoint.load(str(LOGDIR / 'checkpoint.ckpt'), keys=['agent'])
obs = env._env.reset()
obs = env._obs(obs, 0.0, is_first=True)
obs = {k: embodied.convert(v) for k, v in obs.items()}
state = None
while True:
act, state = agent.policy(obs, state, mode='eval') # error comes from that line
acts = {k: v for k, v in act.items() if not k.startswith('log_')}
obs, reward, done, _ = env.step(acts) # act['action'])
I get an error from that line act, state = agent.policy(obs, state, mode='eval')
that point to jaxagent.py line 144: IndexError: tuple index out of range
.
from dreamerv3.
@jobesu14 The easiest way is to take example.py
and replace embodied.run.train(...)
at the end with embodied.run.eval_only(agent, env, logger, args)
. You can also look at embodied/run/eval_only.py
for the details and simplify that further as needed.
I think the issue in your snippet is that the policy expects a batch size. I think it should look something like the following but don't have the time to test it right now:
logdir = embodied.Path('~/logdir/test_1')
config = embodied.Config.load(logdir / 'config.yaml')
env = crafter.Env()
env = from_gym.FromGym(env)
env = dreamerv3.wrap_env(env, config)
step = embodied.Counter()
agent = dreamerv3.Agent(env.obs_space, env.act_space, step, config)
checkpoint = embodied.Checkpoint()
checkpoint.agent = agent
checkpoint.load(logdir / 'checkpoint.ckpt', keys=['agent'])
state = None
act = {'action': env.act_space['action'].sample(), 'reset': np.array(True)}
while True:
obs = env.step(act)
obs = {k: v[None] for k, v in obs.items()}
act, state = agent.policy(obs, state, mode='eval')
act = {'action': act['action'][0], 'reset': obs['is_last'][0]}
from dreamerv3.
@danijar Any chance you've been able to figure out rendering from your snippet above? This still produces the error mentioned above for me. It would be amazing to have code for example.py
that can generically render gym environments + roll out trained policies.
from dreamerv3.
If you env returns an image key as part of the observation dictionary, it will already get rendered and can be viewed in TensorBoard. Does that work for your use case?
from dreamerv3.
Thanks for such a great research algo @danijar! Wondering if there's any good way now to render
the AI playing the game?
from dreamerv3.
Related Issues (20)
- Hybrid Action Spaces HOT 4
- Obtain World Model Predictions during Inference.
- Bug: jaxutils.Optimizer.PARAM_COUNTS parameters counts is None HOT 2
- train.py make_envs error: tuple index out of range in cloudpickle.py HOT 1
- Dimension issue with Observation from custom Gym environment HOT 1
- Invalid syntax in the latest repo HOT 2
- How to generate the scores json files HOT 1
- Replay sample is waiting X seconds (too empty: 0 < 1) HOT 1
- Need some clarifications about details in Atari env HOT 1
- How are online trajectories sampled? HOT 1
- How to add dropout HOT 1
- [Question] Adding separate optimizers/loss functions per network HOT 2
- Invalid syntax `segment = prob[*path]` HOT 2
- Clarification on `carry` variable used during training
- AttributeError: type object 'Module' has no attribute '__annotations__' HOT 1
- Some confusion on the env steps HOT 1
- Slow operation for convolution HOT 5
- Question about integration of Plan2Explore HOT 1
- Outdated README for custom environment and mlp_keys/cnn_keys HOT 2
- Replay parameters
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from dreamerv3.