Comments (13)
Dear all,
I just added the support for multi-gpu training. You are welcomed to check out it.
- About memory usage:
You still need two 10G+ GPUs for now. I have not written the support for fp16 with multiple-gpu. (I will consider to support it in the near future.)
Some losses are still calculated on the first GPU, so the memory usage of the first gpu is larger than the second gpu.
The main reason is that copy.deepcopy
now not supports for multi-gpu. So for some losses and forward functions, I still keep them running on the first GPU.
- About speed:
I tested it on my two P6000
(the speed is close to GTX1080
)
Single GPU takes about 1.1s
for one iteration at the beginning.
Two GPUs take about 0.9s
for one iteration at the beginning.
(Since we add the teacher model calculation at the 30000th
iteration, the speed will slow down after the 30000th iteration.)
from dg-net.
Hi @layumi, if you are still here, please elaborate one more time on how you could make the adaptive instance normalization layer work in a multi-GPU mode with nn.DataParallel. I looked through the code and version history, but I didn't see any substantial amendments compared to the first commit.
Thank you
from dg-net.
trainer
is a high-level container. We need to specify the leaf models in trainer
.
if num_gpu>1:
#trainer.teacher_model = torch.nn.DataParallel(trainer.teacher_model, gpu_ids)
trainer.id_a = torch.nn.DataParallel(trainer.id_a, gpu_ids)
trainer.gen_a.enc_content = torch.nn.DataParallel(trainer.gen_a.enc_content, gpu_ids)
trainer.gen_a.mlp_w1 = torch.nn.DataParallel(trainer.gen_a.mlp_w1, gpu_ids)
trainer.gen_a.mlp_w2 = torch.nn.DataParallel(trainer.gen_a.mlp_w2, gpu_ids)
trainer.gen_a.mlp_w3 = torch.nn.DataParallel(trainer.gen_a.mlp_w3, gpu_ids)
trainer.gen_a.mlp_w4 = torch.nn.DataParallel(trainer.gen_a.mlp_w4, gpu_ids)
trainer.gen_a.mlp_b1 = torch.nn.DataParallel(trainer.gen_a.mlp_b1, gpu_ids)
trainer.gen_a.mlp_b2 = torch.nn.DataParallel(trainer.gen_a.mlp_b2, gpu_ids)
trainer.gen_a.mlp_b3 = torch.nn.DataParallel(trainer.gen_a.mlp_b3, gpu_ids)
trainer.gen_a.mlp_b4 = torch.nn.DataParallel(trainer.gen_a.mlp_b4, gpu_ids)
for dis_model in trainer.dis_a.cnns:
dis_model = torch.nn.DataParallel(dis_model, gpu_ids)
This code works on multiple GPUs. You may have a try.
Note that you also need to modify the code about saving model to save model.module
.
However, it is not the best solution. We still work on this.
You might notice I did not include the decoder.
trainer.gen_a.dec = torch.nn.DataParallel(trainer.gen_a.dec, gpu_ids)
It is due to the adaptive instance normalisation, which can not be duplicated on multi-gpu.
from dg-net.
Thank you! This is really helpful.
from dg-net.
trainer
is a high-level container. We need to specify the leaf models intrainer
.if num_gpu>1: trainer.teacher_model = torch.nn.DataParallel(trainer.teacher_model, gpu_ids) trainer.id_a = torch.nn.DataParallel(trainer.id_a, gpu_ids) trainer.gen_a.enc_content = torch.nn.DataParallel(trainer.gen_a.enc_content, gpu_ids) trainer.gen_a.mlp_w1 = torch.nn.DataParallel(trainer.gen_a.mlp_w1, gpu_ids) trainer.gen_a.mlp_w2 = torch.nn.DataParallel(trainer.gen_a.mlp_w2, gpu_ids) trainer.gen_a.mlp_w3 = torch.nn.DataParallel(trainer.gen_a.mlp_w3, gpu_ids) trainer.gen_a.mlp_w4 = torch.nn.DataParallel(trainer.gen_a.mlp_w4, gpu_ids) trainer.gen_a.mlp_b1 = torch.nn.DataParallel(trainer.gen_a.mlp_b1, gpu_ids) trainer.gen_a.mlp_b2 = torch.nn.DataParallel(trainer.gen_a.mlp_b2, gpu_ids) trainer.gen_a.mlp_b3 = torch.nn.DataParallel(trainer.gen_a.mlp_b3, gpu_ids) trainer.gen_a.mlp_b4 = torch.nn.DataParallel(trainer.gen_a.mlp_b4, gpu_ids) for dis_model in trainer.dis_a.cnns: dis_model = torch.nn.DataParallel(dis_model, gpu_ids)This code works on multiple GPUs. You may have a try.
Note that you also need to modify the code about saving model to savemodel.module
.However, it is not the best solution. We still work on this.
You might notice I did not include the decoder.trainer.gen_a.dec = torch.nn.DataParallel(trainer.gen_a.dec, gpu_ids)
It is due to the adaptive instance normalisation, which can not be duplicated on multi-gpu.
I notice that F.batch_norm() is used in class AdaptiveInstanceNorm2d, is it the reason?
from dg-net.
Not really. It is due to the value of w
and b
in adaptive instance normalisation layer.
https://github.com/NVlabs/DG-Net/blob/master/networks.py#L822-L823
We access the w
and b
on the fly, and use assign_adain_params
to obtain the current parameters.
https://github.com/NVlabs/DG-Net/blob/master/networks.py#L236
For pytorch DataParallel
, it splits the batch into several parts and duplicates the network into all gpus, which does not match the size of w
and b
.
For example, we use the min-batch of 8 samples and have two gpus. The input of each GPU is 4 samples. But the w
and b
is 8, since they are duplicated from the original full model.
from dg-net.
@layumi Thank you, it's really helpful. But any reference to modify the code?
from dg-net.
Hi @chenxingjian
I am working on it and checking the results. If everything goes well, I will upload the code in the next week.
from dg-net.
It seems it works with multi-GPUs when you put the "assign_adain_params" function into the "Decoder" class.
from dg-net.
@FreemanG
Yes. You are right.
We could copy two encoder+decoder as one function at the beginning, so there will not be any problem about mismatched dimension.
In fact, I have written the code, and I am checking the result before I release it.
from dg-net.
Great 👍
from dg-net.
is it possible to use nn.parallel.replicate instead of deepcopy?
from dg-net.
@FreemanG
Yes. You are right.
We could copy two encoder+decoder as one function at the beginning, so there will not be any problem about mismatched dimension.In fact, I have written the code, and I am checking the result before I release it.
Hi, Thank you very much for implementing the multi-GPU training version,may I ask where the method you mentioned (copy two encoder+decoder as one function at the beginning) is reflected in the code, I did not find it in your latest version. Thank you very much.
from dg-net.
Related Issues (20)
- Train fp16 interrupt HOT 3
- Change the data set to train HOT 3
- doubt about a loss in the ‘gen_update’ HOT 2
- DG-Net_Trainer类的forward函数的一个疑问 HOT 1
- 如何采用训练好的模型去生成图片 HOT 1
- How to test accuracy. HOT 1
- in trainer.py /gen_update(), i cannot understand "encoder is tuned, input is fixed" and "encoder is fixed, input is tuned" HOT 1
- Doubt about dis_opt and gen_opt HOT 2
- How to use DG-Net for vehicle reid? HOT 2
- test_folder HOT 1
- Out of memory error while training the model HOT 1
- when I run train.py, print there error message. Can you help to solve the issue? Thank you ! HOT 1
- 关于inception score HOT 1
- 关于FID计算 HOT 4
- IndexError:Dimension out of range(excepted to be in range of [-1,0],but got 1) HOT 4
- how does two different networks extract appearance and clothing separately?
- appearance code and structure code HOT 1
- online feeding and identity supervision HOT 3
- test_2label报错 yaml.load()缺少 loader HOT 2
- why detach the gradient of appearance codes? HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from dg-net.