Code Monkey home page Code Monkey logo

Comments (8)

cantabile-kwok avatar cantabile-kwok commented on May 17, 2024 1

@adefossez @npuichigo Could you please point out into more detail why "this won't impact the Straight-Through-Estimator gradient for the Encoder"? I think if the residual is computed in a sense that doesn't pass its real gradients, then the gradient estimator may also be affected. The following code snippet may illustrate this:

import torch
def quantize(x, codebook):
    diff = codebook - x  # (n_code, dim)
    mse = (diff**2).sum(1)
    idx = torch.argmin(mse)
    return codebook[idx]

dim = 5
x = torch.randn(1, dim, requires_grad=True)
codebook1 = torch.randn(10, dim)
codebook2 = torch.randn(10, dim)

q1 = quantize(x, codebook1)  # quantize x with first codebook
q1 = x + (q1 - x).detach()  # transplant q1's gradient to x
residual = x - q1  # detach q1 or not may make a difference. Compute residual for next level quantizing
q2 = quantize(residual, codebook2)  # quantize residual with second codebook
q2 = residual + (q2 - residual).detach()  # transplant q2's gradient to residual

loss = 0*q1.sum() + 1*q2.sum()  # loss is a function of q1 and q2, now it is independent of q1.
loss.backward()
print(x.grad)

The printed gradient is all zero, but if we replace residual = x - q1 with residual = x - q1.detach(), the gradient will be non-zero.

from encodec.

npuichigo avatar npuichigo commented on May 17, 2024

@adefossez

from encodec.

adefossez avatar adefossez commented on May 17, 2024

Thanks for bringing that out!

It seem like this won't impact the Straight-Through-Estimator gradient for the Encoder, but will kill the commitment loss for all residual VQ but the first one right ?

from encodec.

npuichigo avatar npuichigo commented on May 17, 2024

It seems so. But I'm not sure how much it affects the final result.

from encodec.

adefossez avatar adefossez commented on May 17, 2024

I'm a bit reluctant on introducing a change we haven't tested in this codebase, as it could change the best hyper params etc. I can add a warning however if the model is used in training mode pointing to this issue.

from encodec.

adefossez avatar adefossez commented on May 17, 2024

why did you put 0 * q1.sum() ? that is what is breaking the STE gradient. With the current code d q1 / d x = Id and d q_i d / x = 0 for all i > 1, which is okay as the overall gradient d (sum q_i) / d x = Id which is what we want. The only thing that is impacted in the commitment loss.

from encodec.

cantabile-kwok avatar cantabile-kwok commented on May 17, 2024

Oh, I think I over-complicated the problem here. In the model, all the quantization outputs q_i are simply added to feed the decoder, so the relation d (sum q_i) / d x = Id helps making this STE still working. In my code snippet, I assume the loss function can be any arbitrary function of argument q1 and q2. In this case, the gradient from q2 will never impact the previous networks, thus may not be good.

Still, if we replace residual = x - q1 with residual = x - q1.detach(), it seems d (sum q_i) / d x = n*Id then. Thus the scale of the losses may be affected. Thanks for the clarification!

from encodec.

DingWeiPeng avatar DingWeiPeng commented on May 17, 2024

@adefossez @cantabile-kwok

If residual = residual - quantized , then the second codebook can update with gradient but it can not afffect the first codebook.
If residual = residual - quantized.detach(), then the second codebook's gradient will affect the fisrt codebook.

In core_vq.py, there is the following code in VectorQuantization Class :
image

Now there is the following code in the ResidualVectorQuantization Class
image

So, this problem equals to the following problem. The following code snippet may illustrate this:

'''
import torch
def quantize(x, codebook):
diff = codebook - x # (n_code, dim)
mse = (diff**2).sum(1)
idx = torch.argmin(mse)
return codebook[idx]

dim = 5
x = torch.randn(1, dim, requires_grad=True)
codebook1 = torch.randn(10, dim)
codebook2 = torch.randn(10, dim)

q1 = quantize(x, codebook1) # quantize x with first codebook
q1 = x + (q1 - x).detach() # transplant q1's gradient to x
residual = x - q1.detach() # detach q1 or not may make a difference. Compute residual for next level quantizing
q2 = quantize(residual, codebook2) # quantize residual with second codebook
q2 = residual + (q2 - residual).detach() # transplant q2's gradient to residual

loss = 1*q2.sum() # loss is a function of q1 and q2, now it is independent of q1.
loss.backward()
print(x.grad)
'''

if residual = x-q1, x.grad = 0,
if residul = x-q1.detach(), x.grad = tensor([[1., 1., 1., 1., 1.]])

from encodec.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.