diff --git a/examples/sklearnex/basic_statistics_spmd.py b/examples/sklearnex/basic_statistics_spmd.py index 909c842cb9..e469f7d9cb 100644 --- a/examples/sklearnex/basic_statistics_spmd.py +++ b/examples/sklearnex/basic_statistics_spmd.py @@ -60,5 +60,5 @@ def generate_data(par, size, seed=777): bss = BasicStatisticsSpmd(["mean", "standard_deviation"]) bss.fit(dpt_data, dpt_weights) -print(f"Computed mean on rank {rank}:\n", bss.mean) -print(f"Computed std on rank {rank}:\n", bss.standard_deviation) +print(f"Computed mean on rank {rank}:\n", bss.mean_) +print(f"Computed std on rank {rank}:\n", bss.standard_deviation_) diff --git a/examples/sklearnex/incremental_basic_statistics.py b/examples/sklearnex/incremental_basic_statistics.py index b2713e1657..3e4ec8aa13 100644 --- a/examples/sklearnex/incremental_basic_statistics.py +++ b/examples/sklearnex/incremental_basic_statistics.py @@ -30,9 +30,9 @@ X_3 = np.array([[1, 1], [1, 2], [2, 3]]) result = incbs.partial_fit(X_3) -print(f"Mean:\n{result.mean}") -print(f"Max:\n{result.max}") -print(f"Sum:\n{result.sum}") +print(f"Mean:\n{result.mean_}") +print(f"Max:\n{result.max_}") +print(f"Sum:\n{result.sum_}") # We put the whole data to fit method, it is split automatically and then # partial_fit is called for each batch. @@ -40,6 +40,6 @@ X = np.array([[0, 1], [0, 1], [1, 2], [1, 1], [1, 2], [2, 3]]) result = incbs.fit(X) -print(f"Mean:\n{result.mean}") -print(f"Max:\n{result.max}") -print(f"Sum:\n{result.sum}") +print(f"Mean:\n{result.mean_}") +print(f"Max:\n{result.max_}") +print(f"Sum:\n{result.sum_}") diff --git a/examples/sklearnex/incremental_basic_statistics_dpctl.py b/examples/sklearnex/incremental_basic_statistics_dpctl.py index 170ba0e446..7b6a905dec 100644 --- a/examples/sklearnex/incremental_basic_statistics_dpctl.py +++ b/examples/sklearnex/incremental_basic_statistics_dpctl.py @@ -36,9 +36,9 @@ X_3 = dpt.asarray([[1, 1], [1, 2], [2, 3]], sycl_queue=queue) result = incbs.partial_fit(X_3) -print(f"Mean:\n{result.mean}") -print(f"Max:\n{result.max}") -print(f"Sum:\n{result.sum}") +print(f"Mean:\n{result.mean_}") +print(f"Max:\n{result.max_}") +print(f"Sum:\n{result.sum_}") # We put the whole data to fit method, it is split automatically and then # partial_fit is called for each batch. @@ -46,6 +46,6 @@ X = dpt.asarray([[0, 1], [0, 1], [1, 2], [1, 1], [1, 2], [2, 3]], sycl_queue=queue) result = incbs.fit(X) -print(f"Mean:\n{result.mean}") -print(f"Max:\n{result.max}") -print(f"Sum:\n{result.sum}") +print(f"Mean:\n{result.mean_}") +print(f"Max:\n{result.max_}") +print(f"Sum:\n{result.sum_}") diff --git a/onedal/basic_statistics/basic_statistics.py b/onedal/basic_statistics/basic_statistics.py index c60d1599ac..9dc82e8757 100644 --- a/onedal/basic_statistics/basic_statistics.py +++ b/onedal/basic_statistics/basic_statistics.py @@ -17,17 +17,17 @@ import warnings from abc import ABCMeta, abstractmethod -import numpy as np - from ..common._base import BaseEstimator from ..datatypes import _convert_to_supported, from_table, to_table from ..utils import _is_csr -from ..utils.validation import _check_array -class BaseBasicStatistics(BaseEstimator, metaclass=ABCMeta): - @abstractmethod - def __init__(self, result_options, algorithm): +class BasicStatistics(BaseEstimator, metaclass=ABCMeta): + """ + Basic Statistics oneDAL implementation. + """ + + def __init__(self, result_options="all", algorithm="by_default"): self.options = result_options self.algorithm = algorithm @@ -46,49 +46,38 @@ def get_all_result_options(): "second_order_raw_moment", ] - def _get_result_options(self, options): - if options == "all": - options = self.get_all_result_options() - if isinstance(options, list): - options = "|".join(options) - assert isinstance(options, str) - return options + @property + def options(self): + if self._options == ["all"]: + return self.get_all_result_options() + return self._options + + @options.setter + def options(self, opts): + # options always to be an iterable + self._options = opts.split("|") if isinstance(opts, str) else opts - def _get_onedal_params(self, is_csr, dtype=np.float32): - options = self._get_result_options(self.options) + def _get_onedal_params(self, is_csr, dtype=None): return { "fptype": dtype, "method": "sparse" if is_csr else self.algorithm, - "result_option": options, + "result_option": "|".join(self.options), } - -class BasicStatistics(BaseBasicStatistics): - """ - Basic Statistics oneDAL implementation. - """ - - def __init__(self, result_options="all", algorithm="by_default"): - super().__init__(result_options, algorithm) - def fit(self, data, sample_weight=None, queue=None): policy = self._get_policy(queue, data, sample_weight) is_csr = _is_csr(data) - if data is not None and not is_csr: - data = _check_array(data, ensure_2d=False) - if sample_weight is not None: - sample_weight = _check_array(sample_weight, ensure_2d=False) - - data, sample_weight = _convert_to_supported(policy, data, sample_weight) is_single_dim = data.ndim == 1 - data_table, weights_table = to_table(data, sample_weight) + data, sample_weight = to_table( + *_convert_to_supported(policy, data, sample_weight) + ) - dtype = data.dtype - raw_result = self._compute_raw(data_table, weights_table, policy, dtype, is_csr) - for opt, raw_value in raw_result.items(): - value = from_table(raw_value).ravel() + result = self._compute_raw(data, sample_weight, policy, data.dtype, is_csr) + + for opt in self.options: + value = from_table(getattr(result, opt))[0] # two-dimensional table [1, n] if is_single_dim: setattr(self, opt, value[0]) else: @@ -96,12 +85,10 @@ def fit(self, data, sample_weight=None, queue=None): return self - def _compute_raw( - self, data_table, weights_table, policy, dtype=np.float32, is_csr=False - ): + def _compute_raw(self, data_table, weights_table, policy, dtype=None, is_csr=False): + # This function is maintained for internal use by KMeans tolerance + # calculations, but is otherwise considered legacy code and is not + # to be used externally in any circumstance module = self._get_backend("basic_statistics") params = self._get_onedal_params(is_csr, dtype) - result = module.compute(policy, params, data_table, weights_table) - options = self._get_result_options(self.options).split("|") - - return {opt: getattr(result, opt) for opt in options} + return module.compute(policy, params, data_table, weights_table) diff --git a/onedal/basic_statistics/incremental_basic_statistics.py b/onedal/basic_statistics/incremental_basic_statistics.py index 5b83e7722a..14f13bd1aa 100644 --- a/onedal/basic_statistics/incremental_basic_statistics.py +++ b/onedal/basic_statistics/incremental_basic_statistics.py @@ -14,16 +14,11 @@ # limitations under the License. # ============================================================================== -import numpy as np - -from daal4py.sklearn._utils import get_dtype - from ..datatypes import _convert_to_supported, from_table, to_table -from ..utils import _check_array -from .basic_statistics import BaseBasicStatistics +from .basic_statistics import BasicStatistics -class IncrementalBasicStatistics(BaseBasicStatistics): +class IncrementalBasicStatistics(BasicStatistics): """ Incremental estimator for basic statistics based on oneDAL implementation. Allows to compute basic statistics if data are splitted into batches. @@ -65,8 +60,8 @@ class IncrementalBasicStatistics(BaseBasicStatistics): Second order moment of each feature over all samples. """ - def __init__(self, result_options="all"): - super().__init__(result_options, algorithm="by_default") + def __init__(self, result_options="all", algorithm="by_default"): + super().__init__(result_options, algorithm) self._reset() def _reset(self): @@ -85,7 +80,7 @@ def __getstate__(self): return data - def partial_fit(self, X, weights=None, queue=None): + def partial_fit(self, X, sample_weight=None, queue=None): """ Computes partial data for basic statistics from data batch X and saves it to `_partial_result`. @@ -106,24 +101,11 @@ def partial_fit(self, X, weights=None, queue=None): """ self._queue = queue policy = self._get_policy(queue, X) - X, weights = _convert_to_supported(policy, X, weights) - - X = _check_array( - X, dtype=[np.float64, np.float32], ensure_2d=False, force_all_finite=False - ) - if weights is not None: - weights = _check_array( - weights, - dtype=[np.float64, np.float32], - ensure_2d=False, - force_all_finite=False, - ) + X, sample_weight = to_table(*_convert_to_supported(policy, X, sample_weight)) if not hasattr(self, "_onedal_params"): - dtype = get_dtype(X) - self._onedal_params = self._get_onedal_params(False, dtype=dtype) + self._onedal_params = self._get_onedal_params(False, dtype=X.dtype) - X_table, weights_table = to_table(X, weights) self._partial_result = self._get_backend( "basic_statistics", None, @@ -131,8 +113,8 @@ def partial_fit(self, X, weights=None, queue=None): policy, self._onedal_params, self._partial_result, - X_table, - weights_table, + X, + sample_weight, ) self._need_to_finalize = True @@ -167,9 +149,8 @@ def finalize_fit(self, queue=None): self._onedal_params, self._partial_result, ) - options = self._get_result_options(self.options).split("|") - for opt in options: - setattr(self, opt, from_table(getattr(result, opt)).ravel()) + for opt in self.options: + setattr(self, opt, from_table(getattr(result, opt))[0]) self._need_to_finalize = False diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index 93eadf8c6b..1d661f140c 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -89,7 +89,7 @@ def _tolerance(self, X_table, rtol, is_csr, policy, dtype): bs = self._get_basic_statistics_backend("variance") res = bs._compute_raw(X_table, dummy, policy, dtype, is_csr) - mean_var = from_table(res["variance"]).mean() + mean_var = from_table(res.variance).mean() return mean_var * rtol diff --git a/onedal/spmd/basic_statistics/basic_statistics.py b/onedal/spmd/basic_statistics/basic_statistics.py index 8253aa6628..f7f273b9ea 100644 --- a/onedal/spmd/basic_statistics/basic_statistics.py +++ b/onedal/spmd/basic_statistics/basic_statistics.py @@ -21,10 +21,4 @@ class BasicStatistics(BaseEstimatorSPMD, BasicStatistics_Batch): - @support_input_format() - def compute(self, data, weights=None, queue=None): - return super().compute(data, weights=weights, queue=queue) - - @support_input_format() - def fit(self, data, sample_weight=None, queue=None): - return super().fit(data, sample_weight=sample_weight, queue=queue) + pass diff --git a/onedal/spmd/basic_statistics/incremental_basic_statistics.py b/onedal/spmd/basic_statistics/incremental_basic_statistics.py index f4d7414abc..a5d0e01333 100644 --- a/onedal/spmd/basic_statistics/incremental_basic_statistics.py +++ b/onedal/spmd/basic_statistics/incremental_basic_statistics.py @@ -30,7 +30,7 @@ def _reset(self): "basic_statistics", None, "partial_compute_result" ) - def partial_fit(self, X, weights=None, queue=None): + def partial_fit(self, X, sample_weight=None, queue=None): """ Computes partial data for basic statistics from data batch X and saves it to `_partial_result`. @@ -51,13 +51,11 @@ def partial_fit(self, X, weights=None, queue=None): """ self._queue = queue policy = super(base_IncrementalBasicStatistics, self)._get_policy(queue, X) - X, weights = _convert_to_supported(policy, X, weights) + X, sample_weight = to_table(*_convert_to_supported(policy, X, sample_weight)) if not hasattr(self, "_onedal_params"): - dtype = get_dtype(X) - self._onedal_params = self._get_onedal_params(False, dtype=dtype) + self._onedal_params = self._get_onedal_params(False, dtype=X.dtype) - X_table, weights_table = to_table(X, weights) self._partial_result = super(base_IncrementalBasicStatistics, self)._get_backend( "basic_statistics", None, @@ -65,8 +63,8 @@ def partial_fit(self, X, weights=None, queue=None): policy, self._onedal_params, self._partial_result, - X_table, - weights_table, + X, + sample_weight, ) self._need_to_finalize = True diff --git a/sklearnex/basic_statistics/basic_statistics.py b/sklearnex/basic_statistics/basic_statistics.py index da82e3bd82..94c6f607e2 100644 --- a/sklearnex/basic_statistics/basic_statistics.py +++ b/sklearnex/basic_statistics/basic_statistics.py @@ -16,10 +16,8 @@ import warnings -import numpy as np from sklearn.base import BaseEstimator from sklearn.utils import check_array -from sklearn.utils.validation import _check_sample_weight from daal4py.sklearn._n_jobs_support import control_n_jobs from daal4py.sklearn._utils import sklearn_check_version @@ -27,11 +25,8 @@ from .._device_offload import dispatch from .._utils import IntelEstimator, PatchingConditionsChain - -if sklearn_check_version("1.6"): - from sklearn.utils.validation import validate_data -else: - validate_data = BaseEstimator._validate_data +from ..utils._array_api import get_namespace +from ..utils.validation import _check_sample_weight, validate_data if sklearn_check_version("1.2"): from sklearn.utils._param_validation import StrOptions @@ -130,30 +125,15 @@ def __init__(self, result_options="all"): def _save_attributes(self): assert hasattr(self, "_onedal_estimator") - - if self.result_options == "all": - result_options = onedal_BasicStatistics.get_all_result_options() - else: - result_options = self.result_options - - if isinstance(result_options, str): - setattr( - self, - result_options + "_", - getattr(self._onedal_estimator, result_options), - ) - elif isinstance(result_options, list): - for option in result_options: - setattr(self, option + "_", getattr(self._onedal_estimator, option)) + for option in self._onedal_estimator.options: + setattr(self, option + "_", getattr(self._onedal_estimator, option)) def __getattr__(self, attr): - if self.result_options == "all": - result_options = onedal_BasicStatistics.get_all_result_options() - else: - result_options = self.result_options is_deprecated_attr = ( - isinstance(result_options, str) and (attr == result_options) - ) or (isinstance(result_options, list) and (attr in result_options)) + attr in self._onedal_estimator.options + if "_onedal_estimator" in self.__dict__ + else False + ) if is_deprecated_attr: warnings.warn( "Result attributes without a trailing underscore were deprecated in version 2025.1 and will be removed in 2026.0" @@ -179,13 +159,16 @@ def _onedal_fit(self, X, sample_weight=None, queue=None): if sklearn_check_version("1.2"): self._validate_params() + xp, _ = get_namespace(X) if sklearn_check_version("1.0"): - X = validate_data(self, X, dtype=[np.float64, np.float32], ensure_2d=False) + X = validate_data(self, X, dtype=[xp.float64, xp.float32], ensure_2d=False) else: - X = check_array(X, dtype=[np.float64, np.float32]) + X = check_array(X, dtype=[xp.float64, xp.float32]) if sample_weight is not None: - sample_weight = _check_sample_weight(sample_weight, X) + sample_weight = _check_sample_weight( + sample_weight, X, dtype=[xp.float64, xp.float32] + ) onedal_params = { "result_options": self.result_options, diff --git a/sklearnex/basic_statistics/incremental_basic_statistics.py b/sklearnex/basic_statistics/incremental_basic_statistics.py index d1ddcd55dc..d6c81942c2 100644 --- a/sklearnex/basic_statistics/incremental_basic_statistics.py +++ b/sklearnex/basic_statistics/incremental_basic_statistics.py @@ -14,10 +14,8 @@ # limitations under the License. # ============================================================================== -import numpy as np from sklearn.base import BaseEstimator from sklearn.utils import check_array, gen_batches -from sklearn.utils.validation import _check_sample_weight from daal4py.sklearn._n_jobs_support import control_n_jobs from daal4py.sklearn._utils import sklearn_check_version @@ -34,10 +32,8 @@ import numbers import warnings -if sklearn_check_version("1.6"): - from sklearn.utils.validation import validate_data -else: - validate_data = BaseEstimator._validate_data +from ..utils._array_api import get_namespace +from ..utils.validation import _check_sample_weight, validate_data @control_n_jobs(decorated_methods=["partial_fit", "_onedal_finalize_fit"]) @@ -160,12 +156,7 @@ class IncrementalBasicStatistics(IntelEstimator, BaseEstimator): } def __init__(self, result_options="all", batch_size=None): - if result_options == "all": - self.result_options = ( - self._onedal_incremental_basic_statistics.get_all_result_options() - ) - else: - self.result_options = result_options + self.result_options = result_options self._need_to_finalize = False self.batch_size = batch_size @@ -178,14 +169,6 @@ def _onedal_supported(self, method_name, *data): _onedal_cpu_supported = _onedal_supported _onedal_gpu_supported = _onedal_supported - def _get_onedal_result_options(self, options): - if isinstance(options, list): - onedal_options = "|".join(self.result_options) - else: - onedal_options = options - assert isinstance(onedal_options, str) - return options - def _onedal_finalize_fit(self, queue=None): assert hasattr(self, "_onedal_estimator") self._onedal_estimator.finalize_fit(queue=queue) @@ -195,21 +178,24 @@ def _onedal_partial_fit(self, X, sample_weight=None, queue=None, check_input=Tru first_pass = not hasattr(self, "n_samples_seen_") or self.n_samples_seen_ == 0 if check_input: + xp, _ = get_namespace(X) if sklearn_check_version("1.0"): X = validate_data( self, X, - dtype=[np.float64, np.float32], + dtype=[xp.float64, xp.float32], reset=first_pass, ) else: X = check_array( X, - dtype=[np.float64, np.float32], + dtype=[xp.float64, xp.float32], ) - if sample_weight is not None: - sample_weight = _check_sample_weight(sample_weight, X) + if sample_weight is not None: + sample_weight = _check_sample_weight( + sample_weight, X, dtype=[xp.float64, xp.float32] + ) if first_pass: self.n_samples_seen_ = X.shape[0] @@ -217,27 +203,28 @@ def _onedal_partial_fit(self, X, sample_weight=None, queue=None, check_input=Tru else: self.n_samples_seen_ += X.shape[0] - onedal_params = { - "result_options": self._get_onedal_result_options(self.result_options) - } if not hasattr(self, "_onedal_estimator"): self._onedal_estimator = self._onedal_incremental_basic_statistics( - **onedal_params + result_options=self.result_options ) - self._onedal_estimator.partial_fit(X, weights=sample_weight, queue=queue) + + self._onedal_estimator.partial_fit(X, sample_weight=sample_weight, queue=queue) self._need_to_finalize = True def _onedal_fit(self, X, sample_weight=None, queue=None): if sklearn_check_version("1.2"): self._validate_params() + xp, _ = get_namespace(X) if sklearn_check_version("1.0"): - X = validate_data(self, X, dtype=[np.float64, np.float32]) + X = validate_data(self, X, dtype=[xp.float64, xp.float32]) else: - X = check_array(X, dtype=[np.float64, np.float32]) + X = check_array(X, dtype=[xp.float64, xp.float32]) if sample_weight is not None: - sample_weight = _check_sample_weight(sample_weight, X) + sample_weight = _check_sample_weight( + sample_weight, X, dtype=[xp.float64, xp.float32] + ) n_samples, n_features = X.shape if self.batch_size is None: @@ -263,11 +250,12 @@ def _onedal_fit(self, X, sample_weight=None, queue=None): return self def __getattr__(self, attr): - result_options = self.__dict__["result_options"] sattr = attr.removesuffix("_") is_statistic_attr = ( - isinstance(result_options, str) and (sattr == result_options) - ) or (isinstance(result_options, list) and (sattr in result_options)) + sattr in self._onedal_estimator.options + if "_onedal_estimator" in self.__dict__ + else False + ) if is_statistic_attr: if self._need_to_finalize: self._onedal_finalize_fit() diff --git a/sklearnex/spmd/basic_statistics/basic_statistics.py b/sklearnex/spmd/basic_statistics/basic_statistics.py index eef0b666a8..94854bd85c 100644 --- a/sklearnex/spmd/basic_statistics/basic_statistics.py +++ b/sklearnex/spmd/basic_statistics/basic_statistics.py @@ -14,8 +14,11 @@ # limitations under the License. # ============================================================================== -from onedal.spmd.basic_statistics import BasicStatistics +from onedal.spmd.basic_statistics import BasicStatistics as onedal_BasicStatistics -# TODO: -# Currently it uses `onedal` module interface. -# Add sklearnex dispatching. +from ...basic_statistics import BasicStatistics as BasicStatistics_Batch + + +class BasicStatistics(BasicStatistics_Batch): + __doc__ = BasicStatistics_Batch.__doc__ + _onedal_basic_statistics = staticmethod(onedal_BasicStatistics) diff --git a/sklearnex/tests/test_memory_usage.py b/sklearnex/tests/test_memory_usage.py index aa92df1d6a..2d52a545cf 100644 --- a/sklearnex/tests/test_memory_usage.py +++ b/sklearnex/tests/test_memory_usage.py @@ -35,10 +35,14 @@ get_dataframes_and_queues, ) from onedal.tests.utils._device_selection import get_queues, is_dpctl_device_available -from onedal.utils._array_api import _get_sycl_namespace from onedal.utils._dpep_helpers import dpctl_available, dpnp_available from sklearnex import config_context -from sklearnex.tests.utils import PATCHED_FUNCTIONS, PATCHED_MODELS, SPECIAL_INSTANCES +from sklearnex.tests.utils import ( + PATCHED_FUNCTIONS, + PATCHED_MODELS, + SPECIAL_INSTANCES, + DummyEstimator, +) from sklearnex.utils._array_api import get_namespace if dpctl_available: @@ -131,41 +135,6 @@ def gen_functions(functions): ORDER_DICT = {"F": np.asfortranarray, "C": np.ascontiguousarray} -if _is_dpc_backend: - - from sklearn.utils.validation import check_is_fitted - - from onedal.datatypes import from_table, to_table - - class DummyEstimatorWithTableConversions(BaseEstimator): - - def fit(self, X, y=None): - sua_iface, xp, _ = _get_sycl_namespace(X) - X_table = to_table(X) - y_table = to_table(y) - # The presence of the fitted attributes (ending with a trailing - # underscore) is required for the correct check. The cleanup of - # the memory will occur at the estimator instance deletion. - self.x_attr_ = from_table( - X_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp - ) - self.y_attr_ = from_table( - y_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp - ) - return self - - def predict(self, X): - # Checks if the estimator is fitted by verifying the presence of - # fitted attributes (ending with a trailing underscore). - check_is_fitted(self) - sua_iface, xp, _ = _get_sycl_namespace(X) - X_table = to_table(X) - returned_X = from_table( - X_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp - ) - return returned_X - - def gen_clsf_data(n_samples, n_features, dtype=None): data, label = make_classification( n_classes=2, n_samples=n_samples, n_features=n_features, random_state=777 @@ -369,7 +338,7 @@ def test_table_conversions_memory_leaks(dataframe, queue, order, data_shape, dty pytest.skip("SYCL device memory leak check requires the level zero sysman") _kfold_function_template( - DummyEstimatorWithTableConversions, + DummyEstimator, dataframe, data_shape, queue, diff --git a/sklearnex/tests/utils/__init__.py b/sklearnex/tests/utils/__init__.py index 60ca67fa37..db728fe913 100644 --- a/sklearnex/tests/utils/__init__.py +++ b/sklearnex/tests/utils/__init__.py @@ -21,6 +21,7 @@ SPECIAL_INSTANCES, UNPATCHED_FUNCTIONS, UNPATCHED_MODELS, + DummyEstimator, _get_processor_info, call_method, gen_dataset, @@ -39,6 +40,7 @@ "gen_models_info", "gen_dataset", "sklearn_clone_dict", + "DummyEstimator", ] _IS_INTEL = "GenuineIntel" in _get_processor_info() diff --git a/sklearnex/tests/utils/base.py b/sklearnex/tests/utils/base.py index 1949519585..706de39a91 100755 --- a/sklearnex/tests/utils/base.py +++ b/sklearnex/tests/utils/base.py @@ -32,8 +32,11 @@ ) from sklearn.datasets import load_diabetes, load_iris from sklearn.neighbors._base import KNeighborsMixin +from sklearn.utils.validation import check_is_fitted +from onedal.datatypes import from_table, to_table from onedal.tests.utils._dataframes_support import _convert_to_dataframe +from onedal.utils._array_api import _get_sycl_namespace from sklearnex import get_patch_map, patch_sklearn, sklearn_is_patched, unpatch_sklearn from sklearnex.basic_statistics import BasicStatistics, IncrementalBasicStatistics from sklearnex.linear_model import LogisticRegression @@ -369,3 +372,41 @@ def _get_processor_info(): ) return proc + + +class DummyEstimator(BaseEstimator): + + def fit(self, X, y=None): + sua_iface, xp, _ = _get_sycl_namespace(X) + X_table = to_table(X) + y_table = to_table(y) + # The presence of the fitted attributes (ending with a trailing + # underscore) is required for the correct check. The cleanup of + # the memory will occur at the estimator instance deletion. + if sua_iface: + self.x_attr_ = from_table( + X_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp + ) + self.y_attr_ = from_table( + y_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp + ) + else: + self.x_attr = from_table(X_table) + self.y_attr = from_table(y_table) + + return self + + def predict(self, X): + # Checks if the estimator is fitted by verifying the presence of + # fitted attributes (ending with a trailing underscore). + check_is_fitted(self) + sua_iface, xp, _ = _get_sycl_namespace(X) + X_table = to_table(X) + if sua_iface: + returned_X = from_table( + X_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp + ) + else: + returned_X = from_table(X_table) + + return returned_X diff --git a/sklearnex/utils/__init__.py b/sklearnex/utils/__init__.py index 4c3fe21154..686e089adf 100755 --- a/sklearnex/utils/__init__.py +++ b/sklearnex/utils/__init__.py @@ -14,6 +14,6 @@ # limitations under the License. # =============================================================================== -from .validation import _assert_all_finite +from .validation import assert_all_finite -__all__ = ["_assert_all_finite"] +__all__ = ["assert_all_finite"] diff --git a/sklearnex/utils/tests/test_finite.py b/sklearnex/utils/tests/test_finite.py deleted file mode 100644 index 7d83667699..0000000000 --- a/sklearnex/utils/tests/test_finite.py +++ /dev/null @@ -1,89 +0,0 @@ -# ============================================================================== -# Copyright 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import time - -import numpy as np -import numpy.random as rand -import pytest -from numpy.testing import assert_raises - -from sklearnex.utils import _assert_all_finite - - -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) -@pytest.mark.parametrize( - "shape", - [ - [16, 2048], - [ - 2**16 + 3, - ], - [1000, 1000], - ], -) -@pytest.mark.parametrize("allow_nan", [False, True]) -def test_sum_infinite_actually_finite(dtype, shape, allow_nan): - X = np.empty(shape, dtype=dtype) - X.fill(np.finfo(dtype).max) - _assert_all_finite(X, allow_nan=allow_nan) - - -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) -@pytest.mark.parametrize( - "shape", - [ - [16, 2048], - [ - 65539, # 2**16 + 3, - ], - [1000, 1000], - ], -) -@pytest.mark.parametrize("allow_nan", [False, True]) -@pytest.mark.parametrize("check", ["inf", "NaN", None]) -@pytest.mark.parametrize("seed", [0, int(time.time())]) -def test_assert_finite_random_location(dtype, shape, allow_nan, check, seed): - rand.seed(seed) - X = rand.uniform(high=np.finfo(dtype).max, size=shape).astype(dtype) - - if check: - loc = rand.randint(0, X.size - 1) - X.reshape((-1,))[loc] = float(check) - - if check is None or (allow_nan and check == "NaN"): - _assert_all_finite(X, allow_nan=allow_nan) - else: - assert_raises(ValueError, _assert_all_finite, X, allow_nan=allow_nan) - - -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) -@pytest.mark.parametrize("allow_nan", [False, True]) -@pytest.mark.parametrize("check", ["inf", "NaN", None]) -@pytest.mark.parametrize("seed", [0, int(time.time())]) -def test_assert_finite_random_shape_and_location(dtype, allow_nan, check, seed): - lb, ub = 32768, 1048576 # lb is a patching condition, ub 2^20 - rand.seed(seed) - X = rand.uniform(high=np.finfo(dtype).max, size=rand.randint(lb, ub)).astype(dtype) - - if check: - loc = rand.randint(0, X.size - 1) - X[loc] = float(check) - - if check is None or (allow_nan and check == "NaN"): - _assert_all_finite(X, allow_nan=allow_nan) - else: - assert_raises(ValueError, _assert_all_finite, X, allow_nan=allow_nan) diff --git a/sklearnex/utils/tests/test_validation.py b/sklearnex/utils/tests/test_validation.py new file mode 100644 index 0000000000..92ba0d742a --- /dev/null +++ b/sklearnex/utils/tests/test_validation.py @@ -0,0 +1,240 @@ +# ============================================================================== +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import time + +import numpy as np +import numpy.random as rand +import pytest + +from daal4py.sklearn._utils import sklearn_check_version +from onedal.tests.utils._dataframes_support import ( + _convert_to_dataframe, + get_dataframes_and_queues, +) +from sklearnex import config_context +from sklearnex.tests.utils import DummyEstimator, gen_dataset +from sklearnex.utils.validation import _check_sample_weight, validate_data + +# array_api support starts in sklearn 1.2, and array_api_strict conformance starts in sklearn 1.3 +_dataframes_supported = ( + "numpy,pandas" + + (",dpctl" if sklearn_check_version("1.2") else "") + + (",array_api" if sklearn_check_version("1.3") else "") +) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize( + "shape", + [ + [16, 2048], + [ + 2**16 + 3, + ], + [1000, 1000], + ], +) +@pytest.mark.parametrize("ensure_all_finite", ["allow-nan", True]) +def test_sum_infinite_actually_finite(dtype, shape, ensure_all_finite): + est = DummyEstimator() + X = np.empty(shape, dtype=dtype) + X.fill(np.finfo(dtype).max) + X = np.atleast_2d(X) + X_array = validate_data(est, X, ensure_all_finite=ensure_all_finite) + assert type(X_array) == type(X) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize( + "shape", + [ + [16, 2048], + [ + 2**16 + 3, + ], + [1000, 1000], + ], +) +@pytest.mark.parametrize("ensure_all_finite", ["allow-nan", True]) +@pytest.mark.parametrize("check", ["inf", "NaN", None]) +@pytest.mark.parametrize("seed", [0, int(time.time())]) +@pytest.mark.parametrize( + "dataframe, queue", + get_dataframes_and_queues(_dataframes_supported), +) +def test_validate_data_random_location( + dataframe, queue, dtype, shape, ensure_all_finite, check, seed +): + est = DummyEstimator() + rand.seed(seed) + X = rand.uniform(high=np.finfo(dtype).max, size=shape).astype(dtype) + + if check: + loc = rand.randint(0, X.size - 1) + X.reshape((-1,))[loc] = float(check) + + # column heavy pandas inputs are very slow in sklearn's check_array even without + # the finite check, just transpose inputs to guarantee fast processing in tests + X = _convert_to_dataframe( + np.atleast_2d(X).T, + target_df=dataframe, + sycl_queue=queue, + ) + + dispatch = {} + if sklearn_check_version("1.2") and dataframe != "pandas": + dispatch["array_api_dispatch"] = True + + with config_context(**dispatch): + + allow_nan = ensure_all_finite == "allow-nan" + if check is None or (allow_nan and check == "NaN"): + validate_data(est, X, ensure_all_finite=ensure_all_finite) + else: + type_err = "infinity" if allow_nan else "[NaN|infinity]" + msg_err = f"Input X contains {type_err}" + with pytest.raises(ValueError, match=msg_err): + validate_data(est, X, ensure_all_finite=ensure_all_finite) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("ensure_all_finite", ["allow-nan", True]) +@pytest.mark.parametrize("check", ["inf", "NaN", None]) +@pytest.mark.parametrize("seed", [0, int(time.time())]) +@pytest.mark.parametrize( + "dataframe, queue", + get_dataframes_and_queues(_dataframes_supported), +) +def test_validate_data_random_shape_and_location( + dataframe, queue, dtype, ensure_all_finite, check, seed +): + est = DummyEstimator() + lb, ub = 32768, 1048576 # lb is a patching condition, ub 2^20 + rand.seed(seed) + X = rand.uniform(high=np.finfo(dtype).max, size=rand.randint(lb, ub)).astype(dtype) + + if check: + loc = rand.randint(0, X.size - 1) + X[loc] = float(check) + + X = _convert_to_dataframe( + np.atleast_2d(X).T, + target_df=dataframe, + sycl_queue=queue, + ) + + dispatch = {} + if sklearn_check_version("1.2") and dataframe != "pandas": + dispatch["array_api_dispatch"] = True + + with config_context(**dispatch): + + allow_nan = ensure_all_finite == "allow-nan" + if check is None or (allow_nan and check == "NaN"): + validate_data(est, X, ensure_all_finite=ensure_all_finite) + else: + type_err = "infinity" if allow_nan else "[NaN|infinity]" + msg_err = f"Input X contains {type_err}." + with pytest.raises(ValueError, match=msg_err): + validate_data(est, X, ensure_all_finite=ensure_all_finite) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("check", ["inf", "NaN", None]) +@pytest.mark.parametrize("seed", [0, int(time.time())]) +@pytest.mark.parametrize( + "dataframe, queue", + get_dataframes_and_queues(_dataframes_supported), +) +def test__check_sample_weight_random_shape_and_location( + dataframe, queue, dtype, check, seed +): + # This testing assumes that array api inputs to validate_data will only occur + # with sklearn array_api support which began in sklearn 1.2. This would assume + # that somewhere upstream of the validate_data call, a data conversion of dpnp, + # dpctl, or array_api inputs to numpy inputs would have occurred. + + lb, ub = 32768, 1048576 # lb is a patching condition, ub 2^20 + rand.seed(seed) + shape = (rand.randint(lb, ub), 2) + X = rand.uniform(high=np.finfo(dtype).max, size=shape).astype(dtype) + sample_weight = rand.uniform(high=np.finfo(dtype).max, size=shape[0]).astype(dtype) + + if check: + loc = rand.randint(0, shape[0] - 1) + sample_weight[loc] = float(check) + + X = _convert_to_dataframe( + X, + target_df=dataframe, + sycl_queue=queue, + ) + sample_weight = _convert_to_dataframe( + sample_weight, + target_df=dataframe, + sycl_queue=queue, + ) + + dispatch = {} + if sklearn_check_version("1.2") and dataframe != "pandas": + dispatch["array_api_dispatch"] = True + + with config_context(**dispatch): + + if check is None: + X_out = _check_sample_weight(sample_weight, X) + if dispatch: + assert type(X_out) == type(X) + else: + assert isinstance(X_out, np.ndarray) + else: + msg_err = "Input sample_weight contains [NaN|infinity]" + with pytest.raises(ValueError, match=msg_err): + X_out = _check_sample_weight(sample_weight, X) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize( + "dataframe, queue", + get_dataframes_and_queues(_dataframes_supported), +) +def test_validate_data_output(dtype, dataframe, queue): + # This testing assumes that array api inputs to validate_data will only occur + # with sklearn array_api support which began in sklearn 1.2. This would assume + # that somewhere upstream of the validate_data call, a data conversion of dpnp, + # dpctl, or array_api inputs to numpy inputs would have occurred. + est = DummyEstimator() + X, y = gen_dataset(est, queue=queue, target_df=dataframe, dtype=dtype)[0] + + dispatch = {} + if sklearn_check_version("1.2") and dataframe != "pandas": + dispatch["array_api_dispatch"] = True + + with config_context(**dispatch): + X_out, y_out = validate_data(est, X, y) + # check sklearn validate_data operations work underneath + X_array = validate_data(est, X, reset=False) + + if dispatch: + assert type(X) == type( + X_array + ), f"validate_data converted {type(X)} to {type(X_array)}" + assert type(X) == type(X_out), f"from_array converted {type(X)} to {type(X_out)}" + else: + # array_api_strict from sklearn < 1.2 and pandas will convert to numpy arrays + assert isinstance(X_array, np.ndarray) + assert isinstance(X_out, np.ndarray) diff --git a/sklearnex/utils/validation.py b/sklearnex/utils/validation.py index b2d1898643..76470091ce 100755 --- a/sklearnex/utils/validation.py +++ b/sklearnex/utils/validation.py @@ -14,4 +14,142 @@ # limitations under the License. # =============================================================================== -from daal4py.sklearn.utils.validation import _assert_all_finite +import numbers + +import scipy.sparse as sp +from sklearn.utils.validation import _assert_all_finite as _sklearn_assert_all_finite +from sklearn.utils.validation import _num_samples, check_array, check_non_negative + +from daal4py.sklearn._utils import sklearn_check_version +from onedal.utils.validation import _assert_all_finite as _onedal_assert_all_finite + +from ._array_api import get_namespace + +if sklearn_check_version("1.6"): + from sklearn.utils.validation import validate_data as _sklearn_validate_data + + _finite_keyword = "ensure_all_finite" + +else: + from sklearn.base import BaseEstimator + + _sklearn_validate_data = BaseEstimator._validate_data + _finite_keyword = "force_all_finite" + + +def _is_contiguous(X): + # array_api does not have a `strides` or `flags` attribute for testing memory + # order. When dlpack support is brought in for oneDAL, the dlpack python capsule + # can then be inspected for strides and this must be updated. _is_contiguous is + # therefore conservative in verifying attributes and does not support array_api. + # This will block onedal_assert_all_finite from being used for array_api inputs. + return hasattr(X, "flags") and (X.flags["C_CONTIGUOUS"] or X.flags["F_CONTIGUOUS"]) + + +def _sklearnex_assert_all_finite( + X, + *, + allow_nan=False, + input_name="", +): + # size check is an initial match to daal4py for performance reasons, can be + # optimized later + xp, _ = get_namespace(X) + if X.size < 32768 or X.dtype not in [xp.float32, xp.float64] or not _is_contiguous(X): + if sklearn_check_version("1.1"): + _sklearn_assert_all_finite(X, allow_nan=allow_nan, input_name=input_name) + else: + _sklearn_assert_all_finite(X, allow_nan=allow_nan) + else: + _onedal_assert_all_finite(X, allow_nan=allow_nan, input_name=input_name) + + +def assert_all_finite( + X, + *, + allow_nan=False, + input_name="", +): + _sklearnex_assert_all_finite( + X.data if sp.issparse(X) else X, + allow_nan=allow_nan, + input_name=input_name, + ) + + +def validate_data( + _estimator, + /, + X="no_validation", + y="no_validation", + **kwargs, +): + # force finite check to not occur in sklearn, default is True + # `ensure_all_finite` is the most up-to-date keyword name in sklearn + # _finite_keyword provides backward compatability for `force_all_finite` + ensure_all_finite = kwargs.pop("ensure_all_finite", True) + kwargs[_finite_keyword] = False + + out = _sklearn_validate_data( + _estimator, + X=X, + y=y, + **kwargs, + ) + if ensure_all_finite: + # run local finite check + allow_nan = ensure_all_finite == "allow-nan" + arg = iter(out if isinstance(out, tuple) else (out,)) + if not isinstance(X, str) or X != "no_validation": + assert_all_finite(next(arg), allow_nan=allow_nan, input_name="X") + if not (y is None or isinstance(y, str) and y == "no_validation"): + assert_all_finite(next(arg), allow_nan=allow_nan, input_name="y") + return out + + +def _check_sample_weight( + sample_weight, X, dtype=None, copy=False, only_non_negative=False +): + + n_samples = _num_samples(X) + xp, _ = get_namespace(X) + + if dtype is not None and dtype not in [xp.float32, xp.float64]: + dtype = xp.float64 + + if sample_weight is None: + sample_weight = xp.ones(n_samples, dtype=dtype) + elif isinstance(sample_weight, numbers.Number): + sample_weight = xp.full(n_samples, sample_weight, dtype=dtype) + else: + if dtype is None: + dtype = [xp.float64, xp.float32] + + params = { + "accept_sparse": False, + "ensure_2d": False, + "dtype": dtype, + "order": "C", + "copy": copy, + _finite_keyword: False, + } + if sklearn_check_version("1.1"): + params["input_name"] = "sample_weight" + + sample_weight = check_array(sample_weight, **params) + assert_all_finite(sample_weight, input_name="sample_weight") + + if sample_weight.ndim != 1: + raise ValueError("Sample weights must be 1D array or scalar") + + if sample_weight.shape != (n_samples,): + raise ValueError( + "sample_weight.shape == {}, expected {}!".format( + sample_weight.shape, (n_samples,) + ) + ) + + if only_non_negative: + check_non_negative(sample_weight, "`sample_weight`") + + return sample_weight