bryandlee / stylegan2-encoder-pytorch Goto Github PK
View Code? Open in Web Editor NEWPyTorch Implementation of In-Domain GAN Inversion for StyleGAN2
License: MIT License
PyTorch Implementation of In-Domain GAN Inversion for StyleGAN2
License: MIT License
hi,
I have a question about the encoder+model interpolation. I have downloaded naver webtoon generator model, and I replaced high level layer with FFHQ model(test from 3-11 till 10-11, trucation=0.6), but I can't get the same result as yours. Can you give me some advices? Or there are some different hyperparametric?
Hoping for your reply!
What exactly does this repo do in practice ? Sorry I am new to this. I tried GAN of all kinds, but very new to this.
Are you trying to do what StarGAN is doing ?
Thanks,
Steve
Hi Bryandlee
thanks for your greate work!
I tried your Domain-Guided Encoder and get the latent (z0) with the shape [1,14,512]
but when I tried the offical stylegan2-pytorch project.py and get the latent with the shape[1,512].
so I really confused. from my understanding, the latent shape is always like [1,512]. so why this 14 coming in your encoder? thanks!
import os
import random
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from model import Generator, Encoder
from train_encoder import VGGLoss
import matplotlib.pyplot as plt
def image2tensor(image):
image = torch.FloatTensor(image).permute(2,0,1).unsqueeze(0)/255.
return (image-0.5)/0.5
def tensor2image(tensor):
tensor = tensor.clamp_(-1., 1.).detach().squeeze().permute(1,2,0).cpu().numpy()
return tensor*0.5 + 0.5
def imshow(img, size=5, cmap='jet'):
plt.figure(figsize=(size,size))
plt.imshow(img, cmap=cmap)
plt.axis('off')
plt.show()
device = 'cuda'
image_size=256
g_model_path = '/content/generator_ffhq.pt'
g_ckpt = torch.load(g_model_path, map_location=device)
latent_dim = g_ckpt['args'].latent
generator = Generator(image_size, latent_dim, 8).to(device)
generator.load_state_dict(g_ckpt["g_ema"], strict=False)
generator.eval()
print('[generator loaded]')
e_model_path = '/content/encoder_ffhq.pt'
e_ckpt = torch.load(e_model_path, map_location=device)
encoder = Encoder(image_size, latent_dim).to(device)
encoder.load_state_dict(e_ckpt['e'])
encoder.eval()
print('[encoder loaded]')
truncation = 0.7
trunc = generator.mean_latent(4096).detach().clone()
with torch.no_grad():
latent = generator.get_latent(torch.randn(4*6, latent_dim, device=device))
imgs_gen, _ = generator([latent],
truncation=truncation,
truncation_latent=trunc,
input_is_latent=True,
randomize_noise=True)
result = []
for row in imgs_gen.chunk(4, dim=0):
result.append(torch.cat([img for img in row], dim=2))
result = torch.cat(result, dim=1)
print('generated samples:')
imshow(tensor2image(result), size=15)
---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
<ipython-input-254-b06be6808604> in <module>()
10 from torchvision import datasets, transforms
11
---> 12 from model import Generator, Encoder
13 from train_encoder import VGGLoss
14
6 frames
/usr/lib/python3.6/imp.py in find_module(name, path)
295 break # Break out of outer loop when breaking out of inner loop.
296 else:
--> 297 raise ImportError(_ERR_MSG.format(name), name=name)
298
299 encoding = None
ImportError: No module named 'fused'
---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.
To view examples of installing some common dependencies, click the
"Open Examples" button below.
Could you share the cartoon dataset of the last result in Readme.md? Thank you very much!
Hi! What training data was used to train encoder? Original implementation used 70k FFHQ, what about yours?
Thanks for your work.
By the way, interpolate.ipynb tries to load the checkpoint './checkpoint/generator_ffhq.pt', but it is missing.
Can you please let me know how to get it?
Best regards,
Ron
Is it possible to provide the convert_weight code? It does not appear in Seonghyeon Kim's Pytorch Implementation of StyleGAN2.
For example, g_args = g_ckpt['args'], g_ckpt["d_optim"].
Hi The results of inversion seems not good as the paper. Did you know what the qusetion is?
Dear bryandlee,
Thank you for your great work, it is a very nice project.
About the defalut checkpoint (encoder_ffhq.pt) in the main webpage, it is trained with a generator of 256x256? If so, then it is not the official stylegan2 on FFHQ 1024x1024. Would you mind provide the generator weight for this encoder?
Thank you very much for your help.
Best Wishes,
Alex
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.