Code Monkey home page Code Monkey logo

Comments (10)

cpgaffney1 avatar cpgaffney1 commented on May 23, 2024 1

If you only have a checkpoint file in the directory, that means all your values were saved with aggregate=True, so they are all stored in one file. Otherwise, they will be stored separately using Tensorstore, in which case they will be represented with PLACEHOLDER in the structure file. You're right that this difference is a bit annoying, and it's not easy to get metadata for all arrays, to look at things like the shape and stuff without doing a full restoration.

This does materialize the entire checkpoint. If you want to look up the shapes without doing so, you have to parse through the .zarray files, or use Tensorstore to do so. Again, we're working on a more user-friendly way of doing this. In the future, LazyArray will also have properties like shape and dtype even before materialization.

from orbax.

cpgaffney1 avatar cpgaffney1 commented on May 23, 2024

I think this may be consequence of using the Flax restore_args_from_target function which makes certain assumptions about how on-device, fully replicated arrays should be restored. Try using construct_restore_args instead.

from orbax.

gianlucadetommaso avatar gianlucadetommaso commented on May 23, 2024

Thanks, it appears to work.

Here is a follow-up question: in the documentation, the target object passed to both construct_restore_args and CheckpointManager.restore have the same shapes as the train_state that we are trying to restore. This seems not ideal, considering that it may take up quite some memory.

In the Flax documentation here, it seems to be possible to pass a smaller target reference. However, I've trying creating one with the same axis dimension as my shardings, but I get an error like

ValueError: Cannot intersect index domain { [0, 40*) } with index domain { }: Ranks do not match [source locations='tensorstore/index_space/index_transform.cc:484']

What is the exact logic? How can I create a reference smaller than the actual train state?

from orbax.

cpgaffney1 avatar cpgaffney1 commented on May 23, 2024

For starters, the item argument for CheckpointManager.restore is optional, and is quite unnecessary in your case since it's just a dict, and not a custom PyTree.

Secondly, it is possible to skip the initialization of target if you are constrained by memory. Simply use restore_args=dict(sharded=orbax.checkpoint.ArrayRestoreArgs(sharding=...), unsharded=orbax.checkpoint.ArrayRestoreArgs(sharding=...))

I'm not actually sure what the Flax documentation is talking about when they say the reference may be smaller than the actual train state.... the construct_restore_args function is really only intended for use when you already have the entire PyTree initialized with arrays of the correct shape and sharding. If you have jax.ShapeDtypeStruct, for example, you would just use similar logic to initialize the restore_args.

from orbax.

gianlucadetommaso avatar gianlucadetommaso commented on May 23, 2024

Fantastic, seems to have worked like a charm. Thanks a lot!

from orbax.

gianlucadetommaso avatar gianlucadetommaso commented on May 23, 2024

Reopening the issue since I have got a follow-up problem. In my code, the restoring function comes within a jitted function. Unfortunately, I get the following concretization error:

jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(uint32[])>with<DynamicJaxprTrace(level=2/0)>

This value became a tracer due to JAX operations on these lines:
  operation a:u32[1] = host_local_array_to_global_array[
  global_mesh=Mesh(device_ids=array([[0]]), axis_names=('processes', 'local_devices'))
  pspec=PartitionSpec('processes',)
] b

Any idea how I can work around this issue?

from orbax.

cpgaffney1 avatar cpgaffney1 commented on May 23, 2024

This is an easy answer: you can't restore from within a jitted function. This is an issue we encountered before with a few other Flax users, and we reached the conclusion that they would just need to move their restore outside the jitted function. Sorry!

from orbax.

gianlucadetommaso avatar gianlucadetommaso commented on May 23, 2024

Alright, then perhaps there is a way to obtain a tree of shapes corresponding to the checkpoint, without actually restoring the checkpoint? I tried using jax.eval_shape in combination with a CheckpointManager.restore, but it turns out that jax.eval_shape jits the function, therefore this approach did not work.

from orbax.

cpgaffney1 avatar cpgaffney1 commented on May 23, 2024

There is, but it's currently not well integrated into the API. Orbax does support getting the pytree structure of the checkpoint via the structure API or using lazy restore. However, if you want the shapes of the arrays, we currently don't have an API for that, though we're working on it. What you can do is parse through the .zarray files, which store the shape (and other metadata) for each parameter.

from orbax.

gianlucadetommaso avatar gianlucadetommaso commented on May 23, 2024

Thanks. Related to structure, I've noticed that when there is only one checkpoint in the directory, structure creates a pytree with the same structure as the checkpoint, with numpy arrays as values. However, when multiple checkpoints are available, it returns a pytree with lists of placeholder strings as values. I find this difference a little odd. For example, it would seem I could get shapes in the first case, but not in the second?

Anyway, in order to get shapes, currently I'm restoring the whole checkpoint as lazy (using ArrayRestoreArgs(lazy=True)), then I get the shape of their values in a tree_map as follows:

shapes = tree_map(lambda v: v.get().shape, restored_checkpoint)

Would you say this method allocates the full checkpoint in memory?

from orbax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.