Code Monkey home page Code Monkey logo

rurel's Introduction

Rurel

crates.io

Rurel is a flexible, reusable reinforcement learning (Q learning) implementation in Rust.

In Cargo.toml:

rurel = "0.5.1"

An example is included. This teaches an agent on a 21x21 grid how to arrive at 10,10, using actions (go left, go up, go right, go down):

cargo run --example eucdist

Getting started

There are two main traits you need to implement: rurel::mdp::State and rurel::mdp::Agent.

A State is something which defines a Vec of actions that can be taken from this state, and has a certain reward. A State needs to define the corresponding action type A.

An Agent is something which has a current state, and given an action, can take the action and evaluate the next state.

Example

Let's implement the example in cargo run --example eucdist. We want to make an agent which is taught how to arrive at 10,10 on a 21x21 grid.

First, let's define a State, which should represent a position on a 21x21, and the correspoding Action, which is either up, down, left or right.

use rurel::mdp::State;

#[derive(PartialEq, Eq, Hash, Clone)]
struct MyState { x: i32, y: i32 }
#[derive(PartialEq, Eq, Hash, Clone)]
struct MyAction { dx: i32, dy: i32 }

impl State for MyState {
	type A = MyAction;
	fn reward(&self) -> f64 {
		// Negative Euclidean distance
		-((((10 - self.x).pow(2) + (10 - self.y).pow(2)) as f64).sqrt())
	}
	fn actions(&self) -> Vec<MyAction> {
		vec![MyAction { dx: 0, dy: -1 },	// up
			 MyAction { dx: 0, dy: 1 },	// down
			 MyAction { dx: -1, dy: 0 },	// left
			 MyAction { dx: 1, dy: 0 },	// right
		]
	}
}

Then define the agent:

use rurel::mdp::Agent;

struct MyAgent { state: MyState }
impl Agent<MyState> for MyAgent {
	fn current_state(&self) -> &MyState {
		&self.state
	}
	fn take_action(&mut self, action: &MyAction) -> () {
		match action {
			&MyAction { dx, dy } => {
				self.state = MyState {
					x: (((self.state.x + dx) % 21) + 21) % 21, // (x+dx) mod 21
					y: (((self.state.y + dy) % 21) + 21) % 21, // (y+dy) mod 21
				}
			}
		}
	}
}

That's all. Now make a trainer and train the agent with Q learning, with learning rate 0.2, discount factor 0.01 and an initial value of Q of 2.0. We let the trainer run for 100000 iterations, randomly exploring new states.

use rurel::AgentTrainer;
use rurel::strategy::learn::QLearning;
use rurel::strategy::explore::RandomExploration;
use rurel::strategy::terminate::FixedIterations;

let mut trainer = AgentTrainer::new();
let mut agent = MyAgent { state: MyState { x: 0, y: 0 }};
trainer.train(&mut agent,
              &QLearning::new(0.2, 0.01, 2.),
              &mut FixedIterations::new(100000),
              &RandomExploration::new());

After this, you can query the learned value (Q) for a certain action in a certain state by:

trainer.expected_value(&state, &action) // : Option<f64>

Development

  • Run cargo fmt --all to format the code.
  • Run cargo clippy --all-targets --all-features -- -Dwarnings to lint the code.
  • Run cargo test to test the code.

rurel's People

Contributors

chriamue avatar lucanlepus avatar lucky4luuk avatar milanboers avatar mtib avatar nyurik avatar willow-iam avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

rurel's Issues

Example doesn't "learn" anything

Running the example and adding some debugging code, I'm finding that the neural network is not learning anything at all.

    let mut trainer = AgentTrainer::new();
    let mut agent = MyAgent {
        state: MyState { x: 0, y: 0 },
    };
    trainer.train(
        &mut agent,
        &QLearning::new(0.2, 0.01, 2.),
        &mut FixedIterations::new(10000000),
        &RandomExploration::new(),
    );
    let state1 = MyState { x: 1, y: 0 };
    let state2 = MyState { x: 0, y: 1 };
    let actions = vec![MyAction { dx: 0, dy: -1 }, MyAction { dx: -1, dy: 0 }];
    for action in actions {
        println!(
            "1: {:?} {:?} {:?}",
            state1,
            action,
            trainer.expected_value(&state1, &action),
        );
        println!(
            "2: {:?} {:?} {:?}",
            state2,
            action,
            trainer.expected_value(&state2, &action),
        );
        println!();
    }
1: MyState { x: 1, y: 0 } MyAction { dx: 0, dy: -1 } Some(-13.582118848154376)
2: MyState { x: 0, y: 1 } MyAction { dx: 0, dy: -1 } Some(-14.27795681221249)

1: MyState { x: 1, y: 0 } MyAction { dx: -1, dy: 0 } Some(-14.27795681221249)
2: MyState { x: 0, y: 1 } MyAction { dx: -1, dy: 0 } Some(-13.582118848154376)

It seems that it hasn't learned that even with x:1 and y:0, dx:-1 and dy:0 is the best move. Am I misunderstanding the example or anything here?

Can't represent a state from which there are no actions

Hey! cool crate, I'm playing around with it a bit. I realized that I can't seem to represent a state from which there are no actions, I'm training a model to play a game and in this game there are some "end states" like the end of the game where you're scored on your performance, but if I have my State::actions function return an empty vec I get a div/0 error here :https://github.com/milanboers/rurel/blob/40d0fa7116c528953780b74e0a19756182a70a72/src/mdp/mod.rs#L22C59-L22C59

since there are no actions. Maybe I'm just misunderstanding how this is supposed to work :)

Empty list of actions panics

It seems that if the State::actions() returns an empty vector, the whole system crashes. I am new to reinforcement learning, so I might be using it incorrectly. My setup:

  • the state is a NxN empty board
  • for each board state, actions() generates a list of actions available for the current board state
  • in many cases the board state cannot be improved any further, and it should back track to try a different action

Usage for future values?

Question... can this library be used for something like Pong - i.e. where the reward isn't known right away, but rather becomes eventually known and somehow backpropogated?

I guess the reward could sortof be immediately known, by checking y-distance from the ball.. but that's not a huge win over just manually making it chase the ball. It'd be nicer to have the AI figure out other things - like trying to hit the ball such that it makes the opponent chase it a longer distance.

Sorry for the naive question - I haven't jumped into ML yet and I'm just tinkering around with a webassembly pong thing and thought this library might be a nice way to drive the AI :)

Train by individual steps

Is it possible to train a single step at a time, so I can have it run within something like a bevy system?

It seems like the current train method is usually done with a certain amount of iterations, but perhaps a function step or train_step would be convenient?

Eucdist example doesn't seem to be correct

I've modified the eucdist example to add Display for MyAction which prints an arrow based on the action.
And added a function entry_to_action which gets the most likely action from a given state (if I'm not wrong):

fn entry_to_action(entry: &HashMap<MyAction, f64>) -> Option<&MyAction> {
    entry
        .iter()
        .max_by(|(_, v1), (_, v2)| v1.partial_cmp(v2).unwrap_or(Ordering::Equal))
        .map(|(a, _)| a)
}

And after running the example, it prints this:

→  →  →  →  →  →  →  →  →  →  →  →  →  →  →  →  →  →  →  →  ↑  
↓  →  →  →  →  →  →  →  →  →  →  →  →  →  →  →  →  →  →  ↑  ↑  
↓  ↓  →  →  →  →  →  →  →  →  →  →  →  →  →  →  →  →  ↑  ↑  ↑  
↓  ↓  ↓  ↓  →  →  →  →  →  →  →  →  →  →  →  →  →  ↑  ↑  ↑  ↑  
↓  ↓  ↓  ↓  →  →  →  →  →  →  →  →  →  →  →  →  ↑  ↑  ↑  ↑  ↑  
↓  ↓  ↓  ↓  ↓  →  →  →  →  →  →  →  →  →  →  ↑  ↑  ↑  ↑  ↑  ↑  
↓  ↓  ↓  ↓  ↓  ↓  ↓  →  →  →  →  →  →  →  ↑  ↑  ↑  ↑  ↑  ↑  ↑  
↓  ↓  ↓  ↓  ↓  ↓  ↓  →  →  →  →  →  →  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  
↓  ↓  ↓  ↓  ↓  ↓  ↓  ↓  ↓  →  →  →  →  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  
↓  ↓  ↓  ↓  ↓  ↓  ↓  ↓  ↓  →  →  →  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  
↓  ↓  ↓  ↓  ↓  ↓  ↓  ↓  ↓  ↓  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  
↓  ↓  ↓  ↓  ↓  ↓  ↓  ↓  ↓  ↓  ←  ←  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  
↓  ↓  ↓  ↓  ↓  ↓  ↓  ↓  ↓  ←  ←  ←  ←  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  
↓  ↓  ↓  ↓  ↓  ↓  ↓  ←  ←  ←  ←  ←  ←  ←  ↑  ↑  ↑  ↑  ↑  ↑  ↑  
↓  ↓  ↓  ↓  ↓  ↓  ←  ←  ←  ←  ←  ←  ←  ←  ↑  ↑  ↑  ↑  ↑  ↑  ↑  
↓  ↓  ↓  ↓  ↓  ↓  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ↑  ↑  ↑  ↑  ↑  
↓  ↓  ↓  ↓  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ↑  ↑  ↑  ↑  ↑  
↓  ↓  ↓  ↓  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ↑  ↑  ↑  
↓  ↓  ↓  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ↑  ↑  ↑  
↓  ↓  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ↑  ↑  
←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←

I expected that the arrows would all point toward the center right?

Heres the code for printing the arrows:

impl fmt::Display for MyAction {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match *self {
            MyAction::Move { dx, dy } => {
                match (dx, dy) {
                    (-1, 0) => write!(f, "{LEFT_ARROW}"),
                    (1, 0) => write!(f, "{RIGHT_ARROW}"),
                    (0, -1) => write!(f, "{UP_ARROW}"),
                    (0, 1) => write!(f, "{DOWN_ARROW}"),
                    _ => unreachable!()
                }
            }
        }
    }
}

Dumping results to file

I think the library should offer a way to dump the learned state to file (basically just the HashMap of the AgentTrainer) to save the learned state, checkpoint or continue learning an earlier state. This could be hacked from the outside with some unsafe code, though I think offering access to q within AgentTrainer would be useful.

Sink States

I want to use this with an MDP that has sink states. That is, states with reward and no possible actions. Can I do so with this, and if so, how?

How to make 2 networks play against each other?

I really love the approach of this library, but can't seem to figure out how to make program play with itself. Can I exchange states between 2 agents in any way, or has this not been accounted for yet?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.