Code Monkey home page Code Monkey logo

crux.jl's Introduction

Crux.jl

Build Status Code Coverage

Deep RL library with concise implementations of popular algorithms. Implemented using Flux.jl and fits into the POMDPs.jl interface.

Supports CPU and GPU computation and implements the following algorithms:

Reinforcement Learning

Imitation Learning

Batch RL

Adversarial RL

Continual Learning

  • Experience Replay

Installation

  • Install POMDPGym
  • Install by opening julia and running ] add Crux

To edit or contribute use ] dev Crux and the repo will be cloned to ~/.julia/dev/Crux

Maintained by Anthony Corso ([email protected])

crux.jl's People

Contributors

ancorso avatar c-j-cundy avatar jamgochiana avatar mossr avatar mykelk avatar smkatz12 avatar whifflefish avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

crux.jl's Issues

Default compatibility should be with CommonRLInterface

The package was originally designed to work with a handful of POMDPs from the POMDPs.jl ecosystem. Now that it is more general, we should probably change the design so that it assumes a problem conforms to the CommonRLInterface and not an POMDP/MDP interface. We can easily construct a CommonRLInterface from the POMDP or MDP (thats effectively what is happening here:

if sampler.mdp isa POMDP
).

Requesting some help, 2 Deg of Freedom Inverted Pendulum,

I was trying to extend the inverted Pendulum model implemented in the examples to a 2 degrees of freedom Inverted Pendulum,

@with_kw struct Pendulum_2D_MDP <: MDP{Array{Float32}, Array{Float32}}
    failure_thresh::Union{Nothing, Float64} = nothing # if set, defines the operating range fo the pendulum. Episode terminates if abs(θ) is larger than this. 
    θ0_x = Distributions.Uniform(-π, π) # Distribution to sample initial angular position in the x direction
    ω0_x = Distributions.Uniform(-1., 1.) # Distribution to sample initial angular velocity in the x direction
    θ0_y = Distributions.Uniform(-π, π) # Distribution to sample initial angular position in the y direction    
    ω0_y = Distributions.Uniform(-1., 1.) # Distribution to sample initial angular velocity in the y direction
    Rstep = 0 # Reward earned on each step of the simulation
    λcost = 1 # Coefficient to the traditional OpenAIGym Reward
    max_speed::Float64 = 8.
    max_torque::Float64 = 2.
    ashift::Float64 = 0.0
    dt::Float64 = .001
    g::Float64 = 9.81
    m::Float64 = 65.
    l::Float64 = 1.1
    γ::Float32 = 0.9
    ashift_x::Float64 = 0.0
    ashift_y::Float64 = 0.0
    actions::Vector{Float64} = [[-1., 1.]]
    include_time_in_state = false
    maxT = 99*dt
    px = nothing # Distribution over disturbances
end

The extension that is 2DoF Inverted Pendulum is modelled by a system of equations called pendulum_2d_dynamics_ODE. The function returns the solution at the next time step.

The Pendulum_2DoF_dynamics is given below


function pendulum_2d_dynamics(env, s, a, x = isnothing(env.px) ? 0 : rand(env.px); rng::AbstractRNG = Random.GLOBAL_RNG)
    # Deal with terminal states
    # println("failur thresh: ", env.failure_thresh, " val: ", abs(s[1]))
    if (isnothing(env.failure_thresh) ?  false : abs(s[1]) > env.failure_thresh)
        # println("here")
        return fill(-100, size(s)), 0
    elseif env.include_time_in_state && (s[3] > env.maxT || s[5] ≈ env.maxT)
        return fill(100, size(s)), 0
    end
        
    θ_x = s[1]
    ω_x = s[2]
    θ_y = s[3]
    ω_y = s[4]
    dt, g, m, l = env.dt, env.g, env.m, env.l

    a_x = a[1]
    a_y = a[1]
    a_x = clamp(a_x, -env.max_torque, env.max_torque)
    a_y = clamp(a_y, -env.max_torque, env.max_torque)
    costs = angle_normalize(θ_x)^2 + 0.1f0 * ω_x^2 + 0.001f0 * a_x^2 + angle_normalize(θ_y)^2 + 0.1f0 * ω_y^2 + 0.001f0 * a_y^2
    
    a_x = a_x + env.ashift
    a_y = a_y + env.ashift
    
    a_x = a_x + x
    a_y = a_y + x

    
    θ_x, ω_x, θ_y, ω_y = pendulum_2d_dynamics_ODE(a_x, a_y, θ_x, ω_x, θ_y, ω_y, m, g, l)


    #=
    ω = ω + (-3. * g / (2 * l) * sin(θ + π) + 3. * a / (m * l^2)) * dt
    θ = angle_normalize(θ + ω * dt)
    ω = clamp(ω, -env.max_speed, env.max_speed)
    =#

    if env.include_time_in_state
        sp = Float32.([θ_x, ω_x, θ_y, ω_y, s[5] + dt])
    else
        sp = Float32.([θ_x, ω_x, θ_y, ω_y])
    end
    r = Float32(env.Rstep - env.λcost*costs)
    return sp, r
end

The Deep_RL is defined as

A = [-2., -0.5, 0, 0.5, 2.]
𝔄 = Base.Iterators.product(A,A) |>collect |>vec
𝔄_2D = [collect(a) for a in 𝔄]

mdp = InvertedPendulum_2D_MDP(actions=A)
as = [actions(mdp)...]
amin = [-2f0]
amax = [2f0]
rand_policy = FunctionPolicy((s) -> Float32.(rand.(Distributions.Uniform.(amin, amax))))
S = state_space(mdp, σ=[3.14f0, 8f0, 3.14f0, 8f0])


SG() = SquashedGaussianPolicy(ContinuousNetwork(Chain(Dense(4, 64, relu), Dense(64, 64, relu), Dense(64, 1)), 1,), zeros(Float32, 1,), 2f0)

𝒮_reinforce = REINFORCE(π=SG(), S=S, N=10000, ΔN=2048, a_opt=(batch_size=512,))
π_reinforce = POMDPs.solve(𝒮_reinforce, mdp)

The problem I am facing is how do I setup an action space with now 2 torques, called a_x and a_y. I don't know if the way I have done it will work. However I am getting very good rewards.

Step: 0, undiscounted_return: 98.408066
Step: 2048, undiscounted_return: 98.030914, actor_batches_trained: 320, actor_loss: 5.8472886, entropy: 1.4157778, actor_grad_norm: 0.83697426, kl: -0.0006680194
Step: 4096, undiscounted_return: 98.45372, actor_batches_trained: 320, actor_loss: 5.8725863, entropy: 1.4154032, actor_grad_norm: 0.86363125, kl: -0.0013817283
Step: 6144, undiscounted_return: 97.804565, actor_batches_trained: 320, actor_loss: 5.89741, entropy: 1.4170549, actor_grad_norm: 1.0493901, kl: -0.00082278205
Step: 8192, undiscounted_return: 97.0133, actor_batches_trained: 320, actor_loss: 5.9142237, entropy: 1.4090993, actor_grad_norm: 1.1050732, kl: -0.00086959853
SquashedGaussianPolicy(ContinuousNetwork(Chain(Dense(4 => 64, relu), Dense(64 => 64, relu), Dense(64 => 1)), 1, Flux.cpu), ContinuousNetwork(Chain(ConstantLayer{Vector{Float32}}(Float32[-0.013098647])), 1, Flux.cpu), 2.0f0, false)

I am pretty sure something is seriously wrong. Requesting some help regarding this implementation. Thank you very much.

Reexport Flux

using Reexport
@reexport using Flux

That way we don't have to call using Flux whenever we are using Crux

Compatibility with recent Julia versions

Hi,

I've noticed that installing Crux.jl under more recent Julia versions may result in compatibility issues, e.g., for Julia 1.8. Is there any plan to release an update that addresses this issue in the near term?

solve failure for POMDPs.jl models

using POMDPModels
using POMDPs
using Flux
using Crux

mdp = SimpleGridWorld()
as = actions(mdp)
S = state_space(mdp)

A() = DiscreteNetwork(Chain(Dense(Crux.dim(S)..., 64, relu), Dense(64, 64, relu), Dense(64, length(as))), as)
V() = ContinuousNetwork(Chain(Dense(Crux.dim(S)..., 64, relu), Dense(64, 64, relu), Dense(64, 1)))

𝒮_ppo = PPO=ActorCritic(A(), V()), S=S, N=10_000, ΔN=1_000)
π_ppo = solve(𝒮_ppo, mdp)

Initially this fails because cpucall does not have a method for StaticArrays.

After adding a method accounting for this (or covering all bases with AbstractArray), another failure occurs here:
https://github.com/ancorso/Crux.jl/blob/c32fd8ca94437c991eb0a9bb54686531d8542d23/src/sampler.jl#L93 .

This can be corrected by switching to sp, r = @gen(:sp,:r)(sampler.mdp, sampler.s, args...; kwargs...) from POMDPs.jl.

I'm not entirely certain what the side effects on certain POMDPGym environments may be, but at least this allows GridWorld to run.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.