Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Pipeline and TransformedTargetModel to support feature importances #963

Merged
merged 5 commits into from
Mar 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "1.1.3"
version = "1.2.0"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
36 changes: 26 additions & 10 deletions src/composition/models/pipelines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -402,15 +402,13 @@ end

# ## Methods to extend a pipeline learning network

# The "front" of a pipeline network, as we grow it, consists of a
# "predict" and a "transform" node. Once the pipeline is complete
# (after a series of `extend` operations - see below) the "transform"
# node is what is used to deliver the output of `transform(pipe)` in
# the exported model, and the "predict" node is what will be used to
# deliver the output of `predict(pipe). Both nodes can be changed by
# `extend` but only the "active" node is propagated. Initially
# "transform" is active and "predict" only becomes active when a
# supervised model is encountered; this change is permanent.
# The "front" of a pipeline network, as we grow it, consists of a "predict" and a
# "transform" node. Once the pipeline is complete (after a series of `extend` operations -
# see below) the "transform" node is what is used to deliver the output of
# `transform(pipe, ...)` in the exported model, and the "predict" node is what will be
# used to deliver the output of `predict(pipe, ...). Both nodes can be changed by `extend`
# but only the "active" node is propagated. Initially "transform" is active and "predict"
# only becomes active when a supervised model is encountered; this change is permanent.
# https://github.com/JuliaAI/MLJClusteringInterface.jl/issues/10

abstract type ActiveNodeOperation end
Expand Down Expand Up @@ -587,7 +585,10 @@ end
# component, only its `abstract_type`. See comment at top of page.

MMI.supports_training_losses(pipe::SupervisedPipeline) =
MMI.supports_training_losses(getproperty(pipe, supervised_component_name(pipe)))
MMI.supports_training_losses(supervised_component(pipe))

MMI.reports_feature_importances(pipe::SupervisedPipeline) =
MMI.reports_feature_importances(supervised_component(pipe))

# This trait cannot be defined at the level of types (see previous comment):
function MMI.iteration_parameter(pipe::SupervisedPipeline)
Expand Down Expand Up @@ -618,3 +619,18 @@ function MMI.training_losses(pipe::SupervisedPipeline, pipe_report)
report = getproperty(pipe_report, supervised_name)
return training_losses(supervised, report)
end


# ## Feature importances

function feature_importances(pipe::SupervisedPipeline, fitresult, report)
# locate the machine associated with the supervised component:
supervised_name = MLJBase.supervised_component_name(pipe)
predict_node = fitresult.interface.predict
mach = only(MLJBase.machines_given_model(predict_node)[supervised_name])

# To extract the feature_importances, we can't do `feature_importances(mach)` because
# `mach.model` is just a symbol; instead we do:
supervised = MLJBase.supervised_component(pipe)
return feature_importances(supervised, mach.fitresult, mach.report[:fit])
end
25 changes: 20 additions & 5 deletions src/composition/models/transformed_target_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,28 +237,41 @@ const ERR_TT_MISSING_REPORT =
"Cannot find report for `TransformedTargetModel` atomic model, from which "*
"to extract training losses. "

function training_losses(composite::SomeTT, tt_report)
function MMI.training_losses(composite::SomeTT, tt_report)
hasproperty(tt_report, :model) || throw(ERR_TT_MISSING_REPORT)
atomic_report = getproperty(tt_report, :model)
return training_losses(composite.model, atomic_report)
end


# # FEATURE IMPORTANCES

function MMI.feature_importances(composite::SomeTT, fitresult, report)
# locate the machine associated with the supervised component:
predict_node = fitresult.interface.predict
mach = only(MLJBase.machines_given_model(predict_node)[:model])

# To extract the feature_importances, we can't do `feature_importances(mach)` because
# `mach.model` is just a symbol; instead we do:
return feature_importances(composite.model, mach.fitresult, mach.report[:fit])
end


## MODEL TRAITS

MMI.package_name(::Type{<:SomeTT}) = "MLJBase"
MMI.package_license(::Type{<:SomeTT}) = "MIT"
MMI.package_uuid(::Type{<:SomeTT}) = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MMI.is_wrapper(::Type{<:SomeTT}) = true
MMI.package_url(::Type{<:SomeTT}) =
"https://github.com/JuliaAI/MLJBase.jl"
MMI.package_url(::Type{<:SomeTT}) = "https://github.com/JuliaAI/MLJBase.jl"

for New in TT_TYPE_EXS
quote
MMI.iteration_parameter(::Type{<:$New{M}}) where M =
MLJBase.prepend(:model, iteration_parameter(M))
end |> eval
for trait in [:input_scitype,
for trait in [
:input_scitype,
:output_scitype,
:target_scitype,
:fit_data_scitype,
Expand All @@ -270,8 +283,10 @@ for New in TT_TYPE_EXS
:supports_class_weights,
:supports_online,
:supports_training_losses,
:reports_feature_importances,
:is_supervised,
:prediction_type]
:prediction_type
]
quote
MMI.$trait(::Type{<:$New{M}}) where M = MMI.$trait(M)
end |> eval
Expand Down
8 changes: 0 additions & 8 deletions src/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -808,10 +808,6 @@ julia> fitted_params(mach).logistic_classifier
intercept = 0.0883301599726305,)
```

Additional keys, `machines` and `fitted_params_given_machine`, give a
list of *all* machines in the underlying network, and a dictionary of
fitted parameters keyed on those machines.

See also [`report`](@ref)

"""
Expand Down Expand Up @@ -852,10 +848,6 @@ julia> report(mach).linear_binary_classifier

```

Additional keys, `machines` and `report_given_machine`, give a
list of *all* machines in the underlying network, and a dictionary of
reports keyed on those machines.

See also [`fitted_params`](@ref)

"""
Expand Down
20 changes: 17 additions & 3 deletions test/_models/DecisionTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ export DecisionTreeClassifier, DecisionTreeRegressor

import MLJBase
import MLJBase: @mlj_model, metadata_pkg, metadata_model

import MLJBase.Tables
using ScientificTypes

using CategoricalArrays
Expand Down Expand Up @@ -98,8 +98,11 @@ function MLJBase.fit(model::DecisionTreeClassifier, verbosity::Int, X, y)
#> empty values):

cache = nothing
report = (classes_seen=classes_seen,
print_tree=TreePrinter(tree))
report = (
classes_seen=classes_seen,
print_tree=TreePrinter(tree),
features=Tables.columnnames(Tables.columns(X)) |> collect,
)

return fitresult, cache, report
end
Expand Down Expand Up @@ -137,6 +140,17 @@ function MLJBase.predict(model::DecisionTreeClassifier
for i in 1:size(y_probabilities, 1)]
end

MLJBase.reports_feature_importances(::Type{<:DecisionTreeClassifier}) = true

function MMI.feature_importances(m::DecisionTreeClassifier, fitresult, report)
features = report.features
fi = DecisionTree.impurity_importance(first(fitresult), normalize=true)
fi_pairs = Pair.(features, fi)
# sort descending
sort!(fi_pairs, by= x->-x[2])

return fi_pairs
end

## REGRESSOR

Expand Down
8 changes: 4 additions & 4 deletions test/composition/models/network_composite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ end
MLJBase.reporting_operations(::Type{<:ReportingScaler}) = (:transform, )

MLJBase.transform(model::ReportingScaler, _, X) = (
model.alpha*Tables.matrix(X),
Tables.table(model.alpha*Tables.matrix(X)),
(; nrows = size(MLJBase.matrix(X))[1]),
)

Expand Down Expand Up @@ -143,7 +143,7 @@ composite = WatermelonComposite(
Set([:scaler, :clusterer, :classifier1, :training_loss, :len])
@test fitr.scaler == (nrows=10,)
@test fitr.clusterer == (labels=['A', 'B', 'C'],)
@test Set(keys(fitr.classifier1)) == Set([:classes_seen, :print_tree])
@test Set(keys(fitr.classifier1)) == Set([:classes_seen, :print_tree, :features])
@test fitr.training_loss isa Real
@test fitr.len == 10

Expand All @@ -164,7 +164,7 @@ composite = WatermelonComposite(
Set([:scaler, :clusterer, :classifier1, :finalizer])
@test predictr.scaler == (nrows=5,)
@test predictr.clusterer == (labels=['A', 'B', 'C'],)
@test Set(keys(predictr.classifier1)) == Set([:classes_seen, :print_tree])
@test Set(keys(predictr.classifier1)) == Set([:classes_seen, :print_tree, :features])
@test predictr.finalizer == (nrows=5,)

o, predictr = predict(composite, f, selectrows(X, 1:2))
Expand All @@ -174,7 +174,7 @@ composite = WatermelonComposite(
Set([:scaler, :clusterer, :classifier1, :finalizer])
@test predictr.scaler == (nrows=2,) # <----------- different
@test predictr.clusterer == (labels=['A', 'B', 'C'],)
@test Set(keys(predictr.classifier1)) == Set([:classes_seen, :print_tree])
@test Set(keys(predictr.classifier1)) == Set([:classes_seen, :print_tree, :features])
@test predictr.finalizer == (nrows=2,) # <---------- different

r = MMI.report(composite, Dict(:fit => fitr, :predict=> predictr))
Expand Down
10 changes: 10 additions & 0 deletions test/composition/models/pipelines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,16 @@ end
rm(filename)
end

@testset "feature importances" begin
# the DecisionTreeClassifier in /test/_models/ supports feature importances.
pipe = Standardizer |> DecisionTreeClassifier()
@test reports_feature_importances(pipe)
X, y = @load_iris
fitresult, _, report = MLJBase.fit(pipe, 0, X, y)
features = first.(feature_importances(pipe, fitresult, report))
@test Set(features) == Set(keys(X))
end

end # module

true
9 changes: 9 additions & 0 deletions test/composition/models/transformed_target_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,5 +177,14 @@ y = rand(5)
@test training_losses(mach) == ones(5)
end

@testset "feature_importances" begin
X, y = @load_iris
atom = DecisionTreeClassifier()
model = TransformedTargetModel(atom, transformer=identity, inverse=identity)
@test reports_feature_importances(model)
fitresult, _, rpt = MMI.fit(model, 0, X, y)
@test Set(first.(feature_importances(model, fitresult, rpt))) == Set(keys(X))
end

end
true
Loading