Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bias_act! #457

Merged
merged 11 commits into from
Sep 4, 2023
Merged

Add bias_act! #457

merged 11 commits into from
Sep 4, 2023

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Jan 5, 2023

This was part of #346, but the conv part got complicated.

Intended as a better alternative to part of FluxML/Flux.jl#2137 --- using this in layers will remove all identity.(x .+ false) broadcasts, with less repetition of the idea.

Dismayed how long the rrule code is here. I couldn't see what's wrong with the second case (it fails on swish) so I commented it out for now. There's room to improve this once JuliaDiff/ChainRulesCore.jl#592 works.

Benchmarks

Some min times are slower. But mean times show the effect of saving allocations.

## M1 mac, 1.10

julia> w, b = rand(Float32, 100, 10000), rand(Float32, 100);

julia> @btime bias_act!(relu, $w, $b);
  min 19.500 μs, mean 21.375 μs (0 allocations)

julia> @btime relu.($w .+ $b);
  min 17.208 μs, mean 62.826 μs (2 allocations, 390.67 KiB)

julia> @btime bias_act!(tanh, $w, $b);
  min 63.792 μs, mean 65.052 μs (0 allocations)

julia> @btime tanh_fast.($w .+ $b);
  min 63.583 μs, mean 102.004 μs (2 allocations, 390.67 KiB)

julia> using Zygote

julia> @btime gradient((w,b) -> sum(bias_act!(relu, w, b)), $w, $b);
  min 145.166 μs, mean 150.785 μs (51 allocations, 2.18 KiB)

julia> @btime gradient((w,b) -> sum(relu.(w .+ b)), $w, $b);
  min 165.583 μs, mean 314.267 μs (32 allocations, 1.15 MiB)

julia> @btime gradient((w,b) -> sum(bias_act!(tanh, w, b)), $w, $b);
  min 191.917 μs, mean 195.956 μs (51 allocations, 2.18 KiB)

julia> @btime gradient((w,b) -> sum(tanh_fast.(w .+ b)), $w, $b);
  min 209.458 μs, mean 338.652 μs (32 allocations, 1.15 MiB)



## Cyclops

julia> using CUDA  # 10x bigger. Note that I'm not measuring GPU allocations.

julia> cw, cb = CUDA.rand(Float32, 100, 100_00), CUDA.rand(Float32, 100);

julia> @btime CUDA.@sync bias_act!(relu, $cw, $cb);
  22.546 μs (27 allocations: 1.45 KiB)

julia> @btime CUDA.@sync relu.($cw .+ $cb);
  31.282 μs (38 allocations: 1.81 KiB)

julia> @btime CUDA.@sync bias_act!(tanh, $cw, $cb);
  27.030 μs (27 allocations: 1.45 KiB)

julia> @btime CUDA.@sync tanh_fast.($cw .+ $cb);
  36.421 μs (38 allocations: 1.81 KiB)

julia> using Zygote

julia> @btime CUDA.@sync gradient((w,b) -> sum(bias_act!(relu, w, b)), $cw, $cb);
  204.507 μs (382 allocations: 18.15 KiB)

julia> @btime CUDA.@sync gradient((w,b) -> sum(relu.(w .+ b)), $cw, $cb);
  204.458 μs (409 allocations: 19.19 KiB)

julia> @btime CUDA.@sync gradient((w,b) -> sum(bias_act!(tanh, w, b)), $cw, $cb);
  224.545 μs (382 allocations: 18.15 KiB)

julia> @btime CUDA.@sync gradient((w,b) -> sum(tanh_fast.(w .+ b)), $cw, $cb);
  204.793 μs (411 allocations: 19.30 KiB)

Flux:

julia> using Flux

julia> model = Chain(Dense(784=>512, relu), Dense(512=>256, relu), Dense(256, 10));

julia> x = randn(Float32, 784, 256);

julia> @btime $model($x);
  min 247.333 μs, mean 354.454 μs (10 allocations, 1.52 MiB)   # before
  min 235.708 μs, mean 292.189 μs (5 allocations, 778.22 KiB)  # after
  
julia> @btime gradient(m -> sum(abs2, m($x)), $model);
  min 859.292 μs, mean 1.454 ms (72 allocations, 6.60 MiB)  # before
  min 824.833 μs, mean 1.328 ms (98 allocations, 5.09 MiB)  # after

So 50% saving on the forward pass, as you'd expect.

If I'm thinking right, then JuliaDiff/ChainRulesCore.jl#592 should get the gradient down to 4.35 MB, saving about 1/3.

PR Checklist

  • Tests are added
  • Documentation, if applicable

@ToucheSir
Copy link
Member

ToucheSir commented Jan 5, 2023

If you have a stacktrace for the swish failure I can look into it.

Curious that only CI with threads > 1 is unhappy about the latest commit. There's no way we'd be secretly multithreading anywhere with this, right? Edit: windows CI is unhappy too, and that is single-threaded. Some sort of heisenbug perhaps.

@mcabbott
Copy link
Member Author

mcabbott commented Jan 5, 2023

It's very odd, buildkite also fails but not sure if this is multi-threaded.

For swish I got wrong answers, only in the cases which have nonzero bias.

@ToucheSir
Copy link
Member

ToucheSir commented Jan 5, 2023

Buildkite is the other multithreaded CI config. Do you recall which tests were failing?

@mcabbott
Copy link
Member Author

mcabbott commented Jan 7, 2023

On latest commit c1d834f the failures are like this:

...
  gradient with elu            |    7                    7     6.0s
  gradient with gelu           |    7                    7     6.1s
  gradient with swish          |    5     2              7     8.3s
  gradient with hardswish      |    7                    7     6.3s
  gradient with selu           |    7                    7     6.8s
  gradient with celu           |    7                    7     5.6s
  gradient with softplus       |    7                    7     6.2s
  gradient with softsign       |    7                    7     6.3s
  gradient with logσ           |    7                    7     5.7s
  gradient with logcosh        |    7                    7     6.9s
  gradient with mish           |    7                    7     6.3s
  gradient with tanhshrink     |    5     2              7     6.6s
  gradient with softshrink     |    7                    7     6.4s
  gradient with trelu          |    7                    7     6.6s
...
  gradient for fast_broadcast! |    4             1      5    45.9s
ERROR: Some tests did not pass: 206 passed, 4 failed, 0 errored, 1 broken.

julia> x = randn(3,4)
3×4 Matrix{Float64}:
 -1.22369   0.0921121  -0.941871  -1.19349
  1.31897   1.07247    -0.981244  -0.363552
 -0.707036  0.328161    0.252119   0.0805549

julia> b = randn(3)
3-element Vector{Float64}:
 -0.8524132980503979
  1.4314247570126006
 -0.2652038000170137

julia> fun = swish
swish (generic function with 2 methods)

julia> gx = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x)
3×4 Matrix{Float64}:
 -0.094139   0.153529  -0.076764  -0.0929148
  1.09521    1.09937    0.717713   0.947483
  0.0808417  0.531458   0.493458   0.408198

julia> Zygote.gradient(x -> sum(bias_act!(fun, copy(x), b)), x)[1]
σ = NNlib.swish
b = [-0.8524132980503979, 1.4314247570126006, -0.2652038000170137]
3×4 Matrix{Float64}:
 -0.893114   -0.229442  -0.814908  -0.886241
  1.08842     1.09301    1.06381    1.07222
 -0.0874094   0.479686   0.434657   0.332037

julia> gb = ForwardDiff.gradient(b -> sum(bias_act!(fun, copy(x), b)), b)
3-element Vector{Float64}:
 -0.11028850033542084
  3.8597769560080724
  1.5139547892071894

julia> Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), b)[1]
σ = NNlib.swish
b = [-0.8524132980503979, 1.4314247570126006, -0.2652038000170137]
3-element Vector{Float64}:
 -2.8237039339595764
  4.317457358561092
  1.158969872333099

Comment on lines +63 to +68
gx2 = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x)
gx2plus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x .- eps())
gx2minus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x .- eps())
if !(gx2 ≈ gx2plus ≈ gx2minus)
@warn "skipping gradient tests due to discontinuity" fun x
continue
Copy link
Member Author

@mcabbott mcabbott Jan 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This slightly elaborate thing is avoiding my best guess as to why there were failures on CI: hardsigmoid has discontinuities, and if x hits them, the two gradients may not agree.

But it doesn't seem to work:

  gradient with hardσ: Test Failed at /home/runner/work/NNlib.jl/NNlib.jl/test/bias_act.jl:73
  Expression: gb ≈ (Zygote.gradient((b->(sum(bias_act!(fun, copy(x), b));)), b))[1]
   Evaluated: [0.5, 0.6666666666666666, 0.6666666666666666] ≈ [1.5000000000000002, 0.6666666666666666, 0.6666666666666666]

@mcabbott
Copy link
Member Author

mcabbott commented Sep 2, 2023

Seems to pass, after rebasing. Failure is now only on windows. Perhaps that means it depends on random seed. Maybe we should just assume it's an effect of the kinks in hardσ?

The above benchmark, on the same computer, give much slower times, and a much larger speedup.

julia> w, b = rand(Float32, 100, 10000), rand(Float32, 100);

julia> @btime bias_act!(relu, $w, $b);
  min 141.250 μs, mean 145.076 μs (0 allocations)

julia> @btime relu.($w .+ $b);
  min 107.667 μs, mean 443.560 μs (2 allocations, 3.81 MiB)

julia> @btime bias_act!(tanh, $w, $b);
  min 418.125 μs, mean 425.345 μs (0 allocations)

julia> @btime tanh_fast.($w .+ $b);
  min 404.042 μs, mean 772.522 μs (2 allocations, 3.81 MiB)

julia> using Zygote

julia> @btime gradient((w,b) -> sum(bias_act!(relu, w, b)), $w, $b);
  min 424.875 μs, mean 818.428 μs (28 allocations, 3.82 MiB)

julia> @btime gradient((w,b) -> sum(relu.(w .+ b)), $w, $b);
  min 969.541 μs, mean 1.591 ms (32 allocations, 11.45 MiB)

julia> @btime gradient((w,b) -> sum(bias_act!(tanh, w, b)), $w, $b);
  min 700.292 μs, mean 1.037 ms (28 allocations, 3.82 MiB)

julia> @btime gradient((w,b) -> sum(tanh_fast.(w .+ b)), $w, $b);
  min 1.217 ms, mean 1.898 ms (32 allocations, 11.45 MiB)
  
julia> versioninfo()  # results look similar on 1.10 + 1.11
Julia Version 1.9.2
Commit e4ee485e909 (2023-07-05 09:39 UTC)
Platform Info:
  OS: macOS (arm64-apple-darwin22.4.0)
  CPU: 8 × Apple M1

@ToucheSir
Copy link
Member

Some buildkite jobs are not happy either. Can we constrain the inputs for the hardσ subset of tests somehow to avoid test flakiness and call it a day?

test/bias_act.jl Outdated Show resolved Hide resolved
@mcabbott mcabbott merged commit 90a0043 into FluxML:master Sep 4, 2023
10 of 13 checks passed
@mcabbott mcabbott deleted the bias_act_23 branch September 4, 2023 22:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants