Code Monkey home page Code Monkey logo

Comments (14)

yaxingwang avatar yaxingwang commented on June 1, 2024 1

@VincieD, I have fixed it. In get_data_loaders of utils.py, you should set start_itr = 0 when you use pretrained model, since using pretrained model, it is not 0. I guess start_itr refer to batchnorm or something. I didnot create 1000 clalsses, in fact you do not need. In generator the label is projected into vector by word2vector if I understand well, you can set the number of class.

from biggan-pytorch.

isaac-dunn avatar isaac-dunn commented on June 1, 2024

Hi, I think this is a problem at your end rather than with this repo - when I first tried, I was making a mistake which also caused training to collapse, so I'd probably try to check carefully that everything is set up and run properly! (Sorry if that's vague.)

from biggan-pytorch.

yaxingwang avatar yaxingwang commented on June 1, 2024

Thanks. I am checking one by one. It is vary helpful if you can point which parts. In fact, I also reset Lr(G, D) to 0., and only train one batch, the generated images still surfer from the same problem.

from biggan-pytorch.

VincieD avatar VincieD commented on June 1, 2024

@yaxingwang im wondering how did you solve the problem that the imagenet has 1000 classes? Did you created a dataset of yours with 1000 clases as well. Otherwise im getting an error, while loading the weights and dont know what will be the best way to solve it, since training from scratch is pain in the a... Thank you for any suggestion

from biggan-pytorch.

VincieD avatar VincieD commented on June 1, 2024

@yaxingwang Thank you for your reply, ive did as you proposed and hardcoded in BigGAN.py number of classes to 1000. Weights are loaded and lets see what will be the result. Did you experimented with hyper-parametr settings, or just left the original ones?

from biggan-pytorch.

yaxingwang avatar yaxingwang commented on June 1, 2024

The original setting is fine-tunning. The key point is about computing resource. If you conduct papally it with Multi-GPUs, results is interesting. Otherwise, using one GPU(12G) is difficulty. In fact, we compare the same setting with one GPU and 8*V100 in DGX1, the gap is larger. To be honest, it also rely on your data size.

from biggan-pytorch.

yaxingwang avatar yaxingwang commented on June 1, 2024

Let me know if you have new insight when you train.

from biggan-pytorch.

janzd avatar janzd commented on June 1, 2024

@VincieD Do you get a mismatch error? It's caused by the embedding layer. The embedding layer in the pretrained model has a shape of (1000, 128), so you can't use it if you use a different dataset. First, I removed the layer from the state dict that is loaded (the weights of the networks). PyTorch has a strict argument that allows you to load weights only for layers that are in both the pretrained model and your target model, which is useful for transferring weights, but here both models contain the embedding layer so it tries to load it. The strict argument only checks for the name of the layer, not the shape. But the shape is different so it will throw an error.
Removing the layer from the state dict wasn't enough. The load_weights function also loads the state of the optimizer and the optimizer uses a running average of gradients, which is tied to the parameters, that is the shape of the layers. So I just disabled the loading of the optimizer state and use a freshly initialized Adam.
Finally, when you use the resume flag, the network parameters are not initialized because it expects that they will all be loaded from the checkpoint. Since I don't load the weights for the embedding layers, they are not initialized. I don't know what values they hold when you just declare them in the model declaration and don't initialize their weights, but when they are not initialized it seems to totally cripple the network.
I'm still figuring out the code and have just started some experiments so I don't know how the changes I made affect the performance when finetuning, but at least I'm able to run it on my own dataset with an arbitrary number of classes. By the way, the scores (IS and FID) will probably be meaningless if you use a different dataset and use the InceptionV3 network trained on ImageNet. Swapping that with a network trained on my dataset is another thing I plan to do.
I didn't change start_itr to 0. I don't know if that's necessary. It seems to affect the decay and I'm not sure if it's better to reset it when finetuning or leave it as it is.

Here is my version of the load_weights function. Argument strict is set to False and skip_load_optim is set True. In train.py, I make sure that skip_init is not set True. There are various ways you could do it. For example, I defined a new argument finetune in addition to resume to differentiate between the two processes and modified the code accordingly.

def load_weights(G, D, state_dict, weights_root, experiment_name,
                 name_suffix=None, G_ema=None, strict=True, skip_load_optim=False):
  root = '/'.join([weights_root, experiment_name])
  if name_suffix:
    print('Loading %s weights from %s...' % (name_suffix, root))
  else:
    print('Loading weights from %s...' % root)
  if G is not None:
    pretrained_model = torch.load('%s/%s.pth' % (root, join_strings('_', ['G', name_suffix])))
    pretrained_model = {k: v for k, v in pretrained_model.items() if (k in G.state_dict()) and (G.state_dict()[k].shape == pretrained_model[k].shape)}
    G.load_state_dict(pretrained_model, strict=strict)
    if not skip_load_optim:
      G.optim.load_state_dict(
        torch.load('%s/%s.pth' % (root, join_strings('_', ['G_optim', name_suffix]))))
  if D is not None:
    pretrained_model = torch.load('%s/%s.pth' % (root, join_strings('_', ['D', name_suffix])))
    pretrained_model = {k: v for k, v in pretrained_model.items() if (k in D.state_dict()) and (D.state_dict()[k].shape == pretrained_model[k].shape)}
    D.load_state_dict(pretrained_model, strict=strict)
    if not skip_load_optim:
      D.optim.load_state_dict(
        torch.load('%s/%s.pth' % (root, join_strings('_', ['D_optim', name_suffix]))))
  # Load state dict  
  for item in state_dict:
    state_dict[item] = torch.load('%s/%s.pth' % (root, join_strings('_', ['state_dict', name_suffix])))[item]
  if G_ema is not None:
    pretrained_model = torch.load('%s/%s.pth' % (root, join_strings('_', ['G_ema', name_suffix])))
    pretrained_model = {k: v for k, v in pretrained_model.items() if (k in G_ema.state_dict()) and (G_ema.state_dict()[k].shape == pretrained_model[k].shape)}
    G_ema.load_state_dict(pretrained_model, strict=strict)

from biggan-pytorch.

tobschu94 avatar tobschu94 commented on June 1, 2024

Hello,
I'm trying to finetune BigGAN together with @VincieD. We use the weights at 100k iterations provided by the author and try to finetune it on the INRIA person dataset but only on the pedestrian class, so we have one class. We fixed the mismatch error by hardcoding the number of classes to 1000 as mentioned earlier, disregarding how many classes our new dataset has. We train on 128x128 images on 2 RTX 2070. So we are able to use a batch_size=24, num_G_accumulations=8 and also 8 for the discriminator. Here is our full hyperparameter setting

Screenshot from 2020-01-15 08-39-16

After a few tests we figured out that we get the best results after only 3k more iterations, see first figure. Some pedestrians are identifiable.

Iteration 3000
fixed_samples103000

So we started using ema for another 6k iterations to develope the finer structures better, but then it starts to generate always he same image looking like mode collapse or something.
Anyone an idea how to get better results?

iteration 9000
fixed_samples109000

And thanks @kurapan for your ideas, how are your results looking? have you any new information about your approach with the embedding layer?

from biggan-pytorch.

JanineCHEN avatar JanineCHEN commented on June 1, 2024

Hello,
I'm trying to finetune BigGAN together with @VincieD. We use the weights at 100k iterations provided by the author and try to finetune it on the INRIA person dataset but only on the pedestrian class, so we have one class. We fixed the mismatch error by hardcoding the number of classes to 1000 as mentioned earlier, disregarding how many classes our new dataset has. We train on 128x128 images on 2 RTX 2070. So we are able to use a batch_size=24, num_G_accumulations=8 and also 8 for the discriminator. Here is our full hyperparameter setting

Screenshot from 2020-01-15 08-39-16

After a few tests we figured out that we get the best results after only 3k more iterations, see first figure. Some pedestrians are identifiable.

Iteration 3000
fixed_samples103000

So we started using ema for another 6k iterations to develope the finer structures better, but then it starts to generate always he same image looking like mode collapse or something.
Anyone an idea how to get better results?

iteration 9000
fixed_samples109000

And thanks @kurapan for your ideas, how are your results looking? have you any new information about your approach with the embedding layer?

@tobschu94 Hi, I am having the same issue as yours, I am training with my own dataset which has fixed number of arbitrary classes which I cannot alter, so I tried to train from scratch with the following parameter settings:

#!/bin/bash
python train.py \
--dataset I128_hdf5 --parallel --shuffle  --num_workers 8 --batch_size 128 --load_in_mem \
--num_G_accumulations 16 --num_D_accumulations 16 --num_epochs 10000 \
--num_D_steps 1 --G_lr 5e-5 --D_lr 2e-4 --D_B2 0.999 --G_B2 0.999 \
--G_attn 32 --D_attn 32 \
--G_nl inplace_relu --D_nl inplace_relu \
--SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \
--G_ortho 0.0 \
--G_shared \
--G_init ortho --D_init ortho \
--hier --dim_z 120 --shared_dim 128 \
--G_eval_mode \
--which_best FID \
--G_ch 32 --D_ch 32 \
--ema --use_ema --ema_start 200 \
--test_every 50 --save_every 50 --num_best_copies 5 --num_save_copies 2 --seed 0 \
--use_multiepoch_sampler

The generated images also seem to have mode collapse issue after 13k iteration:
Screenshot from 2020-09-29 11-15-32

I am not sure if be a bit more patient and wait for training for a bit longer would help, or it would just stuck at the mode collapse status...

from biggan-pytorch.

tobschu94 avatar tobschu94 commented on June 1, 2024

@JanineCHEN Neither fixing the number of classes, training from scratch or fine tuning BigGAN worked in my case. At some point it is collapsing. Also training it for a long time and more than 120k iterations did not improve the results. So I switched to a different GAN made by [NVIDIA,] called ProGAN. This GAN worked quite well in my case. The structure of their code is very well to understand and also the principle of their GAN.

I hope this will help to achieve your desired results.

from biggan-pytorch.

JanineCHEN avatar JanineCHEN commented on June 1, 2024

@JanineCHEN Neither fixing the number of classes, training from scratch or fine tuning BigGAN worked in my case. At some point it is collapsing. Also training it for a long time and more than 120k iterations did not improve the results. So I switched to a different GAN made by [NVIDIA,] called ProGAN. This GAN worked quite well in my case. The structure of their code is very well to understand and also the principle of their GAN.

I hope this will help to achieve your desired results.

Thank you @tobschu94 for your kind advise. May I check on one thing: As far as I know, BigGAN is a conditional model and is layerwise stochastic (meaning the model acquire stochastic code for each layer) and this is something I need. Do you happen to know if ProGAN also possesses these traits?

from biggan-pytorch.

nnajeh avatar nnajeh commented on June 1, 2024

@tobschu94 how did u fine-tune the model in your dataset?

from biggan-pytorch.

thusinh1969 avatar thusinh1969 commented on June 1, 2024

I think you must triple the capacity of network to learn such perplexed dataset. Same here. Never worked for complex dataset. The GAN capability (no meter NVIDIA or Google) is simply not there yet friends, wait for some more time.

Steve

from biggan-pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.