Code Monkey home page Code Monkey logo

Comments (8)

adarob avatar adarob commented on August 22, 2024

Thanks for reporting this -- it looks like a GPU-specific bug in T5X. Can you try setting

partitioning.PjitPartitioner:
  model_parallel_submesh = (1, 1)
  logical_axis_rules= @partitioning.standard_logical_axis_rules()

from t5x.

adarob avatar adarob commented on August 22, 2024

@jekbradbury as FYI

Also, I think your activation/parameter partitioning dims options will be no-ops since you're not using model parallelism.

from t5x.

Namco0816 avatar Namco0816 commented on August 22, 2024

Thanks for your reply!
I've modified the code with:
partitioning.PjitPartitioner:

  model_parallel_submesh = (1, 1)
  logical_axis_rules= @partitioning.standard_logical_axis_rules()

However I still get the error:

Traceback (most recent call last):
2296   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/runpy.py", line 196, in _run_module_as_main
2297   │     return _run_code(code, main_globals, None,
2298   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/runpy.py", line 86, in _run_code
2299   │     exec(code, run_globals)
2300   │   File "/mnt/cache/namco/t5x/t5x/train.py", line 659, in <module>
2301   │     gin_utils.run(main)
2302   │   File "/mnt/cache/namco/t5x/t5x/gin_utils.py", line 105, in run
2303   │     app.run(
2304   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/absl/app.py", line 312, in run
2305   │     _run_main(main, args)
2306   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/absl/app.py", line 258, in _run_main
2307   │     sys.exit(main(argv))
2308   │   File "/mnt/cache/namco/t5x/t5x/train.py", line 637, in main
2309   │     _main(argv)
2310   │   File "/mnt/cache/namco/t5x/t5x/train.py", line 657, in _main
2311   │     train_using_gin()
2312   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/gin/config.py", line 1605, in gin_wrapper
2313   │     utils.augment_exception_message_and_reraise(e, err_str)
2314   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
2315   │     raise proxy.with_traceback(exception.__traceback__) from None
2316   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/gin/config.py", line 1582, in gin_wrapper
2317   │     return fn(*new_args, **new_kwargs)
2318   │   File "/mnt/cache/namco/t5x/t5x/train.py", line 507, in train
2319   │     trainer.compile_train(first_batch)
2320   │   File "/mnt/cache/namco/t5x/t5x/trainer.py", line 549, in compile_train
2321   │     self._compiled_train_step = self._partitioner.compile(
2322   │   File "/mnt/cache/namco/t5x/t5x/partitioning.py", line 779, in compile
2323   │     return partitioned_fn.lower(*args).compile()
2324   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/stages.py", line 174, in compile
2325   │     self._lowering.compile(), self.in_tree, self.in_avals,
2326   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 2280, in compile
2327   │     self._executable = MeshExecutable.from_hlo(
2328   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 2371, in from_hlo
2329   │     xla_executable = dispatch.compile_or_get_cached(backend, computation, compile_options)
2330   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/dispatch.py", line 583, in compile_or_get_cached
2331   │     return backend_compile(backend, computation, compile_options)
2332   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
2333   │     return func(*args, **kwargs)
2334   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/dispatch.py", line 537, in backend_compile
2335   │     return backend.compile(built_c, compile_options=options)
2336   │ RuntimeError: UNIMPLEMENTED: Requested AllReduce not implemented on GPU; replica_count: 1; partition_count: 8, group_mode: kCrossReplicaAndPartition, operand_count
       │ : 26; NCCL support: 1; first operand array element-type: BF16
2337   │   In call to configurable 'train' (<function train at 0x7f3d01a81240>)
2338   │ Rewritten gin arg: --gin_bindings=MODEL_DIR = "/mnt/lustre/namco/jax-model/t5-base"

I am also very confused about the partition rules in the partitioning.py. I've noticed that the get_gpu_mesh will return the mesh for gpu. This function will return a (1, 8) mesh for my 8 A100 GPUs machine. The first dimension 1 represents the host_num and second dimension 8 represent the num_gpus. If I understand correctly, based on the partition rules, 'data' will be assigned to the first axis and 'model' will be assigned to the second axis. However the simple ddp will shard the input_data across 8 GPUs, which means that the partition rules should be ("batch", "model"), however the code provided in the partitioning example define the data parallel as ("batch", "data"). I am really confused about this part.

Thanks for your help!

from t5x.

ibulu avatar ibulu commented on August 22, 2024

I think I am getting a related error when trying to fine-tune longT5 model on CPU:

ValueError: Failed to map logical axes for target/decoder/logits_dense/kernel
In call to configurable 'train' (<function train at 0x17a77b940>)

from t5x.

sudhakarsingh27 avatar sudhakarsingh27 commented on August 22, 2024

TLDR;

@adarob For data+model parallelism on GPUs, is model_parallel_submesh=(1,1,1,<#GPU for model parallelism>) the way to go for a single node multi-gpu case (seems to be suggested by this line in the code as well)?
Thanks!


More context:

I'm working with this config: t5_1_1/base.gin and I faced the same error as OP when I used the default partitioning config. (My intention is to run data+model parallel).

I followed @adarob 's suggestion but couldn't get it running.

Thanks for reporting this -- it looks like a GPU-specific bug in T5X. Can you try setting

partitioning.PjitPartitioner:
  model_parallel_submesh = (1, 1)
  logical_axis_rules= @partitioning.standard_logical_axis_rules()

After playing around, I found the following partitioning rule:

partitioning.PjitPartitioner:
  model_parallel_submesh = (1, 1, 1, 2)
  logical_axis_rules= @partitioning.standard_logical_axis_rules()

seems to get the following mesh of devices [4,2] (and mapping to ('data', 'model')) which is what I wanted.

[[GpuDevice(id=0, process_index=0) GpuDevice(id=1, process_index=0)]
 [GpuDevice(id=2, process_index=0) GpuDevice(id=3, process_index=0)]
 [GpuDevice(id=4, process_index=0) GpuDevice(id=5, process_index=0)]
 [GpuDevice(id=6, process_index=0) GpuDevice(id=7, process_index=0)]]

from t5x.

sudhakarsingh27 avatar sudhakarsingh27 commented on August 22, 2024

Also, as @Namco0816 pointed out, with using data parallelism only, the GPU mesh returned is

[[GpuDevice(id=0, process_index=0) GpuDevice(id=1, process_index=0)                                                                                                                                        
  GpuDevice(id=2, process_index=0) GpuDevice(id=3, process_index=0)                                                                                                                                        
  GpuDevice(id=4, process_index=0) GpuDevice(id=5, process_index=0)                                                                                                                                        
  GpuDevice(id=6, process_index=0) GpuDevice(id=7, process_index=0)]]

which is of dim [1,8] and maps to ('data','model') axes and so there's effectively no data parallelism even when data only parallelism is selected. I think this function is completely agnostic of data/model parallel axes/partitions and therefore we see the issue.
Can someone confirm this? @adarob @jekbradbury

from t5x.

StephennFernandes avatar StephennFernandes commented on August 22, 2024

@adarob , some of us are actually trying to pretrain and finetune locally on GPUs, as i fear reduced batch_size could affect generalizations. is there DeepSpeed ZeRO integration in t5x ?

from t5x.

adarob avatar adarob commented on August 22, 2024

Fixed by #643

from t5x.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.