Comments (16)
After a call with Mohamed, I think the implementation we decided to try out will address scan
quite nicely. I'll write up what we discussed here so it's public. For simplicity, I'm assuming that xs
is a shaped array, so its shape is included in its type.
scan
is a function that has a core computation, f
, and some machinery around it to "catch" its outputs. In this case, scan
's output shapes and concrete types are determined once f
, xs
, and init
are known. We want to avoid recompiling the core computation f
every time xs
changes shape.
[API for this part a work in progress, I mean so is this whole thing but this part especially]
To address this, we wrap f
in a callable struct CachedReverseDiffBackend
that contains f
and a compiled tape:
struct CachedReverseDiffBackend{F, T} # Could also be parametric in backend type
func::F
compiled_tape::T
# Constructor to compile the tape given inputs
function CachedReverseDiffBackend(f::F, x) where {F}
compiled_tape = compile(construct_tape(f, x)) # pseudo RD code
T = typeof(compiled_tape)
return new{F, T}(f, compiled_tape)
end
end
compiled_f = CachedReverseDiffBackend(f, x) # where typeof(x) == eltype(xs)
then we make the cached backend callable (with the caveat that @grad
only accepts functions so we define call_func
too):
const CRDB = CachedReverseDiffBackend # alias for brevity
(b::CRDB)(y) = call_func(b, y)
call_func(b::CRDB, y) = b.func(y)
and define a custom rule for our CRDB
structs:
function call_func(b::CRDB, y::TrackedArray)
return ReverseDiff.track(call_func, b, y)
end
import AbstractDifferentiation as AD
ReverseDiff.@grad function call_func(b::CRDB, y)
return AD.value_and_pullback_function(b, y) # to be implemented
end
Now we can pass compiled_f
to scan
instead: scan(compiled_f, xs, init)
, and when we try to differentiate through it with ReverseDiff.gradient
, it will reach compiled_f
inside the loop and see that there's a custom rule for it. The custom rule we defined (making use of AD
) calls ReverseDiff.gradient
on compiled_f
and uses the compiled tape that we created when defining compiled_f = ...
.
So in the end we have an outer uncompiled tape which contains calls to inner compiled tapes.
from reversediff.jl.
First, sorry for the really late response.
In my testing, everything defined outside of the grad_branching function would be frozen, but I couldn't find any documentation on this in ReverseDiff.
Correct. Documentation PRs are welcome :)
The yv = cb.func(xv)
in the code below will break higher order AD. I would use yv = cb(xv)
instead letting dispatch do its thing.
function value_and_pullback_function(cb::CRDB, x)
xv = ReverseDiff.value(x)
yv = cb.func(xv)
function pullback_f(Δ)
(Δ*ReverseDiff.gradient!(cb.compiled_tape, xv), ) # no space to cache output :/
end
return yv, pullback_f
end
Should this backend be a real backend i.e. should it define a
@primitive
?
Your implementation right now seems to only work for scalar-valued functions. You might want to generalise it and then yes making it a primitive will give you all the other methods for free. Check the ReverseDiff backend implementation in AbstractDifferentiation for reference.
As talked about in JuliaDiff/AbstractDifferentiation.jl#41, caching interfaces should be addressed as this is, I think, where the performance difference comes from.
Try profiling to see where the performance difference comes from. Also try a function with more inputs which might be more representative of when people use ReverseDiff. Most people would not use ReverseDiff for a function of 3 variables. If allocations are the bottleneck in your function, then we need to consider reducing those but let's check first that: 1) that's the case with profiling, and 2) that's a real problem you will run into when using the package for real sized problems.
from reversediff.jl.
Use the ReverseDiff.@grad
macro to define an rrule for any function that has a branch. The rrule can use AbstractDifferentiation
to call ReverseDiff again. This will essentially maintain a sub-tape for this particular function with dynamic control flow and will make it work even when the remaining functions' tape is compiled and cached. IIUC, this is roughly equivalent to what you are trying to do with very little engineering work.
from reversediff.jl.
Thanks for the reply. Forgive me for not understanding fully, do you think you could expand a little on how @grad
could be used in combination with AD
to make a "sub-tape"? In this case is the sub-tape also compiled and cached as in my toy implementation?
from reversediff.jl.
It would not be compiled by default but you can choose to compile 2 different tapes, one for each branch. I think you might also be able to do that lazily.
from reversediff.jl.
Could you give me a starting point that I could expand on? I'm not too familiar with AbstractDifferentiation
but I'd love to build a usable MVP of this idea.
from reversediff.jl.
It's not easy. If you are already familiar with ReverseDiff, try reading https://github.com/JuliaDiff/AbstractDifferentiation.jl/blob/master/ext/AbstractDifferentiationReverseDiffExt.jl to understand the AD API. Then you will need to address JuliaDiff/AbstractDifferentiation.jl#41. Then it should be easy to do a MWE. If you are interested to spend time on this, we can schedule a call to go through the work required to get it done.
from reversediff.jl.
I had a quick read of the above links as well as the AbstractDifferentiation PR about ReverseDiff. I see that it's a relatively difficult problem to solve at such a high-level (for all backends) due to type stability. I'd be interested in working on it.
from reversediff.jl.
Saw the GSoC idea this proposal is referring to, very interesting stuff. One question from me: would this help with being able to represent dynamically-bounded loops on the tape without requiring recompilation? I can think of a few cases related to sequence/time series modelling where it would be nice to not eat tracing + tape compilation latency every time the input length changes. Some mechanism for caching sub-tapes seems like a necessary prerequisite for that, but I'm not sure if it falls under the scope of this proposal.
from reversediff.jl.
Base on my (limited) understanding of the problem I think the answer is no. That said, Mohamed may have a better idea to deal with it. Maybe Julia can do more than JAX in this regard?
from reversediff.jl.
One question from me: would this help with being able to represent dynamically-bounded loops on the tape without requiring recompilation? I can think of a few cases related to sequence/time series modelling where it would be nice to not eat tracing + tape compilation latency every time the input length changes.
If you have a specific example, we can think about it.
from reversediff.jl.
The ultimate use case I have in mind is a RNN, but here is a simpler dependency-free example:
function f(xs)
s = zero(eltype(xs))
for (i, x) in enumerate(xs)
s += i * x
end
return s
end
julia> tape = ReverseDiff.GradientTape(f, ones(5))
typename(ReverseDiff.GradientTape)(f)
julia> ReverseDiff.gradient!(tape, ones(5))
5-element Vector{Float64}:
1.0
2.0
3.0
4.0
5.0
julia> ReverseDiff.gradient!(tape, ones(3))
5-element Vector{Float64}:
1.0
2.0
3.0
4.0
5.0
julia> ReverseDiff.gradient!(tape, ones(10))
ERROR: BoundsError: attempt to access 5-element Vector{Float64} at index [1:10]
Stacktrace:
[1] throw_boundserror(A::Vector{Float64}, I::Tuple{UnitRange{Int64}})
@ Base ./abstractarray.jl:744
[2] checkbounds
@ ./abstractarray.jl:709 [inlined]
[3] _copyto_impl!(dest::Vector{Float64}, doffs::Int64, src::Vector{Float64}, soffs::Int64, n::Int64)
@ Base ./array.jl:325
[4] copyto!
@ ./array.jl:319 [inlined]
[5] copyto!
@ ./array.jl:342 [inlined]
[6] value!
@ ~/.julia/packages/ReverseDiff/wIfrd/src/tracked.jl:156 [inlined]
[7] seeded_forward_pass!
@ ~/.julia/packages/ReverseDiff/wIfrd/src/api/tape.jl:41 [inlined]
[8] gradient!
@ ~/.julia/packages/ReverseDiff/wIfrd/src/api/gradients.jl:79 [inlined]
[9] gradient!(tape::ReverseDiff.GradientTape{typeof(f), ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, input::Vector{Float64})
@ ReverseDiff ~/.julia/packages/ReverseDiff/wIfrd/src/api/gradients.jl:63
[10] top-level scope
@ REPL[18]:1
It would be nice to have a way to specify "don't unroll this loop" when tracing so that the same tape could be re-used for different input lengths.
from reversediff.jl.
For loops are not possible to intercept with ReverseDiff because they are not functions but if wrapped in a function, the function can be intercepted. In this case, you can define an rrule for this function which calls RD with no tape caching. This is possible now with AbstractDifferentiation.
for (i, x) in enumerate(xs)
s += i * x
end
from reversediff.jl.
Thanks Mohamed. I'm aware of the custom rule path, but the hope was to make use of tape caching (or I'd resort to using Zygote). Perhaps this example better describes my motivation:
function scan(f, xs, init)
ys = empty(xs)
h = init
for x in xs
h, y = f(h, x)
push!(ys, y)
end
return h, ys
end
@jacobusmmsmit probably recognizes this as jax.lax.scan
;)
Per the suggestion, I would have to define a rrule for scan
which calls ReverseDiff on f
. The problem is then that I need to give up on tape caching altogether. Is there a way to create a tape for f
once, compile it and then reuse it on every iteration of the loop (while keeping nice features like mutation tracking)? I could find very little in the way of resources on how to manipulate tapes, so I assumed it would require changes to ReverseDiff itself.
from reversediff.jl.
My previous comment was discussing the compiled tape in an uncompiled tape case, but the uncompiled tape in a compiled tape is easier to address. I'm leaving this comment as some documentation of how this is already possible but could use some development to make it easier to use.
At the end I do have a question of how grad
works.
Example showing it's already possible
Some setup
using ReverseDiff
using ReverseDiff: TrackedArray, track, @grad, value, GradientTape, compile, gradient!, gradient
First we define a function with branches. Compiling a tape with branches on it is currently a very dangerous operation as it will compile without complaining but silently return the wrong answer.
branching_f(x) = sum(x) > 1 ? sum(x)^ 3 : sum(x)^2
_branching_f(x) = sum(x) > 1 ? sum(x)^ 3 : sum(x)^2 # function used as a reference
Then we define a custom gradient with some logging to show that the right thing is happening each time.
branching_f(x::TrackedArray) = track(branching_f, x)
@grad function branching_f(x)
xv = value(x)
function grad_branching(Δ)
@show sum(xv)
if sum(xv) > 1
println("High branch")
return (3*sum(xv)^2*Δ, )
else
println("Low branch")
return (2*sum(xv)*Δ, )
end
end
return branching_f(xv), grad_branching
end
Now we construct the tapes and test that everything is running as expected:
# Construct and compile the tape
input = [0.0, 1.1, 1.0]
branching_tape = compile(GradientTape(branching_f, input))
_branching_tape = compile(GradientTape(_branching_f, input)) # This tape should ignore the branch
# One input for each branch in the function
input_low = [0.1, 0.2, 0.3]
input_high = [1.1, 1.2, 1.3]
# Test for correctness of implementation
grad_low = gradient(_branching_f, input_low)
grad_high = gradient(_branching_f, input_high)
grad_low == gradient(branching_f, input_low)
grad_high == gradient(branching_f, input_high)
# An example of the method working
grad_low == gradient!(branching_tape, input_low) # true
grad_low == gradient!(_branching_tape, input_low) # false
grad_high == gradient!(branching_tape, input_high) # true
grad_high == gradient!(_branching_tape, input_high) # true (but for the wrong reason)
Where to go from here
So, in a way, there we go. We can do modular tape caching already! But this is all very manual. It would be very nice we could have this done automatically such as:
Automatic detection of branches and a warning
julia> compile(GradientTape(_branching_tape, input))
Warning: woah buddy, you've got a branch in that function of yours, I don't think you want to compile it!
or automatic detection of branches and not compiling the branch sources (not ideal)
julia> compile(GradientTape(outer_function_with_inner_branch, my_input)) # Automatic modularisation
Warning: The tape of `outer_function_with_inner_branch` has branches because of `inner_function`,
this function was not compiled
or allowing users to define static arguments à la JAX
inner_function(x, y) = x > 0 : 2y : 3y^2
sa_inner_function = @static_arguments(inner_function, x)
outer_function_with_inner_branch(z) = sum(z) * sa_inner_function(z[1], z[2])
or ultimately automatic detection of branches and not compiling the branch sources with respect to those arguments
inner_function(x, y) = x > 0 : 2y : 3y^2
outer_function_with_inner_branch(z) = sum(z) * sa_inner_function(z[1], z[2])
compile(GradientTape(outer_function_with_inner_branch, my_input)) # All good, works as if it were uncompiled but with compiled performance where possible.
A question
What I'd like to ask is about how @grad
works: Which parts of the @grad
function are "frozen" when the tape is compiled? In my testing, everything defined outside of the grad_branching
function would be frozen, but I couldn't find any documentation on this in ReverseDiff.
@grad function branching_f(x)
xv = value(x)
sum_xv = sum(xv) # This part is constant when compiled
function grad_branching(Δ)
(sum_xv > 1 ? 3*sum_xv^2*Δ : 2*sum_xv*Δ,) # Doesn't work at all
end
return branching_f(xv), grad_branching
end
from reversediff.jl.
Ok, I've got a draft implementation for defining cached sub-tapes:
import AbstractDifferentiation as AD
using ReverseDiff
using ReverseDiff: @grad, compile, GradientTape
import AbstractDifferentiation: primal_value, pullback_function, value_and_pullback_function
struct CachedReverseDiffBackend{F,T} <: AD.AbstractBackend# Could also be parametric in backend type
func::F
compiled_tape::T
# Constructor to compile the tape given inputs
function CachedReverseDiffBackend(f::F, x) where {F}
compiled_tape = compile(GradientTape(f, x)) # pseudo RD code
T = typeof(compiled_tape)
return new{F,T}(f, compiled_tape)
end
end
const CRDB = CachedReverseDiffBackend # alias for brevity
(b::CRDB)(x) = call_func(b, x)
call_func(b::CRDB, x) = b.func(x)
function call_func(b::CRDB, x::ReverseDiff.TrackedArray)
return ReverseDiff.track(call_func, b, x)
end
@grad function call_func(b::CRDB, x)
return value_and_pullback_function(b, x)
end
primal_value(::CRDB, xs, _) = primal_value(xs) # is this ok?
function value_and_pullback_function(cb::CRDB, x)
xv = ReverseDiff.value(x)
yv = cb.func(xv)
function pullback_f(Δ)
(Δ*ReverseDiff.gradient!(cb.compiled_tape, xv), ) # no space to cache output :/
end
return yv, pullback_f
end
Should this backend be a real backend i.e. should it define a @primitive
?
Here's an example of how it would be used:
using BenchmarkTools
g(xs) = sum(abs2, xs)
xs = [1.0, 2.0, 3.0]
const crdb = CRDB(g, xs) # must be declared const otherwise type unstable when called
gt = compile(GradientTape(g, xs)) # RD code
# Check gradients work as intended :)
ReverseDiff.gradient(g, xs .+ 1)
ReverseDiff.gradient!(gt, xs .+ 1)
ReverseDiff.gradient!(crdb.compiled_tape, xs .+ 1)
# All return the same thing
# Define an outer function
f_nocompile(xs) = 2g(xs) # use the original `g`
f_compile(xs) = 2crdb(xs) # use the `g` with a compiled gradient
# Primal timings
@btime f_nocompile($xs) # 4.000 ns (0 allocations: 0 bytes)
@btime f_compile($xs) # 4.000 ns (0 allocations: 0 bytes)
# Gradient timings
@btime ReverseDiff.gradient(f_nocompile, $xs) # 961.750 ns (32 allocations: 1.34 KiB)
@btime ReverseDiff.gradient(f_compile, $xs) # 1.092 μs (17 allocations: 1008 bytes)
# Double-compile also works
fnc_tape = compile(GradientTape(f_nocompile, xs))
fc_tape = compile(GradientTape(f_compile, xs))
@btime ReverseDiff.gradient!(fnc_tape, $xs) # 521.266 ns (1 allocation: 80 bytes)
@btime ReverseDiff.gradient!(fc_tape, $xs) # 847.889 ns (3 allocations: 240 bytes)
As talked about in this issue, caching interfaces should be addressed as this is, I think, where the performance difference comes from.
from reversediff.jl.
Related Issues (20)
- Record `Broadcast.broadcasted` instead of `Broadcast.broadcast`
- MethodError: ReverseDiff.TrackedReal ... is ambiguous.
- double free crash with multi-threaded code only when using multiple threads
- @grad_from_chainrules macro fails when using multi-output functions HOT 2
- ReverseDiff documentation shows issue that has been fixed? Nested differentiation of a closure? HOT 1
- `MethodError: *(::Diagonal, ::ReverseDiff.TrackedArray)` is ambiguous.
- `@grad_from_chainrules` hygiene: cannot use custom types in method signature HOT 3
- Define `typemin` for tracked reals.
- ReverseDiff defines a huge number of methods. HOT 3
- Nested differentiation of closures yields incorrect results. Any news on the fix?
- Bug: Derivative of transposed-vector times matrix is incorrect. HOT 5
- Strange bug when deferring to ChainRules HOT 1
- Add ChainRulesCore RuleConfig? HOT 1
- mean BigFloat precision
- MethodError: vcat(::ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ::Matrix{Float32}) is ambiguous. HOT 4
- Method ambiguities reported by Aqua
- DiffResults objects are not re-aliased properly HOT 2
- ERROR: LoadError: Some tests did not pass: 146 passed, 0 failed, 1 errored, 0 broken. HOT 1
- broken link to doc HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from reversediff.jl.