Comments (3)
The attached git patch can probably do the job.
I can make a PR if you are interested.
commit f22039512e6cbd10f1f19bb4a6368f87fcbab2ba
Author: ray@omni <[email protected]>
Date: Sun Mar 13 19:20:24 2022 -0400
Added `withgradient!`.
diff --git a/src/api/gradients.jl b/src/api/gradients.jl
index ed92cd4..ac58027 100644
--- a/src/api/gradients.jl
+++ b/src/api/gradients.jl
@@ -80,3 +80,39 @@ function gradient!(result, tape::Union{GradientTape,CompiledGradient}, input)
seeded_reverse_pass!(result, tape)
return result
end
+
+"""
+ ReverseDiff.withgradient!(tape::Union{GradientTape,CompiledGradient}, input)
+
+If `input` is an `AbstractArray`, assume `tape` represents a function of the form
+`f(::AbstractArray)::Real` and return `∇f(input)`.
+
+If `input` is a tuple of `AbstractArray`s, assume `tape` represents a function of the form
+`f(::AbstractArray...)::Real` and return a `Tuple` where the `i`th element is the gradient
+of `f` w.r.t. `input[i].`
+"""
+function withgradient!(tape::Union{GradientTape,CompiledGradient}, input)
+ result = construct_result(input_hook(tape))
+ tmp = copy(value(output_hook(tape)))
+ output = isa(tmp, Number) ? [tmp] : copy(tmp)
+ withgradient!(output, result, tape, input)
+ return output, result
+end
+
+"""
+ ReverseDiff.withgradient!(output, result, tape::Union{GradientTape,CompiledGradient}, input)
+
+Returns `result`, and `output`. This method is like `ReverseDiff.gradient!(tape, input)`, except it
+stores the resulting gradient(s) in `result` rather than allocating new memory.
+
+`result` and `output` can be an `AbstractArray` or a `Tuple` of `AbstractArray`s. The `result` (or any
+of its elements, if `isa(result, Tuple)`), can also be a `DiffResults.DiffResult`, in which
+case the primal value `f(input)` (or `f(input...)`, if `isa(input, Tuple)`) will be stored
+in it as well.
+"""
+function withgradient!(output, result, tape::Union{GradientTape,CompiledGradient}, input)
+ seeded_forward_pass!(tape, input)
+ seeded_reverse_pass!(result, tape)
+ copyto!(output, value(output_hook(tape)))
+ return output, result
+end
from reversediff.jl.
This functionality is already supported through DiffResults:
julia> using ReverseDiff, DiffResults
julia> f(x) = sum(sin, x) + prod(tan, x) * sum(sqrt, x);
julia> x = rand(4);
julia> results = DiffResults.GradientResult(x);
julia> ReverseDiff.gradient!(results, f, x);
julia> f(x)
1.3363694662279235
julia> DiffResults.value(results)
1.3363694662279235
julia> ReverseDiff.gradient(f, x)
4-element Vector{Float64}:
1.0840368499697293
0.6648496332827732
1.1191780857095852
1.0526226868764572
julia> DiffResults.gradient(results)
4-element Vector{Float64}:
1.0840368499697293
0.6648496332827732
1.1191780857095852
1.0526226868764572
from reversediff.jl.
Thanks!
from reversediff.jl.
Related Issues (20)
- Error when using scalar vs. vector to operate on tracked inupt HOT 1
- Record `Broadcast.broadcasted` instead of `Broadcast.broadcast`
- MethodError: ReverseDiff.TrackedReal ... is ambiguous.
- double free crash with multi-threaded code only when using multiple threads
- @grad_from_chainrules macro fails when using multi-output functions HOT 2
- ReverseDiff documentation shows issue that has been fixed? Nested differentiation of a closure? HOT 1
- `MethodError: *(::Diagonal, ::ReverseDiff.TrackedArray)` is ambiguous.
- `@grad_from_chainrules` hygiene: cannot use custom types in method signature HOT 3
- Define `typemin` for tracked reals.
- ReverseDiff defines a huge number of methods. HOT 3
- Nested differentiation of closures yields incorrect results. Any news on the fix?
- Enhancement proposal: Modular tape caching HOT 16
- Bug: Derivative of transposed-vector times matrix is incorrect. HOT 5
- Strange bug when deferring to ChainRules HOT 1
- Add ChainRulesCore RuleConfig? HOT 1
- mean BigFloat precision
- MethodError: vcat(::ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ::Matrix{Float32}) is ambiguous. HOT 4
- Method ambiguities reported by Aqua
- DiffResults objects are not re-aliased properly HOT 2
- ERROR: LoadError: Some tests did not pass: 146 passed, 0 failed, 1 errored, 0 broken. HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from reversediff.jl.