Comments (3)
I imagine the best thing to do at the moment is to splice in ForwardDiff for these.
from reversediff.jl.
Sorry, do you mean using ReverseDiff.@forward
? I didn't know about it but it seems to be working well:
julia> using ReverseDiff
julia> ReverseDiff.@forward elu(x) = x > 0 ? x : expm1(x)
elu (generic function with 1 method)
julia> x = ones(3)
3-element Array{Float64,1}:
1.0
1.0
1.0
julia> ∇elu! = ReverseDiff.compile_gradient(x -> sum(elu.(x)), x)
(::#301) (generic function with 1 method)
julia> ∇elu!(similar(x), x)
3-element Array{Float64,1}:
1.0
1.0
1.0
julia> ∇elu!(similar(x), -x)
3-element Array{Float64,1}:
0.367879
0.367879
0.367879
from reversediff.jl.
ReverseDiff.jl's re-recording API (gradient
/gradient!
, jacobian
/jacobian!
, etc.) does support branching, so that's what you should use if your function branches on differentiable input. The performance of this should be similar to (if not better than) other autograd-like approaches.
If you would like to use the pre-recording API (ReverseDiff.compile_gradient!
, ReverseDiff.compile_jacobian!
etc.), then you can splice the branch into ForwardDiff using @forward
(as you've discovered).
In the future, we plan to provide an @source
macro which will allow you to splice reverse-mode source-to-source AD, which may be able to handle simple branches.
from reversediff.jl.
Related Issues (20)
- Record `Broadcast.broadcasted` instead of `Broadcast.broadcast`
- MethodError: ReverseDiff.TrackedReal ... is ambiguous.
- double free crash with multi-threaded code only when using multiple threads
- @grad_from_chainrules macro fails when using multi-output functions HOT 2
- ReverseDiff documentation shows issue that has been fixed? Nested differentiation of a closure? HOT 1
- `MethodError: *(::Diagonal, ::ReverseDiff.TrackedArray)` is ambiguous.
- `@grad_from_chainrules` hygiene: cannot use custom types in method signature HOT 3
- Define `typemin` for tracked reals.
- ReverseDiff defines a huge number of methods. HOT 3
- Nested differentiation of closures yields incorrect results. Any news on the fix?
- Enhancement proposal: Modular tape caching HOT 16
- Bug: Derivative of transposed-vector times matrix is incorrect. HOT 5
- Strange bug when deferring to ChainRules HOT 1
- Add ChainRulesCore RuleConfig? HOT 1
- mean BigFloat precision
- MethodError: vcat(::ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ::Matrix{Float32}) is ambiguous. HOT 4
- Method ambiguities reported by Aqua
- DiffResults objects are not re-aliased properly HOT 2
- ERROR: LoadError: Some tests did not pass: 146 passed, 0 failed, 1 errored, 0 broken. HOT 1
- broken link to doc HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from reversediff.jl.