Code Monkey home page Code Monkey logo

Comments (3)

ma-gilles avatar ma-gilles commented on June 18, 2024

I have been making some tests about different batched solvers and decompositions, see below:

The Cholesky (and LU probably), seems to be the main issue here. I tried to swap to a QR solver, but it is somehow 300 times slower than Cholesky for large batch size, and ~100 slower than SVD. SVD seems to be the most reliable, though I have seen it fail too.

import jax
import jax.numpy as jnp
device = jax.local_devices()[0]
print('on device:', device)

m = 10

import time
for n in [1e5, 1e6, 1e7]:
    n = int(n)
    A = jnp.repeat(jnp.identity(m)[None], n, axis = 0).block_until_ready()
    print("n=", n)
    
    st_time = time.time()
    U,S,Vh = jax.scipy.linalg.svd(A)
    A2 = jax.lax.batch_matmul(U * S[...,None,:], Vh)
    print(f"SVD error {jnp.mean(jnp.linalg.norm(A2 - A, axis=(-1, -2))/jnp.linalg.norm(A, axis=(-1,-2)))}, time = {time.time() - st_time} ")

    st_time = time.time()
    L = jax.scipy.linalg.cholesky(A)
    A2 = jax.lax.batch_matmul(L, L.swapaxes(-1,-2)).block_until_ready()
    print(f"Cholesky error {jnp.mean(jnp.linalg.norm(A2 - A, axis=(-1, -2))/jnp.linalg.norm(A, axis=(-1,-2)))}, time = {time.time() - st_time} ")

    
    if n <= 1e6:
        st_time = time.time()
        Q,R = jnp.linalg.qr(A)
        A2 = jax.lax.batch_matmul(Q,R).block_until_ready()
        print(f"QR error {jnp.mean(jnp.linalg.norm(A2 - A, axis=(-1, -2))/jnp.linalg.norm(A, axis=(-1,-2)))}, time = {time.time() - st_time} ")

Output:

on device: cuda:0
n= 100000
SVD error 0.0, time = 0.6100842952728271 
Cholesky error 0.0, time = 0.15522980690002441 
QR error 0.0, time = 3.7535462379455566 
n= 1000000
SVD error 0.0, time = 0.5560173988342285 
Cholesky error 0.0, time = 0.1310713291168213 
QR error 0.0, time = 35.838584184646606 
n= 10000000
SVD error 0.0, time = 2.056659460067749 
Cholesky error nan, time = 0.27480244636535645 

Note that JAX doesn't throw an error or warning

Here is a SVD based solver to use in the mean time, in case someone else needs a stopgap:

def solve_by_SVD(A,b):
    U,S,Vh = jax.scipy.linalg.svd(A)

    if b.ndim == A.ndim -1:
        expand = True
        b = b[...,None]
    else:
        expand = False
    
    Uhb = jax.lax.batch_matmul(jnp.conj(U.swapaxes(-1,-2)),b)/ S[...,None]
    x = jax.lax.batch_matmul(jnp.conj(Vh.swapaxes(-1,-2)),Uhb)

    if expand:
        x = x[...,0]

    return x

from jax.

ma-gilles avatar ma-gilles commented on June 18, 2024

As far as I can tell, the LU/Cholesky bugs are fixed by updating jax to 0.4.23 . See patrick-kidger/lineax#79 (comment) for more info

from jax.

hawkinsp avatar hawkinsp commented on June 18, 2024

Well, fixed is fixed, I guess. We could dig into why, but it would mostly be of historical interest. Please reopen if it happens again!

from jax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.