Code Monkey home page Code Monkey logo

Comments (9)

oxinabox avatar oxinabox commented on August 26, 2024

I think this is something to keep in mind.
Having rewritten the core of ChainRules once this month,
I'ld rather not do it again for a few minor versions.
Also I will be wanting to put out the metaprogramming helper macros #44,
before we tackle this anyway.
Since solving it will likely need to be done by using them.

This also relates to the fact you can't actually pass AbstractDifferentials as input to most pullbacks, they need to be externed normally anyway.

Interesting idea: if we did go to pullback as a global function, taking a signature, and (Ȳ) and some extra information, we might be able to encode that extra information as a closure that has the default implementation.

Ideally, we would dispatch the rrule itself based on what would be passed to its pullback.
Since we might also want to do the forward-pass differently and capture different information to give to that pullback.

But that requires global information, some of which violates the halting problem.
So we can't do that,
but we should think about how useful the ability to add new pullbacks without adding new rrules is.

The input to the pullback (Ȳ) has to be a very similar type (need to be able to subtract them, I think?) to the output of the forward pass (Y).
And we haven't really implemented #8 yet. So I don't entirely know what that looks like, and if for example we will end up with NamedTuples that want to be used as Matrix's

Definately this issue is one to think on

from chainrulescore.jl.

oxinabox avatar oxinabox commented on August 26, 2024

How about this.
Rather than using out pullback directly as the canonial way to propagrate gradient,
we have a function backpropagate,
that takes the sig, the pullback, and the Ȳ.
By default it just calls the pullback but it can be overloaded.

It can solve a few problems:

  1. externing befause calling pullbacks
  2. Not wasting time calling pullback of Ȳ is Zero.
  3. Overloading backpropagate for different Ȳ types. And just using the pullback closure as a collection of fields.

We also need a function forward to help track signature types.

function forward(f::F, args...)
    ret, default_pullback = rrule(f, args...)
    sig = Tuple{F, typeof.(args)...}
    return ret, default_pullback, sig
end 

So this is what it might looklike:

backpropagate(sig, default_pullback, Ȳ) = default_pullback(extern(Ȳ))
backpropagate(sig, pullback_info, Ȳ::Zero) = Zero()

function backpropagate(sig::Tuple{*, Special1, Special2}, pullback_info, Ȳ)
    # Even though pullback_info is a closure, we never call it,
    # it might as well be a NamedTuple.

    Ā = @thunk(g(Ȳ, pullback_info.A))
    B̄ = h(Ȳ, pullback_info.B)
    return NO_FIELDS, Ā, B̄
end

Bonus fact, that may or maynot apply to storing sig as part of namedtuple
storing it as a tuple means getting the covarient types.

It might kinda be part of replacing accumulate, as I am not sure how that works in the new world. (it isn't broken, i am just unsure how useful it is)

from chainrulescore.jl.

MikeInnes avatar MikeInnes commented on August 26, 2024

I'd be interested in more specifics of the use case for this and the kind of extensibility you need. My main issue with a separate pullback function is that it simulates closures (i.e. a bundle of data + code) anyway; you're going to end up with something equivalent but much less nice to use.

function forward(::typeof(*), A::AbstractMatrix, B::AbstractMatrix)
    return A * B, C -> pullback((signature=(typeof(*), typeof(A), typeof(B)), A=A, B=B), C)
end

Of course, writing things out this way isn't that helpful if you have to do it for every rule. But the only real difference is that the pullback has a name, which gives you an interface to overload it. We could get the same effect with something like (with appropriate sugar)

P = pullback_name(Tuple{typeof(*),AbstractMatrix,AbstractMatrix})
(::typeof(P))(Y) = ...

However, while this would solve the problem in a general way, I'm sceptical that even this is really needed. Why can't you make the adjoint a named type, rather than a named tuple, and overload * and +? That's all you need to support almost every rule in one go, and it seems unlikely that you'd want to change the actual meaning of the pullback, as opposed to just making linear maps work on your custom type.

from chainrulescore.jl.

oxinabox avatar oxinabox commented on August 26, 2024

Why can't you make the adjoint a named type, rather than a named tuple, and overload * and +?

That will be the case with #8 , I am calling that DNamedTuple.


as opposed to just making linear maps work on your custom type.

The main case is that being a linear map practically is not enough for some pullbacks.
E.g. some of the LinearAlgebra ones want you to support various BLAS operations,
or factorizations.
So conceptually I can imagine some types that might show up at some point don't support the operations in the default pullbacks but have their own wierd ways to do the same thing.
Or that for them the same thing is actually way faster if expressed in a different way.

from chainrulescore.jl.

willtebbutt avatar willtebbutt commented on August 26, 2024

Why can't you make the adjoint a named type, rather than a named tuple, and overload * and +?

This is tricky to do when mixing-and-matching custom adjoints with automatically derived ones as automatically derived rules will always produce a NamedTuple.

from chainrulescore.jl.

MikeInnes avatar MikeInnes commented on August 26, 2024

This is tricky to do when mixing-and-matching custom adjoints with automatically derived ones as automatically derived rules will always produce a NamedTuple.

Sure, but there's a finite number of such adjoints (I think just getproperty – are there others?) and an unlimited number of pullbacks that would otherwise need to have their behaviour overloaded.

Customising your adjoint type is something we can define a clear interface for and it'll work with custom adjoints defined in outside packages, whereas if you manually override each pullback it's only going to work with the set you specifically overloaded.

from chainrulescore.jl.

oxinabox avatar oxinabox commented on August 26, 2024

I am in favour of waiting and seeing.

We will certainly have a way to convert a NamedTuple to a DNamedTuple as part of #8 anyway.
if we need more we will deal with that then.

Particularly, since under my plan the extra info you need will be housed in the default closure propagator anyway.
(Big fan of this notion still, since it makes it easy for rule writers to include the relevant information. Since they use it there and then)

from chainrulescore.jl.

oxinabox avatar oxinabox commented on August 26, 2024

Here is an instance of this in the wild for Zygote
https://github.com/FluxML/Zygote.jl/blob/9af896e5eb9539adf7161ca3cadf4af9dfce0723/src/lib/array.jl#L388-L393

It special cases AbstractMatrix and NamedTuple
but what is to say some future package won't want similar kind of special treatment?

So I think we should do
#53 (comment)

from chainrulescore.jl.

oxinabox avatar oxinabox commented on August 26, 2024

Because of math reasons, it is very rare to get a unexpected type being passed to the pullback.
We thus do not in general have an extensibility problem.
The tangent type provided is pretty much determined by the output primal type.
Which in turn is determined by the primal input types for type stable functions.

Thus generally one just adds another rrule

There is a bit more to this story w.r.t. arrays but we have a bit more of that story encoded eg. in ProjectTo

from chainrulescore.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.