Skip to content

Commit

Permalink
API simplification.
Browse files Browse the repository at this point in the history
  • Loading branch information
rofinn committed Jul 5, 2019
1 parent 9f7cce7 commit 2066b5b
Show file tree
Hide file tree
Showing 11 changed files with 110 additions and 207 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
IterTools = "1.2"
Tables = "0.2"

[extras]
Expand Down
129 changes: 23 additions & 106 deletions src/Impute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using Statistics
using StatsBase
using Tables: Tables, materializer, istable

import Base.Iterators
import Base.Iterators: drop

export impute, impute!, chain, chain!, drop, drop!, interp, interp!, ImputeError

Expand All @@ -18,6 +18,13 @@ function __init__()
Please qualify your calls with `Impute.<method>(...)` or explicitly import the symbol.
"""
)

@warn(
"""
The default limit for all impute functions will be 1.0 going forward.
If you depend on a specific threshold please pass in an appropriate `AbstractContext`.
"""
)
end

"""
Expand Down Expand Up @@ -45,112 +52,22 @@ const global imputation_methods = Dict{Symbol, Type}(
:nocb => NOCB,
)

"""
impute!(data, method::Symbol=:interp, args...; limit::Float64=0.1)
Looks up the `Imputor` type for the `method`, creates it and calls
`impute!(imputor::Imputor, data, limit::Float64)` with it.
# Arguments
* `data`: the datset containing missing elements we should impute.
* `method::Symbol`: the imputation method to use
(options: [`:drop`, `:fill`, `:interp`, `:locf`, `:nocb`])
* `args::Any...`: any arguments you should pass to the `Imputor` constructor.
* `limit::Float64`: missing data ratio limit/threshold (default: 0.1)
"""
function impute!(data, method::Symbol, args...; limit::Float64=0.1)
imputor_type = imputation_methods[method]
imputor = length(args) > 0 ? imputor_type(args...) : imputor_type()
return impute!(imputor, data, limit)
end

"""
impute!(data, missing::Function, method::Symbol=:interp, args...; limit::Float64=0.1)
Creates the appropriate `Imputor` type and `Context` (using `missing` function) in order to call
`impute!(imputor::Imputor, ctx::Context, data)` with them.
# Arguments
* `data`: the datset containing missing elements we should impute.
* `missing::Function`: the missing data function to use
* `method::Symbol`: the imputation method to use
(options: [`:drop`, `:fill`, `:interp`, `:locf`, `:nocb`])
* `args::Any...`: any arguments you should pass to the `Imputor` constructor.
* `limit::Float64`: missing data ratio limit/threshold (default: 0.1)
"""
function impute!(data, missing::Function, method::Symbol, args...; limit::Float64=0.1)
imputor_type = imputation_methods[method]
imputor = length(args) > 0 ? imputor_type(args...) : imputor_type()
return Context(; limit=limit, is_missing=missing)() do ctx
impute!(imputor, ctx, data)
include("deprecated.jl")

let
for (k, v) in imputation_methods
local typename = nameof(v)
local f = k
local f! = Symbol(k, :!)

# NOTE: The
@eval begin
$f(data; kwargs...) = impute($typename(; context=Context(Dict(kwargs...))), data)
$f!(data; kwargs...) = impute!($typename(; context=Context(Dict(kwargs...))), data)
$f(; kwargs...) = data -> impute($typename(; context=Context(Dict(kwargs...))), data)
$f!(; kwargs...) = data -> impute!($typename(; context=Context(Dict(kwargs...))), data)
end
end
end

"""
impute(data, args...; kwargs...)
Copies the `data` before calling `impute!(new_data, args...; kwargs...)`
"""
function impute(data, args...; kwargs...)
return impute!(deepcopy(data), args...; kwargs...)
end

"""
chain!(data, missing::Function, imputors::Imputor...; kwargs...)
Creates a `Chain` with `imputors` and calls `impute!(imputor, missing, data; kwargs...)`
"""
function chain!(data, missing::Function, imputors::Imputor...; kwargs...)
imputor = Chain(imputors...)
return impute!(imputor, missing, data; kwargs...)
end

"""
chain!(data, imputors::Imputor...; kwargs...)
Creates a `Chain` with `imputors` and calls `impute!(imputor, data; kwargs...)`
"""
function chain!(data, imputors::Imputor...; kwargs...)
imputor = Chain(imputors...)
return impute!(imputor, data; kwargs...)
end

"""
chain(data, args...; kwargs...)
Copies the `data` before calling `chain!(data, args...; kwargs...)`
"""
function chain(data, args...; kwargs...)
result = deepcopy(data)
return chain!(data, args...; kwargs...)
end

"""
drop!(data; limit=1.0)
Utility method for `impute!(data, :drop; limit=limit)`
"""
drop!(data; limit=1.0) = impute!(data, :drop; limit=limit)

"""
drop(data; limit=1.0)
Utility method for `impute(data, :drop; limit=limit)`
"""
Iterators.drop(data; limit=1.0) = impute(data, :drop; limit=limit)

"""
interp!(data; limit=1.0)
Utility method for `impute!(data, :interp; limit=limit)`
"""
interp!(data; limit=1.0) = impute!(data, :interp; limit=limit)

"""
interp(data; limit=1.0)
Utility method for `impute(data, :interp; limit=limit)`
"""
interp(data; limit=1.0) = impute(data, :interp; limit=limit)

end # module
18 changes: 17 additions & 1 deletion src/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,29 @@ mutable struct Context <: AbstractContext
end

function Context(;
limit::Float64=1.0,
limit::Float64=0.1,
is_missing::Function=ismissing,
on_complete::Function=complete
)
Context(0, 0, limit, is_missing, on_complete)
end

# The constructor only exists for legacy reasons
# We should drop this when we're ready to stop accepting limit in
# arbitrary impute functions.
function Context(d::Dict)
if haskey(d, :context)
return d[:context]
else haskey(d, :limit)
return Context(;
# We using a different default limit value here for legacy reason.
limit=get(d, :limit, 1.0),
is_missing=get(d, :is_missing, ismissing),
on_complete=get(d, :on_complete, complete),
)
end
end

function (ctx::Context)(f::Function)
_ctx = copy(ctx)
_ctx.num = 0
Expand Down
43 changes: 9 additions & 34 deletions src/imputors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,80 +11,55 @@ abstract type Imputor end


"""
impute(imp::Imputor, data, limit=0.1)
impute(imp::Imputor, ctx, data)
impute(imp::Imputor, data)
Copies the `data` before calling the corresponding `impute!(imp, ...)` call.
"""
impute(imp::Imputor, data) = impute!(imp, deepcopy(data))
impute(imp::Imputor, ctx::AbstractContext, data) = impute!(imp, ctx, deepcopy(data))

"""
impute!(imp::Imputor, data, limit::Float64=0.1)
Creates a `Context` using information about `data`. These include
1. missing data function which defaults to `missing`
2. number of elements: `*(size(data)...)`
# Arguments
* `imp::Imputor`: the Imputor method to use
* `data`: the data to impute
* `limit::Float64: missing data ratio limit/threshold (default: 0.1)`
# Return
* the input `data` with values imputed.
"""
function impute!(imp::Imputor, data, limit::Float64=0.1)
Context(; limit=limit)() do ctx
return impute!(imp, ctx, data)
end
function impute(imp::Imputor, data)
impute!(imp, deepcopy(data))
end

"""
impute!(imp::Imputor, ctx::AbstractContext, data::AbstractMatrix)
impute!(imp::Imputor, data::AbstractMatrix)
Imputes the data in a matrix by imputing the values 1 column at a time;
if this is not the desired behaviour custom imputor methods should overload this method.
# Arguments
* `imp::Imputor`: the Imputor method to use
* `ctx::AbstractContext`: the contextual information for missing data
* `data::AbstractMatrix`: the data to impute
# Returns
* `AbstractMatrix`: the input `data` with values imputed
"""
function impute!(imp::Imputor, ctx::AbstractContext, data::AbstractMatrix)
function impute!(imp::Imputor, data::AbstractMatrix)
for i in 1:size(data, 2)
impute!(imp, ctx, view(data, :, i))
impute!(imp, view(data, :, i))
end
return data
end

"""
impute!(imp::Imputor, ctx::AbstractContext, table)
impute!(imp::Imputor, table)
Imputes the data in a table by imputing the values 1 column at a time;
if this is not the desired behaviour custom imputor methods should overload this method.
# Arguments
* `imp::Imputor`: the Imputor method to use
* `ctx::AbstractContext`: the contextual information for missing data
* `table`: the data to impute
# Returns
* the input `data` with values imputed
"""
function impute!(imp::Imputor, ctx::AbstractContext, table)
function impute!(imp::Imputor, table)
@assert istable(table)
# Extract a columns iterate that we should be able to use to mutate the data.
# NOTE: Mutation is not guaranteed for all table types, but it avoid copying the data
columntable = Tables.columns(table)

for cname in propertynames(columntable)
impute!(imp, ctx, getproperty(columntable, cname))
impute!(imp, getproperty(columntable, cname))
end

return table
Expand Down
31 changes: 4 additions & 27 deletions src/imputors/chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,44 +18,21 @@ Creates a Chain using the `Imputor`s provided (ordering matters).
Chain(imputors::Imputor...) = Chain(collect(imputors))

"""
impute!(imp::Chain, missing::Function, data; limit::Float64=0.1)
impute!(imp::Chain, data)
Creates a `Context` and runs the `Imputor`s on the supplied data.
Runs the `Imputor`s on the supplied data.
# Arguments
* `imp::Chain`: the chain to run
* `missing::Function`: the missing function to use in the `Context` to pass to the `Imputor`s
* `data`: our data to impute
* `limit::Float64`: the missing data ration limit/threshold
# Returns
* our imputed data
"""
function impute!(imp::Chain, missing::Function, data; limit::Float64=0.1)
context = Context(; limit=limit, is_missing=missing)

function impute!(imp::Chain, data)
for imputor in imp.imputors
data = impute!(imputor, context, data)
data = impute!(imputor, data)
end

return data
end

"""
impute!(imp::Chain, data; limit::Float64=0.1)
Infers the missing data function from the `data` and passes that to
`impute!(imp::Chain, missing::Function, data; limit::Float64=0.1)`.
# Arguments
* `imp::Chain`: the chain to run
* `data`: our data to impute
* `limit::Float64`: the missing data ration limit/threshold
# Returns
* our imputed data
"""
function impute!(imp::Chain, data; limit::Float64=0.1)
impute!(imp, ismissing, data; limit=limit)
end
Loading

0 comments on commit 2066b5b

Please sign in to comment.