Code Monkey home page Code Monkey logo

Comments (3)

jeanfeydy avatar jeanfeydy commented on July 17, 2024 2

Hi @Yura52 , @NightWinkle ,

Thanks a lot for your interest in the library and relevant questions!

As described by @NightWinkle and in Eqs. (3.226-3.227) of my PhD thesis, the current implementation allows you to backpropagate efficiently through the computation of the Sinkhorn loss: if you're interested in optimizing a (smooth) Wasserstein distance with respect to the weights alpha, beta and the sample locations x, y, you're good to go. This seems to be what you have in mind, which is good news.

Note, however, that you may encounter problems if you plan to do something a bit more exotic. For instance, some authors use the Sinkhorn algorithm to compute an optimal transport plan (as in e.g. this tutorial), and then backprop through a loss that is not a regularized Wasserstein distance. In this situation, the simplifications that I hardcoded into GeomLoss do not hold anymore: to retrieve the correct gradients, you should indeed backprop through the iterations of the Sinkhorn loop. In other words, comment the torch.autograd.set_grad_enabled(False) line or add some extra iterations at the end of the loop, in the final "extrapolation" step.

All these points are discussed in recent works by Pierre Ablin, such as this paper: going forward, I will certainly add a "switch" for this behaviour as an optional argument. Right now, I am mostly working on improving the low-level KeOps routines of GeomLoss and finalizing theoretical papers, but I will really push for a stable v1.0 release over the next few months.

I hope that this answers your question: feel free to re-open the issue if needed :-)
Best regards,
Jean

from geomloss.

NightWinkle avatar NightWinkle commented on July 17, 2024 1

It can differentiate through the solution of the Optimal Transport problem.

As you can read in the litterature, for instance in Interpolating between Optimal Transport and MMD using Sinkhorn Divergences, the gradient of Sinkhorn is actually equal to the gradient of one Sinkhorn iteration.
For this reason, it is more efficient to compute the gradient using only one of these Sinkhorn iterations.

The line you are mentioning just allows to disable computation of the autodifferentiation graph through the steps that are not needed to compute the gradient and would make backward pass quite slow.

As you can see tho, gradients are reactivated before the last iteration, allowing for the gradient to be computed.

from geomloss.

Yura52 avatar Yura52 commented on July 17, 2024

Oh, I see, thank you for the answer!

from geomloss.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.