Code Monkey home page Code Monkey logo

Comments (4)

astrocyted avatar astrocyted commented on August 11, 2024 1

Hi Yuan,

I would like to double down on this issue. because I don't think it is about whether you use mean(dim=0) or sum(dim=0) in where you aggregate the output of attention heads, the issue is that the self.head_weight is an unbounded parameter:

self.head_weight = nn.Parameter(torch.tensor([1.0/self.head_num] * self.head_num))

and it could end up in any value as you're not constraining it either explicitly (e.g. normalizing) or implicitly (through regularization terms).
Therefore i was really surprised to see the value of all 4 heads weight to be less than 1 in your pretrained model release.

That said, you do clmap the output of the network to be [0,1] before passing it to BCELoss:

audio_output = torch.clamp(audio_output, epsilon, 1. - epsilon)

So rather using a smooth, squishing activation fucntion like sigmoid at the very end of the model, (whether intended or not) you are using a troublesome piece-wise continuous:

image

This means that unless you have super carefully initialized your model's parameters and a very small learning rate, the training would stop if the output goes above or below zero (zero grad).
So, I've not tried to train your model from scratch, but it must have been quite tricky if not very difficult.

So do you have any explanation as to why this particular design choice with clamping and not using smooth activation functions or avoiding the need for any end activation function altogether by enforcing constraint on head weights?

from psla.

YuanGongND avatar YuanGongND commented on August 11, 2024

Hi Haohe,

Thanks for reaching out.

It has been a while since I coded the model, so I might be wrong.

In the PSLA paper, figure 2 caption, we said "We multiply the output of each branch element-wise and apply a temporal mean pooling (implemented by summation)", which is relected in

x = (torch.stack(x_out, dim=0)).sum(dim=0)

I guess if you change it to x = (torch.stack(x_out, dim=0)).mean(dim=0), the range should be smaller than 1. If you just take a pretrained model and change this line of code in inference, it should not change the result (mAP). But if you change this line for training, you might not get same result with us as it scales the output and loss.

Please let me know what you think.

-Yuan

from psla.

astrocyted avatar astrocyted commented on August 11, 2024

On a different note, I see you normalize the attention values across temporal axis :

norm_att = att / torch.sum(att, dim=2)[:, :, None]

this would seemingly encourage the model to attend to one single temporal unit (in the output layer) at the expense of not-attending to other temporal slices. Given that many events are dynamic and have larger extent than a single unit of time, specially considering event-dense audioset recordings, what would be the inductive bias for such a choice?

Furthermore, in order to obtain these normalized attention values for each head, you first pass them through a sigmoid function and then normalize them using "division by sum"

norm_att = att / torch.sum(att, dim=2)[:, :, None]

is there any paticular reason for this choice of "sigmoid +normalization by sum" versus the more mainstream approach of using a softmax of attention values directly? they are not of course equivalent, as Softmax exclusively depends on the difference between values i.e $(X_i- X_j)$ s but your version does actually also depend on absolute values of $X_i$ s.

from psla.

YuanGongND avatar YuanGongND commented on August 11, 2024

Hi there,

Thanks so much for your questions. I need time to think of it. The main model architecture is from a previous paper: http://groups.csail.mit.edu/sls/archives/root/publications/2019/LoganFord_Interspeech-2019.PDF.

This means that unless you have super carefully initialized your model's parameters and a very small learning rate, the training would stop if the output goes above or below zero (zero grad). So, I've not tried to train your model from scratch, but it must have been quite tricky if not very difficult.

But before that, I want to clarify that we do not pick the random seeds or pick the success runs at all. All experiments are run 3 times and report the mean, which should be reproducible with the provided code. In the paper, we show the variance is pretty small. Your proposed ``more reasonable'' solution might lead to more stable optimization and probably better results. Have you tried that?

-Yuan

from psla.

Related Issues (12)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.