Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

strategy or example for doing a stratified k-fold #950

Closed
lazarusA opened this issue Jan 5, 2024 · 2 comments
Closed

strategy or example for doing a stratified k-fold #950

lazarusA opened this issue Jan 5, 2024 · 2 comments

Comments

@lazarusA
Copy link

lazarusA commented Jan 5, 2024

I see that the function partition can take as argument stratefy and that https://juliaml.github.io/MLUtils.jl/dev/#Examples can do k-folds..., is there out there a workflow for combining both?

@ablaom
Copy link
Member

ablaom commented Jan 5, 2024

Well, MLJ's evaluate and TunedModel have the option resampling=StratifiedCV().

So for example:

using MLJ
import Imbalance
using Random

# generate unbalanced synthetic data:
class_probs = [0.1, 0.9]
num_rows, num_features = 1000, 3
X, y = Imbalance.generate_imbalanced_data(
    num_rows,
    num_features;
    class_probs,
    rng=42
)

# stratified split into train/test:
rng = Random.Xoshiro(123)
(Xtrain, Xtest), (ytrain, ytest) = partition((X, y), 0.6; stratify=y, multi=true, rng)

# instantiate a random forest:
RandomForestClassifier = @iload RandomForestClassifier  pkg=DecisionTree
forest = RandomForestClassifier()

# evaluation of default forest on training set using stratified cv:
evaluate(forest, Xtrain, ytrain; resampling=StratifiedCV(; nfolds=6, rng), measure=log_loss)

# PerformanceEvaluation object with these fields:
#   model, measure, operation, measurement, per_fold,
#   per_observation, fitted_params_per_fold,
#   report_per_fold, train_test_rows, resampling, repeats
# Extract:
# ┌──────────────────────┬───────────┬─────────────┬─────────┬─────────────────────────────────────────────────────────┐
# │ measure              │ operation │ measurement │ 1.96*SE │ per_fold                                                │
# ├──────────────────────┼───────────┼─────────────┼─────────┼─────────────────────────────────────────────────────────┤
# │ LogLoss(             │ predict   │ 0.00199     │ 0.00135 │ [0.00213, 0.00478, 0.00174, 0.000302, 0.00091, 0.00206] │
# │   tol = 2.22045e-16) │           │             │         │                                                         │
# └──────────────────────┴───────────┴─────────────┴─────────┴─────────────────────────────────────────────────────────┘

# tune the model using the training set and stratified-cv:
r = range(forest, :n_subfeatures, lower=1, upper=3)
tuned_forest = TunedModel(
    forest;
    tuning=Grid(),
    range = r,
    resampling=StratifiedCV(; nfolds=6, rng),
    measure = log_loss,
    )
mach = machine(tuned_forest, Xtrain, ytrain) |> fit!

# evaluate the optimal model on the test set, using stratified-cv
best_forest = report(mach).best_model
evaluate(best_forest, Xtest, ytest;
         resampling=StratifiedCV(; nfolds=6, rng), measure=log_loss)

# PerformanceEvaluation object with these fields:
#   model, measure, operation, measurement, per_fold,
#   per_observation, fitted_params_per_fold,
#   report_per_fold, train_test_rows, resampling, repeats
# Extract:
# ┌──────────────────────┬───────────┬─────────────┬─────────┬──────────────────────────────────────────────────────────────┐
# │ measure              │ operation │ measurement │ 1.96*SE │ per_fold                                                     │
# ├──────────────────────┼───────────┼─────────────┼─────────┼──────────────────────────────────────────────────────────────┤
# │ LogLoss(             │ predict   │ 2.22e-16    │ 0.0     │ [2.22e-16, 2.22e-16, 2.22e-16, 2.22e-16, 2.22e-16, 2.22e-16] │
# │   tol = 2.22045e-16) │           │             │         │                                                              │
# └──────────────────────┴───────────┴─────────────┴─────────┴──────────────────────────────────────────────────────────────┘

Does this answer your question?

@ablaom
Copy link
Member

ablaom commented Jan 17, 2024

Closing as no response. Feel free to re-open

@ablaom ablaom closed this as completed Jan 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants