Code Monkey home page Code Monkey logo

Comments (5)

ChrisRackauckas avatar ChrisRackauckas commented on July 17, 2024

Hey, that's a pretty cool project. We'd love to help out. I was actually just running the torchsde benchmarks today to see how it turned out and those benchmarks convinced me that we should really make sure to contribute more to the Python community. So I was planning to try and get some things up and running during this year's JuliaCon.

One of the main things we'd like to do is make the installation more automatic. @christopher-dG do you know much about Python build systems? I am wondering if we can somehow get pyjulia vendering Julia itself, kind of like Conda.jl, so diffeqpy could be a full instsallation from pip. I was looking to see if PackageCompiler can do a static compilation of the ODE solvers too, but let's ignore that for now and look at the lower hanging fruit.

@Zymrael do you know much about direct definitions of adjoints for PyTorch? We refactored the Julia side a bit ago in preparation for this combination, so what it looks like is this. solve has a quick step that lowers to solve_up:

https://github.com/SciML/DiffEqBase.jl/blob/master/src/solve.jl#L98-L103

and then the adjoint is defined as:

https://github.com/SciML/DiffEqBase.jl/blob/master/src/solve.jl#L262-L266

The internal function _solve_adjoint takes care of the rest of the plumbing in DiffEqSensitivity.jl, doing:

https://github.com/SciML/DiffEqSensitivity.jl/blob/master/src/local_sensitivity/concrete_solve.jl#L33-L136

for example as the "standard" adjoint (covering the big 3 we will want: QuadratureAdjoint, InterpolatingAdjoint, and BacksolveAdjoint) with all of the keyword argument and saving handling. So given how that's refactored, I think we could do this in like just 2 function definitions in PyTorch, I just need to figure out that interface.

Now the next difficulty involved is going to be that these adjoints will use Zygote for the vjp calculations by default. The defaults are handled at the top of that file:

https://github.com/SciML/DiffEqSensitivity.jl/blob/master/src/local_sensitivity/concrete_solve.jl#L7-L26

Since this is using Zygote or ReverseDiff vjps, this might not directly work at first. So step one I think we'd test this out where the function is eval'd to live in Julia and get that working, and then try to get torch working in the vjps (which has been demonstrated before). That should be all it takes.

Do you think the overheads would be small enough such that a switch to a Julia diffeq solvers backend would benefit the PyTorch userbase?

I think there's a two-pronged approach we can take. For small problems, problems which are non-heterogeneous, like chemical combustion or quantitative systems pharmacology models, torch JIT doesn't seem to do very well. On these problems, what I want to do is hijack the functional form via ModelingToolkit.jl (https://github.com/SciML/ModelingToolkit.jl) to then directly compile the version in Julia with sparsity and all of that jazz (see my coming JuliaCon talk for more details on this system). This is the aspect I've been working on the most in terms of pyjulia performance.

That would cover a lot of scientific modeling, but I don't think that is necessary for your case which is more big heterogeneous matmul neural ODE models. In that case, the overhead should be minimal to non-existent since pyjulia passes to Julia by reference and not by copying, and so as long as the two AD systems connect well for the vjp we should be in the asymtopically large matmul case. For GPUs I think we might need to connect to https://github.com/TuringLang/ThArrays.jl to give it the right overloads in Julia but that shouldn't be fairly difficult either (pinging @KDr2 who may be interested in helping)

from diffeqpy.

christopher-dG avatar christopher-dG commented on July 17, 2024

While I can't claim to be an expert, I do know some things. I could certainly look into helping out here.

from diffeqpy.

KDr2 avatar KDr2 commented on July 17, 2024

It's my pleasure to help too if necessary.

from diffeqpy.

Zymrael avatar Zymrael commented on July 17, 2024

Thank you for the detailed response!

A first huge step would be being able to access the .jl solvers for solve steps within our current API. Before redesigning everything to include additional options / kwargs (e.g Callbacks for event handling), we should verify that the two AD systems can connect well.

We also have an adjoint class in torchdyn (https://github.com/DiffEqML/torchdyn/blob/master/torchdyn/sensitivity/adjoint.py). It shouldn't be too big of a problem for us to refactor the API (or provide another option), modifying it to follow DiffEqBase.jl conventions.

from diffeqpy.

ChrisRackauckas avatar ChrisRackauckas commented on July 17, 2024

MTK performance discussion can continue in #57 and the vendering of Julia discussion to JuliaPy/pyjulia#118

from diffeqpy.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.