Code Monkey home page Code Monkey logo

hcpi's Introduction

Overview

This is a Rust crate implementing a variant of High-Confidence Policy Improvement. HCPI is a reinforcement learning algorithm that takes trajectories generated by a behavior policy and uses them to recommend a new policy that is better with high probability. The acceptable probability of a regression in policy performance is an input to the algorithm that can be tuned by the user. The intent of such algorithms is to allow safe policy improvement in domains (e.g. medicine) where a competent behavior policy is known, but on-policy exploration is prohibited because mistakes are costly.

HCPI works by using black-box optimization to find policy parameters that optimize expected discounted return, subject to the constraint that they are expected to outperform the behavior policy with high probability. Expected returns for candidate policies are estimated using per-decision importance sampling to reweight the observations collected under the behavior policy. To ensure it really is likely to be an improvement, the resulting candidate policy is then subjected to a safety test using a holdout split from the input data, and is either returned or discarded depending on the outcome.

The src/{mdp, data, hcpi}.rs files define an interface for running HCPI, and main.rs contains an example of using HCPI to generate 10 improved policies for a small MDP using behavior policy data contained in data.csv.

This code was written in Dec., 2019, as a project for CS 687 (Reinforcement Learning) at UMass Amherst, and released with permission from the instructor.

See also:

Input Data Format

The input data is stored in a CSV file in the datasets directory, containing the following rows:

  1. The number of state features.
  2. The number of actions.
  3. The Fourier basis order used by the behavior policy.
  4. The parameters of the behavior policy.
  5. The number of episodes N of data generated under the behavior policy.
  6. N rows of numbers, where each row indicates the full history of an episode.
  7. A list of (state, action) probabilities generated by the behavior policy, used to test that the HCPI policy representation is accurate.

See p. 151 of the course notes for full details.

Running

  1. Install Rust.
  2. This project depends on FFI bindings into the GNU Scientific Library. Most package managers bundle GSL, so installing it should be painless. A scary compile error from the HCPI code probably means that rustc can't find the GSL.
  3. (Optional) Generate a new dataset using tests/cartpole.py or some analogous RL code of your own.
  4. Symlink your dataset of choice to data.csv in the top level of the source directory (the level containing Cargo.toml).
  5. In the top level directory, run cargo run --release. The crate should compile without warnings.

Note: In the working directory, main.rs will create top-level directories called output and failed, which will be populated by CSV files containing policies as they are found. Policies that pass the safety test will be written to output, and policies that fail the test will be written to failed (this is mostly so that you can inspect the failed policy parameters to help tune the optimizer). The code will panic and prompt you to delete these directories if they already exist, in order to avoid accidentally overwriting policies from a previous run.

Testing

This repository includes the original CS687 dataset at datasets/cs687.csv. This dataset is small, so useful for testing that the code runs successfully, but it can't be used to validate the algorithm because the dynamics of the MDP that generated the data were not provided in the class. Larger datasets can be generated on CartPole using the provided tests. These will have much longer episode horizons and higher dimensionality than cs687.csv, so HCPI will take much longer to run. They are better for testing that the algorithm works, but worse for quickly testing that the code runs without errors.

The tests directory contains the following Python code, useful for testing the behavior of the HCPI algorithm:

  • agents.py, an implementation of a simple hill-climbing agent.
  • policies.py, an implementation of a Fourier-basis policy representation with softmax action selection.
  • cartpole.py, a script that trains a mediocre behavior policy on the OpenAI Gym CartPole-v0 and then uses it to generate a dataset on which to run HCPI. The generated data file will be located at datasets/cartpole_deg{k}_ret{R}_eps{N}.csv where k is the Fourier basis order, R is a mean return of the behavior policy over a configurable number of episodes, and N is the number of histories in the dataset.
  • eval.py, a script that loads the baseline policy in data.csv, as well as all the policies in output, runs them all for a configurable number of episodes, saves the results to tests/eval.csv, plots them, and saves the plot to eval.png.

Running cargo test in the top-level directory will execute a Rust test that ensures the HCPI policy representation matches the policy representation used to generate the dataset (this is what the last row of data.csv is for).

Limitations

  • While HCPI works in a more general setting, this code only handles Fourier policies over finite action spaces.
  • This code was written to solve a specific problem, and makes no attempt to provide a general library API.
  • Consequently, some hyperparameters or constants may be hard-coded, though this should be mostly confined to main.rs.

hcpi's People

Contributors

jwarley avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.