The gradient with respect to the slots_mu and slots_sigmavariables is zero. To learn the initialization of slots, you could change your model.pyin line 40 to slots = torch.distributions.Normal(mu, sigma).rsample()... with rsample()the gradients will flow for these variables
Thanks for your implementation of Slot Attention module. However, I found that the sampling operation (in Line 40 at model.py) prevents gradients from the back-propagation. During training, the gradients of slot_mu and slot_sigma will be zero, which means the two variable will not change. I think the reparameterization trick is needed to make the sampling operation differentiable.
Hello and thank you for the great work.
I think the eval file misses the pre-trained model ('./tmp/model3.ckpt') as indicated in the error I received. Would you please show me how I can have it?
Thank you in advance