Comments (5)
Thanks.
I am most concerned about adding wrappers of external libraries that have already been implemented with other forms of computing backend. We use a set of operators for particle mesh simulation and FFTs.
When we are talking about custom primitives, we also need to notify the auto-differ how to pick the tensor operators on the operands. How is this currently done in JAX?
I believe in autograd this was done with a vector space object that knows how to serialize any operand into a numpy array, after which numpy functions are used for inner products etc. This may not always be desirable -- e.g. if data has to be partitioned to several MPI ranks, then serialization to a single MPI rank is not even going to fit into the memory. We weren't able to use autograd due to this.
Another thing to worry about is whether these customized primitives support higher order differentiations. If the vjp function itself needs to be an external routine (not set of primitives) then higher order differentiation and auto-jvp are probably both broken? Is this a supported case?
from jax.
Not sure if this is the right place: how can I define custom vmap primitives; in my case I am calling an external function that already supports batches and I want to vmap the code surrounding the call of this external primitive.
from jax.
@jonasrauber that question is from a while ago, but the short answer is that custom_transforms
(as in from jax import custom_transforms
) is for doing this. To be improved and documented...
from jax.
Just sketched out a custom VJPs API last night: https://gist.github.com/mattjj/2ba580930472e8e04c1759737268af92
The example there is trivial, and there's a bit more bookkeeping to be done to handle general code. But our initial thinking is that we can have a defvjp
and a defvjp_all
, to be used with @custom_transforms
, where the former lets you specify a vjp function for each positional argument and the latter lets you specify a vjp for all arguments at once. (Maybe we can also provide a defvjp_all_staged
if you want to compute some reduced residual information on the forward pass, rather than saving all the argument values.)
The funny bookkeeping in that gist is due to the fact that in JAX we usually don't specify VJPs (reverse-mode rules) directly, and instead only specify forward-mode rules; JAX generates reverse-mode autodiff through a composition of forward-mode, partial evaluation, and transposition transformations. But if you want to specify a VJP rule directly, that gist shows a trick to do it.
That was all a work-in-progress. We've got something better now!
from jax.
Whoops, I didn't mean to close this in #818!
from jax.
Related Issues (20)
- Unexpected dtype returned from jnp.outer with mixed inputs dtypes HOT 1
- Adding `tree_util.stack_leaves()` and `tree_util.unstack_leaves()` HOT 5
- Metal : fp64 operations with jax.numpy base functions not supported HOT 3
- Performance Issue Report: JAX Slower Than Autograd on GPU and CPU Setups HOT 4
- Counterintuitive speed of einsums vs equivalent matmuls HOT 5
- amd-smi import error with rocm 6.1 + rocm/jax:latest image HOT 2
- scipy.linalg.tril and .triu changed to sparse HOT 1
- Sharding is much slower than pmap for while loops of varying length while loops HOT 4
- ⚠️ Nightly upstream-dev CI failed ⚠️ HOT 1
- jnp.fft.ifft imprecision for GPU
- psum_scatter does not allow scatter_dimension to be negative HOT 3
- spsolve exits with error when inverting matrix sum HOT 4
- jax.random seems to have unnecessary buffer allocations on stack HOT 6
- buggy interaction: remat, automatic partitioning, and unsafe `rbg`-based RNGs
- Seeking guidance for landing spot of `scipy.stats.levy_stable` in Jax
- dynamic config scope under `jit` doesn't change partitionable threefry behavior
- Unexpected speedup from wrapping function call in trivial jax.lax.cond statement
- Persistent compilation cache does not work HOT 2
- ROCm 6.1, 7900 xtx: bfloat16 support not enabled? HOT 1
- Remaining deprecations for array API compliance
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jax.