Comments (3)
@miguelsvasco
Hi, thank you for your detailed comments and I'm sorry for my late reply.
However, the original formulation of the MVAE model (in the paper Multimodal Generative Models for Scalable Weakly-Supervised Learning), does not consider such terms, only a KL divergence term between the distribution of the POE encoder and the prior
Yes. this loss function comes from not MIVAE
but JMVAE
(originally proposed in this paper as JMVAE-kl
). Though the PoE encoder is not used in the original paper of JMVAE, we wanted to see if this PoE encoder works well on the JMVAE loss. Anyway, I'm sorry for the confusion.
When I remove the kl_x and kl_y terms from the regularizer and train, the model seems unable to perform cross-modality inference:
This might be due to not training "unimodal" inferences of the PoE encoder, q(z|x) and q(z|y). Without it, inferred z from unimodal input (especially label or attribute) might be collapsed (a similar issue is also referred to our preprint paper as the "missing modality difficulty").
In the JMVAE, these are trained by making close them to "bimodal" inference q(z|x,y), which corresponds to the additional KL terms you pointed out.
Would that be possible to implement with the Pixyz framework?
Yes, but you should use the Model
class instead of the VAE
class because the loss function becomes more complex.
The implementation of the original MVAE model with Pixyz is as follows.
Given your comments, I replaced the name of the previous notebook from mvae_poe.ipynb
to jmvae_poe.ipynb
(to avoid confusion), and added the new notebook mvae.ipynb
which includes the implementation of the original MVAE model.
Thank you!
from pixyz.
@masa-su
Thank you for the framework.
For the MVAE implementation you provided above, how the model should be trained for the semi-supervised case? Let's say for the MNIST dataset only a share of labels is available. Should two Model objects which share the networks but have different loss functions be created for 1) the image and the label available and 2) only the label available?
from pixyz.
@sgalkina
Thank you for your comment!
I don't know what kind of loss functions for each supervised and unsupervised you are going to implement, but you can use the replace_var
method in the Distribution class to share the same network in different losses, e.g., supervised and unsupervised losses.
For an example of the usage, please see the implementation of the M2 model, which is the well-known semi-supervised VAE model.
If you have any trouble understanding how to use it, please feel free to ask!
from pixyz.
Related Issues (20)
- IterativeLoss may have a bug(?)
- How to get KL divergence and reconstruction error in VAE HOT 6
- Encoder is executed 2 times in VAE HOT 2
- "invalid equation" in README.md HOT 1
- Implementation of NVAE
- feature request: MultivariateNormal
- How to get progress in pixyz batch processing
- why need 'double_after_norm' in resnet.py?
- Add .mean() and .sum() in Loss classes
- Add ELBO and NLL as Loss classes HOT 1
- Write an introduction and tutorials in English
- Add the "marginalization" option
- Add the Data distribution and the degenerate distribution HOT 1
- Add Arxiv paper link to readme HOT 1
- Add von Mises-Fisher distribution
- Change the name of evaluation method in each API HOT 1
- typo in README.md HOT 1
- Annealing beta in the loss function HOT 2
- Your Glow or RealNVP's implementation is forgot Split Layer, I think HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pixyz.