Comments (5)
Update, I figured out a solution by:
- passing --overwrite in the command line
- add
dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
in the main() func
Then, I ran into the error below:
-- Process 0 terminated with the following error:
Traceback (most recent call last):
File "/home/xxx/miniconda3/envs/a2p_env/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
fn(i, *args)
File "/home/xxx/code/audio2photoreal/train/train_diffusion.py", line 77, in main
TrainLoop(
File "/home/xxx/code/audio2photoreal/train/training_loop.py", line 151, in run_loop
self.run_step(motion, cond)
File "/home/xxx/code/audio2photoreal/train/training_loop.py", line 175, in run_step
self.forward_backward(batch, cond)
File "/home/xxx/code/audio2photoreal/train/training_loop.py", line 201, in forward_backward
losses = compute_losses()
File "/home/xxx/code/audio2photoreal/diffusion/respace.py", line 110, in training_losses
return super().training_losses(self._wrap_model(model), *args, **kwargs)
File "/home/xxx/code/audio2photoreal/diffusion/respace.py", line 121, in _wrap_model
return _WrappedModel(
File "/home/xxx/code/audio2photoreal/diffusion/respace.py", line 135, in __init__
self.add_frame_cond = model.add_frame_cond
File "/home/xxx/miniconda3/envs/a2p_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1614, in __getattr__
raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'DistributedDataParallel' object has no attribute 'add_frame_cond'
My modified train_diffusion.py
is shown as below for your reference:
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import json
import os
import torch
import torch.multiprocessing as mp
from data_loaders.get_data import get_dataset_loader, load_local_data
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
from train.train_platforms import ClearmlPlatform, NoPlatform, TensorboardPlatform
from train.training_loop import TrainLoop
from utils.diff_parser_utils import train_args
from utils.misc import cleanup, fixseed, setup_dist
from utils.model_util import create_model_and_diffusion
def main(rank: int, world_size: int):
args = train_args()
fixseed(args.seed)
train_platform_type = eval(args.train_platform_type)
train_platform = train_platform_type(args.save_dir)
train_platform.report_args(args, name="Args")
setup_dist(args.device)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '9000'
# Initialize the distributed environment
dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
if rank == 0:
if args.save_dir is None:
raise FileNotFoundError("save_dir was not specified.")
elif os.path.exists(args.save_dir) and not args.overwrite:
raise FileExistsError("save_dir [{}] already exists.".format(args.save_dir))
elif not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
args_path = os.path.join(args.save_dir, "args.json")
with open(args_path, "w") as fw:
json.dump(vars(args), fw, indent=4, sort_keys=True)
if not os.path.exists(args.data_root):
args.data_root = args.data_root.replace("/home/", "/derived/")
data_dict = load_local_data(args.data_root, audio_per_frame=1600)
print("creating data loader...")
data = get_dataset_loader(args=args, data_dict=data_dict)
print("creating logger...")
writer = SummaryWriter(args.save_dir)
print("creating model and diffusion...")
model, diffusion = create_model_and_diffusion(args, split_type="train")
model.to(rank)
if world_size > 1:
model = DDP(
model, device_ids=[rank], output_device=rank, find_unused_parameters=True
)
params = (
model.module.parameters_w_grad()
if world_size > 1
else model.parameters_w_grad()
)
print("Total params: %.2fM" % (sum(p.numel() for p in params) / 1000000.0))
print("Training...")
TrainLoop(
args, train_platform, model, diffusion, data, writer, rank, world_size
).run_loop()
train_platform.close()
cleanup()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
print(f"using {world_size} gpus")
if world_size > 1:
mp.spawn(main, args=(world_size,), nprocs=world_size, join=True)
else:
main(rank=0, world_size=1)
from audio2photoreal.
Update: now I finally successfully executed the code, with the last modification as:
- comment this line
audio2photoreal/diffusion/respace.py
Line 135 in 3a94699
- replace with
self.add_frame_cond = False
Note that this is only a solution for training tasks with data_format face
.
from audio2photoreal.
One more hint for the those who are trying multiple GPU, you may change
audio2photoreal/model/diffusion.py
Line 276 in 3a94699
to
cp = torch.load(cp_path, map_location='cpu')
These will save a lot of memory for GPU #0.
:)
from audio2photoreal.
Hi!! Sorry for the delay, and thanks so much for debugging and finding the solutions to this issue! I really appreciate the active effort on this! :) Please let me know if there's anything else I can help with.
from audio2photoreal.
One more hint for the those who are trying multiple GPU, you may change
audio2photoreal/model/diffusion.py
Line 276 in 3a94699
to
cp = torch.load(cp_path, map_location='cpu')These will save a lot of memory for GPU #0.
:)
What is the total training time and how many GPUs are used? What GPUs are these?
Thanks
from audio2photoreal.
Related Issues (20)
- How can I manually rotate an avatar's head? HOT 2
- How to pass avatar renderer conditions HOT 1
- How to change the position of camera/model? HOT 1
- Training the model with different data format HOT 1
- The lips regressor predicts unexpected result HOT 5
- Switching from Recording to Uploading Audio in a Demo: Is it Possible? HOT 1
- Why the data is not as in the README ? HOT 2
- Models and pre-requisites models unavailable HOT 3
- Does it support languages other than English? HOT 1
- Models and pre-requisites models unavailable HOT 3
- What model was used to extract the body pose ? HOT 4
- Data acquisition and processing HOT 3
- The evaluation code for lip reconstructions HOT 1
- Is it possible to run the demo in a laptop without GPU? HOT 3
- Training inference time and test data HOT 2
- How to train a new model from scratch HOT 1
- Visualize 2 avatars in the same scene, just like the introduction page HOT 1
- Replancement of fairseq HOT 1
- Video data HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from audio2photoreal.