Skip to content

Commit

Permalink
Add SSO MFA prompt for WebUI MFA flows (#49794)
Browse files Browse the repository at this point in the history
* Include sso channel ID in web mfa challenges.

* Handle SSO MFA challenges.

* Handle sso response in backend.

* Handle non-webauthn mfa response for file transfer, admin actions, and app session.

* Simplify useMfa with new helpers.

* Fix lint.

* Use AuthnDialog for file transfers; Fix json backend logic for file transfers.

* Make useMfa and AuthnDialog more reusable and error proof.

* Use AuthnDialog for App sessions.

* Resolve comments.

* Fix broken app launcher; improve mfaRequired logic in useMfa.

* Fix AuthnDialog test.

* Fix merge conflict with Db web access.

* fix stories.

* Refactor mfa required logic.

* Address bl-nero's comments.

* Address Ryan's comments.

* Add useMfa unit test.

* Fix story lint.

* Replace Promise.withResolvers for compatiblity with older browers; Fix bug where MFA couldn't be retried after a failed attempt; Add extra tests.
  • Loading branch information
Joerger committed Dec 20, 2024
1 parent 9c314ad commit b9adc43
Show file tree
Hide file tree
Showing 36 changed files with 1,045 additions and 602 deletions.
46 changes: 40 additions & 6 deletions lib/client/weblogin.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ type MFAChallengeResponse struct {
WebauthnResponse *wantypes.CredentialAssertionResponse `json:"webauthn_response,omitempty"`
// SSOResponse is a response from an SSO MFA flow.
SSOResponse *SSOResponse `json:"sso_response"`
// TODO(Joerger): DELETE IN v19.0.0, WebauthnResponse used instead.
WebauthnAssertionResponse *wantypes.CredentialAssertionResponse `json:"webauthnAssertionResponse"`
}

// SSOResponse is a json compatible [proto.SSOResponse].
Expand All @@ -124,25 +126,57 @@ type SSOResponse struct {
// GetOptionalMFAResponseProtoReq converts response to a type proto.MFAAuthenticateResponse,
// if there were any responses set. Otherwise returns nil.
func (r *MFAChallengeResponse) GetOptionalMFAResponseProtoReq() (*proto.MFAAuthenticateResponse, error) {
if r.TOTPCode != "" && r.WebauthnResponse != nil {
var availableResponses int
if r.TOTPCode != "" {
availableResponses++
}
if r.WebauthnResponse != nil {
availableResponses++
}
if r.SSOResponse != nil {
availableResponses++
}

if availableResponses > 1 {
return nil, trace.BadParameter("only one MFA response field can be set")
}

if r.TOTPCode != "" {
switch {
case r.WebauthnResponse != nil:
return &proto.MFAAuthenticateResponse{Response: &proto.MFAAuthenticateResponse_Webauthn{
Webauthn: wantypes.CredentialAssertionResponseToProto(r.WebauthnResponse),
}}, nil
case r.SSOResponse != nil:
return &proto.MFAAuthenticateResponse{Response: &proto.MFAAuthenticateResponse_SSO{
SSO: &proto.SSOResponse{
RequestId: r.SSOResponse.RequestID,
Token: r.SSOResponse.Token,
},
}}, nil
case r.TOTPCode != "":
return &proto.MFAAuthenticateResponse{Response: &proto.MFAAuthenticateResponse_TOTP{
TOTP: &proto.TOTPResponse{Code: r.TOTPCode},
}}, nil
}

if r.WebauthnResponse != nil {
case r.WebauthnAssertionResponse != nil:
return &proto.MFAAuthenticateResponse{Response: &proto.MFAAuthenticateResponse_Webauthn{
Webauthn: wantypes.CredentialAssertionResponseToProto(r.WebauthnResponse),
Webauthn: wantypes.CredentialAssertionResponseToProto(r.WebauthnAssertionResponse),
}}, nil
}

return nil, nil
}

// ParseMFAChallengeResponse parses [MFAChallengeResponse] from JSON and returns it as a [proto.MFAAuthenticateResponse].
func ParseMFAChallengeResponse(mfaResponseJSON []byte) (*proto.MFAAuthenticateResponse, error) {
var resp MFAChallengeResponse
if err := json.Unmarshal(mfaResponseJSON, &resp); err != nil {
return nil, trace.Wrap(err)
}

protoResp, err := resp.GetOptionalMFAResponseProtoReq()
return protoResp, trace.Wrap(err)
}

// CreateSSHCertReq is passed by tsh to authenticate a local user without MFA
// and receive short-lived certificates.
type CreateSSHCertReq struct {
Expand Down
12 changes: 4 additions & 8 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2778,7 +2778,7 @@ func (h *Handler) mfaLoginBegin(w http.ResponseWriter, r *http.Request, p httpro
return nil, trace.AccessDenied("invalid credentials")
}

return makeAuthenticateChallenge(mfaChallenge), nil
return makeAuthenticateChallenge(mfaChallenge, "" /*channelID*/), nil
}

// mfaLoginFinish completes the MFA login ceremony, returning a new SSH
Expand Down Expand Up @@ -4877,16 +4877,12 @@ func parseMFAResponseFromRequest(r *http.Request) error {
// context and returned.
func contextWithMFAResponseFromRequestHeader(ctx context.Context, requestHeader http.Header) (context.Context, error) {
if mfaResponseJSON := requestHeader.Get("Teleport-MFA-Response"); mfaResponseJSON != "" {
var resp mfaResponse
if err := json.Unmarshal([]byte(mfaResponseJSON), &resp); err != nil {
mfaResp, err := client.ParseMFAChallengeResponse([]byte(mfaResponseJSON))
if err != nil {
return nil, trace.Wrap(err)
}

return mfa.ContextWithMFAResponse(ctx, &proto.MFAAuthenticateResponse{
Response: &proto.MFAAuthenticateResponse_Webauthn{
Webauthn: wantypes.CredentialAssertionResponseToProto(resp.WebauthnAssertionResponse),
},
}), nil
return mfa.ContextWithMFAResponse(ctx, mfaResp), nil
}

return ctx, nil
Expand Down
8 changes: 3 additions & 5 deletions lib/web/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5594,10 +5594,6 @@ func TestCreateAppSession_RequireSessionMFA(t *testing.T) {
require.NoError(t, err)
mfaResp, err := webauthnDev.SolveAuthn(chal)
require.NoError(t, err)
mfaRespJSON, err := json.Marshal(mfaResponse{
WebauthnAssertionResponse: wantypes.CredentialAssertionResponseFromProto(mfaResp.GetWebauthn()),
})
require.NoError(t, err)

// Extract the session ID and bearer token for the current session.
rawCookie := *pack.cookies[0]
Expand Down Expand Up @@ -5631,7 +5627,9 @@ func TestCreateAppSession_RequireSessionMFA(t *testing.T) {
PublicAddr: "panel.example.com",
ClusterName: "localhost",
},
MFAResponse: string(mfaRespJSON),
MFAResponse: client.MFAChallengeResponse{
WebauthnAssertionResponse: wantypes.CredentialAssertionResponseFromProto(mfaResp.GetWebauthn()),
},
},
expectMFAVerified: true,
},
Expand Down
29 changes: 15 additions & 14 deletions lib/web/apps.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ package web

import (
"context"
"encoding/json"
"net/http"
"sort"

Expand All @@ -33,7 +32,7 @@ import (
"github.com/gravitational/teleport/api/client/proto"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/httplib"
"github.com/gravitational/teleport/lib/reversetunnelclient"
"github.com/gravitational/teleport/lib/utils"
Expand Down Expand Up @@ -191,7 +190,10 @@ type CreateAppSessionRequest struct {
// AWSRole is the AWS role ARN when accessing AWS management console.
AWSRole string `json:"arn,omitempty"`
// MFAResponse is an optional MFA response used to create an MFA verified app session.
MFAResponse string `json:"mfa_response"`
MFAResponse client.MFAChallengeResponse `json:"mfaResponse"`
// TODO(Joerger): DELETE IN v19.0.0
// Backwards compatible version of MFAResponse
MFAResponseJSON string `json:"mfa_response"`
}

// CreateAppSessionResponse is a response to POST /v1/webapi/sessions/app
Expand Down Expand Up @@ -230,17 +232,16 @@ func (h *Handler) createAppSession(w http.ResponseWriter, r *http.Request, p htt
}
}

var mfaProtoResponse *proto.MFAAuthenticateResponse
if req.MFAResponse != "" {
var resp mfaResponse
if err := json.Unmarshal([]byte(req.MFAResponse), &resp); err != nil {
return nil, trace.Wrap(err)
}
mfaResponse, err := req.MFAResponse.GetOptionalMFAResponseProtoReq()
if err != nil {
return nil, trace.Wrap(err)
}

mfaProtoResponse = &proto.MFAAuthenticateResponse{
Response: &proto.MFAAuthenticateResponse_Webauthn{
Webauthn: wantypes.CredentialAssertionResponseToProto(resp.WebauthnAssertionResponse),
},
// Fallback to backwards compatible mfa response.
if mfaResponse == nil && req.MFAResponseJSON != "" {
mfaResponse, err = client.ParseMFAChallengeResponse([]byte(req.MFAResponseJSON))
if err != nil {
return nil, trace.Wrap(err)
}
}

Expand All @@ -263,7 +264,7 @@ func (h *Handler) createAppSession(w http.ResponseWriter, r *http.Request, p htt
PublicAddr: result.App.GetPublicAddr(),
ClusterName: result.ClusterName,
AWSRoleARN: req.AWSRole,
MFAResponse: mfaProtoResponse,
MFAResponse: mfaResponse,
AppName: result.App.GetName(),
URI: result.App.GetURI(),
ClientAddr: r.RemoteAddr,
Expand Down
47 changes: 22 additions & 25 deletions lib/web/files.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package web

import (
"context"
"encoding/json"
"errors"
"net/http"
"time"
Expand All @@ -35,7 +34,6 @@ import (
"github.com/gravitational/teleport/api/utils/keys"
"github.com/gravitational/teleport/api/utils/sshutils"
"github.com/gravitational/teleport/lib/auth/authclient"
wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/multiplexer"
"github.com/gravitational/teleport/lib/reversetunnelclient"
Expand All @@ -56,8 +54,8 @@ type fileTransferRequest struct {
remoteLocation string
// filename is a file name
filename string
// webauthn is an optional parameter that contains a webauthn response string used to issue single use certs
webauthn string
// mfaResponse is an optional parameter that contains an mfa response string used to issue single use certs
mfaResponse string
// fileTransferRequestID is used to find a FileTransferRequest on a session
fileTransferRequestID string
// moderatedSessonID is an ID of a moderated session that has completed a
Expand All @@ -74,11 +72,25 @@ func (h *Handler) transferFile(w http.ResponseWriter, r *http.Request, p httprou
remoteLocation: query.Get("location"),
filename: query.Get("filename"),
namespace: defaults.Namespace,
webauthn: query.Get("webauthn"),
mfaResponse: query.Get("mfaResponse"),
fileTransferRequestID: query.Get("fileTransferRequestId"),
moderatedSessionID: query.Get("moderatedSessionId"),
}

// Check for old query parameter, uses the same data structure.
// TODO(Joerger): DELETE IN v19.0.0
if req.mfaResponse == "" {
req.mfaResponse = query.Get("webauthn")
}

var mfaResponse *proto.MFAAuthenticateResponse
if req.mfaResponse != "" {
var err error
if mfaResponse, err = client.ParseMFAChallengeResponse([]byte(req.mfaResponse)); err != nil {
return nil, trace.Wrap(err)
}
}

// Send an error if only one of these params has been sent. Both should exist or not exist together
if (req.fileTransferRequestID != "") != (req.moderatedSessionID != "") {
return nil, trace.BadParameter("fileTransferRequestId and moderatedSessionId must both be included in the same request.")
Expand Down Expand Up @@ -107,7 +119,7 @@ func (h *Handler) transferFile(w http.ResponseWriter, r *http.Request, p httprou
return nil, trace.Wrap(err)
}

if mfaReq.Required && query.Get("webauthn") == "" {
if mfaReq.Required && mfaResponse == nil {
return nil, trace.AccessDenied("MFA required for file transfer")
}

Expand Down Expand Up @@ -135,8 +147,8 @@ func (h *Handler) transferFile(w http.ResponseWriter, r *http.Request, p httprou
return nil, trace.Wrap(err)
}

if req.webauthn != "" {
err = ft.issueSingleUseCert(req.webauthn, r, tc)
if req.mfaResponse != "" {
err = ft.issueSingleUseCert(mfaResponse, r, tc)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -216,21 +228,10 @@ func (f *fileTransfer) createClient(req fileTransferRequest, httpReq *http.Reque
return tc, nil
}

type mfaResponse struct {
// WebauthnResponse is the response from authenticators.
WebauthnAssertionResponse *wantypes.CredentialAssertionResponse `json:"webauthnAssertionResponse"`
}

// issueSingleUseCert will take an assertion response sent from a solved challenge in the web UI
// and use that to generate a cert. This cert is added to the Teleport Client as an authmethod that
// can be used to connect to a node.
func (f *fileTransfer) issueSingleUseCert(webauthn string, httpReq *http.Request, tc *client.TeleportClient) error {
var mfaResp mfaResponse
err := json.Unmarshal([]byte(webauthn), &mfaResp)
if err != nil {
return trace.Wrap(err)
}

func (f *fileTransfer) issueSingleUseCert(mfaResponse *proto.MFAAuthenticateResponse, httpReq *http.Request, tc *client.TeleportClient) error {
pk, err := keys.ParsePrivateKey(f.sctx.cfg.Session.GetSSHPriv())
if err != nil {
return trace.Wrap(err)
Expand All @@ -241,11 +242,7 @@ func (f *fileTransfer) issueSingleUseCert(webauthn string, httpReq *http.Request
SSHPublicKey: pk.MarshalSSHPublicKey(),
Username: f.sctx.GetUser(),
Expires: time.Now().Add(time.Minute).UTC(),
MFAResponse: &proto.MFAAuthenticateResponse{
Response: &proto.MFAAuthenticateResponse_Webauthn{
Webauthn: wantypes.CredentialAssertionResponseToProto(mfaResp.WebauthnAssertionResponse),
},
},
MFAResponse: mfaResponse,
})
if err != nil {
return trace.Wrap(err)
Expand Down
27 changes: 23 additions & 4 deletions lib/web/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ package web
import (
"context"
"net/http"
"net/url"
"strings"

"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/julienschmidt/httprouter"

Expand Down Expand Up @@ -201,6 +203,22 @@ func (h *Handler) createAuthenticateChallengeHandle(w http.ResponseWriter, r *ht
allowReuse = mfav1.ChallengeAllowReuse_CHALLENGE_ALLOW_REUSE_YES
}

// Prepare an sso client redirect URL in case the user has an SSO MFA device.
ssoClientRedirectURL, err := url.Parse(sso.WebMFARedirect)
if err != nil {
return nil, trace.Wrap(err)
}

// id is used by the front end to differentiate between separate ongoing SSO challenges.
id, err := uuid.NewRandom()
if err != nil {
return nil, trace.Wrap(err)
}
channelID := id.String()
query := ssoClientRedirectURL.Query()
query.Set("channel_id", channelID)
ssoClientRedirectURL.RawQuery = query.Encode()

chal, err := clt.CreateAuthenticateChallenge(r.Context(), &proto.CreateAuthenticateChallengeRequest{
Request: &proto.CreateAuthenticateChallengeRequest_ContextUser{
ContextUser: &proto.ContextUser{},
Expand All @@ -211,13 +229,13 @@ func (h *Handler) createAuthenticateChallengeHandle(w http.ResponseWriter, r *ht
AllowReuse: allowReuse,
UserVerificationRequirement: req.UserVerificationRequirement,
},
SSOClientRedirectURL: sso.WebMFARedirect,
SSOClientRedirectURL: ssoClientRedirectURL.String(),
})
if err != nil {
return nil, trace.Wrap(err)
}

return makeAuthenticateChallenge(chal), nil
return makeAuthenticateChallenge(chal, channelID), nil
}

// createAuthenticateChallengeWithTokenHandle creates and returns MFA authenticate challenges for the user defined in token.
Expand All @@ -235,7 +253,7 @@ func (h *Handler) createAuthenticateChallengeWithTokenHandle(w http.ResponseWrit
return nil, trace.Wrap(err)
}

return makeAuthenticateChallenge(chal), nil
return makeAuthenticateChallenge(chal, "" /*channelID*/), nil
}

type createRegisterChallengeWithTokenRequest struct {
Expand Down Expand Up @@ -581,7 +599,7 @@ func (h *Handler) checkMFARequired(ctx context.Context, req *isMFARequiredReques
}

// makeAuthenticateChallenge converts proto to JSON format.
func makeAuthenticateChallenge(protoChal *proto.MFAAuthenticateChallenge) *client.MFAAuthenticateChallenge {
func makeAuthenticateChallenge(protoChal *proto.MFAAuthenticateChallenge, ssoChannelID string) *client.MFAAuthenticateChallenge {
chal := &client.MFAAuthenticateChallenge{
TOTPChallenge: protoChal.GetTOTP() != nil,
}
Expand All @@ -590,6 +608,7 @@ func makeAuthenticateChallenge(protoChal *proto.MFAAuthenticateChallenge) *clien
}
if protoChal.GetSSOChallenge() != nil {
chal.SSOChallenge = client.SSOChallengeFromProto(protoChal.GetSSOChallenge())
chal.SSOChallenge.ChannelID = ssoChannelID
}
return chal
}
Loading

0 comments on commit b9adc43

Please sign in to comment.