Comments (11)
Oh, I see. The generator needs access to the tape itself, not just the type information. A generated function won't work there.
from reversediff.jl.
I guess this is happening because compile(::CompiledTape)
creates the forward and backward pass methods by eval()
-ing into the current module. I can work around it by calling @eval ReverseDiff.gradient!(result, tape, x)
, but that causes performance issues.
It seems like fixing this would require changing the way the forward and reverse pass methods are created. Could they perhaps be implemented with an @generated
function instead? The logic inside compile()
looks suspiciously similar to the way @generated
functions behave.
from reversediff.jl.
A less-bad workaround is to define:
forward_pass!(compiled_tape::CompiledTape) = forward_pass!(compiled_tape.tape.tape)
reverse_pass!(compiled_tape::CompiledTape) = reverse_pass!(compiled_tape.tape.tape)
and then not bother calling generate_forward_pass_method
and generate_reverse_pass_method
. The disadvantage (and presumably the reason you didn't do that originally) is that each Instruction
is a different type, so there will be some run-time dispatch.
Here's a crazy idea: if the tape just exists to call forward_exec!
and reverse_exec!
, could you store (in addition to the tape) a vector of FunctionWrappers around closures? Something like
FunctionWrapper{Void, Tuple{}}(() -> forward_exec!(instruction))
Then forward_pass!
could just iterate through the wrappers (which are all of the same type) and call their wrapped functions (which would have the correct dispatch baked in already, I think).
Does that make any sense?
from reversediff.jl.
Here's a more concrete version of what I'm suggesting:
julia> using FunctionWrappers: FunctionWrapper
julia> using BenchmarkTools
julia> # Basic function which will return Float64 for float or int input
f(x) = x + 1.0
f (generic function with 1 method)
julia> # Non-concrete data vector
y = Number[1, 2.0]
2-element Array{Number,1}:
1
2.0
julia> # Wrap the closures in a concretely-typed wrapper
wrappers = [FunctionWrapper{Float64, Tuple{}}(() -> f(x)) for x in y];
julia> [w() for w in wrappers]
2-element Array{Float64,1}:
2.0
3.0
Calling the wrapped functions doesn't allocate:
julia> function forward_pass(wrappers)
for w in wrappers
w()
end
nothing
end
forward_pass (generic function with 1 method)
julia> @benchmark forward_pass($wrappers)
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 24.852 ns (0.00% GC)
median time: 25.743 ns (0.00% GC)
mean time: 28.313 ns (0.00% GC)
maximum time: 128.195 ns (0.00% GC)
--------------
samples: 10000
evals/sample: 996
whereas calling f()
on each element of y
would allocate:
julia> function forward_pass(y)
for x in y
f(x)
end
nothing
end
forward_pass (generic function with 1 method)
julia> @benchmark forward_pass($y)
BenchmarkTools.Trial:
memory estimate: 32 bytes
allocs estimate: 2
--------------
minimum time: 74.362 ns (0.00% GC)
median time: 78.193 ns (0.00% GC)
mean time: 84.891 ns (1.10% GC)
maximum time: 973.529 ns (85.22% GC)
--------------
samples: 10000
evals/sample: 973
from reversediff.jl.
And here's a proof-of-concept implementation: https://github.com/JuliaDiff/ReverseDiff.jl/compare/master...rdeits:no-eval?expand=1 (requires FunctionWrappers.jl master). Tests seem to be passing, but I haven't looked at performance yet.
from reversediff.jl.
Yup, this a documented limitation of ReverseDiff.compile
. I'll check out your proposed solution today and report back, thanks for playing around with this!
from reversediff.jl.
A less-bad workaround is to define:
forward_pass!(compiled_tape::CompiledTape) = forward_pass!(compiled_tape.tape.tape)
reverse_pass!(compiled_tape::CompiledTape) = reverse_pass!(compiled_tape.tape.tape)
and then not bother calling generate_forward_pass_method and generate_reverse_pass_method. The disadvantage (and presumably the reason you didn't do that originally) is that each Instruction is a different type, so there will be some run-time dispatch.
Note that this is exactly what using the un-compiled pre-recorded tape API will do. You're correct in that avoiding runtime dispatch is the main advantage to using compile
. That should probably be in the documentation to make things more clear...
from reversediff.jl.
This is my first time messing about with the FunctionWrappers package, this approach looks really cool! Another benefit would be avoiding the crazy compilation times associated with the functions generated from large tapes. If this does end up working out, we wouldn't even need the compile
function anymore - AFAICT, there'd be no reason not to wrap the instructions by default.
Unfortunately, it seems that there are still some performance issues to work through. Here's the performance of a toy benchmark on master:
julia> using ReverseDiff, BenchmarkTools
# good benchmark for tape traversal performance
julia> function rosenbrock(x)
a = one(eltype(x))
b = 100 * a
result = zero(eltype(x))
for i in 1:length(x)-1
result += (a - x[i])^2 + b*(x[i+1] - x[i]^2)^2
end
return result
end
rosenbrock (generic function with 1 method)
julia> const tape = ReverseDiff.GradientTape(rosenbrock, rand(100))
ReverseDiff.GradientTape(rosenbrock)
julia> const ctape = ReverseDiff.compile(tape)
ReverseDiff.CompiledTape{##660}(rosenbrock)
julia> y, x = zeros(100), rand(100);
julia> @benchmark ReverseDiff.gradient!($y, $tape, $x)
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 122.630 μs (0.00% GC)
median time: 124.411 μs (0.00% GC)
mean time: 124.321 μs (0.00% GC)
maximum time: 171.988 μs (0.00% GC)
--------------
samples: 10000
evals/sample: 1
julia> @benchmark ReverseDiff.gradient!($y, $ctape, $x)
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 67.663 μs (0.00% GC)
median time: 68.532 μs (0.00% GC)
mean time: 68.719 μs (0.00% GC)
maximum time: 161.848 μs (0.00% GC)
--------------
samples: 10000
evals/sample: 1
Performance of compile
on the no-eval
branch:
julia> @benchmark ReverseDiff.gradient!($y, $ctape, $x)
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 199.394 μs (0.00% GC)
median time: 200.976 μs (0.00% GC)
mean time: 201.844 μs (0.00% GC)
maximum time: 482.731 μs (0.00% GC)
--------------
samples: 10000
evals/sample: 1
I haven't dug any deeper than this, but if I were to tinker with it, I'd use a callable struct with an instruction
field instead of a closure, then define call with @inline
. That should at least prevent any performance problems due weird code lowering issues (though I'm not sure that's what's happening here).
from reversediff.jl.
Yes, the callable struct is a great idea! With that change, I'm seeing performance only slightly worse than master for compiled tapes (71us with master
on v0.5, 80us with no-eval
on v0.6).
I've pushed the updates, so you should be able to try it yourself.
from reversediff.jl.
With that change, I'm seeing performance only slightly worse than master for compiled tapes (71us with master on v0.5, 80us with no-eval on v0.6).
That's a totally acceptable regression compared to the benefits of this approach - feel free to open a PR for no-eval
! After such a PR gets merged, I'll explore removing compile
entirely and using the FunctionWrapper approach by default.
from reversediff.jl.
Ok, done: #71
from reversediff.jl.
Related Issues (20)
- Error when using scalar vs. vector to operate on tracked inupt HOT 1
- Record `Broadcast.broadcasted` instead of `Broadcast.broadcast`
- MethodError: ReverseDiff.TrackedReal ... is ambiguous.
- double free crash with multi-threaded code only when using multiple threads
- @grad_from_chainrules macro fails when using multi-output functions HOT 2
- ReverseDiff documentation shows issue that has been fixed? Nested differentiation of a closure? HOT 1
- `MethodError: *(::Diagonal, ::ReverseDiff.TrackedArray)` is ambiguous.
- `@grad_from_chainrules` hygiene: cannot use custom types in method signature HOT 3
- Define `typemin` for tracked reals.
- ReverseDiff defines a huge number of methods. HOT 3
- Nested differentiation of closures yields incorrect results. Any news on the fix?
- Enhancement proposal: Modular tape caching HOT 16
- Bug: Derivative of transposed-vector times matrix is incorrect. HOT 5
- Strange bug when deferring to ChainRules HOT 1
- Add ChainRulesCore RuleConfig? HOT 1
- mean BigFloat precision
- MethodError: vcat(::ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ::Matrix{Float32}) is ambiguous. HOT 4
- Method ambiguities reported by Aqua
- DiffResults objects are not re-aliased properly HOT 2
- ERROR: LoadError: Some tests did not pass: 146 passed, 0 failed, 1 errored, 0 broken. HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from reversediff.jl.