Skip to content

Commit

Permalink
Merge pull request #177 from maxpumperla/add-support-for-predicting-c…
Browse files Browse the repository at this point in the history
…lass-probs

Add support for predicting class probabilities
  • Loading branch information
danielenricocahall authored Feb 6, 2021
2 parents cd87574 + 2ea8304 commit e587406
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 18 deletions.
15 changes: 15 additions & 0 deletions elephas/ml/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,18 @@ def set_custom_objects(self, custom_objects):

def get_custom_objects(self):
return self.getOrDefault(self.custom_objects)


class HasPredictClasses(Params):
def __init__(self):
super(HasPredictClasses, self).__init__()
self.predict_classes = Param(self, "predict_classes", "Flag to predict class or probability in classification "
"problems")
self._setDefault(predict_classes=True)

def set_predict_classes(self, predict_classes):
self._paramMap[self.predict_classes] = predict_classes
return self

def get_predict_classes(self):
return self.getOrDefault(self.predict_classes)
41 changes: 25 additions & 16 deletions elephas/ml_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import json

from pyspark.ml.param.shared import HasOutputCol, HasFeaturesCol, HasLabelCol
from pyspark import keyword_only, RDD
from pyspark import keyword_only
from pyspark.ml import Estimator, Model
from pyspark.sql import DataFrame
from pyspark.sql.types import StringType, DoubleType, StructField
from pyspark.sql.types import DoubleType, StructField, ArrayType

from tensorflow.keras.models import model_from_yaml
from tensorflow.keras.optimizers import get as get_optimizer
Expand All @@ -20,12 +20,13 @@
from .utils.model_utils import LossModelTypeMapper, ModelType, determine_predict_function
from .ml.adapter import df_to_simple_rdd
from .ml.params import *
from .utils.warnings import ElephasWarning


class ElephasEstimator(Estimator, HasCategoricalLabels, HasValidationSplit, HasKerasModelConfig, HasFeaturesCol,
HasLabelCol, HasMode, HasEpochs, HasBatchSize, HasFrequency, HasVerbosity, HasNumberOfClasses,
HasNumberOfWorkers, HasOutputCol, HasLoss,
HasMetrics, HasKerasOptimizerConfig, HasCustomObjects):
HasMetrics, HasKerasOptimizerConfig, HasCustomObjects, HasPredictClasses):
"""
SparkML Estimator implementation of an elephas model. This estimator takes all relevant arguments for model
compilation and training.
Expand Down Expand Up @@ -108,6 +109,7 @@ def _fit(self, df: DataFrame):
keras_model_config=spark_model.master_network.to_yaml(),
weights=weights,
custom_objects=self.get_custom_objects(),
predict_classes=self.get_predict_classes(),
loss=loss)

def setFeaturesCol(self, value):
Expand All @@ -125,6 +127,12 @@ def setOutputCol(self, value):
" ElephasEstimator(outputCol='foo')", DeprecationWarning)
return self._set(outputCol=value)

def set_predict_classes(self, predict_classes):
if LossModelTypeMapper().get_model_type(self.get_loss()) == ModelType.REGRESSION:
warnings.warn("Setting `predict_classes` doesn't have any effect when training a regression problem.",
ElephasWarning)
super().set_predict_classes(predict_classes)


def load_ml_estimator(file_name):
f = h5py.File(file_name, mode='r')
Expand All @@ -133,7 +141,8 @@ def load_ml_estimator(file_name):
return ElephasEstimator(**config)


class ElephasTransformer(Model, HasKerasModelConfig, HasLabelCol, HasOutputCol, HasFeaturesCol, HasCustomObjects):
class ElephasTransformer(Model, HasKerasModelConfig, HasLabelCol, HasOutputCol, HasFeaturesCol, HasCustomObjects,
HasPredictClasses):
"""SparkML Transformer implementation. Contains a trained model,
with which new feature data can be transformed into labels.
"""
Expand Down Expand Up @@ -178,40 +187,40 @@ def _transform(self, df):
"""Private transform method of a Transformer. This serves as batch-prediction method for our purposes.
"""
output_col = self.getOutputCol()
label_col = self.getLabelCol()
new_schema = copy.deepcopy(df.schema)
new_schema.add(StructField(output_col, StringType(), True))
rdd = df.rdd

weights = self.weights

def extract_features_and_predict(model_yaml: str,
custom_objects: dict,
features_col: str,
model_type: ModelType,
predict_classes: bool,
data):
model = model_from_yaml(model_yaml, custom_objects)
model.set_weights(weights.value)
predict_function = determine_predict_function(model, model_type)
predict_function = determine_predict_function(model, model_type, predict_classes)
return predict_function(np.stack([from_vector(x[features_col]) for x in data]))

predictions = rdd.mapPartitions(
partial(extract_features_and_predict,
self.get_keras_model_config(),
self.get_custom_objects(),
self.getFeaturesCol(),
self.model_type))
if self.model_type == ModelType.CLASSIFICATION:
predictions = predictions.map(lambda x: tuple(str(x)))
else:
self.model_type,
self.get_predict_classes()))
if (self.model_type == ModelType.CLASSIFICATION and self.get_predict_classes()) \
or self.model_type == ModelType.REGRESSION:
predictions = predictions.map(lambda x: tuple([float(x)]))
output_col_field = StructField(output_col, DoubleType(), True)
else:
# we're doing classification and predicting class probabilities
predictions = predictions.map(lambda x: tuple([x.tolist()]))
output_col_field = StructField(output_col, ArrayType(DoubleType()), True)
results_rdd = rdd.zip(predictions).map(lambda x: x[0] + x[1])

new_schema.add(output_col_field)
results_df = df.sql_ctx.createDataFrame(results_rdd, new_schema)
results_df = results_df.withColumn(
output_col, results_df[output_col].cast(DoubleType()))
results_df = results_df.withColumn(
label_col, results_df[label_col].cast(DoubleType()))

return results_df

Expand Down
5 changes: 3 additions & 2 deletions elephas/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ def register_loss(self, loss, model_type):


def determine_predict_function(model: tensorflow.keras.models.Model,
model_type: ModelType):
if model_type == ModelType.CLASSIFICATION:
model_type: ModelType,
predict_classes: bool = True):
if model_type == ModelType.CLASSIFICATION and predict_classes:
if isinstance(model, tensorflow.keras.models.Sequential):
predict_function = model.predict_classes
else:
Expand Down
2 changes: 2 additions & 0 deletions elephas/utils/warnings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class ElephasWarning(Warning):
"""Custom warning class for any Elephas issues"""
46 changes: 46 additions & 0 deletions tests/test_ml_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from pyspark.mllib.evaluation import MulticlassMetrics, RegressionMetrics
from pyspark.ml import Pipeline

from elephas.utils.warnings import ElephasWarning


def test_serialization_transformer(classification_model):
transformer = ElephasTransformer()
Expand Down Expand Up @@ -242,4 +244,48 @@ def custom_activation(x):
prediction = fitted_pipeline.transform(test_df)


def test_predict_classes_probability(spark_context, classification_model, mnist_data):
batch_size = 64
nb_classes = 10
epochs = 1

x_train, y_train, x_test, y_test = mnist_data
x_train = x_train[:1000]
y_train = y_train[:1000]
df = to_data_frame(spark_context, x_train, y_train, categorical=True)
test_df = to_data_frame(spark_context, x_test, y_test, categorical=True)

sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
sgd_conf = optimizers.serialize(sgd)

# Initialize Spark ML Estimator
estimator = ElephasEstimator()
estimator.set_keras_model_config(classification_model.to_yaml())
estimator.set_optimizer_config(sgd_conf)
estimator.set_mode("synchronous")
estimator.set_loss("categorical_crossentropy")
estimator.set_metrics(['acc'])
estimator.set_predict_classes(False)
estimator.set_epochs(epochs)
estimator.set_batch_size(batch_size)
estimator.set_validation_split(0.1)
estimator.set_categorical_labels(True)
estimator.set_nb_classes(nb_classes)

# Fitting a model returns a Transformer
pipeline = Pipeline(stages=[estimator])
fitted_pipeline = pipeline.fit(df)

results = fitted_pipeline.transform(test_df)
# we should have an array of 10 elements in the prediction column, since we have 10 classes
# and therefore 10 probabilities
assert len(results.take(1)[0].prediction) == 10


def test_set_predict_classes_regression_warning(spark_context, regression_model):
with pytest.warns(ElephasWarning):
estimator = ElephasEstimator()
estimator.set_loss("mae")
estimator.set_metrics(['mae'])
estimator.set_categorical_labels(False)
estimator.set_predict_classes(True)

0 comments on commit e587406

Please sign in to comment.