Code Monkey home page Code Monkey logo

navix's Introduction

What is NAVIX?

NAVIX is a JAX-powered reimplementation of MiniGrid. Experiments that took 1 week, now take 15 minutes.

200 000x speedups compared to MiniGrid and 670 Million steps/s are not just a speed improvements. They produce a whole new paradigm that grants access to experiments that were previously impossible, e.g., those taking years to run.

It changes the game.
Check out the NAVIX performance more in detail and the documentation for more information.

Key features:

  • Performance Boost: NAVIX offers over 1000x speed increase compared to the original Minigrid implementation, enabling faster experimentation and scaling. You can see a preliminary performance comparison here, and a full benchmarking at here.
  • XLA Compilation: Leverage the power of XLA to optimize NAVIX computations for many accelerators. NAVIX can run on CPU, GPU, and TPU.
  • Autograd Support: Differentiate through environment transitions, opening up new possibilities such as learned world models.
  • Batched hyperparameter tuning: run thousands of experiments in parallel, enabling hyperparameter tuning at scale. Clear your doubts instantly if your algorithm doesn't work because of the hyperparameters choice.
  • It allows finally focus on the method research, and not the engineering.

The library is in active development, and we are working on adding more environments and features. If you want join the development and contribute, please open a discussion and let's have a chat!

Installation

Install JAX

Follow the official installation guide for your OS and preferred accelerator: https://github.com/google/jax#installation.

Install NAVIX

pip install navix

Or, for the latest version from source:

pip install git+https://github.com/epignatelli/navix

Performance

NAVIX improves MiniGrid both in execution speed and throughput, allowing to run more than 2048 PPO agents in parallel almost 10 times faster than a single PPO agent in the original MiniGrid.

speedup_env

NAVIX performs 2048 × 1M/49s = 668 734 693.88 steps per second (∼ 670 Million steps/s) in batch mode, while the original Minigrid implementation performs 1M/318.01 = 3 144.65 steps per second. This is a speedup of over 200 000×. throughput_ppo

Examples

You can view a full set of examples here (more coming), but here are the most common use cases.

Compiling a collection step

import jax
import navix as nx
import jax.numpy as jnp


def run(seed):
  env = nx.make('MiniGrid-Empty-8x8-v0') # Create the environment
  key = jax.random.PRNGKey(seed)
  timestep = env.reset(key)
  actions = jax.random.randint(key, (N_TIMESTEPS,), 0, env.action_space.n)

  def body_fun(timestep, action):
      timestep = env.step(action)  # Update the environment state
      return timestep, ()

  return jax.lax.scan(body_fun, timestep, actions)[0]

# Compile the entire training run for maximum performance
final_timestep = jax.jit(jax.vmap(run))(jnp.arange(1000))

Compiling a full training run

import jax
import navix as nx
import jax.numpy as jnp
from jax import random

def run_episode(seed, env, policy):
    """Simulates a single episode with a given policy"""
    key = random.PRNGKey(seed)
    timestep = env.reset(key)
    done = False
    total_reward = 0

    while not done:
        action = policy(timestep.observation)
        timestep, reward, done, _ = env.step(action)
        total_reward += reward

    return total_reward

def train_policy(policy, num_episodes):
    """Trains a policy over multiple parallel episodes"""
    envs = jax.vmap(nx.make, in_axes=0)(['MiniGrid-MultiRoom-N2-S4-v0'] * num_episodes)
    seeds = random.split(random.PRNGKey(0), num_episodes)

    # Compile the entire training loop with XLA
    compiled_episode = jax.jit(run_episode)
    compiled_train = jax.jit(jax.vmap(compiled_episode, in_axes=(0, 0, None)))

    for _ in range(num_episodes):
        rewards = compiled_train(seeds, envs, policy)
        # ... Update the policy based on rewards ...

# Hypothetical policy function
def policy(observation):
   # ... your policy logic ...
   return action

# Start the training
train_policy(policy, num_episodes=100)

Backpropagation through the environment

import jax
import navix as nx
import jax.numpy as jnp
from jax import grad
from flax import struct


class Model(struct.PyTreeNode):
  @nn.compact
  def __call__(self, x):
    # ... your NN here

model = Model()
env = nx.environments.Room(16, 16, 8)

def loss(params, timestep):
  action = jnp.asarray(0)
  pred_obs = model.apply(timestep.observation)
  timestep = env.step(timestep, action)
  return jnp.square(timestep.observation - pred_obs).mean()

key = jax.random.PRNGKey(0)
timestep = env.reset(key)
params = model.init(key, timestep.observation)

gradients = grad(loss)(params, timestep)

JAX ecosystem for RL

NAVIX is not alone and part of an ecosystem of JAX-powered modules for RL. Check out the following projects:

  • Environments:
    • Gymnax: a broad range of RL environments
    • Brax: a physics engine for robotics experiments
    • EnvPool: a set of various batched environments
    • Craftax: a JAX reimplementation of the game of Crafter
    • Jumanji: another set of diverse environments
    • PGX: board games commonly used for RL, such as backgammon, chess, shogi, and go
    • JAX-MARL: multi-agent RL environments in JAX
    • Xland-Minigrid: a set of JAX-reimplemented grid-world environments
    • Minimax: a JAX library for RL autocurricula with 120x faster baselines
  • Agents:
    • PureJaxRl: proposing fullly-jitten training routines
    • Rejax: a suite of diverse agents, among which, DDPG, DQN, PPO, SAC, TD3
    • Stoix: useful implementations of popular single-agent RL algorithms in JAX
    • JAX-CORL: lean single-file implementations of offline RL algorithms with solid performance reports
    • Dopamine: a research framework for fast prototyping of reinforcement learning algorithms

Join Us!

NAVIX is actively developed. If you'd like to contribute to this open-source project, we welcome your involvement! Start a discussion or open a pull request.

Please, consider starring the project if you like NAVIX!

Cite us, please!

If you use NAVIX please cite it as:

@article{pignatelli2024navix,
  title={NAVIX: Scaling MiniGrid Environments with JAX},
  author={Pignatelli, Eduardo and Liesen, Jarek and Lange, Robert Tjarko and Lu, Chris and Castro, Pablo Samuel and Toni, Laura},
  journal={arXiv preprint arXiv:2407.19396},
  year={2024}
}

navix's People

Contributors

alanlivio avatar epignatelli avatar jysdoran avatar luchris429 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

navix's Issues

Entities should have only members, not `@property`

For clarity and to minimise entropy, entities should not implement @property, but full members.

For example:

class Foo(Entity):
  walkable: Array = jnp.asarray(False)

While this can be problematic for broadcasting, when entities are batched, the create method should take care of that.

Broken link

Just a heads up.

Was forwarding this to a friend, and just noticed that the link to the performances notebook is broken in the readme, here:

You can find a partial performance comparison with [minigrid](https://github.com/Farama-Foundation/Minigrid) in the [docs](docs/profiling.ipynb).

The current solution for `State` is not scalable to other, new entities.

The current solution for State is not scalable to other, new entities.

Consider replacing the state set of players doors, keys goals with a single (fixed-length at compile time) collection of entities.

From this:

class State(struct.PyTreeNode):
    """The Markovian state of the environment"""

    key: KeyArray
    """The random number generator state"""
    grid: Array
    """The base map of the environment that remains constant throughout the training"""
    cache: RenderingCache
    """The rendering cache to speed up rendering"""
    players: Player = Player.create()
    """The player entity"""
    goals: Goal = Goal.create()
    """The goal entity, batched over the number of goals"""
    keys: Key = Key.create()
    """The key entity, batched over the number of keys"""
    doors: Door = Door.create()

To this:

class State(struct.PyTreeNode):
    """The Markovian state of the environment"""

    key: KeyArray
    """The random number generator state"""
    grid: Array
    """The base map of the environment that remains constant throughout the training"""
    cache: RenderingCache
    """The rendering cache to speed up rendering"""
    entities: Tuple[Entity, ...]

The main obstacle is the computational cost of iterating though the list when we need to extract a specific entity, like a player for the action, for example.

Automate version upgrade

Version is currently uploaded manually before merging to main.

It would be convenient to update that automatically at every pull to main.
Notice that you need to pull again the automatic commit to be up to date with main

`KeyDoor`: the door can spawn on the room wall

import matplotlib.pyplot as plt
import navix as nx
import jax


env = nx.environments.KeyDoor(12, 6, 100, observation_fn=nx.observations.rgb)

key = jax.random.PRNGKey(0)
timestep = env.reset(key)
plt.imshow(timestep.observation)
plt.show()

Rendering two sprites in the same cell

import jax.numpy as jnp
import navix as nx
import matplotlib.pyplot as plt

grid = jnp.zeros((1, 1, 32, 32, 3), dtype=jnp.uint8)
goal = nx.entities.Goal.create(position=jnp.asarray((0, 0)), probability=jnp.asarray(1.0))
player = nx.entities.Player.create(position=jnp.asarray((0, 0)))

positions = jnp.stack([goal.position[0], player.position])
sprites = jnp.stack([goal.get_sprite(nx.graphics.SPRITES_REGISTRY)[0], player.get_sprite(nx.graphics.SPRITES_REGISTRY)])
image = grid.at[tuple(positions.T)].set(sprites)
image = jnp.swapaxes(image, 1, 2)
image = image.reshape(32, 32, 3)
plt.imshow(image)

image

Implement more environments

Source: https://minigrid.farama.org/environments/minigrid/

  • MiniGrid-BlockedUnlockPickup-v0
  • MiniGrid-LavaCrossingS9N1-v0
  • MiniGrid-LavaCrossingS9N2-v0
  • MiniGrid-LavaCrossingS9N3-v0
  • MiniGrid-LavaCrossingS11N5-v0
  • MiniGrid-SimpleCrossingS9N1-v0
  • MiniGrid-SimpleCrossingS9N2-v0
  • MiniGrid-SimpleCrossingS9N3-v0
  • MiniGrid-SimpleCrossingS11N5-v0
  • MiniGrid-DistShift1-v0
  • MiniGrid-DistShift2-v0
  • MiniGrid-DoorKey-5x5-v0
  • MiniGrid-DoorKey-6x6-v0
  • MiniGrid-DoorKey-8x8-v0
  • MiniGrid-DoorKey-16x16-v0
  • MiniGrid-Dynamic-Obstacles-5x5-v0
  • MiniGrid-Dynamic-Obstacles-Random-5x5-v0
  • MiniGrid-Dynamic-Obstacles-6x6-v0
  • MiniGrid-Dynamic-Obstacles-Random-6x6-v0
  • MiniGrid-Dynamic-Obstacles-8x8-v0
  • MiniGrid-Dynamic-Obstacles-16x16-v0
  • MiniGrid-Empty-5x5-v0
  • MiniGrid-Empty-Random-5x5-v0
  • MiniGrid-Empty-6x6-v0
  • MiniGrid-Empty-Random-6x6-v0
  • MiniGrid-Empty-8x8-v0
  • MiniGrid-Empty-16x16-v0
  • MiniGrid-Fetch-5x5-N2-v0
  • MiniGrid-Fetch-6x6-N2-v0
  • MiniGrid-Fetch-8x8-N3-v0
  • MiniGrid-FourRooms-v0
  • MiniGrid-GoToDoor-5x5-v0
  • MiniGrid-GoToDoor-6x6-v0
  • MiniGrid-GoToDoor-8x8-v0
  • MiniGrid-GoToObject-6x6-N2-v0
  • MiniGrid-GoToObject-8x8-N2-v0
  • MiniGrid-KeyCorridorS3R1-v0
  • MiniGrid-KeyCorridorS3R2-v0
  • MiniGrid-KeyCorridorS3R3-v0
  • MiniGrid-KeyCorridorS4R3-v0
  • MiniGrid-KeyCorridorS5R3-v0
  • MiniGrid-KeyCorridorS6R3-v0
  • MiniGrid-LavaGapS5-v0
  • MiniGrid-LavaGapS6-v0
  • MiniGrid-LavaGapS7-v0
  • MiniGrid-LockedRoom-v0
  • MiniGrid-MemoryS17Random-v0
  • MiniGrid-MemoryS13Random-v0
  • MiniGrid-MemoryS13-v0
  • MiniGrid-MemoryS11-v0
  • MiniGrid-MemoryS9-v0
  • MiniGrid-MemoryS7-v0
  • MiniGrid-MultiRoom-N2-S4-v0
  • MiniGrid-MultiRoom-N4-S5-v0
  • MiniGrid-MultiRoom-N6-v0
  • MiniGrid-ObstructedMaze-1Dl-v0
  • MiniGrid-ObstructedMaze-1Dlh-v0
  • MiniGrid-ObstructedMaze-1Dlhb-v0
  • MiniGrid-ObstructedMaze-2Dl-v0
  • MiniGrid-ObstructedMaze-2Dlh-v0
  • MiniGrid-ObstructedMaze-2Dlhb-v0
  • MiniGrid-ObstructedMaze-1Q-v0
  • MiniGrid-ObstructedMaze-2Q-v0
  • MiniGrid-ObstructedMaze-Full-v0
  • MiniGrid-ObstructedMaze-2Dlhb-v1
  • MiniGrid-ObstructedMaze-1Q-v1
  • MiniGrid-ObstructedMaze-2Q-v1
  • MiniGrid-ObstructedMaze-Full-v1
  • MiniGrid-Playground-v0
  • MiniGrid-PutNear-6x6-N2-v0
  • MiniGrid-PutNear-8x8-N3-v0
  • MiniGrid-RedBlueDoors-6x6-v0
  • MiniGrid-RedBlueDoors-8x8-v0
  • MiniGrid-Unlock-v0
  • MiniGrid-UnlockPickup-v0

Agents walking through walls/lava?

Hi!

I have been training ppo agents in navix environments and I keep seeing that agents walk through walls/lava. I'm using version 0.6.14 of Navix. I'm using a PPO agent from PureJaxRL (purejaxrl/ppo_minigrid.py). I train the agent, and then load a network to render the behaviour, and I get the following:

from ppo_minigrid import *
from navix import observations, rewards, terminations, register_env
import navix

config = {
    "LR": 2.5e-4,
    "NUM_ENVS": 16,
    "NUM_STEPS": 256,
    "TOTAL_TIMESTEPS": 2e6,
    "UPDATE_EPOCHS": 1,
    "NUM_MINIBATCHES": 8,
    "GAMMA": 0.99,
    "GAE_LAMBDA": 0.95,
    "CLIP_EPS": 0.2,
    "ENT_COEF": 0.05,
    "VF_COEF": 0.5,
    "MAX_GRAD_NORM": 0.5,
    "ACTIVATION": "tanh",
    "ENV_NAME": "Navix-LavaGapS7-v0",
    "ANNEAL_LR": True,
    "DEBUG": True,
}

rng = jax.random.PRNGKey(30)
train_jit = jax.jit(make_train(config))
out = train_jit(rng)

import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp

env = NavixGymnaxWrapper("Navix-LavaGapS7-v0")
env = FlattenObservationWrapper(env)
env = LogWrapper(env)
key = jax.random.PRNGKey(30)
key, _key = jax.random.split(key)
obs, state = env.reset(_key, None)

def render(obs, title):
    plt.imshow(obs)
    plt.title(title)
    plt.axis('off')
    plt.savefig("image.png")

render(observations.rgb(state.env_state.state), "Initial observation")
network = ActorCritic(env.action_space(None).n, activation=config["ACTIVATION"])
params = out['runner_state'][0].params

for t in range(100):
    rng, _rng = jax.random.split(rng)
    pi, _ = network.apply(params, obs)
    action = pi.sample(seed=_rng)
    print(action)
    obs, state, reward, done, info = env.step(rng, state, action, None)
    render(observations.rgb(state.env_state.state), f"Step {t}")
    input('Continue...')

The initial observation:
initial_obs

After 2 steps:
t2

After 3 steps:
t3

I had the same with the DoorKey environments, where agents simply walk through the dividing wall. Any ideas of what is happening?

Thanks!

Daniel

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.