diff --git a/jupyter-book/_toc.yml b/jupyter-book/_toc.yml index 277bc22eb..01d356e74 100644 --- a/jupyter-book/_toc.yml +++ b/jupyter-book/_toc.yml @@ -102,14 +102,15 @@ parts: - file: python_scripts/linear_models_ex_02 - file: python_scripts/linear_models_sol_02 - file: python_scripts/linear_models_feature_engineering_classification.py - - file: python_scripts/logistic_regression_non_linear + - file: python_scripts/linear_models_ex_03 + - file: python_scripts/linear_models_sol_03 - file: linear_models/linear_models_quiz_m4_02 - file: linear_models/linear_models_regularization_index sections: - file: linear_models/regularized_linear_models_slides - file: python_scripts/linear_models_regularization - - file: python_scripts/linear_models_ex_03 - - file: python_scripts/linear_models_sol_03 + - file: python_scripts/linear_models_ex_04 + - file: python_scripts/linear_models_sol_04 - file: linear_models/linear_models_quiz_m4_03 - file: linear_models/linear_models_wrap_up_quiz - file: linear_models/linear_models_module_take_away diff --git a/python_scripts/linear_models_ex_03.py b/python_scripts/linear_models_ex_03.py index 9c311e817..50fe942cd 100644 --- a/python_scripts/linear_models_ex_03.py +++ b/python_scripts/linear_models_ex_03.py @@ -14,69 +14,118 @@ # %% [markdown] # # 📝 Exercise M4.03 # -# The parameter `penalty` can control the **type** of regularization to use, -# whereas the regularization **strength** is set using the parameter `C`. -# Setting`penalty="none"` is equivalent to an infinitely large value of `C`. In -# this exercise, we ask you to train a logistic regression classifier using the -# `penalty="l2"` regularization (which happens to be the default in -# scikit-learn) to find by yourself the effect of the parameter `C`. -# -# We start by loading the dataset. +# Now, we tackle a more realistic classification problem instead of making a +# synthetic dataset. We start by loading the Adult Census dataset with the +# following snippet. For the moment we retain only the **numerical features**. + +# %% +import pandas as pd + +adult_census = pd.read_csv("../datasets/adult-census.csv") +target = adult_census["class"] +data = adult_census.select_dtypes(["integer", "floating"]) +data = data.drop(columns=["education-num"]) +data # %% [markdown] -# ```{note} -# If you want a deeper overview regarding this dataset, you can refer to the -# Appendix - Datasets description section at the end of this MOOC. -# ``` +# We confirm that all the selected features are numerical. +# +# Compute the generalization performance in terms of accuracy of a linear model +# composed of a `StandardScaler` and a `LogisticRegression`. Use a 10-fold +# cross-validation with `return_estimator=True` to be able to inspect the +# trained estimators. # %% -import pandas as pd +# Write your code here. -penguins = pd.read_csv("../datasets/penguins_classification.csv") -# only keep the Adelie and Chinstrap classes -penguins = ( - penguins.set_index("Species").loc[["Adelie", "Chinstrap"]].reset_index() -) +# %% [markdown] +# What is the most important feature seen by the logistic regression? +# +# You can use a boxplot to compare the absolute values of the coefficients while +# also visualizing the variability induced by the cross-validation resampling. + +# %% +# Write your code here. -culmen_columns = ["Culmen Length (mm)", "Culmen Depth (mm)"] -target_column = "Species" +# %% [markdown] +# Let's now work with **both numerical and categorical features**. You can +# reload the Adult Census dataset with the following snippet: # %% -from sklearn.model_selection import train_test_split +adult_census = pd.read_csv("../datasets/adult-census.csv") +target = adult_census["class"] +data = adult_census.drop(columns=["class", "education-num"]) + +# %% [markdown] +# Create a predictive model where: +# - The numerical data must be scaled. +# - The categorical data must be one-hot encoded, set `min_frequency=0.01` to +# group categories concerning less than 1% of the total samples. +# - The predictor is a `LogisticRegression`. You may need to increase the number +# of `max_iter`, which is 100 by default. +# +# Use the same 10-fold cross-validation strategy with `return_estimator=True` as +# above to evaluate this complex pipeline. -penguins_train, penguins_test = train_test_split(penguins, random_state=0) +# %% +# Write your code here. -data_train = penguins_train[culmen_columns] -data_test = penguins_test[culmen_columns] +# %% [markdown] +# By comparing the cross-validation test scores of both models fold-to-fold, +# count the number of times the model using both numerical and categorical +# features has a better test score than the model using only numerical features. -target_train = penguins_train[target_column] -target_test = penguins_test[target_column] +# %% +# Write your code here. # %% [markdown] -# First, let's create our predictive model. +# For the following questions, you can copy adn paste the following snippet to +# get the feature names from the column transformer here named `preprocessor`. +# +# ```python +# preprocessor.fit(data) +# feature_names = ( +# preprocessor.named_transformers_["onehotencoder"].get_feature_names_out( +# categorical_columns +# ) +# ).tolist() +# feature_names += numerical_columns +# feature_names +# ``` # %% -from sklearn.pipeline import make_pipeline -from sklearn.preprocessing import StandardScaler -from sklearn.linear_model import LogisticRegression +# Write your code here. -logistic_regression = make_pipeline( - StandardScaler(), LogisticRegression(penalty="l2") -) +# %% [markdown] +# Notice that there are as many feature names as coefficients in the last step +# of your predictive pipeline. # %% [markdown] -# Given the following candidates for the `C` parameter, find out the impact of -# `C` on the classifier decision boundary. You can use -# `sklearn.inspection.DecisionBoundaryDisplay.from_estimator` to plot the -# decision function boundary. +# Which of the following pairs of features is most impacting the predictions of +# the logistic regression classifier based on the absolute magnitude of its +# coefficients? # %% -Cs = [0.01, 0.1, 1, 10] +# Write your code here. + +# %% [markdown] +# Now create a similar pipeline consisting of the same preprocessor as above, +# followed by a `PolynomialFeatures` and a logistic regression with `C=0.01`. +# Set `degree=2` and `interaction_only=True` to the feature engineering step. +# Remember not to include a "bias" feature to avoid introducing a redundancy +# with the intercept of the subsequent logistic regression. +# %% # Write your code here. # %% [markdown] -# Look at the impact of the `C` hyperparameter on the magnitude of the weights. +# By comparing the cross-validation test scores of both models fold-to-fold, +# count the number of times the model using multiplicative interactions and both +# numerical and categorical features has a better test score than the model +# without interactions. + +# %% +# Write your code here. # %% # Write your code here. diff --git a/python_scripts/linear_models_ex_04.py b/python_scripts/linear_models_ex_04.py new file mode 100644 index 000000000..dd9ae6bb1 --- /dev/null +++ b/python_scripts/linear_models_ex_04.py @@ -0,0 +1,170 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.15.2 +# kernelspec: +# display_name: Python 3 +# name: python3 +# --- + +# %% [markdown] +# # 📝 Exercise M4.04 +# +# In the previous Module we tuned the hyperparameter `C` of the logistic +# regression without mentioning that it controls the regularization strength. +# Later, on the slides on 🎥 **Intuitions on regularized linear models** we +# metioned that a small `C` provides a more regularized model, whereas a +# non-regularized model is obtained with an infinitely large value of `C`. +# Indeed, `C` behaves as the inverse of the `alpha` coefficient in the `Ridge` +# model. +# +# In this exercise, we ask you to train a logistic regression classifier using +# different values of the parameter `C` to find its effects by yourself. +# +# We start by loading the dataset. We only keep the Adelie and Chinstrap classes +# to keep the discussion simple. + + +# %% [markdown] +# ```{note} +# If you want a deeper overview regarding this dataset, you can refer to the +# Appendix - Datasets description section at the end of this MOOC. +# ``` + +# %% +import pandas as pd + +penguins = pd.read_csv("../datasets/penguins_classification.csv") +penguins = ( + penguins.set_index("Species").loc[["Adelie", "Chinstrap"]].reset_index() +) + +culmen_columns = ["Culmen Length (mm)", "Culmen Depth (mm)"] +target_column = "Species" + +# %% +from sklearn.model_selection import train_test_split + +penguins_train, penguins_test = train_test_split( + penguins, random_state=0, test_size=0.4 +) + +data_train = penguins_train[culmen_columns] +data_test = penguins_test[culmen_columns] + +target_train = penguins_train[target_column] +target_test = penguins_test[target_column] + +# %% [markdown] +# We define a function to help us fit a given `model` and plot its decision +# boundary. We recall that by using a `DecisionBoundaryDisplay` with diverging +# colormap, `vmin=0` and `vmax=1`, we ensure that the 0.5 probability is mapped +# to the white color. Equivalently, the darker the color, the closer the +# predicted probability is to 0 or 1 and the more confident the classifier is in +# its predictions. + +# %% +import matplotlib.pyplot as plt +import seaborn as sns +from sklearn.inspection import DecisionBoundaryDisplay + + +def plot_decision_boundary(model): + model.fit(data_train, target_train) + accuracy = model.score(data_test, target_test) + C = model.get_params()["logisticregression__C"] + + disp = DecisionBoundaryDisplay.from_estimator( + model, + data_train, + response_method="predict_proba", + plot_method="pcolormesh", + cmap="RdBu_r", + alpha=0.8, + vmin=0.0, + vmax=1.0, + ) + DecisionBoundaryDisplay.from_estimator( + model, + data_train, + response_method="predict_proba", + plot_method="contour", + linestyles="--", + linewidths=1, + alpha=0.8, + levels=[0.5], + ax=disp.ax_, + ) + sns.scatterplot( + data=penguins_train, + x=culmen_columns[0], + y=culmen_columns[1], + hue=target_column, + palette=["tab:blue", "tab:red"], + ax=disp.ax_, + ) + plt.legend(bbox_to_anchor=(1.05, 0.8), loc="upper left") + plt.title(f"C: {C} \n Accuracy on the test set: {accuracy:.2f}") + + +# %% [markdown] +# Let's now create our predictive model. + +# %% +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.linear_model import LogisticRegression + +logistic_regression = make_pipeline(StandardScaler(), LogisticRegression()) + +# %% [markdown] +# ## Influence of the parameter `C` on the decision boundary +# +# Given the following candidates for the `C` parameter and the +# `plot_decision_boundary` function, find out the impact of `C` on the +# classifier's decision boundary. +# +# - How does the value of `C` impact the confidence on the predictions? +# - How does it impact the underfit/overfit trade-off? +# - How does it impact the position and orientation of the decision boundary? +# +# Try to give an interpretation on the reason for such behavior. + +# %% +Cs = [1e-6, 0.01, 0.1, 1, 10, 100, 1e6] + +# Write your code here. + +# %% [markdown] +# ## Impact of the regularization on the weights +# +# Look at the impact of the `C` hyperparameter on the magnitude of the weights. +# **Hint**: You can [access pipeline +# steps](https://scikit-learn.org/stable/modules/compose.html#access-pipeline-steps) +# by name or position. Then you can query the attributes of that step such as +# `coef_`. + +# %% +# Write your code here. + +# %% [markdown] +# ## Impact of the regularization on with non-linear feature engineering +# +# Use the `plot_decision_boundary` function to repeat the experiment using a +# non-linear feature engineering pipeline. For such purpose, insert +# `Nystroem(kernel="rbf", gamma=1, n_components=100)` between the +# `StandardScaler` and the `LogisticRegression` steps. +# +# - Does the value of `C` still impact the position of the decision boundary and +# the confidence of the model? +# - What can you say about the impact of `C` on the underfitting vs overfitting +# trade-off? + +# %% +from sklearn.kernel_approximation import Nystroem + +# Write your code here. diff --git a/python_scripts/linear_models_sol_03.py b/python_scripts/linear_models_sol_03.py index dc2a82f5c..c76806a45 100644 --- a/python_scripts/linear_models_sol_03.py +++ b/python_scripts/linear_models_sol_03.py @@ -8,273 +8,267 @@ # %% [markdown] # # 📃 Solution for Exercise M4.03 # -# In the previous Module we tuned the hyperparameter `C` of the logistic -# regression without mentioning that it controls the regularization strength. -# Later, on the slides on 🎥 **Intuitions on regularized linear models** we -# metioned that a small `C` provides a more regularized model, whereas a -# non-regularized model is obtained with an infinitely large value of `C`. -# Indeed, `C` behaves as the inverse of the `alpha` coefficient in the `Ridge` -# model. -# -# In this exercise, we ask you to train a logistic regression classifier using -# different values of the parameter `C` to find its effects by yourself. -# -# We start by loading the dataset. We only keep the Adelie and Chinstrap classes -# to keep the discussion simple. +# Now, we tackle a more realistic classification problem instead of making a +# synthetic dataset. We start by loading the Adult Census dataset with the +# following snippet. For the moment we retain only the **numerical features**. + +# %% +import pandas as pd +adult_census = pd.read_csv("../datasets/adult-census.csv") +target = adult_census["class"] +data = adult_census.select_dtypes(["integer", "floating"]) +data = data.drop(columns=["education-num"]) +data # %% [markdown] -# ```{note} -# If you want a deeper overview regarding this dataset, you can refer to the -# Appendix - Datasets description section at the end of this MOOC. -# ``` +# We confirm that all the selected features are numerical. +# +# Compute the generalization performance in terms of accuracy of a linear model +# composed of a `StandardScaler` and a `LogisticRegression`. Use a 10-fold +# cross-validation with `return_estimator=True` to be able to inspect the +# trained estimators. # %% -import pandas as pd +# solution +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import cross_validate -penguins = pd.read_csv("../datasets/penguins_classification.csv") -penguins = ( - penguins.set_index("Species").loc[["Adelie", "Chinstrap"]].reset_index() +model = make_pipeline(StandardScaler(), LogisticRegression()) +cv_results_lr = cross_validate( + model, data, target, cv=10, return_estimator=True ) +test_score_lr = cv_results_lr["test_score"] +test_score_lr -culmen_columns = ["Culmen Length (mm)", "Culmen Depth (mm)"] -target_column = "Species" +# %% [markdown] +# What is the most important feature seen by the logistic regression? +# +# You can use a boxplot to compare the absolute values of the coefficients while +# also visualizing the variability induced by the cross-validation resampling. # %% -from sklearn.model_selection import train_test_split +# solution +import matplotlib.pyplot as plt -penguins_train, penguins_test = train_test_split( - penguins, random_state=0, test_size=0.4 -) +coefs = [pipeline[-1].coef_[0] for pipeline in cv_results_lr["estimator"]] +coefs = pd.DataFrame(coefs, columns=data.columns) -data_train = penguins_train[culmen_columns] -data_test = penguins_test[culmen_columns] +color = {"whiskers": "black", "medians": "black", "caps": "black"} +_, ax = plt.subplots() +_ = coefs.abs().plot.box(color=color, vert=False, ax=ax) -target_train = penguins_train[target_column] -target_test = penguins_test[target_column] +# %% [markdown] tags=["solution"] +# Since we scaled the features, the coefficients of the linear model can be +# meaningful compared directly. `"capital-gain"` is the most impacting feature. +# Just be aware not to draw conclusions on the causal effect provided the impact +# of a feature. Interested readers are refered to the [example on Common +# pitfalls in the interpretation of coefficients of linear +# models](https://scikit-learn.org/stable/auto_examples/inspection/plot_linear_model_coefficient_interpretation.html) +# or the [example on Failure of Machine Learning to infer causal +# effects](https://scikit-learn.org/stable/auto_examples/inspection/plot_causal_interpretation.html). # %% [markdown] -# We define a function to help us fit a given `model` and plot its decision -# boundary. We recall that by using a `DecisionBoundaryDisplay` with diverging -# colormap, `vmin=0` and `vmax=1`, we ensure that the 0.5 probability is mapped -# to the white color. Equivalently, the darker the color, the closer the -# predicted probability is to 0 or 1 and the more confident the classifier is in -# its predictions. +# Let's now work with **both numerical and categorical features**. You can +# reload the Adult Census dataset with the following snippet: # %% -import matplotlib.pyplot as plt -import seaborn as sns -from sklearn.inspection import DecisionBoundaryDisplay - - -def plot_decision_boundary(model): - model.fit(data_train, target_train) - accuracy = model.score(data_test, target_test) - - disp = DecisionBoundaryDisplay.from_estimator( - model, - data_train, - response_method="predict_proba", - plot_method="pcolormesh", - cmap="RdBu_r", - alpha=0.8, - vmin=0.0, - vmax=1.0, - ) - DecisionBoundaryDisplay.from_estimator( - model, - data_train, - response_method="predict_proba", - plot_method="contour", - linestyles="--", - linewidths=1, - alpha=0.8, - levels=[0.5], - ax=disp.ax_, - ) - sns.scatterplot( - data=penguins_train, - x=culmen_columns[0], - y=culmen_columns[1], - hue=target_column, - palette=["tab:blue", "tab:red"], - ax=disp.ax_, - ) - plt.legend(bbox_to_anchor=(1.05, 0.8), loc="upper left") - plt.title(f"C: {C} \n Accuracy on the test set: {accuracy:.2f}") +adult_census = pd.read_csv("../datasets/adult-census.csv") +target = adult_census["class"] +data = adult_census.drop(columns=["class", "education-num"]) +# %% [markdown] +# Create a predictive model where: +# - The numerical data must be scaled. +# - The categorical data must be one-hot encoded, set `min_frequency=0.01` to +# group categories concerning less than 1% of the total samples. +# - The predictor is a `LogisticRegression`. You may need to increase the number +# of `max_iter`, which is 100 by default. +# +# Use the same 10-fold cross-validation strategy with `return_estimator=True` as +# above to evaluate this complex pipeline. + +# %% +# solution +from sklearn.compose import make_column_selector as selector +from sklearn.compose import make_column_transformer +from sklearn.preprocessing import OneHotEncoder + +categorical_columns = selector(dtype_include=object)(data) +numerical_columns = selector(dtype_exclude=object)(data) + +preprocessor = make_column_transformer( + ( + OneHotEncoder(handle_unknown="ignore", min_frequency=0.01), + categorical_columns, + ), + (StandardScaler(), numerical_columns), +) +model = make_pipeline(preprocessor, LogisticRegression(max_iter=5_000)) +cv_results_complex_lr = cross_validate( + model, data, target, cv=10, return_estimator=True, n_jobs=2 +) +test_score_complex_lr = cv_results_complex_lr["test_score"] +test_score_complex_lr # %% [markdown] -# Let's now create our predictive model. +# By comparing the cross-validation test scores of both models fold-to-fold, +# count the number of times the model using both numerical and categorical +# features has a better test score than the model using only numerical features. # %% -from sklearn.pipeline import make_pipeline -from sklearn.preprocessing import StandardScaler -from sklearn.linear_model import LogisticRegression +# solution +import numpy as np +import matplotlib.pyplot as plt -logistic_regression = make_pipeline(StandardScaler(), LogisticRegression()) +indices = np.arange(len(test_score_lr)) +plt.scatter( + indices, test_score_lr, color="tab:blue", label="numerical features only" +) +plt.scatter( + indices, + test_score_complex_lr, + color="tab:red", + label="all features", +) +plt.ylim((0, 1)) +plt.xlabel("Cross-validation iteration") +plt.ylabel("Accuracy") +_ = plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left") + +print( + "A model using both all features is better than a" + " model using only numerical features for" + f" {sum(test_score_complex_lr > test_score_lr)} CV iterations out of 10." +) # %% [markdown] -# ## Influence of the parameter `C` on the decision boundary -# -# Given the following candidates for the `C` parameter and the -# `plot_decision_boundary` function, find out the impact of `C` on the -# classifier's decision boundary. -# -# - How does the value of `C` impact the confidence on the predictions? -# - How does it impact the underfit/overfit trade-off? -# - How does it impact the position and orientation of the decision boundary? +# For the following questions, you can copy adn paste the following snippet to +# get the feature names from the column transformer here named `preprocessor`. # -# Try to give an interpretation on the reason for such behavior. +# ```python +# preprocessor.fit(data) +# feature_names = ( +# preprocessor.named_transformers_["onehotencoder"].get_feature_names_out( +# categorical_columns +# ) +# ).tolist() +# feature_names += numerical_columns +# feature_names +# ``` # %% -Cs = [1e-6, 0.01, 0.1, 1, 10, 100, 1e6] - # solution -for C in Cs: - logistic_regression.set_params(logisticregression__C=C) - plot_decision_boundary(logistic_regression) +preprocessor.fit(data) +feature_names = ( + preprocessor.named_transformers_["onehotencoder"].get_feature_names_out( + categorical_columns + ) +).tolist() +feature_names += numerical_columns +feature_names -# %% [markdown] tags=["solution"] -# -# On this series of plots we can observe several important points. Regarding the -# confidence on the predictions: -# -# - For low values of `C` (strong regularization), the classifier is less -# confident in its predictions. We are enforcing a **spread sigmoid**. -# - For high values of `C` (weak regularization), the classifier is more -# confident: the areas with dark blue (very confident in predicting "Adelie") -# and dark red (very confident in predicting "Chinstrap") nearly cover the -# entire feature space. We are enforcing a **steep sigmoid**. -# -# To answer the next question, think that misclassified data points are more -# costly when the classifier is more confident on the decision. Decision rules -# are mostly driven by avoiding such cost. From the previous observations we can -# then deduce that: -# -# - The smaller the `C` (the stronger the regularization), the lower the cost -# of a misclassification. As more data points lay in the low-confidence -# zone, the more the decision rules are influenced almost uniformly by all -# the data points. This leads to a less expressive model, which may underfit. -# - The higher the value of `C` (the weaker the regularization), the more the -# decision is influenced by a few training points very close to the boundary, -# where decisions are costly. Remember that models may overfit if the number -# of samples in the training set is too small, as at least a minimum of -# samples is needed to average the noise out. -# -# The orientation is the result of two factors: minimizing the number of -# misclassified training points with high confidence and their distance to the -# decision boundary (notice how the contour line tries to align with the most -# misclassified data points in the dark-colored zone). This is closely related -# to the value of the weights of the model, which is explained in the next part -# of the exercise. -# -# Finally, for small values of `C` the position of the decision boundary is -# affected by the class imbalance: when `C` is near zero, the model predicts the -# majority class (as seen in the training set) everywhere in the feature space. -# In our case, there are approximately two times more "Adelie" than "Chinstrap" -# penguins. This explains why the decision boundary is shifted to the right when -# `C` gets smaller. Indeed, the most regularized model predicts light blue -# almost everywhere in the feature space. +# %% [markdown] +# Notice that there are as many feature names as coefficients in the last step +# of your predictive pipeline. # %% [markdown] -# ## Impact of the regularization on the weights -# -# Look at the impact of the `C` hyperparameter on the magnitude of the weights. -# **Hint**: You can [access pipeline -# steps](https://scikit-learn.org/stable/modules/compose.html#access-pipeline-steps) -# by name or position. Then you can query the attributes of that step such as -# `coef_`. +# Which of the following pairs of features is most impacting the predictions of +# the logistic regression classifier based on the absolute magnitude of its +# coefficients? # %% # solution -lr_weights = [] -for C in Cs: - logistic_regression.set_params(logisticregression__C=C) - logistic_regression.fit(data_train, target_train) - coefs = logistic_regression[-1].coef_[0] - lr_weights.append(pd.Series(coefs, index=culmen_columns)) - -# %% tags=["solution"] -lr_weights = pd.concat(lr_weights, axis=1, keys=[f"C: {C}" for C in Cs]) -lr_weights.plot.barh() -_ = plt.title("LogisticRegression weights depending of C") +coefs = [ + pipeline[-1].coef_[0] for pipeline in cv_results_complex_lr["estimator"] +] +coefs = pd.DataFrame(coefs, columns=feature_names) + +_, ax = plt.subplots(figsize=(10, 35)) +_ = coefs.abs().plot.box(color=color, vert=False, ax=ax) # %% [markdown] tags=["solution"] -# -# As small `C` provides a more regularized model, it shrinks the weights values -# toward zero, as in the `Ridge` model. -# -# In particular, with a strong penalty (e.g. `C = 0.01`), the weight of the feature -# named "Culmen Depth (mm)" is almost zero. It explains why the decision -# separation in the plot is almost perpendicular to the "Culmen Length (mm)" -# feature. -# -# For even stronger penalty strengths (e.g. `C = 1e-6`), the weights of both -# features are almost zero. It explains why the decision separation in the plot -# is almost constant in the feature space: the predicted probability is only -# based on the intercept parameter of the model (which is never regularized). +# We can visually inspect the coefficients and observe that `"capital-gain"` and +# `"education_Doctorate"` are impacting the predictions the most. # %% [markdown] -# ## Impact of the regularization on with non-linear feature engineering -# -# Use the `plot_decision_boundary` function to repeat the experiment using a -# non-linear feature engineering pipeline. For such purpose, insert -# `Nystroem(kernel="rbf", gamma=1, n_components=100)` between the -# `StandardScaler` and the `LogisticRegression` steps. -# -# - Does the value of `C` still impact the position of the decision boundary and -# the confidence of the model? -# - What can you say about the impact of `C` on the underfitting vs overfitting -# trade-off? +# Now create a similar pipeline consisting of the same preprocessor as above, +# followed by a `PolynomialFeatures` and a logistic regression with `C=0.01`. +# Set `degree=2` and `interaction_only=True` to the feature engineering step. +# Remember not to include a "bias" feature to avoid introducing a redundancy +# with the intercept of the subsequent logistic regression. # %% -from sklearn.kernel_approximation import Nystroem +# solution +from sklearn.preprocessing import PolynomialFeatures +model_with_interaction = make_pipeline( + preprocessor, + PolynomialFeatures(degree=2, include_bias=False, interaction_only=True), + LogisticRegression(C=0.01, max_iter=5_000), +) +model_with_interaction + +# %% [markdown] +# By comparing the cross-validation test scores of both models fold-to-fold, +# count the number of times the model using multiplicative interactions and both +# numerical and categorical features has a better test score than the model +# without interactions. + +# %% # solution -classifier = make_pipeline( - StandardScaler(), - Nystroem(kernel="rbf", gamma=1.0, n_components=100, random_state=0), - LogisticRegression(penalty="l2", max_iter=1000), +cv_results_interactions = cross_validate( + model_with_interaction, + data, + target, + cv=10, + return_estimator=True, + n_jobs=2, ) +test_score_interactions = cv_results_interactions["test_score"] +test_score_interactions -for C in Cs: - classifier.set_params(logisticregression__C=C) - plot_decision_boundary(classifier) +# %% +# solution +plt.scatter( + indices, test_score_lr, color="tab:blue", label="numerical features only" +) +plt.scatter( + indices, + test_score_complex_lr, + color="tab:red", + label="all features", +) +plt.scatter( + indices, + test_score_interactions, + color="black", + label="all features and interactions", +) +plt.xlabel("Cross-validation iteration") +plt.ylabel("Accuracy") +_ = plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left") + +print( + "A model using all features and interactions is better than a model" + " without interactions for" + f" {sum(test_score_interactions > test_score_complex_lr)} CV iterations" + " out of 10." +) # %% [markdown] tags=["solution"] +# When you multiply two one-hot encoded categorical features, the resulting +# interaction feature is mostly 0, with a 1 only when both original features are +# active, acting as a logical `AND`. In this case it could mean we are creating +# new rules such as "has a given education `AND` a given native country", which +# we expect to be predictive. This new rules map the original feature space into +# a higher dimension space, where the linear model can separate the data more +# easily. # -# - For the lowest values of `C`, the overall pipeline underfits: it predicts -# the majority class everywhere, as previously. -# - When `C` increases, the models starts to predict some datapoints from the -# "Chinstrap" class but the model is not very confident anywhere in the -# feature space. -# - The decision boundary is no longer a straight line: the linear model is now -# classifying in the 100-dimensional feature space created by the `Nystroem` -# transformer. As are result, the decision boundary induced by the overall -# pipeline is now expressive enough to wrap around the minority class. -# - For `C = 1` in particular, it finds a smooth red blob around most of the -# "Chinstrap" data points. When moving away from the data points, the model is -# less confident in its predictions and again tends to predict the majority -# class according to the proportion in the training set. -# - For higher values of `C`, the model starts to overfit: it is very confident -# in its predictions almost everywhere, but it should not be trusted: the -# model also makes a larger number of mistakes on the test set (not shown in -# the plot) while adopting a very curvy decision boundary to attempt fitting -# all the training points, including the noisy ones at the frontier between -# the two classes. This makes the decision boundary very sensitive to the -# sampling of the training set and as a result, it does not generalize well in -# that region. This is confirmed by the (slightly) lower accuracy on the test -# set. -# -# Finally, we can also note that the linear model on the raw features was as -# good or better than the best model using non-linear feature engineering. So in -# this case, we did not really need this extra complexity in our pipeline. -# **Simpler is better!** -# -# So to conclude, when using non-linear feature engineering, it is often -# possible to make the pipeline overfit, even if the original feature space is -# low-dimensional. As a result, it is important to tune the regularization -# parameter in conjunction with the parameters of the transformers (e.g. tuning -# `gamma` would be important here). This has a direct impact on the certainty of -# the predictions. +# Keep into account that multiplying all pairs of one-hot encoded features may +# lead to a rapid increase in the number of features, especially if the original +# categorical variables have many levels. This can increase the computational +# cost of your model and promote overfitting, as we will see in a future +# notebook. diff --git a/python_scripts/linear_models_sol_04.py b/python_scripts/linear_models_sol_04.py new file mode 100644 index 000000000..942aed56d --- /dev/null +++ b/python_scripts/linear_models_sol_04.py @@ -0,0 +1,281 @@ +# --- +# jupyter: +# kernelspec: +# display_name: Python 3 +# name: python3 +# --- + +# %% [markdown] +# # 📃 Solution for Exercise M4.04 +# +# In the previous Module we tuned the hyperparameter `C` of the logistic +# regression without mentioning that it controls the regularization strength. +# Later, on the slides on 🎥 **Intuitions on regularized linear models** we +# metioned that a small `C` provides a more regularized model, whereas a +# non-regularized model is obtained with an infinitely large value of `C`. +# Indeed, `C` behaves as the inverse of the `alpha` coefficient in the `Ridge` +# model. +# +# In this exercise, we ask you to train a logistic regression classifier using +# different values of the parameter `C` to find its effects by yourself. +# +# We start by loading the dataset. We only keep the Adelie and Chinstrap classes +# to keep the discussion simple. + + +# %% [markdown] +# ```{note} +# If you want a deeper overview regarding this dataset, you can refer to the +# Appendix - Datasets description section at the end of this MOOC. +# ``` + +# %% +import pandas as pd + +penguins = pd.read_csv("../datasets/penguins_classification.csv") +penguins = ( + penguins.set_index("Species").loc[["Adelie", "Chinstrap"]].reset_index() +) + +culmen_columns = ["Culmen Length (mm)", "Culmen Depth (mm)"] +target_column = "Species" + +# %% +from sklearn.model_selection import train_test_split + +penguins_train, penguins_test = train_test_split( + penguins, random_state=0, test_size=0.4 +) + +data_train = penguins_train[culmen_columns] +data_test = penguins_test[culmen_columns] + +target_train = penguins_train[target_column] +target_test = penguins_test[target_column] + +# %% [markdown] +# We define a function to help us fit a given `model` and plot its decision +# boundary. We recall that by using a `DecisionBoundaryDisplay` with diverging +# colormap, `vmin=0` and `vmax=1`, we ensure that the 0.5 probability is mapped +# to the white color. Equivalently, the darker the color, the closer the +# predicted probability is to 0 or 1 and the more confident the classifier is in +# its predictions. + +# %% +import matplotlib.pyplot as plt +import seaborn as sns +from sklearn.inspection import DecisionBoundaryDisplay + + +def plot_decision_boundary(model): + model.fit(data_train, target_train) + accuracy = model.score(data_test, target_test) + C = model.get_params()["logisticregression__C"] + + disp = DecisionBoundaryDisplay.from_estimator( + model, + data_train, + response_method="predict_proba", + plot_method="pcolormesh", + cmap="RdBu_r", + alpha=0.8, + vmin=0.0, + vmax=1.0, + ) + DecisionBoundaryDisplay.from_estimator( + model, + data_train, + response_method="predict_proba", + plot_method="contour", + linestyles="--", + linewidths=1, + alpha=0.8, + levels=[0.5], + ax=disp.ax_, + ) + sns.scatterplot( + data=penguins_train, + x=culmen_columns[0], + y=culmen_columns[1], + hue=target_column, + palette=["tab:blue", "tab:red"], + ax=disp.ax_, + ) + plt.legend(bbox_to_anchor=(1.05, 0.8), loc="upper left") + plt.title(f"C: {C} \n Accuracy on the test set: {accuracy:.2f}") + + +# %% [markdown] +# Let's now create our predictive model. + +# %% +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.linear_model import LogisticRegression + +logistic_regression = make_pipeline(StandardScaler(), LogisticRegression()) + +# %% [markdown] +# ## Influence of the parameter `C` on the decision boundary +# +# Given the following candidates for the `C` parameter and the +# `plot_decision_boundary` function, find out the impact of `C` on the +# classifier's decision boundary. +# +# - How does the value of `C` impact the confidence on the predictions? +# - How does it impact the underfit/overfit trade-off? +# - How does it impact the position and orientation of the decision boundary? +# +# Try to give an interpretation on the reason for such behavior. + +# %% +Cs = [1e-6, 0.01, 0.1, 1, 10, 100, 1e6] + +# solution +for C in Cs: + logistic_regression.set_params(logisticregression__C=C) + plot_decision_boundary(logistic_regression) + +# %% [markdown] tags=["solution"] +# +# On this series of plots we can observe several important points. Regarding the +# confidence on the predictions: +# +# - For low values of `C` (strong regularization), the classifier is less +# confident in its predictions. We are enforcing a **spread sigmoid**. +# - For high values of `C` (weak regularization), the classifier is more +# confident: the areas with dark blue (very confident in predicting "Adelie") +# and dark red (very confident in predicting "Chinstrap") nearly cover the +# entire feature space. We are enforcing a **steep sigmoid**. +# +# To answer the next question, think that misclassified data points are more +# costly when the classifier is more confident on the decision. Decision rules +# are mostly driven by avoiding such cost. From the previous observations we can +# then deduce that: +# +# - The smaller the `C` (the stronger the regularization), the lower the cost +# of a misclassification. As more data points lay in the low-confidence +# zone, the more the decision rules are influenced almost uniformly by all +# the data points. This leads to a less expressive model, which may underfit. +# - The higher the value of `C` (the weaker the regularization), the more the +# decision is influenced by a few training points very close to the boundary, +# where decisions are costly. Remember that models may overfit if the number +# of samples in the training set is too small, as at least a minimum of +# samples is needed to average the noise out. +# +# The orientation is the result of two factors: minimizing the number of +# misclassified training points with high confidence and their distance to the +# decision boundary (notice how the contour line tries to align with the most +# misclassified data points in the dark-colored zone). This is closely related +# to the value of the weights of the model, which is explained in the next part +# of the exercise. +# +# Finally, for small values of `C` the position of the decision boundary is +# affected by the class imbalance: when `C` is near zero, the model predicts the +# majority class (as seen in the training set) everywhere in the feature space. +# In our case, there are approximately two times more "Adelie" than "Chinstrap" +# penguins. This explains why the decision boundary is shifted to the right when +# `C` gets smaller. Indeed, the most regularized model predicts light blue +# almost everywhere in the feature space. + +# %% [markdown] +# ## Impact of the regularization on the weights +# +# Look at the impact of the `C` hyperparameter on the magnitude of the weights. +# **Hint**: You can [access pipeline +# steps](https://scikit-learn.org/stable/modules/compose.html#access-pipeline-steps) +# by name or position. Then you can query the attributes of that step such as +# `coef_`. + +# %% +# solution +lr_weights = [] +for C in Cs: + logistic_regression.set_params(logisticregression__C=C) + logistic_regression.fit(data_train, target_train) + coefs = logistic_regression[-1].coef_[0] + lr_weights.append(pd.Series(coefs, index=culmen_columns)) + +# %% tags=["solution"] +lr_weights = pd.concat(lr_weights, axis=1, keys=[f"C: {C}" for C in Cs]) +lr_weights.plot.barh() +_ = plt.title("LogisticRegression weights depending of C") + +# %% [markdown] tags=["solution"] +# +# As small `C` provides a more regularized model, it shrinks the weights values +# toward zero, as in the `Ridge` model. +# +# In particular, with a strong penalty (e.g. `C = 0.01`), the weight of the feature +# named "Culmen Depth (mm)" is almost zero. It explains why the decision +# separation in the plot is almost perpendicular to the "Culmen Length (mm)" +# feature. +# +# For even stronger penalty strengths (e.g. `C = 1e-6`), the weights of both +# features are almost zero. It explains why the decision separation in the plot +# is almost constant in the feature space: the predicted probability is only +# based on the intercept parameter of the model (which is never regularized). + +# %% [markdown] +# ## Impact of the regularization on with non-linear feature engineering +# +# Use the `plot_decision_boundary` function to repeat the experiment using a +# non-linear feature engineering pipeline. For such purpose, insert +# `Nystroem(kernel="rbf", gamma=1, n_components=100)` between the +# `StandardScaler` and the `LogisticRegression` steps. +# +# - Does the value of `C` still impact the position of the decision boundary and +# the confidence of the model? +# - What can you say about the impact of `C` on the underfitting vs overfitting +# trade-off? + +# %% +from sklearn.kernel_approximation import Nystroem + +# solution +classifier = make_pipeline( + StandardScaler(), + Nystroem(kernel="rbf", gamma=1.0, n_components=100, random_state=0), + LogisticRegression(max_iter=1000), +) + +for C in Cs: + classifier.set_params(logisticregression__C=C) + plot_decision_boundary(classifier) + +# %% [markdown] tags=["solution"] +# +# - For the lowest values of `C`, the overall pipeline underfits: it predicts +# the majority class everywhere, as previously. +# - When `C` increases, the models starts to predict some datapoints from the +# "Chinstrap" class but the model is not very confident anywhere in the +# feature space. +# - The decision boundary is no longer a straight line: the linear model is now +# classifying in the 100-dimensional feature space created by the `Nystroem` +# transformer. As are result, the decision boundary induced by the overall +# pipeline is now expressive enough to wrap around the minority class. +# - For `C = 1` in particular, it finds a smooth red blob around most of the +# "Chinstrap" data points. When moving away from the data points, the model is +# less confident in its predictions and again tends to predict the majority +# class according to the proportion in the training set. +# - For higher values of `C`, the model starts to overfit: it is very confident +# in its predictions almost everywhere, but it should not be trusted: the +# model also makes a larger number of mistakes on the test set (not shown in +# the plot) while adopting a very curvy decision boundary to attempt fitting +# all the training points, including the noisy ones at the frontier between +# the two classes. This makes the decision boundary very sensitive to the +# sampling of the training set and as a result, it does not generalize well in +# that region. This is confirmed by the (slightly) lower accuracy on the test +# set. +# +# Finally, we can also note that the linear model on the raw features was as +# good or better than the best model using non-linear feature engineering. So in +# this case, we did not really need this extra complexity in our pipeline. +# **Simpler is better!** +# +# So to conclude, when using non-linear feature engineering, it is often +# possible to make the pipeline overfit, even if the original feature space is +# low-dimensional. As a result, it is important to tune the regularization +# parameter in conjunction with the parameters of the transformers (e.g. tuning +# `gamma` would be important here). This has a direct impact on the certainty of +# the predictions. diff --git a/python_scripts/logistic_regression_non_linear.py b/python_scripts/logistic_regression_non_linear.py deleted file mode 100644 index d28a4a9e6..000000000 --- a/python_scripts/logistic_regression_non_linear.py +++ /dev/null @@ -1,217 +0,0 @@ -# --- -# jupyter: -# kernelspec: -# display_name: Python 3 -# name: python3 -# --- - -# %% [markdown] -# # Beyond linear separation in classification -# -# As we saw in the regression section, the linear classification model expects -# the data to be linearly separable. When this assumption does not hold, the -# model is not expressive enough to properly fit the data. Therefore, we need to -# apply the same tricks as in regression: feature augmentation (potentially -# using expert-knowledge) or using a kernel-based method. -# -# We will provide examples where we will use a kernel support vector machine to -# perform classification on some toy-datasets where it is impossible to find a -# perfect linear separation. -# -# We will generate a first dataset where the data are represented as two -# interlaced half circles. This dataset is generated using the function -# [`sklearn.datasets.make_moons`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_moons.html). - -# %% -import numpy as np -import pandas as pd -from sklearn.datasets import make_moons - -feature_names = ["Feature #0", "Features #1"] -target_name = "class" - -X, y = make_moons(n_samples=100, noise=0.13, random_state=42) - -# We store both the data and target in a dataframe to ease plotting -moons = pd.DataFrame( - np.concatenate([X, y[:, np.newaxis]], axis=1), - columns=feature_names + [target_name], -) -data_moons, target_moons = moons[feature_names], moons[target_name] - -# %% [markdown] -# Since the dataset contains only two features, we can make a scatter plot to -# have a look at it. - -# %% -import matplotlib.pyplot as plt -import seaborn as sns - -sns.scatterplot( - data=moons, - x=feature_names[0], - y=feature_names[1], - hue=target_moons, - palette=["tab:red", "tab:blue"], -) -_ = plt.title("Illustration of the moons dataset") - -# %% [markdown] -# From the intuitions that we got by studying linear model, it should be obvious -# that a linear classifier will not be able to find a perfect decision function -# to separate the two classes. -# -# Let's try to see what is the decision boundary of such a linear classifier. We -# will create a predictive model by standardizing the dataset followed by a -# linear support vector machine classifier. - -# %% -from sklearn.pipeline import make_pipeline -from sklearn.preprocessing import StandardScaler -from sklearn.svm import SVC - -linear_model = make_pipeline(StandardScaler(), SVC(kernel="linear")) -linear_model.fit(data_moons, target_moons) - -# %% [markdown] -# ```{warning} -# Be aware that we fit and will check the boundary decision of the classifier on -# the same dataset without splitting the dataset into a training set and a -# testing set. While this is a bad practice, we use it for the sake of -# simplicity to depict the model behavior. Always use cross-validation when you -# want to assess the generalization performance of a machine-learning model. -# ``` - -# %% [markdown] -# Let's check the decision boundary of such a linear model on this dataset. - -# %% -from sklearn.inspection import DecisionBoundaryDisplay - -DecisionBoundaryDisplay.from_estimator( - linear_model, data_moons, response_method="predict", cmap="RdBu", alpha=0.5 -) -sns.scatterplot( - data=moons, - x=feature_names[0], - y=feature_names[1], - hue=target_moons, - palette=["tab:red", "tab:blue"], -) -_ = plt.title("Decision boundary of a linear model") - -# %% [markdown] -# As expected, a linear decision boundary is not enough flexible to split the -# two classes. -# -# To push this example to the limit, we will create another dataset where -# samples of a class will be surrounded by samples from the other class. - -# %% -from sklearn.datasets import make_gaussian_quantiles - -feature_names = ["Feature #0", "Features #1"] -target_name = "class" - -X, y = make_gaussian_quantiles( - n_samples=100, n_features=2, n_classes=2, random_state=42 -) -gauss = pd.DataFrame( - np.concatenate([X, y[:, np.newaxis]], axis=1), - columns=feature_names + [target_name], -) -data_gauss, target_gauss = gauss[feature_names], gauss[target_name] - -# %% -ax = sns.scatterplot( - data=gauss, - x=feature_names[0], - y=feature_names[1], - hue=target_gauss, - palette=["tab:red", "tab:blue"], -) -_ = plt.title("Illustration of the Gaussian quantiles dataset") - -# %% [markdown] -# Here, this is even more obvious that a linear decision function is not -# adapted. We can check what decision function, a linear support vector machine -# will find. - -# %% -linear_model.fit(data_gauss, target_gauss) -DecisionBoundaryDisplay.from_estimator( - linear_model, data_gauss, response_method="predict", cmap="RdBu", alpha=0.5 -) -sns.scatterplot( - data=gauss, - x=feature_names[0], - y=feature_names[1], - hue=target_gauss, - palette=["tab:red", "tab:blue"], -) -_ = plt.title("Decision boundary of a linear model") - -# %% [markdown] -# As expected, a linear separation cannot be used to separate the classes -# properly: the model will under-fit as it will make errors even on the training -# set. -# -# In the section about linear regression, we saw that we could use several -# tricks to make a linear model more flexible by augmenting features or using a -# kernel. Here, we will use the later solution by using a radial basis function -# (RBF) kernel together with a support vector machine classifier. -# -# We will repeat the two previous experiments and check the obtained decision -# function. - -# %% -kernel_model = make_pipeline(StandardScaler(), SVC(kernel="rbf", gamma=5)) - -# %% -kernel_model.fit(data_moons, target_moons) -DecisionBoundaryDisplay.from_estimator( - kernel_model, data_moons, response_method="predict", cmap="RdBu", alpha=0.5 -) -sns.scatterplot( - data=moons, - x=feature_names[0], - y=feature_names[1], - hue=target_moons, - palette=["tab:red", "tab:blue"], -) -_ = plt.title("Decision boundary with a model using an RBF kernel") - -# %% [markdown] -# We see that the decision boundary is not anymore a straight line. Indeed, an -# area is defined around the red samples and we could imagine that this -# classifier should be able to generalize on unseen data. -# -# Let's check the decision function on the second dataset. - -# %% -kernel_model.fit(data_gauss, target_gauss) -DecisionBoundaryDisplay.from_estimator( - kernel_model, data_gauss, response_method="predict", cmap="RdBu", alpha=0.5 -) -ax = sns.scatterplot( - data=gauss, - x=feature_names[0], - y=feature_names[1], - hue=target_gauss, - palette=["tab:red", "tab:blue"], -) -_ = plt.title("Decision boundary with a model using an RBF kernel") - -# %% [markdown] -# We observe something similar than in the previous case. The decision function -# is more flexible and does not underfit anymore. -# -# Thus, kernel trick or feature expansion are the tricks to make a linear -# classifier more expressive, exactly as we saw in regression. -# -# Keep in mind that adding flexibility to a model can also risk increasing -# overfitting by making the decision function to be sensitive to individual -# (possibly noisy) data points of the training set. Here we can observe that the -# decision functions remain smooth enough to preserve good generalization. If -# you are curious, you can try to repeat the above experiment with `gamma=100` -# and look at the decision functions.