Code Monkey home page Code Monkey logo

Comments (16)

jacobusmmsmit avatar jacobusmmsmit commented on July 18, 2024 2

After a call with Mohamed, I think the implementation we decided to try out will address scan quite nicely. I'll write up what we discussed here so it's public. For simplicity, I'm assuming that xs is a shaped array, so its shape is included in its type.

scan is a function that has a core computation, f, and some machinery around it to "catch" its outputs. In this case, scan's output shapes and concrete types are determined once f, xs, and init are known. We want to avoid recompiling the core computation f every time xs changes shape.

[API for this part a work in progress, I mean so is this whole thing but this part especially]
To address this, we wrap f in a callable struct CachedReverseDiffBackend that contains f and a compiled tape:

struct CachedReverseDiffBackend{F, T} # Could also be parametric in backend type
    func::F
    compiled_tape::T
    # Constructor to compile the tape given inputs
    function CachedReverseDiffBackend(f::F, x) where {F}
        compiled_tape = compile(construct_tape(f, x)) # pseudo RD code
        T = typeof(compiled_tape)
        return new{F, T}(f, compiled_tape)
    end
end

compiled_f = CachedReverseDiffBackend(f, x) # where typeof(x) == eltype(xs)

then we make the cached backend callable (with the caveat that @grad only accepts functions so we define call_func too):

const CRDB = CachedReverseDiffBackend # alias for brevity

(b::CRDB)(y) = call_func(b, y)
call_func(b::CRDB, y) = b.func(y)

and define a custom rule for our CRDB structs:

function call_func(b::CRDB, y::TrackedArray)
    return ReverseDiff.track(call_func, b, y)
end

import AbstractDifferentiation as AD
ReverseDiff.@grad function call_func(b::CRDB, y)
    return AD.value_and_pullback_function(b, y) # to be implemented
end

Now we can pass compiled_f to scan instead: scan(compiled_f, xs, init), and when we try to differentiate through it with ReverseDiff.gradient, it will reach compiled_f inside the loop and see that there's a custom rule for it. The custom rule we defined (making use of AD) calls ReverseDiff.gradient on compiled_f and uses the compiled tape that we created when defining compiled_f = ....

So in the end we have an outer uncompiled tape which contains calls to inner compiled tapes.

from reversediff.jl.

mohamed82008 avatar mohamed82008 commented on July 18, 2024 1

First, sorry for the really late response.

In my testing, everything defined outside of the grad_branching function would be frozen, but I couldn't find any documentation on this in ReverseDiff.

Correct. Documentation PRs are welcome :)

The yv = cb.func(xv) in the code below will break higher order AD. I would use yv = cb(xv) instead letting dispatch do its thing.

function value_and_pullback_function(cb::CRDB, x)
    xv = ReverseDiff.value(x)
    yv = cb.func(xv)
    
    function pullback_f(Δ)
        (Δ*ReverseDiff.gradient!(cb.compiled_tape, xv), ) # no space to cache output :/
    end
    return yv, pullback_f
end

Should this backend be a real backend i.e. should it define a @primitive?

Your implementation right now seems to only work for scalar-valued functions. You might want to generalise it and then yes making it a primitive will give you all the other methods for free. Check the ReverseDiff backend implementation in AbstractDifferentiation for reference.

As talked about in JuliaDiff/AbstractDifferentiation.jl#41, caching interfaces should be addressed as this is, I think, where the performance difference comes from.

Try profiling to see where the performance difference comes from. Also try a function with more inputs which might be more representative of when people use ReverseDiff. Most people would not use ReverseDiff for a function of 3 variables. If allocations are the bottleneck in your function, then we need to consider reducing those but let's check first that: 1) that's the case with profiling, and 2) that's a real problem you will run into when using the package for real sized problems.

from reversediff.jl.

mohamed82008 avatar mohamed82008 commented on July 18, 2024

Use the ReverseDiff.@grad macro to define an rrule for any function that has a branch. The rrule can use AbstractDifferentiation to call ReverseDiff again. This will essentially maintain a sub-tape for this particular function with dynamic control flow and will make it work even when the remaining functions' tape is compiled and cached. IIUC, this is roughly equivalent to what you are trying to do with very little engineering work.

from reversediff.jl.

jacobusmmsmit avatar jacobusmmsmit commented on July 18, 2024

Thanks for the reply. Forgive me for not understanding fully, do you think you could expand a little on how @grad could be used in combination with AD to make a "sub-tape"? In this case is the sub-tape also compiled and cached as in my toy implementation?

from reversediff.jl.

mohamed82008 avatar mohamed82008 commented on July 18, 2024

It would not be compiled by default but you can choose to compile 2 different tapes, one for each branch. I think you might also be able to do that lazily.

from reversediff.jl.

jacobusmmsmit avatar jacobusmmsmit commented on July 18, 2024

Could you give me a starting point that I could expand on? I'm not too familiar with AbstractDifferentiation but I'd love to build a usable MVP of this idea.

from reversediff.jl.

mohamed82008 avatar mohamed82008 commented on July 18, 2024

It's not easy. If you are already familiar with ReverseDiff, try reading https://github.com/JuliaDiff/AbstractDifferentiation.jl/blob/master/ext/AbstractDifferentiationReverseDiffExt.jl to understand the AD API. Then you will need to address JuliaDiff/AbstractDifferentiation.jl#41. Then it should be easy to do a MWE. If you are interested to spend time on this, we can schedule a call to go through the work required to get it done.

from reversediff.jl.

jacobusmmsmit avatar jacobusmmsmit commented on July 18, 2024

I had a quick read of the above links as well as the AbstractDifferentiation PR about ReverseDiff. I see that it's a relatively difficult problem to solve at such a high-level (for all backends) due to type stability. I'd be interested in working on it.

from reversediff.jl.

ToucheSir avatar ToucheSir commented on July 18, 2024

Saw the GSoC idea this proposal is referring to, very interesting stuff. One question from me: would this help with being able to represent dynamically-bounded loops on the tape without requiring recompilation? I can think of a few cases related to sequence/time series modelling where it would be nice to not eat tracing + tape compilation latency every time the input length changes. Some mechanism for caching sub-tapes seems like a necessary prerequisite for that, but I'm not sure if it falls under the scope of this proposal.

from reversediff.jl.

jacobusmmsmit avatar jacobusmmsmit commented on July 18, 2024

Base on my (limited) understanding of the problem I think the answer is no. That said, Mohamed may have a better idea to deal with it. Maybe Julia can do more than JAX in this regard?

from reversediff.jl.

mohamed82008 avatar mohamed82008 commented on July 18, 2024

One question from me: would this help with being able to represent dynamically-bounded loops on the tape without requiring recompilation? I can think of a few cases related to sequence/time series modelling where it would be nice to not eat tracing + tape compilation latency every time the input length changes.

If you have a specific example, we can think about it.

from reversediff.jl.

ToucheSir avatar ToucheSir commented on July 18, 2024

The ultimate use case I have in mind is a RNN, but here is a simpler dependency-free example:

function f(xs)
    s = zero(eltype(xs))
    for (i, x) in enumerate(xs)
        s += i * x
    end
    return s
 end

julia> tape = ReverseDiff.GradientTape(f, ones(5))
typename(ReverseDiff.GradientTape)(f)

julia> ReverseDiff.gradient!(tape, ones(5))
5-element Vector{Float64}:
 1.0
 2.0
 3.0
 4.0
 5.0

julia> ReverseDiff.gradient!(tape, ones(3))
5-element Vector{Float64}:
 1.0
 2.0
 3.0
 4.0
 5.0

julia> ReverseDiff.gradient!(tape, ones(10))
ERROR: BoundsError: attempt to access 5-element Vector{Float64} at index [1:10]
Stacktrace:
  [1] throw_boundserror(A::Vector{Float64}, I::Tuple{UnitRange{Int64}})
    @ Base ./abstractarray.jl:744
  [2] checkbounds
    @ ./abstractarray.jl:709 [inlined]
  [3] _copyto_impl!(dest::Vector{Float64}, doffs::Int64, src::Vector{Float64}, soffs::Int64, n::Int64)
    @ Base ./array.jl:325
  [4] copyto!
    @ ./array.jl:319 [inlined]
  [5] copyto!
    @ ./array.jl:342 [inlined]
  [6] value!
    @ ~/.julia/packages/ReverseDiff/wIfrd/src/tracked.jl:156 [inlined]
  [7] seeded_forward_pass!
    @ ~/.julia/packages/ReverseDiff/wIfrd/src/api/tape.jl:41 [inlined]
  [8] gradient!
    @ ~/.julia/packages/ReverseDiff/wIfrd/src/api/gradients.jl:79 [inlined]
  [9] gradient!(tape::ReverseDiff.GradientTape{typeof(f), ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, input::Vector{Float64})
    @ ReverseDiff ~/.julia/packages/ReverseDiff/wIfrd/src/api/gradients.jl:63
 [10] top-level scope
    @ REPL[18]:1

It would be nice to have a way to specify "don't unroll this loop" when tracing so that the same tape could be re-used for different input lengths.

from reversediff.jl.

mohamed82008 avatar mohamed82008 commented on July 18, 2024

For loops are not possible to intercept with ReverseDiff because they are not functions but if wrapped in a function, the function can be intercepted. In this case, you can define an rrule for this function which calls RD with no tape caching. This is possible now with AbstractDifferentiation.

for (i, x) in enumerate(xs)
  s += i * x
end

from reversediff.jl.

ToucheSir avatar ToucheSir commented on July 18, 2024

Thanks Mohamed. I'm aware of the custom rule path, but the hope was to make use of tape caching (or I'd resort to using Zygote). Perhaps this example better describes my motivation:

function scan(f, xs, init)
  ys = empty(xs)
  h = init
  for x in xs
    h, y = f(h, x)
    push!(ys, y)
  end
  return h, ys
end

@jacobusmmsmit probably recognizes this as jax.lax.scan ;)
Per the suggestion, I would have to define a rrule for scan which calls ReverseDiff on f. The problem is then that I need to give up on tape caching altogether. Is there a way to create a tape for f once, compile it and then reuse it on every iteration of the loop (while keeping nice features like mutation tracking)? I could find very little in the way of resources on how to manipulate tapes, so I assumed it would require changes to ReverseDiff itself.

from reversediff.jl.

jacobusmmsmit avatar jacobusmmsmit commented on July 18, 2024

My previous comment was discussing the compiled tape in an uncompiled tape case, but the uncompiled tape in a compiled tape is easier to address. I'm leaving this comment as some documentation of how this is already possible but could use some development to make it easier to use.

At the end I do have a question of how grad works.

Example showing it's already possible

Some setup

using ReverseDiff
using ReverseDiff: TrackedArray, track, @grad, value, GradientTape, compile, gradient!, gradient

First we define a function with branches. Compiling a tape with branches on it is currently a very dangerous operation as it will compile without complaining but silently return the wrong answer.

branching_f(x) = sum(x) > 1 ? sum(x)^ 3 : sum(x)^2
_branching_f(x) = sum(x) > 1 ? sum(x)^ 3 : sum(x)^2 # function used as a reference

Then we define a custom gradient with some logging to show that the right thing is happening each time.

branching_f(x::TrackedArray) = track(branching_f, x)
@grad function branching_f(x)
    xv = value(x)
    function grad_branching(Δ)
        @show sum(xv)
        if sum(xv) > 1
            println("High branch")
            return (3*sum(xv)^2*Δ, )
        else
            println("Low branch")
            return (2*sum(xv)*Δ, )
        end
    end
    return branching_f(xv), grad_branching
end

Now we construct the tapes and test that everything is running as expected:

# Construct and compile the tape
input = [0.0, 1.1, 1.0]
branching_tape = compile(GradientTape(branching_f, input))
_branching_tape = compile(GradientTape(_branching_f, input)) # This tape should ignore the branch

# One input for each branch in the function
input_low = [0.1, 0.2, 0.3]
input_high = [1.1, 1.2, 1.3]

# Test for correctness of implementation
grad_low = gradient(_branching_f, input_low)
grad_high = gradient(_branching_f, input_high)

grad_low == gradient(branching_f, input_low)
grad_high == gradient(branching_f, input_high)

# An example of the method working
grad_low == gradient!(branching_tape, input_low) # true
grad_low == gradient!(_branching_tape, input_low) # false
grad_high == gradient!(branching_tape, input_high) # true
grad_high == gradient!(_branching_tape, input_high) # true (but for the wrong reason)

Where to go from here

So, in a way, there we go. We can do modular tape caching already! But this is all very manual. It would be very nice we could have this done automatically such as:

Automatic detection of branches and a warning

julia> compile(GradientTape(_branching_tape, input))
Warning: woah buddy, you've got a branch in that function of yours, I don't think you want to compile it!

or automatic detection of branches and not compiling the branch sources (not ideal)

julia> compile(GradientTape(outer_function_with_inner_branch, my_input)) # Automatic modularisation
Warning: The tape of `outer_function_with_inner_branch` has branches because of `inner_function`,
this function was not compiled

or allowing users to define static arguments à la JAX

inner_function(x, y) = x > 0 : 2y : 3y^2
sa_inner_function = @static_arguments(inner_function, x)

outer_function_with_inner_branch(z) = sum(z) * sa_inner_function(z[1], z[2])

or ultimately automatic detection of branches and not compiling the branch sources with respect to those arguments

inner_function(x, y) = x > 0 : 2y : 3y^2
outer_function_with_inner_branch(z) = sum(z) * sa_inner_function(z[1], z[2])
compile(GradientTape(outer_function_with_inner_branch, my_input)) # All good, works as if it were uncompiled but with compiled performance where possible.

A question

What I'd like to ask is about how @grad works: Which parts of the @grad function are "frozen" when the tape is compiled? In my testing, everything defined outside of the grad_branching function would be frozen, but I couldn't find any documentation on this in ReverseDiff.

@grad function branching_f(x)
    xv = value(x)
    sum_xv = sum(xv) # This part is constant when compiled
    function grad_branching(Δ)
        (sum_xv > 1 ? 3*sum_xv^2*Δ : 2*sum_xv*Δ,) # Doesn't work at all
    end
    return branching_f(xv), grad_branching
end

from reversediff.jl.

jacobusmmsmit avatar jacobusmmsmit commented on July 18, 2024

Ok, I've got a draft implementation for defining cached sub-tapes:

import AbstractDifferentiation as AD
using ReverseDiff

using ReverseDiff: @grad, compile, GradientTape
import AbstractDifferentiation: primal_value, pullback_function, value_and_pullback_function

struct CachedReverseDiffBackend{F,T} <: AD.AbstractBackend# Could also be parametric in backend type
    func::F
    compiled_tape::T
    # Constructor to compile the tape given inputs
    function CachedReverseDiffBackend(f::F, x) where {F}
        compiled_tape = compile(GradientTape(f, x)) # pseudo RD code
        T = typeof(compiled_tape)
        return new{F,T}(f, compiled_tape)
    end
end

const CRDB = CachedReverseDiffBackend # alias for brevity

(b::CRDB)(x) = call_func(b, x)
call_func(b::CRDB, x) = b.func(x)

function call_func(b::CRDB, x::ReverseDiff.TrackedArray)
    return ReverseDiff.track(call_func, b, x)
end

@grad function call_func(b::CRDB, x)
    return value_and_pullback_function(b, x)
end

primal_value(::CRDB, xs, _) = primal_value(xs) # is this ok?

function value_and_pullback_function(cb::CRDB, x)
    xv = ReverseDiff.value(x)
    yv = cb.func(xv)
    
    function pullback_f(Δ)
        (Δ*ReverseDiff.gradient!(cb.compiled_tape, xv), ) # no space to cache output :/
    end
    return yv, pullback_f
end

Should this backend be a real backend i.e. should it define a @primitive?

Here's an example of how it would be used:

using BenchmarkTools
g(xs) = sum(abs2, xs)
xs = [1.0, 2.0, 3.0]
const crdb = CRDB(g, xs) # must be declared const otherwise type unstable when called
gt = compile(GradientTape(g, xs)) # RD code

# Check gradients work as intended :)
ReverseDiff.gradient(g, xs .+ 1)
ReverseDiff.gradient!(gt, xs .+ 1)
ReverseDiff.gradient!(crdb.compiled_tape, xs .+ 1)
# All return the same thing

# Define an outer function
f_nocompile(xs) = 2g(xs) # use the original `g`
f_compile(xs) = 2crdb(xs) # use the `g` with a compiled gradient

# Primal timings
@btime f_nocompile($xs) #  4.000 ns (0 allocations: 0 bytes)
@btime f_compile($xs) # 4.000 ns (0 allocations: 0 bytes)

# Gradient timings
@btime ReverseDiff.gradient(f_nocompile, $xs) # 961.750 ns (32 allocations: 1.34 KiB)
@btime ReverseDiff.gradient(f_compile, $xs) # 1.092 μs (17 allocations: 1008 bytes)

# Double-compile also works
fnc_tape = compile(GradientTape(f_nocompile, xs))
fc_tape = compile(GradientTape(f_compile, xs))

@btime ReverseDiff.gradient!(fnc_tape, $xs) # 521.266 ns (1 allocation: 80 bytes)
@btime ReverseDiff.gradient!(fc_tape, $xs) # 847.889 ns (3 allocations: 240 bytes)

As talked about in this issue, caching interfaces should be addressed as this is, I think, where the performance difference comes from.

from reversediff.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.