Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Implement (most of the) MFA support for Okta. #14

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ venv/
ENV/
env.bak/
venv.bak/
venv-awsprocesscreds/

# mypy
.mypy_cache/
222 changes: 212 additions & 10 deletions awsprocesscreds/saml.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,199 @@ def _get_value_of_first_tag(self, root, tag, attr, trait):
class OktaAuthenticator(GenericFormsBasedAuthenticator):
_AUTH_URL = '/api/v1/authn'

_ERROR_AUTH_CANCELLED = (
'Authentication cancelled'
)

_ERROR_LOCKED_OUT = (
"You are locked out of your Okta account. Go to %s to unlock it."
)

_ERROR_PASSWORD_EXPIRED = (
"Your password has expired. Go to %s to change it."
)

_ERROR_MFA_ENROLL = (
"You need to enroll a MFA first."
)

_MSG_AUTH_CODE = (
"Authentication code (RETURN to cancel): "
)

_MSG_ANSWER = (
"Answer (RETURN to cancel): "
)

_MSG_SMS_CODE = (
"SMS authentication code (RETURN to cancel, "
"'RESEND' to get new code sent): "
)

def get_response(self, prompt, allow_cancel=True):
response = self._password_prompter(prompt)
if allow_cancel and response == "":
raise SAMLError(self._ERROR_AUTH_CANCELLED)
return response

def get_assertion_from_response(self, endpoint, parsed):
session_token = parsed['sessionToken']
saml_url = endpoint + '?sessionToken=%s' % session_token
response = self._requests_session.get(saml_url)
logger.info(
'Received HTTP response of status code: %s', response.status_code)
r = self._extract_saml_assertion_from_response(response.text)
logger.info(
'Received the following SAML assertion: \n%s', r,
extra={'is_saml_assertion': True}
)
return r

def process_response(self, response, endpoint):
parsed = json.loads(response.text)
if response.status_code == 200:
return self.get_assertion_from_response(endpoint, parsed)
if response.status_code >= 400:
error = parsed["errorCauses"][0]["errorSummary"]
self.get_response("%s\r\nPress RETURN to continue\r\n"
% error, False)
return None

def process_mfa_totp(self, endpoint, url, statetoken):
while True:
response = self.get_response(self._MSG_AUTH_CODE)
totp_response = self._requests_session.post(
url,
headers={'Content-Type': 'application/json',
'Accept': 'application/json'},
data=json.dumps({'stateToken': statetoken,
'passCode': response})
)
result = self.process_response(totp_response, endpoint)
if result is not None:
return result

def process_mfa_push(self, endpoint, url, statetoken):
self.get_response(("Press RETURN when you are ready to request the "
"push notification"), False)
while True:
totp_response = self._requests_session.post(
url,
headers={'Content-Type': 'application/json',
'Accept': 'application/json'},
data=json.dumps({'stateToken': statetoken})
)
totp_parsed = json.loads(totp_response.text)
if totp_parsed["status"] == "SUCCESS":
return self.get_assertion_from_response(endpoint, totp_parsed)
if totp_parsed["factorResult"] != "WAITING":
raise SAMLError(self._ERROR_AUTH_CANCELLED)

def process_mfa_security_question(self, endpoint, url, statetoken):
while True:
response = self.get_response(self._MSG_ANSWER)
totp_response = self._requests_session.post(
url,
headers={'Content-Type': 'application/json',
'Accept': 'application/json'},
data=json.dumps({'stateToken': statetoken,
'answer': response})
)
result = self.process_response(totp_response, endpoint)
if result is not None:
return result

def verify_sms_factor(self, url, statetoken, passcode):
body = {'stateToken': statetoken}
if passcode != "":
body['passCode'] = passcode
return self._requests_session.post(
url,
headers={'Content-Type': 'application/json',
'Accept': 'application/json'},
data=json.dumps(body)
)

def process_mfa_sms(self, endpoint, url, statetoken):
# Need to trigger the initial code to be sent ...
self.verify_sms_factor(url, statetoken, "")
while True:
response = self.get_response(self._MSG_SMS_CODE)
# If the user has asked for the code to be resent, clear
# the response to retrigger sending the code.
if response == "RESEND":
response = ""
sms_response = self.verify_sms_factor(url, statetoken, response)
# If we've just requested a resend, don't check the result
# - just loop around to get the next response from the user.
if response != "":
result = self.process_response(sms_response, endpoint)
if result is not None:
return result

def display_mfa_choices(self, parsed):
index = 1
prompt = ""
for f in parsed["_embedded"]["factors"]:
if f["factorType"] == "token":
prompt += "%s: %s token\r\n" % (index, f["provider"])
elif f["factorType"] == "token:software:totp":
prompt += ("%s: %s authenticator app\r\n"
% (index, f["provider"]))
elif f["factorType"] == "sms":
prompt += "%s: SMS text message\r\n" % index
elif f["factorType"] == "push":
prompt += "%s: Push notification\r\n" % index
elif f["factorType"] == "question":
prompt += "%s: Security question\r\n" % index
else:
prompt += "%s: %s %s\r\n" % (index,
f["provider"],
f["factorType"])
index += 1
return index, prompt

def get_number(self, prompt):
response = self.get_response(prompt)
choice = 0
try:
choice = int(response)
except ValueError:
pass
return choice

def get_mfa_choice(self, parsed):
count, prompt = self.display_mfa_choices(parsed)
prompt = ("Please choose from the following authentication"
" choices:\r\n") + prompt
prompt += ("Enter the number corresponding to your choice "
"or press RETURN to cancel authentication: ")
while True:
choice = self.get_number(prompt)
if 0 < choice < count:
return choice

def process_mfa_verification(self, endpoint, parsed):
# If we've only got one factor, pick that automatically
if len(parsed["_embedded"]["factors"]) == 1:
choice = 1
else:
choice = self.get_mfa_choice(parsed)
factor = parsed["_embedded"]["factors"][choice - 1]
url = factor["_links"]["verify"]["href"]
statetoken = parsed["stateToken"]
if factor["factorType"] == "token:software:totp":
return self.process_mfa_totp(endpoint, url, statetoken)
if factor["factorType"] == "push":
return self.process_mfa_push(endpoint, url, statetoken)
if factor["factorType"] == "question":
return self.process_mfa_security_question(endpoint,
url, statetoken)
if factor["factorType"] == "sms":
return self.process_mfa_sms(endpoint, url, statetoken)

raise SAMLError("Unsupported factor")

def retrieve_saml_assertion(self, config):
self._validate_config_values(config)
endpoint = config['saml_endpoint']
Expand All @@ -237,17 +430,27 @@ def retrieve_saml_assertion(self, config):
'password': password})
)
parsed = json.loads(response.text)
session_token = parsed['sessionToken']
saml_url = endpoint + '?sessionToken=%s' % session_token
response = self._requests_session.get(saml_url)
logger.info(
'Received HTTP response of status code: %s', response.status_code)
r = self._extract_saml_assertion_from_response(response.text)
logger.info(
'Received the following SAML assertion: \n%s', r,
extra={'is_saml_assertion': True}
'Got status %s and response: %s',
response.status_code, response.text
)
return r
if response.status_code == 401:
raise SAMLError(self._ERROR_LOGIN_FAILED_NON_200 %
parsed["errorSummary"])
if "status" in parsed:
if parsed["status"] == "SUCCESS":
return self.get_assertion_from_response(endpoint, parsed)
if parsed["status"] == "LOCKED_OUT":
raise SAMLError(self._ERROR_LOCKED_OUT %
parsed["_links"]["href"])
if parsed["status"] == "PASSWORD_EXPIRED":
raise SAMLError(self._ERROR_PASSWORD_EXPIRED %
parsed["_links"]["href"])
if parsed["status"] == "MFA_ENROLL":
raise SAMLError(self._ERROR_MFA_ENROLL)
if parsed["status"] == "MFA_REQUIRED":
return self.process_mfa_verification(endpoint, parsed)
raise SAMLError("Code logic failure")

def is_suitable(self, config):
return (config.get('saml_authentication_type') == 'form' and
Expand Down Expand Up @@ -309,7 +512,6 @@ class SAMLCredentialFetcher(CachedCredentialFetcher):
SAML_FORM_AUTHENTICATORS = {
'okta': OktaAuthenticator,
'adfs': ADFSFormsBasedAuthenticator

}

def __init__(self, client_creator, provider_name, saml_config,
Expand Down
Loading