Code Monkey home page Code Monkey logo

gan-compression's Introduction

GAN Compression

[NEW!] GAN Compression is accepted by T-PAMI! We released our T-PAMI version in the arXiv v4!

[NEW!] We release the codes of our interactive demo and include the TVM tuned model. It achieves 8FPS on Jetson Nano GPU now!

[NEW!] Add support to the MUNIT, a multimodal unsupervised image-to-image translation approach! Please follow the test commands to test the pre-trained models and the tutorial to train your own models!

teaser We introduce GAN Compression, a general-purpose method for compressing conditional GANs. Our method reduces the computation of widely-used conditional GAN models, including pix2pix, CycleGAN, MUNIT, and GauGAN, by 9-29x while preserving the visual fidelity. Our method is effective for a wide range of generator architectures, learning objectives, and both paired and unpaired settings.

GAN Compression: Efficient Architectures for Interactive Conditional GANs
Muyang Li, Ji Lin, Yaoyao Ding, Zhijian Liu, Jun-Yan Zhu, and Song Han
MIT, Adobe Research, SJTU
In CVPR 2020.

Demos

Overview

overviewGAN Compression framework: ① Given a pre-trained teacher generator G', we distill a smaller “once-for-all” student generator G that contains all possible channel numbers through weight sharing. We choose different channel numbers for the student generator G at each training step. ② We then extract many sub-generators from the “once-for-all” generator and evaluate their performance. No retraining is needed, which is the advantage of the “once-for-all” generator. ③ Finally, we choose the best sub-generator given the compression ratio target and performance target (FID or mIoU) using either brute-force or evolutionary search method. Optionally, we perform additional fine-tuning, and obtain the final compressed model.

Performance

performance

GAN Compression reduces the computation of pix2pix, cycleGAN and GauGAN by 9-21x, and model size by 4.6-33x.

Colab Notebook

PyTorch Colab notebook: CycleGAN and pix2pix.

Prerequisites

  • Linux
  • Python 3
  • CPU or NVIDIA GPU + CUDA CuDNN

Getting Started

Installation

  • Clone this repo:

    git clone [email protected]:mit-han-lab/gan-compression.git
    cd gan-compression
  • Install PyTorch 1.4 and other dependencies (e.g., torchvision).

    • For pip users, please type the command pip install -r requirements.txt.
    • For Conda users, we provide an installation script scripts/conda_deps.sh. Alternatively, you can create a new Conda environment using conda env create -f environment.yml.

CycleGAN

Setup

  • Download the CycleGAN dataset (e.g., horse2zebra).

    bash datasets/download_cyclegan_dataset.sh horse2zebra
  • Get the statistical information for the ground-truth images for your dataset to compute FID. We provide pre-prepared real statistic information for several datasets. For example,

    bash datasets/download_real_stat.sh horse2zebra A
    bash datasets/download_real_stat.sh horse2zebra B

Apply a Pre-trained Model

  • Download the pre-trained models.

    python scripts/download_model.py --model cycle_gan --task horse2zebra --stage full
    python scripts/download_model.py --model cycle_gan --task horse2zebra --stage compressed
  • Test the original full model.

    bash scripts/cycle_gan/horse2zebra/test_full.sh
  • Test the compressed model.

    bash scripts/cycle_gan/horse2zebra/test_compressed.sh
  • Measure the latency of the two models.

    bash scripts/cycle_gan/horse2zebra/latency_full.sh
    bash scripts/cycle_gan/horse2zebra/latency_compressed.sh
  • There may be a little differences between the results of above models and those of the paper since we retrained the models. We also release the compressed models of the paper. If there are such inconsistencies, you could try the following commands to test our paper models:

    python scripts/download_model.py --model cycle_gan --task horse2zebra --stage legacy
    bash scripts/cycle_gan/horse2zebra/test_legacy.sh
    bash scripts/cycle_gan/horse2zebra/latency_legacy.sh

Pix2pix

Setup

  • Download the pix2pix dataset (e.g., edges2shoes).

    bash datasets/download_pix2pix_dataset.sh edges2shoes-r
  • Get the statistical information for the ground-truth images for your dataset to compute FID. We provide pre-prepared real statistics for several datasets. For example,

    bash datasets/download_real_stat.sh edges2shoes-r B
    bash datasets/download_real_stat.sh edges2shoes-r subtrain_B

Apply a Pre-trained Model

  • Download the pre-trained models.

    python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage full
    python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage compressed
  • Test the original full model.

    bash scripts/pix2pix/edges2shoes-r/test_full.sh
  • Test the compressed model.

    bash scripts/pix2pix/edges2shoes-r/test_compressed.sh
  • Measure the latency of the two models.

    bash scripts/pix2pix/edges2shoes-r/latency_full.sh
    bash scripts/pix2pix/edges2shoes-r/latency_compressed.sh
  • There may be a little differences between the results of above models and those of the paper since we retrained the models. We also release the compressed models of the paper. If there are such inconsistencies, you could try the following commands to test our paper models:

    python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage legacy
    bash scripts/pix2pix/edges2shoes-r/test_legacy.sh
    bash scripts/pix2pix/edges2shoes-r/latency_legacy.sh

GauGAN

Setup

  • Prepare the cityscapes dataset. Check here for preparing the cityscapes dataset.

  • Get the statistical information for the ground-truth images for your dataset to compute FID. We provide pre-prepared real statistics for several datasets. For example,

    bash datasets/download_real_stat.sh cityscapes A

Apply a Pre-trained Model

  • Download the pre-trained models.

    python scripts/download_model.py --model gaugan --task cityscapes --stage full
    python scripts/download_model.py --model gaugan --task cityscapes --stage compressed
  • Test the original full model.

    bash scripts/gaugan/cityscapes/test_full.sh
  • Test the compressed model.

    bash scripts/gaugan/cityscapes/test_compressed.sh
  • Measure the latency of the two models.

    bash scripts/gaugan/cityscapes/latency_full.sh
    bash scripts/gaugan/cityscapes/latency_compressed.sh
  • There may be a little differences between the results of above models and those of the paper since we retrained the models. We also release the compressed models of the paper. If there are such inconsistencies, you could try the following commands to test our paper models:

    python scripts/download_model.py --model gaugan --task cityscapes --stage legacy
    bash scripts/gaugan/cityscapes/test_legacy.sh
    bash scripts/gaugan/cityscapes/latency_legacy.sh

MUNIT

Setup

  • Prepare the dataset (e.g., edges2shoes-r).

    bash datasets/download_pix2pix_dataset.sh edges2shoes-r
    python datasets/separate_A_and_B.py --input_dir database/edges2shoes-r --output_dir database/edges2shoes-r-unaligned
    python datasets/separate_A_and_B.py --input_dir database/edges2shoes-r --output_dir database/edges2shoes-r-unaligned --phase val
  • Get the statistical information for the ground-truth images for your dataset to compute FID. We provide pre-prepared real statistics for several datasets. For example,

    bash datasets/download_real_stat.sh edges2shoes-r B
    bash datasets/download_real_stat.sh edges2shoes-r-unaligned subtrain_B

Apply a Pretrained Model

  • Download the pre-trained models.

    python scripts/download_model.py --model gaugan --task cityscapes --stage full
    python scripts/download_model.py --model gaugan --task cityscapes --stage compressed
  • Test the original full model.

    bash scripts/munit/edges2shoes-r_fast/test_full.sh
  • Test the compressed model.

    bash scripts/munit/edges2shoes-r_fast/test_compressed.sh
  • Measure the latency of the two models.

    bash scripts/munit/edges2shoes-r_fast/latency_full.sh
    bash scripts/munit/edges2shoes-r_fast/latency_compressed.sh

Cityscapes Dataset

For the Cityscapes dataset, we cannot provide it due to license issue. Please download the dataset from https://cityscapes-dataset.com and use the script prepare_cityscapes_dataset.py to preprocess it. You need to download gtFine_trainvaltest.zip and leftImg8bit_trainvaltest.zip and unzip them in the same folder. For example, you may put gtFine and leftImg8bit in database/cityscapes-origin. You need to prepare the dataset with the following commands:

python datasets/get_trainIds.py database/cityscapes-origin/gtFine/
python datasets/prepare_cityscapes_dataset.py \
--gtFine_dir database/cityscapes-origin/gtFine \
--leftImg8bit_dir database/cityscapes-origin/leftImg8bit \
--output_dir database/cityscapes \
--train_table_path datasets/train_table.txt \
--val_table_path datasets/val_table.txt

You will get a preprocessed dataset in database/cityscapes and a mapping table (used to compute mIoU) in dataset/table.txt.

To support mIoU computation, you need to download a pre-trained DRN model drn-d-105_ms_cityscapes.pth from http://go.yf.io/drn-cityscapes-models. By default, we put the drn model in the root directory of the repo. Then you can test our compressed models on cityscapes after you have downloaded our models.

COCO-Stuff Dataset

We follow the same COCO-Stuff dataset preparation as NVlabs/spade. Specifically, you need to download train2017.zip, val2017.zip, stuffthingmaps_trainval2017.zip, and annotations_trainval2017.zip from nightrome/cocostuff. The images, labels, and instance maps should be arranged in the same directory structure as in datasets/coco_stuff. In particular, we used an instance map that combines both the boundaries of "things instance map" and "stuff label map". To do this, we used a simple script datasets/coco_generate_instance_map.py.

To support mIoU computation, you need to download a pre-trained DeeplabV2 model deeplabv2_resnet101_msc-cocostuff164k-100000.pth and also put it in the root directory of the repo.

Performance of Released Models

Here we show the performance of all our released models:

Model Dataset Method #Parameters MACs Metric
FID mIoU
CycleGAN horse→zebra Original 11.4M 56.8G 65.75 --
GAN Compression (Paper) 0.342M 2.67G 65.33 --
GAN Compression (Retrained) 0.357M 2.55G 65.12 --
Fast GAN Compression 0.355M 2.64G 65.19 --
Pix2pix edges→shoes Original 11.4M 56.8G 24.12 --
GAN Compression (Paper) 0.700M 4.81G 26.60 --
GAN Compression (Retrained) 0.822M 4.99G 26.70 --
Fast GAN Compression 0.703M 4.83G 25.76 --
Cityscapes Original 11.4M 56.8G -- 42.06
GAN Compression (Paper) 0.707M 5.66G -- 40.77
GAN Compression (Retrained) 0.781M 5.59G -- 38.63
Fast GAN Compression 0.867M 5.61G -- 41.71
map→arial photo
Original 11.4M 56.8G 47.91 --
GAN Compression 0.746M 4.68G 48.02 --
Fast GAN Compression 0.708M 4.53G 48.67 --
GauGAN Cityscapes Original 93.0M 281G 57.60 61.04
GAN Compression (Paper) 20.4M 31.7G 55.19 61.22
GAN Compression (Retrained) 21.0M 31.2G 56.43 60.29
Fast GAN Compression 20.2M 31.3G 56.25 61.17
COCO-Stuff Original 97.5M 191G 21.38 38.78
Fast GAN Compression 26.0M 35.5G 25.06 35.05
MUNIT edges→shoes Original 15.0M 77.3G 30.13 --
Fast GAN Compression 1.10M 2.63G 30.53 --

Training

Please refer to the tutorial of Fast GAN Compression and GAN Compression on how to train models on our datasets and your own.

FID Computation

To compute the FID score, you need to get some statistical information from the groud-truth images of your dataset. We provide a script get_real_stat.py to extract statistical information. For example, for the edges2shoes dataset, you could run the following command:

python get_real_stat.py \
--dataroot database/edges2shoes-r \
--output_path real_stat/edges2shoes-r_B.npz \
--direction AtoB

For paired image-to-image translation (pix2pix and GauGAN), we calculate the FID between generated test images to real test images. For unpaired image-to-image translation (CycleGAN), we calculate the FID between generated test images to real training+test images. This allows us to use more images for a stable FID evaluation, as done in previous unconditional GANs research. The difference of the two protocols is small. The FID of our compressed CycleGAN model increases by 4 when using real test images instead of real training+test images.

To help users better understand and use our code, we briefly overview the functionality and implementation of each package and each module.

Citation

If you use this code for your research, please cite our paper.

@inproceedings{li2020gan,
  title={GAN Compression: Efficient Architectures for Interactive Conditional GANs},
  author={Li, Muyang and Lin, Ji and Ding, Yaoyao and Liu, Zhijian and Zhu, Jun-Yan and Han, Song},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  year={2020}
}

Acknowledgements

Our code is developed based on pytorch-CycleGAN-and-pix2pix, SPADE, and MUNIT.

We also thank pytorch-fid for FID computation, drn for cityscapes mIoU computation and deeplabv2 for Coco-Stuff mIoU computation.

gan-compression's People

Contributors

junyanz avatar lmxyy avatar seekingdeep avatar songhan avatar tonylins avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

gan-compression's Issues

multi-gpu

How can I use multi-gpu when training?
I have tried --gpu_ids 1,2,3,4 when training, but when training mobile this can run correctly but when distilling this was not.

About the different between "pruning + distill" and "GAN Compression" method.

As mentioned in your paper (Appendix 6.1)

we first train a MobileNet [25] style network from scratch, and then use the network as a teacher model to distill a smaller student network.

I guess "use the network as a teacher model to distill a smaller student network." corresponding to "Pre-distillation (Optional)..." here in your docs/training_tutorial.md.

And you have compared "Pruning + distill" to the "GAN Compression" method in 4.3 Figure 6.

My questions are:

  1. Is it true that, in Figure 6, the "pruning + distill" network is the student network after "Pre-distillation (Optional)..." and the "GAN compression" is the network after NAS and fine-tuning.
  2. Is is true that the student network is smaller than the teacher network because using the pruning method? And if so, what is the pruning method applied before training a once-for-all network?
  3. What is the relationshap between Figure 3 ① and the pipline mentioned in Appendix 6.1? It seems that in Figure 3 ①, no purning method is applied.
    image

Hope for your reply.

Can the sharpness of the image be increased a little

First of all, this is a subversive technical paper, which makes the reasoning model of mobile devices possible. If the picture definition is improved a little more, the breakdown threshold can be achieved.

  • This is a cartoon

png

  • It's aging

png

  • This is mosaic

png

on the left, middle and right, real_ A,fake_ B,real_B
The cartoon picture is simple, but it's OK. Others may need to be improved.
I'm focusing on pix2pixhd and partialconv. The calculation amount and capacity of the model are relatively large. I wonder if I can refine the advantages to solve the problem.

  • Currently, reasoning data on mobile phones, At iPhone 11pro speed, 256px is about 4 / s.I think there will be more than a dozen iPhone 12.

png

Dataset mode [single] only supports direction BtoA. We will change the direction to BtoA.!

when preparing the dataset for cyclegan, i get a warning saying that only BtoA is allowed:

(gan) home@home-lnx:~/programs/level 2/gan-compression$ python get_real_stat.py --dataroot database/horse2zebra/valB --output_path real_stat/horse2zebra_B.npz --direction AtoB --dataset_mode single
get_real_stat.py:61: UserWarning: Dataset mode [single] only supports direction BtoA. We will change the direction to BtoA.!
  warnings.warn('Dataset mode [single] only supports direction BtoA. '

get_real_stat.py error

@lmxyy
i am trying everything from scratch, when running get_real_stat.py i get an error.
why does it require creating a val folder? there is already valA and valB

(gan) home@home-lnx:~/programs/level 2/gan-compression$ python get_real_stat.py --dataroot database/horse2zebra/ --output_path real_stat/horse2zebra_B.npz --direction AtoB
Traceback (most recent call last):
  File "get_real_stat.py", line 84, in <module>
    main(opt)
  File "get_real_stat.py", line 15, in main
    dataloader = create_dataloader(opt)
  File "/home/home/programs/level 2/gan-compression/data/__init__.py", line 45, in create_dataloader
    dataloader = CustomDatasetDataLoader(opt, verbose)
  File "/home/home/programs/level 2/gan-compression/data/__init__.py", line 97, in __init__
    self.dataset = dataset_class(opt)
  File "/home/home/programs/level 2/gan-compression/data/aligned_dataset.py", line 24, in __init__
    self.AB_paths = sorted(make_dataset(self.dir_AB))  # get image paths
  File "/home/home/programs/level 2/gan-compression/data/image_folder.py", line 45, in make_dataset
    assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
AssertionError: database/horse2zebra/val is not a valid directory
(gan) home@home-lnx:~/programs/level 2/gan-compression/database/horse2zebra$ tree -d
.
├── testA
├── testB
├── trainA
├── trainB
├── valA -> testA
└── valB -> testB

Do I have to use the statistics during training the teacher net?

I trained a cycleGAN model without using the statistical information ofthe groud-truth images (using the original cycleGAN code). Is it possible to make the statistical information up to the model trained using origial cycleGAN or I have to re-train the model using the new version?
Thanks.

gan compression on tensorflow

Hi,

I wanted to implement GAN compression on TensorFlow so I can use it on a mobile application.
I'm pretty new to this field and I'm not aware of the challenges of doing something like this might have.
Should I implement the whole thing on Tensorflow myself? would it be as fast as it is in Pytorch?
Or should I convert the pytorch model to tensorflow with onnx?
I would really appreciate any help

Does OFA training reduce the capacity gap between student and teacher?

I compare the distill result and the supernet result(w/o finetune), results are:

#super net w/o finetune:
config_str MACs FID
32_32_32_32_32_32_32_32 4.955 55.73
32_16_32_24_32_32_24_24 3.639 60.27
16_16_32_16_32_32_16_16 2.546 65.33
16_16_16_32_32_32_16_24 1.977 134.45
16_16_16_16_16_16_16_16 1.421 223.46

#distill
config_str MACs FID
32_32_32_32_32_32_32_32 4.955 65.78
32_16_32_24_32_32_24_24 3.639 73.50
16_16_32_16_32_32_16_16 2.546 80.36
16_16_16_32_32_32_16_24 1.977 83.06
16_16_16_16_16_16_16_16 1.421 103.17

it seems when MACs>1.977, supernet is better than distill, but when MACs<1.977, distill is better.
Also it seems when MACs>1.977, OFA training reduce the capacity gpa between student and teacher, then get better performance.
Do you know why?

question regarding pix2pixHD

Hi! Thank you all for the tremendous and awesome work.

I want to ask you what would be your recommendations regarding incorporating pix2pixHD?

I've done almost all the necessary steps, but I'm interested in your advice on supernet. pix2pixHD has the same amount of blocks as pix2pix, however, the dimensionality of connecting layers (mapping layers) differs. How would you propose to modify resnet-9blocks for that? or maybe you have better faith in another architecture?

Pix2PixHD

@lmxyy @junyanz
Did you conduct any test on compressing Pix2PixHD
Will other types of GANs be implemented in this repo

license

@lmxyy
This project license is not fully clear.
Can't this repository be used for commercial uses?

How to distill a specific structure like "16_16_32_16_32_32_16_16"?

I notice your scripts can distill resblock of 16, 24, 32.
But how to distill a specific structure like "16_16_32_16_32_32_16_16"?
I think your experiment in fig 6 of the paper, is it to compare 'pruning+distill" with "gan-compressIon" in the same MACs and resblock structure?

Question about distillers

Hello, the following is the "once-for-all" training stratage you mentioned in ur paper:
"At each training step, we randomly sample a sub-network with a certain channel number configuration, compute the output and gradients, and update the extracted weights using our learning objective (Equation 4)"
where can I find this stratage in your codes?

can not find opt_compressed.pkl

Tried to test colab with default opts, but can not find necessary file "opt_compressed.pkl". Should I run a specific script to generate it?

two questions about once for all

Thanks for sharing your excellent work. I hava two questions about once for all.

1.Different hardware platforms have different optimizations for op and We often choose efficient op according to differnt hardware platform, can OFA handle this situation when different hardware platform have different prefer op?
2.On mobile platforms, different camera sensor produce different data, so different training data for different hardware platform. when we usr OFA for a generative network, like srgan, which platform's training data should be used?

How to distill a specific structure like "16_16_32_16_32_32_16_16"?

I notice your scripts can distill resblock of 16, 24, 32.
But how to distill a specific structure like "16_16_32_16_32_32_16_16"?
I think your experiment in fig 6 of the paper, is it to compare 'pruning+distill" with "gan-compressIon" in the same MACs and resblock structure?

hanlab website down?

i'm trying to download your pretrained models and it seems that the fileserver is down. Can you guys check it out?

How to train the compressed model?

Could you give me some instructions about how to train the compressed model, not the pre-trained compressed model? How about to take pix2pix for example? I don't understand those train mode.

Questions regarding training and distillation

Hi, after going through the codes I have come up with few questions regarding the training and distillation.

  1. After training the full or teacher model, do we directly train the supernet(student) from scratch using the resnet_supernet.py?
  2. if not, do we first have to train the student mobilenet (with normal seperable conv) using the resnet_distiller.py and transfer the weight to the student supernet?
  3. From looking at the load_networks function in resnet_distiller.py, is it necessary to transfer the weight of the teacher network to the student network? or is it just for faster training and convergence?
  4. Lastly, how long did it took to train the supernet using the distillation?

SERACH MULTI

I had successfully run search.py and evaluate sub-models, but when using search_multi.py, there is an error says
'RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 6 does not equal 0', which means Tensor or model is not on the same GPU.

How can I run search_multi.py successfully?

MobileNet Teacher Model

Thank you for sharing your work! Did you define the "MobileNet Teacher Model" by yourself or from the original paper?

Is Pre-distillation a required step?

I want to know the purpose of pre-distillation in the GAN-compression pipeline. How does it improve the pruning pipeline? It is not mentioned anywhere in the paper.

Question about distillation

take CycleGAN compression as an example.

the teacher generator: a MobileNet-based CycleGAN trained from scratch
the student generator: a MobileNet-based distillation CycleGAN model
the final generator: a fine-tuned sub network of student generator

Is my understanding correct?

Which means the original CycleGAN model(with normal conv) is not needed in compression algorthm?

multi-gpu with forward hook when training supernet with resnet

This is a great project. However, I met the similar issue as this. Moreover, the issue seems to be a little randomly, such that training with 2 GPUs can sometimes run correctly but sometimes not. The key point is that the intermediate features are obtained from forward hook and stored in dictionaries, which cannot guarantee correct device as the netA. The bug occurs here, where netA is always on cuda:0 because it is not wrapped with data parallel correctly (https://github.com/mit-han-lab/gan-compression/blob/master/distillers/base_resnet_distiller.py#L117-L123 , where it should be netA = networks.init_net(netA, gpu_ids=self.gpu_ids), and there should not be to(device) in the above two cases; this line and this line should be modified accordingly to something like getattr(netA, 'module', netA)). However, since Tact and Sact are randomly on different devices, and dictionary will lose this information, there will still be some bug on this. One possible solution is to include such device in the keys here to include information of device of output, but since netA is data parallel, which scatter inputs and replicate weights during its forward, netA(Sact) still does not work. If moving netA to the device of Sact before calling it, the optimizerG step will cause problem as the grad will be on different device as the netA's weight. I am not sure if there is a solution on this. I notice you changed the code structure in the spade net, but I wonder if there can be simpler solution.

About channel pruning

Could you tell me which document can let me understand the process of channel pruning?

How does ‘netG_pretrained’ in MobileDistiller.py work?

I feel puzzled at self.netG_pretrained in resnet_distiller.py.
It is here: https://github.com/mit-han-lab/gan-compression/blob/master/distillers/resnet_distiller.py#L94
Why it is deleted after loaded?

    def load_networks(self, verbose=True):
        if self.opt.restore_pretrained_G_path is not None:
            util.load_network(self.netG_pretrained, self.opt.restore_pretrained_G_path, verbose)
            load_pretrained_weight(self.opt.pretrained_netG, self.opt.student_netG,
                                   self.netG_pretrained, self.netG_student,
                                   self.opt.pretrained_ngf, self.opt.student_ngf)
            del self.netG_pretrained
        super(ResnetDistiller, self).load_networks()

NotImplementedError: Unknown module [<class 'torch.nn.modules.instancenorm.InstanceNorm2d'>]!

Following the tutorial, i train the mobile style CycleGAN. (Without changing parameters except the dataset)

But when i runing train_supernet.sh, it end up with the following error.

Traceback (most recent call last):
  File "train_supernet.py", line 4, in <module>
    trainer = Trainer('supernet')
  File "/data1/edvardzeng/myspace/gan-compression/trainer.py", line 52, in __init__
    model.setup(opt)  # regular setup: load and print networks; create schedulers
  File "/data1/edvardzeng/myspace/gan-compression/distillers/base_resnet_distiller.py", line 153, in setup
    self.load_networks(verbose)
  File "/data1/edvardzeng/myspace/gan-compression/supernets/resnet_supernet.py", line 182, in load_networks
    self.opt.student_ngf, self.opt.student_ngf)
  File "/data1/edvardzeng/myspace/gan-compression/utils/weight_transfer.py", line 167, in load_pretrained_weight
    index = transfer(m1, m2, index)
  File "/data1/edvardzeng/myspace/gan-compression/utils/weight_transfer.py", line 139, in transfer
    return transfer_MobileResnetBlock(m1, m2, input_index, output_index)
  File "/data1/edvardzeng/myspace/gan-compression/utils/weight_transfer.py", line 84, in transfer_MobileResnetBlock
    idxs = transfer(m1.conv_block[1], m2.conv_block[1], input_index=input_index)
  File "/data1/edvardzeng/myspace/gan-compression/utils/weight_transfer.py", line 145, in transfer
    raise NotImplementedError('Unknown module [%s]!' % type(m1))
NotImplementedError: Unknown module [<class 'torch.nn.modules.instancenorm.InstanceNorm2d'>]!

It's not friendly for small cpu memory users to get fid.

I try to get my own dataset by get_real_stat.py .
But I get a error need more than 13G memory in tensors = util.tensor2im(tensors).astype(float)
So, I try to reduce memory.
Firstly, It is simple to change it to tensors = util.tensor2im(tensors).astype(np.float32). It is useful.
But not good.
I try to change function get_activations_from_ims in fid_score.py like:

            images = images.transpose((0, 3, 1, 2))
        images = images.astype(np.float32)/255

But I find it is not good enought by memory_profiler, like follow:

Line #    Mem usage    Increment   Line Contents
================================================
    10   9041.7 MiB   9041.7 MiB   @profile
    11                             def get_fid(fakes, model, npz, device, batch_size=1, use_tqdm=True, bgr=False):
    12   9041.7 MiB      0.0 MiB       m1, s1 = npz['mu'], npz['sigma']
    13  15747.4 MiB   6705.7 MiB       fakes = torch.cat(fakes, dim=0)
    14  10226.1 MiB      0.0 MiB       fakes = util.tensor2im(fakes, normalize=False)  #.astype(np.float32)   # default float
    15  10226.1 MiB      0.0 MiB       m2, s2 = _compute_statistics_of_ims(fakes, model, batch_size, 2048,
    16  10338.3 MiB    112.3 MiB                                           device, use_tqdm=use_tqdm, bgr=bgr)
    17  10346.7 MiB      8.4 MiB       return float(calculate_frechet_distance(m1, s1, m2, s2))

Do you have some good ideas to reduce memory?
Thank you!

Why do you have normalization layer in between separable conv?

I spotted that you have a normalization layer in your separable convolution implementation.

self.conv = nn.Sequential(
       nn.Conv2d(in_channels=in_channels, out_channels=in_channels * scale_factor, kernel_size=kernel_size,
                      stride=stride, padding=padding, groups=in_channels, bias=use_bias),
       norm_layer(in_channels),
       nn.Conv2d(in_channels=in_channels * scale_factor, out_channels=out_channels,
                      kernel_size=1, stride=1, bias=use_bias),
)

I did not see such implementations before. Also, why it doesn't get adjust to scale_factor?

real_stats for COCO

Could you share the real_stats for the COCO data? It is not downloadable, and running the get_real_stat.py file for the coco data returns error:
python get_real_stat.py --dataroot database/coco_stuff --output_path real_stat/coco_A.npz --direction BtoA --dataset_mode coco

RuntimeError: Sizes of tensors must match except in dimension 0. Got 424 and 640 in dimension 2 (The offending index is 1)

Thanks.

No module named 'metric.cityscapes_mIoU'

Hi there, @junyanz @lmxyy
when trying to train a "once-for-all" network, i get error:

(gan) home@home-lnx:~/programs/level 2/gan-compression$ bash scripts/cycle_gan/horse2zebra_lite/train_supernet.sh
Traceback (most recent call last):
  File "train_supernet.py", line 4, in <module>
    trainer = Trainer('supernet')
  File "/home/home/programs/level 2/gan-compression/trainer.py", line 38, in __init__
    opt = Options().parse()
  File "/home/home/programs/level 2/gan-compression/options/base_options.py", line 134, in parse
    opt = self.gather_options()
  File "/home/home/programs/level 2/gan-compression/options/supernet_options.py", line 76, in gather_options
    supernet_option_setter = supernets.get_option_setter(supernet_name)
  File "/home/home/programs/level 2/gan-compression/supernets/__init__.py", line 22, in get_option_setter
    supernet_class = find_supernet_using_name(supernet_name)
  File "/home/home/programs/level 2/gan-compression/supernets/__init__.py", line 6, in find_supernet_using_name
    modellib = importlib.import_module(supernet_filename)
  File "/home/home/anaconda3/envs/gan/lib/python3.8/importlib/__init__.py", line 127, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1014, in _gcd_import
  File "<frozen importlib._bootstrap>", line 991, in _find_and_load
  File "<frozen importlib._bootstrap>", line 975, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 671, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 783, in exec_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
  File "/home/home/programs/level 2/gan-compression/supernets/resnet_supernet.py", line 13, in <module>
    from distillers.base_resnet_distiller import BaseResnetDistiller
  File "/home/home/programs/level 2/gan-compression/distillers/base_resnet_distiller.py", line 11, in <module>
    from metric.cityscapes_mIoU import DRNSeg
ModuleNotFoundError: No module named 'metric.cityscapes_mIoU'

other datasets

Hi, author!
How can I transfer Pix2Pix model to my own datasets such as deRain datasets?
Can I just prepare the dataset according to the form like your datasets?

Searching is slow.How can I continue searching when it is break.

I just have one GPU. Searching is so slow. I want to run another code.
How can I continue searching after finish another project?

65%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 37494/57600 [40:38:20<137:38:38, 24.65s/it]MACs: 4.364G Params: 1.987M
{'config_str': '48_32_32_48_40_32_24_16', 'macs': 4363780096, 'fid': 68.67807846983607}
65%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 37495/57600 [40:38:48<143:48:26, 25.75s/it]MACs: 5.282G Params: 1.987M
{'config_str': '48_32_32_48_40_32_16_64', 'macs': 5282070528, 'fid': 64.41984751671731}
65%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 37496/57600 [40:39:19<152:29:44, 27.31s/it]MACs: 4.825G Params: 1.987M

Thank you!

Questions about resnet_supernet in pix2pix_model.

  1.      if **getattr(opt, 'sort_channels', False)** and opt.restore_student_G_path is not None:  # line 74 for base_resnet_distiller.py
       For "**getattr(opt, 'sort_channels', False)**", I check the definition of the function "getattr", it used to be the format of "getattr(object, name, default=None)", and when setting this default papram to "False" or "True", it won't affect the function output, this function just return the value of **opt.sort_channels** . 
    
  2. I want to know do I need to sort channels before OFA, that is, setting sort_channels = True. cuz I notice the role of "netG_student_tmp" in supernet training, sorting channels before transfering pretrained weights to student_netG.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.