Code Monkey home page Code Monkey logo

gymnax's Introduction

Scholar Badge

Hi there ๐Ÿค— I am Rob - a Research Scientist @SakanaAI & PhD student @ TU Berlin.
I work on meta-optimization, hardware-accelerated evolution & nature-inspired foundation models ๐Ÿงฌ.
I maintain evosax ๐ŸฆŽ, gymnax ๐Ÿ‹๏ธ & the MLE-Infrastructure ๐Ÿคน.
Previously, I was a Graduate Student Researcher @ Google DM (TKY ๐Ÿ—ผ) & interned @ Legacy DM (LDN ๐Ÿ‡ฌ๐Ÿ‡ง).

gymnax's People

Contributors

aidandos avatar clement-bonnet avatar davidslayback avatar edantoledo avatar jaronsgit avatar kerajli avatar luchris429 avatar ludgerpaehler avatar roberttlange avatar smonsays avatar sotetsuk avatar ziksby avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

gymnax's Issues

Jittable `Environment` class

Similar to how distributions work in distrax, I want to change the API to work with a jittable environment class. E.g.

env = gymnax.make('env_name')
obs, state = env.reset(key)
obs, state, reward, done, info = env.step(key, state, action)

Hence, the environment parameters are "absorbed" in the class instance. This should not be too difficult as long as we are careful about the pytree.

References:

Four Rooms (Sutton et al. 1999) environment

Implement the classic four rooms environment. Start with old numpy implementation from HRL MSc thesis:

import numpy as np
import copy

# Action definitions
RIGHT = 0
UP    = 1
LEFT  = 2
DOWN  = 3


class RoomWorld():
    """The environment for Sutton's semi-MDP HRL.
    """
    def __init__(self, goal_position=[7,9], env_noise=0.1):
        """Map of the rooms. -1 indicates wall, 0 indicates hallway,
           positive numbers indicate numbered rooms
        """
        self.numbered_map = np.array([
        [-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1],
        [-1, 1, 1, 1, 1, 1,-1, 2, 2, 2, 2, 2,-1],
        [-1, 1, 1, 1, 1, 1,-1, 2, 2, 2, 2, 2,-1],
        [-1, 1, 1, 1, 1, 1, 0, 2, 2, 2, 2, 2,-1],
        [-1, 1, 1, 1, 1, 1,-1, 2, 2, 2, 2, 2,-1],
        [-1, 1, 1, 1, 1, 1,-1, 2, 2, 2, 2, 2,-1],
        [-1,-1, 0,-1,-1,-1,-1, 2, 2, 2, 2, 2,-1],
        [-1, 3, 3, 3, 3, 3,-1,-1,-1, 0,-1,-1,-1],
        [-1, 3, 3, 3, 3, 3,-1, 4, 4, 4, 4, 4,-1],
        [-1, 3, 3, 3, 3, 3,-1, 4, 4, 4, 4, 4,-1],
        [-1, 3, 3, 3, 3, 3, 0, 4, 4, 4, 4, 4,-1],
        [-1, 3, 3, 3, 3, 3,-1, 4, 4, 4, 4, 4,-1],
        [-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1]])
        self.walkability_map = np.array([
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0],
        [0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0],
        [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0],
        [0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0],
        [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0],
        [0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0],
        [0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0],
        [0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0],
        [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
        self.state_space   = np.argwhere(self.walkability_map)
        self.action_space  = np.arange(4)
        self.goal_position = np.array(goal_position)
        self.action_success_rate = 1 - env_noise
        self.agents = [] # agents affect each other's observations, so should be included
        # Rewards
        self.step_reward      = 0.0 #-0.1 (Sutton used 0 and depended on discounting effect of gamma to push toward more efficient policies)
        self.collision_reward = 0.0 # was -0.1 at first, but spending a
                                    # timestep without moving is a penalty
        self.goal_reward      = 1.#10.
        self.invalid_plan_reward = 0.0#-10.


    def add_agent(self,agent):
        """Adds an agent to the environment after giving it an identifier
        """
        agent.sebango = len(self.agents) + 2
        self.agents.append(agent)

    def move_agent(self,direction,sebango=2):
        """Attempts moving an agent in a specified direction.
           If the move would put the agent in a wall, the agent remains
           where he is and is given a negative reward value.
        """
        agent  = self.agents[sebango-2]
        new_pos = agent.move(direction)
        if self.walkability_map[tuple(new_pos)].all():
            agent.set_position(new_pos)
            collision = False
        else:
            collision = True
        return collision

    def evaluate_reward(self,sebango=2,collision=False):
        """Calculates the reward to be given for the current timestep after an
           action has been taken.
        """
        agent  = self.agents[sebango-2]
        reward = self.step_reward
        done   = False
        if collision:
            reward += self.collision_reward
        if (agent.get_position() == self.goal_position).all():
            reward += self.goal_reward
            done = True
        return reward, done

    def get_observation_map(self):
        """Returns the observation of the current state as a walkability map
           with agents (sebango) and goal position (-1) labeled
        """
        obs = copy.copy(self.walkability_map)
        for ag in self.agents:
            obs[tuple(ag.get_position())] = ag.sebango
        obs[tuple(self.goal_position)] = -1
        return obs

    def get_observation_pos(self,sebango):
        """Returns the observation of the current state as the position of the
           agent indicated by sebango.
           Assumes single agent and static goal location so only need agent pos
        """
        return self.agents[sebango-2].get_position()

    def step(self,direction,sebango=2):
        """Takes one timestep with a specific direction.
           Only deals with primitive actions.
           Determines the actual direction of motion stochastically
           Determines the reward and returns reward and observation.
           Observation is the walkability map + other info:
             - the agent indicated by its sebango (a number 2 or greater)
             - The goal is indicated as -1 in the observation map.
        """
        roll   = np.random.random()
        sr = self.action_success_rate
        fr = 1.0 - sr
        if roll <= sr:
            coll = self.move_agent(direction,sebango)
        elif roll <= sr+fr/3.:
            coll = self.move_agent((direction+1)%4,sebango)
        elif roll <= sr+fr*2./3.:
            coll = self.move_agent((direction+2)%4,sebango)
        else:
            coll = self.move_agent((direction+3)%4,sebango)
        obs = self.get_observation_pos(2)
        reward, done = self.evaluate_reward(sebango, collision=coll)
        return obs, reward, done

    def reset(self, random_placement=False):
        """Resets the state of the world, putting all registered  agents back
           to their initial positions (positions set at instantiation),
           unless random_placement = True
        """
        if random_placement:
            random_index     = np.random.randint(low=0,
                    high=self.state_space.shape[0],size=len(self.agents))
            for i,ag in enumerate(self.agents):
                ag.set_position(self.state_space[random_index[i]])
        else:
            for ag in self.agents:
                ag.set_position(ag.initial_position)
        obs = self.get_observation_pos(2)    # CURRENTLY ASSUMING ONE AGENT!
        return obs

Replace all `state` variables with dictionaries

All MinAtar environments have states that are stored in dictionaries. I want this to be apply for all environments (also the classic control ones). We may loose a little bit of speed, but gain a lot of "de-buggabality" in return. Everything will be explicit.

Modifying optimal return parameter has no effect (bug)

Describe the bug

I have observed that modifying the optimal return parameter of the DiscountingChain bsuite environment has no effect. The author of the code consistently used the specific value of 1.1 in place of the variable name optimal_return throughout the code.

To Reproduce

  1. Run the following sample code.
import gymnax
import jax
from gymnax import environments as envs
rng = jax.random.PRNGKey(0)
env, env_params = gymnax.make("DiscountingChain-bsuite")
rng, key_reset= jax.random.split(rng, 2)

# Attempting to change the optimal_return parameter.
params_setting = {
                "optimal_return" : 1.2
            }
env_params=envs.bsuite.discounting_chain.EnvParams(**params_setting)
print("The optimal return has now changed \
as can been seen here:\n\n",env_params)

# Reset the environment.
obs, state = env.reset(key_reset, env_params)
print()
print("However after printing the state rewards we can\
still see that the optimal return is still 1.1:")
print(state.rewards)

Expected behaviour

The optimal return in state.rewards should be 1.2.

Actual behaviour

The optimal return in state.rewards is still 1.1.

Add Brax <-> Gymnax wrappers

Reminder todo after internship.

Add a simple wrapper that converts gymnax transition tuple to brax state-like objects. This can allow smooth integration of gymnax into brax codebases and vice versa.

Potential bug in Tuple space

Hi Rob,

Great library, thanks for all the hard work. Have been using some custom Gymnax environements in recent work (e.g. https://arxiv.org/abs/2303.10672).

There seem to be issues with the sample() and contains() methods for the Tuple space.

For example:

import gymnax.environments.spaces as spaces
s = spaces.Tuple([spaces.Discrete(5), spaces.Discrete(5)])
s.sample(rng=jax.random.PRNGKey(0))

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[95], line 3
      1 import gymnax.environments.spaces as spaces
      2 s = spaces.Tuple([spaces.Discrete(5), spaces.Discrete(5)])
----> 3 s.sample(rng=jax.random.PRNGKey(0))

File [~/miniconda3/envs/gymnax_and_libs/lib/python3.9/site-packages/gymnax/environments/spaces.py:119](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/joefarrington/other_learning/gymnax_and_libs/~/miniconda3/envs/gymnax_and_libs/lib/python3.9/site-packages/gymnax/environments/spaces.py:119), in Tuple.sample(self, rng)
    116 """Sample random action from all subspaces."""
    117 key_split = jax.random.split(rng, self.num_spaces)
    118 return tuple(
--> 119     [self.spaces[k].sample(key_split[i]) for i, k in enumerate(self.spaces)]
    120 )

File [~/miniconda3/envs/gymnax_and_libs/lib/python3.9/site-packages/gymnax/environments/spaces.py:119](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/joefarrington/other_learning/gymnax_and_libs/~/miniconda3/envs/gymnax_and_libs/lib/python3.9/site-packages/gymnax/environments/spaces.py:119), in (.0)
    116 """Sample random action from all subspaces."""
    117 key_split = jax.random.split(rng, self.num_spaces)
    118 return tuple(
--> 119     [self.spaces[k].sample(key_split[i]) for i, k in enumerate(self.spaces)]
    120 )

TypeError: list indices must be integers or slices, not Discrete
s.contains((1, 1))
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[96], line 1
----> 1 s.contains((1, 1))

File [~/miniconda3/envs/gymnax_and_libs/lib/python3.9/site-packages/gymnax/environments/spaces.py:129](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/joefarrington/other_learning/gymnax_and_libs/~/miniconda3/envs/gymnax_and_libs/lib/python3.9/site-packages/gymnax/environments/spaces.py:129), in Tuple.contains(self, x)
    127 out_of_space = 0
    128 for space in self.spaces:
--> 129     out_of_space += 1 - space.contains(x)
    130 return out_of_space == 0

File [~/miniconda3/envs/gymnax_and_libs/lib/python3.9/site-packages/gymnax/environments/spaces.py:44](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/joefarrington/other_learning/gymnax_and_libs/~/miniconda3/envs/gymnax_and_libs/lib/python3.9/site-packages/gymnax/environments/spaces.py:44), in Discrete.contains(self, x)
     41 """Check whether specific object is within space."""
     42 # type_cond = isinstance(x, self.dtype)
     43 # shape_cond = (x.shape == self.shape)
---> 44 range_cond = jnp.logical_and(x >= 0, x < self.n)
     45 return range_cond

TypeError: '>=' not supported between instances of 'tuple' and 'int'

I think the following fixes it, happy to raise a pull request if you want.

class Tuple(Space):
    """Minimal jittable class for tuple (product) of jittable spaces."""

    def __init__(self, spaces: Sequence[Space]):
        self.spaces = spaces
        self.num_spaces = len(spaces)

    def sample(self, rng: chex.PRNGKey) -> Tuple[chex.Array]:
        """Sample random action from all subspaces."""
        key_split = jax.random.split(rng, self.num_spaces)
        return tuple(
            [s.sample(key_split[i]) for i, s in enumerate(self.spaces)]
        )

    def contains(self, x: jnp.int_) -> bool:
        """Check whether dimensions of object are within subspace."""
        # type_cond = isinstance(x, tuple)
        # num_space_cond = len(x) != len(self.spaces)
        # Check for each space individually
        out_of_space = 0
        for i,space in enumerate(self.spaces):
            out_of_space += 1 - space.contains(x[i])
        return out_of_space == 0

DQN rlax + bsuite vs rlax + gymnax

I would like to have a benchmark figure comparing the DQN example in rlax with a gymnax sped up version. Ideally, I want to compare the runtime for step transitions on different devices.

At the moment there is something wrong with the optimisation and/or evaluation. Figure out the bug ๐Ÿ›.

The agents should all be in an experimental directory.

Feature request: Implement `__repr__` methods for many classes

I was debugging some Gymnax code now. I had a box, and I tried to view what kind of box it is:

>>> box
<gymnax.environments.spaces.Box object at 0x000001D589442FC0>

It would be easier to introspect Gymnax experiments if a __repr__ method were implemented for Box and for many other classes.

Miscellaneous environments

I would like to also implement a set of classic miscellaneous environments, which are not part of the standard environment APIs. Here is a list to address:

  • Different Bernoulli Bandits (as in Wang et al., RL^2).
  • Gaussian Bandit as in (Lange & Sprekeler, 2020)
  • Sutton's Four Rooms environment.
  • Markow reward processes (as in MPG work).
  • Markov chain processes (as in MPG work).

`TrajectoryCollector` with discount masking if terminal

Write a class that collects trajectories and returns a NamedTuple of collected data. This should include a buffer of state transition tuples (s_t, a_t, s_t_1, r_t, d_t). Problem: How to make general enough that different stats can also be stored (e.g. log_prob). Make agent return these in actor_step?

MinAtar Sticky Actions

Hello,

I was just wondering about sticky actions for the MinAtar environments. It doesn't look as though you have implemented it but I was just wondering if there is something I am missing.

Potential bug due to lax.select usage in step function

Im currently using the latest jax version (0.4.8) and it reports that

  File "/home/stao/mambaforge/envs/robojax_brax/lib/python3.8/site-packages/gymnax/environments/environment.py", line 43, in step
    state = jax.tree_map(
  File "/home/stao/mambaforge/envs/robojax_brax/lib/python3.8/site-packages/gymnax/environments/environment.py", line 44, in <lambda>
    lambda x, y: jax.lax.select(done, x, y), state_re, state_st
TypeError: lax.select requires arguments to have the same dtypes, got float32, int32. (Tip: jnp.where is a similar function that does automatic type promotion on inputs).

Any idea of what's going on? Seems like using jnp.where instead is a simple fix.

Automated tools for benchmarking

I would like to have a set of utilities that automatically generates jobs and schedules them for the different hardware platforms. E.g. CPU/GPU/TPU. This should probably all run on GCP and with docker containers for reproducibility. Have a look pyhpc for benchmarking standards and potentially integrate with mle-toolbox GCP experiment setup.

AttributeError: module 'jax' has no attribute 'tree_multimap'

When I try to run the example:

import jax
import gymnax

rng = jax.random.PRNGKey(0)
rng, key_reset, key_act, key_step = jax.random.split(rng, 4)

# Instantiate the environment & its settings.
env, env_params = gymnax.make("Pendulum-v1")

# Reset the environment.
obs, state = env.reset(key_reset, env_params)

# Sample a random action.
action = env.action_space(env_params).sample(key_act)

# Perform the step transition.
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)

I get the following error:

UnfilteredStackTrace: AttributeError: module 'jax' has no attribute 'tree_multimap'

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

AttributeError                            Traceback (most recent call last)
[/usr/local/lib/python3.7/dist-packages/gymnax/environments/environment.py](https://localhost:8080/#) in step(self, key, state, action, params)
     41         obs_re, state_re = self.reset_env(key_reset, params)
     42         # Auto-reset environment based on termination
---> 43         state = jax.tree_multimap(
     44             lambda x, y: jax.lax.select(done, x, y), state_re, state_st
     45         )

AttributeError: module 'jax' has no attribute 'tree_multimap'

I don't think jax.tree_multimap is valid in any Jax version? Isn't it jax.tree_utils.tree_multimap? I might be wrong though, still very new to Jax.

Mention request for Pgx

Sorry for the comment from out of the blue. gymnax's environment coverage is great, and I really like the visualization tool ๐Ÿ‘

Today, we released Pgx, a collection of JAX-based RL environments dedicated to classic board games like Go. We have implemented over 15 environments, including Backgammon, Shogi, and Go, and confirmed that they are considerably faster than existing C++/Python implementations. We also plan to implement Chess and Contract Bridge in the coming weeks.

We believe gymnax and Pgx can complement each other like Gymnasium and PettingZoo. We would be happy if you could kindly mention Pgx in the README like other JAX-based RL environments if you like it. For example,

  • ๐Ÿ’ป Pgx: JAX-based classic board game environments.

Thanks!

Add `Seaquest` MinAtar environment

Reminder todo after internship.

Add the (more) complicated logic for seaquest. Ask @kenjyoung for advice/support. Probably will involve setting a max number of enemies etc. and looping over this fixed set.

ValueError: mutable default <class 'jaxlib.xla_extension.ArrayImpl'> for field reward_timestep is not allowed: use default_factory

I'm getting the following error:

$ python3 -c "import gymnax"
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/usr/local/lib/python3.11/site-packages/gymnax/__init__.py", line 1, in <module>
    from .registration import make, registered_envs
  File "/usr/local/lib/python3.11/site-packages/gymnax/registration.py", line 1, in <module>
    from .environments import (
  File "/usr/local/lib/python3.11/site-packages/gymnax/environments/__init__.py", line 9, in <module>
    from .bsuite import (
  File "/usr/local/lib/python3.11/site-packages/gymnax/environments/bsuite/__init__.py", line 3, in <module>
    from .discounting_chain import DiscountingChain
  File "/usr/local/lib/python3.11/site-packages/gymnax/environments/bsuite/discounting_chain.py", line 17, in <module>
    @struct.dataclass
     ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/flax/struct.py", line 101, in dataclass
    data_clz = dataclasses.dataclass(frozen=True)(clz) # type: ignore
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/Cellar/[email protected]/3.11.2_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/dataclasses.py", line 1210, in wrap
    return _process_class(cls, init, repr, eq, order, unsafe_hash,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/Cellar/[email protected]/3.11.2_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/dataclasses.py", line 958, in _process_class
    cls_fields.append(_get_field(cls, name, type, kw_only))
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/Cellar/[email protected]/3.11.2_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/dataclasses.py", line 815, in _get_field
    raise ValueError(f'mutable default {type(f.default)} for field '
ValueError: mutable default <class 'jaxlib.xla_extension.ArrayImpl'> for field reward_timestep is not allowed: use default_factory

Version information:

  • python 3.11.2
  • jax 0.4.8
  • jaxlib 0.4.7
  • gymnax 0.0.5

Refactor `agents`/`dojos` into `experimental`

The directories and utilities in agents and dojos are not supposed to be part of the core API. Refactor them into an experimental directory. Make sure that the examples still use correct imports.

Trained baselines values incl. active training

very nice work presented here, well done!

just a small question regarding the speed up evaluation for gymnax environments: I guess that the reported execution times with a neural network policy are based on a fixed policy without active learning (i.e., policy improvement steps), right?

did you also benchmark the speed up with active learning of the policy utilizing standard algorithms like PPO, ES,...?

On the use of `jnp.int_`

Hi, thank you for your work on this great library. When I tried gymnax for the first time, I frequently encountered errors due to the use of jnp.int_.

UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in array is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.

It looks like jnp.int_ is used in:

  • Catch-bsuite
  • Freeway-minatar
  • spaces.Discrete
  • spaces.contain

So my question is: do you have any specific reason to use jnp.int_? If not, I want to contribute to replacing these with jnp.int32 because I don't need JAX_ENABLE_X64 for other usages...

Issue: vmapped CartPole input shape does not match

Hello there. I am trying to run a vmapped CartPole step function. My environment state inputs are of the shape:

env_state:
[executor/0] x:  Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=2/0)>
[executor/0] x_dot:  Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=2/0)>
[executor/0] theta:  Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=2/0)>
[executor/0] theta_dot:  Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=2/0)>

When I run jnp.array([env_state.x, env_state.x_dot, env_state.theta, env_state.theta_dot]) on the state, before the environment step, and get out:
Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=2/0)>

However when I try to run the step function I get:

obs, env_state, rewards, done, _ = self.env.step(key_step, env_state, action, self.env_params)
[executor/0]   File "/mava/lib/python3.8/site-packages/gymnax/environments/environment.py", line 38, in step
[executor/0]     obs_st, state_st, reward, done, info = self.step_env(
[executor/0]   File "/mava/lib/python3.8/site-packages/gymnax/environments/classic_control/cartpole.py", line 83, in step_env
[executor/0]     lax.stop_gradient(self.get_obs(state)),
[executor/0]   File "/mava/lib/python3.8/site-packages/gymnax/environments/classic_control/cartpole.py", line 108, in get_obs
[executor/0]     return jnp.array([state.x, state.x_dot, state.theta, state.theta_dot])
[executor/0]   File "/mava/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 1889, in array
[executor/0]     out = stack([asarray(elt, dtype=dtype) for elt in object])
[executor/0]   File "/mava/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 1634, in stack
[executor/0]     raise ValueError("All input arrays must have the same shape.")
[executor/0] ValueError: All input arrays must have the same shape.

Do you have any idea what might be causing this issue? Is the shapes somehow changing inside the step function? Thanks.

Gymnasium API update

Will gymnax update their API in-line with gymnasium? Namely the recent breaking change in gym/gymnasium where step returns terminated, truncated instead of done?

How to use `env.render()` to visualize environment transitions frame by frame?

rng = jax.random.PRNGKey(0)
rng, key_reset, key_act, key_step = jax.random.split(rng, 4)

# Instantiate the environment & its settings.
env, env_params = gymnax.make('FourRooms-misc')

# Reset the environment.
obs, state = env.reset(key_reset, env_params)

while True:
    # Sample a random action.
    action = env.action_space(env_params).sample(key_act)

    # Perform the step transition.
    n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)
    print(f"action: {action}, reward: {reward}, done: {done}")
    env.render(n_state, env_params)

I tried this simple example. The matplotlib figure is never shown. It is stuck and never opens.

Follow-up: Why is gymnax using matplotlib for rendering and not pygame (as gym) ?

Differentiate step function ?

Hello,
is it possible to return the differential of the step reward function (with respect to the action) at least for the simplest envs like pendulum, cartple ?
Best, Jacek

Add CONTRIBUTING .md file

Think about different ways to encourage contributions and what we would like to add. Not any random environment but staples? Forward to templates and give guidelines for issues and PRs. Have a look at how others are doing this. E.g. OpenAI gym, rlax and neurolib.

Notebook links missing?

Thanks so much for releasing this repo, it looks great!

The top two links in the examples section of the README give 404 errors for me, e.g.:

The 'RobertTLange/gymnax' repository doesn't contain the 'notebooks/getting_started.ipynb' path in 'main'. `

[Proposal] Environment API changes

Hey, first off, I love this project and the general idea of defining environments in JAX so that they can be easily batched and integrated into RL training loops!

I tend to do a lot of work with POMDPs and have built a few branches in my own fork that implement various POMDP environments. It works fine for my purposes, but I've run into a couple instances where I just ignore the base Environment API

Specifically:

  1. In typical POMDP formulations, the observation o is a function of state AND action (i.e., o ~ O(s,a)), but currently get_obs() only uses state
  2. Similarly, there are many instances where we need the state AND/OR action to determine if the episode is done (e.g., in the Tiger problem, it ends when you open a door, but the state is just which door the tiger is behind). is_terminal() and discount() only use state
  3. Finally, in the same way that your FourRooms example has noisy actions, there are many environments where a noisy observation is core to the environment, requiring get_obs to also have an RNG key

Obviously I'm just overriding the methods with the extra arguments as needed for my own environments, but some of this might be common enough to justify a different base API?

`Pong-misc`: TypeError: select cases must have the same shapes, got [(30, 40), ()].

When running the Pong-misc environment, the following error is raised from move_paddles.

I tried both the example notebook and gymnax-blines to ensure it's not an usage error.

Below is the stack trace and the gymnax-blines configuration I have used.

$ python train.py -config agents/Pong-misc/ppo.yaml

PPO:   0%|                                                                                                                                                                                                                                                        | 0/18751 [00:00<?, ?it/s]
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/helge/Sandbox/pt/gymnax-blines/train.py", line 76, in <module>
    main(
  File "/home/helge/Sandbox/pt/gymnax-blines/train.py", line 24, in main
    log_steps, log_return, network_ckpt = train_fn(
  File "/home/helge/Sandbox/pt/gymnax-blines/utils/ppo.py", line 271, in train_ppo
    train_state, obs, state, batch, rng_step = get_transition(
  File "/home/helge/Sandbox/pt/gymnax-blines/utils/ppo.py", line 252, in get_transition
    next_obs, next_state, reward, done, _ = rollout_manager.batch_step(
  File "/home/helge/Sandbox/pt/gymnax-blines/utils/ppo.py", line 138, in batch_step
    return jax.vmap(self.env.step, in_axes=(0, 0, 0, None))(
  File "/home/helge/Sandbox/pt/code/.venv/lib/python3.10/site-packages/gymnax/environments/environment.py", line 45, in step
    obs_st, state_st, reward, done, info = self.step_env(key, state, action, params)
  File "/home/helge/Sandbox/pt/code/.venv/lib/python3.10/site-packages/gymnax/environments/misc/pong.py", line 75, in step_env
    state = move_paddles(
  File "/home/helge/Sandbox/pt/code/.venv/lib/python3.10/site-packages/gymnax/environments/misc/pong.py", line 356, in move_paddles
    new_center_p2 = jax.lax.select(use_ai_policy, new_center_ai, new_center_self)
TypeError: select cases must have the same shapes, got [(30, 40), ()].

Configuration (copied from CartPole-v1):

train_config:
  train_type: "PPO"
  num_train_steps: 150000
  evaluate_every_epochs: 1000

  env_name: "Pong-misc"
  env_kwargs: {}
  env_params: {}
  num_test_rollouts: 164
  
  num_train_envs: 8  # Number of parallel env workers
  max_grad_norm: 0.5  # Global norm to clip gradients by
  gamma: 0.99  # Discount factor
  n_steps: 32 # "GAE n-steps"
  n_minibatch: 4 # "Number of PPO minibatches"
  lr_begin: 5e-04  # Start PPO learning rate
  lr_end: 5e-04 #  End PPO learning rate
  lr_warmup: 0.05 # Prop epochs until warmup is completed 
  epoch_ppo: 4  # "Number of PPO epochs on a single batch"
  clip_eps: 0.2 # "Clipping range"
  gae_lambda: 0.95 # "GAE lambda"
  entropy_coeff: 0.01 # "Entropy loss coefficient"
  critic_coeff: 0.5  # "Value loss coefficient"

  network_name: "Categorical-MLP"
  network_config:
    num_hidden_units: 64
    num_hidden_layers: 2

log_config:
  time_to_track: ["num_steps"]
  what_to_track: ["return"]
  verbose: false
  print_every_k_updates: 1
  overwrite: 1
  model_type: "jax"

device_config:
  num_devices: 1
  device_type: "gpu"

Replace all `params_env_name ` with `FrozenDict`

Some params dictionaries do specify the shapes of observation. Hence, when jitting we need to mark them as static_argnums. That in turn is only possible if the dictionary is immutable. I propose porting the flax FrozenDict and to provide a helper function called update_env_params(params, x_name, x_value), which unfreezes, changes and freezes the dictionary again.

In order to reduce dependencies, it may make sense to simply copy the file and use the same Apache License.

https://github.com/google/flax/blob/ac0f57419f32c9924e094e7e0dc82a15be228b5d/flax/core/frozen_dict.py

Go through all envs and update the parameter dictionaries.

CPU/GPU/TPU Benchmarks

Generate benchmarks for runtimes of the different environments and compare with plain NumPy/torch versions.

  • Open AI gym control environments
  • Behavioral suite environments
  • MinAtar environments

env step accumulates memory

Hi Robert,

Thanks for this awesome library!

I use the gymnax library on CPU to collect data for the Breakout MinAtar environment. I generate thousands of random programs and want to execute them on the env. Somehow the memory accumulates over time so that I get RAM problems. I used the python memory profiler and could detect that, the step function of the env always add about 10MB after each call. Do you know why that is the case? Is this maybe only the case when running Jax on CPU?

I had problems getting Jax and Pytorch running in the same virtual env on Cuda so I thought, I just run gymnax on the CPU to avoid Cuda problems. The memory is also not released in the next step of the loop or at the end of the function..

I used the code from the visualization notebook as a reference.

Thanks a lot for your answer!

Best wishes,
Manuel

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    75    899.3 MiB    899.3 MiB           1   @profile
    76                                         def rollout_episode(env_data, env_params, seq_length, program):
    77    899.3 MiB      0.0 MiB           1       env, obs, env_state = env_data
    78    899.3 MiB      0.0 MiB           1       rng = jax.random.PRNGKey(0)
    79    899.3 MiB      0.0 MiB           1       examples = []
    80    908.6 MiB      0.0 MiB           6       for _ in range(seq_length):
    81    908.6 MiB      0.0 MiB           6           rng, rng_act, rng_step = jax.random.split(rng, 3)
    82    908.6 MiB      0.0 MiB           6           obs_inp = convert_to_task_input(obs)
    83    908.6 MiB      0.0 MiB           6           input_ex = (obs_inp,)
    84                                                 # output ex is the action
    85    908.6 MiB      0.0 MiB          18           output_ex = runWithTimeout(lambda: program.runWithArguments(input_ex), None)
    86    908.6 MiB      0.0 MiB           6           examples.append((input_ex, output_ex))
    87                                         
    88    908.6 MiB      9.3 MiB          12           next_obs, next_env_state, reward, done, info = env.step(
    89    908.6 MiB      0.0 MiB           6               rng_step, env_state, output_ex, env_params
    90                                                 )
    91                                         
    92    908.6 MiB      0.0 MiB           6           if done:  # or t_counter == max_frames:
    93    908.6 MiB      0.0 MiB           1               break
    94                                                 else:
    95    908.6 MiB      0.0 MiB           5               env_state = next_env_state
    96    908.6 MiB      0.0 MiB           5               obs = next_obs
    97    908.6 MiB      0.0 MiB           1       return examples```` 

Pendulum-1, MountainCarContinuous-v0 and Reacher-misc return non-squeezed reward

Out of all environments, Pendulum-1, MountainCarContinuous-v0 and Reacher-misc return a jax array with shape (1, ) as a reward. This is inconsistent with all other environments, which return an array with shape (). This can lead to unexpected shaping errors, for example consider a case like this

num_envs = 3
weights = jnp.arange(num_envs)
rewards, ... = jax.vmap(env.step, ...)(actions)
weighted_rewards = weights * rewards

If the reward returned by the environment has shape (1, ), the the result of vmapping will have shape (3, 1) instead of (3, ), and therefore weighted_rewards will have shape (3, 3) instead of (3, ).

Elegant registration of environments

As of right now the registration of an env is done via an import and string conditional. This is probably the worst way of doing it and can't scale or enable spin-offs/easy integration of new envs.

if env_id == "Pendulum-v0":

Mimic the full gym registry pipeline so that people can follow that standard registration setup.

https://github.com/openai/gym/blob/a5a6ae6bc0a5cfc0ff1ce9be723d59593c165022/gym/envs/registration.py#L73

bsuite environment implementation

Implement the set of classic bsuite environments:

  • catch.py
  • bandit.py
  • deep_sea.py
  • discounting_chain.py
  • memory_chain.py
  • mnist.py
  • umbrella_chain.py
  • cartpole.py - Check if this is different from OpenAI gym version
  • mountain_car.py - Check if this is different from OpenAI gym version

Is there a method of making custom environments and registering them?

Hello,

I have been enjoying your project so far but, as an inexperienced person with the codebase I have a few questions about it.
Gymnasium supports making your custom environments and registering them so that gymnax.make() can be used on this new environment. Does gymnax offer a similar functionality?

When trying to add custom environments I found that I had to modify the package source code itself. I was wondering if there was an easier method something along the lines of the system offered by gymnasium.

My project needs to generate a new environment dynamically so modifying the source code was a bit of hassle for me. So I was wondering if there was a native implementation or a workaround that is possible.

MinAtar Environment Implementation

The MinAtar environments are implemented in plain NumPy and can capture cool dynamics. They should be "jax-able" -- give it a go. Here is a TODO-list:

  • Asterix

  • Breakout

  • Seaquest

  • Space Invaders

  • Freeway

  • Many of them are non-deterministic/require sampling of objects based on open slots. Come up with a good test for these special cases.

  • Also add a visualisation wrapper. This can make debugging a lot easier.

  • Make a decision what the allowed actions are. 'n' for example encodes a no-op. This has to respected!

Observation/Action Space Information & Sampling

Add a wrapper that allows the user to sample random actions based on the observation space of different environments. Potentially it makes sense to store this data also in params_env_names. This could also be implemented as an experimental agent wrapper RandomAgent.

[Proposal] Gym conversion wrappers

Would you be interested in a PR with wrappers to convert an Environment instance into Gym and VectorGym instances? This would be similar to how the Brax wrappers work (I helped write that PR), where the Gym environment keeps track of rng/state and vectorizes the underlying environment. I have rough implementations already and would be interested in contributing

If `Environment.observation_space` requires the `EnvParams`, `Environment.get_obs` should too.

If params are designed to influence the observation space, then they should be available when generating an observation. Currently, the two functions in question are implemented like this

def get_obs(self, state: EnvState) -> chex.Array: pass
def observation_space(self, params: EnvParams): pass

I believe that this isn't a fatal issue, but I believe it would be more consistent when considering the overall design of the environment interface.

Action wrappers

It may make sense to add wrappers for the action selection, e.g. sticky actions/action repeats. This will rely on the info provided in the action space specification and may require storing the previous action in the state dictionary. Alternatively, this may be left to the user. In numpy this would look like this:

def act(a, last_action, sticky_action_prob):
        if (np.random.rand() < sticky_action_prob):
            a = last_action
        last_action = a
        return env.act(a), last_action

Add different requirement files

Currently there is only one minimal set of requirements in the setup.py file. Each time the user wants to do something "more specialised" they will most likely get a module error stating that they should install an additional package. That. can be frustrating. Instead, I want a set of different requirements.txt as for example in the distrax repository:

requirements.txt
requirements-full.txt
requirements-examples.txt
requirements-tests.txt

Also take a look at their setup.py and make ours more professional.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.