diff --git a/src/core.jl b/src/core.jl index b58120a8..089860f1 100644 --- a/src/core.jl +++ b/src/core.jl @@ -26,15 +26,15 @@ function show(io::IO, t::Tape) end end end -@inline getindex(t::Tape, n::Int) = getindex(tape(t), n) -@inline getindex(t::Tape, node::Node) = getindex(t, pos(node)) -@inline lastindex(t::Tape) = length(t) -@inline setindex!(t::Tape, x, n::Int) = (tape(t)[n] = x; t) -@inline eachindex(t::Tape) = eachindex(tape(t)) -@inline length(t::Tape) = length(tape(t)) -@inline push!(t::Tape, node::Node) = (push!(tape(t), node); t) -@inline isassigned(t::Tape, n::Int) = isassigned(tape(t), n) -@inline isassigned(t::Tape, node::Node) = isassigned(t, pos(node)) +getindex(t::Tape, n::Int) = getindex(tape(t), n) +getindex(t::Tape, node::Node) = getindex(t, pos(node)) +lastindex(t::Tape) = length(t) +setindex!(t::Tape, x, n::Int) = (tape(t)[n] = x; t) +eachindex(t::Tape) = eachindex(tape(t)) +length(t::Tape) = length(tape(t)) +push!(t::Tape, node::Node) = (push!(tape(t), node); t) +isassigned(t::Tape, n::Int) = isassigned(tape(t), n) +isassigned(t::Tape, node::Node) = isassigned(t, pos(node)) # Make `Tape`s broadcast as scalars without a warning on 0.7 Base.Broadcast.broadcastable(tape::Tape) = Ref(tape) @@ -122,7 +122,7 @@ zero(n::Node) = zero(unbox(n)) one(n::Node) = one(unbox(n)) # Leafs do nothing, Branches compute their own sensitivities and update others. -@inline propagate(y::Leaf, rvs_tape::Tape) = nothing +propagate(y::Leaf, rvs_tape::Tape) = nothing function propagate(y::Branch, rvs_tape::Tape) tape = Nabla.tape(rvs_tape) ȳ, f = tape[pos(y)], getfield(y, :f) @@ -177,11 +177,11 @@ is the output of `preprocess`. `x1`, `x2`,... are the inputs to the function, `y output and `ȳ` the reverse-mode sensitivity of `y`. """ ∇(y::Node, ȳ) = propagate(tape(y), reverse_tape(y, ȳ)) -@inline ∇(y::Node{<:∇Scalar}) = ∇(y, one(unbox(y))) +∇(y::Node{<:∇Scalar}) = ∇(y, one(unbox(y))) # This is a fallback method where we don't necessarily know what we'll be adding and whether # we can update the value in-place, so we'll try to be clever and dispatch. -@inline ∇(x̄, f, ::Type{Arg{N}}, args...) where {N} = update!(x̄, ∇(f, Arg{N}, args...)) +∇(x̄, f, ::Type{Arg{N}}, args...) where {N} = update!(x̄, ∇(f, Arg{N}, args...)) # Update regular arrays in-place. Structured array types should not be updated in-place, # even though it technically "works" (https://github.com/JuliaLang/julia/issues/31674), @@ -242,14 +242,14 @@ for (f_name, scalar_init, array_init) in (:zero, :one, nothing), (:zeros, :ones, nothing)) if scalar_init !== nothing - @eval @inline $f_name(x::Number) = $scalar_init(x) + @eval $f_name(x::Number) = $scalar_init(x) end if array_init !== nothing - @eval @inline $f_name(x::AbstractArray{<:Real}) = $array_init(eltype(x), size(x)) + @eval $f_name(x::AbstractArray{<:Real}) = $array_init(eltype(x), size(x)) end eval(quote - @inline $f_name(x::Tuple) = map($f_name, x) - @inline function $f_name(x) + $f_name(x::Tuple) = map($f_name, x) + function $f_name(x) y = Base.copy(x) for n in eachindex(y) @inbounds y[n] = $f_name(y[n]) @@ -258,10 +258,10 @@ for (f_name, scalar_init, array_init) in end end) end -@inline randned_container(x::Number) = randn(typeof(x)) -@inline randned_container(x::AbstractArray{<:Real}) = randn(eltype(x), size(x)...) +randned_container(x::Number) = randn(typeof(x)) +randned_container(x::AbstractArray{<:Real}) = randn(eltype(x), size(x)...) for T in (:Diagonal, :UpperTriangular, :LowerTriangular) - @eval @inline randned_container(x::$T{<:Real}) = $T(randn(eltype(x), size(x)...)) + @eval randned_container(x::$T{<:Real}) = $T(randn(eltype(x), size(x)...)) end # Bare-bones FMAD implementation based on DualNumbers. Accepts a Tuple of args and returns diff --git a/src/sensitivities/functional/functional.jl b/src/sensitivities/functional/functional.jl index f72e75db..d76140a1 100644 --- a/src/sensitivities/functional/functional.jl +++ b/src/sensitivities/functional/functional.jl @@ -101,39 +101,39 @@ _∇(::typeof(broadcast), ::Type{Arg{N}}, p, y, ȳ, f, A...) where N = # Addition. import Base: + @eval @explicit_intercepts $(Symbol("+")) Tuple{∇Array, ∇Array} -@inline ∇(::typeof(+), ::Type{Arg{1}}, p, z, z̄, x::∇Array, y::∇Array) = +∇(::typeof(+), ::Type{Arg{1}}, p, z, z̄, x::∇Array, y::∇Array) = ∇(broadcast, Arg{2}, p, z, z̄, +, x, y) -@inline ∇(::typeof(+), ::Type{Arg{2}}, p, z, z̄, x::∇Array, y::∇Array) = +∇(::typeof(+), ::Type{Arg{2}}, p, z, z̄, x::∇Array, y::∇Array) = ∇(broadcast, Arg{3}, p, z, z̄, +, x, y) # Multiplication. import Base: * @eval @explicit_intercepts $(Symbol("*")) Tuple{∇ArrayOrScalar, ∇ArrayOrScalar} -@inline ∇(::typeof(*), ::Type{Arg{1}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) = +∇(::typeof(*), ::Type{Arg{1}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) = ∇(broadcast, Arg{2}, p, z, z̄, *, x, y) -@inline ∇(::typeof(*), ::Type{Arg{2}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) = +∇(::typeof(*), ::Type{Arg{2}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) = ∇(broadcast, Arg{3}, p, z, z̄, *, x, y) # Subtraction. import Base: - @eval @explicit_intercepts $(Symbol("-")) Tuple{∇Array, ∇Array} -@inline ∇(::typeof(-), ::Type{Arg{1}}, p, z, z̄, x::∇Array, y::∇Array) = +∇(::typeof(-), ::Type{Arg{1}}, p, z, z̄, x::∇Array, y::∇Array) = ∇(broadcast, Arg{2}, p, z, z̄, -, x, y) -@inline ∇(::typeof(-), ::Type{Arg{2}}, p, z, z̄, x::∇Array, y::∇Array) = +∇(::typeof(-), ::Type{Arg{2}}, p, z, z̄, x::∇Array, y::∇Array) = ∇(broadcast, Arg{3}, p, z, z̄, -, x, y) # Division from the right by a scalar. import Base: / @eval @explicit_intercepts $(Symbol("/")) Tuple{∇Array, ∇Scalar} -@inline ∇(::typeof(/), ::Type{Arg{1}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) = +∇(::typeof(/), ::Type{Arg{1}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) = ∇(broadcast, Arg{2}, p, z, z̄, /, x, y) -@inline ∇(::typeof(/), ::Type{Arg{2}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) = +∇(::typeof(/), ::Type{Arg{2}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) = ∇(broadcast, Arg{3}, p, z, z̄, /, x, y) # Division from the left by a scalar. import Base: \ @eval @explicit_intercepts $(Symbol("\\")) Tuple{∇Scalar, ∇Array} -@inline ∇(::typeof(\), ::Type{Arg{1}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) = +∇(::typeof(\), ::Type{Arg{1}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) = ∇(broadcast, Arg{2}, p, z, z̄, \, x, y) -@inline ∇(::typeof(\), ::Type{Arg{2}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) = +∇(::typeof(\), ::Type{Arg{2}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) = ∇(broadcast, Arg{3}, p, z, z̄, \, x, y) diff --git a/src/sensitivities/scalar.jl b/src/sensitivities/scalar.jl index 531854b6..870d8295 100644 --- a/src/sensitivities/scalar.jl +++ b/src/sensitivities/scalar.jl @@ -6,8 +6,8 @@ using DiffRules: DiffRules, @define_diffrule, diffrule, diffrules, hasdiffrule # gradient implemented for use in higher-order functions. import Base.identity @explicit_intercepts identity Tuple{Any} -@inline ∇(::typeof(identity), ::Type{Arg{1}}, p, y, ȳ, x) = ȳ -@inline ∇(::typeof(identity), ::Type{Arg{1}}, x::Real) = one(x) +∇(::typeof(identity), ::Type{Arg{1}}, p, y, ȳ, x) = ȳ +∇(::typeof(identity), ::Type{Arg{1}}, x::Real) = one(x) # Ignore functions that have complex ranges. This may change when Nabla supports complex # numbers. @@ -29,8 +29,8 @@ for (package, f, arity) in diffrules() push!(unary_sensitivities, (package, f)) ∂f∂x = diffrule(package, f, :x) @eval @explicit_intercepts $f Tuple{∇Scalar} - @eval @inline ∇(::typeof($f), ::Type{Arg{1}}, p, y, ȳ, x::∇Scalar) = ȳ * $∂f∂x - @eval @inline ∇(::typeof($f), ::Type{Arg{1}}, x::∇Scalar) = $∂f∂x + @eval ∇(::typeof($f), ::Type{Arg{1}}, p, y, ȳ, x::∇Scalar) = ȳ * $∂f∂x + @eval ∇(::typeof($f), ::Type{Arg{1}}, x::∇Scalar) = $∂f∂x elseif arity == 2 push!(binary_sensitivities, (package, f)) ∂f∂x, ∂f∂y = diffrule(package, f, :x, :y) diff --git a/src/sensitivity.jl b/src/sensitivity.jl index e6590ccc..f2a4f348 100644 --- a/src/sensitivity.jl +++ b/src/sensitivity.jl @@ -277,4 +277,4 @@ Default implementation of preprocess returns an empty Tuple. Individual sensitiv implementations should add methods specific to their use case. The output is passed in to `∇` as the 3rd or 4th argument in the new-x̄ and update-x̄ cases respectively. """ -@inline preprocess(::Any, args...) = () +preprocess(::Any, args...) = ()