Skip to content

Does ot.emd2 support PyTorch autograd backward propagation? #776

@WANG-SUI

Description

@WANG-SUI

Hi, and thanks for maintaining this great library!

I'm currently using POT (with PyTorch backend) to compute OT-based losses.
I noticed that ot.emd2 seems not differentiable in the usual PyTorch sense,
but when I inspect the .grad field of the cost matrix M, I still get non-zero gradients after calling loss.backward().

So I would like to confirm:
Is ot.emd2 actually differentiable and supports autograd backward propagation through M or input distributions a, b?
Or are the observed gradients just residual tensors from detached operations (i.e., not truly backpropagated through the OT plan solver)?

Minimal reproducible example

import torch
import ot

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

n = 10
a = torch.rand(n, device=device)
a = a / a.sum()
b = torch.rand(n, device=device)
b = b / b.sum()

M = torch.randn(n, n, device=device, requires_grad=True)

loss_emd2 = ot.emd2(a, b, M)
loss_emd2.backward()
grad_emd2 = M.grad.clone()
print("EMD2 loss:", loss_emd2.item())
print("EMD2 grad:\n", grad_emd2)

M.grad.zero_()
reg = 0.1
loss_sinkhorn = ot.sinkhorn2(a, b, M, reg)
loss_sinkhorn.backward()
grad_sinkhorn = M.grad.clone()
print("\nSinkhorn loss:", loss_sinkhorn.item())
print("Sinkhorn grad:\n", grad_sinkhorn)

OUTPUT:
EMD2 loss: -1.5605539083480835
EMD2 grad:
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0427],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0095, 0.1873, 0.0095,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0252, 0.0000, 0.0233, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0925, 0.0000, 0.0000, 0.0000, 0.0000, 0.0061,
0.0987],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0481, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0552, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0126, 0.0000, 0.0000, 0.0000, 0.0000, 0.0532, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0765, 0.0000, 0.0000, 0.0548, 0.0000, 0.0000, 0.0000,
0.0000],
[0.1010, 0.0000, 0.0175, 0.0000, 0.0020, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0844, 0.0000, 0.0000, 0.0000,
0.0000]], device='cuda:0')

Sinkhorn loss: -1.5583062171936035
Sinkhorn grad:
tensor([[-1.0759e-15, -4.5089e-05, -2.3012e-06, -4.7454e-05, -3.1261e-15,
-8.1661e-06, -8.5109e-06, -2.6828e-12, -5.4352e-05, 4.2903e-02],
[-1.3629e-10, -5.7191e-13, -5.8734e-08, -1.4337e-03, -2.9875e-06,
-5.2210e-04, 1.2639e-02, 1.8735e-01, 8.2132e-03, -2.8584e-06],
[-1.3665e-12, -5.0650e-08, -4.7080e-04, -9.4670e-06, 2.7954e-02,
-4.7037e-09, 2.1273e-02, -9.7631e-12, -2.6613e-04, -1.6820e-12],
[-7.8764e-13, -1.0603e-12, -1.3319e-03, 9.4009e-02, -1.1775e-03,
-4.5318e-04, -6.8939e-07, -1.8201e-14, 7.6663e-03, 9.8547e-02],
[-6.4536e-11, -5.2054e-08, -3.9438e-09, -5.9460e-09, -1.3147e-07,
-1.8057e-07, 4.8088e-02, -3.1984e-11, -1.9308e-09, -2.1640e-10],
[-4.6134e-17, -7.1095e-07, 5.5260e-02, -3.3341e-17, -4.1714e-20,
-5.8153e-07, -2.3840e-05, -2.2943e-13, -7.5624e-06, -8.0583e-10],
[-1.0100e-13, 1.3838e-02, -1.4756e-08, -1.8981e-10, -2.9864e-12,
-8.6449e-05, 5.2120e-02, -1.4227e-16, -3.6979e-08, -2.9439e-11],
[-2.2352e-07, -1.1892e-05, 7.6544e-02, -2.5317e-08, -1.2087e-06,
5.4697e-02, -9.0246e-08, -8.5752e-18, -3.6054e-07, -1.7776e-11],
[ 1.0100e-01, -1.0979e-05, 1.9160e-02, -6.2655e-06, 4.4209e-04,
-1.1139e-07, -2.8346e-07, -3.8414e-05, -3.5571e-07, -3.9045e-05],
[-1.1940e-24, -1.1416e-03, -6.1396e-10, -7.2483e-15, -9.0970e-12,
8.5502e-02, -9.0134e-09, -6.0799e-18, -1.8524e-11, -1.1368e-18]],
device='cuda:0')

Please clarify whether:

ot.emd2 is intentionally non-differentiable (since it solves a linear program);
or if there is a plan to add a differentiable variant (e.g. differentiable EMD via implicit function theorem or entropic relaxation);
or if the gradients observed are accidental numerical artifacts.
Thanks a lot for your time and for maintaining this library!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions