sklearn-compat0.1.5
Published
Ease support for compatible scikit-learn estimators across versions
pip install sklearn-compat
Package Downloads
Requires Python
>=3.8
Dependencies
- pytest
; extra == "tests" - scikit-learn
<1.9,>=1.2 - ipython
; extra == "dev" - mkdocs
; extra == "docs" - mkdocs-material
; extra == "docs" - pre-commit
; extra == "lint" - pandas
; extra == "tests" - polars
; extra == "tests" - pyarrow
; extra == "tests" - pytest-cov
; extra == "tests" - pytest-xdist
; extra == "tests" - pytz
; extra == "tests"
Ease multi-version support for scikit-learn compatible library
sklearn-compat is a small Python package that help developer writing scikit-learn
compatible estimators to support multiple scikit-learn versions. Note that we provide
a vendorable version of this package in the src/sklearn_compat/_sklearn_compat.py
file if you do not want to depend on sklearn-compat as a package. The package is
available on PyPI and on
conda-forge.
As maintainers of third-party libraries depending on scikit-learn such as
imbalanced-learn,
skrub, or
skops, we usually identified small breaking
changes on the "private" developer utilities of scikit-learn. Indeed, each of these
third-party libraries code the exact same utilities when it comes to support multiple
scikit-learn versions. We therefore decided to factorize these utilities in a
dedicated package that we update at each scikit-learn release.
When it comes to support multiple scikit-learn versions, the initial plan as of
December 2024 is to follow the SPEC0
recommendations. It means that this utility will support at least the scikit-learn
versions up to 2 years or about 4 versions. The current version of sklearn-compat
supports scikit-learn >= 1.2.
How to adapt your scikit-learn code
In this section, we describe succinctly the changes you need to do to your code to
support multiple scikit-learn versions using sklearn-compat as a package. If you
use the vendored version of sklearn-compat, all imports will be changed from:
from sklearn_compat.any_submodule import any_function
to
from path.to._sklearn_compat import any_function
where _sklearn_compat is the vendored version of sklearn-compat in your project.
Upgrading to scikit-learn 1.8
DataFrame related utility functions
The functions is_df_or_series, is_pandas_df, is_pandas_df_or_series,
is_polars_df, is_polars_df_or_series, is_pyarrow_data have been added in
scikit-learn 1.8. So we backport it such that you can have access to it in
scikit-learn 1.2+. The pattern is the following:
from sklearn_compat.utils._dataframe import (
is_df_or_series,
is_pandas_df,
is_pandas_df_or_series,
is_polars_df,
is_polars_df_or_series,
is_pyarrow_data,
)
is_df_or_series(X)
is_pandas_df(X)
is_pandas_df_or_series(X)
is_polars_df(X)
is_polars_df_or_series(X)
is_pyarrow_data(X)
Before those functions could have been named with a leading underscore and were
available in the sklearn.utils.validation module.
_check_targets function
In scikit-learn 1.8, _check_targets from sklearn.metrics._classification now
returns 4 values (y_type, y_true, y_pred, sample_weight) instead of 3. For backward
compatibility with scikit-learn < 1.8, we provide a wrapper that ensures the function
always returns 4 values. You can import it as:
from sklearn_compat.metrics._classification import _check_targets
y_type, y_true, y_pred, sample_weight = _check_targets(
y_true, y_pred, sample_weight=None
)
Upgrading to scikit-learn 1.7
There is no known breaking change for scikit-learn 1.7.
Upgrading to scikit-learn 1.6
is_clusterer function
The function is_clusterer has been added in scikit-learn 1.6. So we backport it
such that you can have access to it in scikit-learn 1.2+. The pattern is the following:
from sklearn.cluster import KMeans
from sklearn_compat.base import is_clusterer
is_clusterer(KMeans())`
validate_data function
Your previous code could have looked like this:
class MyEstimator(BaseEstimator):
def fit(self, X, y=None):
X = self._validate_data(X, force_all_finite=True)
return self
There is two major changes in scikit-learn 1.6:
validate_datahas been moved tosklearn.utils.validation.force_all_finiteis deprecated in favor of theensure_all_finiteparameter.
You can now use the following code for backward compatibility:
from sklearn_compat.utils.validation import validate_data
class MyEstimator(BaseEstimator):
def fit(self, X, y=None):
X = validate_data(self, X=X, ensure_all_finite=True)
return self
check_array and check_X_y functions
The parameter force_all_finite has been deprecated in favor of the ensure_all_finite
parameter. You need to modify the call to the function to use the new parameter. So,
the change is the same as for validate_data and will look like this:
from sklearn.utils.validation import check_array, check_X_y
check_array(X, force_all_finite=True)
check_X_y(X, y, force_all_finite=True)
to:
from sklearn_compat.utils.validation import check_array, check_X_y
check_array(X, ensure_all_finite=True)
check_X_y(X, y, ensure_all_finite=True)
_check_n_features and _check_feature_names functions
Similarly to validate_data, these two functions have been moved to
sklearn.utils.validation instead of being methods of the estimators. So the following
code:
class MyEstimator(BaseEstimator):
def fit(self, X, y=None):
self._check_n_features(X, reset=True)
self._check_feature_names(X, reset=True)
return self
becomes:
from sklearn_compat.utils.validation import _check_n_features, _check_feature_names
class MyEstimator(BaseEstimator):
def fit(self, X, y=None):
_check_n_features(self, X, reset=True)
_check_feature_names(self, X, reset=True)
return self
Note that it is best to call validate_data with skip_check_array=True instead of
calling these private functions. See the section above regarding validate_data.
Tags, __sklearn_tags__ and estimator tags
The estimator tags infrastructure in scikit-learn 1.6 has changed. In order to be
compatible with multiple scikit-learn versions, your estimator should implement both
_more_tags and __sklearn_tags__:
class MyEstimator(BaseEstimator):
def _more_tags(self):
return {"non_deterministic": True, "poor_score": True}
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.non_deterministic = True
tags.regressor_tags.poor_score = True
return tags
Notice however that some tags have different names in scikit-learn 1.6. For instance,
to indicate that an estimator only handles binary classification, it needed to have the
tag binary_only set to True, whereas in scikit-learn 1.6,
classifier_tags.multi_class needs to be set to False.
In order to get the tags of a given estimator, you can use the get_tags function:
from sklearn_compat.utils import get_tags
tags = get_tags(MyEstimator())
Which uses sklearn.utils.get_tags under the hood from scikit-learn 1.6+.
In case you want to extend the tags, you can inherit from the available tags:
from sklearn_compat.utils._tags import Tags, InputTags
class MyInputTags(InputTags):
dataframe: bool = False
class MyEstimator(BaseEstimator):
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags = MyInputTags(
one_d_array=tags.input_tags.one_d_array,
two_d_array=tags.input_tags.two_d_array,
sparse=tags.input_tags.sparse,
category=True,
dataframe=True,
string=tags.input_tags.string,
dict=tags.input_tags.dict,
positive_only=tags.input_tags.positive_only,
allow_nan=tags.input_tags.allow_nan,
pairwise=tags.input_tags.pairwise,
)
return tags
check_estimator and parametrize_with_checks functions
The new tags don't include a _xfail_checks tags, and instead, the tests which are
expected to fail are directly passed to the check_estimator and
parametrize_with_checks functions. The two functions available in this package are
compatible with the new signature, and patch the estimator in older scikit-learn
versions to include the expected failed checks in their tags so that you don't need
to include them both in your tests and in your _xfail_checks old tags.
from sklearn_compat.utils.testing import parametrize_with_checks
from mypackage.myestimator import MyEstimator1, MyEstimator2
EXPECTED_FAILED_CHECKS = {
"MyEstimator1": {"check_name1": "reason1", "check_name2": "reason2"},
"MyEstimator2": {"check_name3": "reason3"},
}
@parametrize_with_checks([MyEstimator1(), MyEstimator2()],
expected_failed_checks=lambda est: EXPECTED_FAILED_CHECKS.get(
est.__class__.__name__, {}
)
)
def test_my_estimator(estimator, check):
check(estimator)
Upgrading to scikit-learn 1.5
In scikit-learn 1.5, many developer utilities have been moved to dedicated modules. We provide a compatibility layer such that you don't have to check the version or try to import the utilities from different modules.
In the future, when supporting scikit-learn 1.6+, you will have to change the import from:
from sklearn_compat.utils._indexing import _safe_indexing
to
from sklearn.utils._indexing import _safe_indexing
Thus, the module path will already be correct. Now, we will go into details for each module and function impacted.
extmath module
The function safe_sqr and _approximate_mode have been moved from sklearn.utils to
sklearn.utils.extmath.
So some code looking like this:
from sklearn.utils import safe_sqr, _approximate_mode
safe_sqr(np.array([1, 2, 3]))
_approximate_mode(class_counts=np.array([4, 2]), n_draws=3, rng=0)
becomes:
from sklearn_compat.utils.extmath import safe_sqr, _approximate_mode
safe_sqr(np.array([1, 2, 3]))
_approximate_mode(class_counts=np.array([4, 2]), n_draws=3, rng=0)
type_of_target function
The function type_of_target accepts a new parameter raise_unknown. This parameter is
available in the sklearn_compat.utils.multiclass.type_of_target function.
from sklearn_compat.utils.multiclass import type_of_target
y = []
# raise an error with unknown target type
type_of_target(y, raise_unknown=True)
fixes module
The functions _in_unstable_openblas_configuration, _IS_32BIT and _IS_WASM have
been moved from sklearn.utils to sklearn.utils.fixes.
So the following code:
from sklearn.utils import (
_in_unstable_openblas_configuration,
_IS_32BIT,
_IS_WASM,
)
_in_unstable_openblas_configuration()
print(_IS_32BIT)
print(_IS_WASM)
becomes:
from sklearn_compat.utils.fixes import (
_in_unstable_openblas_configuration,
_IS_32BIT,
_IS_WASM,
)
_in_unstable_openblas_configuration()
print(_IS_32BIT)
print(_IS_WASM)
validation module
The function _to_object_array has been moved from sklearn.utils to
sklearn.utils.validation.
So the following code:
from sklearn.utils import _to_object_array
_to_object_array([np.array([0]), np.array([1])])
becomes:
from sklearn_compat.utils.validation import _to_object_array
_to_object_array([np.array([0]), np.array([1])])
_chunking module
The functions gen_batches, gen_even_slices and get_chunk_n_rows have been moved
from sklearn.utils to sklearn.utils._chunking. The function chunk_generator has
been moved to sklearn.utils._chunking as well but was renamed from _chunk_generator
to chunk_generator.
So the following code:
from sklearn.utils import (
_chunk_generator as chunk_generator,
gen_batches,
gen_even_slices,
get_chunk_n_rows,
)
_chunk_generator(range(10), 3)
gen_batches(7, 3)
gen_even_slices(10, 1)
get_chunk_n_rows(10)
becomes:
from sklearn_compat.utils._chunking import (
chunk_generator, gen_batches, gen_even_slices, get_chunk_n_rows,
)
chunk_generator(range(10), 3)
gen_batches(7, 3)
gen_even_slices(10, 1)
get_chunk_n_rows(10)
_indexing module
The utility functions _determine_key_type, _safe_indexing, _safe_assign,
_get_column_indices, resample and shuffle have been moved from sklearn.utils to
sklearn.utils._indexing.
So the following code:
import numpy as np
import pandas as pd
from sklearn.utils import (
_get_column_indices,
_safe_indexing,
_safe_assign,
resample,
shuffle,
)
_determine_key_type(np.arange(10))
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
_get_column_indices(df, key="b")
_safe_indexing(df, 1, axis=1)
_safe_assign(df, 1, np.array([7, 8, 9]))
array = np.arange(10)
resample(array, n_samples=20, replace=True, random_state=0)
shuffle(array, random_state=0)
becomes:
import numpy as np
import pandas as pd
from sklearn_compat.utils._indexing import (
_determine_key_type,
_safe_indexing,
_safe_assign,
_get_column_indices,
resample,
shuffle,
)
_determine_key_type(np.arange(10))
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
_get_column_indices(df, key="b")
_safe_indexing(df, 1, axis=1)
_safe_assign(df, 1, np.array([7, 8, 9]))
array = np.arange(10)
resample(array, n_samples=20, replace=True, random_state=0)
shuffle(array, random_state=0)
_mask module
The functions safe_mask, axis0_safe_slice and indices_to_mask have been moved from
sklearn.utils to sklearn.utils._mask.
So the following code:
from sklearn.utils import safe_mask, axis0_safe_slice, indices_to_mask
safe_mask(data, condition)
axis0_safe_slice(X, mask, X.shape[0])
indices_to_mask(indices, 5)
becomes:
from sklearn_compat.utils._mask import safe_mask, axis0_safe_slice, indices_to_mask
safe_mask(data, condition)
axis0_safe_slice(X, mask, X.shape[0])
indices_to_mask(indices, 5)
_missing module
The functions is_scalar_nan have been moved from sklearn.utils to
sklearn.utils._missing. The function _is_pandas_na has been moved to
sklearn.utils._missing as well and renamed to is_pandas_na.
So the following code:
from sklearn.utils import is_scalar_nan, _is_pandas_na
is_scalar_nan(float("nan"))
_is_pandas_na(float("nan"))
becomes:
from sklearn_compat.utils._missing import is_scalar_nan, is_pandas_na
is_scalar_nan(float("nan"))
is_pandas_na(float("nan"))
_user_interface module
The function _print_elapsed_time has been moved from sklearn.utils to
sklearn.utils._user_interface.
So the following code:
from sklearn.utils import _print_elapsed_time
with _print_elapsed_time("sklearn_compat", "testing"):
time.sleep(0.1)
becomes:
from sklearn_compat.utils._user_interface import _print_elapsed_time
with _print_elapsed_time("sklearn_compat", "testing"):
time.sleep(0.1)
_optional_dependencies module
The functions check_matplotlib_support and check_pandas_support have been moved from
sklearn.utils to sklearn.utils._optional_dependencies.
So the following code:
from sklearn.utils import check_matplotlib_support, check_pandas_support
check_matplotlib_support("sklearn_compat")
check_pandas_support("sklearn_compat")
becomes:
from sklearn_compat.utils._optional_dependencies import (
check_matplotlib_support, check_pandas_support
)
check_matplotlib_support("sklearn_compat")
check_pandas_support("sklearn_compat")
Upgrading to scikit-learn 1.4
process_routing and _raise_for_params functions
The signature of the process_routing function changed in scikit-learn 1.4. You can
import the function from sklearn_compat.utils.metadata_routing. The pattern will
change from:
from sklearn.utils.metadata_routing import process_routing
class MetaEstimator(BaseEstimator):
def fit(self, X, y, sample_weight=None, **fit_params):
params = process_routing(self, "fit", fit_params, sample_weight=sample_weight)
return self
becomes:
from sklearn_compat.utils.metadata_routing import process_routing
class MetaEstimator(BaseEstimator):
def fit(self, X, y, sample_weight=None, **fit_params):
params = process_routing(self, "fit", sample_weight=sample_weight, **fit_params)
return self
The _raise_for_params function was also introduced in scikit-learn 1.4. You can import
it from sklearn_compat.utils.metadata_routing.
from sklearn_compat.utils.metadata_routing import _raise_for_params
_raise_for_params(params, self, "fit")
Upgrading to scikit-learn 1.2
Parameter validation
scikit-learn introduced a new way to validate parameters at fit time. The recommended
way to support this feature in scikit-learn 1.2+ is to inherit from
sklearn.base.BaseEstimator and decorate the fit method using the decorator
sklearn.base._fit_context. For functions, the decorator to use is
sklearn.utils._param_validation.validate_params.
We provide the function sklearn_compat.base._fit_context such that you can always
decorate the fit method of your estimator. Equivalently, you can use the function
sklearn_compat.utils._param_validation.validate_params to validate the parameters
of your function.
Contributing
You can contribute to this package by:
- reporting an incompatibility with a scikit-learn version on the issue tracker. We will do our best to provide a compatibility layer.
- opening a pull-request to add a compatibility layer that you encountered when writing your scikit-learn compatible estimator.
Be aware that to be able to provide sklearn-compat as a vendorable package and a
dependency, all the changes are implemented in the
src/sklearn_compat/_sklearn_compat.py (indeed not the nicest experience). Then, we
need to import the changes made in this file in the submodules to use sklearn-compat
as a dependency.