Code Monkey home page Code Monkey logo

Comments (16)

bpopeters avatar bpopeters commented on August 20, 2024

Hi, could you provide an example of the code that creates the problem?

Are you using entmax_bisect for attention or something else?

from entmax.

prajjwal1 avatar prajjwal1 commented on August 20, 2024

I'm using entmax_bisect in the same manner as mentioned here. I'm only using it as a replacement for softmax. For code, it is similar to huggingface's implementation of BertAttention. nn.softmax works fine. I don't know why this occurs, when the training is near complete, loss becomes nan. I tried three times. 2/3 times, the loss became nan. It is kind of non deterministic I guess.

from entmax.

bpopeters avatar bpopeters commented on August 20, 2024

What value of alpha are you using?

from entmax.

bpopeters avatar bpopeters commented on August 20, 2024

The algorithm becomes unstable for large values (anything significantly larger than 2, really). If you are learning alpha, the best strategy is to do something like @goncalomcorreia suggested in #10, which ensures that the alpha values is between 1 and 2.

from entmax.

prajjwal1 avatar prajjwal1 commented on August 20, 2024

I'm using the same AlphaChooser which includes clamping. I tried with autograd anomaly detection and it says :
Function 'BinaryCrossEntropyWithLogitsBackward' returned nan values in its 0th output. This doesn't happen when I use softmax.

from entmax.

prajjwal1 avatar prajjwal1 commented on August 20, 2024

What value of alpha are you using?

I'm using the AlphaChooser, which maintains the value between 1.01 and 2 as @goncalomcorreia mentioned. Since it's a parameter, I don't exactly know what alpha is.

from entmax.

vene avatar vene commented on August 20, 2024

Please attach a minimal script reproducing the issue.

As the pytorch Nan detection finds no issues with the entmax_bisect function then it is very hard for us to guess what could be wrong without a script to reproduce.

meanwhile, a few ideas:

  • try lowering your learning rate:
  • ensure you are using the output of entmax_alpha correctly, i.e., if you assume it sums to one over columns when it really sums to one over rows you will have problems

from entmax.

prajjwal1 avatar prajjwal1 commented on August 20, 2024
  1. I'll try with a lower learning rate.
  2. Regarding your second point, my code for getting distribution looks like this:
if self.sparse:
    attention_probs = self.entmax_alpha(attention_scores)
else:
    attention_probs = nn.Softmax(dim=-1)(attention_scores)

I modified the EntmaxAlpha class for my experiment. For reference:


class AlphaChooser(torch.nn.Module):

    def __init__(self, head_count):
        super(AlphaChooser, self).__init__()
        self.pre_alpha = nn.Parameter(torch.randn(head_count))

    def forward(self):
        alpha = 1 + torch.sigmoid(self.pre_alpha)
        return torch.clamp(alpha, min=1.01, max=2)
    
class EntmaxAlpha(nn.Module):

    def __init__(self, head_count, dim=0):
        super(EntmaxAlpha, self).__init__()
        self.dim = dim
        self.alpha_chooser = nn.Parameter(AlphaChooser(head_count)())
        self.alpha = self.alpha_chooser
        
    def forward(self, att_scores):
        batch_size, head_count, query_len, key_len = att_scores.size()
        
        expanded_alpha = self.alpha.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) # [1,nb_heads,1,1]
        expanded_alpha = expanded_alpha.expand((batch_size, -1, query_len,1))# [bs, nb_heads, query_len,1]
        p_star = entmax_bisect(att_scores, expanded_alpha)

        return p_star

Does it look fine to you ?

from entmax.

vene avatar vene commented on August 20, 2024

It's hard to spot issues without a minimal reproducing example. Try to isolate a small input/output that causes problems, ideally isolated from any other complicated model components.

from entmax.

prajjwal1 avatar prajjwal1 commented on August 20, 2024

I tried with a lower learning rate. I also saved the attention weights at every step.

if self.sparse:
    attention_probs = self.entmax_alpha(attention_scores)
    torch.save(attention_probs, '/home/user/ckpt/att.pth')

Just before my loss becomes nan, I checked attention_probs, it returned

tensor([[[[nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          ...,
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],

from entmax.

vene avatar vene commented on August 20, 2024

Are all of those values (suppressed as "...") also nan or not? Can you reproduce the issue when reducing the dimensionality such that this whole tensor can be printed on screen?

I suggest printing the inputs (attention_scores) too, this behaviour might happen if they all become -inf (or indistinguishable from)

from entmax.

vene avatar vene commented on August 20, 2024

Finally, you can construct a minimal reproducing example by saving the attention scores and alphas and writing a small python file that applies entmax_bisect to them, yielding the same nans, and sharing this with us.

from entmax.

prajjwal1 avatar prajjwal1 commented on August 20, 2024

I'm having a hard time finding the problem. I obtained a dict file containing all attention scores and probabilities with alpha values. There were no nans, I'm not sure what's causing the problem. Also obtaining the dict adds a 15x overhead on training time because of saving the dict. The error also occurs during different stages of training (with shuffle=False). Still on it.

from entmax.

vene avatar vene commented on August 20, 2024

Fix all of your random seeds, and save the inputs only for the batch where entmax gives you nans. (No need for dicts, just the inputs and alphas.)

This is, of course, assuming I understood correctly and your issue happens because at some point you feed non-nan inputs to entmax and nans come out.

from entmax.

prajjwal1 avatar prajjwal1 commented on August 20, 2024

I decreased my learning rate to 1e-5. Maybe that is the reason I'm not seeing Nans now. Is there any correlation with a relatively high LR (1e-4) with using sparsity ? I didn't notice any anomaly with softmax.

from entmax.

vene avatar vene commented on August 20, 2024

we haven't encountered any such issues to my knowledge, but neural nets are pretty mysterious and there are many moving parts that may cause nans. It would depend on your data, regularization, etc.

The only reason I've ever seen entmax to return nans is if the input is all -infty, but in this case softmax returns nan as well, so this is in a sense expected behaviour (of course, different trajectories during training might make one model reach -inf scores while the other wouldn't. It's hard to tell if you're not using fixed random seeds, too.)

If you identified a scenario where a single call to entmax gives nans while softmax doesn't for the same inputs (attention_scores), then it would be a bug in our code, and a reproducing script would help catch it.

from entmax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.