naoto0804 / pytorch-adain Goto Github PK
View Code? Open in Web Editor NEWUnofficial pytorch implementation of 'Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization' [Huang+, ICCV2017]
License: MIT License
Unofficial pytorch implementation of 'Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization' [Huang+, ICCV2017]
License: MIT License
"Download vgg_normalized.pth/decoder.pth and put them under models/." Does not work for both liniks.
Hello, there. ^_^ I am trying to perform your code on my laptop which is equipped with the Windows 10 operating system. However, it cannot successfully load the pre-trained model "vgg_normalised.pth". I guess it is probably because that your pre-trained model is obtained under Linux. Could you please provide a version that can be used under Windows 10 or give me some instruction on how to train the "vgg_normalise.pth" under Windows?
I'm sorry to trouble you ,I implement the Spatial Control by pytorch.but I have some question.
I think the contentFeatureBG dimension and targetFeature dimension are different from contentFeature dimension,but the code writes: targetFeature = targetFeature:viewAs(contentFeature),I don't know
the following is my pytorch code:
if mask_path!=None:
_,C,H,W=content_f.size()
mask_img=Image.open(mask_path).resize((W,H))
mask=content_tf(mask_img) #W,H
maskView=mask.view(-1) #HW
idx=np.where(maskView.numpy()==0)#0->HW
bgmask=torch.LongTensor(torch.from_numpy(idx[0]))#0->H*W
contenF=content_f.view(C,-1) #C,H*W
contentFBG=contentF[:,bgmask].view(C,bgmask.size(0),1)
target_bg=adaptive_instance_normalization(contentFBG, style_f).squeeze()
feat=Variable(torch.FloatTensor(C,H*W).zero_().cuda(),volatile=True)
feat=target_bg[:,bgmask]
feat=feat.expand(content_f.size())
Hello, thank you for the code. I trained the model on my on dataset of FONTS which are black and white means the channel is 1, so now where can I set the channel size also can I use same vgg encoder or I have to change the encoder according to my channel size because while training on the images I didn't get the error but while the testing on I got the error I am using the decoder after training from experiments.
(test1) D:\Coding\AdaIN_Code\AdaIN_exp1>python test.py --content_dir input/content4 --style_dir input/style4
Traceback (most recent call last):
File "D:\Coding\AdaIN_Code\AdaIN_exp1\test.py", line 155, in
output = style_transfer(vgg, decoder, content, style,
File "D:\Coding\AdaIN_Code\AdaIN_exp1\test.py", line 28, in style_transfer
content_f = vgg(content)
File "C:\Users\ak874\anaconda3\envs\test1\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\ak874\anaconda3\envs\test1\lib\site-packages\torch\nn\modules\container.py", line 139, in forward
input = module(input)
File "C:\Users\ak874\anaconda3\envs\test1\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\ak874\anaconda3\envs\test1\lib\site-packages\torch\nn\modules\conv.py", line 457, in forward
return self._conv_forward(input, self.weight, self.bias)
File "C:\Users\ak874\anaconda3\envs\test1\lib\site-packages\torch\nn\modules\conv.py", line 453, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [3, 3, 1, 1], expected input[1, 1, 128, 128] to have 3 channels, but got 1 channels instead
Hello,Did you train according to your train.py file? The total number of iterations is 100,000.I found that you trained according to your train.py file, and the final test effect failed to reach the result of the paper.
How was your test effect? Thank!
Hi,
Can you explain what are the initial decoder, vgg weights given? I am referring the weights which are instructed to be set in ./models before training.
Thanks.
thanks for your works!
i have a question about nomalization.
vgg network normalize with imagenet RGB mean std value.
but i don't find code of vgg nomalization.
Is there reason not to use vgg nomalization?
running the commands:
python torch_to_pytorch.py --model models/vgg_normalised.t7
python torch_to_pytorch.py --model models/decoder.t7
produces the following error:
Traceback (most recent call last): File "torch_to_pytorch.py", line 9, in <module> from torch.utils.serialization import load_lua ModuleNotFoundError: No module named 'torch.utils.serialization'
Please refer to this issue: pytorch/pytorch#14630
I used the following to initial the net model and save it as save.pth:
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
decoder = net.decoder
vgg = net.vgg
decoder.eval()
vgg.eval()
decoder.load_state_dict(torch.load("models/decoder.pth"))
vgg.load_state_dict(torch.load('models/vgg_normalised.pth'))
vgg = nn.Sequential(*list(vgg.children())[:31])
vgg.to(device)
decoder.to(device)
model = net.Net(vgg, decoder)
# model.load_state_dict(torch.load("models/decoder.pth"), strict=False)
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
torch.save(model.state_dict(), "save.pth")
and I load the model:
model_new.load_state_dict(torch.load("save.pth"))
model_new.eval()
batch_size=5
img=cv2.imread('input/content/avril.jpg')
input_height=img.shape[0]
input_width=img.shape[1]
input_channels=img.shape[2]
output_channels=3
output_height=512
output_width=512
dummy_input = torch.randn(1, 3, 512, 512)
input_name = 'input'
output_name = 'output'
torch.onnx.export(model_new,
dummy_input,
'AdaIN_style_transfer.onnx',
opset_version=11,
verbose = True,
input_names=[input_name],
output_names=[output_name],
dynamic_axes={
input_name: {0: 'batch_size', 1: 'input_channels', 2: 'input_height', 3: 'input_width'},
output_name: {0: 'batch_size', 1: 'output_channels', 2: 'output_height', 3: 'output_width'}})
but I got the following error:
TypeError: forward() missing 1 required positional argument: 'style'
How do I solve the problem?
@naoto0804 hi, i want to train the code,but I'm confused on conv0. in the paper, it says that the encoder use a pretrained encoder. so why it has conv0 in the code?
Thanks for the PyTorch implementation, very helpful!
If I am not mistaken, it is currently not possible to process a whole batch of images. It would be great to add support for that.
when I input a content image with size 414 * 274, the output size turns out to be 776 * 512, and it becomes very blur. I guess the up-sample module may have some problems
Thank you a lot for your code!
It seems that the output shape will be a multiple of 4? I want to get same shape of input. Is it practicable?
Hi thanks for your great effort.
I am curious where the weight of the pre-trained vgg encoder are from. I trained the whole model with your pre-traiend vgg encoder, and the style-transfered result seems to have some portrait effect even the train set (both style source and content source) is not composed of portraits.
So I am suspicious that the pre-trained vgg encoder weight is trained on portraits such as WikiArt.
Could you clarify about the vgg encoder pre-training?
Thank you
Hi,
I find that during the training time, the encoder is the first few layers(up to relu4_1 ) of a pre-trained vgg-19, as the original paper said.
However, during the testing time, the encoder seems to be the whole vgg-19. I'm not sure if I am right, could you please help me to check it out?
Thanks,
azshue
Hi.
I wonder the difference between 'vgg_normalized.pth' you uploaded and pretrained vgg19 from torchvision.
I tried to train AdaIN with pretrained vgg19 from torchvision many times, but I've failed.
(I've succeed with 'vgg_normalized.pth')
Sample of my AdaIN inferences that trained with torchvision vgg19 for so many epochs looking so bad...
Thank you : )
Thanks for your code, I want to know what do you think of adding this layer as the original author did. Is the experimental result?
When I type this command,
python torch_to_pytorch.py --model models/vgg_normalised.t7
It produces errors as below.
Traceback (most recent call last):
File "torch_to_pytorch.py", line 321, in <module>
torch_to_pytorch(args.model, args.output)
File "torch_to_pytorch.py", line 266, in torch_to_pytorch
model = load_lua(t7_filename, unknown_classes=True)
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 608, in load_lua
return reader.read()
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 593, in read
return self.read_object()
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 523, in wrapper
result = fn(self, *args, **kwargs)
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 546, in read_object
return reader_registry[cls_name](self, version)
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 243, in read_nn_class
attributes = reader.read()
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 595, in read
return self.read_table()
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 523, in wrapper
result = fn(self, *args, **kwargs)
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 571, in read_table
k = self.read()
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 595, in read
return self.read_table()
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 523, in wrapper
result = fn(self, *args, **kwargs)
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 572, in read_table
v = self.read()
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 593, in read
return self.read_object()
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 523, in wrapper
result = fn(self, *args, **kwargs)
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 546, in read_object
return reader_registry[cls_name](self, version)
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 318, in wrapper
obj = build_fn(reader, version)
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 318, in wrapper
obj = build_fn(reader, version)
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 243, in read_nn_class
attributes = reader.read()
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 595, in read
return self.read_table()
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 523, in wrapper
result = fn(self, *args, **kwargs)
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 573, in read_table
table[k] = v
TypeError: unhashable type: 'list'
Also If I type the next line,
python torch_to_pytorch.py --model models/decoder.t7
It shows like this.
Traceback (most recent call last):
File "torch_to_pytorch.py", line 321, in <module>
torch_to_pytorch(args.model, args.output)
File "torch_to_pytorch.py", line 266, in torch_to_pytorch
model = load_lua(t7_filename, unknown_classes=True)
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 608, in load_lua
return reader.read()
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 593, in read
return self.read_object()
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 523, in wrapper
result = fn(self, *args, **kwargs)
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 546, in read_object
return reader_registry[cls_name](self, version)
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 243, in read_nn_class
attributes = reader.read()
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 595, in read
return self.read_table()
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 523, in wrapper
result = fn(self, *args, **kwargs)
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 571, in read_table
k = self.read()
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 595, in read
return self.read_table()
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 523, in wrapper
result = fn(self, *args, **kwargs)
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 572, in read_table
v = self.read()
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 593, in read
return self.read_object()
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 523, in wrapper
result = fn(self, *args, **kwargs)
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 546, in read_object
return reader_registry[cls_name](self, version)
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 318, in wrapper
obj = build_fn(reader, version)
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 318, in wrapper
obj = build_fn(reader, version)
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 243, in read_nn_class
attributes = reader.read()
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 595, in read
return self.read_table()
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 523, in wrapper
result = fn(self, *args, **kwargs)
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 572, in read_table
v = self.read()
File "C:\Users\a\Anaconda3\envs\AdaIN_env\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 598, in read
"corrupted.".format(typeidx))
torch.utils.serialization.read_lua_file.T7ReaderException: unknown type id -1050704824. The file may be corrupted.
It will be so thankful if you reply to the issue. :)
Is there a way to apply the model on grayscale images? I have tried multiple ways:
Attempt 1. copy the grayscale image three times and concatenate them:
def train_transform():
transform_list = [
transforms.Resize(size=(512, 512)),
transforms.RandomCrop(256),
transforms.ToTensor(),
transforms.Lambda(lambda x: torch.cat([x,x,x],dim=0)),
#or transforms.Lambda(lambda x: x.repeat(3, 1, 1) )
]
return transforms.Compose(transform_list)
Attempt 2. copy the grayscale image three times and concatenate them before sending to the network:
for i in tqdm(range(args.max_iter)):
adjust_learning_rate(optimizer, iteration_count=i)
content_images = next(content_iter)
content_images = torch.cat([content_images, content_images, content_images], dim=1).to(device)
style_images = next(style_iter)
style_images = torch.cat([style_images, style_images, style_images], dim=1).to(device)
loss_c, loss_s = network(content_images, style_images)
loss_c = args.content_weight * loss_c
loss_s = args.style_weight * loss_s
loss = loss_c + loss_s
......
Using only one method above will produce the error:
File "train.py", line 147, in
loss_c, loss_s = network(content_images, style_images)
File "/opt/anaconda/lib/python2.7/site-packages/torch/nn/modules/module.py", line 491, in call
result = self.forward(*input, **kwargs)
File "/domainadaptation/AdaIn/pytorch2/grayscale_pytorch-AdaIN/net.py", line 146, in forward
g_t_feats = self.encode_with_intermediate(g_t)
File "/domainadaptation/AdaIn/pytorch2/grayscale_pytorch-AdaIN/net.py", line 116, in encode_with_intermediate
results.append(func(results[-1]))
File "/opt/anaconda/lib/python2.7/site-packages/torch/nn/modules/module.py", line 491, in call
result = self.forward(*input, **kwargs)
File "/opt/anaconda/lib/python2.7/site-packages/torch/nn/modules/container.py", line 91, in forward
input = module(input)
File "/opt/anaconda/lib/python2.7/site-packages/torch/nn/modules/module.py", line 491, in call
result = self.forward(*input, **kwargs)
File "/opt/anaconda/lib/python2.7/site-packages/torch/nn/modules/conv.py", line 301, in forward
self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight[3, 3, 1, 1], so expected input[8, 1, 256, 256] to have 3 channels, but got 1 channels instead
However combining them I could get something like this:
RuntimeError: Given groups=1, weight[3, 3, 1, 1], so expected input[8, 9, 256, 256] to have 3 channels, but got 9 channels instead
Attempt 3. Using transforms.Grayscale(3):
def train_transform():
transform_list = [
transforms.Grayscale(3)
transforms.Resize(size=(512, 512)),
transforms.RandomCrop(256),
transforms.ToTensor(),
transforms.Lambda(lambda x: torch.cat([x,x,x],dim=0)),
#or transforms.Lambda(lambda x: x.repeat(3, 1, 1) )
]
return transforms.Compose(transform_list)
Still got error:
RuntimeError: Given groups=1, weight[3, 3, 1, 1], so expected input[8, 1, 256, 256] to have 3 channels, but got 1 channel instead
In the Class FlatFolderDataset , I have also tried to modify:
img = Image.open(os.path.join(self.root, path)).convert('RGB')
to
img = Image.open(os.path.join(self.root, path)).convert('L')
ALL THOSE DID NOT WORK FOR ME. Is there a way to solve this issue? Thanks so much!
Hi.
I wonder how to train a new style.
I put a image in ./input/content and another in ./input/style, using "vgg_normalized.pth"
CMD:python train.py --content_dir input/content --style_dir input/style --batch_size 4 --max_iter 160
Though I got a xx.pth.tar in ./experiments, I use it "test.py", getting a gray image result.
The link for the model downloads is broken.
vgg_normalized.pth/decoder.pth
Can someone help?
Hello, thank you for your contribution. Can adaIN be applied to other network structures? Can adaIN be used in deblurring network without coding-decoding structure?
Could you suggest a way to incorporate Adaptive Instance normalization like several normalization methods like batch, and instance in other models like UNet, in the form of a layer?
Hi, I found that different from official vgg-19, in this project you added a conv layer of 3*3*1*1
to the input of vgg-19. Could you please tell me the reason? Besides, how did you train this new version of vgg-19 and get the pretrained feature extractor?
It seems that the uploaded pretrain-model has broken. Could upload it again?
python torch_to_pytorch.py --model models/vgg_normalised.t7
Traceback (most recent call last):
File "torch_to_pytorch.py", line 9, in <module>
from torch.utils.serialization import load_lua
ModuleNotFoundError: No module named 'torch.utils.serialization'
It seems like the function is long removed by pytorch: https://stackoverflow.com/questions/54107156/modulenotfounderror-no-module-named-torch-utils-serialization
We should use https://github.com/bshillingford/python-torchfile instead.
On running test.py with proper setup an error is thrown from line 121:
RuntimeError: Error(s) in loading state_dict for Sequential: Missing key(s) in state_dict: "1.weight", "1.bias", "8.weight", "8.bias", "11.weight", "11.bias", "14.weight", "14.bias", "18.weight", "18.bias", "21.weight", "21.bias", "28.weight", "28.bias". Unexpected key(s) in state_dict: "29.weight", "29.bias", "32.weight", "32.bias", "35.weight", "35.bias", "38.weight", "38.bias", "42.weight", "42.bias", "45.weight", "45.bias", "48.weight", "48.bias", "51.weight",
After changing line 121 from:
decoder.load_state_dict(torch.load(args.decoder))
to
decoder.load_state_dict(torch.load(args.decoder), strict=False)
The error seems resolved.
Thanks for this implementation, very cool!
I'm using it for my research, and I was wondering whether you could indicate which license applies to the repo?
I would like to make my code openly available on github, and since it is based on your implementation, I would obviously give you credit in the README but I'm not sure whether any license (e.g. MIT) applies to your code.
I trained on my own dataset to get the decoder, however, I got this error when I attempted to unpack the trained decoder with the command tar -xvzf decoder_iter_160000.pth.tar:
gzip: stdin: not in gzip format
tar: Child returned status 1
tar: Error is not recoverable: exiting now
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.