{ } Raw JSON

bundles / scipy 1.17.1 / scipy / stats / _stats_py / wasserstein_distance

function

scipy.stats._stats_py:wasserstein_distance

source: /scipy/stats/_stats_py.py :9594

Signature

def   wasserstein_distance ( u_values v_values u_weights = None v_weights = None )

Summary

Compute the Wasserstein-1 distance between two 1D discrete distributions.

Extended Summary

The Wasserstein distance, also called the Earth mover's distance or the optimal transport distance, is a similarity metric between two probability distributions [1]. In the discrete case, the Wasserstein distance can be understood as the cost of an optimal transport plan to convert one distribution into the other. The cost is calculated as the product of the amount of probability mass being moved and the distance it is being moved. A brief and intuitive introduction can be found at [2].

Parameters

u_values : 1d array_like

A sample from a probability distribution or the support (set of all possible values) of a probability distribution. Each element is an observation or possible value.

v_values : 1d array_like

A sample from or the support of a second distribution.

u_weights, v_weights : 1d array_like, optional

Weights or counts corresponding with the sample or probability masses corresponding with the support values. Sum of elements must be positive and finite. If unspecified, each value is assigned the same weight.

Returns

distance : float

The computed distance between the distributions.

Notes

Given two 1D probability mass functions, and , the first Wasserstein distance between the distributions is:

where is the set of (probability) distributions on whose marginals are and on the first and second factors respectively. For a given value , gives the probability of at position , and the same for .

If and are the respective CDFs of and , this distance also equals to:

See [3] for a proof of the equivalence of both definitions.

The input distributions can be empirical, therefore coming from samples whose values are effectively inputs of the function, or they can be seen as generalized functions, in which case they are weighted sums of Dirac delta functions located at the specified values.

Array API Standard Support

wasserstein_distance has experimental support for Python Array API Standard compatible backends in addition to NumPy. Please consider testing these features by setting an environment variable SCIPY_ARRAY_API=1 and providing CuPy, PyTorch, JAX, or Dask arrays as array arguments. The following combinations of backend and device (or other capability) are supported.

====================  ====================  ====================
Library               CPU                   GPU
====================  ====================  ====================
NumPy                 ✅                     n/a                 
CuPy                  n/a                   ⛔                   
PyTorch               ⛔                     ⛔                   
JAX                   ⛔                     ⛔                   
Dask                  ⛔                     n/a                 
====================  ====================  ====================

See dev-arrayapi for more information.

Examples

from scipy.stats import wasserstein_distance
wasserstein_distance([0, 1, 3], [5, 6, 8])
wasserstein_distance([0, 1], [0, 1], [3, 1], [2, 2])
wasserstein_distance([3.4, 3.9, 7.5, 7.8], [4.5, 1.4],
                     [1.4, 0.9, 3.1, 7.2], [3.2, 3.5])

See also

wasserstein_distance_nd

Compute the Wasserstein-1 distance between two N-D discrete distributions.

Aliases

  • scipy.stats.wasserstein_distance

Referenced by

This package