Code Monkey home page Code Monkey logo

maxtext's People

Contributors

a9isha avatar abhinavclemson avatar abhinavgoel95 avatar aireenmei avatar chajath avatar cheshire avatar cmto1983 avatar gobbleturk avatar jonb377 avatar khatwanimohit avatar kocchop avatar michelle-yooh avatar morgandu avatar ninacai avatar obliviour avatar patemotter avatar priyanka-ganesha avatar raymondzouu avatar reedwm avatar rissyran avatar roshanin avatar rwitten avatar shaurya89 avatar singh-mitali avatar ssusie avatar surbhijainusc avatar tonyjohnchen avatar xuefgu avatar yashs97 avatar zhiyuli-goog avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

maxtext's Issues

Supported features

Mainly wanted to start with thanking you for making MaxText available. I have been using it for a few days, and the first impression is fantastic. Getting started was really easy, it seemed very stable, and the performance was fantastic. It seems to scale very nicely.

A few things that I have not been able to figure out yet, it might be because of lack of documentation, or simply because it is not implemented.

  • Are there any support for Flash attention, or any plans for implementing this? This has been a major area where GPUs have been ahead of TPUs. I have noticed that there now is at least an experimental implementation from the Jax-team: https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py.

  • Training directly from tfds seemed straight forward. However, I was a bit confused about how to implement more advanced data loader features, for instance probability sampling like explained here. This can be somewhat tricky to do efficiently on multiple tpus. What is the sensible approach here? Manually sampling into a tfds dataset does not seem very efficient. Are there external libraries here that are compatible with maxtext?

  • Are there plans for implementing DPO/RLHF?

I also shamelessly wanted to point you to my own repo: https://github.com/peregilk/ttconnect. It is a very simple bash script that ideally should be run on a VM in the same zone. It automatically opens up synchronised tmux windows to all the VMs in the pod, and allows you to type the same command into all the VMs. This makes it even easier to go from one tpu to pods.

[Question] Loading in a HF Dataset

Hi team, thanks for the great work and public release!

I had a question regarding training on my own dataset. I see that the current README directs users to download datasets from the allennlp TensorFlow datasets. What would the process be for MaxText training using a HuggingFace dataset?

I assume we would need to convert the HuggingFace dataset into TFDS format. Is tokenization done on the fly when loading from the GCS bucket?

Thanks!

setup.sh runs `rm ~/jax`

maxtext/setup.sh

Lines 91 to 94 in 490f75b

# Delete jax folder if it exists
if [[ -d $HOME/jax ]]; then
rm -rf $HOME/jax
fi

# Delete jax folder if it exists
if [[ -d $HOME/jax ]]; then
    rm -rf $HOME/jax
fi

This seems dangerous!

Maybe it's intended only to run on worker nodes rather than users' development machines, and maybe it's only meant to be clearing up its own cached directories. But even if that's the case it seems like we might be able to improve it, like by installing in less common directory names (maybe all under a ~/maxtext_cache, so ~/maxtext_cache/jax rather than ~/jax?).

Problems with a parameter checkpoint after training llama2-7b

I’ve trained llama2-7b model with int8 quantization. The resulting checkpoints dir has the following structure:

- 0/:
-- commit_success.txt

-- default/:
--- _METADATA
--- _sharding
--- checkpoint
--- commit_success.txt
--- d/:
---- cacfeed1cb58f5...
---- ed63ba6da8ada7...
--- manifest.ocdbt
--- ocdbt.process_0/:
---- d/:
----- 028478f3dcd5b9...
----- 551917150ca2a8...
----- 8111cad0633e85...
---- manifest.ocdbt
--- ocdbt.process_1/
--- ...

-- metrics/

-10000/

-20000/

- ...

Following this example, I’m trying to generate a parameter checkpoint for further decoding:

export CHECKPOINT_PATH=gs://.../checkpoints/10000/default
export PARAMETER_CHECKPOINT_RUN=<MY_RUN_NAME>
python3 MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml load_full_state_path=$CHECKPOINT_PATH run_name=${PARAMETER_CHECKPOINT_RUN} model_name='llama2-7b' force_unroll=true

I get this output:

...
I0214 21:09:09.466111 140399213561856 checkpointer.py:167] Finished restoring checkpoint from gs://.../.../.../.../checkpoints/10000/default.
In input checkpoint Number of model params=6.738 billion
Save decode checkpoint at: gs://.../.../.../param_checkpoint/checkpoints/
I0214 21:09:10.367621 140399213561856 async_checkpointer.py:210] Commit thread error check finished successfully
I0214 21:09:12.452971 140399213561856 async_checkpointer.py:293] Async saving item to gs://.../.../.../param_checkpoint/checkpoints/0.
W0214 21:09:13.010293 140399213561856 type_handlers.py:399] SaveArgs.aggregate is deprecated, please use custom TypeHandler (https://orbax.readthedocs.io/en/latest/custom_handlers.html#typehandler) or contact Orbax team to migrate before May 1st, 2024.
I0214 21:09:23.297378 140316597466688 async_checkpointer.py:150] Starting commit to storage layer by process: 0
I0214 21:09:23.300883 140399213561856 checkpoint_manager.py:811] Beginning async checkpoint finalize.
saved an decode checkpoint at gs://.../.../.../param_checkpoint/checkpoints/
Successfully generated decode checkpoint at: gs://.../.../.../param_checkpoint/checkpoints/0/default
I0214 21:10:41.916716 140316597466688 async_checkpointer.py:158] Finished committing to storage layer by process: 0
I0214 21:10:41.917720 140317441869376 async_checkpointer.py:207] Commit thread joined successfully

The resulting dir gs://.../.../.../param_checkpoint/checkpoints/0/default with a parameter checkpoint has the following structure:

- _METADATA
- _sharding
- checkpoint
- ocdbt.process_0/:
-- d/:
--- 14159648ab13db...
--- 6bb79329658hjl...
--- caa39c6cfcajsk...
--- d852816bl798dj...
-- manifest.ocdbt

Following this example further, I try to load a parameter checkpoint into decode.py:

python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=gs://.../.../.../param_checkpoint/checkpoints/0/default run_name=<MY_RUN_NAME> per_device_batch_size=1 model_name='llama2-7b' ici_tensor_parallelism=4 max_prefill_predict_length=4  max_target_length=16 prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to share." attention=dot_product scan_layers=false

I get this error:

restoring params from load_parameters_from_path='gs://.../.../.../param_checkpoint/checkpoints/0/default'
Traceback (most recent call last):
  File "/home/.../2024-02-14-19-59-41/MaxText/decode.py", line 278, in <module>
    app.run(main)
  File "/home/.../.local/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/.../.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/.../2024-02-14-19-59-41/MaxText/decode.py", line 275, in main
    decode_loop(pyconfig.config)
  File "/home/.../2024-02-14-19-59-41/MaxText/decode.py", line 188, in decode_loop
    state, state_mesh_annotations = max_utils.setup_decode_state(
  File "/home/.../2024-02-14-19-59-41/MaxText/max_utils.py", line 334, in setup_decode_state
    return setup_initial_state(model, None, config, rng, mesh, checkpoint_manager, is_training)
  File "/home/.../2024-02-14-19-59-41/MaxText/max_utils.py", line 362, in setup_initial_state
    state, raw_params = checkpointing.load_state_if_possible(checkpoint_manager,
  File "/home/.../2024-02-14-19-59-41/MaxText/checkpointing.py", line 116, in load_state_if_possible
    full_restored_state = checkpointer.restore(p, item = abstract_param_train_state,\
  File "/home/.../.local/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 163, in restore
    raise ValueError(f'Found incomplete checkpoint at {directory}.')
ValueError: Found incomplete checkpoint at gs://.../.../.../param_checkpoint/checkpoints/0/default.

Diving into the code, I realize that this error appears when there is no commit_success.txt file in the checkpoint dir. This file hasn’t been generated automatically by generate_param_only_checkpoint.py, despite the status Successfully generated decode checkpoint . I tried to add this file manually — I got this error:

I0216 11:10:42.486244 140477023528960 checkpointer.py:164] Restoring item from gs://.../.../param_checkpoint/checkpoints/0/default.
Traceback (most recent call last):
  File "/home/.../2024-02-14-19-59-41/MaxText/decode.py", line 278, in <module>
    app.run(main)
  File "/home/.../.local/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/.../.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/.../2024-02-14-19-59-41/MaxText/decode.py", line 275, in main
    decode_loop(pyconfig.config)
  File "/home/.../2024-02-14-19-59-41/MaxText/decode.py", line 188, in decode_loop
    state, state_mesh_annotations = max_utils.setup_decode_state(
  File "/home/.../2024-02-14-19-59-41/MaxText/max_utils.py", line 334, in setup_decode_state
    return setup_initial_state(model, None, config, rng, mesh, checkpoint_manager, is_training)
  File "/home/.../2024-02-14-19-59-41/MaxText/max_utils.py", line 362, in setup_initial_state
    state, raw_params = checkpointing.load_state_if_possible(checkpoint_manager,
  File "/home/.../2024-02-14-19-59-41/MaxText/checkpointing.py", line 116, in load_state_if_possible
    full_restored_state = checkpointer.restore(p, item = abstract_param_train_state,\
  File "/home/.../.local/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 166, in restore
    restored = self._handler.restore(directory, args=ckpt_args)
  File "/home/.../.local/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 1073, in restore
    restored_item = asyncio.run(
  File "/usr/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/usr/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()
  File "/home/.../.local/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 903, in _maybe_deserialize
    deserialized_batches += await asyncio.gather(*deserialized_batches_ops)
  File "/home/.../.local/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py", line 1179, in deserialize
    results = await super().deserialize(infos, args)
  File "/home/.../.local/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py", line 1117, in deserialize
    await _assert_parameter_files_exist(
  File "/home/.../.local/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py", line 260, in _assert_parameter_files_exist
    raise FileNotFoundError(
FileNotFoundError: Individual parameter subdirectory not found at path: gs://.../.../param_checkpoint/checkpoints/0/default/step

Thus, it’s trying to find gs://.../.../param_checkpoint/checkpoints/0/default/step dir, whereas there is no such.

Passing the initial checkpoint path gs://.../checkpoints/10000/default to the decode.py did not lead to the desired result:

ValueError: Dict key mismatch; expected keys: ['decoder_norm', 'layers', 'logits_dense']; dict: {'decoder_norm': {'scale': ArrayRestoreArgs(restore_type=None, dtype=dtype('float32'), mesh=Mesh(device_ids=array([[[[[0],
          [2],
          [1],
          [3]]]]]), axis_names=('data', 'fsdp', 'sequence', 'tensor', 'autoregressive')), mesh_axes=PartitionSpec(('fsdp', 'sequence'),), sharding=None, global_shape=None)}, 'layers_0': {'mlp': {'wi_0': {'kernel': ArrayRestoreArgs(restore_type=None, dtype=dtype('float32'), mesh=Mesh(device_ids=array([[[[[0],
          [2],
          [1],
          [3]]]]]), axis_names=('data', 'fsdp', 'sequence', 'tensor', 'autoregressive')), mesh_axes=PartitionSpec(('fsdp', 'sequence'), ('tensor', 'autoregressive')), sharding=None, global_shape=None)}, 'wi_1': {'kernel': ArrayRestoreArgs(restore_type=None, dtype=dtype('float32'), mesh=Mesh(device_ids=array([[[[[0],
...

I also tried to train a model async_checkpointing: False , thinking that there could be a problem with asynchronous checkpointing, but got the same result. What is the problem? Should I pass the number of training steps of the checkpoint?

I'd be grateful for any advice on how to resolve the issue.

load_parameters_path=gs:// deletes directory

  • when load_parameters_path=$HOME/ everything is fine

  • when load_parameters_path=gs://
    it not only fails but deletes the directory from storage.

  • Fork of maxtext 3 weeks ago

  • orbax ==0.2.6

Repro Sketch

SRC=/home/sam/heather_400m_reshard_d25d4fdd-999c-4d0b-90d5-b345cbcabd9e
DST=gs://uscentral2_user/sam/checkpoints/heather_400m_reshard_ea151096-f3e7-4776-94db-3234dc897d17
gsutil -m rsync $SRC $DST
gcloud storage ls $DST
# gs://uscentral2_user/sam/checkpoints/heather_400m_reshard_ea151096-f3e7-4776-94db-3234dc897d17/1/
# This works
bash train.sh exp.load_parameters_path=$SRC optim.steps=10 

# This deletes $DST
bash train.sh exp.load_parameters_path=$DST optim.steps=10 

Cannot do inference in float32

If we try to perform inference in float32, we get the error:

AssertionError: Key and Value Dtypes should match

This error comes from this line.

The origin of the error is that the cache dtype is set to jnp.int8 if quantize_kvcache else jnp.bfloat16 but never to jnp.float32.

`nextrng` not checkpointed, consider using `fold_in(config.seed, step)`

Looks like the nextrng value is not saved to the checkpoint:

state, metrics, nextrng = p_train_step(

It doesn't matter in most cases since most people don't use dropout or stochastic rounding. But in cases where it does matter, it's cleaner to generate the RNG for a training step using jax.random.fold_in(config.seed, state.step). This way, no checkpointing is required, and there's also some other side advantages listed in https://twitter.com/cgarciae88/status/1615022554992738315.

XlaRuntimeError when training with bfloat16 activations on TPU v3-8

Hello,

I tried testing MaxText on a TPU v3-8 (using the base.yml config with "dataset_type: synthetic") but am encountering the following error when running train.py:

Traceback (most recent call last):
  File "/home/user/maxtext/MaxText/train.py", line 335, in <module>
    app.run(main)
  File "/home/user/jax_env/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/user/jax_env/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/user/maxtext/MaxText/train.py", line 331, in main
    train_loop(config)
  File "/home/user/maxtext/MaxText/train.py", line 282, in train_loop
    state, metrics, nextrng = p_train_step(
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Failed to compile a Mosaic module: 'tpu.matmul' op Unsupported input data type in matrix multiplication. (extra diagnostics trimmed...)

The issue seems to lie with the bfloat16 config option for activations (dtype: "bfloat16") since it works when using 'dtype: "float32" (and necessarily scaling the parameter count way down).
Interestingly, trying this small model with bfloat16 results in a slightly different error:

# same traceback as before except for last line:
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Failed to compile a Mosaic module: failed to legalize operation 'tpu.mask_cast'

For reference, this was tested in a fresh Python 3.10 virtualenv.

Support for T5

Do you have plans to support encoder-decoder models like T5? It will be great to have T5 with flash attention 😃

52B example sharding error

Hi,

I was trying to run the 1x v4-384 52B model example following MaxText/configs/1xv4-384.sh on a v4-384 slice and hit the following error:

Traceback (most recent call last): 
  File "maxtext/MaxText/train.py", line 334, in <module> 
    app.run(main) 
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 308, in run 
    _run_main(main, args) 
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 254, in _run_main 
    sys.exit(main(argv)) 
  File "maxtext/MaxText/train.py", line 330, in main 
    train_loop(pyconfig.config) 
  File "maxtext/MaxText/train.py", line 277, in train_loop 
    state, state_mesh_annotations = max_utils.setup_initial_state(model, tx, config, init_rng, mesh, checkpoint_manager) 
  File "/home/wx/maxtext/MaxText/max_utils.py", line 159, in setup_initial_state 
    state = pjit( 
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback 
    return fun(*args, **kwargs) 
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/pjit.py", line 208, in cache_miss 
    outs, out_flat, out_tree, args_flat = _python_pjit_helper( 
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/pjit.py", line 150, in _python_pjit_helper 
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn( 
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/pjit.py", line 735, in infer_params 
    return common_infer_params(pjit_info_args, *args, **kwargs) 
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/pjit.py", line 474, in common_infer_params 
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr( 
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/pjit.py", line 937, in _pjit_jaxpr 
    canonicalized_out_shardings_flat = _check_and_canonicalize_out_shardings( 
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/pjit.py", line 920, in _check_and_canonicalize_out_shardings 
    pjit_check_aval_sharding( 
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/pjit.py", line 973, in pjit_check_aval_sharding 
    raise ValueError(f"One of {what_aval}{name_str} was given the sharding " 
jax._src.traceback_util.UnfilteredStackTrace: ValueError: One of pjit outputs with pytree key path .params['decoder']['decoder']['self_attention']['key_layer_norm']['scale'].value was given the sharding of NamedSharding(mesh={'data': 1, 'fsdp': 192, 'tensor': 1}, spec=PartitionSpec('fsdp', None)), which implies that the global size of its dimension 0 should be divisible by 192, but it is equal to 256 (full shape: (256, 32)) 

The stack trace below excludes JAX-internal frames. 
The preceding is the original exception that occurred, unmodified. 
  
-------------------- 
  
The above exception was the direct cause of the following exception: 
  
Traceback (most recent call last): 
  File "maxtext/MaxText/train.py", line 334, in <module> 
    app.run(main) 
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 308, in run 
    _run_main(main, args) 
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 254, in _run_main 
    sys.exit(main(argv)) 
  File "maxtext/MaxText/train.py", line 330, in main 
    train_loop(pyconfig.config) 
  File "maxtext/MaxText/train.py", line 277, in train_loop 
    state, state_mesh_annotations = max_utils.setup_initial_state(model, tx, config, init_rng, mesh, checkpoint_manager) 
  File "/home/wx/maxtext/MaxText/max_utils.py", line 159, in setup_initial_state 
    state = pjit( 
ValueError: One of pjit outputs with pytree key path .params['decoder']['decoder']['self_attention']['key_layer_norm']['scale'].value was given the sharding of NamedSharding(mesh={'data': 1, 'fsdp': 192, 'tensor': 1}, spec=PartitionSpec('fsdp', None)), which implies that the global size of its dimension 0 should be divisible by 192, but it is equal to 256 (full shape: (256, 32))

It looks like this has to do with the sharding spec being incompatible with the tensor shape? Below are the commands I used to set up (used the main branch and jax-0.4.10) and run the experiment, any ideas on what went wrong here?

$ gcloud compute tpus tpu-vm ssh tpuv4 --zone=us-central2-b --worker=all --command="git clone https://github.com/google/maxtext.git" 
$ gcloud compute tpus tpu-vm ssh tpuv4 --zone=us-central2-b --worker=all --command="cd maxtext; sudo bash setup.sh" 
$ gcloud compute tpus tpu-vm ssh tpuv4 --zone=us-central2-b --worker=all --command="export LIBTPU_INIT_ARGS='--xla_enable_async_all_gather=true TPU_MEGACORE=MEGACORE_DENSE'" 
$ gcloud compute tpus tpu-vm ssh tpuv4 --zone=us-central2-b --worker=all --command="python3 maxtext/MaxText/train.py maxtext/MaxText/configs/base.yml run_name=max_52B base_output_directory=gs://wx/max/ dataset_path=gs://maxtext_dt/ enable_profiler=true enable_checkpointing=false steps=10 ici_fsdp_parallelism=192 ici_tensor_parallelism=1 scale=4 base_num_decoder_layers=8 per_device_batch_size=10 remat_policy=full base_emb_dim=3072 base_mlp_dim=12288 learning_rate=1e-8" 

Can AQT be used to calculate qk score?

Thanks to the author for this contribution, I see that the code only performs aqt operations on calculations involving parameters. Can aqt be used for the calculation of qk attention score or score * V in the Attention calculation?

I train model on jax==0.4.23 and tpu v5p-8

`attend_dtype` not used

Here it seems that the hard-coded bfloat16 is used instead of attend_dtype. Also query is not cast. I guess the correct behavior should be casting both query and self.embedding to attend_dtype?

Issues running test_llama2_7b.sh on TPU VM v3-8

Hi!
I was trying to run test_llama2_7b.sh following default instructions on TPU-VM with tpus v3-8.
I was able to succesfully run script till fine-tuning part with command
python3 MaxText/train.py MaxText/configs/base.yml load_parameters_path=gs://MY_BUCKET_NAME/test/2024-04-01-01-24/decode-ckpt-maxtext/0/items run_name=runner_finetuning_2024-04-01-01-24 base_output_directory=gs://MY_BUCKET_NAME dataset_path=gs://MY_BUCKET_NAME async_checkpointing=false per_device_batch_size=1 model_name=llama2-7b ici_tensor_parallelism=4 steps=10 max_target_length=1024 per_device_batch_size=1 checkpoint_period=5

I got the following traceback

2024-04-01 02:06:24.369879: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Updating keys from env and command line: ['run_name', 'model_name', 'load_parameters_path', 'async_checkpointing', 'checkpoint_period', 'base_output_directory', 'ici_tensor_parallelism', 'dataset_path', 'per_device_batch_size', 'steps', 'max_target_length']
Running Model: llama2-7b
Updating following parameters in config

base_emb_dim: 4096
base_num_query_heads: 32
base_num_kv_heads: 32
base_mlp_dim: 11008
base_num_decoder_layers: 32
head_dim: 128
mlp_activations: ['silu', 'linear']
vocab_size: 32000
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1e-05
decoder_block: llama2
Updating keys from model: ['base_emb_dim', 'base_num_query_heads', 'base_num_kv_heads', 'base_mlp_dim', 'base_num_decoder_layers', 'head_dim', 'mlp_activations', 'vocab_size', 'enable_dropout', 'logits_via_embedding', 'normalization_layer_epsilon', 'decoder_block']
2024-04-01 02:06:31.698078: I external/xla/xla/pjrt/pjrt_c_api_client.cc:137] PjRtCApiClient created.
System Information: Jax Version: 0.4.25
System Information: Jaxlib Version: 0.4.25
System Information: Jax Backend: PJRT C API
TFRT TPU v3
Built on Feb 24 2024 03:12:26 (1708773146) cl/609954703
Config param adam_b1: 0.9
Config param adam_b2: 0.95
Config param adam_eps: 1e-08
Config param adam_eps_root: 0.0
Config param adam_weight_decay: 0.1
Config param async_checkpointing: False
Config param attention: autoselected
Config param autoregressive_decode_assert:
Config param base_emb_dim: 4096
Config param base_mlp_dim: 11008
Config param base_num_decoder_layers: 32
Config param base_num_kv_heads: 32
Config param base_num_query_heads: 32
Config param base_output_directory: gs://MY_BUCKET_NAME
Config param checkpoint_dir: gs://MY_BUCKET_NAME/runner_finetuning_2024-04-01-01-24/checkpoints/
Config param checkpoint_period: 5
Config param collect_stack_trace: False
Config param compile_topology:
Config param compile_topology_num_slices: -1
Config param compiled_trainstep_file:
Config param cosine_learning_rate_final_fraction: 0.1
Config param data_sharding: (('data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive'),)
Config param data_shuffle_seed: 0
Config param dataset_name: c4/en:3.0.1
Config param dataset_path: gs://MY_BUCKET_NAME
Config param dataset_type: c4
Config param dcn_autoregressive_parallelism: 1
Config param dcn_data_parallelism: -1
Config param dcn_fsdp_parallelism: 1
Config param dcn_fsdp_transpose_parallelism: 1
Config param dcn_sequence_parallelism: 1
Config param dcn_tensor_parallelism: 1
Config param decode_sampling_nucleus_p: -1
Config param decode_sampling_strategy: greedy
Config param decode_sampling_temperature: 1.0
Config param decode_sampling_top_k: 0
Config param decoder_block: llama2
Config param dropout_rate: 0
Config param dtype: bfloat16
Config param emb_dim: 4096
Config param enable_checkpointing: True
Config param enable_data_shuffling: True
Config param enable_dropout: False
Config param enable_profiler: False
Config param enable_single_replica_ckpt_restoring: False
Config param eval_dataset_name: c4/en:3.0.1
Config param eval_interval: -1
Config param eval_per_device_batch_size: 0
Config param eval_split: validation
Config param force_unroll: False
Config param fused_mlp: False
Config param fused_qkv: False
Config param gcs_metrics: False
Config param global_batch_size_to_load: 8
Config param global_batch_size_to_train_on: 8
Config param global_parameter_scale: 1
Config param gradient_clipping_threshold: 1.0
Config param grain_worker_count: 4
Config param hardware: tpu
Config param head_dim: 128
Config param ici_autoregressive_parallelism: 1
Config param ici_data_parallelism: 1
Config param ici_fsdp_parallelism: -1
Config param ici_fsdp_transpose_parallelism: 1
Config param ici_sequence_parallelism: 1
Config param ici_tensor_parallelism: 4
Config param init_weights_seed: 0
Config param jax_cache_dir: ~/jax_cache
Config param learning_rate: 3e-05
Config param learning_rate_schedule_steps: 10
Config param load_from_prefill_dir: False
Config param load_full_state_path:
Config param load_parameters_path: gs://MY_BUCKET_NAME/test/2024-04-01-01-24/decode-ckpt-maxtext/0/items
Config param log_period: 100
Config param logical_axis_rules: (('activation_batch', ('data', 'fsdp', 'fsdp_transpose')), ('activation_heads', ('tensor', 'sequence')), ('activation_length', 'sequence'), ('activation_embed', 'tensor'), ('activation_mlp', 'tensor'), ('activation_kv', 'tensor'), ('activation_vocab', ('tensor', 'sequence')), ('activation_vocab', 'tensor'), ('activation_vocab', 'sequence'), ('mlp', ('fsdp_transpose', 'tensor', 'autoregressive')), ('vocab', ('tensor', 'autoregressive')), ('embed', ('fsdp', 'fsdp_transpose', 'sequence')), ('embed', ('fsdp', 'sequence')), ('heads', ('tensor', 'autoregressive')), ('kv', ()), ('cache_batch', ()), ('cache_heads', ('autoregressive', 'tensor')), ('cache_kv', ()), ('cache_sequence', ()))
Config param logits_dot_in_fp32: True
Config param logits_via_embedding: False
Config param max_corpus_chars: 10000000
Config param max_prefill_predict_length: 64
Config param max_target_length: 1024
Config param mesh_axes: ['data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
Config param metrics_dir: gs://MY_BUCKET_NAME/runner_finetuning_2024-04-01-01-24/metrics/
Config param metrics_file:
Config param mlp_activations: ['silu', 'linear']
Config param mlp_dim: 11008
Config param model_name: llama2-7b
Config param normalization_layer_epsilon: 1e-05
Config param normalize_embedding_logits: True
Config param num_decoder_layers: 32
Config param num_experts: 1
Config param num_experts_per_tok: 1
Config param num_kv_heads: 32
Config param num_query_heads: 32
Config param num_slices: 1
Config param opt_type: adamw
Config param param_scan_axis: 1
Config param per_device_batch_size: 1.0
Config param prefill_cache_dir:
Config param profiler_steps: 5
Config param prompt: I love to
Config param quantization:
Config param quantization_local_shard_count: 1
Config param quantize_kvcache: False
Config param record_internal_nn_metrics: 0
Config param remat_policy: full
Config param reuse_example_batch: 0
Config param run_name: runner_finetuning_2024-04-01-01-24
Config param save_config_to_gcs: False
Config param scan_layers: True
Config param skip_first_n_steps_for_profiler: 1
Config param stack_trace_interval_seconds: 600
Config param stack_trace_to_cloud: False
Config param steps: 10
Config param target_eval_loss: 0.0
Config param tensorboard_dir: gs://MY_BUCKET_NAME/runner_finetuning_2024-04-01-01-24/tensorboard/
Config param tokenizer_path: assets/tokenizer.llama2
Config param trainable_position_size: -1
Config param upload_all_profiler_results: False
Config param use_iota_embed: False
Config param use_untrainable_positional_embedding: False
Config param vocab_size: 32000
Config param warmup_steps_fraction: 0.1
Config param weight_dtype: float32
Creating checkpoint manager...
I0401 02:06:32.866911 140289813645312 checkpoint_manager.py:1040] Found 0 checkpoint steps in gs://MY_BUCKET_NAME/runner_finetuning_2024-04-01-01-24/checkpoints
I0401 02:06:32.867159 140289813645312 checkpoint_manager.py:484] jax.process_index=0, primary_host=0. CheckpointManager created: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x7f960fb3dff0>
Checkpoint manager created!
I0401 02:06:32.867487 140289813645312 mesh_utils.py:73] Reordering mesh to physical ring order on single-tray TPU v2/v3.
Num_devices: 8, shape (1, 2, 1, 1, 4, 1)
I0401 02:06:33.556910 140289813645312 dataset_info.py:610] Load dataset info from gs://MY_BUCKET_NAME/c4/en/3.0.1
I0401 02:06:34.487275 140289813645312 dataset_info.py:702] For 'c4/en/3.0.1': fields info.[splits] differ on disk and in the code. Keeping the one from code.
I0401 02:06:34.662644 140289813645312 reader.py:261] Creating a tf.data.Dataset reading 1024 files located in folders: gs://MY_BUCKET_NAME/c4/en/3.0.1.
I0401 02:06:34.819458 140289813645312 logging_logger.py:49] Constructing tf.data.Dataset c4 for split train, from gs://MY_BUCKET_NAME/c4/en/3.0.1
I0401 02:06:35.186302 140289813645312 dataset_info.py:610] Load dataset info from gs://MY_BUCKET_NAME/c4/en/3.0.1
I0401 02:06:36.488089 140289813645312 dataset_info.py:702] For 'c4/en/3.0.1': fields info.[splits] differ on disk and in the code. Keeping the one from code.
I0401 02:06:36.609976 140289813645312 reader.py:261] Creating a tf.data.Dataset reading 8 files located in folders: gs://MY_BUCKET_NAME/c4/en/3.0.1.
I0401 02:06:36.685018 140289813645312 logging_logger.py:49] Constructing tf.data.Dataset c4 for split validation, from gs://MY_BUCKET_NAME/c4/en/3.0.1
Tokenizer path: assets/tokenizer.llama2
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
checkpoint manager exists so trying to load this run's existing checkpoint
restoring params from load_parameters_from_path='gs://MY_BUCKET_NAME/test/2024-04-01-01-24/decode-ckpt-maxtext/0/items'
I0401 02:06:40.339250 140289813645312 checkpointer.py:166] Restoring item from gs://MY_BUCKET_NAME/test/2024-04-01-01-24/decode-ckpt-maxtext/0/items.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1711937201.278496   28101 gcs_resource.cc:109] Using default AdmissionQueue with limit 32
I0000 00:00:1711937201.282484   29547 google_auth_provider.cc:180] Running on GCE, using service account [email protected]
W0401 02:07:25.916970 140289813645312 transform_utils.py:229] The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
I0401 02:07:25.947791 140289813645312 checkpointer.py:169] Finished restoring checkpoint from gs://MY_BUCKET_NAME/test/2024-04-01-01-24/decode-ckpt-maxtext/0/items.
number parameters: 6.738 billion
Per train step:
 Total TFLOPs: 42.23
 split as 98.05% learnable weight flops and 1.95% attention flops
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/usr/maxtext/MaxText/train.py", line 497, in <module>
    app.run(main)
  File "/home/usr/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/usr/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/usr/maxtext/MaxText/train.py", line 493, in main
    train_loop(config)
  File "/home/usr/maxtext/MaxText/train.py", line 433, in train_loop
    state, metrics = p_train_step(
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: Unsupported input data type in matrix multiplication.

The MLIR operation involved:
  %1481 = "tpu.matmul"(%1478, %1479, %1480) {transpose_lhs = false, transpose_rhs = true} : (vector<512x128xbf16>, vector<128x128xbf16>, vector<512x128xf32>) -> vector<512x128xf32>
... additional diagnostics were skipped.

I've also printed an example batch from code that caused the error

{'inputs': Array([[    1,  2023, 14606, ...,     0,     0,     0],
       [    1,  2567,   393, ...,     0,     0,     0],
       [    1,   390,  2965, ...,     0,     0,     0],
       ...,
       [    1,   887, 30010, ...,   367, 16010,   746],
       [    1, 12547,   393, ..., 29915, 29885, 10932],
       [    1,  1383,   279, ...,     0,     0,     0]], dtype=int32), 'inputs_position': Array([[   0,    1,    2, ...,    0,    0,    0],
       [   0,    1,    2, ...,    0,    0,    0],
       [   0,    1,    2, ...,    0,    0,    0],
       ...,
       [   0,    1,    2, ..., 1021, 1022, 1023],
       [   0,    1,    2, ..., 1021, 1022, 1023],
       [   0,    1,    2, ...,    0,    0,    0]], dtype=int32), 'inputs_segmentation': Array([[1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       ...,
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 0, 0, 0]], dtype=int32), 'targets': Array([[ 2023, 14606,   437, ...,     0,     0,     0],
       [ 2567,   393,   366, ...,     0,     0,     0],
       [  390,  2965, 15444, ...,     0,     0,     0],
       ...,
       [  887, 30010,   345, ..., 16010,   746,   372],
       [12547,   393,   385, ..., 29885, 10932,   393],
       [ 1383,   279,  9010, ...,     0,     0,     0]], dtype=int32), 'targets_position': Array([[   0,    1,    2, ...,    0,    0,    0],
       [   0,    1,    2, ...,    0,    0,    0],
       [   0,    1,    2, ...,    0,    0,    0],
       ...,
       [   0,    1,    2, ..., 1021, 1022, 1023],
       [   0,    1,    2, ..., 1021, 1022, 1023],
       [   0,    1,    2, ...,    0,    0,    0]], dtype=int32), 'targets_segmentation': Array([[1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       ...,
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 0, 0, 0]], dtype=int32)}

I was using default settings (except bucket names, of course)
Can you please help me, what direction to look to fix the error?

How to use GPT2 tokenizer

Thank you so much for releasing this amazing repo! I'm wondering how I can use GPT-2 tokenizer in your data processing pipeline. Thank you very much for your time and help.

[Bug] adam_pax has reuse donated buffer warning

Hi, I noticed that when using adam_pax instead of adamw as optimizer, it will give reuse donated buffer warning. I am wondering if this is expected, and why the code uses adam_pax instead of the standard optax.adam as it does for adamw.

Thank you very for your help! @rwitten

Issues running end_to_end/test_mistral.sh

When running bash end_to_end/test_mistral.sh this line

gsutil -m cp -r gs://maxtext-external/mistral-7B-v0.1 /tmp

leads to the following error:

AccessDeniedException: 403 [email protected] does not have storage.objects.list access to the Google Cloud Storage bucket. Permission 'storage.objects.list' denied on resource (or it may not exist).
CommandException: 1 file/object could not be transferred.

Did the path changed or can I get the needed files from somewhere else?

Edit:
Running

git clone https://huggingface.co/mistralai/Mistral-7B-v0.1

and using the corresponding path for --base-model-path leads to this error:

Loading the base model from /home/MMP/Mistral-7B-v0.1
Traceback (most recent call last):
  File "/home/MMP/maxtext/MaxText/llama_or_mistral_ckpt.py", line 309, in <module>
    convert(args.base_model_path, args.maxtext_model_path, args.model_size)
  File "/home/MMP/maxtext/MaxText/llama_or_mistral_ckpt.py", line 127, in convert
    'scale': pytorch_vars[0]['norm.weight'].type(torch.float16).numpy()
IndexError: list index out of range

Jobs in kubernetes exceeds the limit of 40 characters

I have been trying to delete a job but XPK has a limitation of job name should have less than 40 characters.
In my case, the GKE job name had 44 characters. In order to workaround the problem I bumped the number to 45 and I was able to delete the job with python3 xpk/xpk.py workload delete.

I don't know if the job name limitation has changed recently. but I wanted to give a heads-up about this issue.

if not match or len(match.group(0)) > 40:

any larger model test?

  1. Have you conducted efficiency test on more then 100B model?

  2. If so, Can I get a sample configuration?

TPUv2-8 multislice

I'm trying to create 2 preemptive TPUv2-8 but got:

gcloud alpha compute tpus queued-resources create mega --accelerator-type v2-8 --runtime-version tpu-ubuntu2204-base --node-count 2 --node-prefix slice
ERROR: (gcloud.alpha.compute.tpus.queued-resources.create) INVALID_ARGUMENT: Cloud TPU was unable to complete the operation. Please try again, or contact support if the problem persists. [EID: 0x98e5db490ee44df9]

Converting checkpoints

Are there any scripts available for converting trained Gemma/Llama/Mistral MaxText checkpoints to HuggingFace?

[Question] are there some some train replication results?

Hi,

Thanks for the library! I'm new to the JAX+LLM ecosystem and trying to understand which library I should be using.

I see a lot of (very impressive) computational efficiency benchmarks of maxtext but can't find any benchmark in terms of performance. Do you have some perplexity/evals on a model trained with maxtext in standard settings? E.g. nanoGPT on wikitext (evaluated with perplexities or MMLU) or llama finetuning on Vicuna or Alapca data? I think it would be very useful to decide which JAX library to use for training LLMs!

Thank you for your help!

A pip error occurs when running setup.sh.

environment: GCP TPU node - v2.8 tpu-vm-base us-central1-c

An error occurs in pip dependency.
image

Successfully installed pip-24.0
/usr/local/lib/python3.8/dist-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 0.1.36ubuntu1 is an invalid version and will not be supported in a future release
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 0.23ubuntu1 is an invalid version and will not be supported in a future release
  warnings.warn(
Collecting git+https://github.com/mlperf/logging.git (from -r requirements.txt (line 24))
  Cloning https://github.com/mlperf/logging.git to /tmp/pip-req-build-3nnn5u08
  Running command git clone -q https://github.com/mlperf/logging.git /tmp/pip-req-build-3nnn5u08
/usr/local/lib/python3.8/dist-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 0.1.36ubuntu1 is an invalid version and will not be supported in a future release
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 0.23ubuntu1 is an invalid version and will not be supported in a future release
  warnings.warn(
ERROR: Could not find a version that satisfies the requirement jax>=0.4.23 (from -r requirements.txt (line 1)) (from versions: 0.0, 0.1, 0.1.1, 0.1.2, 0.1.3, 0.1.4, 0.1.5, 0.1.6, 0.1.7, 0.1.8, 0.1.9, 0.1.10, 0.1.11, 0.1.12, 0.1.13, 0.1.14, 0.1.15, 0.1.16, 0.1.18, 0.1.19, 0.1.20, 0.1.21, 0.1.22, 0.1.23, 0.1.24, 0.1.25, 0.1.26, 0.1.27, 0.1.28, 0.1.29, 0.1.30, 0.1.31, 0.1.32, 0.1.33, 0.1.34, 0.1.35, 0.1.36, 0.1.37, 0.1.38, 0.1.39, 0.1.40, 0.1.41, 0.1.42, 0.1.43, 0.1.44, 0.1.45, 0.1.46, 0.1.47, 0.1.48, 0.1.49, 0.1.50, 0.1.51, 0.1.52, 0.1.53, 0.1.54, 0.1.55, 0.1.56, 0.1.57, 0.1.58, 0.1.59, 0.1.60, 0.1.61, 0.1.62, 0.1.63, 0.1.64, 0.1.65, 0.1.66, 0.1.67, 0.1.68, 0.1.69, 0.1.70, 0.1.71, 0.1.72, 0.1.73, 0.1.74, 0.1.75, 0.1.76, 0.1.77, 0.2.0, 0.2.1, 0.2.2, 0.2.3, 0.2.4, 0.2.5, 0.2.6, 0.2.7, 0.2.8, 0.2.9, 0.2.10, 0.2.11, 0.2.12, 0.2.13, 0.2.14, 0.2.15, 0.2.16, 0.2.17, 0.2.18, 0.2.19, 0.2.20, 0.2.21, 0.2.22, 0.2.23, 0.2.24, 0.2.25, 0.2.26, 0.2.27, 0.2.28, 0.3.0, 0.3.1, 0.3.2, 0.3.3, 0.3.4, 0.3.5, 0.3.6, 0.3.7, 0.3.8, 0.3.9, 0.3.10, 0.3.11, 0.3.12, 0.3.13, 0.3.14, 0.3.15, 0.3.16, 0.3.17, 0.3.18, 0.3.19, 0.3.20, 0.3.21, 0.3.22, 0.3.23, 0.3.24, 0.3.25, 0.4.0, 0.4.1, 0.4.2, 0.4.3, 0.4.4, 0.4.5, 0.4.6, 0.4.7, 0.4.8, 0.4.9, 0.4.10, 0.4.11, 0.4.12, 0.4.13)
ERROR: No matching distribution found for jax>=0.4.23 (from -r requirements.txt (line 1))

It is strange that it does not run in the tpu default environment.

multihost_runner.py: number of devices does not match the product of the parallelism

I'm testing out multihost training, and following along with the readme for "Quick Experiments on Multiple Slices" (except, I'm not using queued resources because those also do not work for me).
I have two v2-8 TPU VMs running, and another non-TPU VM that I'm running commands on.

This is what happens, however:
$ python3 multihost_runner.py --TPU_PREFIX='tpu' --COMMAND='python3 MaxText/train.py MaxText/configs/base.yml run_name=RUN-2-SLICES dcn_data_parallelism=2'

Starting multihost runner...
2 slices found.
[t=0.00, SCP] Completed 0/2, slice 0 worker 0 still working...
[t=1.00, SCP] Completed 0/2, slice 0 worker 0 still working...
[t=2.01, SCP] Completed 0/2, slice 0 worker 0 still working...
SCP: Attempting to connect to worker 0...
[t=3.01, SCP] Completed 0/2, slice 0 worker 0 still working...
[t=4.01, SCP] Completed 0/2, slice 0 worker 0 still working...
script_dir_zip_2023-05-06-11-18-22.tar.gz                                                                                                                                        100%  417KB  60.7MB/s   00:00
[t=5.01, SCP] Completed 0/2, slice 0 worker 0 still working...
[t=6.01, SCP] Completed 2/2...
Running main command, logs located in: /tmp/2023-05-06-11-18-22/
[t=0.01, MAIN COMMAND] Completed 0/2, slice 0 worker 0 still working...
...

Searching for existing processes on device accel0...
No existing processes found.
[t=5.01, MAIN COMMAND] Completed 0/2, slice 0 worker 0 still working...
...

2023-05-06 11:18:36.790902: I external/org_tensorflow/tensorflow/compiler/xla/stream_executor/tpu/tpu_initializer_helper.cc:269] Libtpu path is: /home/rlrs_alexandra_gmail_com/.local/lib/python3.8/site-packages/
libtpu/libtpu.so
[t=9.02, MAIN COMMAND] Completed 0/2, slice 0 worker 0 still working...
...

2023-05-06 11:18:43.741793: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:168] XLA service 0x358f380 initialized for platform TPU (this does not guarantee that XLA will be used). Devices:
2023-05-06 11:18:43.741840: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): TPU, 2a886c8
2023-05-06 11:18:43.741851: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (1): TPU, 2a886c8
2023-05-06 11:18:43.741861: I external/org_tenso[t=15.02, MAIN COMMAND] Completed 0/2, slice 0 worker 0 still working...
rflow/tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (2): TPU, 2a886c8
2023-05-06 11:18:43.741872: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (3): TPU, 2a886c8
2023-05-06 11:18:43.741882: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (4): TPU, 2a886c8
2023-05-06 11:18:43.741891: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (5): TPU, 2a886c8
2023-05-06 11:18:43.741901: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (6): TPU, 2a886c8
2023-05-06 11:18:43.741911: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (7): TPU, 2a886c8
2023-05-06 11:18:43.765380: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc:218] TfrtCpuClient created.
[t=16.02, MAIN COMMAND] Completed 0/2, slice 0 worker 0 still working...
...
2023-05-06 11:19:03.221482: I tensorflow/compiler/xla/stream_executor/tpu/tpu_initializer_helper.cc:266] Libtpu path is: /home/rlrs_alexandra_gmail_com/.local/lib/python3.8/site-packages/libtpu/libtpu.so
[t=35.05, MAIN COMMAND] Completed 0/2, slice 0 worker 0 still working...
2023-05-06 11:19:04.591554: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
[t=36.05, MAIN COMMAND] Completed 0/2, slice 0 worker 0 still working...
...

Initialized persistent compilation cache at /home/rlrs_alexandra_gmail_com/jax_cache
Found 8 devices.
Creating checkpoint manager...
I0506 11:19:17.302592 140423533554752 dispatch.py:1131] Not writing persistent cache entry for 'jit_convert_element_type' because it took < 1.00 seconds to compile (0.03s)
I0506 11:19:17.325799 140423533554752 dispatch.py:1131] Not writing persistent cache entry for 'jit_broadcast_in_dim' because it took < 1.00 seconds to compile (0.02s)
I0506 11:19:17.455217 140423533554752 dispatch.py:1131] Not writing persistent cache entry for 'pmap__psum' because it took < 1.00 seconds to compile (0.12s)
I0506 11:19:17.484924 140423533554752 dispatch.py:1131] Not writing persistent cache entry for 'jit_broadcast_in_dim' because it took < 1.00 seconds to compile (0.02s)
I0506 11:19:17.611627 140423533554752 dispatch.py:1131] Not writing persistent cache entry for 'pmap__psum' because it took < 1.00 seconds to compile (0.11s)
I0506 11:19:17.614650 140423533554752 distributed.py:68] Starting JAX distributed service on 10.128.0.11:33429
2023-05-06 11:19:17.615705: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/distributed/service.cc:522] Experimental coordination service is enabled.
2023-05-06 11:19:17.615970: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/distributed/service.cc:554] Jax service listening on 10.128.0.11:33429
I0506 11:19:17.616165 140423533554752 distributed.py:79] Connecting to JAX distributed service on 10.128.0.11:33429
2023-05-06 11:19:17.616817: I external/org_tensorflow/tensorflow/tsl/distributed_runtime/coordination/coordination_service.cc:525] /job:jax_worker/replica:0/task:0 has connected to coordination service. Incarnat
ion: 7064756694357617212
2023-05-06 11:19:17.616980: I external/org_tensorflow/tensorflow/tsl/distributed_runtime/coordination/coordination_service_agent.cc:298] Coordination agent has successfully connected.
2023-05-06 11:19:17.617244: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/distributed/client.cc:508] Connected to distributed JAX controller
[t=49.06, MAIN COMMAND] Completed 0/2, slice 0 worker 0 still working...
I0506 11:19:18.599849 140423533554752 dispatch.py:1131] Not writing persistent cache entry for 'pmap__psum' because it took < 1.00 seconds to compile (0.12s)
[t=50.06, MAIN COMMAND] Completed 0/2, slice 0 worker 0 still working...
Checkpoint manager created!
I0506 11:19:18.818627 140423533554752 dispatch.py:1131] Not writing persistent cache entry for 'jit_shift_right_logical' because it took < 1.00 seconds to compile (0.02s)
I0506 11:19:18.839082 140423533554752 dispatch.py:1131] Not writing persistent cache entry for 'jit_convert_element_type' because it took < 1.00 seconds to compile (0.02s)
I0506 11:19:18.860087 140423533554752 dispatch.py:1131] Not writing persistent cache entry for 'jit_reshape' because it took < 1.00 seconds to compile (0.02s)
I0506 11:19:18.882867 140423533554752 dispatch.py:1131] Not writing persistent cache entry for 'jit__lambda_' because it took < 1.00 seconds to compile (0.02s)
I0506 11:19:18.906052 140423533554752 dispatch.py:1131] Not writing persistent cache entry for 'jit_concatenate' because it took < 1.00 seconds to compile (0.02s)
I0506 11:19:19.128638 140423533554752 dispatch.py:1131] Not writing persistent cache entry for 'jit__threefry_split_original' because it took < 1.00 seconds to compile (0.11s)
I0506 11:19:19.153988 140423533554752 dispatch.py:1131] Not writing persistent cache entry for 'jit__unstack' because it took < 1.00 seconds to compile (0.02s)
Devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), Tpu
Device(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6,
process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)] (num_devices: 8)
Traceback (most recent call last):
  File "MaxText/train.py", line 349, in <module>
    app.run(main)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "MaxText/train.py", line 345, in main
    train_loop(pyconfig.config)
  File "MaxText/train.py", line 278, in train_loop
    devices_array = max_utils.create_device_mesh(config)
  File "/home/rlrs_alexandra_gmail_com/2023-05-06-11-18-22/MaxText/max_utils.py", line 78, in create_device_mesh
    assert (
**AssertionError: Number of devices 8         does not match the product of the parallelism 16**
2023-05-06 11:19:19.174563: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/distributed/client.cc:516] Distributed task shutdown initiated.
2023-05-06 11:19:19.174587: I external/org_tensorflow/tensorflow/tsl/distributed_runtime/coordination/coordination_service_agent.cc:464] Coordination agent has initiated Shutdown().
2023-05-06 11:19:19.174862: I external/org_tensorflow/tensorflow/tsl/distributed_runtime/coordination/coordination_service.cc:1136] Shutdown barrier in coordination service has passed.
2023-05-06 11:19:19.174911: I external/org_tensorflow/tensorflow/tsl/distributed_runtime/coordination/coordination_service.cc:630] /job:jax_worker/replica:0/task:0 has disconnected from coordination service.
2023-05-06 11:19:19.175025: I external/org_tensorflow/tensorflow/tsl/distributed_runtime/coordination/coordination_service_agent.cc:483] Coordination agent has successfully shut down.
2023-05-06 11:19:19.175146: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/distributed/client.cc:518] Distributed task shutdown result: OK
2023-05-06 11:19:19.175165: I external/org_tensorflow/tensorflow/tsl/distributed_runtime/preemption/preemption_sync_manager.cc:166] Cancelled call to retrieve preemption notice. This is expected upon program shu
tdown.
2023-05-06 11:19:19.175203: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/distributed/service.cc:576] Jax service shutting down
2023-05-06 11:19:19.176923: I external/org_tensorflow/tensorflow/tsl/distributed_runtime/preemption/preemption_sync_manager.cc:139] Preemption sync protocol cancelled by notifier: CANCELLED: Preemption notifier
is being deleted.. This is expected during program shutdown.
[t=51.07, MAIN COMMAND] Completed 0/2, slice 0 worker 0 still working...
Exception ignored in: <function GCSRecordWriter.__del__ at 0x7fb4b803daf0>
Traceback (most recent call last):
... etc

Seems like the TPU node is not aware that it's multislice? What am I doing wrong?

Edit:
A little update from my side. I realize that maybe jax.distributed does not actually support connection TPUs across pods? Even if I try to manually connect two v2-8 VMs with jax.distributed.initialize(), they end up with jax.process_count() returning 1.

FAILED_PRECONDITION: TPU platform already registered for platform hardware version

With various minor/trivial modifcations to train.py I quickly hit the error

FAILED_PRECONDITION: TPU platform already registered for platform hardware version

This happens either on the first random.split if checkpointing is disabled, or inside of the checkpointer if it is enabled.

It looks like calling jax.devices() very early on (in the imports) solves this issue.

Minimal Repro: See https://github.com/google/maxtext/tree/import-fun
The difference between an error and working code is a single print("hello") statement.

TFDS Data Processing Pipline

Hi, I'm trying to understand some details in the TFDS data processing pipeline in your repo, and I'm confused about the following details:

In _tfds_data_processing.py:

(1) The truncate_to_max_allowable_length function truncates each sequence to be less than max_length. Does this mean we will simply discard and waste the text that exceeds max_length?

(2) The shuffle_buffer_size is default to 1024, which seems to be very small compared to the total number of sequences in a modern dataset.

(3) The dataset is first shuffled and then repeated for num_epochs time. But shouldn't we first repeat and then shuffle to guarantee different orders in different epochs?

In sequence_packing.py:

(4) The comment says pack_dataset.py does greedy sequence packing? What does that mean? Does it pack portions from different sequences that exceed max_length into a new sequence?

(5) map_fn is commented as "Internal function to flat_map over". Is this code working well externally?

Thank you very much for your time and help!

Support beam search

Hi,

It would be nice to support beam search.

There is the reference flax implementation in wmt example and the equivalent one from transformers.

I am guessing that we could:

  • duplicate inputs per num_beams initially
  • at each step we do:
    • decode_step
    • select top beams
    • overwrite entire past cache per selected beams
    • update cache with new selected tokens

So maybe the extra step here is to add the "overwrite entire past cache per selected beams"?
Curious if you have suggestions for implementation

Compatibility issue with tensorflow>=2.15.1 on GPU

Hi team,

I'm having an issue launching the pretraining job with tensorflow 2.15 or above. Tensorflow 2.15 immediately segdumps. With the latest tensorflow 2.16.1 I see there is an unbound or near-100% video memory growth of one of the data loading process, leading to CUDA OOM and cascading failures. One quick workaround is to locally install lower tensorflow versions e.g.

pip install tensorflow==2.13.1 tensorflow-text==2.13.0

Also works:

pip install tensorflow==2.14.1 tensorflow-text==2.14.0

Grain vs. `tf.data` Input Pipeline

Hello MaxText Team,

I'm in the process of deciding between Grain and tf.data for the input pipeline of my next research project, that will be based on Flax. While I recognize the potential benefits of both, I'm particularly interested in understanding the advantages of using Grain over tf.data from your experience, especially regarding runtime performance aspects.

Could you share any benchmarks or insights on the runtime performance differences you observed between the two input pipelines? Additionally, any details on scenarios where Grain particularly outshines tf.data or vice versa would be incredibly helpful for my decision-making process.

Thank you very much for your time and for any information you can provide.

All the best,

Convert Gemma weights

Hi,

Could you confirm which commit of google/grain to use when converting the Gemma weights?

It returns an error when using latest commit of both maxtext and grain as reported in google/grain#333

~/maxtext$ python MaxText/convert_gemma_chkpt.py --base_model_path $CHKPT_BUCKET/2b --maxtext_model_path $MODEL_BUCKET/2b --model_size 2b
Traceback (most recent call last):
  File "/home/boris/maxtext/MaxText/convert_gemma_chkpt.py", line 33, in <module>
    import checkpointing
  File "/home/boris/maxtext/MaxText/checkpointing.py", line 25, in <module>
    import grain.python as grain
  File "/home/boris/grain/grain/python.py", line 21, in <module>
    from . import python_experimental as experimental
  File "/home/boris/grain/grain/python_experimental.py", line 22, in <module>
    from . import python_lazy_dataset as lazy_dataset
  File "/home/boris/grain/grain/python_lazy_dataset.py", line 54, in <module>
    from ._src.python.lazy_dataset.transformations.shuffle import ShuffleLazyMapDataset
  File "/home/boris/grain/grain/_src/python/lazy_dataset/transformations/shuffle.py", line 18, in <module>
    from grain._src.python.experimental.index_shuffle.python import index_shuffle_module as index_shuffle
ImportError: cannot import name 'index_shuffle_module' from 'grain._src.python.experimental.index_shuffle.python' (unknown location)

Convert Gemma weights with scan layers

Hi,

It would be nice to be able to convert the Gemma checkpoints to support scan layers.
This will allow faster compilation for training & inference.

Thanks

Local development instructions don't work

Hello! I'm not able to use the local development instructions. Using a v4-8, I have cloned the repo, created a new conda env, and run the setup script inside the new env:

$ git clone [email protected]:google/maxtext.git
$ conda create --name maxtext python=3.10
$ bash setup.sh

However, when I try to run decoding, it fails:

$ python3 MaxText/decode.py MaxText/configs/base.yml run_name=test
... (long traceback)
AssertionError: Failed to construct dataset c4Dataset c4 cannot be loaded at version 3.0.1, only: 2.3.0, 2.2.1, 2.2.0.

Do you have a work around for this? Are there perhaps instructions for a Docker-based install?

How to calculate MFU as shown in readme?

Question for the developers, as in the readme page , one can see the MFU for the TPU VM.

Say if I were to make any modifications and run the same on a TPU VM, how can I calculate the MFU, in order to benchmark my changes with reference to performance gain or loss?

PS: I did go through the code to find any script calculating the MFU, but I couldn't, if I have missed something please help me find the same.

Consider installing local CUDA variant when building GPU image

We currently install jax[cuda12_pip] in setup.sh. However, since the base image comes with its own CUDA installation, this creates a version conflict, generating warnings such as:

2024-03-11 22:25:16.557390: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:517] Loaded cuDNN version 8907
2024-03-11 22:25:16.563077: W external/xla/xla/service/gpu/nvptx_compiler.cc:742] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.99). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.

Installing jax[cuda12_local] seems to mitigate this problem, although we should be careful in pinning aqtp version when doing so (see #500)

Issues running decode example from readme

Running

python3 MaxText/decode.py MaxText/configs/base.yml run_name=MY_JOB_NAME

from https://github.com/google/maxtext?tab=readme-ov-file#getting-started-local-development-for-single-host leads to the following error:

I0213 14:26:13.638493 140101499770880 logging_logger.py:49] Constructing tf.data.Dataset c4 for split validation, from gs://xyz_tpu/c4/en/3.0.1
Model path: assets/tokenizer
No existing checkpoints found, not restoring checkpoint.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/MMP/maxtext/MaxText/decode.py", line 278, in <module>
    app.run(main)
  File "/home/MMP/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/MMP/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/MMP/maxtext/MaxText/decode.py", line 275, in main
    decode_loop(pyconfig.config)
  File "/home/MMP/maxtext/MaxText/decode.py", line 191, in decode_loop
    kv_cache_annotations = max_utils.get_kv_cache_annotations(model, config, rng, mesh)
  File "/home/MMP/maxtext/MaxText/max_utils.py", line 539, in get_kv_cache_annotations
    abstract_state = jax.eval_shape(init_kv_cache_partial)
  File "/home/MMP/maxtext/MaxText/max_utils.py", line 531, in init_kv_cache
    model_vars = model.init({'params': rng, 'dropout': rng, 'aqt': rng},
  File "/home/MMP/maxtext/MaxText/layers/models.py", line 349, in __call__
    logits = self.decoder(
  File "/home/MMP/maxtext/MaxText/layers/models.py", line 241, in __call__
    y, _ = nn.scan(
  File "/home/MMP/.local/lib/python3.10/site-packages/flax/core/axes_scan.py", line 148, in scan_fn
    _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
  File "/home/MMP/.local/lib/python3.10/site-packages/flax/core/axes_scan.py", line 120, in body_fn
    broadcast_out, c, ys = fn(broadcast_in, c, *xs)
  File "/home/MMP/maxtext/MaxText/layers/models.py", line 93, in __call__
    attention_lnx = attention_layer(
  File "/home/MMP/maxtext/MaxText/layers/attentions.py", line 849, in __call__
    out = attention_op(query, key, value, decoder_segment_ids, model_mode)
  File "/home/MMP/maxtext/MaxText/layers/attentions.py", line 634, in __call__
    prefill_unnormalized_output, prefill_exponentials_max, prefill_exponentials_sum = self.apply_attention(
  File "/home/MMP/maxtext/MaxText/layers/attentions.py", line 169, in apply_attention
    return self.tpu_flash_attention(query, key, value, decoder_segment_ids), None, None
  File "/home/MMP/maxtext/MaxText/layers/attentions.py", line 242, in tpu_flash_attention
    x = wrap_flash_attention(query, key, value, decoder_segment_ids)
  File "/home/MMP/maxtext/MaxText/layers/attentions.py", line 235, in wrap_flash_attention
    return jax.vmap(splash_kernel)(query,key,value, segment_ids = decoder_segment_ids)
  File "/home/MMP/.local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 2303, in __call__
    return _splash_attention(
  File "/home/MMP/.local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 2269, in _splash_attention
    return _splash_attention_custom(
  File "/home/MMP/.local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 1176, in _splash_attention_custom
    return _splash_attention_forward(  # pytype: disable=wrong-arg-types
  File "/home/MMP/.local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 913, in _splash_attention_forward
    raise ValueError(f"{bkv=} must be a multiple of {NUM_LANES}.")
ValueError: bkv=64 must be a multiple of 128.
2024-02-13 14:26:15.940961: I external/xla/xla/pjrt/distributed/client.cc:134] Distributed task shutdown initiated.
2024-02-13 14:26:15.941302: I external/tsl/tsl/distributed_runtime/coordination/coordination_service.cc:1193] Shutdown barrier in coordination service has passed.
2024-02-13 14:26:15.941335: I external/tsl/tsl/distributed_runtime/coordination/coordination_service.cc:684] /job:jax_worker/replica:0/task:0 has disconnected from coordination service.
2024-02-13 14:26:15.941496: I external/xla/xla/pjrt/distributed/client.cc:136] Distributed task shutdown result: OK
2024-02-13 14:26:15.941510: I external/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager.cc:166] Cancelled call to retrieve preemption notice. This is expected upon program shutdown.
2024-02-13 14:26:15.941699: I external/xla/xla/pjrt/distributed/service.cc:118] Jax service shutting down
2024-02-13 14:26:15.942328: I external/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager.cc:139] Preemption sync protocol cancelled by notifier: CANCELLED: Preemption notifier is being deleted.. This is expected during program shutdown.

Before that I only run this to set up everything:

git clone https://github.com/google/maxtext.git
bash setup.sh
bash download_dataset.sh tpu-cluster gs://xyz_tpu

And I added the correct dataset_path to MaxText/configs/base.yml and I'm running this on a v5p-8.

Create a user friendly inference demo

This is a feature request.

I like maxtext because it is very customizable and efficient for training.
The main issue I’m having is hacking away an inference function. The code is quite complex so not straightforward to do.
The simple decode.py works but it seems mainly experimental development for streaming.

I think streaming will be really cool, but we would also benefit from an easy model.generate(input_ids, attention_mask, params) function:

  • it should allow prefill based on the length of input_ids (user responsibility to try to supply not too many shapes to avoid recompilation)
  • it should allow batch input, with left padding to support different input length
  • should be compilable with jit/pjit
  • allow a few common sampling strategy: greedy, sample (with temperature, top k, top p), beam search
  • allow being used without a separate engine/service in case we want to make it part of a larger function that includes multiple models

This PR looked interesting: #402
I think that it was mainly for benchmarking though as it didn’t stop when the entire batch was eos but had a nice prefill functionality.

Long sequences are dropped rather than trimmed

From reading the dataset code, it looks like long sequences are dropped (

def length_filter(max_len):
) before they get a chance to be trimmed (
# trim to length
). This seems probably undesirable to me: some of the best documents in the corpus are likely in long sequences.

For comparison, seqio seems to trim but not discard long sequences (https://github.com/google/seqio/blob/515d917bf58da4103a2bbf39c3716213c36aff03/seqio/feature_converters.py#L535) and tensor2tensor seems to chop long sequences into multiple training examples (https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/data_generators/generator_utils.py#L612-L614). The tensor2tensor approach seems best, since we don't waste tokens from some of the (hypothesized) best documents in our corpus.

Should non-pod multihost be possible on TPU v2s/v3s?

I'm trying to use MaxText's multihost_runner to run across multiple TPU v2-8/v3-8 VMs.
In doing so, I'm running into the same issue described here when I try to setup multihost: google/jax#16708

I have changed the jax.distributed.init() to take a host IP, and number of processes as inputs. This works without crashing if I set the environment variable JAX_PLATFORM_NAME=cpu. Each process sees global CPUs from other processes, and has its own jax.process_index().

But when running with the TPU backend, each VM only sees its 8 local devices, and each vm has jax.process_index()==0. This results in the maxtext mulltihost_runner failing, since multiple processes try to write checkpoints etc. Even if I disable writing, the processes aren't actually communicating. They are basically just running their own local copies of the program.

One potential culprit is that I did not used queued resources to start my TPUs. But when I try running

gcloud alpha compute tpus queued-resources create my-queued-resource --accelerator-type=v3-8 --runtime-version=tpu-vm-base --node-count=2 --zone=us-central1-a --project=myproject-123456

I get

(gcloud.alpha.compute.tpus.queued-resources.create) INVALID_ARGUMENT: Cloud TPU was unable to complete the operation. Please try again, or contact support if the problem persists. [EID: 0x9640cfdf11d084b6]

Using queued resources to allocate them one at a time (using --node-id instead of --node-count) works without issue, which leads me to suspect that if --node-count does something to connect the TPUs being spun up, it isn't supported with v2s/v3s?

Is this expected behavior?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.