zhaoyi-yan / shift-net_pytorch Goto Github PK
View Code? Open in Web Editor NEWPytorch implementation of Shift-Net: Image Inpainting via Deep Feature Rearrangement (ECCV, 2018)
License: MIT License
Pytorch implementation of Shift-Net: Image Inpainting via Deep Feature Rearrangement (ECCV, 2018)
License: MIT License
Could you please update a pre-trained model for test.
Thanks in advance.
Hi! Thanks a lot for your work! I downloaded 2 pretrained face models and noticed that results of the model with square masks are much better than that with random masks despite that random masks are smaller than squares. Below is the example. Why do you think this is? Also do you think that model can be trained with higher resolutions?
When I run test_m.py
,
(1, 4, 4)
torch.FloatTensor
cosine
tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, nan,
1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, nan,
1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, nan,
1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, nan,
1.0000, 1.0000, 1.0000]])
index
tensor([8, 8, 8, 8])
former
tensor([[[29., 19., 5., 38.],
[49., 7., 33., 38.],
[ 2., 22., 10., 25.],
[39., 43., 33., 36.]]])
latter
tensor([[[ 7., 6., 11., 35.],
[30., 18., 14., 30.],
[14., 1., 24., 27.],
[ 0., 15., 7., 12.]]])
flag
tensor([[0, 0, 0, 0],
[0, 1, 1, 0],
[0, 1, 1, 0],
[0, 0, 0, 0]], dtype=torch.uint8)
ind_lst
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([[[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]]])
Another test
(1, 4, 4)
torch.FloatTensor
cosine
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
index
tensor([0, 0, 0, 0])
former
tensor([[[48., 27., 47., 18.],
[40., 38., 13., 32.],
[26., 16., 40., 14.],
[31., 29., 2., 36.]]])
latter
tensor([[[34., 38., 39., 29.],
[14., 49., 9., 5.],
[18., 49., 21., 10.],
[20., 26., 21., 1.]]])
flag
tensor([[0, 0, 0, 0],
[0, 1, 1, 0],
[0, 1, 1, 0],
[0, 0, 0, 0]], dtype=torch.uint8)
ind_lst
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([[[[ 0., 0., 0., 0.],
[ 0., 34., 34., 0.],
[ 0., 34., 34., 0.],
[ 0., 0., 0., 0.]]]])
Hello, I wonder when training the Places2 dataset, whether part of the data set is used or the whole data set? And how long it takes for the training? Thanks.
Will investigate how to change shift operation to batch shift, it is a bottleneck for higher gpu ultilization.
I am curious about the method that you mentioned to obtain the mask of feature map. Why not downsample directly?
finish discounting loss l1
in tmp_one_D
: aed3638
Make layers_to_last
an attribute of shift layer, making it easier to implement multi-shift layers.
Clean shiftnet_model.py
, move some code of wgan_gp
. See here junyanz/pytorch-CycleGAN-and-pix2pix@032e884 as a reference.
Hello, I was studying your work and during some tests I found that the current approach has issues with the size of masks: if there is no mask at all, the shift will try to find a mask in the latent space, and as it does not have a case to no mask in the latent space, the code it will crash. This case can be expanded to a case where the mask is small enough to be compressed, generating no mask in some reduction, during the compression phase, crashing again.
What modifications can be done to remove that issue? Currently I am inserting a mask with 8x8 pixels in a irrelevant part of the image, but it is not optimal. As I still have not thought in a better solution, I am asking you a better approach. I am open to develop that and make a pull request in your repo.
About 2 years ago, we find a new novel training strategy that boosts the peformance of face inpainting a lot. This strategy was originally taken as the second novelty of Shift-Net v2 as the journel version. However, for some reason, I do not write the paper until now. Unfortuatelly, I do not think Shift-Net v2 will come out in the furture. So I will release the code. It surpasses normal Shift-Net in the irregular face inpainting a lot!
Traceback (most recent call last):
File "test.py", line 9, in <module>
opt = TestOptions().parse()
File "/home/Shift_Net/options/base_options.py", line 97, in parse
opt = self.gather_options()
File "/home/Shift_Net/options/base_options.py", line 69, in gather_options
parser = self.initialize(parser)
TypeError: initialize() takes exactly 1 argument (2 given)
I miss something in test_options.py
, it causes this error.
It seems not correct. Because it can only handle pixel-to-pixel shift, not patch.
The mask generation is inside of the model.
Could you pls move the mask generation outside so that the model can be tested on given mask?
There might be one error in the 238 line of models/shiftnet_model.py ,there no object called self.ng_loss
It is not easy to solve, in fact.
I have no idea on how to solve it.
I once write another kind of InnerCos.py
which only works on multi-gpu but not suitable for single-gpu.
I do not know how to solve it by now.
Hello, Mr. Yan.When I read your code, I found that you created several variants on the basis of the original, which really made me admire. But I don't understand the difference between them and shift-net. Could you briefly introduce them?Thank you
Hello, could you explain why I have the following problem when I have tried to execute -
python test.py --which_epoch=30 --name='paris_random_mask_20_30' --offline_loading_mask=1 --testing_mask_folder='masks' --dataroot='./datasets/celeba-256/test' --norm='instance'
AssertionError: ./datasets/Paris/train is not a valid directory.
Do I need to download the dataset to make the inference, if so where can I do it?
Could you give me the link of Paris StreetView dataset?
The original author did not give me the link yet.
Thanks a lot.
No bug, I just made a mistake.
File "/home/yaxian/image-creation/Shift_Net/Shift-Net_pytorch/data/aligned_dataset.py", line 33, in getitem
w, h = AB.size
TypeError: 'builtin_function_or_method' object is not iterable
It would be helpful if you provide pretained model.
Thank you.
It goes like this: setting a flag offline_testing
to load given masks from the disk.
gt_latent = self.ng_innerCos_list[0].get_target() # then get gt_latent(should be variable require_grad=False)
self.ng_innerCos_list[0].set_target(gt_latent)
I can not find these codes in set_gt_latent(self) function, so confused
For now, the construction of UNet is complicated and not flexible enough. Especially, when we need to adding other components to UNet, such type of model construction is surely horrible. So we need to construct the model in another way(More lines of code, yet more flexible)
Hi @Zhaoyi-Yan! I'm training your model in Google Colab. When starting it floods the screen with user warnings during all training. Like this:
It is overloading the memory. How can I remove it? Or at least shut it down.
Hello,
I am going to concentrate next week reading and learning about your code.
I have several questions I would like to discuss with you.
My email adress is [email protected].
Best,
Thomas Chaton.
I want to run this project with CPU only.And what changes should I make about the code?
Thanks for your reading!
What's your opinion on this ? @tchaton
When I pushed this commit #81 , it makes multi-gpu training broken.
I have known how to solve it, will push a commit when I am free.
@tchaton
I DO NOT when it goes like it.
When set batchSize=8
and run acc_unet_shift
.
When batchsize>1, x_latter.size()
is (8, 128, 64, 64)
, while x.size()
is (1, 128, 64, 64)
File "/home/yan/github/Shift-Net_pytorch/models/modules/shift_unet.py", line 236, in forward
return torch.cat([x_latter, x], 1) # cat in the C channel
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 8 and 1 in dimension 0 at /opt/conda/conda-bld/pytorch_1535488076166/work/aten/src/THC/generic/THCTensorMath.cu:87
Hey,
thank you for publishing the code for your project.
I am trying to run inference on images that are not 256x256 pixels, but of an arbitrary size.
It seems like if I do that the program automatically crops a 256x256 image out and uses that.
Is it possible by tweaking some code to run inference for images of arbitrary size?
Thank you and best regards
ThJOD
I have a 1080Ti GPU ,and when I'm training,the training stops at third epoch.And I have tried again,it stopped at third epoch too,have you met this question?Why?Thanks.
networks.py
. It is redundant and a little bit complex for now.init_gain
option.InnerCos
and Shift
layer.Skip
to InnerCos
as a workaround.The main difficulty lies in the implementationInnerCos
. The target of InnerCos
is not on the same GPU as that of its input. Multi-gpu support counts for a great deal when handling enormous datasets.
After I ran the train.py error comes up
Traceback (most recent call last):
File "train.py", line 8, in
opt = TrainOptions().parse()
File "/Users//Shift-Net_pytorch/options/base_options.py", line 99, in parse
if opt.suffix:
AttributeError: 'Namespace' object has no attribute 'suffix'
Since we acc_shift_net
has been developed, we need to delete the obsolete code to keep the codebase clean.
What's your opinion on this? @tchaton
Hwllo,I am currently working on image inpainting work and trying to train several model architectures. I saw that you used paris dataset in the experiments, I have been looking for it.Can you share the dataset through a private link in my email address?
Thank you very much
Email:[email protected]
Hello, Mr. Yan. Can you tell me how to understand parameter ‘mask_thred’,How can I choose a suitable value?。
in util.py,What is the function of mm = m.gt(mask_thred/(1.*patch_size**2 + 1e-4)).long()
Why is the value of eps selected as 1e-4?
Hello @Zhaoyi-Yan ,
I have been running speed tests.
And visualizer.display_current_results time is growing due to the increases in images.
Therefore, I removed it from the main loop.
However, we should have a version which just change one elements and not all of them.
Best,
T.C
Hello,I just run the PyTorch code training with Paris StreetView(30 epochs) ,but I found PSNR is much lower than that in your paper.Why?
Hello. At first, thank you for the great idea and the code release.
I'm training the Shift-Net, but I'm not sure if it is converged or not since the logs do not say about the metrics (PSNR, SSIM, and Mean L2 Loss) reported in the paper.
So, can you please tell me how many epochs do I need for training to get the performances reported in the paper? Or if you have a script that calculates those metrics, I'd be happy if you can share with us.
Thanks!
As the Nonparametric
is modified, _paste
is no long what it means before. Need a fix, when the final Nonparametric
is decided.
Usually, Such UNet suffers with small training batchsize when training with fixed mask(usually batchsize=1), now, I find a easy way to solve it.
Hello,Regarding the dataset, do I put all the data directly in opt.dataroot, then the code will automatically divide the data set into a training set, a validation set, a test set? Or do I only put the training set into opt.dataroot?
Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
Hello there,
The training was very slow.
I started to look into the code (You can find the notebook for the optimization on my repo).
At my big surprise, it takes 0.5 s to forward with a square centered mask. I was expecting way more.
I checked with your random mask generator.
while True:
x = random.randint(1, MAX_SIZE-fineSize)
y = random.randint(1, MAX_SIZE-fineSize)
mask = pattern[y:y+fineSize, x:x+fineSize] # need check
area = mask.sum()100./(fineSizefineSize)
if area>20 and area<maxPartition:
break
wastedIter += 1
You have a while True that sometines never finishes. It took between 6 sec to 400 sec.
I am going to remove it.
Can this model trained on other mask dataset?
(epoch: 1, iters: 800, time: 0.027, data: 0.521) G_GAN: 2.870 G_L1: 31.211 D: 0.407
(epoch: 1, iters: 1600, time: 0.027, data: 0.009) G_GAN: 5.973 G_L1: 25.281 D: 0.072
(epoch: 1, iters: 2400, time: 0.029, data: 0.010) G_GAN: 6.557 G_L1: 21.723 D: 0.005
(epoch: 1, iters: 3200, time: 0.029, data: 0.010) G_GAN: 5.861 G_L1: 17.568 D: 0.006
(epoch: 1, iters: 4000, time: 0.028, data: 0.011) G_GAN: 5.291 G_L1: 16.388 D: 0.061
(epoch: 1, iters: 4800, time: 0.029, data: 0.014) G_GAN: 5.334 G_L1: 15.123 D: 0.006
(epoch: 1, iters: 5600, time: 0.028, data: 0.010) G_GAN: 5.188 G_L1: 16.427 D: 0.219
Traceback (most recent call last):
File "train.py", line 33, in
model.set_gt_latent()
File "/home/tchaton/projects/original/Shift-Net_pytorch/models/shift_net/shiftnet_model.py", line 169, in set_gt_latent
self.netG(real_B) # input ground truth
File "/home/tchaton/virtualenvs/labelbox/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "/home/tchaton/virtualenvs/labelbox/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 121, in forward
return self.module(*inputs[0], **kwargs[0])
File "/home/tchaton/virtualenvs/labelbox/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "/home/tchaton/projects/original/Shift-Net_pytorch/models/modules/shift_unet.py", line 59, in forward
return self.model(input)
File "/home/tchaton/virtualenvs/labelbox/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "/home/tchaton/projects/original/Shift-Net_pytorch/models/modules/unet.py", line 83, in forward
return self.model(x)
File "/home/tchaton/virtualenvs/labelbox/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "/home/tchaton/virtualenvs/labelbox/lib/python3.6/site-packages/torch/nn/modules/container.py", line 91, in forward
input = module(input)
File "/home/tchaton/virtualenvs/labelbox/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "/home/tchaton/projects/original/Shift-Net_pytorch/models/modules/unet.py", line 85, in forward
x_latter = self.model(x)
File "/home/tchaton/virtualenvs/labelbox/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "/home/tchaton/virtualenvs/labelbox/lib/python3.6/site-packages/torch/nn/modules/container.py", line 91, in forward
input = module(input)
File "/home/tchaton/virtualenvs/labelbox/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "/home/tchaton/projects/original/Shift-Net_pytorch/models/modules/shift_unet.py", line 134, in forward
x_latter = self.model(x)
File "/home/tchaton/virtualenvs/labelbox/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "/home/tchaton/virtualenvs/labelbox/lib/python3.6/site-packages/torch/nn/modules/container.py", line 91, in forward
input = module(input)
File "/home/tchaton/virtualenvs/labelbox/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "/home/tchaton/projects/original/Shift-Net_pytorch/models/shift_net/InnerCos.py", line 39, in forward
self.loss = self.criterion(self.former_in_mask, self.target.expand_as(self.former_in_mask).type_as(self.former_in_mask))
RuntimeError: The expanded size of the tensor (8) must match the existing size (32) at non-singleton dimension 0
The batch_size don't match.
Hi,the uploaded model for face center inpainting(named 30_net_G) is corrupted. Can you upload it again? Thx!
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.