Code Monkey home page Code Monkey logo

Comments (7)

adarob avatar adarob commented on August 22, 2024

fyi @blester125

from t5x.

adarob avatar adarob commented on August 22, 2024

can you paste in your config?

from t5x.

dptam avatar dptam commented on August 22, 2024
python3 -m t5x.eval \
  --gin_search_paths="${T5X_DIR},${FLAXFORMER_DIR},${PROMPT_DIR}" \
  --gin_file="prompt_tuning/configs/models/t5_1_1_xl_prompt.gin" \
  --gin_file="prompt_tuning/configs/runs/prompt_eval.gin" \
  --gin.MIXTURE_OR_TASK_NAME="'glue_rte_32_shot_32_seed'" \
  --gin.MIXTURE_OR_TASK_MODULE="'prompt_tuning.data.few_glue'" \
  --gin.TASK_FEATURE_LENGTHS="{'inputs': 512, 'targets': 8}" \
  --gin.CHECKPOINT_PATH="'${PRETRAINED_MODEL}'" \
  --gin.EVAL_OUTPUT_DIR="'${EVAL_DIR}'" \
  --gin.utils.DatasetConfig.split="'validation'" \
  --gin.utils.DatasetConfig.batch_size="128" \
  --gin.USE_CACHED_TASKS="False" \
  --gin.partitioning.ModelBasedPjitPartitioner.model_parallel_submesh="(4, 4, 1, 2)" \
  --gin.PROMPT_FILE="'${PROMPT_FILE}'" \
  --tfds_data_dir=${TFDS_DATA_DIR}

I also tried removing the gin.partitioning.ModelBasedPjitPartitioner.model_parallel_submesh value, but got the same error. Thanks

from t5x.

adarob avatar adarob commented on August 22, 2024

This looks like a bug in the prompt-tuning configs -- can you file a bug there?
A given model_parallel_submesh work for every TPU size. I'd recommend using num_partitions instead.
FYI @blester125 @nconstant-google

from t5x.

dptam avatar dptam commented on August 22, 2024

Yup, I can file a bug there and try num_partitions

I used that same model_parallel_submesh value for training on the same machine and it worked. Should it work for inference if it worked for training?

from t5x.

adarob avatar adarob commented on August 22, 2024

Yes, that likely isn't the issue -- although I would recommend @blester125 update their example configs regardless.

Could you log the trainstate that is being passed to

  File "/home/dptam/.local/lib/python3.8/site-packages/t5x/utils.py", line 365, in __init__
    self.train_state_axes = partitioner.get_mesh_axes(

and share it?

from t5x.

dptam avatar dptam commented on August 22, 2024

This is the print for self.global_train_state_shape on line 365 being passed into partitioner.get_mesh_axes

pre_cross_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_self_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            self_attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
        },
        layers_3: {
            encoder_decoder_attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_cross_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_self_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            self_attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
        },
        layers_4: {
            encoder_decoder_attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_cross_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_self_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            self_attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
        },
        layers_5: {
            encoder_decoder_attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_cross_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_self_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            self_attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
        },
        layers_6: {
            encoder_decoder_attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_cross_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_self_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            self_attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
        },
        layers_7: {
            encoder_decoder_attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_cross_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_self_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            self_attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
        },
        layers_8: {
            encoder_decoder_attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_cross_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_self_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            self_attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
        },
        layers_9: {
            encoder_decoder_attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_cross_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_self_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            self_attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
        },
        logits_dense: {
            kernel: ShapeDtypeStruct(shape=(2048, 32128), dtype=float32),
        },
        relpos_bias: {
            rel_embedding: ShapeDtypeStruct(shape=(32, 32), dtype=float32),
        },
    },
    encoder: {
        encoder_norm: {
            scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
        },
        layers_0: {
            attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
        },
        layers_1: {
            attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
        },
        layers_10: {
            attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
        },
        layers_11: {
            attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
        },
        layers_12: {
            attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
        },
        layers_13: {
            attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
        },
        layers_14: {
            attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
        },
        layers_15: {
            attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
        },
        layers_16: {
            attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
        },
        layers_17: {
            attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
        },
        layers_18: {
            attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
        },
        layers_19: {
            attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
        },
        layers_2: {
            attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
        },
        layers_20: {
            attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
        },
        layers_21: {
            attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
        },
        layers_22: {
            attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
        },
        layers_23: {
            attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
        },
        layers_3: {
            attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
        },
        layers_4: {
            attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
        },
        layers_5: {
            attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
        },
        layers_6: {
            attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
        },
        layers_7: {
            attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
        },
        layers_8: {
            attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
        },
        layers_9: {
            attention: {
                key: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                out: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                query: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
                value: {
                    kernel: ShapeDtypeStruct(shape=(2048, 2048), dtype=float32),
                },
            },
            mlp: {
                wi_0: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wi_1: {
                    kernel: ShapeDtypeStruct(shape=(2048, 5120), dtype=float32),
                },
                wo: {
                    kernel: ShapeDtypeStruct(shape=(5120, 2048), dtype=float32),
                },
            },
            pre_attention_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
            pre_mlp_layer_norm: {
                scale: ShapeDtypeStruct(shape=(2048,), dtype=float32),
            },
        },
        prompt: {
            prompt: {
                prompt: ShapeDtypeStruct(shape=(100, 2048), dtype=float32),
            },
        },
        relpos_bias: {
            rel_embedding: ShapeDtypeStruct(shape=(32, 32), dtype=float32),
        },
    },
    token_embedder: {
        embedding: ShapeDtypeStruct(shape=(32128, 2048), dtype=float32),
    },
}), params_axes=FrozenDict({
    decoder: {
        decoder_norm: {
            scale_axes: AxisMetadata(names=('embed',)),
        },
        layers_0: {
            encoder_decoder_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_cross_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_self_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            self_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
        },
        layers_1: {
            encoder_decoder_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_cross_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_self_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            self_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
        },
        layers_10: {
            encoder_decoder_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_cross_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_self_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            self_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
        },
        layers_11: {
            encoder_decoder_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_cross_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_self_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            self_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
        },
        layers_12: {
            encoder_decoder_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_cross_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_self_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            self_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
        },
        layers_13: {
            encoder_decoder_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_cross_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_self_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            self_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
        },
        layers_14: {
            encoder_decoder_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_cross_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_self_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            self_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
        },
        layers_15: {
            encoder_decoder_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_cross_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_self_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            self_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
        },
        layers_16: {
            encoder_decoder_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_cross_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_self_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            self_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
        },
        layers_17: {
            encoder_decoder_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_cross_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_self_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            self_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
        },
        layers_18: {
            encoder_decoder_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_cross_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_self_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            self_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
        },
        layers_19: {
            encoder_decoder_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_cross_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_self_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            self_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
        },
        layers_2: {
            encoder_decoder_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_cross_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_self_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            self_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
        },
        layers_20: {
            encoder_decoder_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_cross_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_self_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            self_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
        },
        layers_21: {
            encoder_decoder_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_cross_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_self_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            self_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
        },
        layers_22: {
            encoder_decoder_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_cross_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_self_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            self_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
        },
        layers_23: {
            encoder_decoder_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_cross_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_self_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            self_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
        },
        layers_3: {
            encoder_decoder_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_cross_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_self_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            self_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
        },
        layers_4: {
            encoder_decoder_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_cross_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_self_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            self_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
        },
        layers_5: {
            encoder_decoder_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_cross_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_self_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            self_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
        },
        layers_6: {
            encoder_decoder_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_cross_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_self_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            self_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
        },
        layers_7: {
            encoder_decoder_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_cross_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_self_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            self_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
        },
        layers_8: {
            encoder_decoder_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_cross_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_self_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            self_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
        },
        layers_9: {
            encoder_decoder_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_cross_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_self_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            self_attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
        },
        logits_dense: {
            kernel_axes: AxisMetadata(names=('embed', 'vocab')),
        },
        relpos_bias: {
            rel_embedding_axes: AxisMetadata(names=('heads', 'relpos_buckets')),
        },
    },
    encoder: {
        encoder_norm: {
            scale_axes: AxisMetadata(names=('embed',)),
        },
        layers_0: {
            attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
        },
        layers_1: {
            attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
        },
        layers_10: {
            attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
        },
        layers_11: {
            attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
        },
        layers_12: {
            attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
        },
        layers_13: {
            attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
        },
        layers_14: {
            attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
        },
        layers_15: {
            attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
        },
        layers_16: {
            attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
        },
        layers_17: {
            attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
        },
        layers_18: {
            attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
        },
        layers_19: {
            attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
        },
        layers_2: {
            attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
        },
        layers_20: {
            attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
        },
        layers_21: {
            attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
        },
        layers_22: {
            attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
        },
        layers_23: {
            attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
        },
        layers_3: {
            attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
        },
        layers_4: {
            attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
        },
        layers_5: {
            attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
        },
        layers_6: {
            attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
        },
        layers_7: {
            attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
        },
        layers_8: {
            attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
        },
        layers_9: {
            attention: {
                key: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                out: {
                    kernel_axes: AxisMetadata(names=('joined_kv', 'embed')),
                },
                query: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
                value: {
                    kernel_axes: AxisMetadata(names=('embed', 'joined_kv')),
                },
            },
            mlp: {
                wi_0: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wi_1: {
                    kernel_axes: AxisMetadata(names=('embed', 'mlp')),
                },
                wo: {
                    kernel_axes: AxisMetadata(names=('mlp', 'embed')),
                },
            },
            pre_attention_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
            pre_mlp_layer_norm: {
                scale_axes: AxisMetadata(names=('embed',)),
            },
        },
        prompt: {
            prompt: {
                prompt_axes: AxisMetadata(names=('prompt', 'embed')),
            },
        },
        relpos_bias: {
            rel_embedding_axes: AxisMetadata(names=('heads', 'relpos_buckets')),
        },
    },
    token_embedder: {
        embedding_axes: AxisMetadata(names=('vocab', 'embed')),
    },
}), flax_mutables=FrozenDict({}))

from t5x.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.