Code Monkey home page Code Monkey logo

ecrl's Introduction

Entity-Centric Reinforcement Learning (ECRL)

Official PyTorch code release of the paper "Entity-Centric Reinforcement Learning for Object Manipulation from Pixels" by Dan Haramati, Tal Daniel and Aviv Tamar.


Entity-Centric Reinforcement Learning for Object Manipulation from Pixels

ICLR 2024 - Spotlight (top 5%)

Goal-Conditioned Reinforcement Learning Workshop, NeurIPS 2023 - Spotlight

Zero-Shot Generalization from 3 to 12 Objects

       

Abstract

Manipulating objects is a hallmark of human intelligence, and an important task in domains such as robotics. In principle, Reinforcement Learning (RL) offers a general approach to learn object manipulation. In practice, however, domains with more than a few objects are difficult for RL agents due to the curse of dimensionality, especially when learning from raw image observations. In this work we propose a structured approach for visual RL that is suitable for representing multiple objects and their interaction, and use it to learn goal-conditioned manipulation of several objects. Key to our method is the ability to handle goals with dependencies between the objects (e.g., moving objects in a certain order). We further relate our architecture to the generalization capability of the trained agent, based on a theoretical result for compositional generalization, and demonstrate agents that learn with 3 objects but generalize to similar tasks with over 10 objects.

Citation

Haramati, Dan, Tal Daniel, and Aviv Tamar. "Entity-Centric Reinforcement Learning for Object Manipulation from Pixels." Proceedings of the Twelfth International Conference on Learning Representations (ICLR). 2024.

 @inproceedings{
 haramati2024entitycentric,
 title={Entity-Centric Reinforcement Learning for Object Manipulation from Pixels},
 author={Dan Haramati and Tal Daniel and Aviv Tamar},
 booktitle={The Twelfth International Conference on Learning Representations},
 year={2024},
 url={https://openreview.net/forum?id=uDxeSZ1wdI}
 }
In the Eyes of the Agent

           

Content

Entity Centric-Reinforcement Learning (ECRL)

  1. Prerequisites
  2. Environments
  3. Training
  4. Evaluation
  5. Repository Content Details
  6. Credits

1. Prerequisites

The following are the main libraries required to run this code:

Library Version
Python 3.8
torch 2.1.2
stable-baselines3 1.5.0
isaacgym

For the full list of requirements, see the requirements.txt file.

For the simulation environment, download the Isaac Gym Preview release from the website, then follow the installation instructions in the documentation.

2. Environments

Environments

The above figure describes the suite of training environments used in the paper.

N-Cubes: Push N different-colored cubes to their goal location.
Adjacent-Goals: A 3-Cubes setting where goals are sampled randomly on the table such that all cubes are adjacent. This task requires accounting for interactions between objects.
Ordered-Push: A 2-Cubes setting where a narrow corridor is set on top of the table such that its width can only fit a single cube. We consider two possible goal configurations: red cube in the rear of the corridor and green cube in the front, or vice versa. This task requires to fulfill the goals in a certain order, otherwise the agent fails (pulling a block out of the corridor is not possible).
Small-Table: A 3-Cubes setting where the table is substantially smaller. This task requires to accurately account for all objects in the scene at all times, to avoid pushing blocks off the table.
Push-2T: Push 2 T-shaped blocks to a single goal orientation.

A configuration file for each environment, IsaacPandaPushConfig.yaml, is provided in the corresponding directory in config.

3. Training

Deep Latent Particles (DLP) Pretraining

We provide pretrained model checkpoints:

Model Dataset Download
DLP 5-Cubes Google Drive
DLP 6-Cubes Google Drive
DLP Push-T Google Drive
Slot-Attention 5-Cubes Google Drive
VAE Mixture of 1/2/3-Cubes Google Drive

Download and place in the relevant directory in latent_rep_chkpts (e.g., checkpoint of DLP trained on data from the 5-Cubes environment should be placed in latent_rep_chkpts/dlp_push_5C).

In order to retrain the model:

  1. Collect image data using a random policy by running main.py -c <configuration_dir> with the desired environment (e.g, main.py -c config/n_cubes), setting collectData: True and collectDataNumTimesteps in the relevant Config.yaml. This will save a .npy file in the results directory.
  2. Process the data into a dataset by running dlp2/datasets/process_dlp_data.py (fill in the relevant paths in the beginning of the script).
  3. Configure config/TrainDLPConfig.yaml and run train_dlp.py.
DLP Decomposition

RL Training

Run main.py -c <configuration_dir> with the desired configuration (e.g, main.py -c config/n_cubes).
Config.yaml contains agent and training parameters and IsaacPandaPushConfig.yaml contains environment parameters.

In order to reproduce the experiments in the paper, input the corresponding configuration directory. The configurations are already set to match the ones used in the paper. The parameters requiring configuration for the different instances of the experiments (e.g, 'State' or 'Image'):

  • In Config.yaml the Model parameters.
  • In IsaacPandaPushConfig.yaml the numObjects parameter (for n_cubes and push_t).

To log training statistics and images/videos using Weights & Biases set WANDB: log: True in Config.yaml and fill in your username in the wandb.init(entity="") line in the main.py script.

Agent model checkpoints and intermediate results are saved in the model_chkpts and results directories respectively.

EIT Architecture Outline

4. Evaluation

To evaluate an agent on a given environment, run policy_eval.py. Set the agent model_path and the desired configuration directory manually in the beginning of the script.

Evaluation on Zero-Shot Generalization

Cube Sorting: train on config/n_cubes with numObjects: 3 and evaluate on config/generalization_sort_push.

Different Number of Cubes than in Training: train on config/generalization_num_cubes with numObjects: 3 and evaluate with same config and varying number of objects.

Zero-Shot Generalization from 3 to 6 Objects

           

5. Repository Content Details

Filename Description
main.py main script for training the RL agent
policy_eval.py script for evaluating a trained agent on a given environment
td3_agent.py agent code implementing TD3 + HER with adjustments
multi_her_replay_buffer.py SB3 HER replay buffer adjusted for data from parallel environments
policies.py policy and Q-function neural networks for ECRL and baselines
isaac_env_wrappers.py wrappers for the IsaacGym environment for SB3 and goal compatibility
isaac_panda_push_env.py IsaacGym-based tabletop robotic pushing environment
/panda_controller directory containing code for the robotic arm controller
isaac_vec_env.py base IsaacGym environment
chamfer_reward.py Chamfer reward model
latent_classifier.py latent binary classifier for the Chamfer Reward filter
utils.py utility functions
/dlp2 directory containing the DLP code
train_dlp.py script for pretraining a DLP model
/vae directory containing the VAE code for the unstructured baseline
train_vae.py script for pretraining a VAE model (with VQ-VAE option)
/slot_attention directory containing the Slot-Attention code
train_sa.py script for pretraining a Slot-Attention model (with SLATE option)
/latent_rep_chkpts directory containing pre-trained representation model checkpoints
/latent_classifier_chkpts directory containing trained classifiers for the Chamfer Reward
/assets IsaacGym environment assets
/config directory containing configuration files for the different environments
requirements.txt library requirements file
/media directory containing images and gifs for this README

isaac_env_wrappers.py, isaac_panda_push_env.py, and chamfer_reward.py contain scripts for debugging the components seperately.
latent_classifier.py contains an interactive script for tagging data and training the classifier.

6. Credits

  • Environments adapted from the IsaacGymEnvs FrankaCubeStack environment.
  • Controller for the Franka Panda arm adapted from the OSC implementation of OSCAR.
  • RL code modified from SB3 (mainly from the TD3 agent and HERReplayBuffer).
  • Entity Interaction Transformer (EIT) code built from components adapted from the DDLP Particle Interaction Transformer (PINT) which is based on minGPT.
  • SMORL re-implementation based on the official repository of the paper using the same codebase as the EIT.
  • DLP code modified from the official implementation of DLPv2 in the DDLP repository.
  • VAE code modified from Taming Transformers.
  • Slot-Attention code modified from Object Discovery PyTorch.

ecrl's People

Contributors

danhrmti avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

lc-dev enabotech

ecrl's Issues

Results expectation

Hello again,

I have a brief question about the results that I am supposed to get on wandb as well as from running main.py and eval_policy.py.

I'm attempting to train an ordered push policy. I run main.py -c config/ordered_push. In the respective Config.yaml file I specify latentRepPath: 'latent_rep_chkpts/dlp_push_5C' for pre-trained DLP and train with numObjects = 2. I run policy_eval.py config/ordered_push with the respective checkpoint found in model_chkpts and the same number of objects.

I am unsure of what I am supposed to expect in the results. My issues are as follows:

  1. There are several .gif files that I obtained in wanbd: Eval Goal Image, Goal Image - Frontview, Goal Image - Sideview, Episode Video - Frontview, Episode Video - Sideview, and Eval Episode Video. I'm getting videos for both main.py and policy_eval.py. I'm not sure about what these videos represent and at which point in the training they are recorded.

  2. I notice that the code runs after 800 episodes. I'm probably misunderstanding the code a little bit, but I'm currently expecting that the code should stop as it is specified in the config file.

  3. I would be curious to see what are my expected graphs for obs_rms_std and obs_rms_mean should look like. I see the error decrease, but only from 0.68 to 0.655 and -0.03 to -0.04 respectively (which tells me that I'm probably not learning much). Would you mind telling me what results you got on that end?

Thank you very much for your help!

Hanna

p.s. Would it be possible to email you for further personal issues I have?

Dependency issues

Hi,

While trying to run your code, I ran into several dependency issues.

In the prerequisites section, you mention having torch 2.0.1 and stable-baseline3 1.5.0. However, your requirements.txt has torch 2.1.2. Furthermore, I can't install gym 0.21.0 without downgrading to python 3.7. I have more similar issues that require newer versions of stable-baseline3.

Also, I have permission denied when trying to run python main.py -c config/push_t (for wandb).

Would you mind helping me to get started running your code in conda?

Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.