Code Monkey home page Code Monkey logo

Comments (11)

jobesu14 avatar jobesu14 commented on May 20, 2024 1

I didn't maange to make the above snipets work. I kept having issues with the parameters of the agent initial_policy calls.
Here is what worked for me

Basically adding a callback for each step in the embodied/run/eval_only.py script. It is a very minimal change to the original codebase and it also decouple quite nicelly the inner working of Dreamer from the pygame rendering code.

Hope this helps.

from dreamerv3.

vtopiatech avatar vtopiatech commented on May 20, 2024 1

After 2 long days, found the answer based on this issue! Leaving here for anyone who wants to render their DRL AIs playing:

In from_gym.py, add the 4 lines that start with plt:

import matplotlib.pyplot as plt

  def step(self, action):
    if action['reset'] or self._done:
      self._done = False
      obs = self._env.reset()
      return self._obs(obs, 0.0, is_first=True)
    if self._act_dict:
      action = self._unflatten(action)
    else:
      action = action[self._act_key]
    obs, reward, self._done, self._info = self._env.step(action)
    plt.imshow(obs)
    plt.show(block=False)
    plt.pause(0.001)  # Pause to ensure the plot updates
    plt.clf()  # Clear the plot so that the next image replaces this one
    return self._obs(
        obs, reward,
        is_last=bool(self._done),
        is_terminal=bool(self._info.get('is_terminal', self._done)))

from dreamerv3.

danijar avatar danijar commented on May 20, 2024

Hey, I updated the checkpoint code and run scripts to make this easy. You can now train an agent as normal:

python dreamerv3/train.py --run.logdir ~/logdir/train --configs crafter --run.script train

And then load the agent to evaluate it in an environment without further training:

python dreamerv3/train.py --run.logdir ~/logdir/eval --configs crafter \
  --run.script eval_only --run.from_checkpoint ~/logdir/train/checkpoint.pkl

You also asked for a minimal example to load the agent yourself. The relevant code is in dreamerv3/train.py and run/eval_only.py and boils down to:

env = ...
config = ...
step = embodied.Counter()
agent = Agent(env.obs_space, env.act_space, step, config)
checkpoint = embodied.Checkpoint()
checkpoint.agent = agent
checkpoint.load('path/to/checkpoint.pkl', keys=['agent'])
state = None
act, state = agent.policy(obs, state, mode='eval')

from dreamerv3.

jobesu14 avatar jobesu14 commented on May 20, 2024

Great, thank you so much.

from dreamerv3.

ThomasRochefortB avatar ThomasRochefortB commented on May 20, 2024

Hello ! Any idea how the initial state should be formatted? I am trying to run from the minimal code you provided above with a gym environment. However:

env =   # Replace this with your Gym env.
env = from_gym.FromGym(env)
obs=env._env.reset()
obs=env._obs(obs, 0.0, is_first=True)
obs = {k: embodied.convert(v) for k, v in obs.items()}
act, state = agent.policy(obs, state=[None], mode='eval')

returns an error.

From the policy() function I can see that it is expecting:

def policy(self, obs, state, mode='train'):                                              │
│    52 │   self.config.jax.jit and print('Tracing policy function.')                              │
│    53 │   obs = self.preprocess(obs)                                                             │
│ ❱  54 │   (prev_latent, prev_action), task_state, expl_state = state         

Is there a way to initialize the state?

from dreamerv3.

danijar avatar danijar commented on May 20, 2024

You can just pass in None as the first state and from then on pass back the state that it returns.

This is done in jaxagent.py.

from dreamerv3.

jobesu14 avatar jobesu14 commented on May 20, 2024

@ThomasRochefortB did you manage to run the minimal snippet successfully?

On my side, I run into an error that seems to come from the observation data being not formatted as expected when passed to the agent policy.

Here is what I did:

Training, everything work well:

python dreamerv3/train.py --logdir ~/logdir/test_1 --configs crafter

And then, when I try to run the minimal snippet inference like that:

LOGDIR = Path('~/logdir/test_1')
config = embodied.Config.load(str(LOGDIR / 'config.yaml'))
env = crafter.Env()  # Replace this with your Gym env.
env = from_gym.FromGym(env)
env = dreamerv3.wrap_env(env, config)
# env = embodied.BatchEnv([env], parallel=False)

step = embodied.Counter()
agent = dreamerv3.Agent(env.obs_space, env.act_space, step, config)
checkpoint = embodied.Checkpoint()
checkpoint.agent = agent
checkpoint.load(str(LOGDIR / 'checkpoint.ckpt'), keys=['agent'])

obs = env._env.reset()
obs = env._obs(obs, 0.0, is_first=True)
obs = {k: embodied.convert(v) for k, v in obs.items()}
state = None

while True:
    act, state = agent.policy(obs, state, mode='eval')  # error comes from that line
    acts = {k: v for k, v in act.items() if not k.startswith('log_')}
    obs, reward, done, _ = env.step(acts)  # act['action'])

I get an error from that line act, state = agent.policy(obs, state, mode='eval') that point to jaxagent.py line 144: IndexError: tuple index out of range.

from dreamerv3.

danijar avatar danijar commented on May 20, 2024

@jobesu14 The easiest way is to take example.py and replace embodied.run.train(...) at the end with embodied.run.eval_only(agent, env, logger, args). You can also look at embodied/run/eval_only.py for the details and simplify that further as needed.

I think the issue in your snippet is that the policy expects a batch size. I think it should look something like the following but don't have the time to test it right now:

logdir = embodied.Path('~/logdir/test_1')
config = embodied.Config.load(logdir / 'config.yaml')

env = crafter.Env()
env = from_gym.FromGym(env)
env = dreamerv3.wrap_env(env, config)

step = embodied.Counter()
agent = dreamerv3.Agent(env.obs_space, env.act_space, step, config)
checkpoint = embodied.Checkpoint()
checkpoint.agent = agent
checkpoint.load(logdir / 'checkpoint.ckpt', keys=['agent'])

state = None
act = {'action': env.act_space['action'].sample(), 'reset': np.array(True)}
while True:
    obs = env.step(act)
    obs = {k: v[None] for k, v in obs.items()}
    act, state = agent.policy(obs, state, mode='eval')
    act = {'action': act['action'][0], 'reset': obs['is_last'][0]}

from dreamerv3.

cameronberg avatar cameronberg commented on May 20, 2024

@danijar Any chance you've been able to figure out rendering from your snippet above? This still produces the error mentioned above for me. It would be amazing to have code for example.py that can generically render gym environments + roll out trained policies.

from dreamerv3.

danijar avatar danijar commented on May 20, 2024

If you env returns an image key as part of the observation dictionary, it will already get rendered and can be viewed in TensorBoard. Does that work for your use case?

from dreamerv3.

vtopiatech avatar vtopiatech commented on May 20, 2024

Thanks for such a great research algo @danijar! Wondering if there's any good way now to render the AI playing the game?

from dreamerv3.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.