Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sumlog #48

Open
wants to merge 40 commits into
base: master
Choose a base branch
from
Open

sumlog #48

wants to merge 40 commits into from

Conversation

cscherrer
Copy link

@cscherrer cscherrer commented May 2, 2022

This PR adds sumlog, a more efficient way to compute sum(log, x). There's more discussion on this on Discourse here:
https://discourse.julialang.org/t/sum-of-logs/80370

EDIT:
I think we have a good enough understanding of what's possible to lay out some design criteria. That ought to be more efficient than taking each line of code in isolation.

As a starting point, I suggest

  1. Whenever sum(log ∘ f, x) is defined, sumlog(f, x) should give the same result (within some tolerance, etc)
  2. sumlog(x) == sumlog(identity, x)
  3. sumlog(f, x) should support a dims keyword argument whenever sum(log ∘ f, x) does (i.e., when x is an AbstractArray)
  4. sumlog should be type-stable and compiler-friendly when possible
  5. sumlog should use the optimized method requiring a single log application, whenever that's possible.

@devmotion @mcabbott thoughts on these?

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks!

I think there are some problems with the current implementation but I made some suggestions that hopefully can fix most of them.

Can you also update the version number and add it to the docs?

src/sumlog.jl Outdated Show resolved Hide resolved
src/sumlog.jl Outdated Show resolved Hide resolved
src/sumlog.jl Outdated Show resolved Hide resolved
src/sumlog.jl Outdated Show resolved Hide resolved
src/sumlog.jl Outdated

Since `log(2)` is constant, `sumlog` only requires a single `log` evaluation.
"""
function sumlog(x::AbstractArray{T}) where {T}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this will fail for non-floating point types T such as Int, etc., BigFloat, and for complex numbers?

src/sumlog.jl Outdated Show resolved Hide resolved
src/sumlog.jl Outdated Show resolved Hide resolved
test/sumlog.jl Outdated Show resolved Hide resolved
cscherrer and others added 7 commits May 2, 2022 10:40
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
@cscherrer
Copy link
Author

@devmotion I guess we also need some ChainRules methods...

@devmotion
Copy link
Member

@devmotion I guess we also need some ChainRules methods...

We could, but it's not necessary to do in this PR IMO.

Apart from the element type (see discussion above), I think the main problem left is that I assume the code is problematic for GPU arrays. Other array implementations in LohExpFunctions are written in a GPU-friendly way and should work with all array types.

@cscherrer
Copy link
Author

I think the main problem left is that I assume the code is problematic for GPU arrays. Other array implementations in LohExpFunctions are written in a GPU-friendly way and should work with all array types.

Do you see a nice way of doing this?

I see two other potential things to add:

  1. Support for Tuples and NamedTuples
  2. Support for calling sumlog(f, x)

Both less critical, so we can come back to them

@cscherrer
Copy link
Author

cscherrer commented May 2, 2022

One of these updates along the way killed all of the performance. Are you seeing this too? Need to backtrack a bit I guess, and maybe split the preprocessing into a separate function

Got it

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can use mapreduce to support more general and in particular GPU arrays (similar to how we use reduce in logsumexp).

The function has to be added to the docs for the tests to pass.

src/sumlog.jl Outdated Show resolved Hide resolved
src/sumlog.jl Outdated Show resolved Hide resolved
src/sumlog.jl Outdated Show resolved Hide resolved
src/sumlog.jl Outdated Show resolved Hide resolved
@cscherrer
Copy link
Author

It got faster! Check it out:

julia> x = rand(1000);

julia> @btime sum(log, $x)
  6.362 μs (0 allocations: 0 bytes)
-1027.6

julia> @btime sumlog($x)
  986.857 ns (0 allocations: 0 bytes)
-1027.6

@cscherrer
Copy link
Author

cscherrer commented May 3, 2022

The function has to be added to the docs for the tests to pass.

I don't understand what you mean by this. Docstrings are usually added automatically, what's left to do?

@cscherrer
Copy link
Author

I changed it to

function sumlog(x)
    T = float(eltype(x))
    _sumlog(T, values(x))
end

There's no need to restrict the type of x, and this allows it to be a Tuple or NamedTuple. For NamedTuples, calling values(x) makes it much faster, and this doesn't affect other types.

@devmotion
Copy link
Member

What's left to do?

You have to add the function to docs/src/index.md.

test/sumlog.jl Outdated
@@ -0,0 +1,7 @@
@testset "sumlog" begin
for T in [Int, Float16, Float32, Float64, BigFloat]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that you removed the type restriction. Thus we should extend the tests and eg. check more general iterables (also with different types, abstract eltype etc since sum(log, x) would work for them) and also complex numbers.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eltype doesn't work well for Base.Generators. Usually this is when I'd turn to something like

julia> Core.Compiler.return_type(gen.f, Tuple{eltype(gen.iter)})
Float64

We could instead have it fall back on the default, but I'd guess that will sacrifice performance.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I bet you could write an equally fast version which explicitly calls iterate, and widens if the type changes. (But usually the compiler will prove that it won't.)

One reason to keep mapreduce for arrays is that you can give it dims.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should check more carefully, but this appears to work & is as fast as current version:

function sumlog(x)
    iter = iterate(x)
    if isnothing(iter)
        return eltype(x) <: Number ? zero(float(eltype(x))) : 0.0
    end
    x1 = float(iter[1])
    x1 isa AbstractFloat || return sum(log, x)
    sig, ex = significand(x1), exponent(x1)
    iter = iterate(x, iter[2])
    while iter !== nothing
        xj = float(iter[1])
        x1 isa AbstractFloat || return sum(log, x)  # maybe not ideal, re-starts iterator
        sig, ex = _sumlog_op((sig, ex), (significand(xj), exponent(xj)))
        iter = iterate(x, iter[2])
    end
    return log(sig) + IrrationalConstants.logtwo * ex
end

sumlog(f, x) = sumlog(Iterators.map(f, x))
sumlog(f, x, ys...) = sumlog(f(xy...) for xy in zip(x, ys...))

And for dims:

sumlog(x::AbstractArray{T}; dims=:) where T = _sumlog(float(T), dims, x)

function _sumlog(::Type{T}, ::Colon, x) where {T<:AbstractFloat}
    sig, ex = mapreduce(_sumlog_op, x; init=(one(T), zero(exponent(one(T))))) do xj
        float_xj = float(xj)
        significand(float_xj), exponent(float_xj) 
    end
    return log(sig) + IrrationalConstants.logtwo * ex
end

function _sumlog(::Type{T}, dims, x) where {T<:AbstractFloat}
    sig_ex = mapreduce(_sumlog_op, x; dims=dims, init=(one(T), zero(exponent(one(T))))) do xj
        float_xj = float(xj)
        significand(float_xj), exponent(float_xj) 
    end
    map(sig_ex) do (sig, ex)
        log(sig) + IrrationalConstants.logtwo * ex
    end
end

Should I make a PR to the PR?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cscherrer#1 is a tidier version of the above.

@devmotion
Copy link
Member

I think it might be easier to focus on goals 1 (accuracy and consistency with sum(log, x)), 4 (type stability), 5 (performance), and GPU compatibility first. Supporting sumlog(f, x) and optional dims arguments seems less relevant initially.

Generally, I can see that the function can be useful in some cases but I would like to avoid that code complexity is increased too much in this package, so I think a simple implementation should be another main goal. IMO the code for logsumexp is already quite complex and hence difficult to maintain but probably this is justified by the popularity of this particular function.

@cscherrer
Copy link
Author

  • Does goal 1 imply that 0.0 and NaN etc. should propagate as usual?

Ideally, yes. But none of these are requirements in any way. The idea is more that if we start with an idealized wish list, it might be easier to talk about the design space and decide together where to make compromises.

Maybe this is just me, but after ten or so updates I find it too easy to get lost in the weeds. Maybe this can help keep us form going in circles in the discussion.

  • I don't see how to do sum(log ∘ f, x; dims) without type-inference

I think to start we should focus on the real case. We can come back to complex numbers - maybe this could be a kwarg or optional type parameter, or even a separate function.

Supporting sumlog(f, x) and optional dims arguments seems less relevant initially.

If we use mapreduce, I think we get dims support almost for free, is that right?

Generally, I can see that the function can be useful in some cases but I would like to avoid that code complexity is increased too much in this package, so I think a simple implementation should be another main goal.

I like "as simple as possible, but no simpler". I can understand wanting to avoid Base._return_type, and to lean toward higher-order functions to help with AD. But one concern with simplicity is the potential for others to re-implement to avoid any shortcomings. IMO some degree of complexity is better than simple code no one uses.


Also... We could consider changing this function name to logprod. Note that sum(log, x) ≈ log(prod(x)), with sum(log, x) having a more restricted domain (no negative reals allowed), and log(prod, x) being faster, but much more likely to overflow or underflow.

So if it's easy to set it up so "double negatives cancel", logprod might be a better name.

@cscherrer
Copy link
Author

This also points to another application - maybe you really want to just compute a product, but you'd like to avoid underflow and overflow. So for example @cjdoris's LogarithmicNumbers.jl would seem to benefit from adding

Base.prod(::Type{ULogarithmic}, x) = exp(ULogarithmic, logprod(x))

@devmotion
Copy link
Member

avoid Base._return_type

Yes, this is part of the code complexity goal but will also improve stability of LogExpFunctions. All such internal functions and "hacks" should be removed from the PR, in particular since it seems they can be avoided easily. Even standard libraries such as Statistics don't use _return_type to handle empty iterators, see eg https://github.com/JuliaLang/Statistics.jl/blob/cdd95fea3ce7bf31c68e01412548688fbd505903/src/Statistics.jl#L204 and https://github.com/JuliaLang/Statistics.jl/blob/cdd95fea3ce7bf31c68e01412548688fbd505903/src/Statistics.jl#L170.

@tpapp
Copy link
Collaborator

tpapp commented May 10, 2022

@cscherrer: Regarding stepping back and agreement: I always think in terms of costs and benefits (code complexity and maintainability vs how useful the code is), and personally I would just go to logs as soon as possible, even at a slight performance cost. But if you really need this, I am fine with including it.

Regarding the goals:

  1. I would rename to logprod, conveys the underlying algorithm and precision trade-offs better,

  2. I think that dims and foo(f, x) are unnecessary, and it is silly that each function replicates the boilerplate for this, given that Julia has much nicer mechanisms now for these, but I understand that whenever we leave these out someone will complain

  3. I would be happy with a robust and reasonably accurate logprod that is approximately sum(log, ...), with the understanding that someones one is better than the other. All algorithms have trade-offs and that's fine. Maybe we should document them though.

@cscherrer
Copy link
Author

cscherrer commented May 10, 2022

I've pushed another version. This time

  1. It's logprod instead of sumlog
  2. sumlog is still there for now, for easy comparison
  3. I remembered about frexp. @mcabbott this behaves better for subnormals
  4. I dropped dims, etc. I think we have a better understanding now of each other's priorities. I'm still in favor of more functionality, but we can start simple and get tests etc going. That will make it easier to weigh any drawbacks of adding more functionality.
  5. I think this is a good candidate to go in Base. If we make it logabsprod it can be a big help to speed up logabsdet:
julia> x = LowerTriangular(randn(1000,1000));

julia> using LinearAlgebra

julia> using BenchmarkTools

julia> @btime logabsdet($x)
  8.687 μs (0 allocations: 0 bytes)
(-631.836, -1.0)

julia> d = diag(x);

julia> @btime logabsprod($d)
  1.202 μs (0 allocations: 0 bytes)
(-631.836, 1.0)

I cheated here a little, since we don't have (that I know of) a lazy diag in Base.

docs/src/index.md Outdated Show resolved Hide resolved
src/LogExpFunctions.jl Outdated Show resolved Hide resolved
src/LogExpFunctions.jl Outdated Show resolved Hide resolved
@@ -0,0 +1,78 @@
"""
logprod(X::AbstractArray{T}; dims)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logprod(X::AbstractArray{T}; dims)
logprod(x)

src/logprod.jl Outdated Show resolved Hide resolved
x1 = float(iter[1])
x1 isa AbstractFloat || return sum(log, x)
x1 < 0 && Base.Math.throw_complex_domainerror(:log, x1)
sig, ex = significand(x1), _exponent(x1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sig, ex = significand(x1), _exponent(x1)
sig, ex = frexp(x1)

x1 isa AbstractFloat || return sum(log, x)
x1 < 0 && Base.Math.throw_complex_domainerror(:log, x1)
sig, ex = significand(x1), _exponent(x1)
nonfloat = zero(x1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
nonfloat = zero(x1)

while iter !== nothing
xj = float(iter[1])
if xj isa AbstractFloat
sig, ex = _logprod_op((sig, ex), (significand(xj), _exponent(xj)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sig, ex = _logprod_op((sig, ex), (significand(xj), _exponent(xj)))
sig, ex = _logabsprod_op((sig, ex), frexp(xj))

if xj isa AbstractFloat
sig, ex = _logprod_op((sig, ex), (significand(xj), _exponent(xj)))
else
nonfloat += log(xj)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
nonfloat += log(xj)
y = prod(x)
return log(abs(y)), sign(y)

end
iter = iterate(x, iter[2])
end
return log(sig) + IrrationalConstants.logtwo * oftype(sig, ex) + nonfloat
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return log(sig) + IrrationalConstants.logtwo * oftype(sig, ex) + nonfloat
return (log(abs(sig)) + IrrationalConstants.logtwo * oftype(sig, ex), sign(sig))

cscherrer and others added 4 commits May 10, 2022 07:08
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
@mcabbott
Copy link

mcabbott commented May 10, 2022

logprod is a neat idea to avoid checks.

frexp is also clearly what we were looking for. Does this have any effect on speed?

I am lost in all the noise on minor details, but checking sig > floatmax(typeof(sig)) / 2 is now the wrong thing, as it will overflow towards zero.

logprod should really have a case which tries log(prod()) first on small enough arrays, as this is much faster. (And it should advertise itself as being less prone to overflow than log(prod, rather than as being faster.)

Co-authored-by: David Widmann <[email protected]>
@cscherrer
Copy link
Author

logprod is a neat idea to avoid checks.

Thanks! I think the name is more natural too :)

frexp is also clearly what we were looking for. Does this have any effect on speed?

I had assumed if anything it might be more efficient, but maybe not. Here's a quick check:

julia> @benchmark frexp(x) setup=(x=100 * rand())
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     2.063 ns (0.00% GC)
  median time:      2.084 ns (0.00% GC)
  mean time:        2.085 ns (0.00% GC)
  maximum time:     6.112 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

julia> @benchmark (significand(x), exponent(x)) setup=(x=100 * rand())
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     1.613 ns (0.00% GC)
  median time:      1.633 ns (0.00% GC)
  mean time:        1.640 ns (0.00% GC)
  maximum time:     10.330 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

May not be a real effect since this is sub-nanosecond.

I am lost in all the noise on minor details, but checking sig > floatmax(typeof(sig)) / 2 is now the wrong thing, as it will overflow towards zero.

I had missed this, but @devmotion caught it too. Very surprising it would be so different.

I'm also wondering... Was this wrong in the first place? A factor of 2 should work if we assume it's sequential. But some implementations might exploit the associativity. In this case I think we need sqrt, or better, to look at exponent(sig). But with that we end up in the renormalization branch more often, so we're only about twice as fast as sum(log, x)

logprod should really have a case which tries log(prod()) first on small enough arrays, as this is much faster. (And it should advertise itself as being less prone to overflow than log(prod, rather than as being faster.)

Both great points.

@oscardssmith
Copy link
Contributor

do we want to merge this?

@tpapp
Copy link
Collaborator

tpapp commented Jul 21, 2022

I lost track of the various changes, but I am fine with merging. We can always micro-optimize things later; this is already more efficient than sum(log, x).

@cscherrer
Copy link
Author

Same. I expected this would be relatively simple, and was surprised by the number of cases to worry about. I agree merging and then handling various cases as they come up is reasonable. It looks like the tests need some update though, they're currently failing because they use sumlog (now undefined) instead of the updated name logprod.

@devmotion
Copy link
Member

I don't remember the details either. It seems there are unaddressed comments and the tests don't pass yet, so I guess the PR needs some updates and another round of review before merging.

@tpapp
Copy link
Collaborator

tpapp commented Nov 23, 2022

Friendly ping: this is a great contribution and it would be unfortunate to leave it dormant. @cscherrer, if you have the time to address the May 10 review comments by @devmotion and fix CI, I would be happy to merge.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants