Comments (3)
Hi @Yura52 , @NightWinkle ,
Thanks a lot for your interest in the library and relevant questions!
As described by @NightWinkle and in Eqs. (3.226-3.227) of my PhD thesis, the current implementation allows you to backpropagate efficiently through the computation of the Sinkhorn loss: if you're interested in optimizing a (smooth) Wasserstein distance with respect to the weights alpha
, beta
and the sample locations x
, y
, you're good to go. This seems to be what you have in mind, which is good news.
Note, however, that you may encounter problems if you plan to do something a bit more exotic. For instance, some authors use the Sinkhorn algorithm to compute an optimal transport plan (as in e.g. this tutorial), and then backprop through a loss that is not a regularized Wasserstein distance. In this situation, the simplifications that I hardcoded into GeomLoss do not hold anymore: to retrieve the correct gradients, you should indeed backprop through the iterations of the Sinkhorn loop. In other words, comment the torch.autograd.set_grad_enabled(False)
line or add some extra iterations at the end of the loop, in the final "extrapolation" step.
All these points are discussed in recent works by Pierre Ablin, such as this paper: going forward, I will certainly add a "switch" for this behaviour as an optional argument. Right now, I am mostly working on improving the low-level KeOps routines of GeomLoss and finalizing theoretical papers, but I will really push for a stable v1.0 release over the next few months.
I hope that this answers your question: feel free to re-open the issue if needed :-)
Best regards,
Jean
from geomloss.
It can differentiate through the solution of the Optimal Transport problem.
As you can read in the litterature, for instance in Interpolating between Optimal Transport and MMD using Sinkhorn Divergences, the gradient of Sinkhorn is actually equal to the gradient of one Sinkhorn iteration.
For this reason, it is more efficient to compute the gradient using only one of these Sinkhorn iterations.
The line you are mentioning just allows to disable computation of the autodifferentiation graph through the steps that are not needed to compute the gradient and would make backward pass quite slow.
As you can see tho, gradients are reactivated before the last iteration, allowing for the gradient to be computed.
from geomloss.
Oh, I see, thank you for the answer!
from geomloss.
Related Issues (20)
- Dual and primal loss don't align for small blur values HOT 1
- Support for half/single floating point numbers HOT 3
- Optionally return transport plans for the Sinkhorn loss HOT 3
- Wasserstein distance for p not in {1,2}
- sinkhorn_divergence for 1D images not workin HOT 4
- Hausdorff Distance HOT 1
- Has ImagesLoss ever been finished? Or is it still a WIP? HOT 4
- CUDA_ERROR_INVALID_SOURCE error when running geomloss on some GPUs HOT 1
- Can this library be used with torch.amp? HOT 1
- generic_logsumexp with larger point clouds HOT 2
- Installing geomloss fails if torch is being installed at the same time HOT 4
- Error when using the hausdorff distance HOT 1
- Very different results for Wasserstein distance compared to Gudhi HOT 9
- Gaussian MMD what is the optimal blur
- question about the error : arange: cannot compute length HOT 3
- ValueError: Maximum allowed size exceeded when only one value
- `LazyTensor` is not defined
- Compute Sinkhorn distance for desity images
- do you have any plans for low-rank sinkhorn implementation?
- Hi @heslowen! May I ask how to obtain both the loss and the transport plan at the same time?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from geomloss.