Comments (3)
I have been making some tests about different batched solvers and decompositions, see below:
The Cholesky (and LU probably), seems to be the main issue here. I tried to swap to a QR solver, but it is somehow 300 times slower than Cholesky for large batch size, and ~100 slower than SVD. SVD seems to be the most reliable, though I have seen it fail too.
import jax
import jax.numpy as jnp
device = jax.local_devices()[0]
print('on device:', device)
m = 10
import time
for n in [1e5, 1e6, 1e7]:
n = int(n)
A = jnp.repeat(jnp.identity(m)[None], n, axis = 0).block_until_ready()
print("n=", n)
st_time = time.time()
U,S,Vh = jax.scipy.linalg.svd(A)
A2 = jax.lax.batch_matmul(U * S[...,None,:], Vh)
print(f"SVD error {jnp.mean(jnp.linalg.norm(A2 - A, axis=(-1, -2))/jnp.linalg.norm(A, axis=(-1,-2)))}, time = {time.time() - st_time} ")
st_time = time.time()
L = jax.scipy.linalg.cholesky(A)
A2 = jax.lax.batch_matmul(L, L.swapaxes(-1,-2)).block_until_ready()
print(f"Cholesky error {jnp.mean(jnp.linalg.norm(A2 - A, axis=(-1, -2))/jnp.linalg.norm(A, axis=(-1,-2)))}, time = {time.time() - st_time} ")
if n <= 1e6:
st_time = time.time()
Q,R = jnp.linalg.qr(A)
A2 = jax.lax.batch_matmul(Q,R).block_until_ready()
print(f"QR error {jnp.mean(jnp.linalg.norm(A2 - A, axis=(-1, -2))/jnp.linalg.norm(A, axis=(-1,-2)))}, time = {time.time() - st_time} ")
Output:
on device: cuda:0
n= 100000
SVD error 0.0, time = 0.6100842952728271
Cholesky error 0.0, time = 0.15522980690002441
QR error 0.0, time = 3.7535462379455566
n= 1000000
SVD error 0.0, time = 0.5560173988342285
Cholesky error 0.0, time = 0.1310713291168213
QR error 0.0, time = 35.838584184646606
n= 10000000
SVD error 0.0, time = 2.056659460067749
Cholesky error nan, time = 0.27480244636535645
Note that JAX doesn't throw an error or warning
Here is a SVD based solver to use in the mean time, in case someone else needs a stopgap:
def solve_by_SVD(A,b):
U,S,Vh = jax.scipy.linalg.svd(A)
if b.ndim == A.ndim -1:
expand = True
b = b[...,None]
else:
expand = False
Uhb = jax.lax.batch_matmul(jnp.conj(U.swapaxes(-1,-2)),b)/ S[...,None]
x = jax.lax.batch_matmul(jnp.conj(Vh.swapaxes(-1,-2)),Uhb)
if expand:
x = x[...,0]
return x
from jax.
As far as I can tell, the LU/Cholesky bugs are fixed by updating jax to 0.4.23 . See patrick-kidger/lineax#79 (comment) for more info
from jax.
Well, fixed is fixed, I guess. We could dig into why, but it would mostly be of historical interest. Please reopen if it happens again!
from jax.
Related Issues (20)
- Slowdown of Hermitian matrix-vector product on GPU
- Provide utilities for creating pytrees filled with random samples HOT 7
- Efficient diag(JtJ) HOT 14
- The gradient of jax.lax.select gives nan when the false branch is nan HOT 2
- Better error message when setting `jax_enable_x64` using a 0/1 integer (after jax 0.4.26) HOT 6
- Performance issue with 64bit on CPU HOT 3
- Segmentation Fault on JAX GPU HOT 1
- Arrays having dtype float0 are broken with the Array API HOT 5
- Tracer's imag method returns float; crashes with the Array API HOT 14
- Let `initial=-jnp.inf` by default in `nn.softmax` and `nn.log_softmax` HOT 2
- jax.debug.callback changes array type HOT 2
- Support IO effect in vmap-of-while. HOT 11
- Invalid `default type` for environment variable HOT 2
- Test aborting due to thread limit exhaustion HOT 5
- Toy example tracing (python -m trace --T file.py) throws an _remove AttrbuteError exception due to NoneType HOT 2
- LaxBackedNumpyTests.testClipStaticBounds18 failing HOT 1
- jax.numpy.mean ops convert to tflite is slowly HOT 1
- On jax-metal, updating multidimensional boolean arrays sometimes fails HOT 2
- Issues with compilation cache on jax==0.4.26 for large models in multi-host setup HOT 4
- vmap does not ensure sharding propagation HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jax.