Code Monkey home page Code Monkey logo

botcl's People

Contributors

wbw520 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

botcl's Issues

Question about Slot Attention Implementation

Hi, thanks for sharing the code.

I see there are some variations of your Slot Attention implementation compared to the original slot attention implementation [NIPS2020]. Could you please help me to clarify?

  1. On Line 45 in slots.py, you normalize in the spatial dimension with dots.sum(2), then scale with the global sum dots.sum(2).sum(1). Could you please explain what is the motivation behind multiplying with the global sum? Would it cause gradient exploding when multiplying with a big number?
  2. On the same line, you use sigmoid instead of softmax. What is the motivation?
  3. Unlike Slot Attention, the NormLayer is removed, and q_linear is a stack of three linear layers instead of only one nn.Linear. How does it affect the performance?

I would be appreciate if you could clarify my confusions.
Best wishes.

AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute 'next'

I get the following error when running python main_recon.py --num_classes 10 --num_cpt 20 --lr 0.001 --epoch 50 --lr_drop 30:

AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute 'next'

Pytorch versions: torch==2.0.1 and torchvision==0.15.2

Error happens in utils/engine_recon.py on line 80 and in vis_recon.py on line 37. I used the solution mentioned in https://stackoverflow.com/questions/74289077/attributeerror-multiprocessingdataloaderiter-object-has-no-attribute-next to fix it.

There may be other areas of the codebase that also use this syntax.

Current line:

data, label = iter(loader).next()

Proposed line change:

data, label = next(iter(loader))

Inplement about the contrastive loss

hi, i am confused about how to compute the contrastive loss in the paper, as it mentioned in the paper to calculate lret through (t,t',y,y') , but in the code, the model returns (cpt - 0.5) * 2, cls, attn, updates and seems to calculate loss by directly seeding (cpt - 0.5) * 2 as y and use the pairwise_loss as lret, but the input for the pairwise_loss is pairwise_loss(y, y, label, label in the released code, which means they are similar all the time and its really confusing.
Could you please tell me if there is anything wrong with my understanding?
Really appreciate to your help!

def get_retrieval_loss(args, y, label, num_cls, device):
    b = label.shape[0]
    if args.dataset != "matplot":
        label = label.unsqueeze(-1)
        label = torch.zeros(b, num_cls).to(device).scatter(1, label, 1)
    similarity_loss = pairwise_loss(y, y, label, label, sigmoid_param=10. / 32)
    # similarity_loss = pairwise_loss2(y, y, label.float(), sigmoid_param=10. / 32)
    q_loss = quantization_loss(y)
    return similarity_loss, q_loss

Slot attention implementation did not update the q?

Based on your implementation, why does the slot attention iterate with the same q & k?

    https://github.com/wbw520/BotCL/blob/3dde3ac20cdecd7eea8c4b7cb0e04e2bb95f639b/model/contrast/slots.py#L37
    def forward(self, inputs_pe, inputs, weight=None, things=None):
        b, n, d = inputs_pe.shape
        slots = self.initial_slots.expand(b, -1, -1)
        k, v = self.to_k(inputs_pe), inputs_pe
        for _ in range(self.iters):
            q = slots  # always taking the initial slots as q?

            dots = torch.einsum('bid,bjd->bij', q, k) * self.scale
            dots = torch.div(dots, torch.abs(dots).sum(2).expand_as(dots.permute([2, 0, 1])).permute([1, 2, 0])) * \
                   torch.abs(dots).sum(2).sum(1).expand_as(dots.permute([1, 2, 0])).permute([2, 0, 1])
            attn = torch.sigmoid(dots)

            # print(torch.max(attn))
            # dsfds()

            attn2 = attn / (attn.sum(dim=-1, keepdim=True) + self.eps)
            updates = torch.einsum('bjd,bij->bid', inputs, attn2)

        if self.vis:
            slots_vis_raw = attn.clone()
            vis(slots_vis_raw, "vis", self.args.feature_size, weight, things)
        return updates, attn

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.