Code Monkey home page Code Monkey logo

cql's Introduction

CQL

A simple and modular implementation of the Conservative Q Learning and Soft Actor Critic algorithm in PyTorch.

If you like Jax, checkout my reimplementation of this codebase in Jax, which runs 4 times faster.

Installation

  1. Install and use the included Ananconda environment
$ conda env create -f environment.yml
$ source activate SimpleSAC

You'll need to get your own MuJoCo key if you want to use MuJoCo.

  1. Add this repo directory to your PYTHONPATH environment variable.
export PYTHONPATH="$PYTHONPATH:$(pwd)"

Run Experiments

You can run SAC experiments using the following command:

python -m SimpleSAC.sac_main \
    --env 'HalfCheetah-v2' \
    --logging.output_dir './experiment_output'

All available command options can be seen in SimpleSAC/conservative_sac_main.py and SimpleSAC/conservative_sac.py.

You can run CQL experiments using the following command:

python -m SimpleSAC.conservative_sac_main \
    --env 'halfcheetah-medium-v0' \
    --logging.output_dir './experiment_output'

If you want to run on CPU only, just add the --device='cpu' option. All available command options can be seen in SimpleSAC/sac_main.py and SimpleSAC/sac.py.

Visualize Experiments

You can visualize the experiment metrics with viskit:

python -m viskit './experiment_output'

and simply navigate to http://localhost:5000/

Weights and Biases Online Visualization Integration

This codebase can also log to W&B online visualization platform. To log to W&B, you first need to set your W&B API key environment variable:

export WANDB_API_KEY='YOUR W&B API KEY HERE'

Then you can run experiments with W&B logging turned on:

python -m SimpleSAC.conservative_sac_main \
    --env 'halfcheetah-medium-v0' \
    --logging.output_dir './experiment_output' \
    --device='cuda' \
    --logging.online

Results of Running CQL on D4RL Environments

In order to save your time and compute resources, I've done a sweep of CQL on certain D4RL environments with various min Q weight values. The results can be seen here. You can choose the environment to visualize by filtering on env. The results for each cql.cql_min_q_weight on each env is repeated and average across 3 random seeds.

Credits

The project organization is inspired by TD3. The SAC implementation is based on rlkit. THe CQL implementation is based on CQL. The viskit visualization is taken from viskit, which is taken from rllab.

cql's People

Contributors

young-geng avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

cql's Issues

Encounter error when running this package and two general questions

Dear author,

Thanks for providing this excellent package!

When I try to run this package on my linux server, I encourage the following error

tensorflow.python.framework.errors_impl.NotFoundError: /home/.local/lib/python3.8/site-packages/tensorflow/core/kernels/libtfkernel_sobol_op.so: undefined symbol: _ZN10tensorflow14kerne
l_factory17OpKernelRegistrar12InitInternalEPKNS_9KernelDefEN4absl12lts_2021032411string_viewESt10unique_ptrINS0_15OpKernelFactoryESt14default_deleteIS9_EE
...
wandb: ERROR Internal wandb error: file data was not synced
...
Exception: The wandb backend process has shutdown
Error in atexit._run_exitfuncs:
...
Exception: The wandb backend process has shutdown

The command I used is python -m SimpleSAC.conservative_sac_main --env 'halfcheetah-medium-v0' --logging.output_dir './experiment_output' --device "cuda:0", as per your recommendation.

May I ask how may I fix this error and run this project on my server?

Besides, I have the following two general questions:

  1. I notice that this CQL implementation uses n_epochs=2000, which IMHO is longer than typical offline RL methods. Can we reduce the number of training epochs to, say, 1000 as in BCQ?
  2. Do you have recommendation for the hyperparameter settings for the Maze2D and Adroit domains of tasks in the D4RL dataset? I have tried to replicate the CQL results on the D4RL whitepaper using the original CQL repo, but was unsuccessful.

Results on Adroit tasks

Hi, thank you for sharing the torch version of CQL! For the most of mujoco tasks, I can get similar results with the paper. However, for the adroit task, I've been tried lots of different parameters and versions of CQL, but still can not reproduce the results. Do you have any suggestions on that, thanks again :)

Some questions on CQL

1.For behavior cloning, the update formula policy_loss = (alpha*log_pi - log_probs).mean(), I wonder why using log_probs , but not q-value here?
2. When using Lagrange, do alpha_prime and cql_min_q_weight refer to the same thing, and shouldn't alpha_prime be updated before updating Q_loss, according to formula 30 from CQL paper?
3. Is twin Q function still essential? From my opinion, since q-value could be guaranteed to be a lower bound of true Q value, the twin Q function outputs are needless. Am I right?
4. What is cql_temp in code? The value is always 1, and what is it used for if taking a different value?

(I know some code are referring to CQL, but since the author is no longer active, I asked here.)

Question about the importance sampling

In CQL paper's Appendix F, when using importance sampling to compute the log sum exp of Q(s,a) , only sample actions from Unif(a) and pi(a|s), but why here also need to sample actions from pi(a'|s'). This makes me confused.

cql_cat_q1 = torch.cat(
                    [cql_q1_rand - random_density,
                     cql_q1_next_actions - cql_next_log_pis.detach(),
                     cql_q1_current_actions - cql_current_log_pis.detach()],
                    dim=1
                )
cql_cat_q2 = torch.cat(
                    [cql_q2_rand - random_density,
                     cql_q2_next_actions - cql_next_log_pis.detach(),
                     cql_q2_current_actions - cql_current_log_pis.detach()],
                    dim=1
                )
cql_min_qf1_loss = torch.logsumexp(cql_cat_q1 / self.config.cql_temp, dim=1).mean() * self.config.cql_min_q_weight * self.config.cql_temp
cql_min_qf2_loss = torch.logsumexp(cql_cat_q2 / self.config.cql_temp, dim=1).mean() * self.config.cql_min_q_weight * self.config.cql_temp

Question about the CQL-temperature

I am confused about the cql temperature:

cql_concat_q1 = jnp.concatenate([
      jnp.squeeze(cql_random_q1) - random_density,
      jnp.squeeze(cql_q1) - cql_logp,
])
cql_concat_q2 = jnp.concatenate([
      jnp.squeeze(cql_random_q2) - random_density,
      jnp.squeeze(cql_q2) - cql_logp,
])
cql_qf1_ood = torch.logsumexp(cql_cat_q1 / self.config.cql_temp, dim=1) * self.config.cql_temp
cql_qf2_ood = torch.logsumexp(cql_cat_q2 / self.config.cql_temp, dim=1) * self.config.cql_temp

Shouldn't it be:

cql_concat_q1 = jnp.concatenate([
      jnp.squeeze(cql_random_q1) / self.config.cql_temp - random_density,
      jnp.squeeze(cql_q1) / self.config.cql_temp - cql_logp,
])
cql_concat_q2 = jnp.concatenate([
      jnp.squeeze(cql_random_q2) / self.config.cql_temp - random_density,
      jnp.squeeze(cql_q2) / self.config.cql_temp - cql_logp,
])
cql_qf1_ood = torch.logsumexp(cql_cat_q1, dim=1)
cql_qf2_ood = torch.logsumexp(cql_cat_q2, dim=1)

image

Make checkpoints public

Hi, would it be possible to release the checkpoints for this implementation? Would be very grateful for this.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.