Code Monkey home page Code Monkey logo

Comments (13)

MikeInnes avatar MikeInnes commented on August 24, 2024

Yes, this is mentioned in the readme; you can call Zygote.refresh() to avoid it.

from zygote.jl.

oxinabox avatar oxinabox commented on August 24, 2024

Ah, I missed that.
Worth leaving this open, to track against it?

I feel like hooking into the julia #265 stuff,
should be doable?

from zygote.jl.

jrevels avatar jrevels commented on August 24, 2024

I feel like hooking into the julia #265 stuff,
should be doable?

ref JuliaLang/julia#27073

IIRC we hit a roadblock with making the world-age mechanism hookable when we realized that the compiler would need to backprop the new world age bounds in a way it didn't have to before. So it's going to take some redesign of the world-age mechanism to make that feasible.

Also note that there are still some compiler bugs that make it quite possible to trigger miscompiles that manifest as 265-style issues (ref JuliaLabs/Cassette.jl#6, JuliaLang/julia#28595). It's possible that fixing those bugs will fix some of the more blatant 265-style problems without requiring any world age redesign.

from zygote.jl.

oxinabox avatar oxinabox commented on August 24, 2024

How about this as a work around.

While we don't want to call Refesh() when it is not going to change the answer, since it triggers an expensive recompile.
We do want to do if it would change the answer.

So we track the world-ages of functions,
that we have operated on.

min_world(f) = maximum(mm.min_world for mm in methods(f).ms)

const _tracked_worldages = Dict{Function,Int}()

"""
Refereshes Zygote, if it is required to correctly work with `f`
This avoids expensive recomplilation if it is not required.
Returns true if a refresh was done.
"""
function Zygote.refresh(f)
	cur_world = min_world(f)
	old_world = get!(_tracked_worldages, f, cur_world)
	if cur_world > old_world
		_tracked_worldages[f] = cur_world
		Zygote.refresh() # trigger refresh
		true
	else
		false
	end
end

Then inside each of the calls to functions affected by this (just gradients(f)?)
insert a call to refresh(f).

Demonstration

for proof of concept,
just doing this for my own function:

function der(f::Function)
	function (x)
		Zygote.refresh(f)
		derivative(f, x)
	end
end

demo

julia> h(x)  = 2x
h (generic function with 1 method)

julia> dh = der(h)
#7 (generic function with 1 method)

julia> dh(1)
2

julia> dh(1)
2

julia> h(x) = 10x
h (generic function with 1 method)

julia> dh(1)
2

julia> g(x) = 2x
g (generic function with 1 method)

julia> h(x) = 2g(x)
h (generic function with 1 method)

julia> dh(1)
4

julia> g(x) = 10x
g (generic function with 1 method)

julia> dh(1)
20

from zygote.jl.

oxinabox avatar oxinabox commented on August 24, 2024

Hmmm, that is not quiet as reliable as it seemed.
it will catch redefinitions of h I think always,
but it will only sometimes catch redefinitions of g

from zygote.jl.

oxinabox avatar oxinabox commented on August 24, 2024

The following is more reliable, but slower.
It goes and recursively search's the code looking for functions being called.

I think there is a way to speed it up,
especially if we actually only need to worry functions that are calling functions we have called gradient on before.
Then we can short circuit it in i_min_world(ff::Function).
Also for at least the outer call, we have type information, so don't need to check all methods.

function min_world(f)
	visitted = Set{Function}()

	########################################
	# Inner Dispatches
	function i_min_world(ff::Function)
		ff in visitted && return 0 
		# we are maxing over all mins, so returning 0 is fine we'll have the true value already in
		push!(visitted, ff)
		
		meths=methods(ff).ms
		isempty(meths) && return 0
		maximum(i_min_world.(meths))
	end
	
	function i_min_world(mm::Method)
		if isdefined(mm, :generator)
			# I don't know how to deal with generated functions
			# Can't search all methods since there are an infinite number of them
			mm.min_world
		else
			max(mm.min_world, i_min_world(Base.uncompressed_ast(mm)))
		end
	end
	
	i_min_world(u_ast::Core.CodeInfo) = maximum(i_min_world.(u_ast.code))

	i_min_world(expr::Expr) = expr.head == :call ? maximum(i_min_world.(expr.args)) : 0 
	
	function i_min_world(gr::GlobalRef)
		func = try 
			eval(gr)
			# This will occasionally error,
			# but never for the case we care about AFAIK
			# which is when the result is a function
		catch err
			err isa UndefVarError || rethrow()
			return 0
		end
		
		i_min_world(func)
	end
	
	i_min_world(x::Any) = 0
	
	################################################
	i_min_world(f)
end

from zygote.jl.

oxinabox avatar oxinabox commented on August 24, 2024

thoughts, on the code I posted before?
I think it can be made sufficiently smart to solve this without a large cost overhead

from zygote.jl.

MikeInnes avatar MikeInnes commented on August 24, 2024

So the idea is that we call refresh() at the entry point of gradient? And then that can call eval/invokelatest on _forward to get the latest definition.

It seems hard to do this in general. I could see it working in the case that everything is fully well-typed, but what about if the functions that f calls are not known at compile time? We either make that significantly more expensive (call refresh again at the boundary) or just push this issue into more complex code (which might actually make it more surprising when it comes up). I could see something like this working, but it would be a relatively invasive change.

from zygote.jl.

oxinabox avatar oxinabox commented on August 24, 2024

So the idea is that we call refresh() at the entry point of gradient?

Yes, we call a version of refresh before the existing gradient, where this version only refreshes if required according to world_ages.

And calculating world ages has to be done recursively, since only functions directly changed have the min_world field updated (I am pretty sure you understand that already, but just for anyone else reading this in the future. Like future-me.)

I could see it working in the case that everything is fully well-typed, but what about if the functions that f calls are not known at compile time?

I assume you mean _if the methods called are not known at compile time.
Since the functions always are (excluding maybe some kinda eval hackery that I'm not sure is actually possible.).

Yes, if the method called is not known at compile time, the world-age needs to be (recursively) checked for all possibilities. And that set can be lowered somewhat by various things like argument counting (doesn't work for splatting), and knowing some of the argument types, even if not all of them.

The code in #22 (comment)
always checks all methods, it does none of the cutting down of things just mentioned.
and so is quiet slow.

But yes the other way would be to call refresh at the boundry, I think that makes sense.
There is a timing trade-off between how long it takes to just do a refresh that is not required, vs how long it takes to recurse the AST of all methods to show that it is not required.

A heuristic to make this much faster would be a max nesting depth for how deep to check for changes in called methods.
Only checking for modifications in the function being calls world age is really fast.
(No recursion, as in min_world(f) = maximum(mm.min_world for mm in methods(f).ms))

from zygote.jl.

MikeInnes avatar MikeInnes commented on August 24, 2024

I assume you mean _if the methods called are not known at compile time. Since the functions always are ...

Slightly contrived, but for example:

function foo(fs)
  f = pop!(fs)
  f(1, 2)
end

foo(Any[+])

from zygote.jl.

oxinabox avatar oxinabox commented on August 24, 2024

Ah, yes, I see what you mean.
If we can't work out the type of f there is nothing we can do at compile time.

from zygote.jl.

willow-ahrens avatar willow-ahrens commented on August 24, 2024

Would it make sense to call the generators with Core._apply_pure() as is done in https://github.com/NHDaly/StagedFunctions.jl/blob/master/src/StagedFunctions.jl#L116? You don't need to invalidate later like StagedFunctions.jl does, but it would at least free things up so that if you define relevant rules before calling the generator you wouldn't need to refresh(). This is the approach I'm taking with willow-ahrens/Finch.jl#176 and it seems to work okay for me so far.

from zygote.jl.

ToucheSir avatar ToucheSir commented on August 24, 2024

There are some changes to world age handling being discussed in FluxML/IRTools.jl#109, but I lack the know-how to say how they fit into your proposal.

from zygote.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.