From 7a58e563f926c22f5ea5760982dd6cb00cd5fd43 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 7 Jan 2023 12:48:46 -0500 Subject: [PATCH] more tests --- test/bias_act.jl | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/test/bias_act.jl b/test/bias_act.jl index 31dcb0487..110f1a24e 100644 --- a/test/bias_act.jl +++ b/test/bias_act.jl @@ -1,4 +1,4 @@ -using NNlib, Zygote, Test +using NNlib, Zygote, ChainRulesCore, Test using Zygote: ForwardDiff ACTIVATION_FUNCTIONS = @@ -14,14 +14,21 @@ ACTIVATION_FUNCTIONS = @test @inferred(bias_act!(tanh, copy(x), false)) ≈ tanh.(x) # Check that it does overwrite: - x32 = rand(Float32, 3, 4) - x32copy = copy(x32) + x32 = rand(Float32, 3, 4); x32copy = copy(x32) @test @inferred(bias_act!(cbrt, x32, b)) ≈ cbrt.(x32copy .+ b) - @test x32 ≈ cbrt.(x32copy .+ b) - x32 = rand(Float32, 3, 4) - x32copy = copy(x32) + @test x32 ≈ cbrt.(x32copy .+ b) + + x32 = rand(Float32, 3, 4); x32copy = copy(x32) # without bias @test @inferred(bias_act!(tanh, x32, false)) ≈ tanh.(x32copy) - @test x32 ≈ tanh.(x32copy) + @test x32 ≈ tanh.(x32copy) + + x32 = rand(Float32, 3, 4); x32copy = copy(x32) # now check gradient rule + y, back = rrule(Zygote.ZygoteRuleConfig(), bias_act!, relu, x32, b) + @test y ≈ x32 ≈ relu.(x32copy .+ b) + + x32 = rand(Float32, 3, 4); x32copy = copy(x32) # without bias + y, back = rrule(Zygote.ZygoteRuleConfig(), bias_act!, relu, x32, false) + @test y ≈ x32 ≈ relu.(x32copy) # Check that it doesn't try to overwrite non-float arrays: xint = rand(-3:3, 3, 4) @@ -78,7 +85,7 @@ ACTIVATION_FUNCTIONS = g2 = ForwardDiff.gradient(x) do x sum(abs2, Zygote.gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1]) end - @test_broken gx ≈ Zygote.gradient(x) do x + @test_skip gx ≈ Zygote.gradient(x) do x # Here global variable b causes an error sum(abs2,Zygote. gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1]) end # Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_eqtable_get), Any, svec(Any, Any, Any), 0, :(:ccall), %5, %3, %4)).