Skip to content

Commit

Permalink
adressing PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vnherdeiro committed Sep 16, 2024
1 parent 6d20ef8 commit d715311
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,14 +674,16 @@ def _more_tags(self) -> Dict[str, Any]:
}

def __sklearn_tags__(self):
# scikit-learn 1.6 introduced an __sklearn__tags() method intended to replace _more_tags().
# _more_tags() can be removed whenever lightgbm's minimum supported scikit-learn version
# is >=1.6.
# ref: https://github.com/microsoft/LightGBM/pull/6651
tags = super().__sklearn_tags__()
more_tags = self._more_tags()
tags.input_tags.allow_nan = more_tags["allow_nan"]
tags.input_tags.sparse = "sparse" in more_tags["X_types"]
tags.target_tags.one_d_labels = "1dlabels" in more_tags["X_types"]
tags._xfail_checks = more_tags["_xfail_checks"]
if more_tags or set(tagged_input_types).difference({"2darray", "sparse", "1dlabels"}):
_log_warning(f"Some tags sklearn tag values are missing from __sklearn_tags__: `{more_tags}`")
return tags

def __sklearn_is_fitted__(self) -> bool:
Expand Down

0 comments on commit d715311

Please sign in to comment.