Code Monkey home page Code Monkey logo

Comments (3)

tianrluo avatar tianrluo commented on July 18, 2024

The attached git patch can probably do the job.
I can make a PR if you are interested.

commit f22039512e6cbd10f1f19bb4a6368f87fcbab2ba
Author: ray@omni <[email protected]>
Date:   Sun Mar 13 19:20:24 2022 -0400

    Added `withgradient!`.

diff --git a/src/api/gradients.jl b/src/api/gradients.jl
index ed92cd4..ac58027 100644
--- a/src/api/gradients.jl
+++ b/src/api/gradients.jl
@@ -80,3 +80,39 @@ function gradient!(result, tape::Union{GradientTape,CompiledGradient}, input)
     seeded_reverse_pass!(result, tape)
     return result
 end
+
+"""
+    ReverseDiff.withgradient!(tape::Union{GradientTape,CompiledGradient}, input)
+
+If `input` is an `AbstractArray`, assume `tape` represents a function of the form
+`f(::AbstractArray)::Real` and return `∇f(input)`.
+
+If `input` is a tuple of `AbstractArray`s, assume `tape` represents a function of the form
+`f(::AbstractArray...)::Real` and return a `Tuple` where the `i`th element is the gradient
+of `f` w.r.t. `input[i].`
+"""
+function withgradient!(tape::Union{GradientTape,CompiledGradient}, input)
+    result = construct_result(input_hook(tape))
+    tmp = copy(value(output_hook(tape)))
+    output = isa(tmp, Number) ? [tmp] : copy(tmp)
+    withgradient!(output, result, tape, input)
+    return output, result
+end
+
+"""
+    ReverseDiff.withgradient!(output, result, tape::Union{GradientTape,CompiledGradient}, input)
+
+Returns `result`, and `output`. This method is like `ReverseDiff.gradient!(tape, input)`, except it
+stores the resulting gradient(s) in `result` rather than allocating new memory.
+
+`result` and `output` can be an `AbstractArray` or a `Tuple` of `AbstractArray`s. The `result` (or any
+of its elements, if `isa(result, Tuple)`), can also be a `DiffResults.DiffResult`, in which
+case the primal value `f(input)` (or `f(input...)`, if `isa(input, Tuple)`) will be stored
+in it as well.
+"""
+function withgradient!(output, result, tape::Union{GradientTape,CompiledGradient}, input)
+    seeded_forward_pass!(tape, input)
+    seeded_reverse_pass!(result, tape)
+    copyto!(output, value(output_hook(tape)))
+    return output, result
+end

from reversediff.jl.

devmotion avatar devmotion commented on July 18, 2024

This functionality is already supported through DiffResults:

julia> using ReverseDiff, DiffResults

julia> f(x) = sum(sin, x) + prod(tan, x) * sum(sqrt, x);

julia> x = rand(4);

julia> results = DiffResults.GradientResult(x);

julia> ReverseDiff.gradient!(results, f, x);

julia> f(x)
1.3363694662279235

julia> DiffResults.value(results)
1.3363694662279235

julia> ReverseDiff.gradient(f, x)
4-element Vector{Float64}:
 1.0840368499697293
 0.6648496332827732
 1.1191780857095852
 1.0526226868764572

julia> DiffResults.gradient(results)
4-element Vector{Float64}:
 1.0840368499697293
 0.6648496332827732
 1.1191780857095852
 1.0526226868764572

from reversediff.jl.

tianrluo avatar tianrluo commented on July 18, 2024

Thanks!

from reversediff.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.