Skip to content

Commit

Permalink
upgrade tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Sep 2, 2023
1 parent dbf39d4 commit 791531a
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 7 deletions.
17 changes: 13 additions & 4 deletions src/bias_act.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,23 @@ contains only `Ω` (the output) not `x`.
It is intended mainly for Flux layers, in which the previous operation is
known to be safe, e.g. `bias_act!(σ, weight * input, bias)` for a `Dense` layer.
"""
bias_act!::Function, x::AbstractArray, b) = fast_act(σ, x).(x .+ b) # fallback

bias_act!::Function, x::StridedArray{<:AbstractFloat}, b::AbstractArray{<:Union{Bool, AbstractFloat}}) =
_fast_broadcast!(fast_act(σ, x)(+), x, b) # works around a SIMD bug

bias_act!(::typeof(identity), x::StridedArray{<:AbstractFloat}, b::Bool) =
(@assert !b "bias=true is not accepted; layer constructors shoud guarantee this"; x)
function bias_act!::Function, x::StridedArray{<:AbstractFloat}, b::Bool)
b === true && error("bias=true is not accepted; layer constructors shoud guarantee this")
_fast_broadcast!(fast_act(σ, x), x)
end

function bias_act!(::typeof(identity), x::StridedArray{<:AbstractFloat}, b::Bool)
b === true && error("bias=true is not accepted; layer constructors shoud guarantee this")
x # pass-through
end

function bias_act!::Function, x::AbstractArray, b)
b === true && error("bias=true is not accepted; layer constructors shoud guarantee this")
fast_act(σ, x).(x .+ b) # fallback
end

function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractArray{T,N}, b::B) where {F,T,N,B}
biasgrad = if eltype(B) !== Bool
Expand Down
42 changes: 39 additions & 3 deletions test/bias_act.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,33 @@ ACTIVATION_FUNCTIONS =
@testset "bias_act!" begin
x = randn(3,4)
b = randn(3)
@test bias_act!(identity, copy(x), b) (x .+ b)
@test bias_act!(relu, copy(x), b) relu.(x .+ b)
@test bias_act!(tanh, copy(x), b) tanh.(x .+ b)
@test @inferred(bias_act!(identity, x, false)) === x # pass-through
@test @inferred(bias_act!(identity, copy(x), b)) (x .+ b)
@test @inferred(bias_act!(relu, copy(x), b)) relu.(x .+ b)
@test @inferred(bias_act!(tanh, copy(x), b)) tanh.(x .+ b)
@test @inferred(bias_act!(tanh, copy(x), false)) tanh.(x)

# Check that it does overwrite:
x32 = rand(Float32, 3, 4)
x32copy = copy(x32)
@test @inferred(bias_act!(cbrt, x32, b)) cbrt.(x32copy .+ b)
@test x32 cbrt.(x32copy .+ b)
x32 = rand(Float32, 3, 4)
x32copy = copy(x32)
@test @inferred(bias_act!(tanh, x32, false)) tanh.(x32copy)
@test x32 tanh.(x32copy)

# Check that it doesn't try to overwrite non-float arrays:
xint = rand(-3:3, 3, 4)
bint = rand(-2:2, 3)
@test bias_act!(identity, copy(xint), bint) xint .+ bint
@test bias_act!(tanh, copy(xint), bint) tanh.(xint .+ bint)
@test bias_act!(tanh, copy(xint), false) tanh.(xint)

# Reject bias===true so that Bool means one thing:
@test_throws Exception bias_act!(identity, rand(3), true)
@test_throws Exception bias_act!(cbrt, rand(3), true)
@test_throws Exception bias_act!(cbrt, rand(1:3, 3), true)

@testset "gradient with $fun" for fun in vcat([identity, tanh, cbrt],
ACTIVATION_FUNCTIONS,
Expand All @@ -21,9 +45,21 @@ ACTIVATION_FUNCTIONS =
@test bias_act!(fun, copy(x), false) fun.(x)

gx = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x)
gxplus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x .+ eps())
gxminus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x .- eps())
if !(gx gxplus gxminus)
@warn "skipping gradient tests due to discontinuity" fun x b
continue
end
@test gx Zygote.gradient(x -> sum(bias_act!(fun, copy(x), b)), x)[1]

gx2 = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x)
gx2plus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x .- eps())
gx2minus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x .- eps())
if !(gx2 gx2plus gx2minus)
@warn "skipping gradient tests due to discontinuity" fun x
continue
end
@test gx2 Zygote.gradient(x -> sum(bias_act!(fun, copy(x), false)), x)[1]

gb = ForwardDiff.gradient(b -> sum(bias_act!(fun, copy(x), b)), b)
Expand Down

0 comments on commit 791531a

Please sign in to comment.