Code Monkey home page Code Monkey logo

deepqlearning.jl's Introduction

DeepQLearning

Build status codecov

This package provides an implementation of the Deep Q learning algorithm for solving MDPs. For more information see https://arxiv.org/pdf/1312.5602.pdf. It uses POMDPs.jl and Flux.jl

It supports the following innovations:

Installation

using Pkg
Pkg.add("DeepQLearning")

Usage

using DeepQLearning
using POMDPs
using Flux
using POMDPModels
using POMDPSimulators
using POMDPTools

# load MDP model from POMDPModels or define your own!
mdp = SimpleGridWorld();

# Define the Q network (see Flux.jl documentation)
# the gridworld state is represented by a 2 dimensional vector.
model = Chain(Dense(2, 32), Dense(32, length(actions(mdp))))

exploration = EpsGreedyPolicy(mdp, LinearDecaySchedule(start=1.0, stop=0.01, steps=10000/2))

solver = DeepQLearningSolver(qnetwork = model, max_steps=10000, 
                             exploration_policy = exploration,
                             learning_rate=0.005,log_freq=500,
                             recurrence=false,double_q=true, dueling=true, prioritized_replay=true)
policy = solve(solver, mdp)

sim = RolloutSimulator(max_steps=30)
r_tot = simulate(sim, mdp, policy)
println("Total discounted reward for 1 simulation: $r_tot")

Specifying exploration / evaluation policy

An exploration policy and evaluation policy can be specified in the solver parameters.

An exploration policy can be provided in the form of a function that must return an action. The function provided will be called as follows: f(policy, env, obs, global_step, rng) where policy is the NN policy being trained, env the environment, obs the observation at which to take the action, global_step the interaction step of the solver, and rng a random number generator. This package provides by default an epsilon greedy policy with linear decrease of epsilon with global_step.

An evaluation policy can be provided in a similar manner. The function will be called as follows: f(policy, env, n_eval, max_episode_length, verbose) where policy is the NN policy being trained, env the environment, n_eval the number of evaluation episode, max_episode_length the maximum number of steps in one episode, and verbose a boolean to enable printing or not. The evaluation function must returns three elements:

  • Average total reward (Float), the average score per episode
  • Average number of steps (Float), the average number of steps taken per episode
  • Info, a dictionary mapping String to Float that can be used to log custom scalar values.

Q-Network

The qnetwork options of the solver should accept any Chain object. It is expected that they will be multi-layer perceptrons or convolutional layers followed by dense layer. If the network is ending with dense layers, the dueling option will split all the dense layers at the end of the network.

If the observation is a multi-dimensional array (e.g. an image), one can use the flattenbatch function to flatten all the dimensions of the image. It is useful to connect convolutional layers and dense layers for example. flattenbatch will flatten all the dimensions but the batch size.

The input size of the network is problem dependent and must be specified when you create the q network.

This package exports the type AbstractNNPolicy which represents neural network based policy. In addition to the functions from POMDPs.jl, AbstractNNPolicy objects supports the following: - getnetwork(policy): returns the value network of the policy - resetstate!(policy): reset the hidden states of a policy (does nothing if it is not an RNN)

Saving/Reloading model

See Flux.jl documentation for saving and loading models. The DeepQLearning solver saves the weights of the Q-network as a bson file in solver.logdir/"qnetwork.bson".

Logging

Logging is done through TensorBoardLogger.jl. A log directory can be specified in the solver options, to disable logging you can set the logdir option to nothing.

GPU Support

DeepQLearning.jl should support running the calculations on GPUs through the package CuArrays.jl. You must checkout the branch gpu-support. Note that it has not been tested thoroughly. To run the solver on GPU you must first load CuArrays and then proceed as usual.

using CuArrays
using DeepQLearning
using POMDPs
using Flux
using POMDPModels

mdp = SimpleGridWorld();

# the model weights will be send to the gpu in the call to solve
model = Chain(Dense(2, 32), Dense(32, length(actions(mdp))))

solver = DeepQLearningSolver(qnetwork = model, max_steps=10000, 
                             learning_rate=0.005,log_freq=500,
                             recurrence=false,double_q=true, dueling=true, prioritized_replay=true)
policy = solve(solver, mdp)

Solver Options

Fields of the Q Learning solver:

  • qnetwork::Any = nothing Specify the architecture of the Q network
  • learning_rate::Float64 = 1e-4 learning rate
  • max_steps::Int64 total number of training step default = 1000
  • target_update_freq::Int64 frequency at which the target network is updated default = 500
  • batch_size::Int64 batch size sampled from the replay buffer default = 32
  • train_freq::Int64 frequency at which the active network is updated default = 4
  • log_freq::Int64 frequency at which to logg info default = 100
  • eval_freq::Int64 frequency at which to eval the network default = 100
  • num_ep_eval::Int64 number of episodes to evaluate the policy default = 100
  • eps_fraction::Float64 fraction of the training set used to explore default = 0.5
  • eps_end::Float64 value of epsilon at the end of the exploration phase default = 0.01
  • double_q::Bool double q learning udpate default = true
  • dueling::Bool dueling structure for the q network default = true
  • recurrence::Bool = false set to true to use DRQN, it will throw an error if you set it to false and pass a recurrent model.
  • prioritized_replay::Bool enable prioritized experience replay default = true
  • prioritized_replay_alpha::Float64 default = 0.6
  • prioritized_replay_epsilon::Float64 default = 1e-6
  • prioritized_replay_beta::Float64 default = 0.4
  • buffer_size::Int64 size of the experience replay buffer default = 1000
  • max_episode_length::Int64 maximum length of a training episode default = 100
  • train_start::Int64 number of steps used to fill in the replay buffer initially default = 200
  • save_freq::Int64 save the model every save_freq steps, default = 1000
  • evaluation_policy::Function = basic_evaluation function use to evaluate the policy every eval_freq steps, the default is a rollout that return the undiscounted average reward
  • exploration_policy::Any = linear_epsilon_greedy(max_steps, eps_fraction, eps_end) exploration strategy (default is epsilon greedy with linear decay)
  • rng::AbstractRNG random number generator default = MersenneTwister(0)
  • logdir::String = "" folder in which to save the model
  • verbose::Bool default = true

deepqlearning.jl's People

Contributors

dylan-asmar avatar github-actions[bot] avatar juliatagbot avatar lkruse avatar maximebouton avatar peggyyuchunwang avatar rejuvyesh avatar zsunberg avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

deepqlearning.jl's Issues

Switch to Flux.jl

The current package is relying on TensorFlow.jl, it might be interesting to test out an implementation using Flux.jl since it seems to be the future of deep learning for Julia.

TagBot trigger issue

This issue is used to trigger TagBot; feel free to unsubscribe.

If you haven't already, you should update your TagBot.yml to include issue comment triggers.
Please see this post on Discourse for instructions and more details.

If you'd like for me to do this for you, comment TagBot fix on this issue.
I'll open a PR within a few hours, please be patient!

Support of AbtractEnvironment

This solver uses some function that are broader than the minimal interface defined in RLInterface and relies on internal fields such as env.problem in many places.
Ideally, the solver should support an RL environment defined just using RLInterface.jl and without necessarily an MDP or POMDP object associated with it.

Compilation Error

The following error is thrown when attempting to use DeepQLearning.jl:
ERROR: LoadError: LoadError: UndefVarError: Tracker not defined

This appears to be a Flux issue.

Exploration Policy requires a (PO)MDP

It would be nice to be able to use this package only with CommonRLInterface and not need to know anything about POMDPs.jl. Currently, the main thing presenting this is the exploration policy.

Error: Can't differentiate loopinfo expression

Running the example given in the docs:

using DeepQLearning
using POMDPs
using Flux
using POMDPModels
using POMDPSimulators
using POMDPPolicies

# load MDP model from POMDPModels or define your own!
mdp = SimpleGridWorld();

# Define the Q network (see Flux.jl documentation)
# the gridworld state is represented by a 2 dimensional vector.
model = Chain(Dense(2, 32), Dense(32, length(actions(mdp))))

exploration = EpsGreedyPolicy(mdp, LinearDecaySchedule(start=1.0, stop=0.01, steps=10000/2))

solver = DeepQLearningSolver(qnetwork = model, max_steps=10000, 
                             exploration_policy = exploration,
                             learning_rate=0.005,log_freq=500,
                             recurrence=false,double_q=true, dueling=true, prioritized_replay=true)
policy = solve(solver, mdp)

sim = RolloutSimulator(max_steps=30)
r_tot = simulate(sim, mdp, policy)
println("Total discounted reward for 1 simulation: $r_tot")

produces an error as follows, where can I look to fix this? I am using Julia 1.6

ERROR: Can't differentiate loopinfo expression
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] macro expansion
    @ ./simdloop.jl:79 [inlined]
  [3] Pullback
    @ ./reduce.jl:243 [inlined]
  [4] (::typeof(∂(mapreduce_impl)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
  [5] Pullback
    @ ./reduce.jl:257 [inlined]
  [6] (::typeof(∂(mapreduce_impl)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
  [7] Pullback
    @ ./reduce.jl:415 [inlined]
  [8] (::typeof(∂(_mapreduce)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
  [9] Pullback
    @ ./reducedim.jl:318 [inlined]
 [10] Pullback (repeats 2 times)
    @ ./reducedim.jl:310 [inlined]
 [11] (::typeof(∂(mapreduce)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [12] Pullback
    @ ./reducedim.jl:878 [inlined]
 [13] (::typeof(∂(#_sum#682)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [14] Pullback
    @ ./reducedim.jl:878 [inlined]
 [15] (::typeof(∂(_sum)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [16] Pullback (repeats 2 times)
    @ ./reducedim.jl:874 [inlined]
 [17] (::typeof(∂(sum)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [18] Pullback
    @ ~/.julia/packages/DeepQLearning/jJkAu/src/solver.jl:223 [inlined]
 [19] (::typeof(∂(λ)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [20] (::Zygote.var"#69#70"{Zygote.Params, typeof(∂(λ)), Zygote.Context})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:252
 [21] gradient(f::Function, args::Zygote.Params)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59
 [22] batch_train!(solver::DeepQLearningSolver{EpsGreedyPolicy{LinearDecaySchedule{Float64}, Random._GLOBAL_RNG, NTuple{4, Symbol}}}, env::POMDPModelTools.MDPCommonRLEnv{AbstractArray{Float32, N} where N, SimpleGridWorld, StaticArrays.SVector{2, Int64}}, policy::NNPolicy{SimpleGridWorld, DeepQLearning.DuelingNetwork, Symbol}, optimizer::ADAM, target_q::DeepQLearning.DuelingNetwork, replay::PrioritizedReplayBuffer{Int32, Float32, CartesianIndex{2}, StaticArrays.SVector{2, Float32}, Matrix{Float32}}; discount::Float64)
    @ DeepQLearning ~/.julia/packages/DeepQLearning/jJkAu/src/solver.jl:219
 [23] batch_train!
    @ ~/.julia/packages/DeepQLearning/jJkAu/src/solver.jl:200 [inlined]
 [24] dqn_train!(solver::DeepQLearningSolver{EpsGreedyPolicy{LinearDecaySchedule{Float64}, Random._GLOBAL_RNG, NTuple{4, Symbol}}}, env::POMDPModelTools.MDPCommonRLEnv{AbstractArray{Float32, N} where N, SimpleGridWorld, StaticArrays.SVector{2, Int64}}, policy::NNPolicy{SimpleGridWorld, DeepQLearning.DuelingNetwork, Symbol}, replay::PrioritizedReplayBuffer{Int32, Float32, CartesianIndex{2}, StaticArrays.SVector{2, Float32}, Matrix{Float32}})
    @ DeepQLearning ~/.julia/packages/DeepQLearning/jJkAu/src/solver.jl:138
 [25] solve(solver::DeepQLearningSolver{EpsGreedyPolicy{LinearDecaySchedule{Float64}, Random._GLOBAL_RNG, NTuple{4, Symbol}}}, env::POMDPModelTools.MDPCommonRLEnv{AbstractArray{Float32, N} where N, SimpleGridWorld, StaticArrays.SVector{2, Int64}})
    @ DeepQLearning ~/.julia/packages/DeepQLearning/jJkAu/src/solver.jl:56
 [26] solve(solver::DeepQLearningSolver{EpsGreedyPolicy{LinearDecaySchedule{Float64}, Random._GLOBAL_RNG, NTuple{4, Symbol}}}, problem::SimpleGridWorld)
    @ DeepQLearning ~/.julia/packages/DeepQLearning/jJkAu/src/solver.jl:32
 [27] top-level scope
    @ REPL[11]:1

Avoid using env.state

The solver makes use of env.state to resolve conflicts between change in the mutable object env during both exploration and evaluation.

See the comments [here].(cf60925)

DQExperience should support AbstractArrays

From Piazza:

I'm troubleshooting in the deep Q learning package and I'm having a problem with the DQExperience object. The object is defined in 'prioritized_experience_replay.jl' and is as follows:

struct DQExperience{N <: Real,T <: Real, Q}
s::Array{T, Q}
a::N
r::T
sp::Array{T, Q}
done::Bool
end

The problem is that our states are defined using the StaticArrays type

Is there a reason that state is constrained to an array of real numbers? Is there a way to create a subclass of DQExperience or something that allows or the StaticArrays type?

RNN Performance

The solver seems too slow when using RNN. This statement needs to be supported by benchmark of course.

Potential performance issues:

  • the replay buffer
  • computing the loss function

Fix avgR in terminal while training

Logging of avgR in terminal may be off, see below.

51000 / 1000000 eps 0.899 |  avgR -0.997 | Loss 9.240e-03 | Grad 5.585e-03 
51500 / 1000000 eps 0.898 |  avgR -0.997 | Loss 2.298e-02 | Grad 1.119e-02 
52000 / 1000000 eps 0.897 |  avgR -0.997 | Loss 8.478e-02 | Grad 7.080e-02 
52500 / 1000000 eps 0.896 |  avgR -0.997 | Loss 1.660e-02 | Grad 6.243e-03 
53000 / 1000000 eps 0.895 |  avgR -0.997 | Loss 1.105e-02 | Grad 4.012e-03 
53500 / 1000000 eps 0.894 |  avgR -0.997 | Loss 1.182e-02 | Grad 7.963e-03 

Use only RLInterface.jl interface

One of my students discovered that this package uses POMDPs.actionindex. Would it be possible to make the package only use functions from the RLInterface.jl interface? (i.e. we would need to construct our own action map)

Problem with reading log files

Hi,

I'm attempting to read in the log files generated by TensorBoardLogger, but am having some issues. When I try the method for de-serialization recommended in the TensorBoardLogger docs I get an error regarding crc headers, so I'm wondering if there's a specific method that works for reading the logs generated from this package. I've included the error message below.

Alternatively, if there's a way to plot learning curves without reading in the log files that would also be helpful.

Thanks

ERROR: AssertionError: crc_header == crc_header_ck
Stacktrace:
[1] read_event(::IOStream) at /home/ben/.julia/packages/TensorBoardLogger/gv4oF/src/Deserialization/deserialization.jl:16
[2] iterate(::TensorBoardLogger.TBEventFileIterator, ::Int64) at /home/ben/.julia/packages/TensorBoardLogger/gv4oF/src/Deserialization/deserialization.jl:84
[3] iterate at /home/ben/.julia/packages/TensorBoardLogger/gv4oF/src/Deserialization/deserialization.jl:83 [inlined]
[4] iterate(::TensorBoardLogger.TBEventFileCollectionIterator, ::Int64) at /home/ben/.julia/packages/TensorBoardLogger/gv4oF/src/Deserialization/deserialization.jl:59
[5] iterate at /home/ben/.julia/packages/TensorBoardLogger/gv4oF/src/Deserialization/deserialization.jl:52 [inlined]
[6] #map_summaries#158(::Bool, ::Nothing, ::Nothing, ::Bool, ::typeof(map_summaries), ::var"#6#7", ::String) at /home/ben/.julia/packages/TensorBoardLogger/gv4oF/src/Deserialization/deserialization.jl:211
[7] map_summaries(::Function, ::String) at /home/ben/.julia/packages/TensorBoardLogger/gv4oF/src/Deserialization/deserialization.jl:205
[8] top-level scope at REPL[36]:1

Dimension Mismatch

I am having trouble debugging my use of the DeepQLearning package and looking for some help.

My problem is the mountain car problem where the state represents the position and velocity of the car and the action is the force you can apply in order to climb the mountain. The car starts in a valley and needs to climb out to get the reward. The force is not enough to climb to the top so you need to build up momentum to get up the hill.

The state is a 2-element StaticArrays.SArray{Tuple{2},Float64,1,2} with indices SOneTo(2).
The action space is RealInterval{Float64}(-1.0, 1.0), but I discretized this.

My network is as follows:
#Define the Q Network (input {state,action}, return Q(s,a))
activation = leakyrelu;
inputlayer = Dense(3,50,activation); # Input is the size of the state-action pair
hiddenlayer1 = Dense(50,50,activation);
outputlayer = Dense(50,1,activation);
model = Chain(inputlayer,hiddenlayer1,outputlayer)

The environment is an MDPEnvironment and my solver is a DeepQLearningSolver but running the following results in a dimension mismatch:
policy = solve(solver,env)
DimensionMismatch("A has dimensions (50,3) but B has dimensions (2,32)")

I followed the stacktrace and its happening in the Flux library here:
function (a::Dense)(x::AbstractArray)
W, b, σ = a.W, a.b, a.σ
σ.(W*x .+ b)
end

But I have no idea why any of these matrices would be size (2,32)? I assume its this AbstractArray x, since the weights would be size (50,3)... but shouldn't x just be the input size (3,1)? Is it trying to do some batch processing or something? But that doesn't explain how we get to (2,32). The only place I can imagine those numbers is that the weight matrix itself would be a Array{Float32,2}, which makes no sense but does match up if it's somehow getting transposed... I'm not sure if this is a bug or if I am implementing this incorrectly. Any thoughts would be great appreciated. Thanks!

The entire Stacktrace is below for reference:
DimensionMismatch("A has dimensions (50,3) but B has dimensions (2,32)")
Stacktrace:
[1] gemm_wrapper!(::Array{Float32,2}, ::Char, ::Char, ::Array{Float32,2}, ::Array{Float32,2}, ::LinearAlgebra.MulAddMul{true,true,Float32,Float32}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.3/LinearAlgebra/src/matmul.jl:545
[2] mul! at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.3/LinearAlgebra/src/matmul.jl:160 [inlined]
[3] mul! at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.3/LinearAlgebra/src/matmul.jl:203 [inlined]
[4] * at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.3/LinearAlgebra/src/matmul.jl:153 [inlined]
[5] (::Dense{typeof(leakyrelu),Array{Float32,2},Array{Float32,1}})(::Array{Float32,2}) at /Users/liamsmith/.julia/packages/Flux/NpkMm/src/layers/basic.jl:115
[6] applychain at /Users/liamsmith/.julia/packages/Flux/NpkMm/src/layers/basic.jl:126 [inlined]
[7] Chain at /Users/liamsmith/.julia/packages/Flux/NpkMm/src/layers/basic.jl:32 [inlined]
[8] batch_train!(::DeepQLearningSolver, ::MDPEnvironment{Array{Float32,1},QuickPOMDPs.QuickMDP{UUID("c4d31997-7cb6-478c-8b46-c104fdaf65ad"),StaticArrays.SArray{Tuple{2},Float64,1,2},Float64,NamedTuple{(:isterminal, :render, :initialstate, :gen, :actions, :discount),Tuple{DMUStudent.HW4.var"#3#10",DMUStudent.HW4.var"#4#11",DMUStudent.HW4.var"#2#9",DMUStudent.HW4.var"#1#8",DMUStudent.HW4.RealInterval{Float64},Float64}}},StaticArrays.SArray{Tuple{2},Float64,1,2},Random.MersenneTwister,false}, ::NNPolicy{QuickPOMDPs.QuickMDP{UUID("c4d31997-7cb6-478c-8b46-c104fdaf65ad"),StaticArrays.SArray{Tuple{2},Float64,1,2},Float64,NamedTuple{(:isterminal, :render, :initialstate, :gen, :actions, :discount),Tuple{DMUStudent.HW4.var"#3#10",DMUStudent.HW4.var"#4#11",DMUStudent.HW4.var"#2#9",DMUStudent.HW4.var"#1#8",DMUStudent.HW4.RealInterval{Float64},Float64}}},Chain{Tuple{Dense{typeof(leakyrelu),Array{Float32,2},Array{Float32,1}},Dense{typeof(leakyrelu),Array{Float32,2},Array{Float32,1}},Dense{typeof(leakyrelu),Array{Float32,2},Array{Float32,1}}}},Float64}, ::ADAM, ::Chain{Tuple{Dense{typeof(leakyrelu),Array{Float32,2},Array{Float32,1}},Dense{typeof(leakyrelu),Array{Float32,2},Array{Float32,1}},Dense{typeof(leakyrelu),Array{Float32,2},Array{Float32,1}}}}, ::PrioritizedReplayBuffer{Int32,Float32,CartesianIndex{2},1}) at /Users/liamsmith/.julia/packages/DeepQLearning/wF0rJ/src/solver.jl:208
[9] dqn_train!(::DeepQLearningSolver, ::MDPEnvironment{Array{Float32,1},QuickPOMDPs.QuickMDP{UUID("c4d31997-7cb6-478c-8b46-c104fdaf65ad"),StaticArrays.SArray{Tuple{2},Float64,1,2},Float64,NamedTuple{(:isterminal, :render, :initialstate, :gen, :actions, :discount),Tuple{DMUStudent.HW4.var"#3#10",DMUStudent.HW4.var"#4#11",DMUStudent.HW4.var"#2#9",DMUStudent.HW4.var"#1#8",DMUStudent.HW4.RealInterval{Float64},Float64}}},StaticArrays.SArray{Tuple{2},Float64,1,2},Random.MersenneTwister,false}, ::NNPolicy{QuickPOMDPs.QuickMDP{UUID("c4d31997-7cb6-478c-8b46-c104fdaf65ad"),StaticArrays.SArray{Tuple{2},Float64,1,2},Float64,NamedTuple{(:isterminal, :render, :initialstate, :gen, :actions, :discount),Tuple{DMUStudent.HW4.var"#3#10",DMUStudent.HW4.var"#4#11",DMUStudent.HW4.var"#2#9",DMUStudent.HW4.var"#1#8",DMUStudent.HW4.RealInterval{Float64},Float64}}},Chain{Tuple{Dense{typeof(leakyrelu),Array{Float32,2},Array{Float32,1}},Dense{typeof(leakyrelu),Array{Float32,2},Array{Float32,1}},Dense{typeof(leakyrelu),Array{Float32,2},Array{Float32,1}}}},Float64}, ::PrioritizedReplayBuffer{Int32,Float32,CartesianIndex{2},1}) at /Users/liamsmith/.julia/packages/DeepQLearning/wF0rJ/src/solver.jl:136
[10] solve(::DeepQLearningSolver, ::MDPEnvironment{Array{Float32,1},QuickPOMDPs.QuickMDP{UUID("c4d31997-7cb6-478c-8b46-c104fdaf65ad"),StaticArrays.SArray{Tuple{2},Float64,1,2},Float64,NamedTuple{(:isterminal, :render, :initialstate, :gen, :actions, :discount),Tuple{DMUStudent.HW4.var"#3#10",DMUStudent.HW4.var"#4#11",DMUStudent.HW4.var"#2#9",DMUStudent.HW4.var"#1#8",DMUStudent.HW4.RealInterval{Float64},Float64}}},StaticArrays.SArray{Tuple{2},Float64,1,2},Random.MersenneTwister,false}) at /Users/liamsmith/.julia/packages/DeepQLearning/wF0rJ/src/solver.jl:58
[11] top-level scope at In[26]:1

GPU support

Can a maintainer mention what is the state with the GPU support?
The README says that the gpu-support branch should be used, but that one was updated 5 years ago last time.
Given all the changes that happened in the meantime I guess it might be easier to make a new branch and support GPU from scratch. Also why a different branch? couldn't it be an option? It would be great if someone could provide some insight.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.