Code Monkey home page Code Monkey logo

low-precision-rl's Introduction

Low-Precision Reinforcement Learning: Running Soft Actor-Critic in Half Precision -- ICML 2021

Johan Bjorck, Xiangyu Chen, Christopher De Sa, Carla P. Gomes, Kilian Q. Weinberger

Overview

Low-precision training has become a popular approach to reduce compute requirements, memory footprint, and energy consumption in supervised learning. In our ICML 2021 paper, we consider training reinforcement learning agents in low precision. Naively training in fp16 does not work well. After six modifications, we demonstrate that low-precision RL trains stably while decreasing computational/memory demands. This codebase contains code for our main experiments. Configuration and command-line arguments are handled via the excellent hydra framework.

[paper]

Installation

  • You will need an Nvidia GPU with a reasonably recent CUDA version to run the code.

  • Create an environment from env.yml via:

conda env create -f env.yml
conda activate lowprec_rl
  • Install deepmind control suite as per here.

  • You will need to set appropriate environment flags, e.g MUJOCO_GL=egl. You may also consider the flags HYDRA_FULL_ERROR=1 and OMP_NUM_THREADS=1.

Usage

  1. To run an experiment in fp32 on the finger_spin environment with seed 123 use:

    python train.py env=finger_spin seed=123
    

    Results will appear in a folder named runs.

  2. To use half-precision (fp16) for the actor, critic, and alpha use the code below. Note, this is expected to crash.

    python train.py env=finger_spin seed=123 \
        agent.params.actor_half=True agent.params.crit_half=True agent.params.alpha_half=True
    
  3. The command above typically crashes without our proposed methods. Our proposed methods can be independently toggled with

    Method Flags
    hAdam agent.params.use_num_adam=True
    compound loss scaling agent.params.use_grad_scaler=True agent.params.adam_eps=0.0001
    normal-fix diag_gaussian_actor.params.stable_normal=True
    softplus-fix diag_gaussian_actor.params.tanh_threshold=10
    Kahan-momentum agent.params.soft_update_scale=10000
    Kahan-gradients agent.params.alpha_kahan=True agent.params.crit_kahan=True
  4. To apply all proposed methods,

    python train.py env=finger_spin seed=123 \
        agent.params.actor_half=True agent.params.crit_half=True agent.params.alpha_half=True \
        agent.params.use_grad_scaler=True agent.params.adam_eps=0.0001 agent.params.use_num_adam=True \
        diag_gaussian_actor.params.tanh_threshold=10 diag_gaussian_actor.params.stable_normal=True \
        agent.params.soft_update_scale=10000 agent.params.alpha_kahan=True agent.params.crit_kahan=True
    

Citation

@inproceedings{bjorck2021low,
  title={Low-Precision Reinforcement Learning: Running Soft Actor-Critic in Half Precision},
  author={Bj{\"o}rck, Johan and Chen, Xiangyu and De Sa, Christopher and Gomes, Carla P and Weinberger, Kilian},
  booktitle={International Conference on Machine Learning},
  pages={980--991},
  year={2021},
  organization={PMLR}
}

Acknowledgements

The starting point for our codebase is pytorch_sac.

low-precision-rl's People

Contributors

nilsjohanbjorck avatar

Stargazers

Aleksei Petrenko avatar  avatar Kailash Gogineni avatar  avatar Luming Tang avatar Dong Chen avatar Xiangyu Chen avatar  avatar

Watchers

 avatar

low-precision-rl's Issues

Towards Deeper Deep Reinforcement Learning

Hi @nilsjohanbjorck !

I was not able to find another way to contact you regarding one of your latest papers. I've got a few questions on your paper "Towards Deeper Deep Reinforcement Learning", which was a very interesting read, so thank you for your contribution!

Figure 3 shows the behavior of the gradients during training ( log || grad loss || ). How did you extract/monitor these statistics? Did you take the magnitude of the output layer of your network or did you took the mean for every single layer?

Concerning spectral layer normalization, did you apply it to every layer of the critic?

Are you going to make your code open source?

Thanks
Marco

Hello, does the same technique work for multi agent problems?

Hello, I'm using the multi-agent sac algorithm and trying to see if we can apply the same modifications to the multi-agent sac in TensorFlow. Can you tell me the key points or changes we need to make to execute the algorithm with lower precision? Should I rewrite the entire code in PyTorch, and is there any possibility that I could use your implementation and try to see the computational savings when using multiple agents? If you have any idea about integrating this implementation into multiple agents, can you refer me to that material?

Thanks for the open sourcing of the code. I found the paper interesting to read thoroughly.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.