Skip to content

Commit

Permalink
Make non-accumulators local
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Mar 4, 2022
1 parent b498a7f commit 513df39
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 2 deletions.
120 changes: 119 additions & 1 deletion src/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,125 @@ end
# https://github.com/tkf/ThreadsX.jl/pull/106. But this should be
# done automatically in Transducers.jl.

"""
localize(body::Expr) -> body′::Expr
Add `local`s to make non-accumulators local in the loop `body`.
It adds `local v` to the inner most scope-creating lexical block that contain
the outer most lexical block that has an assignment to `v`.
"""
localize(body::Expr) = localize!(body)

struct VariableEnv
current::Vector{Symbol}
outer::Union{VariableEnv,Nothing}
end

function Base.in(v::Symbol, env::VariableEnv)
v in env.current && return true
outer = env.outer
return outer !== nothing && v in outer
end

function Base.push!(env::VariableEnv, v::Symbol)
v in env || push!(env.current, v)
return env
end

Base.append!(env::VariableEnv, variables) = foldl(push!, variables; init = env)

localize!(@nospecialize(body), ::VariableEnv) = body
function localize!(body::Expr, env::VariableEnv = VariableEnv(Symbol[], nothing))
add_new_assignments!(env, body)
body = localize_nested(body, env)
if isempty(env.current)
return body
else
return Expr(:block, Expr(:local, env.current...), body)
end
end

# Not that this doesn't handle `local lhs = rhs` but it's OK since it doesn't
# matter exactly which scope the variable comes from.
function add_new_assignments!(env::VariableEnv, ex::Expr)
@match ex begin
Expr(:meta, _...) => nothing
Expr(:loopinfo, _...) => nothing

Expr(:function, Expr(:call, f::Symbol, _...), _...) => push!(env, f)
Expr(:(=), lhs, rhs) => begin
@match lhs begin
Expr(:call, f::Symbol, _...) => push!(env, f)
# TODO: handle where
_ => begin
if rhs isa Expr
add_new_assignments!(env, rhs)
end
append!(env, vars_in(lhs))
end
end
end

# Scope-creating
Expr(:let, _...) => nothing
Expr(:function, _...) => nothing
Expr(:->, _...) => nothing

Expr(_, args...) => begin
for x in args
if x isa Expr
add_new_assignments!(env, x)
end
end
end
end
end

localize_nested(@nospecialize(body), ::VariableEnv) = body
function localize_nested(body::Expr, env::VariableEnv)
@match body begin
Expr(:let, let_bindings_, let_body) => begin
let_bindings = @match let_bindings_ begin
Expr(:block, args...) => collect(args)
b => [b]
end
let_vars = mapfoldl(vars_in, append!, let_bindings; init = Symbol[])
localize!(let_body, VariableEnv(let_vars, env))
end
Expr(:(=), lhs, rhs) => begin
if isexpr(lhs, :call)
Expr(:(=), lhs, localize!(rhs, VariableEnv(Symbol[], env)))
else
Expr(:(=), lhs, localize_nested(rhs, env))
end
end
Expr(:function, call, def) => Expr(
:function,
localize_nested(call, env),
localize!(def, VariableEnv(Symbol[], env)),
)
Expr(:->, lhs, rhs) =>
Expr(:->, localize_nested(lhs, env), localize!(rhs, VariableEnv(Symbol[], env)))

Expr(:meta, _...) => body
Expr(:loopinfo, _...) => body

Expr(head, args...) => begin
args = mapfoldl(push!, args; init = []) do x
if x isa Expr
localize_nested(x, env)
else
x
end
end
Expr(head, args...)
end
end
end

function transform_loop_body(body, state_vars)
# body = localize(body) # TODO: enable this for sequential case as well
external_labels::Vector{Symbol} = setdiff(gotos_in(body), labels_in(body))
# state_vars = extract_state_vars(body)
pack_state = :(($(state_vars...),))
Expand Down Expand Up @@ -236,7 +354,7 @@ end
vars_in(x::Symbol) = [x]
function vars_in(ex)
@match ex begin
Expr(:tuple, vars...) => vars
Expr(:tuple, vars...) => mapfoldl(vars_in, append!, vars; init = Symbol[])
_ => Symbol[]
end
end
Expand Down
2 changes: 1 addition & 1 deletion src/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ function as_parallel_loop(ctx::MacroContext, rf_arg, coll, body0::Expr, simd, ex
end
check_invariance()

body2, info = transform_loop_body(body1, accs_symbols)
body2, info = transform_loop_body(localize(body1), accs_symbols)

@gensym oninit_function reducing_function combine_function result context_function
if ctx.module_ === Main
Expand Down

0 comments on commit 513df39

Please sign in to comment.