Code Monkey home page Code Monkey logo

pti's Introduction

PTI: Pivotal Tuning for Latent-based editing of Real Images (ACM TOG 2022)


Inference Notebook:


Pivotal Tuning Inversion (PTI) enables employing off-the-shelf latent based semantic editing techniques on real images using StyleGAN. PTI excels in identity preserving edits, portrayed through recognizable figures — Serena Williams and Robert Downey Jr. (top), and in handling faces which are clearly out-of-domain, e.g., due to heavy makeup (bottom).

Description

Official Implementation of our PTI paper + code for evaluation metrics. PTI introduces an optimization mechanizem for solving the StyleGAN inversion task. Providing near-perfect reconstruction results while maintaining the high editing abilitis of the native StyleGAN latent space W. For more details, see

Recent Updates

2021.07.01: Fixed files download phase in the inference notebook. Which might caused the notebook not to run smoothly.

2021.06.29: Added support for CPU. In order to run PTI on CPU please change device parameter under configs/global_config.py to "cpu" instead of "cuda".

2021.06.25 : Adding mohawk edit using StyleCLIP+PTI in inference notebook. Updating documentation in inference notebook due to Google Drive rate limit reached. Currently, Google Drive does not allow to download the pretrined models using Colab automatically. Manual intervention might be needed.

Getting Started

Prerequisites

  • Linux or macOS
  • NVIDIA GPU + CUDA CuDNN (Not mandatory bur recommended)
  • Python 3

Installation

  • Dependencies:
    1. lpips
    2. wandb
    3. pytorch
    4. torchvision
    5. matplotlib
    6. dlib
  • All dependencies can be installed using pip install and the package name

Pretrained Models

Please download the pretrained models from the following links.

Auxiliary Models

We provide various auxiliary models needed for PTI inversion task.
This includes the StyleGAN generator and pre-trained models used for loss computation.

Path Description
FFHQ StyleGAN StyleGAN2-ada model trained on FFHQ with 1024x1024 output resolution.
Dlib alignment Dlib alignment used for images preproccessing.
FFHQ e4e encoder Pretrained e4e encoder. Used for StyleCLIP editing.

Note: The StyleGAN model is used directly from the official stylegan2-ada-pytorch implementation. For StyleCLIP pretrained mappers, please see StyleCLIP's official routes

By default, we assume that all auxiliary models are downloaded and saved to the directory pretrained_models. However, you may use your own paths by changing the necessary values in configs/path_configs.py.

Inversion

Preparing your Data

In order to invert a real image and edit it you should first align and crop it to the correct size. To do so you should perform One of the following steps:

  1. Run notebooks/align_data.ipynb and change the "images_path" variable to the raw images path
  2. Run utils/align_data.py and change the "images_path" variable to the raw images path

Weights And Biases

The project supports Weights And Biases framework for experiment tracking. For the inversion task it enables visualization of the losses progression and the generator intermediate results during the initial inversion and the Pivotal Tuning(PT) procedure.

The log frequency can be adjusted using the parameters defined at configs/global_config.py under the "Logs" subsection.

There is no no need to have an account. However, in order to use the features provided by Weights and Biases you first have to register on their site.

Running PTI

The main training script is scripts/run_pti.py. The script receives aligned and cropped images from paths configured in the "Input info" subscetion in configs/paths_config.py. Results are saved to directories found at "Dirs for output files" under configs/paths_config.py. This includes inversion latent codes and tuned generators. The hyperparametrs for the inversion task can be found at configs/hyperparameters.py. They are intilized to the default values used in the paper.

Editing

By default, we assume that all auxiliary edit directions are downloaded and saved to the directory editings. However, you may use your own paths by changing the necessary values in configs/path_configs.py under "Edit directions" subsection.

Example of editing code can be found at scripts/latent_editor_wrapper.py

Inference Notebooks

To help visualize the results of PTI we provide a Jupyter notebook found in notebooks/inference_playground.ipynb.
The notebook will download the pretrained models and run inference on a sample image found online or on images of your choosing. It is recommended to run this in Google Colab.

The notebook demonstrates how to:

  • Invert an image using PTI
  • Visualise the inversion and use the PTI output
  • Edit the image after PTI using InterfaceGAN and StyleCLIP
  • Compare to other inversion methods

Evaluation

Currently the repository supports qualitative evaluation for reconstruction of: PTI, SG2 (W Space), e4e, SG2Plus (W+ Space). As well as editing using InterfaceGAN and GANSpace for the same inversion methods. To run the evaluation please see evaluation/qualitative_edit_comparison.py. Examples of the evaluation scripts are:


Reconsturction comparison between different methods. The images order is: Original image, W+ inversion, e4e inversion, W inversion, PTI inversion


InterfaceGAN pose edit comparison between different methods. The images order is: Original, W+, e4e, W, PTI


Image per edit or several edits without comparison

Coming Soon - Quantitative evaluation and StyleCLIP qualitative evaluation

Repository structure

Path Description
├  configs Folder containing configs defining Hyperparameters, paths and logging
├  criteria Folder containing various loss and regularization criterias for the optimization
├  dnnlib Folder containing internal utils for StyleGAN2-ada
├  docs Folder containing the latent space edit directions
├  editings Folder containing images displayed in the README
├  environment Folder containing Anaconda environment used in our experiments
├  licenses Folder containing licenses of the open source projects used in this repository
├  models Folder containing models used in different editing techniques and first phase inversion
├  notebooks Folder with jupyter notebooks to demonstrate the usage of PTI end-to-end
├  scripts Folder with running scripts for inversion, editing and metric computations
├  torch_utils Folder containing internal utils for StyleGAN2-ada
├  training Folder containing the core training logic of PTI
├  utils Folder with various utility functions

Credits

StyleGAN2-ada model and implementation:
https://github.com/NVlabs/stylegan2-ada-pytorch Copyright © 2021, NVIDIA Corporation.
Nvidia Source Code License https://nvlabs.github.io/stylegan2-ada-pytorch/license.html

LPIPS model and implementation:
https://github.com/richzhang/PerceptualSimilarity
Copyright (c) 2020, Sou Uchida
License (BSD 2-Clause) https://github.com/richzhang/PerceptualSimilarity/blob/master/LICENSE

e4e model and implementation:
https://github.com/omertov/encoder4editing Copyright (c) 2021 omertov
License (MIT) https://github.com/omertov/encoder4editing/blob/main/LICENSE

StyleCLIP model and implementation:
https://github.com/orpatashnik/StyleCLIP Copyright (c) 2021 orpatashnik
License (MIT) https://github.com/orpatashnik/StyleCLIP/blob/main/LICENSE

InterfaceGAN implementation:
https://github.com/genforce/interfacegan Copyright (c) 2020 genforce
License (MIT) https://github.com/genforce/interfacegan/blob/master/LICENSE

GANSpace implementation:
https://github.com/harskish/ganspace Copyright (c) 2020 harkish
License (Apache License 2.0) https://github.com/harskish/ganspace/blob/master/LICENSE

Acknowledgments

This repository structure is based on encoder4editing and ReStyle repositories

Contact

For any inquiry please contact us at our email addresses: [email protected] or [email protected]

Citation

If you use this code for your research, please cite:

@article{roich2021pivotal,
  title={Pivotal Tuning for Latent-based Editing of Real Images},
  author={Roich, Daniel and Mokady, Ron and Bermano, Amit H and Cohen-Or, Daniel},
  publisher = {Association for Computing Machinery},
  journal={ACM Trans. Graph.},
  year={2021}
}

pti's People

Contributors

danielroich avatar seemirra avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pti's Issues

[question] What kind of pre-processing would a model that doesn't generate faces require?

Hi there!

I’ve been trying to invert some pictures using pre-trained models that don’t generate faces. For obvious reasons, I’ve been skipping most of the pre-processing, such as dlib face alignment, being the resizing the only part that I left.

However, both the final embedding and the fine-tuned model are of poor quality, either being distorted or blurred. It seems the repository is specifically designed for faces, so I was wondering if you could tell us any best practices or advice about pre-processing pictures that aren’t necessarily faces.

Thanks for the good work!

Regards

interpolate faces

amazing work
It would be amazing if you can add a script to interpolate 2 faces.

SG and SG2 issue

Hi, thanks for the impressive work, which has raised a great impact.
In the original paper, I wondered about the difference between SG, SG w+ and the first step of PTI.
Both SG, PTI optimize the original W space, while SG w+ employ the w+ space.
SG and SG w+ take more steps to optimize. And all the 3 methods employ noise regularization.
Is there any other difference that I missed?

checkpoint

hi,
Are fine-tuned weights provided for testing?

use_multi_id_training=True error

I have two images aligned in / content / PTI / image_processed
I want to train a model for multiple images but using use_multi_id_training = True.
but it gives this error

100%|██████████| 2/2 [00:01<00:00,  1.39it/s]
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
100%
233M/233M [00:04<00:00, 50.4MB/s]

Loading model from: /usr/local/lib/python3.7/dist-packages/lpips/weights/v0.1/alex.pth
Downloading https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt ... done
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py:1051: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /pytorch/c10/core/TensorImpl.h:1156.)
  return forward_call(*input, **kwargs)
100%|██████████| 450/450 [01:21<00:00,  5.52it/s]
100%|██████████| 450/450 [01:21<00:00,  5.52it/s]
  0%|          | 0/350 [00:00<?, ?it/s]
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-33-b69631c99070> in <module>()
     34 ## In order to run PTI and use StyleGAN2-ada, the cwd should the parent of 'torch_utils' and 'dnnlib'
     35 os.chdir('/content/PTI')
---> 36 model_id = run_PTI(use_wandb=False, use_multi_id_training=True)

1 frames
/content/PTI/training/coaches/multi_id_coach.py in train(self)
     58 
     59                 self.optimizer.zero_grad()
---> 60                 loss.backward()
     61                 self.optimizer.step()
     62 

AttributeError: 'tuple' object has no attribute 'backward'

The input parameters of the Generator

Hi, thanks for your great work!

I am curious about all the input parameters such as 'noise_mode', 'force_fp' of self.G.synthesis(w, noise_mode='const', force_fp32=True).
And I also want to know how to return the featrures of each layer in the self.G.

Thanks

Noob - question missing something

I got all the models. Environment is fine.
Code runs. I check the config…. Seems ok. But I don’t get anywhere using 2 scripts

scripts/run-pti.py

and evaluation python script.

5C13D617-B277-4B07-AF6C-A61770882526

I don’t see any output.
I check output directories / but I don’t get any files spat out. I have an aligned folder with sequence of images of faces…

7EA68D9E-432C-4251-93DF-998235757354

Run PTI on 256*256 images

Dear Daniel,
Thank you very much for the great work.
I am trying to apply PTI on my trained model which has 256 * 256 resolution.
Could you give me any tips on which part should I fix in the code?
Many thanks,

How many images have been tried at most to train together?

Thans for your contribution. How many images have been tried at most to train together? Ideally, each picture corresponds to one pivotal finetuned stylegan. I have 2w+ images with different IDs and quality of images are different. What if I train them together to get one stylegan?

edit an image with text?

the code in colab only allows to use a mapper.
what I want is to write a text to modify the image as Styleclip.

run align_data error


RuntimeError Traceback (most recent call last)
in
----> 1 predictor = dlib.shape_predictor(SHAPE_PREDICTOR_PATH)

RuntimeError: Error deserializing object of type int64
while deserializing a floating point number.
while deserializing a dlib::matrix
while deserializing object of type std::vector
while deserializing object of type std::vector
while deserializing object of type std::vector

how to get more editing directions

Hi, thanks for your great work.
I notice that you said the editing directions you uploaded are trained on the pretrained StyleGAN. If I want more editing directions, what should I do?
Thank you.

(Colab) Files downloaded with get_download_model_command() subject to GDrive rate limit

Pretrained models downloaded with the get_download_model_command() function (align.dat, ffhq.pkl, etc.) can silently fail, instead downloading a file that contains a rate limit error message (see below) and causing opaque errors later in execution. From googling, this rate limit seems to be related to number of downloads from the host.

Some potential resolutions:

  • use a new host
  • have it detect the failure and give a more helpful error message
  • do nothing and the problem will naturally go away as engagement decreases

I've managed to get through the PTI step at least, and the inversion quality is quite incredible. 🤯

<!DOCTYPE html><html><head><title>Google Drive - Quota exceeded</title><meta http-equiv="content-type" content="text/html; charset=utf-8"/><link href=&#47;static&#47;doclist&#47;client&#47;css&#47;2149812255&#45;untrustedcontent.css rel="stylesheet" nonce="HaMJevBe9x570XPxVsedUg"><link rel="icon" href="//ssl.gstatic.com/images/branding/product/1x/drive_2020q4_32dp.png"/><style nonce="HaMJevBe9x570XPxVsedUg">#gbar,#guser{font-size:13px;padding-top:0px !important;}#gbar{height:22px}#guser{padding-bottom:7px !important;text-align:right}.gbh,.gbd{border-top:1px solid #c9d7f1;font-size:1px}.gbh{height:0;position:absolute;top:24px;width:100%}@media all{.gb1{height:22px;margin-right:.5em;vertical-align:top}#gbar{float:left}}a.gb1,a.gb4{text-decoration:underline !important}a.gb1,a.gb4{color:#00c !important}.gbi .gb4{color:#dd8e27 !important}.gbf .gb4{color:#900 !important}
</style><script nonce="Cz9wYhCg7eMLb9sQpDKYQw"></script></head><body><div id=gbar><nobr><a target=_blank class=gb1 href="https://www.google.com/webhp?tab=ow">Search</a> <a target=_blank class=gb1 href="http://www.google.com/imghp?hl=en&tab=oi">Images</a> <a target=_blank class=gb1 href="https://maps.google.com/maps?hl=en&tab=ol">Maps</a> <a target=_blank class=gb1 href="https://play.google.com/?hl=en&tab=o8">Play</a> <a target=_blank class=gb1 href="https://www.youtube.com/?gl=US&tab=o1">YouTube</a> <a target=_blank class=gb1 href="https://news.google.com/?tab=on">News</a> <a target=_blank class=gb1 href="https://mail.google.com/mail/?tab=om">Gmail</a> <b class=gb1>Drive</b> <a target=_blank class=gb1 style="text-decoration:none" href="https://www.google.com/intl/en/about/products?tab=oh"><u>More</u> &raquo;</a></nobr></div><div id=guser width=100%><nobr><span id=gbn class=gbi></span><span id=gbf class=gbf></span><span id=gbe></span><a target="_self" href="/settings?hl=en_US" class=gb4>Settings</a> | <a target=_blank  href="//support.google.com/drive/?p=web_home&hl=en_US" class=gb4>Help</a> | <a target=_top id=gb_70 href="https://accounts.google.com/ServiceLogin?hl=en&passive=true&continue=https://docs.google.com/uc%3Fexport%3Ddownload%26confirm%26id%3D1cUv_reLE6k3604or78EranS7XzuVMWeO&service=writely&ec=GAZAMQ" class=gb4>Sign in</a></nobr></div><div class=gbh style=left:0></div><div class=gbh style=right:0></div><div class="uc-main"><div id="uc-text"><p class="uc-error-caption">Sorry, you can&#39;t view or download this file at this time.</p><p class="uc-error-subcaption">Too many users have viewed or downloaded this file recently. Please try accessing the file again later. If the file you are trying to access is particularly large or is shared with many people, it may take up to 24 hours to be able to view or download the file. If you still can't access a file after 24 hours, contact your domain administrator.</p></div></div><div class="uc-footer"><hr class="uc-footer-divider">&copy; 2021 Google - <a class="goog-link" href="//support.google.com/drive/?p=web_home">Help</a> - <a class="goog-link" href="//support.google.com/drive/bin/answer.py?hl=en_US&amp;answer=2450387">Privacy & Terms</a></div></body></html>

run_pit: assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)

Hi there 👋

Thanks a lot for the project, I trying to use run_pti with an image but I got this error

(stylegan3) ➜  PTI git:(main) ✗ python scripts/run_pti.py
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
/home/zuppif/miniconda3/envs/stylegan3/lib/python3.9/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/home/zuppif/miniconda3/envs/stylegan3/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=AlexNet_Weights.IMAGENET1K_V1`. You can also use `weights=AlexNet_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Loading model from: /home/zuppif/miniconda3/envs/stylegan3/lib/python3.9/site-packages/lpips/weights/v0.1/alex.pth
  0%|                                                                             | 0/1 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/zuppif/Documents/DragGAN/PTI/scripts/run_pti.py", line 51, in <module>
    run_PTI(run_name='', use_wandb=False, use_multi_id_training=False)
  File "/home/zuppif/Documents/DragGAN/PTI/scripts/run_pti.py", line 45, in run_PTI
    coach.train()
  File "/home/zuppif/Documents/DragGAN/PTI/training/coaches/single_id_coach.py", line 39, in train
    w_pivot = self.calc_inversions(image, image_name)
  File "/home/zuppif/Documents/DragGAN/PTI/training/coaches/base_coach.py", line 93, in calc_inversions
    w = w_projector.project(self.G, id_image, device=torch.device(global_config.device), w_avg_samples=600,
  File "/home/zuppif/Documents/DragGAN/PTI/training/projectors/w_projector.py", line 41, in project
    assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)
AssertionError

I first try to resize the image to 1024x512 then to 1024x1024 but the error persist.

Thanks a lot

Fra

Edit direction

Hi, Thanks for your great work!

Here I get one question, when you fine-tune the parameters of the pre-trained StyleGAN model and use InterFaceGAN methods to edit the image, do we need to find a new semantic direction based on the current parameter model? Intuitively, changing the model parameters will change the semantic direction found by InterFaceGAN.

In general, after you fine-tune the pre-trained StyleGAN model, do you retrain the semantic direction found by InterFaceGAN?

Thanks

How did you find the directions?

hello,
I used stylegan2 to find some directions, but the effect is not fine.
Do you use stylegan2-ada to find directions through interfaceGan?
Thank you!

Effect of lpips type

Hello,

I find that the option of lpips_type = 'alex' tend to affect my inversion results:
with lpips_type = 'alex', it introduces undesirable checkerboard/noise-like artifact which gives me a sense of overfitting (input left and reconstruction right)
image
image
with lpips=type='vgg', results are smoother which gives me a sense of underfitting:
image
image
you may need to zoom-in a bit to tell the difference.
any suggestions in this case? do I need to tune options like LPIPS_value_threshold = 0.06 to find a sweet point between this trade-off?

errors in colab

Thanks for sharing your great work! When I run the code in colab, some errors come out in the part of downloading pre-trained models. It shows that "NameError: name 'downloader' is not defined". Can you give me any suggestions?

KeyError: 'FullyConnectedLayer'

hi, i have run the file run_PTI, but get the error: KeyError: 'FullyConnectedLayer'
detail:
KeyError Traceback (most recent call last)
in
----> 1 model_id = run_PTI(use_wandb=False, use_multi_id_training=use_multi_id_training)

~/ali_repo/PTI/scripts/run_pti.py in run_PTI(run_name, use_wandb, use_multi_id_training)
41 coach = MultiIDCoach(dataloader, use_wandb)
42 else:
---> 43 coach = SingleIDCoach(dataloader, use_wandb)
44
45 coach.train()

~/ali_repo/PTI/training/coaches/single_id_coach.py in init(self, data_loader, use_wandb)
10
11 def init(self, data_loader, use_wandb):
---> 12 super().init(data_loader, use_wandb)
13
14 def train(self):

~/ali_repo/PTI/training/coaches/base_coach.py in init(self, data_loader, use_wandb)
37 self.lpips_loss = LPIPS(net=hyperparameters.lpips_type).to(global_config.device).eval()
38
---> 39 self.restart_training()
40
41 # Initialize checkpoint dir

~/ali_repo/PTI/training/coaches/base_coach.py in restart_training(self)
46
47 # Initialize networks
---> 48 self.G = load_old_G()
49 toogle_grad(self.G, True)
50

~/ali_repo/PTI/utils/models_utils.py in load_old_G()
21 def load_old_G():
22 with open(paths_config.stylegan2_ada_ffhq, 'rb') as f:
---> 23 old_G = pickle.load(f)['G_ema'].to(global_config.device).eval()
24 old_G = old_G.float()
25 return old_G

~/ali_repo/PTI/torch_utils/persistence.py in _reconstruct_persistent_obj(meta)
191
192 assert meta.type == 'class'
--> 193 orig_class = module.dict[meta.class_name]
194 decorator_class = persistent_class(orig_class)
195 obj = decorator_class.new(decorator_class)

KeyError: 'FullyConnectedLayer'

one image one checkpoint?

Hello,
I have a question. When we fine-tune the generator, do we need to save a corresponding model parameter for each image?

Why don't you use e4e in the first stage Inversion?

Great work!
I notice that there is a "first_inv_type" option in hyperparameters.py. You must have tried using e4e in the first stage. Can you tell us why you choose to use the original projector in StyleGAN2 instead of e4e? Thanks!

lpips seems broken -

the latest version 0.1.4 seems to break things with missing layer

Traceback (most recent call last):
File "scripts/notebook.py", line 63, in
model_id = run_PTI(use_wandb=False, use_multi_id_training=use_multi_id_training)
File "/home/jp/Documents/gitWorkspace/PTI/scripts/run_pti.py", line 40, in run_PTI
coach = SingleIDCoach(dataloader, use_wandb)
File "/home/jp/Documents/gitWorkspace/PTI/training/coaches/single_id_coach.py", line 13, in init
super().init(data_loader, use_wandb)
File "/home/jp/Documents/gitWorkspace/PTI/training/coaches/base_coach.py", line 36, in init
self.lpips_loss = LPIPS(net=hyperparameters.lpips_type, lpips_layers=hyperparameters.pt_lpips_layers).to(global_config.device).eval()
TypeError: init() got an unexpected keyword argument 'lpips_layers'

I'd submit a PR - but my branch has drifted considerably.

Style Clip editing on Hyperstyle + PTI output using global directions

Hey, So Hyperstyle saves weights when executed and afterward when I tune the inversion using PTI and then use styleGAN for output. The main issue I am facing is that styleGAN loads the saved weights from Hyperstyle and thus the editing is being done on Hyperstyle inversion and not Hyperstyle + PTI tuned inversion. So is there a way to use global directions only and save the weights after the tuning through PTI has been performed?

inversion on whole body images

I'm trying to perform inversion on whole body images as opposed to faces. Looking at the inference notebook you shared, I'm guessing the preprocessing function which receives input from align_faces will need a new function called align_body (for example) to provide input for a body image.

Or will the best solution be to skip the preprocessing step all together?

generate random faces

amazing work.
I would like to know if it is possible to generate random faces with a seed like nvidia's stylegan.
Try to do this, but the generated messages are full of artifacts.
new_G.synthesis(torch.from_numpy(np.random.rand(1,18,512)).float().to("cuda"),noise_mode='const')

Colab Unpickling Error

Running the colab notebook throws an unpickling error at "Use PTI with e4e backbone for StyleCLIP"

---------------------------------------------------------------------------
UnpicklingError                           Traceback (most recent call last)
<ipython-input-38-29c3e7342ea3> in <module>()
      1 hyperparameters.first_inv_type = 'w+'
      2 os.chdir('/content/PTI')
----> 3 model_id = run_PTI(use_wandb=False, use_multi_id_training=False)

5 frames
/usr/local/lib/python3.7/dist-packages/torch/serialization.py in _legacy_load(f, map_location, pickle_module, **pickle_load_args)
    775             "functionality.")
    776 
--> 777     magic_number = pickle_module.load(f, **pickle_load_args)
    778     if magic_number != MAGIC_NUMBER:
    779         raise RuntimeError("Invalid magic number; corrupt file?")

UnpicklingError: invalid load key, '<'.

opt multi image one time?

Thanks for your great work. I find that in multi_id_coach.py, you will opt one img per step. why can you not opt several img in batch per step. Another question, if I have one person with several pose img, how to get a best invision res. one img one ckpt or one ckpt for all img. Thank you in advance.

every image need 127M checkpoint

Well done. Your work is amazing. The performance is the best as far as I know. But I found inference time is slow, and every image should save 127M pt file. It is not pratical. Some advise is welcom.

How to take the generated model pt -> and mash it back into stylegan2-ada pkl format for use in other apps. [DRAFTED]

So I get the npz file (thanks for your help on the other ticket #24) + I see the new generator - saved.
model_MWVZTEZFDDJB_1.pt

I did some inspecting and see the new generator
https://gist.github.com/johndpope/c5b77f8cc7d7d008be7f15079a9378bf

I'm wanting to spit out am update ffhq pkl file in the correct shape and format so I can run the new generator in different use cases with other repos.

  with open(paths_config.stylegan2_ada_ffhq, 'rb') as f:
    old_G = pickle.load(f)['G_ema'].cuda()  // this grabs the pickle for ffhq file
    
  with open(f'{paths_config.checkpoints_dir}/model_{model_id}_{image_name}.pt', 'rb') as f_new: 
    new_G = torch.load(f_new).cuda() // and htis is grabbing the updated model_MWVZTEZFDDJB_1.pt

UPDATE 1 - thus far I have this hack which saves out a pkl

UPDATE 2 -
I actually load the new file into stylegan2-ada-pytorch and run the approach.py in conjunction with projected_w.pnz
but it's badly working - I wonder if it's because this pickle would need a new descriminator too???

UPDATE 3 -
I think I know how to solve - I need to load the final pt which is spat out and do the hot wiring - should be fine.

def export_updated_pickle(new_G,model_id):
  print("Exporting large updated pickle based off new generator and ffhq.pkl")
  with open(paths_config.stylegan2_ada_ffhq, 'rb') as f:
    d = pickle.load(f)
    old_G = d['G_ema'].cuda() ## tensor
    old_D = d['D'].eval().requires_grad_(False).cpu()

  tmp = {}
  tmp['G_ema'] = old_G.eval().requires_grad_(False).cpu()# copy.deepcopy(new_G).eval().requires_grad_(False).cpu()
  tmp['G'] = new_G.eval().requires_grad_(False).cpu() # copy.deepcopy(new_G).eval().requires_grad_(False).cpu()
  tmp['D'] = old_D
  tmp['training_set_kwargs'] = None
  tmp['augment_pipe'] = None


  with open(f'{paths_config.checkpoints_dir}/model_{model_id}.pkl', 'wb') as f:
      pickle.dump(tmp, f)

....
at bottom of notebook
print(f'Displaying PTI inversion')
plot_image_from_w(w_pivot, new_G)
np.savez(f'projected_w.npz', w=w_pivot.cpu().detach().numpy())
export_updated_pickle(new_G,model_id)

original
image_from_w

1_afro
1_angry
1_bobcut
1_bowlcut
1_mohawk
1_surprised
1_trump

https://drive.google.com/drive/folders/1l6Xvs6EPVyyw0sFowIpN1pd1lJbm56hD?usp=sharing

I get new pkl / npz file

I cherry pick this file into original stylegan2-ada-pytorch repo
https://github.com/l4rz/stylegan2-clip-approach

I rename file pkl to ffhq-pti.pkl
I run
(torch) ➜ stylegan2-ada-pytorch git:(main) ✗ python approach.py --network ffhq-pti.pkl --w projected_w.npz --outdir ffhq-pti --num-steps 100 --text 'squint'

GPU error

Hello, I was trying to implement PTI into eg3d/loss.py at main · NVlabs/eg3d but I got some problems when calling the PTI/training/projectors at main · danielroich/PTI

So here is how I call w/w_plus projector (search for function pti_projector):

eg3d code
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

"""Loss functions."""

import numpy as np
import torch
from torch_utils import training_stats
from torch_utils.ops import conv2d_gradfix
from torch_utils.ops import upfirdn2d
from training.dual_discriminator import filtered_resizing

#----------------------------------------------------------------------------

class Loss:
    def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg): # to be overridden by subclass
        raise NotImplementedError()

#----------------------------------------------------------------------------

# ---------------------- project image into latent space --------------------- #
# modified code from https://github.com/oneThousand1000/EG3D-projector/tree/master/eg3d/projector
from training.projector import w_plus_projector, w_projector
from torchvision import transforms
import copy

def pti_projector(cur_G, cur_c, cur_image, device, latent_type='w_plus'):
    # # put image back to cpu for transforms
    # image = cur_image.cpu()
    # # normalize image
    # normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
    #                                     std=[0.5, 0.5, 0.5])
    # id_image = normalize(image)
    # id_image = torch.squeeze((id_image + 1) / 2, 0)
    
    id_image = cur_image.to(device)
    # c = c.to(device)
    c = torch.reshape(cur_c, (1, 25)).to(device) # 25 is the camera pose dimension 16 + 9
    G = cur_G

    if latent_type == 'w_plus':
        w = w_plus_projector.project(G, c, id_image, device=device, w_avg_samples=600)
    else:
        w = w_projector.project(G, c, id_image, device=device, w_avg_samples=600)
    print('w shape: ', w.shape)
    return w
    

class StyleGAN2Loss(Loss):
    def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10, style_mixing_prob=0, pl_weight=0, pl_batch_shrink=2, pl_decay=0.01, pl_no_weight_grad=False, blur_init_sigma=0, blur_fade_kimg=0, r1_gamma_init=0, r1_gamma_fade_kimg=0, neural_rendering_resolution_initial=64, neural_rendering_resolution_final=None, neural_rendering_resolution_fade_kimg=0, gpc_reg_fade_kimg=1000, gpc_reg_prob=None, dual_discrimination=False, filter_mode='antialiased'):
        super().__init__()
        self.device             = device
        self.G                  = G
        self.D                  = D
        self.augment_pipe       = augment_pipe
        self.r1_gamma           = r1_gamma
        self.style_mixing_prob  = style_mixing_prob
        self.pl_weight          = pl_weight
        self.pl_batch_shrink    = pl_batch_shrink
        self.pl_decay           = pl_decay
        self.pl_no_weight_grad  = pl_no_weight_grad
        self.pl_mean            = torch.zeros([], device=device)
        self.blur_init_sigma    = blur_init_sigma
        self.blur_fade_kimg     = blur_fade_kimg
        self.r1_gamma_init      = r1_gamma_init
        self.r1_gamma_fade_kimg = r1_gamma_fade_kimg
        self.neural_rendering_resolution_initial = neural_rendering_resolution_initial
        self.neural_rendering_resolution_final = neural_rendering_resolution_final
        self.neural_rendering_resolution_fade_kimg = neural_rendering_resolution_fade_kimg
        self.gpc_reg_fade_kimg = gpc_reg_fade_kimg
        self.gpc_reg_prob = gpc_reg_prob
        self.dual_discrimination = dual_discrimination
        self.filter_mode = filter_mode
        self.resample_filter = upfirdn2d.setup_filter([1,3,3,1], device=device)
        self.blur_raw_target = True
        assert self.gpc_reg_prob is None or (0 <= self.gpc_reg_prob <= 1)

    def run_G(self, z, c, swapping_prob, neural_rendering_resolution, update_emas=False):
        if swapping_prob is not None:
            c_swapped = torch.roll(c.clone(), 1, 0)
            c_gen_conditioning = torch.where(torch.rand((c.shape[0], 1), device=c.device) < swapping_prob, c_swapped, c)
        else:
            c_gen_conditioning = torch.zeros_like(c)

        ws = self.G.mapping(z, c_gen_conditioning, update_emas=update_emas)
        if self.style_mixing_prob > 0:
            with torch.autograd.profiler.record_function('style_mixing'):
                cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
                cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
                ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:]
        gen_output = self.G.synthesis(ws, c, neural_rendering_resolution=neural_rendering_resolution, update_emas=update_emas)
        return gen_output, ws

    def run_D(self, img, c, blur_sigma=0, blur_sigma_raw=0, update_emas=False):
        blur_size = np.floor(blur_sigma * 3)
        if blur_size > 0:
            with torch.autograd.profiler.record_function('blur'):
                f = torch.arange(-blur_size, blur_size + 1, device=img['image'].device).div(blur_sigma).square().neg().exp2()
                img['image'] = upfirdn2d.filter2d(img['image'], f / f.sum())

        if self.augment_pipe is not None:
            augmented_pair = self.augment_pipe(torch.cat([img['image'],
                                                    torch.nn.functional.interpolate(img['image_raw'], size=img['image'].shape[2:], mode='bilinear', antialias=True)],
                                                    dim=1))
            img['image'] = augmented_pair[:, :img['image'].shape[1]]
            img['image_raw'] = torch.nn.functional.interpolate(augmented_pair[:, img['image'].shape[1]:], size=img['image_raw'].shape[2:], mode='bilinear', antialias=True)

        logits = self.D(img, c, update_emas=update_emas)
        return logits

    def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg):
        assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
        if self.G.rendering_kwargs.get('density_reg', 0) == 0:
            phase = {'Greg': 'none', 'Gboth': 'Gmain'}.get(phase, phase)
        if self.r1_gamma == 0:
            phase = {'Dreg': 'none', 'Dboth': 'Dmain'}.get(phase, phase)
        blur_sigma = max(1 - cur_nimg / (self.blur_fade_kimg * 1e3), 0) * self.blur_init_sigma if self.blur_fade_kimg > 0 else 0
        r1_gamma = self.r1_gamma

        alpha = min(cur_nimg / (self.gpc_reg_fade_kimg * 1e3), 1) if self.gpc_reg_fade_kimg > 0 else 1
        swapping_prob = (1 - alpha) * 1 + alpha * self.gpc_reg_prob if self.gpc_reg_prob is not None else None

        if self.neural_rendering_resolution_final is not None:
            alpha = min(cur_nimg / (self.neural_rendering_resolution_fade_kimg * 1e3), 1)
            neural_rendering_resolution = int(np.rint(self.neural_rendering_resolution_initial * (1 - alpha) + self.neural_rendering_resolution_final * alpha))
        else:
            neural_rendering_resolution = self.neural_rendering_resolution_initial

        real_img_raw = filtered_resizing(real_img, size=neural_rendering_resolution, f=self.resample_filter, filter_mode=self.filter_mode)

        if self.blur_raw_target:
            blur_size = np.floor(blur_sigma * 3)
            if blur_size > 0:
                f = torch.arange(-blur_size, blur_size + 1, device=real_img_raw.device).div(blur_sigma).square().neg().exp2()
                real_img_raw = upfirdn2d.filter2d(real_img_raw, f / f.sum())

        real_img = {'image': real_img, 'image_raw': real_img_raw}

        # run PTI to get w/w_plus latent codes for real images
        # print(real_img.shape, real_c.shape, gen_z.shape, gen_c.shape)
        # torch.Size([8, 3, 512, 512]) torch.Size([8, 25]) torch.Size([8, 512]) torch.Size([8, 25])
        # convert gen_z to real_z

        batch_size = real_img['image'].shape[0]
        real_z = []
        for i in range(batch_size):
            cur_img = real_img['image'][i]
            cur_c = real_c[i]
            cur_z = pti_projector(self.G, cur_c, cur_img, device=self.device)
            real_z.append(cur_z)
        real_z = torch.stack(real_z)
        print('real_z', real_z.shape)
        

        # Gmain: Maximize logits for generated images.
        if phase in ['Gmain', 'Gboth']:
            with torch.autograd.profiler.record_function('Gmain_forward'):
                gen_img, _gen_ws = self.run_G(gen_z, gen_c, swapping_prob=swapping_prob, neural_rendering_resolution=neural_rendering_resolution)
                gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma)
                training_stats.report('Loss/scores/fake', gen_logits)
                training_stats.report('Loss/signs/fake', gen_logits.sign())
                loss_Gmain = torch.nn.functional.softplus(-gen_logits)
                training_stats.report('Loss/G/loss', loss_Gmain)
            with torch.autograd.profiler.record_function('Gmain_backward'):
                loss_Gmain.mean().mul(gain).backward()

        # Density Regularization
        if phase in ['Greg', 'Gboth'] and self.G.rendering_kwargs.get('density_reg', 0) > 0 and self.G.rendering_kwargs['reg_type'] == 'l1':
            if swapping_prob is not None:
                c_swapped = torch.roll(gen_c.clone(), 1, 0)
                c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c)
            else:
                c_gen_conditioning = torch.zeros_like(gen_c)

            ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False)
            if self.style_mixing_prob > 0:
                with torch.autograd.profiler.record_function('style_mixing'):
                    cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
                    cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
                    ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:]
            initial_coordinates = torch.rand((ws.shape[0], 1000, 3), device=ws.device) * 2 - 1
            perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * self.G.rendering_kwargs['density_reg_p_dist']
            all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1)
            sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma']
            sigma_initial = sigma[:, :sigma.shape[1]//2]
            sigma_perturbed = sigma[:, sigma.shape[1]//2:]

            TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs['density_reg']
            TVloss.mul(gain).backward()

        # Alternative density regularization
        if phase in ['Greg', 'Gboth'] and self.G.rendering_kwargs.get('density_reg', 0) > 0 and self.G.rendering_kwargs['reg_type'] == 'monotonic-detach':
            if swapping_prob is not None:
                c_swapped = torch.roll(gen_c.clone(), 1, 0)
                c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c)
            else:
                c_gen_conditioning = torch.zeros_like(gen_c)

            ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False)

            initial_coordinates = torch.rand((ws.shape[0], 2000, 3), device=ws.device) * 2 - 1 # Front

            perturbed_coordinates = initial_coordinates + torch.tensor([0, 0, -1], device=ws.device) * (1/256) * self.G.rendering_kwargs['box_warp'] # Behind
            all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1)
            sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma']
            sigma_initial = sigma[:, :sigma.shape[1]//2]
            sigma_perturbed = sigma[:, sigma.shape[1]//2:]

            monotonic_loss = torch.relu(sigma_initial.detach() - sigma_perturbed).mean() * 10
            monotonic_loss.mul(gain).backward()


            if swapping_prob is not None:
                c_swapped = torch.roll(gen_c.clone(), 1, 0)
                c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c)
            else:
                c_gen_conditioning = torch.zeros_like(gen_c)

            ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False)
            if self.style_mixing_prob > 0:
                with torch.autograd.profiler.record_function('style_mixing'):
                    cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
                    cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
                    ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:]
            initial_coordinates = torch.rand((ws.shape[0], 1000, 3), device=ws.device) * 2 - 1
            perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * (1/256) * self.G.rendering_kwargs['box_warp']
            all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1)
            sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma']
            sigma_initial = sigma[:, :sigma.shape[1]//2]
            sigma_perturbed = sigma[:, sigma.shape[1]//2:]

            TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs['density_reg']
            TVloss.mul(gain).backward()

        # Alternative density regularization
        if phase in ['Greg', 'Gboth'] and self.G.rendering_kwargs.get('density_reg', 0) > 0 and self.G.rendering_kwargs['reg_type'] == 'monotonic-fixed':
            if swapping_prob is not None:
                c_swapped = torch.roll(gen_c.clone(), 1, 0)
                c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c)
            else:
                c_gen_conditioning = torch.zeros_like(gen_c)

            ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False)

            initial_coordinates = torch.rand((ws.shape[0], 2000, 3), device=ws.device) * 2 - 1 # Front

            perturbed_coordinates = initial_coordinates + torch.tensor([0, 0, -1], device=ws.device) * (1/256) * self.G.rendering_kwargs['box_warp'] # Behind
            all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1)
            sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma']
            sigma_initial = sigma[:, :sigma.shape[1]//2]
            sigma_perturbed = sigma[:, sigma.shape[1]//2:]

            monotonic_loss = torch.relu(sigma_initial - sigma_perturbed).mean() * 10
            monotonic_loss.mul(gain).backward()


            if swapping_prob is not None:
                c_swapped = torch.roll(gen_c.clone(), 1, 0)
                c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c)
            else:
                c_gen_conditioning = torch.zeros_like(gen_c)

            ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False)
            if self.style_mixing_prob > 0:
                with torch.autograd.profiler.record_function('style_mixing'):
                    cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
                    cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
                    ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:]
            initial_coordinates = torch.rand((ws.shape[0], 1000, 3), device=ws.device) * 2 - 1
            perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * (1/256) * self.G.rendering_kwargs['box_warp']
            all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1)
            sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma']
            sigma_initial = sigma[:, :sigma.shape[1]//2]
            sigma_perturbed = sigma[:, sigma.shape[1]//2:]

            TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs['density_reg']
            TVloss.mul(gain).backward()

        # Dmain: Minimize logits for generated images.
        loss_Dgen = 0
        if phase in ['Dmain', 'Dboth']:
            with torch.autograd.profiler.record_function('Dgen_forward'):
                gen_img, _gen_ws = self.run_G(gen_z, gen_c, swapping_prob=swapping_prob, neural_rendering_resolution=neural_rendering_resolution, update_emas=True)
                gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma, update_emas=True)
                training_stats.report('Loss/scores/fake', gen_logits)
                training_stats.report('Loss/signs/fake', gen_logits.sign())
                loss_Dgen = torch.nn.functional.softplus(gen_logits)
            with torch.autograd.profiler.record_function('Dgen_backward'):
                loss_Dgen.mean().mul(gain).backward()

        # Dmain: Maximize logits for real images.
        # Dr1: Apply R1 regularization.
        if phase in ['Dmain', 'Dreg', 'Dboth']:
            name = 'Dreal' if phase == 'Dmain' else 'Dr1' if phase == 'Dreg' else 'Dreal_Dr1'
            with torch.autograd.profiler.record_function(name + '_forward'):
                real_img_tmp_image = real_img['image'].detach().requires_grad_(phase in ['Dreg', 'Dboth'])
                real_img_tmp_image_raw = real_img['image_raw'].detach().requires_grad_(phase in ['Dreg', 'Dboth'])
                real_img_tmp = {'image': real_img_tmp_image, 'image_raw': real_img_tmp_image_raw}

                real_logits = self.run_D(real_img_tmp, real_c, blur_sigma=blur_sigma)
                training_stats.report('Loss/scores/real', real_logits)
                training_stats.report('Loss/signs/real', real_logits.sign())

                loss_Dreal = 0
                if phase in ['Dmain', 'Dboth']:
                    loss_Dreal = torch.nn.functional.softplus(-real_logits)
                    training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal)

                loss_Dr1 = 0
                if phase in ['Dreg', 'Dboth']:
                    if self.dual_discrimination:
                        with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
                            r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp['image'], real_img_tmp['image_raw']], create_graph=True, only_inputs=True)
                            r1_grads_image = r1_grads[0]
                            r1_grads_image_raw = r1_grads[1]
                        r1_penalty = r1_grads_image.square().sum([1,2,3]) + r1_grads_image_raw.square().sum([1,2,3])
                    else: # single discrimination
                        with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
                            r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp['image']], create_graph=True, only_inputs=True)
                            r1_grads_image = r1_grads[0]
                        r1_penalty = r1_grads_image.square().sum([1,2,3])
                    loss_Dr1 = r1_penalty * (r1_gamma / 2)
                    training_stats.report('Loss/r1_penalty', r1_penalty)
                    training_stats.report('Loss/D/reg', loss_Dr1)

            with torch.autograd.profiler.record_function(name + '_backward'):
                (loss_Dreal + loss_Dr1).mean().mul(gain).backward()

#----------------------------------------------------------------------------

and here is the modified projector scripts:

w_plus_projector.py
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

"""Project given image to the latent space of pretrained network pickle."""

import copy
import os
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
import dnnlib
import PIL
from camera_utils import LookAtPoseSampler

def project(
        G,
        c,
        # outdir,
        target: torch.Tensor,  # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
        *,
        num_steps=1000,
        w_avg_samples=10000,
        initial_learning_rate=0.01,
        initial_noise_factor=0.05,
        lr_rampdown_length=0.25,
        lr_rampup_length=0.05,
        noise_ramp_length=0.75,
        regularize_noise_weight=1e5,
        verbose=False,
        device: torch.device,
        initial_w=None,
        image_log_step=100,
        # w_name: str
):
    # os.makedirs(f'{outdir}/{w_name}_w_plus', exist_ok=True)
    # outdir = f'{outdir}/{w_name}_w_plus'
    assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)

    def logprint(*args):
        if verbose:
            print(*args)

    G = copy.deepcopy(G).eval().requires_grad_(False).to(device).float() # type: ignore

    # Compute w stats.
    w_avg_path = './w_avg.npy'
    w_std_path = './w_std.npy'
    if (not os.path.exists(w_avg_path)) or (not os.path.exists(w_std_path)):
        print(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
        z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
        # c_samples = c.repeat(w_avg_samples, 1)

        # use avg look at point

        camera_lookat_point = torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device)
        cam2world_pose = LookAtPoseSampler.sample(3.14 / 2, 3.14 / 2, camera_lookat_point,
                                                  radius=G.rendering_kwargs['avg_camera_radius'], device=device)
        focal_length = 4.2647  # FFHQ's FOV
        intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
        c_samples = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
        c_samples = c_samples.repeat(w_avg_samples, 1)

        w_samples = G.mapping(torch.from_numpy(z_samples).to(device), c_samples)  # [N, L, C]
        w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32)  # [N, 1, C]
        w_avg = np.mean(w_samples, axis=0, keepdims=True)  # [1, 1, C]
        # print('save w_avg  to ./w_avg.npy')
        # np.save('./w_avg.npy',w_avg)
        w_avg_tensor = torch.from_numpy(w_avg).cuda()
        w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5

        # np.save(w_avg_path, w_avg)
        # np.save(w_std_path, w_std)
    else:
        # w_avg = np.load(w_avg_path)
        # w_std = np.load(w_std_path)
        raise Exception(' ')

    # z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
    # c_samples = c.repeat(w_avg_samples, 1)
    # w_samples = G.mapping(torch.from_numpy(z_samples).to(device), c_samples)  # [N, L, C]
    # w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32)  # [N, 1, C]
    # w_avg = np.mean(w_samples, axis=0, keepdims=True)  # [1, 1, C]
    # w_avg_tensor = torch.from_numpy(w_avg).cuda()
    # w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5

    start_w = initial_w if initial_w is not None else w_avg

    # Setup noise inputs.
    noise_bufs = {name: buf for (name, buf) in G.backbone.synthesis.named_buffers() if 'noise_const' in name}

    # Load VGG16 feature detector.
    url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
    # url = './networks/vgg16.pt'
    with dnnlib.util.open_url(url) as f:
        vgg16 = torch.jit.load(f, map_location=device).eval().to(device)

    # Features for target image.
    target_images = target.unsqueeze(0).to(device).to(torch.float32)
    if target_images.shape[2] > 256:
        target_images = F.interpolate(target_images, size=(256, 256), mode='area')
    target_features = vgg16(target_images, resize_images=False, return_lpips=True)

    start_w = np.repeat(start_w, G.backbone.mapping.num_ws, axis=1)
    w_opt = torch.tensor(start_w, dtype=torch.float32, device=device,
                         requires_grad=True)  # pylint: disable=not-callable

    optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999),
                                 lr=0.1)

    # Init noise.
    for buf in noise_bufs.values():
        buf[:] = torch.randn_like(buf)
        buf.requires_grad = True

    for step in tqdm(range(num_steps)):

        # Learning rate schedule.
        t = step / num_steps
        w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
        lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
        lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
        lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
        lr = initial_learning_rate * lr_ramp
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # Synth images from opt_w.
        w_noise = torch.randn_like(w_opt) * w_noise_scale
        ws = (w_opt + w_noise)
        synth_images = G.synthesis(ws,c, noise_mode='const')['image']

        # if step % image_log_step == 0:
        #     with torch.no_grad():
        #         vis_img = (synth_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)

        #         PIL.Image.fromarray(vis_img[0].cpu().numpy(), 'RGB').save(f'{outdir}/{step}.png')

        # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
        synth_images = (synth_images + 1) * (255 / 2)
        if synth_images.shape[2] > 256:
            synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')

        # Features for synth images.
        synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
        dist = (target_features - synth_features).square().sum()

        # Noise regularization.
        reg_loss = 0.0
        for v in noise_bufs.values():
            noise = v[None, None, :, :]  # must be [1,1,H,W] for F.avg_pool2d()
            while True:
                reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2
                reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2
                if noise.shape[2] <= 8:
                    break
                noise = F.avg_pool2d(noise, kernel_size=2)
        loss = dist + reg_loss * regularize_noise_weight

        # if step % 10 == 0:
        #     with torch.no_grad():
        #         print({f'step {step}, first projection _{w_name}': loss.detach().cpu()})

        # Step
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        logprint(f'step {step + 1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}')

        # Normalize noise.
        with torch.no_grad():
            for buf in noise_bufs.values():
                buf -= buf.mean()
                buf *= buf.square().mean().rsqrt()

    del G
    return w_opt

I got errors as shown below:

Computing W midpoint and stddev using 600 samples...
  0%|          | 0/1000 [00:00<?, ?it/s]/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/nn/modules/module.py:1488: UserWarning: operator() profile_node %106 : int = prim::profile_ivalue(%104)
 does not have profile information (Triggered internally at /opt/conda/conda-bld/pytorch_1674202356920/work/torch/csrc/jit/codegen/cuda/graph_fuser.cpp:105.)
  return forward_call(*args, **kwargs)
  0%|          | 0/1000 [00:07<?, ?it/s]
Traceback (most recent call last):
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/train.py", line 396, in <module>
    main() # pylint: disable=no-value-for-parameter
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/click/core.py", line 1128, in __call__
    return self.main(*args, **kwargs)
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/click/core.py", line 1053, in main
    rv = self.invoke(ctx)
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/click/core.py", line 1395, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/click/core.py", line 754, in invoke
    return __callback(*args, **kwargs)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/train.py", line 391, in main
    launch_training(c=c, desc=desc, outdir=opts.outdir, dry_run=opts.dry_run)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/train.py", line 101, in launch_training
    subprocess_fn(rank=0, c=c, temp_dir=temp_dir)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/train.py", line 52, in subprocess_fn
    training_loop.training_loop(rank=rank, **c)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/training/training_loop.py", line 286, in training_loop
    loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, gain=phase.interval, cur_nimg=cur_nimg)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/training/loss.py", line 156, in accumulate_gradients
    cur_z = pti_projector(self.G, cur_c, cur_img, device=self.device)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/training/loss.py", line 49, in pti_projector
    w = w_plus_projector.project(G, c, id_image, device=device, w_avg_samples=600)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/training/projector/w_plus_projector.py", line 171, in project
    loss.backward()
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/autograd/function.py", line 275, in apply
    return user_fn(self, *args)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/torch_utils/ops/conv2d_gradfix.py", line 146, in backward
    grad_weight = Conv2dGradWeight.apply(grad_output, input, weight)
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/autograd/function.py", line 508, in apply
    return super().apply(*args, **kwargs)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/torch_utils/ops/conv2d_gradfix.py", line 173, in forward
    return torch.ops.aten.convolution_backward(grad_output=grad_output, input=input, weight=weight, bias_sizes=None, stride=stride, padding=padding, dilation=dilation, transposed=transpose, output_padding=output_padding, groups=groups, output_mask=[False, True, False])[1]
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/_ops.py", line 499, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument weight in method wrapper_CUDA__convolution_backward)

seems like the error happens at loss.backward() and I checked most of the variables/loss/model to make sure they are on cuda:0. but I still got no luck to solve this. Do you know how to make loss backpropagate properly?

Unexpected size of W features in generated_images = self.G.synthesis(w, noise_mode='const', force_fp32=True)

I am using a self-trained model trained using the StyleGan-ada pytorch repository.
While using use_multi_id_training=True I get a size mismatch error in the forward call of G.synthesis.

The full trace is shown:

Traceback (most recent call last):
File "test.py", line 6, in
run_PTI(use_multi_id_training=True)
File "/home/usman/Documents/Work/PTI/scripts/run_pti.py", line 44, in run_PTI
coach.train()
File "/home/usman/Documents/Work/PTI/training/coaches/multi_id_coach.py", line 58, in train
generated_images = self.forward(w_pivot)
File "/home/usman/Documents/Work/PTI/training/coaches/base_coach.py", line 130, in forward
generated_images = self.G.synthesis(w, noise_mode='const', force_fp32=True)
File "/home/usman/anaconda3/envs/eg3d/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "", line 460, in forward
File "/home/usman/Documents/Work/PTI/torch_utils/misc.py", line 93, in assert_shape
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
AssertionError: Wrong size for dimension 1: got 18, expected 16

About the initialization of the Inversion

Thanks for your great work, as for the inversion step in the paper , i want to konw if there are any initialization for the W to look after the Wp ,like W average or just random vector?

Pivotal tuning for grey-scale (sketch) images

Hi Daniel. Thank you for great work.
I have trained my StyleGAN2-ada model using sketch data, which generates sketches quite well.
After that, to manipulate real-images, I have tested PTI but the quality was not good.
When I remove LPIPS loss (using only L2 loss) running the pivotal tuning, the reconstruction went well.
However, the manipulation is still not working very well.
Could you please provide any tips on this?
Should I train LPIPS with different datasets? or any loss function you can recommend?

Error in Colab

Hi,

I'm running the colab as is and I always get this error:

pre_process_images(f'/content/PTI/{image_dir_name}_original')

RuntimeError                              Traceback (most recent call last)
<ipython-input-18-4cd7c7dc57be> in <module>()
----> 1 pre_process_images(f'/content/PTI/{image_dir_name}_original')

/content/PTI/utils/align_data.py in pre_process_images(raw_images_path)
     12 
     13     IMAGE_SIZE = 1024
---> 14     predictor = dlib.shape_predictor(paths_config.dlib)
     15     os.chdir(raw_images_path)
     16     images_names = glob.glob(f'*')
RuntimeError: Error deserializing object of type int

I did not make any changes, perhaps there is something extra I need to configure?
Just letting you know.

Thanks

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.