Comments (10)
If you only have a checkpoint
file in the directory, that means all your values were saved with aggregate=True
, so they are all stored in one file. Otherwise, they will be stored separately using Tensorstore, in which case they will be represented with PLACEHOLDER in the structure
file. You're right that this difference is a bit annoying, and it's not easy to get metadata for all arrays, to look at things like the shape and stuff without doing a full restoration.
This does materialize the entire checkpoint. If you want to look up the shapes without doing so, you have to parse through the .zarray files, or use Tensorstore to do so. Again, we're working on a more user-friendly way of doing this. In the future, LazyArray will also have properties like shape and dtype even before materialization.
from orbax.
I think this may be consequence of using the Flax restore_args_from_target
function which makes certain assumptions about how on-device, fully replicated arrays should be restored. Try using construct_restore_args
instead.
from orbax.
Thanks, it appears to work.
Here is a follow-up question: in the documentation, the target
object passed to both construct_restore_args
and CheckpointManager.restore
have the same shapes as the train_state
that we are trying to restore. This seems not ideal, considering that it may take up quite some memory.
In the Flax documentation here, it seems to be possible to pass a smaller target reference. However, I've trying creating one with the same axis dimension as my shardings, but I get an error like
ValueError: Cannot intersect index domain { [0, 40*) } with index domain { }: Ranks do not match [source locations='tensorstore/index_space/index_transform.cc:484']
What is the exact logic? How can I create a reference smaller than the actual train state?
from orbax.
For starters, the item
argument for CheckpointManager.restore
is optional, and is quite unnecessary in your case since it's just a dict, and not a custom PyTree.
Secondly, it is possible to skip the initialization of target
if you are constrained by memory. Simply use restore_args=dict(sharded=orbax.checkpoint.ArrayRestoreArgs(sharding=...), unsharded=orbax.checkpoint.ArrayRestoreArgs(sharding=...))
I'm not actually sure what the Flax documentation is talking about when they say the reference may be smaller than the actual train state.... the construct_restore_args
function is really only intended for use when you already have the entire PyTree initialized with arrays of the correct shape and sharding. If you have jax.ShapeDtypeStruct
, for example, you would just use similar logic to initialize the restore_args.
from orbax.
Fantastic, seems to have worked like a charm. Thanks a lot!
from orbax.
Reopening the issue since I have got a follow-up problem. In my code, the restoring function comes within a jitted function. Unfortunately, I get the following concretization error:
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(uint32[])>with<DynamicJaxprTrace(level=2/0)>
This value became a tracer due to JAX operations on these lines:
operation a:u32[1] = host_local_array_to_global_array[
global_mesh=Mesh(device_ids=array([[0]]), axis_names=('processes', 'local_devices'))
pspec=PartitionSpec('processes',)
] b
Any idea how I can work around this issue?
from orbax.
This is an easy answer: you can't restore from within a jitted function. This is an issue we encountered before with a few other Flax users, and we reached the conclusion that they would just need to move their restore outside the jitted function. Sorry!
from orbax.
Alright, then perhaps there is a way to obtain a tree of shapes corresponding to the checkpoint, without actually restoring the checkpoint? I tried using jax.eval_shape
in combination with a CheckpointManager.restore
, but it turns out that jax.eval_shape
jits the function, therefore this approach did not work.
from orbax.
There is, but it's currently not well integrated into the API. Orbax does support getting the pytree structure of the checkpoint via the structure
API or using lazy restore. However, if you want the shapes of the arrays, we currently don't have an API for that, though we're working on it. What you can do is parse through the .zarray files, which store the shape (and other metadata) for each parameter.
from orbax.
Thanks. Related to structure
, I've noticed that when there is only one checkpoint in the directory, structure
creates a pytree with the same structure as the checkpoint, with numpy arrays as values. However, when multiple checkpoints are available, it returns a pytree with lists of placeholder strings as values. I find this difference a little odd. For example, it would seem I could get shapes in the first case, but not in the second?
Anyway, in order to get shapes, currently I'm restoring the whole checkpoint as lazy (using ArrayRestoreArgs(lazy=True)
), then I get the shape of their values in a tree_map
as follows:
shapes = tree_map(lambda v: v.get().shape, restored_checkpoint)
Would you say this method allocates the full checkpoint in memory?
from orbax.
Related Issues (20)
- How to restore on a CPU a checkpoint saved on a GPU? HOT 1
- Checkpoint Manager using different directory paths for save and restore HOT 2
- Cannot restore sharded array on different machine HOT 8
- How to restore a variable from checkpoint saved in cpu back in cpu when you have both gpu and cpu? HOT 5
- Strange behavior of saving sharded trainstate in GCP. HOT 3
- misstake submit
- Error HOT 1
- Struggling to restore metadata on other device HOT 6
- Parse structure of a saved PyTree checkpoint HOT 1
- Top-level orbax import 0.5.* globally breaks logging HOT 3
- [Bug] Asyncio error while loading Flax weights HOT 7
- Tagging releases HOT 2
- Make GCS-style checkpointing configurable instead deciding it on is_gcs_path function output HOT 4
- Installation error when installing T5x HOT 1
- New interface does not support empty dicts in pytrees HOT 1
- New interface does not support `None` in pytrees HOT 5
- save_args_from_target alternatives? HOT 4
- Orbax API migration questions HOT 3
- Saving doesnt work and results in extra *.npy extension? HOT 10
- About using CPU backend as mock and unifying using multihost_utils wrappers along repo HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from orbax.