Skip to content

Commit

Permalink
port float encoding tests from hypothesis
Browse files Browse the repository at this point in the history
This just ports the tests that check the ordering properties of the
encoding and fixes the bugs that were found by the tests.
  • Loading branch information
raineszm committed Jul 17, 2024
1 parent 1df83c0 commit 30533a7
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ module Data

using Supposition
using Supposition: smootherstep, lerp, TestCase, choice!, weighted!, forced_choice!, reject
using Supposition.FloatEncoding: lexographical_float
using Supposition.FloatEncoding: lex_to_float
using RequiredInterfaces: @required
using StyledStrings: @styled_str
using Printf: format, @format_str
Expand Down Expand Up @@ -1458,7 +1458,7 @@ function produce!(tc::TestCase, f::Floats{T}) where {T}

is_negative = produce!(tc, Booleans())

res = lexographical_float(T, bits)
res = lex_to_float(T, bits)
if is_negative
res = -res
end
Expand Down
8 changes: 4 additions & 4 deletions src/float.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ end


"""
lexographical_float(T, bits)
lex_to_float(T, bits)
Reinterpret the bits of a floating point number using an encoding with better shrinking
properties.
Expand All @@ -100,7 +100,7 @@ If the sign bit is not set:
- the float is reassembled using `assemble`
"""
function lexographical_float(::Type{T}, bits::I)::T where {I,T<:Base.IEEEFloat}
function lex_to_float(::Type{T}, bits::I)::T where {I,T<:Base.IEEEFloat}
sizeof(T) == sizeof(I) || throw(ArgumentError("The bitwidth of `$T` needs to match the bidwidth of `I`!"))
iT = uint(T)
sign, exponent, mantissa = tear(reinterpret(T, bits))
Expand Down Expand Up @@ -128,7 +128,7 @@ function is_simple_float(f::T) where {T<:Base.IEEEFloat}
if trunc(f) != f
return false
end
Base.top_set_bit(uint(f)) <= 8 * (sizeof(T) - 1)
Base.top_set_bit(reinterpret(uint(T), f)) <= 8 * (sizeof(T) - 1)
catch e
if isa(e, InexactError)
return false
Expand All @@ -142,6 +142,6 @@ function base_float_to_lex(f::T) where {T<:Base.IEEEFloat}
mantissa = update_mantissa(T, exponent, mantissa)
exponent = decode_exponent(exponent)

reinterpret(uint(T), assemble(T, sign, exponent, mantissa))
reinterpret(uint(T), assemble(T, one(uint(T)), exponent, mantissa))
end
end
58 changes: 57 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -516,9 +516,11 @@ const verb = VERSION.major == 1 && VERSION.minor < 11
# Tests the properties of the enocding used to represent floating point numbers
@testset "Floating point encoding" begin
@testset for T in (Float16, Float32, Float64)

iT = Supposition.uint(T)
# These invariants are ported from Hypothesis
@testset "Exponent encoding" begin
exponents = zero(Supposition.uint(T)):Supposition.max_exponent(T)
exponents = zero(iT):Supposition.max_exponent(T)

# Round tripping
@test all(exponents) do e
Expand All @@ -529,6 +531,60 @@ const verb = VERSION.major == 1 && VERSION.minor < 11
FloatEncoding.encode_exponent(FloatEncoding.decode_exponent(e)) == e
end
end

function roundtrip_encoding(f)
assume!(!signbit(f))
encoded = FloatEncoding.float_to_lex(f)
decoded = FloatEncoding.lex_to_float(T, encoded)
reinterpret(iT, decoded) == reinterpret(iT, f)
end

roundtrip_examples = map(Data.Just,
T[
0.0,
2.5,
8.000000000000007,
3.0,
2.0,
1.9999999999999998,
1.0
])
@check roundtrip_encoding(Data.OneOf(roundtrip_examples...))
@check roundtrip_encoding(Data.Floats{T}(; minimum=zero(T)))

@testset "Ordering" begin
function order_integral_part(n::T, g::T)
f = n + g
assume!(trunc(f) != f)
assume!(trunc(f) != 0)
i = FloatEncoding.float_to_lex(f)
g = trunc(f)
FloatEncoding.float_to_lex(g) < i
end

@check order_integral_part(Data.Just(1.0), Data.Just(0.5))
@check order_integral_part(
Data.Floats{T}(;
minimum=one(T),
maximum=T(2^(Supposition.fracsize(T) + 1)),
nans=false),
filter(x -> !(x in T[0, 1]),
Data.Floats{T}(; minimum=zero(T), maximum=one(T), nans=false)))

integral_float_gen = map(abs trunc,
Data.Floats{T}(; minimum=zero(T), infs=false, nans=false))

@check function integral_floats_order_as_integers(x=integral_float_gen,
y=integral_float_gen)
(x < y) == (FloatEncoding.float_to_lex(x) < FloatEncoding.float_to_lex(y))
end

@check function fractional_floats_greater_than_1(
f=Data.Floats{T}(; minimum=zero(T), maximum=one(T), nans=false))
assume!(0 < f < 1)
FloatEncoding.float_to_lex(f) > FloatEncoding.float_to_lex(one(T))
end
end
end
end

Expand Down

0 comments on commit 30533a7

Please sign in to comment.