Code Monkey home page Code Monkey logo

purejaxrl's Introduction

PureJaxRL (End-to-End RL Training in Pure Jax)

Code style: black Open In Colab

PureJaxRL is a high-performance, end-to-end Jax Reinforcement Learning (RL) implementation. When running many agents in parallel on GPUs, our implementation is over 1000x faster than standard PyTorch RL implementations. Unlike other Jax RL implementations, we implement the entire training pipeline in JAX, including the environment. This allows us to get significant speedups through JIT compilation and by avoiding CPU-GPU data transfer. It also results in easier debugging because the system is fully synchronous. More importantly, this code allows you to use jax to jit, vmap, pmap, and scan entire RL training pipelines. With this, we can:

  • ๐Ÿƒ Efficiently run tons of seeds in parallel on one GPU
  • ๐Ÿ’ป Perform rapid hyperparameter tuning
  • ๐ŸฆŽ Discover new RL algorithms with meta-evolution

For more details, visit the accompanying blog post: https://chrislu.page/blog/meta-disco/

This notebook walks through the basic usage: Open In Colab

Performance

Without vectorization, our implementation runs 10x faster than CleanRL's PyTorch baselines, as shown in the single-thread performance plot.

Cartpole Minatar-Breakout

With vectorized training, we can train 2048 PPO agents in half the time it takes to train a single PyTorch PPO agent on a single GPU. The vectorized agent training allows for simultaneous training across multiple seeds, rapid hyperparameter tuning, and even evolutionary Meta-RL.

Vectorised Cartpole Vectorised Minatar-Breakout

Code Philosophy

PureJaxRL is inspired by CleanRL, providing high-quality single-file implementations with research-friendly features. Like CleanRL, this is not a modular library and is not meant to be imported. The repository focuses on simplicity and clarity in its implementations, making it an excellent resource for researchers and practitioners.

Installation

Install dependencies using the requirements.txt file:

pip install -r requirements.txt

Example Usage

examples/example_0.ipynb walks through the basic usage. Open In Colab

examples/example_1.ipynb walks through using PureJaxRL for Brax and MinAtar. Open In Colab

TODOs

The following improvements are planned for the PureJaxRL repository:

  1. More memory-efficient logging
  2. Integration with Weights & Biases (WandB) for experiment tracking
  3. Connecting to non-Jax environments like envpool

Related Work

PureJaxRL builds upon other tools in the Jax and RL ecosystems. Check out the following projects:

Citation

If you use PureJaxRL in your work, please cite the following paper:

@article{lu2022discovered,
    title={Discovered policy optimisation},
    author={Lu, Chris and Kuba, Jakub and Letcher, Alistair and Metz, Luke and Schroeder de Witt, Christian and Foerster, Jakob},
    journal={Advances in Neural Information Processing Systems},
    volume={35},
    pages={16455--16468},
    year={2022}
}

purejaxrl's People

Contributors

luchris429 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.