Code Monkey home page Code Monkey logo

Comments (6)

cpgaffney1 avatar cpgaffney1 commented on June 9, 2024 1

You are not the first person to request this, so I'll just say that the issue is on our radar, but has not yet risen to a high priority. In the meantime, you'll probably have to convert the key to a jax.Array to save it.

Thanks for reporting though, this will affect our prioritization going forward.

from orbax.

ChromeHearts avatar ChromeHearts commented on June 9, 2024 1

Thanks for your sharing. We have a new release of orbax-checkpoint 0.5.1 including two new JaxRandomKeyCheckpointHanlder and NumpyRandomKeyCheckpointHandler. We recommend to use these outside of train state PyTree because the random keys are more metadata.

Documentation is here. Usage examples can be found here

from orbax.

hylkedonker avatar hylkedonker commented on June 9, 2024

Great, thanks for your response!

from orbax.

ChromeHearts avatar ChromeHearts commented on June 9, 2024

@hylkedonker I am looking into adding support to store jax.random.key in Orbax. I have a couple questions on how these states are stored?

  1. How often should the random keys be saved? Is it necessary to store them in every training step?
  2. Do all machines share the same keys, or does each machine need to store its own ones?

from orbax.

hylkedonker avatar hylkedonker commented on June 9, 2024

Thanks for getting in touch.
I use the pseudo random number generator keys to train variational inference (VI) models. Concretely, each training step I consume a PRNG key to make a Monte Carlo estimate of the ELBO (evidence lower bound). In practice, I make the key part of Flax's TrainState which I checkpoint every now and then.
So to get back to your questions:

  1. The key needs to be tracked every training step, but I don't save the TrainState every training step.
  2. I currently don't have a lot of experience sharding the computation across different machines. But I would imagine that one might pmap the Monte Carlo estimate over different machines (so that each machine gets its own key).

I hope this helps. If not, let me know how I can further clarify.

from orbax.

hylkedonker avatar hylkedonker commented on June 9, 2024

Great work, thanks!

from orbax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.