Code Monkey home page Code Monkey logo

Comments (8)

vene avatar vene commented on August 20, 2024

Hi @mtreviso ,

Interesting catch, does it happen also with entmax_bisect with alpha=2?

I think we have unit tests that should cover the correctness of gradients for multi-dimensional inputs. If you have time and willingness, could you check which test should be catching this and why it isn't?

from entmax.

mtreviso avatar mtreviso commented on August 20, 2024

With entmax_bisect with alpha=2 it doesn't happen! I tested with sparsemax_bisect and it is ok too.

from entmax.

mtreviso avatar mtreviso commented on August 20, 2024

Turns out if you create a tensor of shape (2, 2, 3) the grad is actually correct. I guess the problem is in this line. We should squeeze the dimension that we are working on.

I tested here with
v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze(dim) and it seems to be working fine.

I think the tests in rest_root_finding.py, namely test_sparsemax_grad, test_entmax_grad didn't catch this because they are testing with tensors with two non-unitary dims, and the tests are just for the bisect implementation of both sparsemax and entmax.

from entmax.

vene avatar vene commented on August 20, 2024

test_root_finding isn't supposed to test this impl but the bisection-based one. (this implementation is not based on root finding)

I think the correct line you want is the one from the entmax15 backward pass:
https://github.com/deep-spin/entmax/blob/master/entmax/activations.py#L189

from entmax.

mtreviso avatar mtreviso commented on August 20, 2024

Got it. So, should I create a special test file only for non root finding activations and do a pr?

Yeah! entmax backward pass looks ok to me (in terms of tensor shapes). I think is just a matter of calling squeeze with dim=ctx.dim in the sparsemax backward pass, right?

from entmax.

vene avatar vene commented on August 20, 2024

So since the entmax15 backward pass does not call squeeze and acts correctly, I suspect squeeze is not needed for sparsemax either. This is why I linked you to the entmax15 backward pass.

The tests for the topk-based exact solvers are in https://github.com/deep-spin/entmax/blob/master/entmax/test_topk.py and it seems none are checking the gradient rn.

from entmax.

mtreviso avatar mtreviso commented on August 20, 2024

I think we actually have to squeeze supp_size in the correct dimension, since it is unsqueezed before in this line:
https://github.com/deep-spin/entmax/blob/master/entmax/activations.py#L71

I tried your suggestion and I just removed .squeeze() and left supp_size as it is. If you create a tensor of shape (3,2,2) and try to apply sparsemax with dim=-1, you'll get RuntimeError.

Using entmax:

>>> x = torch.randn(3,2,2, requires_grad=True)
>>> z = torch.sum(torch.pow(entmax.entmax15(x), 2))
>>> z.backward()
>>> x.grad.shape
torch.Size([3, 2, 2])

sparsemax:

>>> x = torch.randn(3,2,2, requires_grad=True)
>>> z = torch.sum(torch.pow(entmax.entmax15(x), 2))
>>> z.backward()
>>> x.grad.shape
  File "/home/mtreviso/Documents/entmax/entmax/activations.py", line 164, in backward
    v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype)
RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 1

Inspecting grad_input.sum(dim=dim).shape and supp_size.shape we see that they are not broadcastable:

>>> grad_input.sum(dim=dim).shape
torch.Size([3, 2])
>>> supp_size.shape
torch.Size([3, 2, 1])

Maybe writing test_sparsemax_grad() and test_entmax_grad() in test_topk.py?

from entmax.

mtreviso avatar mtreviso commented on August 20, 2024

Closing. Solved by #18.

from entmax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.