Comments (6)
You are not the first person to request this, so I'll just say that the issue is on our radar, but has not yet risen to a high priority. In the meantime, you'll probably have to convert the key to a jax.Array
to save it.
Thanks for reporting though, this will affect our prioritization going forward.
from orbax.
Thanks for your sharing. We have a new release of orbax-checkpoint 0.5.1 including two new JaxRandomKeyCheckpointHanlder and NumpyRandomKeyCheckpointHandler. We recommend to use these outside of train state PyTree because the random keys are more metadata.
Documentation is here. Usage examples can be found here
from orbax.
Great, thanks for your response!
from orbax.
@hylkedonker I am looking into adding support to store jax.random.key in Orbax. I have a couple questions on how these states are stored?
- How often should the random keys be saved? Is it necessary to store them in every training step?
- Do all machines share the same keys, or does each machine need to store its own ones?
from orbax.
Thanks for getting in touch.
I use the pseudo random number generator keys to train variational inference (VI) models. Concretely, each training step I consume a PRNG key to make a Monte Carlo estimate of the ELBO (evidence lower bound). In practice, I make the key part of Flax's TrainState
which I checkpoint every now and then.
So to get back to your questions:
- The key needs to be tracked every training step, but I don't save the
TrainState
every training step. - I currently don't have a lot of experience sharding the computation across different machines. But I would imagine that one might
pmap
the Monte Carlo estimate over different machines (so that each machine gets its own key).
I hope this helps. If not, let me know how I can further clarify.
from orbax.
Great work, thanks!
from orbax.
Related Issues (20)
- How to restore on a CPU a checkpoint saved on a GPU? HOT 1
- Checkpoint Manager using different directory paths for save and restore HOT 2
- Cannot restore sharded array on different machine HOT 8
- How to restore a variable from checkpoint saved in cpu back in cpu when you have both gpu and cpu? HOT 5
- Strange behavior of saving sharded trainstate in GCP. HOT 3
- misstake submit
- Error HOT 1
- Struggling to restore metadata on other device HOT 6
- Parse structure of a saved PyTree checkpoint HOT 1
- Top-level orbax import 0.5.* globally breaks logging HOT 3
- [Bug] Asyncio error while loading Flax weights HOT 7
- Tagging releases HOT 2
- Make GCS-style checkpointing configurable instead deciding it on is_gcs_path function output HOT 4
- Installation error when installing T5x HOT 1
- New interface does not support empty dicts in pytrees HOT 1
- New interface does not support `None` in pytrees HOT 5
- save_args_from_target alternatives? HOT 4
- Orbax API migration questions HOT 3
- Saving doesnt work and results in extra *.npy extension? HOT 10
- About using CPU backend as mock and unifying using multihost_utils wrappers along repo HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from orbax.