Comments (13)
Yes, this is mentioned in the readme; you can call Zygote.refresh()
to avoid it.
from zygote.jl.
Ah, I missed that.
Worth leaving this open, to track against it?
I feel like hooking into the julia #265 stuff,
should be doable?
from zygote.jl.
I feel like hooking into the julia #265 stuff,
should be doable?
IIRC we hit a roadblock with making the world-age mechanism hookable when we realized that the compiler would need to backprop the new world age bounds in a way it didn't have to before. So it's going to take some redesign of the world-age mechanism to make that feasible.
Also note that there are still some compiler bugs that make it quite possible to trigger miscompiles that manifest as 265-style issues (ref JuliaLabs/Cassette.jl#6, JuliaLang/julia#28595). It's possible that fixing those bugs will fix some of the more blatant 265-style problems without requiring any world age redesign.
from zygote.jl.
How about this as a work around.
While we don't want to call Refesh()
when it is not going to change the answer, since it triggers an expensive recompile.
We do want to do if it would change the answer.
So we track the world-ages of functions,
that we have operated on.
min_world(f) = maximum(mm.min_world for mm in methods(f).ms)
const _tracked_worldages = Dict{Function,Int}()
"""
Refereshes Zygote, if it is required to correctly work with `f`
This avoids expensive recomplilation if it is not required.
Returns true if a refresh was done.
"""
function Zygote.refresh(f)
cur_world = min_world(f)
old_world = get!(_tracked_worldages, f, cur_world)
if cur_world > old_world
_tracked_worldages[f] = cur_world
Zygote.refresh() # trigger refresh
true
else
false
end
end
Then inside each of the calls to functions affected by this (just gradients(f)
?)
insert a call to refresh(f)
.
Demonstration
for proof of concept,
just doing this for my own function:
function der(f::Function)
function (x)
Zygote.refresh(f)
derivative(f, x)
end
end
demo
julia> h(x) = 2x
h (generic function with 1 method)
julia> dh = der(h)
#7 (generic function with 1 method)
julia> dh(1)
2
julia> dh(1)
2
julia> h(x) = 10x
h (generic function with 1 method)
julia> dh(1)
2
julia> g(x) = 2x
g (generic function with 1 method)
julia> h(x) = 2g(x)
h (generic function with 1 method)
julia> dh(1)
4
julia> g(x) = 10x
g (generic function with 1 method)
julia> dh(1)
20
from zygote.jl.
Hmmm, that is not quiet as reliable as it seemed.
it will catch redefinitions of h
I think always,
but it will only sometimes catch redefinitions of g
from zygote.jl.
The following is more reliable, but slower.
It goes and recursively search's the code looking for functions being called.
I think there is a way to speed it up,
especially if we actually only need to worry functions that are calling functions we have called gradient
on before.
Then we can short circuit it in i_min_world(ff::Function)
.
Also for at least the outer call, we have type information, so don't need to check all methods.
function min_world(f)
visitted = Set{Function}()
########################################
# Inner Dispatches
function i_min_world(ff::Function)
ff in visitted && return 0
# we are maxing over all mins, so returning 0 is fine we'll have the true value already in
push!(visitted, ff)
meths=methods(ff).ms
isempty(meths) && return 0
maximum(i_min_world.(meths))
end
function i_min_world(mm::Method)
if isdefined(mm, :generator)
# I don't know how to deal with generated functions
# Can't search all methods since there are an infinite number of them
mm.min_world
else
max(mm.min_world, i_min_world(Base.uncompressed_ast(mm)))
end
end
i_min_world(u_ast::Core.CodeInfo) = maximum(i_min_world.(u_ast.code))
i_min_world(expr::Expr) = expr.head == :call ? maximum(i_min_world.(expr.args)) : 0
function i_min_world(gr::GlobalRef)
func = try
eval(gr)
# This will occasionally error,
# but never for the case we care about AFAIK
# which is when the result is a function
catch err
err isa UndefVarError || rethrow()
return 0
end
i_min_world(func)
end
i_min_world(x::Any) = 0
################################################
i_min_world(f)
end
from zygote.jl.
thoughts, on the code I posted before?
I think it can be made sufficiently smart to solve this without a large cost overhead
from zygote.jl.
So the idea is that we call refresh()
at the entry point of gradient
? And then that can call eval
/invokelatest
on _forward
to get the latest definition.
It seems hard to do this in general. I could see it working in the case that everything is fully well-typed, but what about if the functions that f
calls are not known at compile time? We either make that significantly more expensive (call refresh
again at the boundary) or just push this issue into more complex code (which might actually make it more surprising when it comes up). I could see something like this working, but it would be a relatively invasive change.
from zygote.jl.
So the idea is that we call refresh() at the entry point of gradient?
Yes, we call a version of refresh
before the existing gradient, where this version only refreshes if required according to world_age
s.
And calculating world ages has to be done recursively, since only functions directly changed have the min_world
field updated (I am pretty sure you understand that already, but just for anyone else reading this in the future. Like future-me.)
I could see it working in the case that everything is fully well-typed, but what about if the functions that f calls are not known at compile time?
I assume you mean _if the methods called are not known at compile time.
Since the functions always are (excluding maybe some kinda eval
hackery that I'm not sure is actually possible.).
Yes, if the method called is not known at compile time, the world-age needs to be (recursively) checked for all possibilities. And that set can be lowered somewhat by various things like argument counting (doesn't work for splatting), and knowing some of the argument types, even if not all of them.
The code in #22 (comment)
always checks all methods, it does none of the cutting down of things just mentioned.
and so is quiet slow.
But yes the other way would be to call refresh
at the boundry, I think that makes sense.
There is a timing trade-off between how long it takes to just do a refresh
that is not required, vs how long it takes to recurse the AST of all methods to show that it is not required.
A heuristic to make this much faster would be a max nesting depth for how deep to check for changes in called methods.
Only checking for modifications in the function being calls world age is really fast.
(No recursion, as in min_world(f) = maximum(mm.min_world for mm in methods(f).ms)
)
from zygote.jl.
I assume you mean _if the methods called are not known at compile time. Since the functions always are ...
Slightly contrived, but for example:
function foo(fs)
f = pop!(fs)
f(1, 2)
end
foo(Any[+])
from zygote.jl.
Ah, yes, I see what you mean.
If we can't work out the type of f
there is nothing we can do at compile time.
from zygote.jl.
Would it make sense to call the generators with Core._apply_pure()
as is done in https://github.com/NHDaly/StagedFunctions.jl/blob/master/src/StagedFunctions.jl#L116? You don't need to invalidate later like StagedFunctions.jl does, but it would at least free things up so that if you define relevant rules before calling the generator you wouldn't need to refresh()
. This is the approach I'm taking with willow-ahrens/Finch.jl#176 and it seems to work okay for me so far.
from zygote.jl.
There are some changes to world age handling being discussed in FluxML/IRTools.jl#109, but I lack the know-how to say how they fit into your proposal.
from zygote.jl.
Related Issues (20)
- `sort(x; rev=true)` is not supported HOT 1
- Incorrect gradients for `plan_rfft(x) * x` HOT 2
- Gradient of scalar function of gradient giving mutating array error HOT 4
- `sum` with CUDA and view on array errors HOT 3
- Cannot take gradient of sort on 2D CuArray HOT 1
- `Ref` and broadcasting issue HOT 2
- Strip zygote frames from mutation error stack trace HOT 1
- Zygote in Julia 1.10+ not reading rrules for default constructors HOT 7
- Few issues in the Zygote Home Page documentation HOT 1
- OOM when computing the gradient in an embarrassingly parallel problem HOT 1
- Pullback over jacobian HOT 6
- `withjacobian` flattens the output when it is a matrix HOT 1
- Gradient wrt to a sparse matrix is mathematically wrong HOT 6
- Increasing memory usage in each call of gradient HOT 1
- Precompilation error in Julia nightly HOT 1
- Gradient involving `LinearAlgebra.tr` errors HOT 1
- Device-to-host copies with GPU code HOT 11
- Inferring `Any` on gradient w.r.t. wrapper of recursive type
- Manually changing the Flux parameters and optimizing using Zygote
- Support try/catch by assuming try branch HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from zygote.jl.