Skip to content

Commit

Permalink
Update utils.py. Removed the disp=False argument in sm.Logit() and up…
Browse files Browse the repository at this point in the history
…dated the fit() method
  • Loading branch information
Jaydon2005 authored Jan 6, 2025
1 parent cf5b441 commit 743d4da
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions msdbook/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,30 @@
import numpy as np
import statsmodels.api as sm


def fit_logit(dta, predictors):
"""Logistic regression"""

# concatenate intercept column of 1s
# Add intercept column of 1s
dta["Intercept"] = np.ones(np.shape(dta)[0])
# get columns of predictors

# Get columns of predictors
cols = dta.columns.tolist()[-1:] + predictors + ["Interaction"]
# fit logistic regression
logit = sm.Logit(dta["Success"], dta[cols], disp=False)
result = logit.fit()


# Fit logistic regression without the deprecated 'disp' argument
logit = sm.Logit(dta["Success"], dta[cols])
result = logit.fit(method='bfgs') # Use method='bfgs' or another supported method

return result


def plot_contour_map(
ax, result, dta, contour_cmap, dot_cmap, levels, xgrid, ygrid, xvar, yvar, base
):
"""Plot the contour map"""

# TODO: see why this warning is being raised about the tight layout
# Ignore tight layout warnings
warnings.filterwarnings("ignore")

# find probability of success for x=xgrid, y=ygrid
# Generate probability of success for x=xgrid, y=ygrid
X, Y = np.meshgrid(xgrid, ygrid)
x = X.flatten()
y = Y.flatten()
Expand All @@ -36,9 +36,12 @@ def plot_contour_map(
Z = np.reshape(z, np.shape(X))

contourset = ax.contourf(X, Y, Z, levels, cmap=contour_cmap, aspect="auto")

# Plot scatter points based on the data
xpoints = np.mean(dta[xvar].values.reshape(-1, 10), axis=1)
ypoints = np.mean(dta[yvar].values.reshape(-1, 10), axis=1)
colors = np.round(np.mean(dta["Success"].values.reshape(-1, 10), axis=1), 0)

ax.scatter(xpoints, ypoints, s=10, c=colors, edgecolor="none", cmap=dot_cmap)
ax.set_xlim(np.min(xgrid), np.max(xgrid))
ax.set_ylim(np.min(ygrid), np.max(ygrid))
Expand Down

0 comments on commit 743d4da

Please sign in to comment.