Comments (4)
I don't know whether it tackles the problem as you'd have wished, but there are a few ad hoc solutions to that:
First, compute the vector * matrix product using the sum of the element-wise product
julia> ReverseDiff.gradient(x->sum(x.*A,1) * x,x)
2-element Array{Float64,1}:
4.57215
10.7736
The other possibility could be to compute the products separately
julia> ReverseDiff.gradient(x->Array(x.'*A) * x,x)
2-element Array{Float64,1}:
4.57215
10.7736
Without the Array()
the product returns a RowVector{Float64,Array{Float64,1}}
which makes ReverseDiff unhappy...
from reversediff.jl.
This is the classic "new array types which define ambiguous method definitions" problem. It might be fixed just by adding :RowVector
to the ambiguity list.
from reversediff.jl.
@ovimo Thanks. I'm avoiding the problem with similar techniques you suggested.
@jrevels I've tried your suggestion but it has still ambiguity problem.
diff --git a/src/ReverseDiff.jl b/src/ReverseDiff.jl
index 6b0635c..e472e73 100644
--- a/src/ReverseDiff.jl
+++ b/src/ReverseDiff.jl
@@ -21,7 +21,7 @@ end
# Not all operations will be valid over all of these types, but that's okay; such cases
# will simply error when they hit the original operation in the overloaded definition.
-const ARRAY_TYPES = (:AbstractArray, :AbstractVector, :AbstractMatrix, :Array, :Vector, :Matrix)
+const ARRAY_TYPES = (:AbstractArray, :AbstractVector, :AbstractMatrix, :Array, :Vector, :Matrix, :RowVector)
const REAL_TYPES = (:Bool, :Integer, :Rational, :BigFloat, :BigInt, :AbstractFloat, :Real, :Dual)
const FORWARD_UNARY_SCALAR_FUNCS = (ForwardDiff.AUTO_DEFINED_UNARY_FUNCS..., :-, :abs, :conj)
julia> using ReverseDiff
INFO: Recompiling stale cache file /Users/kenta/.julia/lib/v0.6/ReverseDiff.ji for module ReverseDiff.
julia> const A = [1.0 2.0; 2.0 5.0]
2×2 Array{Float64,2}:
1.0 2.0
2.0 5.0
julia> quadratic(x) = x' * A * x
quadratic (generic function with 1 method)
julia> ReverseDiff.gradient(quadratic, ones(2))
ERROR: MethodError: *(::RowVector{ReverseDiff.TrackedReal{Float64,Float64,ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}},ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}}, ::ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}) is ambiguous. Candidates:
*(x::RowVector, y::ReverseDiff.TrackedArray{V,D,N,VA,DA} where DA where VA where N) where {V, D} in ReverseDiff at /Users/kenta/.julia/v0.6/ReverseDiff/src/derivatives/linalg/arithmetic.jl:193
*(x::AbstractArray{T,2} where T, y::ReverseDiff.TrackedArray{V,D,N,VA,DA} where DA where VA where N) where {V, D} in ReverseDiff at /Users/kenta/.julia/v0.6/ReverseDiff/src/derivatives/linalg/arithmetic.jl:193
*(x::AbstractArray, y::ReverseDiff.TrackedArray{V,D,N,VA,DA} where DA where VA where N) where {V, D} in ReverseDiff at /Users/kenta/.julia/v0.6/ReverseDiff/src/derivatives/linalg/arithmetic.jl:193
*(rowvec::RowVector{T,V} where V<:(AbstractArray{T,1} where T), vec::AbstractArray{T,1}) where T<:Real in Base.LinAlg at linalg/rowvector.jl:170
Possible fix, define
*(::RowVector{ReverseDiff.TrackedReal{V,D,ReverseDiff.TrackedArray{V,D,1,VA,DA}},V} where V<:(AbstractArray{T,1} where T), ::ReverseDiff.TrackedArray{V,D,1,VA,DA})
Stacktrace:
[1] * at ./operators.jl:424 [inlined]
[2] quadratic(::ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}) at ./REPL[3]:1
[3] Type at /Users/kenta/.julia/v0.6/ReverseDiff/src/api/tape.jl:199 [inlined]
[4] gradient(::Function, ::Array{Float64,1}, ::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}}) at /Users/kenta/.julia/v0.6/ReverseDiff/src/api/gradients.jl:22 (repeats 2 times)
from reversediff.jl.
This no longer gives an error. But if A
is not symmetric, it gives wrong results, as seen in this discourse thread:
julia> const A3 = [1.0 2.0; 7.0 5.0];
julia> quadratic3(x) = x' * A3 * x;
julia> ReverseDiff.gradient(quadratic3, ones(2)) # wrong
2-element Vector{Float64}:
16.0
14.0
julia> ForwardDiff.gradient(quadratic3, ones(2))
2-element Vector{Float64}:
11.0
19.0
julia> Zygote.gradient(quadratic3, ones(2))[1]
2-element Vector{Float64}:
11.0
19.0
julia> ReverseDiff.gradient(x -> dot(x, A3, x), ones(2)) # dot works
2-element Vector{Float64}:
11.0
19.0
(jl_N4cJfW) pkg> st
...
[37e2e3b7] ReverseDiff v1.14.1
from reversediff.jl.
Related Issues (20)
- Error when using scalar vs. vector to operate on tracked inupt HOT 1
- Record `Broadcast.broadcasted` instead of `Broadcast.broadcast`
- MethodError: ReverseDiff.TrackedReal ... is ambiguous.
- double free crash with multi-threaded code only when using multiple threads
- @grad_from_chainrules macro fails when using multi-output functions HOT 2
- ReverseDiff documentation shows issue that has been fixed? Nested differentiation of a closure? HOT 1
- `MethodError: *(::Diagonal, ::ReverseDiff.TrackedArray)` is ambiguous.
- `@grad_from_chainrules` hygiene: cannot use custom types in method signature HOT 3
- Define `typemin` for tracked reals.
- ReverseDiff defines a huge number of methods. HOT 3
- Nested differentiation of closures yields incorrect results. Any news on the fix?
- Enhancement proposal: Modular tape caching HOT 16
- Bug: Derivative of transposed-vector times matrix is incorrect. HOT 5
- Strange bug when deferring to ChainRules HOT 1
- Add ChainRulesCore RuleConfig? HOT 1
- mean BigFloat precision
- MethodError: vcat(::ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ::Matrix{Float32}) is ambiguous. HOT 4
- Method ambiguities reported by Aqua
- DiffResults objects are not re-aliased properly HOT 2
- ERROR: LoadError: Some tests did not pass: 146 passed, 0 failed, 1 errored, 0 broken. HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from reversediff.jl.