Code Monkey home page Code Monkey logo

Comments (9)

PabloAMC avatar PabloAMC commented on May 20, 2024

I am investigating whether this google/jax#5461 (comment) could be a simpler yet effective solution. It seems to allow for 2 training iterations before breaking down, (having set up scf iterations = 2). However, it returns erroneous eigenvalues, as it ends up not conserving the charge.

from graddft.

PabloAMC avatar PabloAMC commented on May 20, 2024

I have created an example to play around with this problem
https://github.com/XanaduAI/DiffDFT/blob/main/examples/basic_examples/example_neural_scf_training.py
It should be deleted if we can't solve it on time. Changing the number of diis iterations in https://github.com/XanaduAI/DiffDFT/blob/9bbab42cfe95bbb6031c45613b5ab718ca9e9e5f/grad_dft/evaluate.py#L524
also has some minor effect.

from graddft.

PabloAMC avatar PabloAMC commented on May 20, 2024

I have realized that https://gist.github.com/jackd/99e012090a56637b8dd8bb037374900e provides a definition of the derivatives by hand which is probably being used https://github.com/XanaduAI/DiffDFT/blob/9bbab42cfe95bbb6031c45613b5ab718ca9e9e5f/grad_dft/external/eigh_impl.py#L84

from graddft.

jackbaker1001 avatar jackbaker1001 commented on May 20, 2024

I'm going to experiment a bit with this now, but I think we can just do:

import jax.numpy as jnp
def generalized_eigh(A, B):
    L = jnp.linalg.cholesky(B)
    L_inv = jnp.linalg.inv(L)
    A_redo = L_inv.dot(A).dot(L_inv.T)
    return jnp.linalg.eigh( A_redo )

from graddft.

jackbaker1001 avatar jackbaker1001 commented on May 20, 2024

I'm going to experiment a bit with this now, but I think we can just do:

import jax.numpy as jnp
def generalized_eigh(A, B):
    L = jnp.linalg.cholesky(B)
    L_inv = jnp.linalg.inv(L)
    A_redo = L_inv.dot(A).dot(L_inv.T)
    return jnp.linalg.eigh( A_redo )

Ah sorry this is the solution you already posted. Let me see if there is a better way of doing this...

from graddft.

jackbaker1001 avatar jackbaker1001 commented on May 20, 2024

Ok @PabloAMC I think this code fixes the problem:

def generalized_to_standard_eig(A, B):
    L = np.linalg.cholesky(B)
    L_inv = np.linalg.inv(L)
    C = L_inv @ A @ L_inv.T
    eigenvalues, eigenvectors_transformed = np.linalg.eigh(C)
    eigenvectors_original = L_inv.T @ eigenvectors_transformed
    return eigenvalues, eigenvectors_original

Basically, the eigenvectors were not transformed back to the original basis.

from graddft.

PabloAMC avatar PabloAMC commented on May 20, 2024

You are right, that works, thanks @jackbaker1001. Unfortunately, it still gives errors when backpropagating even when using this version of the generalized eigenvalue problem. It still gives nan errors when I set cycles = 1 in the additional example https://github.com/XanaduAI/DiffDFT/blob/main/examples/basic_examples/example_neural_scf_training.py I included in the basic folder and change the number of cycles to 1, instead of 0, in https://github.com/XanaduAI/DiffDFT/blob/81444ce067a809e3e69bffdf81b4e7c1b5ca4d2b/examples/basic_examples/example_neural_scf_training.py#L138

from graddft.

jackbaker1001 avatar jackbaker1001 commented on May 20, 2024

Ok @PabloAMC so because this issue was about a differentiable eigh implementation, which we now have, I will put in a PR for this and merge. I will keep the custom eigh implementation from before in the code just in case in comes in handy later. It may be removed later also.

I will make a new issue regarding differentiating through the SCF iterator (different implementations of DIIS presently) as this is now the problem!

from graddft.

PabloAMC avatar PabloAMC commented on May 20, 2024

I think we should reopen this. The problem is numerical stability during the calculation of grads, not a lack of a differentiable implementation. And that does not seem solved yet.

from graddft.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.