Comments (5)
how does
.at
work with boolean indices when not compiled?
It converts them to integer indices, then passes those integer indices to scatter
or gather
. The conversion happens here:
jax/jax/_src/numpy/lax_numpy.py
Line 6952 in c313cac
You can see how tricky it is to get the semantics of boolean indices correct in numpy-style indexing expressions...
from jax.
Currently the update op (at[...].set(...)
) lowers to a scatter, which needs to know statically how many values are being updated.
I agree that it would be nice to jit-compile updates with dynamic indices, but I'm not sure if there's a clean way to do it at the moment. We could conditionally lower the operation to a select
when the operand is a scalar, which would allow it to be jit-compiled as described. This might not be very performant however, if we potentially have to materialize a mask over a large array.
Tagging @jakevdp for any additional opinions.
from jax.
I've thought about this on several occasions, but the added flexibility has never seemed worth the added complexity in the code.
As @justinjfu mentioned, arr.at[...].*
is currently a way of spellinglax.gather
and lax.scatter
. These APIs do not support dynamic shapes, and so boolean indices cannot be used within JIT.
It's important to note that arr.at[mask].set(arr2)
is inherently a dynamically-shaped operation, because semantically arr2
must be broadcast-compatible with the entries of arr
selected by mask
. It's true that there is a special case when arr2
has size 1, because a size-1 axis can broadcast to any array shape: but the shape semantics are still dynamic.
In the simplest case of a 1D array with a 1D mask, you could imagine forking the logic and lowering to select
rather than scatter
; but immediately someone would ask about arr.at[0, mask].set(arr2)
or arr.at[indices, mask].set(arr2)
or arr.at[..., mask1, :, 0, mask2].set(arr2)
, or all the other variants of combined indices that it may be possible to lower to select (depending on the shape of arr2
). To get all of that logic correct would be tricky, and would add a lot of complexity both to the implementation, and to the mental overhead of the user trying to figure out when an operation may or may not be JIT-compatible.
By contrast, the current rule is very simple: boolean indices are not supported by JIT, and if you want to use select
instead of gather
/scatter
, you can do so explicitly. In this case, I think simpler is better.
from jax.
Thanks for the answers. Just one more question - how does .at
work with boolean indices when not compiled?
from jax.
ooh I see. Well, would still be nice to have at some point...
from jax.
Related Issues (20)
- Mypy Error HOT 1
- Flash attention soft capping support
- Cannot pass constant to `lax.fori_loop` body inside pallas kernel HOT 1
- `eval_shape` should preserve `weak_type` HOT 5
- Full zero arrays allocated for "consumed" gradients of `stop_gradient` parameters HOT 2
- TPU tile in first dim of a 2d array HOT 3
- Add function that checks whether `jax` is yet initialized or not
- APIs to set the XLA executor's priority HOT 1
- `jnp.pow()` triggers dtype promotion when it should not HOT 3
- `jax.nn.dot_product_attention` does not respect `key_value_seq_lengths` HOT 4
- Segmentation Fault on 8x8 or larger matrix multiplication with GPU HOT 1
- jax.random.randint() fails when minval/maxval are traced, but aren't arrays. HOT 1
- `RESOURCE_EXHAUSTED` when using `.at[i].set` inside `vmap`. HOT 4
- `random.key_impl`wrong return type annotation HOT 1
- Bug in single-precision linear system solving HOT 4
- error: failed to legalize operation 'mhlo.erf' with jax-metal
- jnp.vectorize causes ```FloatingPointError: invalid value (nan)``` when there is no nan returned. HOT 4
- AttributeError: module 'jax' has no attribute 'linear_util' HOT 4
- [Pallas – Triton] add `tl.debug_barrier` equivalent HOT 1
- `NNFunctionsTest.testDotProductAttention` failing on A100/H100 with `use_vmap=True, `impl='cudnn'` HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jax.