Skip to content

Commit

Permalink
use NNlib.bias_act
Browse files Browse the repository at this point in the history
rm comments
  • Loading branch information
mcabbott committed Mar 30, 2024
1 parent 8654721 commit 4ab8343
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Functors = "0.4"
MLUtils = "0.4"
MacroTools = "0.5"
Metal = "0.5, 1"
NNlib = "0.9.1"
NNlib = "0.9.5"
OneHotArrays = "0.2.4"
Optimisers = "0.3.2"
Preferences = "1"
Expand Down
5 changes: 2 additions & 3 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,8 @@ end

function (a::Dense)(x::AbstractVecOrMat)
_size_check(a, x, 1 => size(a.weight, 2))
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
xT = _match_eltype(a, x) # fixes Float64 input, etc.
return σ.(a.weight * xT .+ a.bias)
NNlib.bias_act!(a.σ, a.weight * xT, a.bias) # does σ.(W*x .+ b), with fast paths
end

function (a::Dense)(x::AbstractArray)
Expand Down Expand Up @@ -446,7 +445,7 @@ function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix)
Z = reshape(Wyx, (d_z, :))

# @einsum out[o,s] := σ(Z[o,i] + b[o])
σ.(Z .+ b)
NNlib.bias_act!(σ, Z, b) # σ.(Z .+ b)
end

(a::Bilinear)(x::AbstractVecOrMat) = a(x, x)
Expand Down
9 changes: 3 additions & 6 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,9 @@ ChainRulesCore.@non_differentiable conv_dims(::Any, ::Any)

function (c::Conv)(x::AbstractArray)
_conv_size_check(c, x)
σ = NNlib.fast_act(c.σ, x)
cdims = conv_dims(c, x)
xT = _match_eltype(c, x)
σ.(conv(xT, c.weight, cdims) .+ conv_reshape_bias(c))
NNlib.bias_act!(c.σ, conv(xT, c.weight, cdims), conv_reshape_bias(c))
end

_channels_in(l::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups
Expand Down Expand Up @@ -332,10 +331,9 @@ ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any)

function (c::ConvTranspose)(x::AbstractArray)
_conv_size_check(c, x)
σ = NNlib.fast_act(c.σ, x)
cdims = conv_transpose_dims(c, x)
xT = _match_eltype(c, x)
σ.(∇conv_data(xT, c.weight, cdims) .+ conv_reshape_bias(c))
NNlib.bias_act!(c.σ, ∇conv_data(xT, c.weight, cdims), conv_reshape_bias(c))
end

function Base.show(io::IO, l::ConvTranspose)
Expand Down Expand Up @@ -474,10 +472,9 @@ ChainRulesCore.@non_differentiable crosscor_dims(::Any, ::Any)

function (c::CrossCor)(x::AbstractArray)
_conv_size_check(c, x)
σ = NNlib.fast_act(c.σ, x)
cdims = crosscor_dims(c, x)
xT = _match_eltype(c, x)
σ.(crosscor(xT, c.weight, cdims) .+ conv_reshape_bias(c))
NNlib.bias_act!(c.σ, crosscor(xT, c.weight, cdims), conv_reshape_bias(c))
end

function Base.show(io::IO, l::CrossCor)
Expand Down
2 changes: 1 addition & 1 deletion src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ function _norm_layer_forward(
β = reshape(l.β, affine_shape)

scale = γ ./ sqrt.(σ² .+ eps)
bias = -scale .* μ .+ β
bias = .-scale .* μ .+ β
l.λ.(scale .* x .+ bias)
end

Expand Down

0 comments on commit 4ab8343

Please sign in to comment.