Code Monkey home page Code Monkey logo

hindsight-experience-replay's Introduction

Hindsight-Experience-Replay

This repository provides the Pytorch implementation of Hindsight Experience Replay on Deep Q Network and Deep Deterministic Policy Gradient algorithms.

Link to the paper: https://arxiv.org/pdf/1707.01495.pdf

Authors: Marcin Andrychowicz, Filip Wolski, Alex Ray, Jonas Schneider, Rachel Fong, Peter Welinder, Bob McGrew, Josh Tobin, Pieter Abbeel, Wojciech Zaremba

Training

  • You can train the model simply by running the main.py files.

    DQN With HER -> HERmain.py

    DDPG With HER -> DDPG_HER_main.py

    DQN Without HER -> main.py

  • You can set the hyper-parameters such as learning_rate, discount factor (gamma), epsilon, and others while initializing the agent variable in the above-mentioned files

Running the pre-trained model

  • Just run the files mentioned in the Training section with making the load_checkpoint variable to True which will load the saved parameters of the model and output the results. Just update the paths as per the saved results path.

Results


With average
Without average (contains spikes)

References

hindsight-experience-replay's People

Contributors

hemilpanchiwala avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

hindsight-experience-replay's Issues

Date type problem

Thank you for sharing this code.

When I run this code, I got the error

In DQNAgentWithHER.py, Line 106
q_pred = self.q_eval.forward(concat_state_goal)[batches, action]

IndexError: tensors used as indices must be long, byte or bool tensors

I already checked dtype and shape then I got it

    print(concat_state_goal.dtype)            torch.float32
    print(batches.dtype)                             int64
    print(action.dtype)                                torch.float32

    print(concat_state_goal.shape)           torch.Size([64, 16])
    print(batches.shape)                            (64,)
    print(action.shape)                               torch.Size([64, 8])

Could you provide me which is the problem? (what is the difference between your output and mine..)

Thank you.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.