Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Sep 2, 2023
1 parent 791531a commit 7a58e56
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions test/bias_act.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using NNlib, Zygote, Test
using NNlib, Zygote, ChainRulesCore, Test
using Zygote: ForwardDiff

ACTIVATION_FUNCTIONS =
Expand All @@ -14,14 +14,21 @@ ACTIVATION_FUNCTIONS =
@test @inferred(bias_act!(tanh, copy(x), false)) tanh.(x)

# Check that it does overwrite:
x32 = rand(Float32, 3, 4)
x32copy = copy(x32)
x32 = rand(Float32, 3, 4); x32copy = copy(x32)
@test @inferred(bias_act!(cbrt, x32, b)) cbrt.(x32copy .+ b)
@test x32 cbrt.(x32copy .+ b)
x32 = rand(Float32, 3, 4)
x32copy = copy(x32)
@test x32 cbrt.(x32copy .+ b)

x32 = rand(Float32, 3, 4); x32copy = copy(x32) # without bias
@test @inferred(bias_act!(tanh, x32, false)) tanh.(x32copy)
@test x32 tanh.(x32copy)
@test x32 tanh.(x32copy)

x32 = rand(Float32, 3, 4); x32copy = copy(x32) # now check gradient rule
y, back = rrule(Zygote.ZygoteRuleConfig(), bias_act!, relu, x32, b)
@test y x32 relu.(x32copy .+ b)

x32 = rand(Float32, 3, 4); x32copy = copy(x32) # without bias
y, back = rrule(Zygote.ZygoteRuleConfig(), bias_act!, relu, x32, false)
@test y x32 relu.(x32copy)

# Check that it doesn't try to overwrite non-float arrays:
xint = rand(-3:3, 3, 4)
Expand Down Expand Up @@ -78,7 +85,7 @@ ACTIVATION_FUNCTIONS =
g2 = ForwardDiff.gradient(x) do x
sum(abs2, Zygote.gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1])
end
@test_broken gx Zygote.gradient(x) do x
@test_skip gx Zygote.gradient(x) do x # Here global variable b causes an error
sum(abs2,Zygote. gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1])
end
# Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_eqtable_get), Any, svec(Any, Any, Any), 0, :(:ccall), %5, %3, %4)).
Expand Down

0 comments on commit 7a58e56

Please sign in to comment.