zalandoresearch / pytorch-vq-vae Goto Github PK
View Code? Open in Web Editor NEWPyTorch implementation of VQ-VAE by Aäron van den Oord et al.
License: MIT License
PyTorch implementation of VQ-VAE by Aäron van den Oord et al.
License: MIT License
Hi and thanks for providing a nice and clean implementation of VQ-VAEs :)
While playing around with your code, I noticed that in VectorQuantizerEMA
you first perform the EMA update of the codebook counts and embeddings, and then use the updated codebook embeddings as the quantized vectors (and for computing e_latent_loss
).
In particular, the order in which you perform operations is:
e_latent_loss
computationIs there a reason why you do the EMA updates before steps 3 and 4? My intuition says that the order should be:
e_latent_loss
computationLooking forward to hearing your thoughts!
Many thanks,
Stefanos
Thanks for your implementation of VQ-VAE but I've got a question. In the original paper, the gradients to the inputs of the decoder have been copied to the encoder's output cuz the op 'index selection' is non-differentiable, but I didn't the corresponding implementation in your code. I'm new in pytorch and not familiar with the auto-grad system, so it'll be appreciatable to have a little explanation about this. Thanks!
The list passed to nn.ModuleList
in the ResidualStack
class ctor in vae.ipynb#L324 duplicates a reference to a single Residual
object instance. Was this done intentionally?
self._layers = nn.ModuleList(
[Residual(in_channels, num_hiddens, num_residual_hiddens)] * self._num_residual_layers)
To create a new objects for each layer the code might be changed to:
self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens)
for _ in range(self._num_residual_layers)])
By updating all embeddings regardless of if they are being used you are decaying them towards 0. Is this intended?
I tried removing the decay but it seems to decrease perplexity.
Hi! Thank you for the great upload. How exactly can I extract the latent code of an image? By that I mean the code of size i.e [1,8,8] and not [128,8,8]
Thanks!
Hi,
thanks for your clean implementation ! I was wondering, have you ever tried to calculate the bits / dimension metric (as in the original paper) ? I've tried to do so using the provided code, and I'm still quite far from the results in the paper. I was hoping maybe you would have some insight to share as to why that is the case.
Thanks!
Lucas
Hi, I can't figure out why we need to change from BCHW to BHWC before we flatten.
I would be happy if you could explain this moment.
Thank you!
inputs = inputs.permute(0, 2, 3, 1).contiguous()
input_shape = inputs.shape
# Flatten input
flat_input = inputs.view(-1, self._embedding_dim)
Hi everybody,
looking at the VectorQuantizerEMA nn.Module in the code, I was not able to understand how the codebook vectors are updated after initialization. Is there a way to force the use of all the codebook?
Last, how should I read the perplexity value?
Thank you!
Giorgio
``
def forward(self, inputs):
# convert inputs from BCHW -> BHWC
inputs = inputs.permute(0, 2, 3, 1).contiguous()
input_shape = inputs.shape
# Flatten input
flat_input = inputs.view(-1, self._embedding_dim)
``
My unders understanding is: dimension of flat_input should be BHWC*embedding_dim, one dimension seems to be missing? Or you are saying number of channels equal to embedding_dim?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.