Skip to content

Commit

Permalink
Various bug fixes (#222)
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBavenstrand committed Jun 8, 2024
2 parents c62a67b + 58004d6 commit 5953bd2
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 15 deletions.
2 changes: 2 additions & 0 deletions .safety-policy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@ security:
ignore-vulnerabilities:
66947:
reason: Not used
70612:
reason: Only used during documentation generation
continue-on-vulnerability-error: False
7 changes: 6 additions & 1 deletion mleko/dataset/transform/label_encoder_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,12 @@ def _fit(
logger.info(f"Fitting label encoder transformer ({len(self._features)}): {self._features}.")
for feature in self._features:
self._ensure_valid_feature_type(feature, data_schema, dataframe)
labels: list[str] = get_column(dataframe, feature).unique(dropna=True) # type: ignore
labels: list[str] = [
label
for label in get_column(dataframe, feature).to_arrow().unique().to_pylist() # type: ignore
if label is not None
]

if not self._fit_using_label_dict(feature, labels):
logger.info(f"Assigning mappings for feature {feature!r}: {labels}.")
self._transformer[feature] = {label: i for i, label in enumerate(labels)}
Expand Down
14 changes: 13 additions & 1 deletion mleko/model/lgbm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pandas as pd
import vaex
from lightgbm.sklearn import _LGBM_ScikitEvalMetricType
from sklearn.utils.validation import NotFittedError, check_is_fitted

from mleko.dataset.data_schema import DataSchema
from mleko.utils.custom_logger import CustomLogger
Expand Down Expand Up @@ -246,7 +247,18 @@ def _fingerprint(self) -> Hashable:
Returns:
The fingerprint of the model.
"""
return (super()._fingerprint(), self._target, self._model.__class__.__qualname__)
is_fitted = True
try:
check_is_fitted(self._model) # type: ignore
except NotFittedError:
is_fitted = False

return (
super()._fingerprint(),
self._target,
self._model.__class__.__qualname__,
self._model.booster_.model_to_string() if is_fitted else None,
)

def _default_features(self, data_schema: DataSchema) -> tuple[str, ...]:
"""The default set of features to use for training.
Expand Down
1 change: 1 addition & 0 deletions mleko/model/tune/optuna_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def _fingerprint(self) -> Hashable:
CallableSourceFingerprinter().fingerprint(self._objective_function),
self._direction,
self._num_trials,
JsonFingerprinter().fingerprint(self._enqueue_trials) if self._enqueue_trials is not None else None,
CallableSourceFingerprinter().fingerprint(self._cv_folds) if self._cv_folds is not None else None,
OptunaSamplerFingerprinter().fingerprint(self._sampler),
OptunaPrunerFingerprinter().fingerprint(self._pruner),
Expand Down
15 changes: 2 additions & 13 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions tests/model/test_lgbm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,24 @@ def test_cache_fit_transform_no_validation(
).fit_transform(example_data_schema, example_vaex_dataframe_train, None, {})
mocked_fit_transform.assert_not_called()

def test_chache_fit_transform_with_refit(
self,
temporary_directory: Path,
example_data_schema: DataSchema,
example_vaex_dataframe_train: vaex.DataFrame,
):
"""Should train the model using fit_transform and use the cache once called again with refit."""
lgbm_model = LGBMModel(
cache_directory=temporary_directory, target="target", model=lgb.LGBMClassifier(objective="binary")
)
lgbm_model.fit(example_data_schema, example_vaex_dataframe_train.copy())
lgbm_model.transform(example_data_schema, example_vaex_dataframe_train.copy())
lgbm_model.fit(example_data_schema, example_vaex_dataframe_train.copy(), hyperparameters={"num_leaves": 10})

with patch.object(LGBMModel, "_transform") as mocked_transform:
lgbm_model.transform(example_data_schema, example_vaex_dataframe_train.copy())
mocked_transform.assert_called()

def test_cache_fit_and_transform(
self,
temporary_directory: Path,
Expand Down

0 comments on commit 5953bd2

Please sign in to comment.