Code Monkey home page Code Monkey logo

dmcontrol-generalization-benchmark's Introduction

DMControl Generalization Benchmark

[07/01/2021] Added SVEA, DrQ, Distracting Control Suite, and reduced memory consumption by 5x

Benchmark for generalization in continuous control from pixels, based on DMControl.

Also contains official implementations of

Stabilizing Deep Q-Learning with ConvNets and Vision Transformers under Data Augmentation (SVEA)
Nicklas Hansen, Hao Su, Xiaolong Wang

[Paper] [Webpage]

and

Generalization in Reinforcement Learning by Soft Data Augmentation (SODA)
Nicklas Hansen, Xiaolong Wang

[Paper] [Webpage]

See this repository for SVEA implemented using Vision Transformers.

Test environments

The DMControl Generalization Benchmark provides two distinct benchmarks for visual generalization, random colors and video backgrounds:

environment samples

Both benchmarks are offered in easy and hard variants. Samples are shown below.

color_easy
color_easy

color_hard
color_hard

video_easy
video_easy

video_hard
video_hard

This codebase also integrates a set of challenging test environments from the Distracting Control Suite (DistractingCS). Our implementation of DistractingCS includes environments of 8 gradually increasing randomization intensities. Note that our implementation of DistractingCS is not equivalent to the original DistractingCS benchmark -- they differ in important ways: (1) we evaluate at a different set of intensities (and number of videos) that more closely matches performance of current algorithms; (2) we reduce randomization update frequency by a factor of 2 to account for frame skip (action repeat); (3) all Tensorflow dependencies have been replaced by PyTorch. By default, algorithms are trained for 500k frames and are continuously evaluated in both training and test environments. Environment randomization is seeded to promote reproducibility.

Algorithms

This repository contains implementations of the following algorithms in a unified framework:

using standardized architectures and hyper-parameters, wherever applicable. If you want to add an algorithm, feel free to send a pull request.

Citation

If you find our work useful in your research, please consider citing our work as follows:

@article{hansen2021stabilizing,
  title={Stabilizing Deep Q-Learning with ConvNets and Vision Transformers under Data Augmentation},
  author={Nicklas Hansen and Hao Su and Xiaolong Wang},
  year={2021},
  eprint={2107.00644},
  archivePrefix={arXiv},
  primaryClass={cs.LG}
}

for the SVEA method, and

@inproceedings{hansen2021softda,
  title={Generalization in Reinforcement Learning by Soft Data Augmentation},
  author={Nicklas Hansen and Xiaolong Wang},
  booktitle={International Conference on Robotics and Automation},
  year={2021},
}

for the SODA method and the DMControl Generalization Benchmark.

Setup

We assume that you have access to a GPU with CUDA >=9.2 support. All dependencies can then be installed with the following commands:

conda env create -f setup/conda.yaml
conda activate dmcgb
sh setup/install_envs.sh

Datasets

Part of this repository relies on external datasets. SODA uses the Places dataset for data augmentation, which can be downloaded by running

wget http://data.csail.mit.edu/places/places365/places365standard_easyformat.tar

Distracting Control Suite uses the DAVIS dataset for video backgrounds, which can be downloaded by running

wget https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip

You should familiarize yourself with their terms before downloading. After downloading and extracting the data, add your dataset directory to the datasets list in setup/config.cfg.

The video_easy environment was proposed in PAD, and the video_hard environment uses a subset of the RealEstate10K dataset for background rendering. All test environments (including video files) are included in this repository, namely in the src/env/ directory.

Training & Evaluation

The scripts directory contains training and evaluation bash scripts for all the included algorithms. Alternatively, you can call the python scripts directly, e.g. for training call

python3 src/train.py \
  --algorithm svea \
  --seed 0

to run SVEA on the default task, walker_walk. This should give you an output of the form:

Working directory: logs/walker_walk/svea/0
Evaluating: logs/walker_walk/svea/0
| eval | S: 0 | ER: 26.2285 | ERTEST: 25.3730
| train | E: 1 | S: 250 | D: 70.1 s | R: 0.0000 | ALOSS: 0.0000 | CLOSS: 0.0000 | AUXLOSS: 0.0000

where ER and ERTEST corresponds to the average return in the training and test environments, respectively. You can select the test environment used in evaluation with the --eval_mode argument, which accepts one of (train, color_easy, color_hard, video_easy, video_hard, distracting_cs, none). Use none if you want to disable continual evaluation of generalization. Note that not all combinations of arguments have been tested. Feel free to open an issue or send a pull request if you encounter an issue or would like to add support for new features.

Results

We provide test results for each of the SVEA, SODA, PAD, DrQ, RAD, and CURL methods. Results for color_hard and video_easy are shown below:

soda table results

See our paper for additional results.

Acknowledgements

We would like to thank the numerous researchers and engineers involved in work of which this work is based on. This repository is a product of our work on SVEA, SODA and PAD. Our SAC implementation is based on this repository, the original DMControl is available here, and the gym wrapper for it is available here. The Distracting Control Suite environments were adapted from this implementation. PAD, RAD, CURL, and DrQ baselines are based on their official implementations provided here, here, here, and here, respectively.

dmcontrol-generalization-benchmark's People

Contributors

dexiongyung avatar haooooooqi avatar nicklashansen avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

dmcontrol-generalization-benchmark's Issues

Results on other domains

Hi, do we need to tune parameters for domains like humanoid?
I tried to run the code and the training is as close to the results claimed for the walker domain, however, when I use the same parameters for humanoid, the reward doesn't even go to double digits. Is this expected or am I expected to consider some more factors?

TypeError: load() got an unexpected keyword argument 'setting_kwargs'

Hi @nicklashansen, when I ran the following command

python3 src/train.py \
  --algorithm svea \
  --seed 0

I got the following error -

 File "src/train.py", line 150, in <module>
    main(args)
  File "src/train.py", line 50, in main
    mode='train'
  File "/home/tejas/github/dmcontrol-generalization-benchmark/src/env/wrappers.py", line 48, in make_env
    background_dataset_paths=paths
  File "/home/tejas/github/dmcontrol-generalization-benchmark/src/env/dmc2gym/dmc2gym/__init__.py", line 64, in make
    return gym.make(env_id)
  File "/home/tejas/anaconda3/envs/dmcgb/lib/python3.7/site-packages/gym/envs/registration.py", line 235, in make
    return registry.make(id, **kwargs)
  File "/home/tejas/anaconda3/envs/dmcgb/lib/python3.7/site-packages/gym/envs/registration.py", line 129, in make
    env = spec.make(**kwargs)
  File "/home/tejas/anaconda3/envs/dmcgb/lib/python3.7/site-packages/gym/envs/registration.py", line 90, in make
    env = cls(**_kwargs)
  File "/home/tejas/github/dmcontrol-generalization-benchmark/src/env/dmc2gym/dmc2gym/wrappers.py", line 90, in __init__
    setting_kwargs=setting_kwargs
TypeError: load() got an unexpected keyword argument 'setting_kwargs'

I'm using latest version of Mujoco i.e. 2.1.0. Seems to me like this error is in dmc2gym but I'm unable to resolve it.

Thanks!

Question about video background

When I run the program with the video_easy or video_hard command, the saved video file has a green background instead of the video background.
I want to ask how to solve this problem.

Question about robot-push

Could you give me some help about the implements about Robotic manipulation. Looking forward your help. I'm not find such task.

Type Error: 'setting_kwargs'

Hello,

I am facing the same issue as the one described here. I ran this command python3 src/train.py --algorithm svea --seed 0 but got this error TypeError: load() got an unexpected keyword argument 'setting_kwargs'.

I did install dm_control which comes with mujoco 2.2.0 and I dont see any way to install mujoco 2.0.0 from the mujoco download page.

@nicklashansen Any help would be appreciated.

RuntimeError: DataLoader worker (pid 384930) is killed by signal: Segmentation fault.

I ran your code of the SODA algorithm for 500k steps. The code ran till 211k steps and then it gave a segmentation fault error.

Evaluating: logs/walker_walk/soda/0
| eval | S: 210000 | ER: 604.9552 | ERTEST: 473.9582
| train | E: 841 | S: 210250 | D: 77.9 s | R: 676.1391 | ALOSS: -200.7976 | CLOSS: 19.3476 | AUXLOSS: 0.0003
| train | E: 842 | S: 210500 | D: 21.4 s | R: 630.4664 | ALOSS: -200.9594 | CLOSS: 19.7981 | AUXLOSS: 0.0003
| train | E: 843 | S: 210750 | D: 21.5 s | R: 575.7474 | ALOSS: -201.1477 | CLOSS: 19.6175 | AUXLOSS: 0.0003
| train | E: 844 | S: 211000 | D: 21.5 s | R: 587.5916 | ALOSS: -201.0205 | CLOSS: 19.8251 | AUXLOSS: 0.0003
| train | E: 845 | S: 211250 | D: 21.7 s | R: 600.5652 | ALOSS: -200.9775 | CLOSS: 19.5227 | AUXLOSS: 0.0003
| train | E: 846 | S: 211500 | D: 21.5 s | R: 617.4011 | ALOSS: -201.0966 | CLOSS: 19.4789 | AUXLOSS: 0.0003
| train | E: 847 | S: 211750 | D: 21.5 s | R: 670.3287 | ALOSS: -200.8488 | CLOSS: 19.7286 | AUXLOSS: 0.0003
ERROR: Unexpected segmentation fault encountered in worker.
Traceback (most recent call last):
File "src/train.py", line 152, in
main(args)
File "src/train.py", line 136, in main
agent.update(replay_buffer, L, step)
File "/home/kumars/Darshita/dmcontrol-generalization-benchmark/src/algorithms/soda.py", line 75, in update
self.update_critic(obs, action, reward, next_obs, not_done, L, step)
File "/home/kumars/Darshita/dmcontrol-generalization-benchmark/src/algorithms/sac.py", line 93, in update_critic
current_Q1, current_Q2 = self.critic(obs, action)
File "/home/kumars/anaconda3/envs/crc/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/kumars/Darshita/dmcontrol-generalization-benchmark/src/algorithms/modules.py", line 248, in forward
return self.Q1(x, action), self.Q2(x, action)
File "/home/kumars/anaconda3/envs/crc/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/kumars/Darshita/dmcontrol-generalization-benchmark/src/algorithms/modules.py", line 232, in forward
return self.trunk(torch.cat([obs, action], dim=1))
File "/home/kumars/anaconda3/envs/crc/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/kumars/anaconda3/envs/crc/lib/python3.6/site-packages/torch/nn/modules/container.py", line 141, in forward
input = module(input)
File "/home/kumars/anaconda3/envs/crc/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/kumars/anaconda3/envs/crc/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 103, in forward
return F.linear(input, self.weight, self.bias)
File "/home/kumars/anaconda3/envs/crc/lib/python3.6/site-packages/torch/nn/functional.py", line 1848, in linear
return torch._C._nn.linear(input, weight, bias)
File "/home/kumars/anaconda3/envs/crc/lib/python3.6/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
_error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 384930) is killed by signal: Segmentation fault.

@nicklashansen Can you please help in this regard?

TypeError: make() got an unexpected keyword argument 'is_distracting_cs'

Hello,

When I ran the command python3 src/train.py --algorithm sac --seed 0, I got this error :

Traceback (most recent call last):
  File "../src/train.py", line 150, in <module>
    main(args)
  File "../src/train.py", line 50, in main
    mode='train'
  File "/home/mgz_21/0_Project/DMConrol-GB/src/env/wrappers.py", line 48, in make_env
    background_dataset_paths=paths
TypeError: make() got an unexpected keyword argument 'is_distracting_cs'

My main packages version as:

cudatoolkit               11.0.221             
dm-control                0.0.318066097
numpy                      1.19.5
python                    3.7.6
torch                     1.7.1

Thanks!

TypeError: load() got an unexpected keyword argument 'setting_kwargs'

Hi. I just made a fresh install following the instructions on the readme file.
When I run

python3 src/train.py \
  --algorithm svea \
  --seed 0

I get
AttributeError: 'dict' object has no attribute 'env_specs'
This is easily solved by downgrading the python version from 0.26.0 to 0.19.0.

Now, instead, I get the following:

/home/antonioricciardi/anaconda3/envs/dmcgb_orig/lib/python3.7/site-packages/glfw/__init__.py:916: GLFWError: (65544) b'X11: The DISPLAY environment variable is missing'
  warnings.warn(message, GLFWError)
Traceback (most recent call last):
  File "src/train.py", line 150, in <module>
    main(args)
  File "src/train.py", line 50, in main
    mode='train'
  File "/home/antonioricciardi/projects/dmcontrol-generalization-benchmark/src/env/wrappers.py", line 48, in make_env
    background_dataset_paths=paths
  File "/home/antonioricciardi/projects/dmcontrol-generalization-benchmark/src/env/dmc2gym/dmc2gym/__init__.py", line 64, in make
    return gym.make(env_id)
  File "/home/antonioricciardi/anaconda3/envs/dmcgb_orig/lib/python3.7/site-packages/gym/envs/registration.py", line 145, in make
    return registry.make(id, **kwargs)
  File "/home/antonioricciardi/anaconda3/envs/dmcgb_orig/lib/python3.7/site-packages/gym/envs/registration.py", line 90, in make
    env = spec.make(**kwargs)
  File "/home/antonioricciardi/anaconda3/envs/dmcgb_orig/lib/python3.7/site-packages/gym/envs/registration.py", line 60, in make
    env = cls(**_kwargs)
  File "/home/antonioricciardi/projects/dmcontrol-generalization-benchmark/src/env/dmc2gym/dmc2gym/wrappers.py", line 90, in __init__
    setting_kwargs=setting_kwargs
TypeError: load() got an unexpected keyword argument 'setting_kwargs'

Have any ideas of how I can solve this? Thank you!

Reproductibility SODA and SVEA conv

Hi Nicklas,
Thank you for your high-quality repo.
We have trouble reproducing your results on finger spin with SODA and SVEA (we have between 500 and 600).
Even in training, we don't achieve the performance shown.
Are there any special settings or configurations for this environment?

Best regards

Question about data augmentation on target network

Thank you for your great work.

Could you please clarify whether the target network undergoes any data augmentation, including random shift (i.e., weak augmentation), in the SVEA? I am unsure if the random shifts are applied or not, in the target network.

Thank you.

Questions about std in SVEA paper

Hi, thanks for the great work!
I've noticed that "Hi, we compute the standard deviation over the mean episode returns of each seed". from the previous issue. (#4)
However, I'm still a bit confused. Could you please confirm if my understanding is correct?

  • (Fig.5 Top) Training performance: std of 5 seeds
  • (Fig.5 Bottom) Test performance: For each seed, run zero-shot evaluation 30 times (args.eval_episode) and calculate the mean from these 30 Return values (resulting in 1 mean value per seed). Then compute std using these 5 mean values.

Thank you!

How to calculate the std mentioned in the article?

Hello~ How to calculate the std.deviation in the paper? Should I record all the episode rewards in every episode from different seeds and calculate their std. deviation?Or just record the mean of 100 episodes in different seeds,and calculate the std. deviation among these mean values?

RAD implementation

Hi, thank you for this work! I'm interested in the RAD implementation but found that the RAD agent just inherits the SAC. May I know how RAD in this implementation differs from SAC agent?

Any changes to DMControl source code?

Hi, great work!

I've noticed that this benchmark differs from the original implementation of Distracting Control Suite.

But how about the DMC source code here compared with official repo? Any changes?

I think some much more highlighted and detailed notifications on README are neccesarry and welcomed.

Confusion about Capacity of ReplayBuffer

Hello authors, thx for your great work.
I have one question about the capacity of the replaybuffer code. But acording to the original DrQ code, I find that they use a hyperparameter to set the cabicity and their default parameter is 100000. Is there any reasons to set the capacity by train_steps.

Unable to find the robotic manipulation environment.

Thank you for your contribution. I would like to use the robotic manipulation environment (pushing the cube to the location of the red disc) used in the paper- Generalization in Reinforcement Learning by Soft Data Augmentation. Is this environment publicly available for everyone to use?

@nicklashansen Request your help.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.