diff --git a/Project.toml b/Project.toml index 42f0260..53103c1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PythonOT" uuid = "3c485715-4278-42b2-9b5f-8f00e43c12ef" authors = ["David Widmann"] -version = "0.1.3" +version = "0.1.4" [deps] PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" diff --git a/docs/src/api.md b/docs/src/api.md index feaa2d7..2730efe 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -14,6 +14,7 @@ emd2_1d ```@docs sinkhorn sinkhorn2 +empirical_sinkhorn_divergence barycenter ``` diff --git a/src/PythonOT.jl b/src/PythonOT.jl index 9b3c534..1f67513 100644 --- a/src/PythonOT.jl +++ b/src/PythonOT.jl @@ -11,7 +11,8 @@ export emd, barycenter, barycenter_unbalanced, sinkhorn_unbalanced, - sinkhorn_unbalanced2 + sinkhorn_unbalanced2, + empirical_sinkhorn_divergence const pot = PyCall.PyNULL() diff --git a/src/lib.jl b/src/lib.jl index 219820f..6e50f7e 100644 --- a/src/lib.jl +++ b/src/lib.jl @@ -244,6 +244,43 @@ function sinkhorn2(μ, ν, C, ε; kwargs...) return pot.sinkhorn2(μ, ν, PyCall.PyReverseDims(permutedims(C)), ε; kwargs...) end +""" + empirical_sinkhorn_divergence(xsource, xtarget, ε; kwargs...) + +Compute the Sinkhorn divergence from empirical data, where `xsource` and `xtarget` are +arrays representing samples in the source domain and target domain, respectively, and `ε` +is the regularization term. + +This function is a wrapper of the function +[`ot.bregman.empirical_sinkhorn_divergence`](https://pythonot.github.io/gen_modules/ot.bregman.html#ot.bregman.empirical_sinkhorn_divergence) +in the Python Optimal Transport package. Keyword arguments are listed in the documentation of the Python function. + +# Examples + +```jldoctest +julia> xsource = [1]; + +julia> xtarget = [2, 3]; + +julia> ε = 0.01; + +julia> empirical_sinkhorn_divergence(xsource, xtarget, ε) ≈ + sinkhorn2([1], [0.5, 0.5], [1 4], ε) - + ( + sinkhorn2([1], [1], zeros(1, 1), ε) + + sinkhorn2([0.5, 0.5], [0.5, 0.5], [0 1; 1 0], ε) + ) / 2 +true +``` + +See also: [`sinkhorn2`](@ref) +""" +function empirical_sinkhorn_divergence(xsource, xtarget, ε; kwargs...) + return pot.bregman.empirical_sinkhorn_divergence( + reshape(xsource, Val(2)), reshape(xtarget, Val(2)), ε; kwargs... + ) +end + """ sinkhorn_unbalanced(μ, ν, C, ε, λ; kwargs...)