From 9680b9097b14090d41b681be7af1ce41af60b023 Mon Sep 17 00:00:00 2001 From: danielenricocahall Date: Wed, 3 Feb 2021 07:29:26 -0500 Subject: [PATCH 1/3] add support for predicting probabilities in classification problems --- elephas/ml/params.py | 15 ++++++++++++ elephas/ml_model.py | 39 ++++++++++++++++++------------ elephas/utils/model_utils.py | 5 ++-- elephas/utils/warnings.py | 2 ++ tests/test_ml_model.py | 46 ++++++++++++++++++++++++++++++++++++ 5 files changed, 90 insertions(+), 17 deletions(-) create mode 100644 elephas/utils/warnings.py diff --git a/elephas/ml/params.py b/elephas/ml/params.py index 719483a..01f84bd 100644 --- a/elephas/ml/params.py +++ b/elephas/ml/params.py @@ -243,3 +243,18 @@ def set_custom_objects(self, custom_objects): def get_custom_objects(self): return self.getOrDefault(self.custom_objects) + + +class HasPredictClasses(Params): + def __init__(self): + super(HasPredictClasses, self).__init__() + self.predict_classes = Param(self, "predict_classes", "Flag to predict class or probability in classification " + "problems") + self._setDefault(predict_classes=True) + + def set_predict_classes(self, predict_classes): + self._paramMap[self.predict_classes] = predict_classes + return self + + def get_predict_classes(self): + return self.getOrDefault(self.predict_classes) diff --git a/elephas/ml_model.py b/elephas/ml_model.py index d418f01..5b95d80 100644 --- a/elephas/ml_model.py +++ b/elephas/ml_model.py @@ -10,7 +10,7 @@ from pyspark import keyword_only, RDD from pyspark.ml import Estimator, Model from pyspark.sql import DataFrame -from pyspark.sql.types import StringType, DoubleType, StructField +from pyspark.sql.types import StringType, DoubleType, StructField, ArrayType from tensorflow.keras.models import model_from_yaml from tensorflow.keras.optimizers import get as get_optimizer @@ -20,12 +20,13 @@ from .utils.model_utils import LossModelTypeMapper, ModelType, determine_predict_function from .ml.adapter import df_to_simple_rdd from .ml.params import * +from .utils.warnings import ElephasWarning class ElephasEstimator(Estimator, HasCategoricalLabels, HasValidationSplit, HasKerasModelConfig, HasFeaturesCol, HasLabelCol, HasMode, HasEpochs, HasBatchSize, HasFrequency, HasVerbosity, HasNumberOfClasses, HasNumberOfWorkers, HasOutputCol, HasLoss, - HasMetrics, HasKerasOptimizerConfig, HasCustomObjects): + HasMetrics, HasKerasOptimizerConfig, HasCustomObjects, HasPredictClasses): """ SparkML Estimator implementation of an elephas model. This estimator takes all relevant arguments for model compilation and training. @@ -108,6 +109,7 @@ def _fit(self, df: DataFrame): keras_model_config=spark_model.master_network.to_yaml(), weights=weights, custom_objects=self.get_custom_objects(), + predict_classes=self.get_predict_classes(), loss=loss) def setFeaturesCol(self, value): @@ -125,6 +127,12 @@ def setOutputCol(self, value): " ElephasEstimator(outputCol='foo')", DeprecationWarning) return self._set(outputCol=value) + def set_predict_classes(self, predict_classes): + if LossModelTypeMapper().get_model_type(self.get_loss()) == ModelType.REGRESSION: + warnings.warn("Setting `predict_classes` doesn't have any effect when training a regression problem.", + ElephasWarning) + super().set_predict_classes(predict_classes) + def load_ml_estimator(file_name): f = h5py.File(file_name, mode='r') @@ -133,7 +141,8 @@ def load_ml_estimator(file_name): return ElephasEstimator(**config) -class ElephasTransformer(Model, HasKerasModelConfig, HasLabelCol, HasOutputCol, HasFeaturesCol, HasCustomObjects): +class ElephasTransformer(Model, HasKerasModelConfig, HasLabelCol, HasOutputCol, HasFeaturesCol, HasCustomObjects, + HasPredictClasses): """SparkML Transformer implementation. Contains a trained model, with which new feature data can be transformed into labels. """ @@ -178,21 +187,19 @@ def _transform(self, df): """Private transform method of a Transformer. This serves as batch-prediction method for our purposes. """ output_col = self.getOutputCol() - label_col = self.getLabelCol() new_schema = copy.deepcopy(df.schema) - new_schema.add(StructField(output_col, StringType(), True)) rdd = df.rdd - weights = self.weights def extract_features_and_predict(model_yaml: str, custom_objects: dict, features_col: str, model_type: ModelType, + predict_classes: bool, data): model = model_from_yaml(model_yaml, custom_objects) model.set_weights(weights.value) - predict_function = determine_predict_function(model, model_type) + predict_function = determine_predict_function(model, model_type, predict_classes) return predict_function(np.stack([from_vector(x[features_col]) for x in data])) predictions = rdd.mapPartitions( @@ -200,18 +207,20 @@ def extract_features_and_predict(model_yaml: str, self.get_keras_model_config(), self.get_custom_objects(), self.getFeaturesCol(), - self.model_type)) - if self.model_type == ModelType.CLASSIFICATION: - predictions = predictions.map(lambda x: tuple(str(x))) - else: + self.model_type, + self.get_predict_classes())) + if (self.model_type == ModelType.CLASSIFICATION and self.get_predict_classes()) \ + or self.model_type == ModelType.REGRESSION: predictions = predictions.map(lambda x: tuple([float(x)])) + output_col_field = StructField(output_col, DoubleType(), True) + else: + # we're doing classification and predicting class probabilities + predictions = predictions.map(lambda x: tuple([x.tolist()])) + output_col_field = StructField(output_col, ArrayType(DoubleType()), True) results_rdd = rdd.zip(predictions).map(lambda x: x[0] + x[1]) + new_schema.add(output_col_field) results_df = df.sql_ctx.createDataFrame(results_rdd, new_schema) - results_df = results_df.withColumn( - output_col, results_df[output_col].cast(DoubleType())) - results_df = results_df.withColumn( - label_col, results_df[label_col].cast(DoubleType())) return results_df diff --git a/elephas/utils/model_utils.py b/elephas/utils/model_utils.py index b523462..54faba8 100644 --- a/elephas/utils/model_utils.py +++ b/elephas/utils/model_utils.py @@ -55,8 +55,9 @@ def register_loss(self, loss, model_type): def determine_predict_function(model: tensorflow.keras.models.Model, - model_type: ModelType): - if model_type == ModelType.CLASSIFICATION: + model_type: ModelType, + predict_classes: bool = True): + if model_type == ModelType.CLASSIFICATION and predict_classes: if isinstance(model, tensorflow.keras.models.Sequential): predict_function = model.predict_classes else: diff --git a/elephas/utils/warnings.py b/elephas/utils/warnings.py new file mode 100644 index 0000000..4512668 --- /dev/null +++ b/elephas/utils/warnings.py @@ -0,0 +1,2 @@ +class ElephasWarning(Warning): + """Custom warning class for any Elephas issues""" diff --git a/tests/test_ml_model.py b/tests/test_ml_model.py index cb28852..c04ecc2 100644 --- a/tests/test_ml_model.py +++ b/tests/test_ml_model.py @@ -10,6 +10,8 @@ from pyspark.mllib.evaluation import MulticlassMetrics, RegressionMetrics from pyspark.ml import Pipeline +from elephas.utils.warnings import ElephasWarning + def test_serialization_transformer(classification_model): transformer = ElephasTransformer() @@ -242,4 +244,48 @@ def custom_activation(x): prediction = fitted_pipeline.transform(test_df) +def test_predict_classes_probability(spark_context, classification_model, mnist_data): + batch_size = 64 + nb_classes = 10 + epochs = 1 + + x_train, y_train, x_test, y_test = mnist_data + x_train = x_train[:1000] + y_train = y_train[:1000] + df = to_data_frame(spark_context, x_train, y_train, categorical=True) + test_df = to_data_frame(spark_context, x_test, y_test, categorical=True) + + sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True) + sgd_conf = optimizers.serialize(sgd) + + # Initialize Spark ML Estimator + estimator = ElephasEstimator() + estimator.set_keras_model_config(classification_model.to_yaml()) + estimator.set_optimizer_config(sgd_conf) + estimator.set_mode("synchronous") + estimator.set_loss("categorical_crossentropy") + estimator.set_metrics(['acc']) + estimator.set_predict_classes(False) + estimator.set_epochs(epochs) + estimator.set_batch_size(batch_size) + estimator.set_validation_split(0.1) + estimator.set_categorical_labels(True) + estimator.set_nb_classes(nb_classes) + + # Fitting a model returns a Transformer + pipeline = Pipeline(stages=[estimator]) + fitted_pipeline = pipeline.fit(df) + + # Evaluate Spark model by evaluating the underlying model + prediction = fitted_pipeline.transform(test_df) + pnl = prediction.select("label", "prediction") + pnl.show(100) + +def test_set_predict_classes_regression_warning(spark_context, regression_model): + with pytest.warns(ElephasWarning): + estimator = ElephasEstimator() + estimator.set_loss("mae") + estimator.set_metrics(['mae']) + estimator.set_categorical_labels(False) + estimator.set_predict_classes(True) From 9215796b7d144f82933e0d5f1090ea41c39b9afa Mon Sep 17 00:00:00 2001 From: danielenricocahall Date: Wed, 3 Feb 2021 07:29:48 -0500 Subject: [PATCH 2/3] delete unused imports --- elephas/ml_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/elephas/ml_model.py b/elephas/ml_model.py index 5b95d80..f70c347 100644 --- a/elephas/ml_model.py +++ b/elephas/ml_model.py @@ -7,10 +7,10 @@ import json from pyspark.ml.param.shared import HasOutputCol, HasFeaturesCol, HasLabelCol -from pyspark import keyword_only, RDD +from pyspark import keyword_only from pyspark.ml import Estimator, Model from pyspark.sql import DataFrame -from pyspark.sql.types import StringType, DoubleType, StructField, ArrayType +from pyspark.sql.types import DoubleType, StructField, ArrayType from tensorflow.keras.models import model_from_yaml from tensorflow.keras.optimizers import get as get_optimizer From 2ea83042b482ec6ea4eb0b596a9761e4acd4babf Mon Sep 17 00:00:00 2001 From: danielenricocahall Date: Wed, 3 Feb 2021 07:43:41 -0500 Subject: [PATCH 3/3] update test --- tests/test_ml_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_ml_model.py b/tests/test_ml_model.py index c04ecc2..ca5dc00 100644 --- a/tests/test_ml_model.py +++ b/tests/test_ml_model.py @@ -276,10 +276,10 @@ def test_predict_classes_probability(spark_context, classification_model, mnist_ pipeline = Pipeline(stages=[estimator]) fitted_pipeline = pipeline.fit(df) - # Evaluate Spark model by evaluating the underlying model - prediction = fitted_pipeline.transform(test_df) - pnl = prediction.select("label", "prediction") - pnl.show(100) + results = fitted_pipeline.transform(test_df) + # we should have an array of 10 elements in the prediction column, since we have 10 classes + # and therefore 10 probabilities + assert len(results.take(1)[0].prediction) == 10 def test_set_predict_classes_regression_warning(spark_context, regression_model):