Comments (5)
Hey, that's a pretty cool project. We'd love to help out. I was actually just running the torchsde benchmarks today to see how it turned out and those benchmarks convinced me that we should really make sure to contribute more to the Python community. So I was planning to try and get some things up and running during this year's JuliaCon.
One of the main things we'd like to do is make the installation more automatic. @christopher-dG do you know much about Python build systems? I am wondering if we can somehow get pyjulia vendering Julia itself, kind of like Conda.jl, so diffeqpy could be a full instsallation from pip. I was looking to see if PackageCompiler can do a static compilation of the ODE solvers too, but let's ignore that for now and look at the lower hanging fruit.
@Zymrael do you know much about direct definitions of adjoints for PyTorch? We refactored the Julia side a bit ago in preparation for this combination, so what it looks like is this. solve
has a quick step that lowers to solve_up
:
https://github.com/SciML/DiffEqBase.jl/blob/master/src/solve.jl#L98-L103
and then the adjoint is defined as:
https://github.com/SciML/DiffEqBase.jl/blob/master/src/solve.jl#L262-L266
The internal function _solve_adjoint
takes care of the rest of the plumbing in DiffEqSensitivity.jl, doing:
for example as the "standard" adjoint (covering the big 3 we will want: QuadratureAdjoint, InterpolatingAdjoint, and BacksolveAdjoint) with all of the keyword argument and saving handling. So given how that's refactored, I think we could do this in like just 2 function definitions in PyTorch, I just need to figure out that interface.
Now the next difficulty involved is going to be that these adjoints will use Zygote for the vjp calculations by default. The defaults are handled at the top of that file:
Since this is using Zygote or ReverseDiff vjps, this might not directly work at first. So step one I think we'd test this out where the function is eval'd to live in Julia and get that working, and then try to get torch working in the vjps (which has been demonstrated before). That should be all it takes.
Do you think the overheads would be small enough such that a switch to a Julia diffeq solvers backend would benefit the PyTorch userbase?
I think there's a two-pronged approach we can take. For small problems, problems which are non-heterogeneous, like chemical combustion or quantitative systems pharmacology models, torch JIT doesn't seem to do very well. On these problems, what I want to do is hijack the functional form via ModelingToolkit.jl (https://github.com/SciML/ModelingToolkit.jl) to then directly compile the version in Julia with sparsity and all of that jazz (see my coming JuliaCon talk for more details on this system). This is the aspect I've been working on the most in terms of pyjulia performance.
That would cover a lot of scientific modeling, but I don't think that is necessary for your case which is more big heterogeneous matmul neural ODE models. In that case, the overhead should be minimal to non-existent since pyjulia passes to Julia by reference and not by copying, and so as long as the two AD systems connect well for the vjp we should be in the asymtopically large matmul case. For GPUs I think we might need to connect to https://github.com/TuringLang/ThArrays.jl to give it the right overloads in Julia but that shouldn't be fairly difficult either (pinging @KDr2 who may be interested in helping)
from diffeqpy.
While I can't claim to be an expert, I do know some things. I could certainly look into helping out here.
from diffeqpy.
It's my pleasure to help too if necessary.
from diffeqpy.
Thank you for the detailed response!
A first huge step would be being able to access the .jl
solvers for solve
steps within our current API. Before redesigning everything to include additional options / kwargs
(e.g Callbacks for event handling), we should verify that the two AD systems can connect well.
We also have an adjoint class in torchdyn
(https://github.com/DiffEqML/torchdyn/blob/master/torchdyn/sensitivity/adjoint.py). It shouldn't be too big of a problem for us to refactor the API (or provide another option), modifying it to follow DiffEqBase.jl
conventions.
from diffeqpy.
MTK performance discussion can continue in #57 and the vendering of Julia discussion to JuliaPy/pyjulia#118
from diffeqpy.
Related Issues (20)
- Accessing `EnsembleSolution` elements HOT 1
- How to make it use of FractionalDiffEq.jl into python - reg HOT 1
- question: is this repository still under maintain? HOT 3
- Sparse Matrix Jacobian / Mass Matrix support HOT 2
- MethodError: no method matching Sundials.NVector(::Vector{Real}) HOT 3
- Terminating integration based on maxiters results in RuntimeError HOT 3
- problems with julia 1.9? HOT 2
- Speed up julia loading and time to first run HOT 1
- Feature request: Ingest sympy representation of ModelingToolkit and return PyTorch tensor with gradients and events
- dictionary support for de.jit HOT 7
- Support prob_func in ensembles HOT 5
- diffeqpy GPU isn't automatically installing DiffEqGPU on Collab? HOT 2
- Segfault when using sympy symbols as initial conditions.
- UndefVarError: NonlinearLeastSquaresProblem not defined in SimpleNonlinearSolve package HOT 2
- LoadError: Artifact "OpenSpecFun" was not found by looking in the paths HOT 3
- Precompilation not quite working HOT 2
- Error when importing de HOT 9
- TypeError on DiscreteCallback HOT 2
- Error when running the example code for GPU-acceleration using diffeqpy HOT 10
- DAE with Mass Matrix HOT 7
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from diffeqpy.