diff --git a/src/bias_act.jl b/src/bias_act.jl index 43eac3a00..ef7fb29d9 100644 --- a/src/bias_act.jl +++ b/src/bias_act.jl @@ -29,14 +29,23 @@ contains only `Ω` (the output) not `x`. It is intended mainly for Flux layers, in which the previous operation is known to be safe, e.g. `bias_act!(σ, weight * input, bias)` for a `Dense` layer. """ -bias_act!(σ::Function, x::AbstractArray, b) = fast_act(σ, x).(x .+ b) # fallback - bias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::AbstractArray{<:Union{Bool, AbstractFloat}}) = _fast_broadcast!(fast_act(σ, x)∘(+), x, b) # works around a SIMD bug -bias_act!(::typeof(identity), x::StridedArray{<:AbstractFloat}, b::Bool) = - (@assert !b "bias=true is not accepted; layer constructors shoud guarantee this"; x) +function bias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::Bool) + b === true && error("bias=true is not accepted; layer constructors shoud guarantee this") + _fast_broadcast!(fast_act(σ, x), x) +end +function bias_act!(::typeof(identity), x::StridedArray{<:AbstractFloat}, b::Bool) + b === true && error("bias=true is not accepted; layer constructors shoud guarantee this") + x # pass-through +end + +function bias_act!(σ::Function, x::AbstractArray, b) + b === true && error("bias=true is not accepted; layer constructors shoud guarantee this") + fast_act(σ, x).(x .+ b) # fallback +end function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractArray{T,N}, b::B) where {F,T,N,B} biasgrad = if eltype(B) !== Bool diff --git a/test/bias_act.jl b/test/bias_act.jl index c4d682559..31dcb0487 100644 --- a/test/bias_act.jl +++ b/test/bias_act.jl @@ -7,9 +7,33 @@ ACTIVATION_FUNCTIONS = @testset "bias_act!" begin x = randn(3,4) b = randn(3) - @test bias_act!(identity, copy(x), b) ≈ (x .+ b) - @test bias_act!(relu, copy(x), b) ≈ relu.(x .+ b) - @test bias_act!(tanh, copy(x), b) ≈ tanh.(x .+ b) + @test @inferred(bias_act!(identity, x, false)) === x # pass-through + @test @inferred(bias_act!(identity, copy(x), b)) ≈ (x .+ b) + @test @inferred(bias_act!(relu, copy(x), b)) ≈ relu.(x .+ b) + @test @inferred(bias_act!(tanh, copy(x), b)) ≈ tanh.(x .+ b) + @test @inferred(bias_act!(tanh, copy(x), false)) ≈ tanh.(x) + + # Check that it does overwrite: + x32 = rand(Float32, 3, 4) + x32copy = copy(x32) + @test @inferred(bias_act!(cbrt, x32, b)) ≈ cbrt.(x32copy .+ b) + @test x32 ≈ cbrt.(x32copy .+ b) + x32 = rand(Float32, 3, 4) + x32copy = copy(x32) + @test @inferred(bias_act!(tanh, x32, false)) ≈ tanh.(x32copy) + @test x32 ≈ tanh.(x32copy) + + # Check that it doesn't try to overwrite non-float arrays: + xint = rand(-3:3, 3, 4) + bint = rand(-2:2, 3) + @test bias_act!(identity, copy(xint), bint) ≈ xint .+ bint + @test bias_act!(tanh, copy(xint), bint) ≈ tanh.(xint .+ bint) + @test bias_act!(tanh, copy(xint), false) ≈ tanh.(xint) + + # Reject bias===true so that Bool means one thing: + @test_throws Exception bias_act!(identity, rand(3), true) + @test_throws Exception bias_act!(cbrt, rand(3), true) + @test_throws Exception bias_act!(cbrt, rand(1:3, 3), true) @testset "gradient with $fun" for fun in vcat([identity, tanh, cbrt], ACTIVATION_FUNCTIONS, @@ -21,9 +45,21 @@ ACTIVATION_FUNCTIONS = @test bias_act!(fun, copy(x), false) ≈ fun.(x) gx = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x) + gxplus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x .+ eps()) + gxminus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x .- eps()) + if !(gx ≈ gxplus ≈ gxminus) + @warn "skipping gradient tests due to discontinuity" fun x b + continue + end @test gx ≈ Zygote.gradient(x -> sum(bias_act!(fun, copy(x), b)), x)[1] gx2 = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x) + gx2plus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x .- eps()) + gx2minus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x .- eps()) + if !(gx2 ≈ gx2plus ≈ gx2minus) + @warn "skipping gradient tests due to discontinuity" fun x + continue + end @test gx2 ≈ Zygote.gradient(x -> sum(bias_act!(fun, copy(x), false)), x)[1] gb = ForwardDiff.gradient(b -> sum(bias_act!(fun, copy(x), b)), b)