Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use BroadcastThunk #705

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"

[compat]
Adapt = "3.4.0"
ChainRulesCore = "1.15.3"
ChainRulesCore = "1.16.0"
ChainRulesTestUtils = "1.5"
Compat = "3.46, 4.2"
FiniteDifferences = "0.12.20"
Expand Down
27 changes: 19 additions & 8 deletions src/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -372,11 +372,12 @@ end
function rrule(::typeof(/), A::AbstractArray{<:CommutativeMulNumber}, b::CommutativeMulNumber)
Y = A/b
function slash_pullback_scalar(ȳ)
Ȳ = unthunk(ȳ)
Athunk = InplaceableThunk(
dA -> dA .+= Ȳ ./ conj(b),
@thunk(Ȳ / conj(b)),
)
Ȳ = unthunk_bc(ȳ)
# Athunk = InplaceableThunk(
# dA -> dA .+= Ȳ ./ conj(b),
# @thunk(Ȳ / conj(b)),
# )
Athunk = @bc_thunk Ȳ / conj(b)
bthunk = @thunk(-dot(A,Ȳ) / conj(b^2))
return (NoTangent(), Athunk, bthunk)
end
Expand All @@ -400,7 +401,9 @@ frule((_, ΔA), ::typeof(-), A::AbstractArray) = -A, -ΔA

function rrule(::typeof(-), x::AbstractArray)
function negation_pullback(ȳ)
return NoTangent(), InplaceableThunk(ā -> ā .-= ȳ, @thunk(-ȳ))
Ȳ = unthunk_bc(ȳ)
# return NoTangent(), InplaceableThunk(ā -> ā .-= ȳ, @thunk(-ȳ))
return (NoTangent(), @bc_thunk -Ȳ)
end
return -x, negation_pullback
end
Expand All @@ -415,9 +418,17 @@ frule((_, ΔAs...), ::typeof(+), As::AbstractArray...) = +(As...), +(ΔAs...)
function rrule(::typeof(+), arrs::AbstractArray...)
y = +(arrs...)
arr_axs = map(axes, arrs)
function add_pullback(dy_raw)
function add_pullback_2(dy_raw)
dy = unthunk(dy_raw) # reshape will otherwise unthunk N times
return (NoTangent(), map(ax -> reshape(dy, ax), arr_axs)...)
end
return y, add_pullback
if all(ax -> ax===arr_axs[1], arr_axs)
# Here no reshape is needed,
add_pullback_1(dy::AbstractArray) = (NoTangent(), map(Returns(dy), arr_axs)...)
add_pullback_1(dy::BroadcastThunk) = (NoTangent(), map(Returns(dy), arr_axs)...)
add_pullback_1(dy) = add_pullback_2(dy)
return y, add_pullback_1
else
return y, add_pullback_2
end
end
55 changes: 37 additions & 18 deletions src/rulesets/Base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,42 +156,50 @@ end

const NumericOrBroadcast = Union{Number, AbstractArray{<:Number}, NTuple{<:Any,Number}, Broadcast.Broadcasted}

# Experimentally, also allow fused reverse passes, via `dx::BroadcastThunk`. This should similarly
# be used only on functions cheap enough that running twice is better than materialising the array.

using ChainRulesCore: BroadcastThunk, unthunk_or_bc, @bc_thunk

##### Arithmetic: +, -, *, ^2, /

function rrule(::typeof(broadcasted), ::typeof(+), xs::NumericOrBroadcast...)
@debug("broadcasting: plus", length(xs))
function bc_plus_back(dy_raw)
dy = unthunk(dy_raw)
dy = unthunk_or_bc(dy_raw) # this allows BroadcastThunk through, which unbroadcast understands
return (NoTangent(), NoTangent(), map(x -> unbroadcast(x, dy), xs)...) # no copies, this may return dx2 === dx3
end
return broadcasted(+, xs...), bc_plus_back
end

function rrule(::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast, y::NumericOrBroadcast)
@debug("broadcasting: minus 2")
@debug("broadcasting: 2-arg minus")
function bc_minus_back(dz_raw)
dz = unthunk(dz_raw)
return (NoTangent(), NoTangent(), @thunk(unbroadcast(x, dz)), @thunk(-unbroadcast(y, dz)))
end
function bc_minus_back(dz::BroadcastThunk) # ?? confused about double-thunking here
return (NoTangent(), NoTangent(), unbroadcast(x, dz), unbroadcast(y, -dz))
end
return broadcasted(-, x, y), bc_minus_back
end

function rrule(::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast)
@debug("broadcasting: minus 1")
bc_minus_back(dy) = (NoTangent(), NoTangent(), @thunk -unthunk(dy))
@debug("broadcasting: 1-arg minus")
bc_minus_back(dy) = (NoTangent(), NoTangent(), -dy)
return broadcasted(-, x), bc_minus_back
end

function rrule(::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast)
@debug("broadcasting: times")
function bc_times_back(Δraw)
Δ = unthunk(Δraw)
Δ = unthunk_or_bc(Δraw)
return (NoTangent(), NoTangent(), _back_star(x, y, Δ), _back_star(y, x, Δ))
end
return broadcasted(*, x, y), bc_times_back
end
_back_star(x, y, Δ) = @thunk unbroadcast(x, Δ .* conj.(y)) # this case probably isn't better than generic
_back_star(x::Number, y, Δ) = @thunk LinearAlgebra.dot(y, Δ) # ... but this is why the rule exists
_back_star(x, y, Δ) = unbroadcast(x, @bc_thunk Δ * conj(y)) # ?? confused about double thunking
_back_star(x::Number, y, Δ) = LinearAlgebra.dot(y, Δ)
_back_star(x::Bool, y, Δ) = NoTangent()
_back_star(x::Complex{Bool}, y, Δ) = NoTangent() # e.g. for fun.(im.*x)

Expand All @@ -210,8 +218,8 @@ end

function rrule(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::NumericOrBroadcast, ::Val{2})
@debug("broadcasting: square")
function bc_square_back(dy_raw)
dx = @thunk ProjectTo(x)(2 .* unthunk(dy_raw) .* conj.(x))
function bc_square_back(dy)
dx = @thunk ProjectTo(x)(@bc_thunk 2 * dy * conj(x))
return (NoTangent(), NoTangent(), NoTangent(), dx, NoTangent())
end
return broadcasted(Base.literal_pow, ^, x, Val(2)), bc_square_back
Expand All @@ -222,10 +230,10 @@ function rrule(::typeof(broadcasted), ::typeof(/), x::NumericOrBroadcast, y::Num
# z = broadcast(/, x, y)
z = broadcasted(/, x, y)
function bc_divide_back(dz_raw)
dz = unthunk(dz_raw)
dx = @thunk unbroadcast(x, dz ./ conj.(y))
dz = unthunk(dz_raw) # ??
dx = @thunk unbroadcast(x, @bc_thunk dz / conj(y))
# dy = @thunk -LinearAlgebra.dot(z, dz) / conj(y) # the reason to be eager is to allow dot here
dy = @thunk -sum(Broadcast.instantiate(broadcasted(*, broadcasted(conj, z), dz))) / conj(y) # complete sum is fast
dy = @thunk -sum(Broadcast.instantiate(broadcasted(*, broadcasted(conj, z), dz))) / conj(y) # complete sum is fast ??
return (NoTangent(), NoTangent(), dx, dy)
end
return z, bc_divide_back
Expand Down Expand Up @@ -256,14 +264,14 @@ rrule(::typeof(broadcasted), ::typeof(identity), x::Number) = rrule(identity, x)

function rrule(::typeof(broadcasted), ::Type{T}, x::NumericOrBroadcast) where {T<:Number}
@debug("broadcasting: type", T)
bc_type_back(dz) = (NoTangent(), NoTangent(), @thunk(unbroadcast(x, unthunk(dz))))
bc_type_back(dz) = (NoTangent(), NoTangent(), @thunk(unbroadcast(x, unthunk_or_bc(dz))))
return broadcasted(T, x), bc_type_back
end
rrule(::typeof(broadcasted), ::Type{T}, x::Number) where {T<:Number} = rrule(T, x) |> _prepend_zero

function rrule(::typeof(broadcasted), ::typeof(float), x::NumericOrBroadcast)
@debug("broadcasting: float")
bc_float_back(dz) = (NoTangent(), NoTangent(), @thunk(unbroadcast(x, unthunk(dz))))
bc_float_back(dz) = (NoTangent(), NoTangent(), @thunk(unbroadcast(x, unthunk_or_bc(dz))))
return broadcasted(float, x), bc_float_back
end
rrule(::typeof(broadcasted), ::typeof(float), x::Number) = rrule(float, x) |> _prepend_zero
Expand All @@ -273,7 +281,7 @@ rrule(::typeof(broadcasted), ::typeof(float), x::Number) = rrule(float, x) |> _p
for conj in [:conj, :adjoint] # identical as we know eltype <: Number
@eval begin
function rrule(::typeof(broadcasted), ::typeof($conj), x::NumericOrBroadcast)
bc_conj_back(dx) = (NoTangent(), NoTangent(), conj(unthunk(dx)))
bc_conj_back(dx) = (NoTangent(), NoTangent(), conj(unthunk_or_bc(dx)))
return broadcasted($conj, x), bc_conj_back
end
rrule(::typeof(broadcasted), ::typeof($conj), x::Number) = rrule($conj, x) |> _prepend_zero
Expand All @@ -285,15 +293,15 @@ end

function rrule(::typeof(broadcasted), ::typeof(real), x::NumericOrBroadcast)
@debug("broadcasting: real")
bc_real_back(dz) = (NoTangent(), NoTangent(), @thunk(real(unthunk(dz))))
bc_real_back(dz) = (NoTangent(), NoTangent(), @thunk(real(unthunk_or_bc(dz))))
return broadcasted(real, x), bc_real_back
end
rrule(::typeof(broadcasted), ::typeof(real), x::Number) = rrule(real, x) |> _prepend_zero
rrule(::typeof(broadcasted), ::typeof(real), x::AbstractArray{<:Real}) = rrule(identity, x) |> _prepend_zero

function rrule(::typeof(broadcasted), ::typeof(imag), x::NumericOrBroadcast)
@debug("broadcasting: imag")
bc_imag_back(dz) = (NoTangent(), NoTangent(), @thunk(im .* real.(unthunk(dz))))
bc_imag_back(dz) = (NoTangent(), NoTangent(), @thunk(im .* real.(unthunk_or_bc(dz))))
return broadcasted(imag, x), bc_imag_back
end
rrule(::typeof(broadcasted), ::typeof(imag), x::Number) = rrule(imag, x) |> _prepend_zero
Expand All @@ -305,7 +313,7 @@ end

function rrule(::typeof(broadcasted), ::typeof(complex), x::NumericOrBroadcast)
@debug("broadcasting: complex")
bc_complex_back(dz) = (NoTangent(), NoTangent(), @thunk(unbroadcast(x, unthunk(dz))))
bc_complex_back(dz) = (NoTangent(), NoTangent(), @thunk unbroadcast(x, unthunk_or_bc(dz)))
return broadcasted(complex, x), bc_complex_back
end
rrule(::typeof(broadcasted), ::typeof(complex), x::Number) = rrule(complex, x) |> _prepend_zero
Expand All @@ -327,6 +335,17 @@ function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx_raw)
ProjectTo(x)(sum(dx; dims))
end
end
function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::BroadcastThunk)
N = ndims(dx.bc)
if length(x) == length(dx.bc)
ProjectTo(x)(dx) # this may return a BroadcastThunk
else
T = Base.@default_eltype dx.bc
init = zero(T) # hack around sum(::Broadcasted; dims) not working
dims = ntuple(d -> get(size(x), d, 1) == 1 ? d : N+1, N)
ProjectTo(x)(sum(dx.bc; dims, init))
end
end
unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::AbstractZero) = dx

_ndims(x) = ndims(x)
Expand Down