Comments (2)
Hi @johannbrehmer ,
Thanks a lot for your detailed report!
A quick answer would be to tell you that this (unfortunate) negative output of the SamplesLoss
layer is due to the default value of the scaling
parameter that I agressively set to 0.5
. As discussed in this tutorial, using a more conservative value of 0.7
or 0.9
should allow you to spend more time in the multiscale Sinkhorn loop (an annealing descent), and thus converge closer to the actual value of the (de-biased) Sinkhorn loss. In your case, something like 0.038 > 0
.
Nevertheless, since your example is one of the first that I encounter where changing the scaling
parameter actually has a meaningful impact, I played around with it for a bit. Surprise: when we display it in 3D, we observe a few outliers lying far away from the unit sphere.
If we remove these extreme points with a code along the lines of:
x2 = x1.clone()
print(x2[(x1**2).sum(-1) > 1.5, :])
x2[(x1**2).sum(-1) > 1.5, :] = 0
print(sinkhorn(x0, x2).item())
We immediately retrieve a positive value even when scaling = 0.5
. As far as I can tell, it thus looks like your example is a perfect worst-case scenario for the multiscale Sinkhorn loop implemented by GeomLoss. This is in line with the theoretical understanding of the algorithm that we're trying to develop with @bernhard-schmitzer : your configuration is pretty close to the "lone wolf" scenario discussed at Example 1, page 17 and Remark 9, page 21 of this important reference paper.
I'm pretty happy to see these theoretical considerations come to life in a real use-case :-)
If you have any other question, feel free to ask; otherwise, I'll let you close this issue.
Best regards,
Jean
P.S.: Out of curiosity, may I ask why you're interested in OT theory? You seem to work mostly on high-energy physics, and I'd be delighted to know how our codes may be of any use to you!
from geomloss.
Hi Jean,
Thanks a lot for the fast and detailed answer! I'll have a closer look at the references you gave, but this makes sense to me.
I arrived at this worst-case scenario not randomly, but by training a generative model on minimizing the Sinkhorn divergence between generated samples and some training data. I had some anomalous results, now I hope that changing the scaling parameter might stabilize my training.
Thanks again,
Johann
PS: Indeed, my background is in high-energy physics, but I am now interested in probabilistic and generative ML models... hence OT. But I don't know much about the theory at all.
from geomloss.
Related Issues (20)
- ValueError: Maximum allowed size exceeded in degenerate case of Sinkhorn loss
- name 'generic_logsumexp' is not defined HOT 4
- ValueError: not enough values to unpack (expected 3, got 2)
- Best way to use scikit-learn distance functions for cost
- Error while running transfer_labels.py
- Custom cost function replicating p=2 doesn't match inbuilt? HOT 1
- Sinkhorn loss always renables gradient tracking
- Dual and primal loss don't align for small blur values HOT 1
- Support for half/single floating point numbers HOT 3
- Optionally return transport plans for the Sinkhorn loss HOT 3
- Wasserstein distance for p not in {1,2}
- sinkhorn_divergence for 1D images not workin HOT 4
- Hausdorff Distance HOT 1
- Has ImagesLoss ever been finished? Or is it still a WIP? HOT 4
- CUDA_ERROR_INVALID_SOURCE error when running geomloss on some GPUs HOT 1
- Can this library be used with torch.amp? HOT 1
- generic_logsumexp with larger point clouds HOT 2
- Installing geomloss fails if torch is being installed at the same time HOT 4
- Error when using the hausdorff distance HOT 1
- Very different results for Wasserstein distance compared to Gudhi HOT 9
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from geomloss.