diff --git a/dask_sql/physical/rel/custom/create_experiment.py b/dask_sql/physical/rel/custom/create_experiment.py index 3d510ac18..ddec9fccf 100644 --- a/dask_sql/physical/rel/custom/create_experiment.py +++ b/dask_sql/physical/rel/custom/create_experiment.py @@ -30,17 +30,9 @@ class CreateExperimentPlugin(BaseRelPlugin): * model_class: Full path to the class of the model which has to be tuned. Any model class with sklearn interface is valid, but might or might not work well with Dask dataframes. - Have a look into the - [dask-ml documentation](https://ml.dask.org/index.html) - for more information on which models work best. You might need to install necessary packages to use the models. * experiment_class : Full path of the Hyperparameter tuner - from dask_ml, choose dask tuner class carefully based on what you - exactly need (memory vs compute constrains), refer: - [dask-ml documentation](https://ml.dask.org/hyper-parameter-search.html) - (for tuning hyperparameter of the models both model_class and experiment class are - required parameters.) * tune_parameters: Key-value of pairs of Hyperparameters to tune, i.e Search Space for particular model to tune @@ -64,7 +56,7 @@ class CreateExperimentPlugin(BaseRelPlugin): CREATE EXPERIMENT my_exp WITH( model_class = 'sklearn.ensemble.GradientBoostingClassifier', - experiment_class = 'dask_ml.model_selection.GridSearchCV', + experiment_class = 'sklearn.model_selection.GridSearchCV', tune_parameters = (n_estimators = ARRAY [16, 32, 2], learning_rate = ARRAY [0.1,0.01,0.001], max_depth = ARRAY [3,4,5,10] @@ -174,7 +166,11 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai search = ExperimentClass(model, {**parameters}, **experiment_kwargs) logger.info(tune_fit_kwargs) - search.fit(X, y, **tune_fit_kwargs) + search.fit( + X.to_dask_array(lengths=True), + y.to_dask_array(lengths=True), + **tune_fit_kwargs, + ) df = pd.DataFrame(search.cv_results_) df["model_class"] = model_class diff --git a/dask_sql/physical/rel/custom/create_model.py b/dask_sql/physical/rel/custom/create_model.py index 179dd7971..726568c5d 100644 --- a/dask_sql/physical/rel/custom/create_model.py +++ b/dask_sql/physical/rel/custom/create_model.py @@ -32,9 +32,6 @@ class CreateModelPlugin(BaseRelPlugin): * model_class: Full path to the class of the model to train. Any model class with sklearn interface is valid, but might or might not work well with Dask dataframes. - Have a look into the - [dask-ml documentation](https://ml.dask.org/index.html) - for more information on which models work best. You might need to install necessary packages to use the models. * target_column: Which column from the data to use as target. @@ -45,16 +42,12 @@ class CreateModelPlugin(BaseRelPlugin): want to set this parameter. * wrap_predict: Boolean flag, whether to wrap the selected model with a :class:`dask_sql.physical.rel.custom.wrappers.ParallelPostFit`. - Have a look into the - [dask-ml docu](https://ml.dask.org/meta-estimators.html#parallel-prediction-and-transformation) - to learn more about it. Defaults to false. Typically you set - it to true for sklearn models if predicting on big data. + Defaults to false. Typically you set it to true for + sklearn models if predicting on big data. * wrap_fit: Boolean flag, whether to wrap the selected - model with a :class:`dask_ml.wrappers.Incremental`. - Have a look into the - [dask-ml docu](https://ml.dask.org/incremental.html) - to learn more about it. Defaults to false. Typically you set - it to true for sklearn models if training on big data. + model with a :class:`dask_sql.physical.rel.custom.wrappers.Incremental`. + Defaults to false. Typically you set it to true for + sklearn models if training on big data. * fit_kwargs: keyword arguments sent to the call to fit(). All other arguments are passed to the constructor of the @@ -76,7 +69,7 @@ class CreateModelPlugin(BaseRelPlugin): Examples: CREATE MODEL my_model WITH ( - model_class = 'dask_ml.xgboost.XGBClassifier', + model_class = 'xgboost.XGBClassifier', target_column = 'target' ) AS ( SELECT x, y, target @@ -95,11 +88,10 @@ class CreateModelPlugin(BaseRelPlugin): dask dataframes. * if you are training on relatively small amounts - of data but predicting on large data samples - (and you are not using a model build for usage with dask - from the dask-ml package), you might want to set - `wrap_predict` to True. With this option, - model interference will be parallelized/distributed. + of data but predicting on large data samples, + you might want to set `wrap_predict` to True. + With this option, model interference will be + parallelized/distributed. * If you are training on large amounts of data, you can try setting wrap_fit to True. This will do the same on the training step, but works only on @@ -158,10 +150,7 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai model = ModelClass(**kwargs) if wrap_fit: - try: - from dask_ml.wrappers import Incremental - except ImportError: # pragma: no cover - raise ValueError("Wrapping requires dask-ml to be installed.") + from dask_sql.physical.rel.custom.wrappers import Incremental model = Incremental(estimator=model) diff --git a/dask_sql/physical/rel/custom/metrics.py b/dask_sql/physical/rel/custom/metrics.py new file mode 100644 index 000000000..4b898d1a9 --- /dev/null +++ b/dask_sql/physical/rel/custom/metrics.py @@ -0,0 +1,206 @@ +# Copyright 2017, Dask developers +# Dask-ML project - https://github.com/dask/dask-ml +from typing import Optional, TypeVar + +import dask +import dask.array as da +import numpy as np +import sklearn.metrics +import sklearn.utils.multiclass +from dask.array import Array +from dask.utils import derived_from + +ArrayLike = TypeVar("ArrayLike", Array, np.ndarray) + + +def accuracy_score( + y_true: ArrayLike, + y_pred: ArrayLike, + normalize: bool = True, + sample_weight: Optional[ArrayLike] = None, + compute: bool = True, +) -> ArrayLike: + """Accuracy classification score. + In multilabel classification, this function computes subset accuracy: + the set of labels predicted for a sample must *exactly* match the + corresponding set of labels in y_true. + Read more in the :ref:`User Guide `. + Parameters + ---------- + y_true : 1d array-like, or label indicator array + Ground truth (correct) labels. + y_pred : 1d array-like, or label indicator array + Predicted labels, as returned by a classifier. + normalize : bool, optional (default=True) + If ``False``, return the number of correctly classified samples. + Otherwise, return the fraction of correctly classified samples. + sample_weight : 1d array-like, optional + Sample weights. + .. versionadded:: 0.7.0 + Returns + ------- + score : scalar dask Array + If ``normalize == True``, return the correctly classified samples + (float), else it returns the number of correctly classified samples + (int). + The best performance is 1 with ``normalize == True`` and the number + of samples with ``normalize == False``. + Notes + ----- + In binary and multiclass classification, this function is equal + to the ``jaccard_similarity_score`` function. + + """ + + if y_true.ndim > 1: + differing_labels = ((y_true - y_pred) == 0).all(1) + score = differing_labels != 0 + else: + score = y_true == y_pred + + if normalize: + score = da.average(score, weights=sample_weight) + elif sample_weight is not None: + score = da.dot(score, sample_weight) + else: + score = score.sum() + + if compute: + score = score.compute() + return score + + +def _log_loss_inner( + x: ArrayLike, y: ArrayLike, sample_weight: Optional[ArrayLike], **kwargs +): + # da.map_blocks wasn't able to concatenate together the results + # when we reduce down to a scalar per block. So we make an + # array with 1 element. + if sample_weight is not None: + sample_weight = sample_weight.ravel() + return np.array( + [sklearn.metrics.log_loss(x, y, sample_weight=sample_weight, **kwargs)] + ) + + +def log_loss( + y_true, y_pred, eps=1e-15, normalize=True, sample_weight=None, labels=None +): + if not (dask.is_dask_collection(y_true) and dask.is_dask_collection(y_pred)): + return sklearn.metrics.log_loss( + y_true, + y_pred, + eps=eps, + normalize=normalize, + sample_weight=sample_weight, + labels=labels, + ) + + if y_pred.ndim > 1 and y_true.ndim == 1: + y_true = y_true.reshape(-1, 1) + drop_axis: Optional[int] = 1 + if sample_weight is not None: + sample_weight = sample_weight.reshape(-1, 1) + else: + drop_axis = None + + result = da.map_blocks( + _log_loss_inner, + y_true, + y_pred, + sample_weight, + chunks=(1,), + drop_axis=drop_axis, + dtype="f8", + eps=eps, + normalize=normalize, + labels=labels, + ) + if normalize and sample_weight is not None: + sample_weight = sample_weight.ravel() + block_weights = sample_weight.map_blocks(np.sum, chunks=(1,), keepdims=True) + return da.average(result, 0, weights=block_weights) + elif normalize: + return result.mean() + else: + return result.sum() + + +def _check_sample_weight(sample_weight: Optional[ArrayLike]): + if sample_weight is not None: + raise ValueError("'sample_weight' is not supported.") + + +@derived_from(sklearn.metrics) +def mean_squared_error( + y_true: ArrayLike, + y_pred: ArrayLike, + sample_weight: Optional[ArrayLike] = None, + multioutput: Optional[str] = "uniform_average", + squared: bool = True, + compute: bool = True, +) -> ArrayLike: + _check_sample_weight(sample_weight) + output_errors = ((y_pred - y_true) ** 2).mean(axis=0) + + if isinstance(multioutput, str) or multioutput is None: + if multioutput == "raw_values": + if compute: + return output_errors.compute() + else: + return output_errors + else: + raise ValueError("Weighted 'multioutput' not supported.") + result = output_errors.mean() + if not squared: + result = da.sqrt(result) + if compute: + result = result.compute() + return result + + +def _check_reg_targets( + y_true: ArrayLike, y_pred: ArrayLike, multioutput: Optional[str] +): + if multioutput is not None and multioutput != "uniform_average": + raise NotImplementedError("'multioutput' must be 'uniform_average'") + + if y_true.ndim == 1: + y_true = y_true.reshape((-1, 1)) + if y_pred.ndim == 1: + y_pred = y_pred.reshape((-1, 1)) + + # TODO: y_type, multioutput + return None, y_true, y_pred, multioutput + + +@derived_from(sklearn.metrics) +def r2_score( + y_true: ArrayLike, + y_pred: ArrayLike, + sample_weight: Optional[ArrayLike] = None, + multioutput: Optional[str] = "uniform_average", + compute: bool = True, +) -> ArrayLike: + _check_sample_weight(sample_weight) + _, y_true, y_pred, _ = _check_reg_targets(y_true, y_pred, multioutput) + weight = 1.0 + + numerator = (weight * (y_true - y_pred) ** 2).sum(axis=0, dtype="f8") + denominator = (weight * (y_true - y_true.mean(axis=0)) ** 2).sum(axis=0, dtype="f8") + + nonzero_denominator = denominator != 0 + nonzero_numerator = numerator != 0 + valid_score = nonzero_denominator & nonzero_numerator + output_chunks = getattr(y_true, "chunks", [None, None])[1] + output_scores = da.ones([y_true.shape[1]], chunks=output_chunks) + with np.errstate(all="ignore"): + output_scores[valid_score] = 1 - ( + numerator[valid_score] / denominator[valid_score] + ) + output_scores[nonzero_numerator & ~nonzero_denominator] = 0.0 + + result = output_scores.mean(axis=0) + if compute: + result = result.compute() + return result diff --git a/dask_sql/physical/rel/custom/wrappers.py b/dask_sql/physical/rel/custom/wrappers.py index 7ed0d0dea..c6432497b 100644 --- a/dask_sql/physical/rel/custom/wrappers.py +++ b/dask_sql/physical/rel/custom/wrappers.py @@ -3,11 +3,19 @@ """Meta-estimators for parallelizing estimators using the scikit-learn API.""" import logging import warnings +from typing import Any, Callable, Tuple, Union import dask.array as da import dask.dataframe as dd import dask.delayed import numpy as np +import sklearn.base +import sklearn.metrics +from dask.delayed import Delayed +from dask.highlevelgraph import HighLevelGraph +from sklearn.metrics import check_scoring as sklearn_check_scoring +from sklearn.metrics import make_scorer +from sklearn.utils.validation import check_is_fitted try: import sklearn.base @@ -15,9 +23,31 @@ except ImportError: # pragma: no cover raise ImportError("sklearn must be installed") +from dask_sql.physical.rel.custom.metrics import ( + accuracy_score, + log_loss, + mean_squared_error, + r2_score, +) + logger = logging.getLogger(__name__) +# Scorers +accuracy_scorer: Tuple[Any, Any] = (accuracy_score, {}) +neg_mean_squared_error_scorer = (mean_squared_error, dict(greater_is_better=False)) +r2_scorer: Tuple[Any, Any] = (r2_score, {}) +neg_log_loss_scorer = (log_loss, dict(greater_is_better=False, needs_proba=True)) + + +SCORERS = dict( + accuracy=accuracy_scorer, + neg_mean_squared_error=neg_mean_squared_error_scorer, + r2=r2_scorer, + neg_log_loss=neg_log_loss_scorer, +) + + class ParallelPostFit(sklearn.base.BaseEstimator, sklearn.base.MetaEstimatorMixin): """Meta-estimator for parallel predict and transform. @@ -231,9 +261,7 @@ def score(self, X, y, compute=True): if not dask.is_dask_collection(X) and not dask.is_dask_collection(y): scorer = sklearn.metrics.get_scorer(scoring) else: - # TODO: implement Dask-ML's get_scorer() function - # scorer = get_scorer(scoring, compute=compute) - raise NotImplementedError("get_scorer function not implemented") + scorer = get_scorer(scoring, compute=compute) return scorer(self, X, y) else: return self._postfit_estimator.score(X, y) @@ -386,6 +414,145 @@ def _check_method(self, method): return getattr(estimator, method) +class Incremental(ParallelPostFit): + """Metaestimator for feeding Dask Arrays to an estimator blockwise. + This wrapper provides a bridge between Dask objects and estimators + implementing the ``partial_fit`` API. These *incremental learners* can + train on batches of data. This fits well with Dask's blocked data + structures. + .. note:: + This meta-estimator is not appropriate for hyperparameter optimization + on larger-than-memory datasets. + See the `list of incremental learners`_ in the scikit-learn documentation + for a list of estimators that implement the ``partial_fit`` API. Note that + `Incremental` is not limited to just these classes, it will work on any + estimator implementing ``partial_fit``, including those defined outside of + scikit-learn itself. + Calling :meth:`Incremental.fit` with a Dask Array will pass each block of + the Dask array or arrays to ``estimator.partial_fit`` *sequentially*. + Like :class:`ParallelPostFit`, the methods available after fitting (e.g. + :meth:`Incremental.predict`, etc.) are all parallel and delayed. + The ``estimator_`` attribute is a clone of `estimator` that was actually + used during the call to ``fit``. All attributes learned during training + are available on ``Incremental`` directly. + .. _list of incremental learners: https://scikit-learn.org/stable/modules/computing.html#incremental-learning # noqa + Parameters + ---------- + estimator : Estimator + Any object supporting the scikit-learn ``partial_fit`` API. + scoring : string or callable, optional + A single string (see :ref:`scoring_parameter`) or a callable + (see :ref:`scoring`) to evaluate the predictions on the test set. + For evaluating multiple metrics, either give a list of (unique) + strings or a dict with names as keys and callables as values. + NOTE that when using custom scorers, each scorer should return a + single value. Metric functions returning a list/array of values + can be wrapped into multiple scorers that return one value each. + See :ref:`multimetric_grid_search` for an example. + .. warning:: + If None, the estimator's default scorer (if available) is used. + Most scikit-learn estimators will convert large Dask arrays to + a single NumPy array, which may exhaust the memory of your worker. + You probably want to always specify `scoring`. + random_state : int or numpy.random.RandomState, optional + Random object that determines how to shuffle blocks. + shuffle_blocks : bool, default True + Determines whether to call ``partial_fit`` on a randomly selected chunk + of the Dask arrays (default), or to fit in sequential order. This does + not control shuffle between blocks or shuffling each block. + predict_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer) + An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output + type of the estimators ``predict`` call. + This meta is necessary for for some estimators to work with + ``dask.dataframe`` and ``dask.array`` + predict_proba_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer) + An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output + type of the estimators ``predict_proba`` call. + This meta is necessary for for some estimators to work with + ``dask.dataframe`` and ``dask.array`` + transform_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer) + An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output + type of the estimators ``transform`` call. + This meta is necessary for for some estimators to work with + ``dask.dataframe`` and ``dask.array`` + Attributes + ---------- + estimator_ : Estimator + A clone of `estimator` that was actually fit during the ``.fit`` call. + + """ + + def __init__( + self, + estimator=None, + scoring=None, + shuffle_blocks=True, + random_state=None, + assume_equal_chunks=True, + predict_meta=None, + predict_proba_meta=None, + transform_meta=None, + ): + self.shuffle_blocks = shuffle_blocks + self.random_state = random_state + self.assume_equal_chunks = assume_equal_chunks + super(Incremental, self).__init__( + estimator=estimator, + scoring=scoring, + predict_meta=predict_meta, + predict_proba_meta=predict_proba_meta, + transform_meta=transform_meta, + ) + + @property + def _postfit_estimator(self): + check_is_fitted(self, "estimator_") + return self.estimator_ + + def _fit_for_estimator(self, estimator, X, y, **fit_kwargs): + check_scoring(estimator, self.scoring) + if not dask.is_dask_collection(X) and not dask.is_dask_collection(y): + result = estimator.partial_fit(X=X, y=y, **fit_kwargs) + else: + result = fit( + estimator, + X, + y, + random_state=self.random_state, + shuffle_blocks=self.shuffle_blocks, + assume_equal_chunks=self.assume_equal_chunks, + **fit_kwargs, + ) + + copy_learned_attributes(result, self) + self.estimator_ = result + return self + + def fit(self, X, y=None, **fit_kwargs): + estimator = sklearn.base.clone(self.estimator) + self._fit_for_estimator(estimator, X, y, **fit_kwargs) + return self + + def partial_fit(self, X, y=None, **fit_kwargs): + """Fit the underlying estimator. + If this estimator has not been previously fit, this is identical to + :meth:`Incremental.fit`. If it has been previously fit, + ``self.estimator_`` is used as the starting point. + Parameters + ---------- + X, y : array-like + **kwargs + Additional fit-kwargs for the underlying estimator. + Returns + ------- + self : object + """ + estimator = getattr(self, "estimator_", None) + if estimator is None: + estimator = sklearn.base.clone(self.estimator) + return self._fit_for_estimator(estimator, X, y, **fit_kwargs) + + def _predict(part, estimator, output_meta=None): if part.shape[0] == 0 and output_meta is not None: empty_output = handle_empty_partitions(output_meta) @@ -495,3 +662,147 @@ def copy_learned_attributes(from_estimator, to_estimator): for k, v in attrs.items(): setattr(to_estimator, k, v) + + +def get_scorer(scoring: Union[str, Callable], compute: bool = True) -> Callable: + """Get a scorer from string + Parameters + ---------- + scoring : str | callable + scoring method as string. If callable it is returned as is. + Returns + ------- + scorer : callable + The scorer. + """ + # This is the same as sklearns, only we use our SCORERS dict, + # and don't have back-compat code + if isinstance(scoring, str): + try: + scorer, kwargs = SCORERS[scoring] + except KeyError: + raise ValueError( + "{} is not a valid scoring value. " + "Valid options are {}".format(scoring, sorted(SCORERS)) + ) + else: + scorer = scoring + kwargs = {} + + kwargs["compute"] = compute + + return make_scorer(scorer, **kwargs) + + +def check_scoring(estimator, scoring=None, **kwargs): + res = sklearn_check_scoring(estimator, scoring=scoring, **kwargs) + if scoring in SCORERS.keys(): + func, kwargs = SCORERS[scoring] + return make_scorer(func, **kwargs) + return res + + +def fit( + model, + x, + y, + compute=True, + shuffle_blocks=True, + random_state=None, + assume_equal_chunks=False, + **kwargs, +): + """Fit scikit learn model against dask arrays + Model must support the ``partial_fit`` interface for online or batch + learning. + Ideally your rows are independent and identically distributed. By default, + this function will step through chunks of the arrays in random order. + Parameters + ---------- + model: sklearn model + Any model supporting partial_fit interface + x: dask Array + Two dimensional array, likely tall and skinny + y: dask Array + One dimensional array with same chunks as x's rows + compute : bool + Whether to compute this result + shuffle_blocks : bool + Whether to shuffle the blocks with ``random_state`` or not + random_state : int or numpy.random.RandomState + Random state to use when shuffling blocks + kwargs: + options to pass to partial_fit + """ + + nblocks, x_name = _blocks_and_name(x) + if y is not None: + y_nblocks, y_name = _blocks_and_name(y) + assert y_nblocks == nblocks + else: + y_name = "" + + if not hasattr(model, "partial_fit"): + msg = "The class '{}' does not implement 'partial_fit'." + raise ValueError(msg.format(type(model))) + + order = list(range(nblocks)) + if shuffle_blocks: + rng = sklearn.utils.check_random_state(random_state) + rng.shuffle(order) + + name = "fit-" + dask.base.tokenize(model, x, y, kwargs, order) + + if hasattr(x, "chunks") and x.ndim > 1: + x_extra = (0,) + else: + x_extra = () + + dsk = {(name, -1): model} + dsk.update( + { + (name, i): ( + _partial_fit, + (name, i - 1), + (x_name, order[i]) + x_extra, + (y_name, order[i]), + kwargs, + ) + for i in range(nblocks) + } + ) + + dependencies = [x] + if y is not None: + dependencies.append(y) + new_dsk = HighLevelGraph.from_collections(name, dsk, dependencies=dependencies) + value = Delayed((name, nblocks - 1), new_dsk, layer=name) + + if compute: + return value.compute() + else: + return value + + +def _blocks_and_name(obj): + if hasattr(obj, "chunks"): + nblocks = len(obj.chunks[0]) + name = obj.name + + elif hasattr(obj, "npartitions"): + # dataframe, bag + nblocks = obj.npartitions + if hasattr(obj, "_name"): + # dataframe + name = obj._name + else: + # bag + name = obj.name + + return nblocks, name + + +def _partial_fit(model, x, y, kwargs=None): + kwargs = kwargs or dict() + model.partial_fit(x, y, **kwargs) + return model diff --git a/docs/source/sql/ml.rst b/docs/source/sql/ml.rst index 5c3a3b9d1..7c388d1e7 100644 --- a/docs/source/sql/ml.rst +++ b/docs/source/sql/ml.rst @@ -48,9 +48,6 @@ The key-value parameters control, how and which model is trained: It is the full python module path to the class of the model to train. Any model class with sklearn interface is valid, but might or might not work well with Dask dataframes. - Have a look into the - `dask-ml documentation `_ - for more information on which models work best. You might need to install necessary packages to use the models. * ``target_column``: @@ -63,17 +60,13 @@ The key-value parameters control, how and which model is trained: * ``wrap_predict``: Boolean flag, whether to wrap the selected model with a :class:`dask_sql.physical.rel.custom.wrappers.ParallelPostFit`. - Have a look into the - `dask-ml docu on ParallelPostFit `_ - to learn more about it. Defaults to false. Typically you set - it to true for sklearn models if predicting on big data. + Defaults to false. Typically you set it to true for + sklearn models if predicting on big data. * ``wrap_fit``: Boolean flag, whether to wrap the selected - model with a :class:`dask_ml.wrappers.Incremental`. - Have a look into the - `dask-ml docu on Incremental `_ - to learn more about it. Defaults to false. Typically you set - it to true for sklearn models if training on big data. + model with a :class:`dask_sql.physical.rel.custom.wrappers.Incremental`. + Defaults to false. Typically you set it to true for + sklearn models if training on big data. * ``fit_kwargs``: keyword arguments sent to the call to ``fit()``. @@ -85,7 +78,7 @@ Example: .. raw:: html
CREATE MODEL my_model WITH (
-        model_class = 'dask_ml.xgboost.XGBClassifier',
+        model_class = 'xgboost.XGBClassifier',
         target_column = 'target'
     ) AS (
         SELECT x, y, target
@@ -104,11 +97,10 @@ prediction, depending if your model can cope with
 dask dataframes.
 
     * if you are training on relatively small amounts
-      of data but predicting on large data samples
-      (and you are not using a model build for usage with dask
-      from the dask-ml package), you might want to set
-      ``wrap_predict`` to True. With this option,
-      model interference will be parallelized/distributed.
+      of data but predicting on large data samples,
+      you might want to set ``wrap_predict`` to True.
+      With this option, model interference will be
+      parallelized/distributed.
     * If you are training on large amounts of data,
       you can try setting wrap_fit to True. This will
       do the same on the training step, but works only on
diff --git a/tests/unit/test_ml_wrappers.py b/tests/unit/test_ml_wrappers.py
new file mode 100644
index 000000000..97277c1ad
--- /dev/null
+++ b/tests/unit/test_ml_wrappers.py
@@ -0,0 +1,250 @@
+# Copyright 2017, Dask developers
+# Dask-ML project - https://github.com/dask/dask-ml
+from collections.abc import Sequence
+
+import dask
+import dask.array as da
+import dask.dataframe as dd
+import numpy as np
+import pandas as pd
+import pytest
+from dask.array.utils import assert_eq as assert_eq_ar
+from dask.dataframe.utils import assert_eq as assert_eq_df
+from sklearn.base import clone
+from sklearn.decomposition import PCA
+from sklearn.ensemble import GradientBoostingClassifier
+from sklearn.linear_model import LogisticRegression, SGDClassifier
+
+from dask_sql.physical.rel.custom.wrappers import Incremental, ParallelPostFit
+
+
+def _check_axis_partitioning(chunks, n_features):
+    c = chunks[1][0]
+    if c != n_features:
+        msg = (
+            "Can only generate arrays partitioned along the "
+            "first axis. Specifying a larger chunksize for "
+            "the second axis.\n\n\tchunk size: {}\n"
+            "\tn_features: {}".format(c, n_features)
+        )
+        raise ValueError(msg)
+
+
+def check_random_state(random_state):
+    if random_state is None:
+        return da.random.RandomState()
+    # elif isinstance(random_state, Integral):
+    #     return da.random.RandomState(random_state)
+    elif isinstance(random_state, np.random.RandomState):
+        return da.random.RandomState(random_state.randint())
+    elif isinstance(random_state, da.random.RandomState):
+        return random_state
+    else:
+        raise TypeError("Unexpected type '{}'".format(type(random_state)))
+
+
+def make_classification(
+    n_samples=100,
+    n_features=20,
+    n_informative=2,
+    n_classes=2,
+    scale=1.0,
+    random_state=None,
+    chunks=None,
+):
+    chunks = da.core.normalize_chunks(chunks, (n_samples, n_features))
+    _check_axis_partitioning(chunks, n_features)
+
+    if n_classes != 2:
+        raise NotImplementedError("n_classes != 2 is not yet supported.")
+
+    rng = check_random_state(random_state)
+
+    X = rng.normal(0, 1, size=(n_samples, n_features), chunks=chunks)
+    informative_idx = rng.choice(n_features, n_informative, chunks=n_informative)
+    beta = (rng.random(n_features, chunks=n_features) - 1) * scale
+
+    informative_idx, beta = dask.compute(
+        informative_idx, beta, scheduler="single-threaded"
+    )
+
+    z0 = X[:, informative_idx].dot(beta[informative_idx])
+    y = rng.random(z0.shape, chunks=chunks[0]) < 1 / (1 + da.exp(-z0))
+    y = y.astype(int)
+
+    return X, y
+
+
+def _assert_eq(l, r, name=None, **kwargs):
+    array_types = (np.ndarray, da.Array)
+    frame_types = (pd.core.generic.NDFrame, dd._Frame)
+    if isinstance(l, array_types):
+        assert_eq_ar(l, r, **kwargs)
+    elif isinstance(l, frame_types):
+        assert_eq_df(l, r, **kwargs)
+    elif isinstance(l, Sequence) and any(
+        isinstance(x, array_types + frame_types) for x in l
+    ):
+        for a, b in zip(l, r):
+            _assert_eq(a, b, **kwargs)
+    elif np.isscalar(r) and np.isnan(r):
+        assert np.isnan(l), (name, l, r)
+    else:
+        assert l == r, (name, l, r)
+
+
+def assert_estimator_equal(left, right, exclude=None, **kwargs):
+    """Check that two Estimators are equal
+    Parameters
+    ----------
+    left, right : Estimators
+    exclude : str or sequence of str
+        attributes to skip in the check
+    kwargs : dict
+        Passed through to the dask `assert_eq` method.
+    """
+    left_attrs = [x for x in dir(left) if x.endswith("_") and not x.startswith("_")]
+    right_attrs = [x for x in dir(right) if x.endswith("_") and not x.startswith("_")]
+    if exclude is None:
+        exclude = set()
+    elif isinstance(exclude, str):
+        exclude = {exclude}
+    else:
+        exclude = set(exclude)
+
+    left_attrs2 = set(left_attrs) - exclude
+    right_attrs2 = set(right_attrs) - exclude
+
+    assert left_attrs2 == right_attrs2, left_attrs2 ^ right_attrs2
+
+    for attr in left_attrs2:
+        l = getattr(left, attr)
+        r = getattr(right, attr)
+        _assert_eq(l, r, name=attr, **kwargs)
+
+
+def test_parallelpostfit_basic():
+    clf = ParallelPostFit(GradientBoostingClassifier())
+
+    X, y = make_classification(n_samples=1000, chunks=100)
+    X_, y_ = dask.compute(X, y)
+    clf.fit(X_, y_)
+
+    assert isinstance(clf.predict(X), da.Array)
+    assert isinstance(clf.predict_proba(X), da.Array)
+
+    result = clf.score(X, y)
+    expected = clf.estimator.score(X_, y_)
+    assert result == expected
+
+
+@pytest.mark.parametrize("kind", ["numpy", "dask.dataframe", "dask.array"])
+def test_predict(kind):
+    X, y = make_classification(chunks=100)
+
+    if kind == "numpy":
+        X, y = dask.compute(X, y)
+    elif kind == "dask.dataframe":
+        X = dd.from_dask_array(X)
+        y = dd.from_dask_array(y)
+
+    base = LogisticRegression(random_state=0, n_jobs=1, solver="lbfgs")
+    wrap = ParallelPostFit(
+        LogisticRegression(random_state=0, n_jobs=1, solver="lbfgs"),
+    )
+
+    base.fit(*dask.compute(X, y))
+    wrap.fit(*dask.compute(X, y))
+
+    assert_estimator_equal(wrap.estimator, base)
+
+    result = wrap.predict(X)
+    expected = base.predict(X)
+    assert_eq_ar(result, expected)
+
+    result = wrap.predict_proba(X)
+    expected = base.predict_proba(X)
+    assert_eq_ar(result, expected)
+
+    result = wrap.predict_log_proba(X)
+    expected = base.predict_log_proba(X)
+    assert_eq_ar(result, expected)
+
+
+@pytest.mark.parametrize("kind", ["numpy", "dask.dataframe", "dask.array"])
+def test_transform(kind):
+    X, y = make_classification(chunks=100)
+
+    if kind == "numpy":
+        X, y = dask.compute(X, y)
+    elif kind == "dask.dataframe":
+        X = dd.from_dask_array(X)
+        y = dd.from_dask_array(y)
+
+    base = PCA(random_state=0)
+    wrap = ParallelPostFit(PCA(random_state=0))
+
+    base.fit(*dask.compute(X, y))
+    wrap.fit(*dask.compute(X, y))
+
+    assert_estimator_equal(wrap.estimator, base)
+
+    result = base.transform(*dask.compute(X))
+    expected = wrap.transform(X)
+    assert_eq_ar(result, expected)
+
+
+@pytest.mark.parametrize("dataframes", [False, True])
+def test_incremental_basic(dataframes):
+    # Create observations that we know linear models can recover
+    n, d = 100, 3
+    rng = da.random.RandomState(42)
+    X = rng.normal(size=(n, d), chunks=30)
+    coef_star = rng.uniform(size=d, chunks=d)
+    y = da.sign(X.dot(coef_star))
+    y = (y + 1) / 2
+    if dataframes:
+        X = dd.from_array(X)
+        y = dd.from_array(y)
+
+    est1 = SGDClassifier(random_state=0, tol=1e-3, average=True)
+    est2 = clone(est1)
+
+    clf = Incremental(est1, random_state=0)
+    result = clf.fit(X, y, classes=[0, 1])
+    assert result is clf
+
+    # est2 is a sklearn optimizer; this is just a benchmark
+    if dataframes:
+        X = X.to_dask_array(lengths=True)
+        y = y.to_dask_array(lengths=True)
+
+    for slice_ in da.core.slices_from_chunks(X.chunks):
+        est2.partial_fit(X[slice_].compute(), y[slice_[0]].compute(), classes=[0, 1])
+
+    assert isinstance(result.estimator_.coef_, np.ndarray)
+    rel_error = np.linalg.norm(clf.coef_ - est2.coef_)
+    rel_error /= np.linalg.norm(clf.coef_)
+    assert rel_error < 0.9
+
+    assert set(dir(clf.estimator_)) == set(dir(est2))
+
+    #  Predict
+    result = clf.predict(X)
+    expected = est2.predict(X)
+    assert isinstance(result, da.Array)
+    if dataframes:
+        # Compute is needed because chunk sizes of this array are unknown
+        result = result.compute()
+    rel_error = np.linalg.norm(result - expected)
+    rel_error /= np.linalg.norm(expected)
+    assert rel_error < 0.3
+
+    # score
+    result = clf.score(X, y)
+    expected = est2.score(*dask.compute(X, y))
+    assert abs(result - expected) < 0.1
+
+    clf = Incremental(SGDClassifier(random_state=0, tol=1e-3, average=True))
+    clf.partial_fit(X, y, classes=[0, 1])
+    assert set(dir(clf.estimator_)) == set(dir(est2))