Code Monkey home page Code Monkey logo

Comments (14)

Zhaoyi-Yan avatar Zhaoyi-Yan commented on September 13, 2024

Hi,can you tell me which line your snippet is?

from shift-net_pytorch.

KumapowerLIU avatar KumapowerLIU commented on September 13, 2024

Thank you for your reply.
I have two questions.
In the original code of Innercos.py in branch 0.3:

import torch.nn as nn
import torch
from torch.autograd import Variable
import util.util as util

class InnerCos(nn.Module):
    def __init__(self, crit='MSE', strength=1, skip=0):
        super(InnerCos, self).__init__()
        self.crit = crit
        self.criterion = torch.nn.MSELoss() if self.crit == 'MSE' else torch.nn.L1Loss()

        self.strength = strength
        self.target = None
        # To define whether this layer is skipped.
        self.skip = skip

    def set_mask(self, mask_global, opt):
        mask = util.cal_feat_mask(mask_global, 3, opt.threshold)
        self.mask = mask.squeeze()
        if torch.cuda.is_available:
            self.mask = self.mask.float().cuda()
        self.mask = Variable(self.mask, requires_grad=False)

    def set_target(self, targetIn):
        self.target = targetIn

    def get_target(self):
        return self.target

    def forward(self, in_data):
        if not self.skip:
            self.bs, self.c, _, _ = in_data.size()
            self.former = in_data.narrow(1, 0, self.c//2)
            self.former_in_mask = torch.mul(self.former, self.mask)
            current_gpu_id = in_data.get_device()

            if self.target.size() != self.former_in_mask.size():
                self.target = self.target.narrow(0, current_gpu_id * self.bs, (current_gpu_id+1)*self.bs)

            self.loss = self.criterion(self.former_in_mask * self.strength, self.target)

            # I have to put it here!
            # when input is image with mask(the second pass), we
            # Mention only when input is the groundtruth, the target makes sense.
            self.target = in_data.narrow(1, self.c // 2, self.c // 2).clone() # the latter part
            self.target = self.target * self.strength
            self.target = self.target.detach()

            self.output = in_data
        else:
            self.loss = 0
            self.output = in_data
        return self.output


    def backward(self, retain_graph=True):
        if not self.skip:
            self.loss.backward(retain_graph=retain_graph)
        return self.loss

    def __repr__(self):
        skip_str = 'True' if not self.skip else 'False'
        return self.__class__.__name__+ '(' \
              + 'skip: ' + skip_str \
              + ' ,strength: ' + str(self.strength) + ')'

When run the line model.set_gt_latent(), in other words, when the we set the ground)truth feature as the label, target is initialized as None. While, self.target.size() != self.former_in_mask.size() raise an error, as target has no attribute of .size().

In version 0.4, I found that you run model.set_gt_latent() and then run model.optimize_parameters(), then it turns out that the label has been changed, not feature of ground_truth any more.

Therefore, I modified a little in your code of version 0.3. I am not sure whether it is good.
Any advices are appreciated.

在运行model.set_gt_latent()时,也就是设置ground truth 作为label时,由于 target的初始化是none
self.target.size() != self.former_in_mask.size() 这段代码会显示 target没有size这个属性而报错。

同时0.4的版本中. 我看了您的代码 我发现如果是0.4的写法中 你虽然也运行了model.set_gt_latent() 但是在运行完model.set_gt_latent()后 再运行 model.optimize_parameters() 您的label 已经不是ground truth了
相反您0.3版本的代码 是您原文中的含义

所以我针对0.3的版本在innercos.py中进行了一点改动 不知道合不合适 希望您给点建议

self.target = torch.FloatTensor(1, 256,32, 32)
self.target.zero_()
self.target=Variable(self.target)
self.target=self.target.cuda()

This kind of initialization of target in file innerCos.py solves the mistake and meets the idea in the paper.

将0.3版本的innercos.py 中target初始化这样定义 就不会报错 并且符合论文中的含义

from shift-net_pytorch.

Zhaoyi-Yan avatar Zhaoyi-Yan commented on September 13, 2024

感谢指出,最近一段时间会验证一下。

from shift-net_pytorch.

KumapowerLIU avatar KumapowerLIU commented on September 13, 2024

谢谢 麻烦了 验证后麻烦回复我一下= =

from shift-net_pytorch.

tchaton avatar tchaton commented on September 13, 2024

@KumapowerLIU @Zhaoyi-Yan Please english :)

from shift-net_pytorch.

Zhaoyi-Yan avatar Zhaoyi-Yan commented on September 13, 2024

@KumapowerLIU Could you try the version of master? The shift operation has been validated. Any you can enjoy the fast speed of optimized shift. Maybe you can try to confirm whether InnerCos is in line your expectation in the master.

from shift-net_pytorch.

tchaton avatar tchaton commented on September 13, 2024

image

Here is a training loss. We can certainly do something to help preventing GAN collapse.

Best,
T.C

from shift-net_pytorch.

Zhaoyi-Yan avatar Zhaoyi-Yan commented on September 13, 2024

vallian gan usually suffers from model collapse. I haven't tried other gans.
Does relativelistic gan or spectral norm helps avoid collapse?

from shift-net_pytorch.

tchaton avatar tchaton commented on September 13, 2024

I didn t really tried yet.

from shift-net_pytorch.

tchaton avatar tchaton commented on September 13, 2024

Hello Yann,

I am trying a new l1 loss. It is weighted sum between l1 loss inside ans outsider the mask. The better the outside will be. The more it will focus on the inside.

I also want to try inverse signed distance function applied over the mask.

Best,
T.C

from shift-net_pytorch.

tchaton avatar tchaton commented on September 13, 2024

I have noticed, that it was creating a constant collapse, even when I was adding a differ starter.

Because the generator mess up outside the mask, the discriminator starts to give back shitty gradients and the Gan collapse.

Best,
T.C

from shift-net_pytorch.

Zhaoyi-Yan avatar Zhaoyi-Yan commented on September 13, 2024

square mask training makes the testing only suitable for the same square masking. To avoid this, we can probably set the square mask randomly in training. In this way, a local D is expected to distinguish the mask region. Something like the settings in global and local and generative inpainting.

from shift-net_pytorch.

Zhaoyi-Yan avatar Zhaoyi-Yan commented on September 13, 2024

@KumapowerLIU I am sorry for the delay. You are totally right. It is a bug, that I should not put this line self.loss = self.criterion(self.former_in_mask * self.strength, Variable(self.target)) at the end of InnerCos forward. In a very early version,

self.loss = self.criterion(self.former_in_mask * self.strength, self.target)
# I have to put it here!
# when input is image with mask(the second pass), we
# Mention only when input is the groundtruth, the target makes sense.
self.target = in_data.narrow(1, self.c // 2, self.c // 2).clone() # the latter part
self.target = self.target * self.strength
self.target = self.target.detach()

self.loss is put before the setting of self.target. In this case, the guidance loss truly makes sense. Thank you for your reporting.

from shift-net_pytorch.

Zhaoyi-Yan avatar Zhaoyi-Yan commented on September 13, 2024

#53 Fixed, please check, thank ypu!

from shift-net_pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.