Code Monkey home page Code Monkey logo

Comments (11)

rdeits avatar rdeits commented on August 16, 2024 1

Oh, I see. The generator needs access to the tape itself, not just the type information. A generated function won't work there.

from reversediff.jl.

rdeits avatar rdeits commented on August 16, 2024

I guess this is happening because compile(::CompiledTape) creates the forward and backward pass methods by eval()-ing into the current module. I can work around it by calling @eval ReverseDiff.gradient!(result, tape, x), but that causes performance issues.

It seems like fixing this would require changing the way the forward and reverse pass methods are created. Could they perhaps be implemented with an @generated function instead? The logic inside compile() looks suspiciously similar to the way @generated functions behave.

from reversediff.jl.

rdeits avatar rdeits commented on August 16, 2024

A less-bad workaround is to define:

forward_pass!(compiled_tape::CompiledTape) = forward_pass!(compiled_tape.tape.tape)
reverse_pass!(compiled_tape::CompiledTape) = reverse_pass!(compiled_tape.tape.tape)

and then not bother calling generate_forward_pass_method and generate_reverse_pass_method. The disadvantage (and presumably the reason you didn't do that originally) is that each Instruction is a different type, so there will be some run-time dispatch.

Here's a crazy idea: if the tape just exists to call forward_exec! and reverse_exec!, could you store (in addition to the tape) a vector of FunctionWrappers around closures? Something like

FunctionWrapper{Void, Tuple{}}(() -> forward_exec!(instruction))

Then forward_pass! could just iterate through the wrappers (which are all of the same type) and call their wrapped functions (which would have the correct dispatch baked in already, I think).

Does that make any sense?

from reversediff.jl.

rdeits avatar rdeits commented on August 16, 2024

Here's a more concrete version of what I'm suggesting:

julia> using FunctionWrappers: FunctionWrapper

julia> using BenchmarkTools

julia> # Basic function which will return Float64 for float or int input
       f(x) = x + 1.0
f (generic function with 1 method)

julia> # Non-concrete data vector
       y = Number[1, 2.0]
2-element Array{Number,1}:
 1
 2.0

julia> # Wrap the closures in a concretely-typed wrapper
       wrappers = [FunctionWrapper{Float64, Tuple{}}(() -> f(x)) for x in y];

julia> [w() for w in wrappers]
2-element Array{Float64,1}:
 2.0
 3.0

Calling the wrapped functions doesn't allocate:

julia> function forward_pass(wrappers)
           for w in wrappers
               w()
           end
           nothing
       end
forward_pass (generic function with 1 method)

julia> @benchmark forward_pass($wrappers)
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     24.852 ns (0.00% GC)
  median time:      25.743 ns (0.00% GC)
  mean time:        28.313 ns (0.00% GC)
  maximum time:     128.195 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     996

whereas calling f() on each element of y would allocate:

julia> function forward_pass(y)
           for x in y
               f(x)
           end
           nothing
       end
forward_pass (generic function with 1 method)

julia> @benchmark forward_pass($y)
BenchmarkTools.Trial:
  memory estimate:  32 bytes
  allocs estimate:  2
  --------------
  minimum time:     74.362 ns (0.00% GC)
  median time:      78.193 ns (0.00% GC)
  mean time:        84.891 ns (1.10% GC)
  maximum time:     973.529 ns (85.22% GC)
  --------------
  samples:          10000
  evals/sample:     973

from reversediff.jl.

rdeits avatar rdeits commented on August 16, 2024

And here's a proof-of-concept implementation: https://github.com/JuliaDiff/ReverseDiff.jl/compare/master...rdeits:no-eval?expand=1 (requires FunctionWrappers.jl master). Tests seem to be passing, but I haven't looked at performance yet.

from reversediff.jl.

jrevels avatar jrevels commented on August 16, 2024

Yup, this a documented limitation of ReverseDiff.compile. I'll check out your proposed solution today and report back, thanks for playing around with this!

from reversediff.jl.

jrevels avatar jrevels commented on August 16, 2024

A less-bad workaround is to define:

forward_pass!(compiled_tape::CompiledTape) = forward_pass!(compiled_tape.tape.tape)
reverse_pass!(compiled_tape::CompiledTape) = reverse_pass!(compiled_tape.tape.tape)
and then not bother calling generate_forward_pass_method and generate_reverse_pass_method. The disadvantage (and presumably the reason you didn't do that originally) is that each Instruction is a different type, so there will be some run-time dispatch.

Note that this is exactly what using the un-compiled pre-recorded tape API will do. You're correct in that avoiding runtime dispatch is the main advantage to using compile. That should probably be in the documentation to make things more clear...

from reversediff.jl.

jrevels avatar jrevels commented on August 16, 2024

This is my first time messing about with the FunctionWrappers package, this approach looks really cool! Another benefit would be avoiding the crazy compilation times associated with the functions generated from large tapes. If this does end up working out, we wouldn't even need the compile function anymore - AFAICT, there'd be no reason not to wrap the instructions by default.

Unfortunately, it seems that there are still some performance issues to work through. Here's the performance of a toy benchmark on master:

julia> using ReverseDiff, BenchmarkTools

# good benchmark for tape traversal performance
julia> function rosenbrock(x)
                         a = one(eltype(x))
                         b = 100 * a
                         result = zero(eltype(x))
                         for i in 1:length(x)-1
                             result += (a - x[i])^2 + b*(x[i+1] - x[i]^2)^2
                         end
                         return result
                     end
rosenbrock (generic function with 1 method)

julia> const tape = ReverseDiff.GradientTape(rosenbrock, rand(100))
ReverseDiff.GradientTape(rosenbrock)

julia> const ctape = ReverseDiff.compile(tape)
ReverseDiff.CompiledTape{##660}(rosenbrock)

julia> y, x = zeros(100), rand(100);

julia> @benchmark ReverseDiff.gradient!($y, $tape, $x)
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     122.630 μs (0.00% GC)
  median time:      124.411 μs (0.00% GC)
  mean time:        124.321 μs (0.00% GC)
  maximum time:     171.988 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1

julia> @benchmark ReverseDiff.gradient!($y, $ctape, $x)
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     67.663 μs (0.00% GC)
  median time:      68.532 μs (0.00% GC)
  mean time:        68.719 μs (0.00% GC)
  maximum time:     161.848 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1

Performance of compile on the no-eval branch:

julia> @benchmark ReverseDiff.gradient!($y, $ctape, $x)
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     199.394 μs (0.00% GC)
  median time:      200.976 μs (0.00% GC)
  mean time:        201.844 μs (0.00% GC)
  maximum time:     482.731 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1

I haven't dug any deeper than this, but if I were to tinker with it, I'd use a callable struct with an instruction field instead of a closure, then define call with @inline. That should at least prevent any performance problems due weird code lowering issues (though I'm not sure that's what's happening here).

from reversediff.jl.

rdeits avatar rdeits commented on August 16, 2024

Yes, the callable struct is a great idea! With that change, I'm seeing performance only slightly worse than master for compiled tapes (71us with master on v0.5, 80us with no-eval on v0.6).

I've pushed the updates, so you should be able to try it yourself.

from reversediff.jl.

jrevels avatar jrevels commented on August 16, 2024

With that change, I'm seeing performance only slightly worse than master for compiled tapes (71us with master on v0.5, 80us with no-eval on v0.6).

That's a totally acceptable regression compared to the benefits of this approach - feel free to open a PR for no-eval! After such a PR gets merged, I'll explore removing compile entirely and using the FunctionWrapper approach by default.

from reversediff.jl.

rdeits avatar rdeits commented on August 16, 2024

Ok, done: #71

from reversediff.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.