Code Monkey home page Code Monkey logo

matformer-olmo's Introduction

MatFormer-OLMo

This is a public reproduction and open source release of the MatFormer's language modeling experiments (MatLM).

MatFormer-OLMo is built on an older codebase of AI2's OLMo and also releases 3 model checkpoints trained on the compute cluster of the Kempner Institute at the Harvard University.

A similar open source release for the vision encoder experiments of MatFormer (MatViT) can be found in the MatViT project of the Scenic library.

Setup

After cloning this repository, first install the latest PyTorch according the official instructions relevant to your environment. Then install the remaining dependencies and code base by running:

pip install -e .[dev]

Setup the paths appropriately in the scripts/env_pile_test.sh for HF_HOME; scripts/pile_test.sh for the SLURM setup; configs/pile-tiny.yaml for data and save paths.

All the models are trained on the Pile corpus tokenized using EleutherAI/gpt-neox-20b tokenizer.

Running LM pre-training jobs

Our training script is scripts/train.py, which should be launched either through torchrun or Slurm (see below) since it only supports distributed training (on GPUs). The first argument to the training script is a path to a training configuration file. Then it takes any number of optional arguments that can be used to override values from the configuration file using dot notation. For example, to change the learning rate you'd pass --optimizer.learning_rate=0.0001.

To use MatFormer structure use the --matformer_factor flag. Setting --matformer_factor=1 results in vanilla baseline model, while using --matformer_factor=8 has 4 exponential granularities in the MLP {h, h/2, h/4, h/8}.

Please check the matformer-expample-commands for pretraining baseline and MatFormer models along with finetuning on a released checkpoint.

MatFormer-OLMo Checkpoints

Name #Parameters #Tokens Checkpoint
MatFormer-OLMo-180M 180M 20B Link
MatFormer-OLMo-460M 460M 40B Link
MatFormer-OLMo-1300M 1.3B 160B Link

You can load a checkpoint like this:

from olmo import Olmo, Tokenizer

checkpoint = "MatFormer-OLMo-1300M"
model = Olmo.from_checkpoint(checkpoint, device="cuda")
tokenizer = Tokenizer.from_checkpoint(checkpoint)

Generating text

You can use the generate() method to produce text using beam search with a variety of options.

For example:

# Prepare inputs.
# Note: we don't want the EOS token added to the end of the input, hence
# the `add_special_tokens=False`.
input_ids = tokenizer.encode("I'm a large language model, ", add_special_tokens=False)
# `model.generate()` expects a batch.
input_tensor = torch.tensor(input_ids).unsqueeze(0)

# Run beam search.
outputs = model.generate(input_tensor, max_steps=3, beam_size=3)

# The output token IDs are shape (batch_size, beam_size, max_steps)
best_generation = outputs.token_ids[0][0].tolist()
print(tokenizer.decode(best_generation))

matformer-olmo's People

Contributors

adityakusupati avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

matformer-olmo's Issues

Some difference on benchmark score with lm-eval compared with the original paper

Hello. I was doing some test on the "MatFormer-OLMo-1300M" checkpoint in this repo. When I measured the performance of some benchmarks and compared them with the data from the paper(Table 14), I found that some benchmark have quite different results.
Result
Generally, the results I get from MatFormer-OLMo-1300M with lm-eval has lower accuracy in most benchmarks.

Here is the code I used to collect data and produce the image: CollectData.zip
To use the code, run python collect_data.py with environment variable MATFORMER_OLMO_CHECKPOINT_PATH set to the checkpoint directory to collect data. Then run python draw_matformer.py . to produce the image.
In the code, I packed the model into huggingface PreTrainedModel, in file "utils/LoadModel.py", to use the lm-eval interface .

Any ideas about why the difference happened?

Issues with FineTuning Checkpoint

We were trying to finetune a Matformer checkpoint ( MatFormer-OLMo-180M Link )

We used the following command to call the training script

python train.py ../configs/pile-tiny.yaml \
    --matformer_factor=8 \
    --matformer_factor=8 \
    --model.d_model=512 \
    --model.n_heads=16 \
    --model.n_layers=8 \
    --model.max_sequence_length=2048 \
    --device_train_microbatch_size=8 \
    --global_train_batch_size=128 \
    --max_duration=75000  \
    --optimizer.learning_rate=1.0e-3 \
    --console_log_interval=10 \
    --load_path=:"/raid/ganesh/namitha/Skill_localization_experiment/ckpt_paths/MatFormer-OLMo-180M" \
    --run_name="matformer-olmo-180M-finetune"

where the folder mentioned in load_path is obtained by download from the link mentioned in the README for MatFormer-OLMo-180M .

However running this gives us the following error

[2024-04-18 09:09:04] CRITICAL [root, rank=0] Uncaught ValueError: Must flatten tensors on the same device but got both cuda:0 and meta
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /raid/ganesh/namitha/Skill_localization_experiment/MatFormer-OLMo/scripts/train.py:229 in <modul │
│                                                                                                  │
│   226 │   │   raise OlmoCliError(f"Usage: {sys.argv[0]} [CONFIG_PATH] [OPTIONS]")                │
│   227 │   print([clean_opt(s) for s in args_list])                                               │
│   228 │   cfg = TrainConfig.load(yaml_path, [clean_opt(s) for s in args_list])                   │
│ ❱ 229 │   main(cfg)                                                                              │
│   230                                                                                            │
│                                                                                                  │
│ /raid/ganesh/namitha/Skill_localization_experiment/MatFormer-OLMo/scripts/train.py:108 in main   │
│                                                                                                  │
│   105 │   log.info(f"Number of non-embedding parameters: {olmo_model.num_params(include_embeddin │
│   106 │   torch.distributed.init_process_group(backend='nccl',rank=0, world_size=1)              │
│   107 │   # Wrap the model in FSDP.                                                              │
│ ❱ 108 │   fsdp_model = FSDP(                                                                     │
│   109 │   │   olmo_model,                                                                        │
│   110 │   │   sharding_strategy=cfg.fsdp.sharding_strategy,                                      │
│   111 │   │   mixed_precision=MixedPrecision(  # equivalent to MosaicML's "PURE"                 │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    474 │   │   │   │   # process groups.                                                         │
│    475 │   │   │   │   root_kwargs["process_group"] = (self.process_group, self._inter_node_pg)  │
│    476 │   │   │                                                                                 │
│ ❱  477 │   │   │   _auto_wrap(                                                                   │
│    478 │   │   │   │   module,                                                                   │
│    479 │   │   │   │   auto_wrap_policy,                                                         │
│    480 │   │   │   │   self._ignored_modules,                                                    │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    98 │   │   )                                                                                  │
│    99 │   │   recursive_wrap_kwargs["auto_wrap_policy"] = policy                                 │
│   100 │   │   _warn_on_overridden_mixed_precision(overridden_module_classes)                     │
│ ❱ 101 │   _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]      │
│   102                                                                                            │
│   103                                                                                            │
│   104 def _check_nested_wrapping(root_module: nn.Module):                                        │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   540 │   │   for name, child in module.named_children():                                        │
│   541 │   │   │   if child in ignored_modules:                                                   │
│   542 │   │   │   │   continue                                                                   │
│ ❱ 543 │   │   │   wrapped_child, num_wrapped_params = _recursive_wrap(                           │
│   544 │   │   │   │   module=child,                                                              │
│   545 │   │   │   │   auto_wrap_policy=auto_wrap_policy,                                         │
│   546 │   │   │   │   wrapper_cls=wrapper_cls,                                                   │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   540 │   │   for name, child in module.named_children():                                        │
│   541 │   │   │   if child in ignored_modules:                                                   │
│   542 │   │   │   │   continue                                                                   │
│ ❱ 543 │   │   │   wrapped_child, num_wrapped_params = _recursive_wrap(                           │
│   544 │   │   │   │   module=child,                                                              │
│   545 │   │   │   │   auto_wrap_policy=auto_wrap_policy,                                         │
│   546 │   │   │   │   wrapper_cls=wrapper_cls,                                                   │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   540 │   │   for name, child in module.named_children():                                        │
│   541 │   │   │   if child in ignored_modules:                                                   │
│   542 │   │   │   │   continue                                                                   │
│ ❱ 543 │   │   │   wrapped_child, num_wrapped_params = _recursive_wrap(                           │
│   544 │   │   │   │   module=child,                                                              │
│   545 │   │   │   │   auto_wrap_policy=auto_wrap_policy,                                         │
│   546 │   │   │   │   wrapper_cls=wrapper_cls,                                                   │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   558 │   │   │   module=module, recurse=False, nonwrapped_numel=remainder                       │
│   559 │   │   ):                                                                                 │
│   560 │   │   │   # Leaf node or final wrapping of the remainder both happen here.               │
│ ❱ 561 │   │   │   return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel                  │
│   562 │   │   else:                                                                              │
│   563 │   │   │   return module, total_wrapped_numel                                             │
│   564 │   return module, 0                                                                       │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   487 │   │   overrides = {**kwargs, **module._wrap_overrides}  # type: ignore[arg-type]         │
│   488 │   │   return wrapper_cls(module, **overrides)                                            │
│   489 │                                                                                          │
│ ❱ 490 │   return wrapper_cls(module, **kwargs)                                                   │
│   491                                                                                            │
│   492                                                                                            │
│   493 def _recursive_wrap(                                                                       │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    500 │   │   _init_buffer_state(self, module)                                                  │
│    501 │   │   # extension needs to be set before `_init_param_handle_from_module()`             │
│    502 │   │   _init_extension(self, device_mesh)                                                │
│ ❱  503 │   │   _init_param_handle_from_module(                                                   │
│    504 │   │   │   self,                                                                         │
│    505 │   │   │   module,                                                                       │
│    506 │   │   │   device_id,                                                                    │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    587 │   │   │   _sync_module_params_and_buffers(                                              │
│    588 │   │   │   │   fully_sharded_module, managed_params, state._inter_node_pg                │
│    589 │   │   │   )                                                                             │
│ ❱  590 │   _init_param_handle_from_params(state, managed_params, fully_sharded_module)           │
│    591 │   return state                                                                          │
│    592                                                                                           │
│    593                                                                                           │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    599 ):                                                                                        │
│    600 │   if len(params) == 0:                                                                  │
│    601 │   │   return                                                                            │
│ ❱  602 │   handle = FlatParamHandle(                                                             │
│    603 │   │   params,                                                                           │
│    604 │   │   fully_sharded_module,                                                             │
│    605 │   │   state.compute_device,                                                             │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    570 │   │   │   else 0                                                                        │
│    571 │   │   )                                                                                 │
│    572 │   │   self._fsdp_extension = fsdp_extension                                             │
│ ❱  573 │   │   self._init_flat_param_and_metadata(                                               │
│    574 │   │   │   params, fully_sharded_module, self._aligned_numel, use_orig_params  # type: i │
│    575 │   │   )                                                                                 │
│    576 │   │   self._use_unsharded_views(as_params=False)                                        │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    620 │   │   │   dtype,                                                                        │
│    621 │   │   │   flat_param_requires_grad,                                                     │
│    622 │   │   │   device,                                                                       │
│ ❱  623 │   │   ) = self._validate_tensors_to_flatten(params)                                     │
│    624 │   │   params_set = set(params)                                                          │
│    625 │   │   # For alignment padding, only `numels` gets strictly non-`None`                   │
│    626 │   │   # elements, and all other lists get `None` elements for padding.                  │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    773 │   │   │   │   │   "`use_orig_params=False`"                                             │
│    774 │   │   │   │   )                                                                         │
│    775 │   │   │   if device is not None and tensor.device != device:                            │
│ ❱  776 │   │   │   │   raise ValueError(                                                         │
│    777 │   │   │   │   │   "Must flatten tensors on the same device but got both "               │
│    778 │   │   │   │   │   f"{device} and {tensor.device}"                                       │
│    779 │   │   │   │   )                                                                         │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: Must flatten tensors on the same device but got both cuda:0 and meta

We are unable to resolve this issue

We tried adding the following line to torch/distributed/fsdp/_init_utils.py

tensor.to("cuda:0")  

But this operation gives another error as follows

│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    753 │   │   device: Optional[torch.device] = None                                             │
│    754 │   │   # For `use_orig_params=True`, permit non-uniform `requires_grad`                  │
│    755 │   │   for tensor in tensors:                                                            │
│ ❱  756 │   │   │   tensor.to("cuda:0")                                                           │
│    757 │   │   │   if isinstance(tensor, FlatParameter):                                         │
│    758 │   │   │   │   raise ValueError("Cannot flatten a `FlatParameter`")                      │
│    759 │   │   │   if dtype is None and not tensor.is_floating_point():                          │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
NotImplementedError: Cannot copy out of meta tensor; no data!

We have made other changes to pile-tiny.yaml , scripts/train.py and scripts/util.py to make it compatible for training
I am attaching a zip of those files here :
changes.zip

Apart from this we were facing another issue

[2024-04-18 09:30:56] CRITICAL [root, rank=0] Uncaught AttributeError: 'LayerNorm' object has no attribute 'reset_parameters'
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /raid/ganesh/namitha/Skill_localization_experiment/MatFormer-OLMo/scripts/train.py:229 in <modul │
│                                                                                                  │
│   226 │   │   raise OlmoCliError(f"Usage: {sys.argv[0]} [CONFIG_PATH] [OPTIONS]")                │
│   227 │   print([clean_opt(s) for s in args_list])                                               │
│   228 │   cfg = TrainConfig.load(yaml_path, [clean_opt(s) for s in args_list])                   │
│ ❱ 229 │   main(cfg)                                                                              │
│   230                                                                                            │
│                                                                                                  │
│ /raid/ganesh/namitha/Skill_localization_experiment/MatFormer-OLMo/scripts/train.py:108 in main   │
│                                                                                                  │
│   105 │   log.info(f"Number of non-embedding parameters: {olmo_model.num_params(include_embeddin │
│   106 │   torch.distributed.init_process_group(backend='nccl',rank=0, world_size=1)              │
│   107 │   # Wrap the model in FSDP.                                                              │
│ ❱ 108 │   fsdp_model = FSDP(                                                                     │
│   109 │   │   olmo_model,                                                                        │
│   110 │   │   sharding_strategy=cfg.fsdp.sharding_strategy,                                      │
│   111 │   │   mixed_precision=MixedPrecision(  # equivalent to MosaicML's "PURE"                 │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    474 │   │   │   │   # process groups.                                                         │
│    475 │   │   │   │   root_kwargs["process_group"] = (self.process_group, self._inter_node_pg)  │
│    476 │   │   │                                                                                 │
│ ❱  477 │   │   │   _auto_wrap(                                                                   │
│    478 │   │   │   │   module,                                                                   │
│    479 │   │   │   │   auto_wrap_policy,                                                         │
│    480 │   │   │   │   self._ignored_modules,                                                    │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    98 │   │   )                                                                                  │
│    99 │   │   recursive_wrap_kwargs["auto_wrap_policy"] = policy                                 │
│   100 │   │   _warn_on_overridden_mixed_precision(overridden_module_classes)                     │
│ ❱ 101 │   _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]      │
│   102                                                                                            │
│   103                                                                                            │
│   104 def _check_nested_wrapping(root_module: nn.Module):                                        │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   540 │   │   for name, child in module.named_children():                                        │
│   541 │   │   │   if child in ignored_modules:                                                   │
│   542 │   │   │   │   continue                                                                   │
│ ❱ 543 │   │   │   wrapped_child, num_wrapped_params = _recursive_wrap(                           │
│   544 │   │   │   │   module=child,                                                              │
│   545 │   │   │   │   auto_wrap_policy=auto_wrap_policy,                                         │
│   546 │   │   │   │   wrapper_cls=wrapper_cls,                                                   │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   540 │   │   for name, child in module.named_children():                                        │
│   541 │   │   │   if child in ignored_modules:                                                   │
│   542 │   │   │   │   continue                                                                   │
│ ❱ 543 │   │   │   wrapped_child, num_wrapped_params = _recursive_wrap(                           │
│   544 │   │   │   │   module=child,                                                              │
│   545 │   │   │   │   auto_wrap_policy=auto_wrap_policy,                                         │
│   546 │   │   │   │   wrapper_cls=wrapper_cls,                                                   │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   540 │   │   for name, child in module.named_children():                                        │
│   541 │   │   │   if child in ignored_modules:                                                   │
│   542 │   │   │   │   continue                                                                   │
│ ❱ 543 │   │   │   wrapped_child, num_wrapped_params = _recursive_wrap(                           │
│   544 │   │   │   │   module=child,                                                              │
│   545 │   │   │   │   auto_wrap_policy=auto_wrap_policy,                                         │
│   546 │   │   │   │   wrapper_cls=wrapper_cls,                                                   │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   558 │   │   │   module=module, recurse=False, nonwrapped_numel=remainder                       │
│   559 │   │   ):                                                                                 │
│   560 │   │   │   # Leaf node or final wrapping of the remainder both happen here.               │
│ ❱ 561 │   │   │   return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel                  │
│   562 │   │   else:                                                                              │
│   563 │   │   │   return module, total_wrapped_numel                                             │
│   564 │   return module, 0                                                                       │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   487 │   │   overrides = {**kwargs, **module._wrap_overrides}  # type: ignore[arg-type]         │
│   488 │   │   return wrapper_cls(module, **overrides)                                            │
│   489 │                                                                                          │
│ ❱ 490 │   return wrapper_cls(module, **kwargs)                                                   │
│   491                                                                                            │
│   492                                                                                            │
│   493 def _recursive_wrap(                                                                       │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    500 │   │   _init_buffer_state(self, module)                                                  │
│    501 │   │   # extension needs to be set before `_init_param_handle_from_module()`             │
│    502 │   │   _init_extension(self, device_mesh)                                                │
│ ❱  503 │   │   _init_param_handle_from_module(                                                   │
│    504 │   │   │   self,                                                                         │
│    505 │   │   │   module,                                                                       │
│    506 │   │   │   device_id,                                                                    │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    549 │   │   │   fully_sharded_module, param_init_fn, state._ignored_modules                   │
│    550 │   │   )                                                                                 │
│    551 │   elif is_meta_module:                                                                  │
│ ❱  552 │   │   _materialize_meta_module(                                                         │
│    553 │   │   │   fully_sharded_module, device_id, state._ignored_modules                       │
│    554 │   │   )                                                                                 │
│    555 │   elif is_torchdistX_deferred_init:                                                     │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    881 │   │   │   f"device with error {str(e)}. Please ensure that your module of"              │
│    882 │   │   │   f"type {type(module)} implements a `reset_parameters()` method."              │
│    883 │   │   )                                                                                 │
│ ❱  884 │   │   raise e                                                                           │
│    885                                                                                           │
│    886                                                                                           │
│    887 def _get_modules_to_materialize(                                                          │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    874 │   │   │   │   has_module_states = len(list(module_state_iter)) > 0                      │
│    875 │   │   │   │   if has_module_states:                                                     │
│    876 │   │   │   │   │   module.to_empty(device=materialization_device, recurse=False)         │
│ ❱  877 │   │   │   │   │   module.reset_parameters()  # type: ignore[operator]                   │
│    878 │   except BaseException as e:                                                            │
│    879 │   │   warnings.warn(                                                                    │
│    880 │   │   │   "Unable to call `reset_parameters()` for module on meta "                     │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/nn/modules/modu │
│                                                                                                  │
│   1685 │   │   │   modules = self.__dict__['_modules']                                           │
│   1686 │   │   │   if name in modules:                                                           │
│   1687 │   │   │   │   return modules[name]                                                      │
│ ❱ 1688 │   │   raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") │
│   1689 │                                                                                         │
│   1690 │   def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:             │
│   1691 │   │   def remove_from(*dicts_or_sets):                                                  │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AttributeError: 'LayerNorm' object has no attribute 'reset_parameters'

However we circumvented this issue by commenting out the raise error (within torch/distributed/fsdp/_init_utils.py ) as follows

except BaseException as e:
        warnings.warn(
            "Unable to call `reset_parameters()` for module on meta "
            f"device with error {str(e)}. Please ensure that your module of"
            f"type {type(module)} implements a `reset_parameters()` method."
        )
        #raise e

I have attached the entire file within changes.zip , just in case

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.