Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge of p-kimberley:gh-3237 #3275

Open
wants to merge 1 commit into
base: 7.1
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ appConfig:
jwksUri: null
logoutEndpoint: null
openIdConfigurationEndpoint: null
redirectUri: null
requestScope: null
tokenEndpoint: null
useInternal: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,11 @@ public static AuthenticationState create(final HttpServletRequest request, final
final HttpSession session = request.getSession(false);
return LogUtil.message("Creating new AuthenticationState, stateId: {}, session: {}, requestUri: {}",
stateId,
session != null
? session.getId()
: null,
request.getRequestURI());
session != null ? session.getId() : null,
url);
});

final AuthenticationState state = new AuthenticationState(stateId, url, nonce);

final Cache<String, AuthenticationState> cache = getOrCreateCache(request);
cache.put(stateId, state);
return state;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ public class OpenIdConfig extends AbstractConfig implements IsStroomConfig {
*/
private final String requestScope;

/**
* Redirect URI
*/
private final String redirectUri;

public OpenIdConfig() {
useInternal = true;
openIdConfigurationEndpoint = null;
Expand All @@ -92,6 +97,7 @@ public OpenIdConfig() {
clientSecret = null;
clientId = null;
requestScope = null;
redirectUri = null;
}

@JsonCreator
Expand All @@ -105,7 +111,8 @@ public OpenIdConfig(@JsonProperty("useInternal") final boolean useInternal,
@JsonProperty("formTokenRequest") final boolean formTokenRequest,
@JsonProperty("clientId") final String clientId,
@JsonProperty("clientSecret") final String clientSecret,
@JsonProperty("requestScope") final String requestScope) {
@JsonProperty("requestScope") final String requestScope,
@JsonProperty("redirectUri") final String redirectUri) {
this.useInternal = useInternal;
this.openIdConfigurationEndpoint = openIdConfigurationEndpoint;
this.issuer = issuer;
Expand All @@ -117,6 +124,7 @@ public OpenIdConfig(@JsonProperty("useInternal") final boolean useInternal,
this.clientId = clientId;
this.clientSecret = clientSecret;
this.requestScope = requestScope;
this.redirectUri = redirectUri;
}

/**
Expand All @@ -130,9 +138,9 @@ public boolean isUseInternal() {
return useInternal;
}

@JsonProperty
@JsonPropertyDescription("You can set an openid-configuration URL to automatically configure much of the openid " +
"settings. Without this the other endpoints etc must be set manually.")
@JsonProperty
public String getOpenIdConfigurationEndpoint() {
return openIdConfigurationEndpoint;
}
Expand Down Expand Up @@ -198,6 +206,13 @@ public String getRequestScope() {
return requestScope;
}

@JsonProperty
@JsonPropertyDescription("Redirect URI used to receive the authorisation code from the remote server. " +
"If not specified, the user is redirected to the original request URL.")
public String getRedirectUri() {
return redirectUri;
}

@Override
public String toString() {
return "OpenIdConfig{" +
Expand All @@ -212,6 +227,7 @@ public String toString() {
", clientId='" + clientId + '\'' +
", clientSecret='" + clientSecret + '\'' +
", requestScope='" + requestScope + '\'' +
", redirectUri='" + redirectUri + '\'' +
'}';
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package stroom.security.impl;

import stroom.config.common.UriFactory;
import stroom.security.api.UserIdentity;
import stroom.security.openid.api.OpenId;
import stroom.util.jersey.UriBuilderUtil;
Expand All @@ -23,18 +24,20 @@ class OpenIdManager {

private final ResolvedOpenIdConfig openIdConfig;
private final UserIdentityFactory userIdentityFactory;
private final UriFactory uriFactory;

@Inject
public OpenIdManager(final ResolvedOpenIdConfig openIdConfig,
final UserIdentityFactory userIdentityFactory) {
final UserIdentityFactory userIdentityFactory,
final UriFactory uriFactory) {
this.openIdConfig = openIdConfig;
this.userIdentityFactory = userIdentityFactory;
this.uriFactory = uriFactory;
}

public String redirect(final HttpServletRequest request,
final String code,
final String stateId,
final String postAuthRedirectUri) {
final String stateId) {
String redirectUri = null;

// If we have completed the front channel flow then we will have a state id.
Expand All @@ -43,22 +46,21 @@ public String redirect(final HttpServletRequest request,
}

if (redirectUri == null) {
final String uri = OpenId.removeReservedParams(postAuthRedirectUri);
redirectUri = frontChannelOIDC(request, uri);
redirectUri = frontChannelOIDC(request);
}

return redirectUri;
}

private String frontChannelOIDC(final HttpServletRequest request, final String postAuthRedirectUri) {
private String frontChannelOIDC(final HttpServletRequest request) {
final String endpoint = openIdConfig.getAuthEndpoint();
final String clientId = openIdConfig.getClientId();
Objects.requireNonNull(endpoint,
"To make an authentication request the OpenId config 'authEndpoint' must not be null");
Objects.requireNonNull(clientId,
"To make an authentication request the OpenId config 'clientId' must not be null");
// Create a state for this authentication request.
final AuthenticationState state = AuthenticationStateSessionUtil.create(request, postAuthRedirectUri);
final AuthenticationState state = AuthenticationStateSessionUtil.create(request, buildRedirectUrl(request));
LOGGER.debug(() -> "frontChannelOIDC state=" + state);
return createAuthUri(request, endpoint, clientId, state, false);
}
Expand All @@ -69,17 +71,13 @@ private String backChannelOIDC(final HttpServletRequest request,
Objects.requireNonNull(code, "Null code");
Objects.requireNonNull(stateId, "Null state Id");

boolean loggedIn = false;
String redirectUri = null;

// If we have a state id then this should be a return from the auth service.
LOGGER.debug(() -> "We have the following state: " + stateId);

// Check the state is one we requested.
final AuthenticationState state = AuthenticationStateSessionUtil.pop(request, stateId);
if (state == null) {
LOGGER.warn(() -> "Unexpected state: " + stateId);

} else {
LOGGER.debug(() -> "backChannelOIDC state=" + state);
final HttpSession session = request.getSession(false);
Expand All @@ -91,17 +89,40 @@ private String backChannelOIDC(final HttpServletRequest request,
if (optionalUserIdentity.isPresent()) {
// Set the token in the session.
UserIdentitySessionUtil.set(session, optionalUserIdentity.get());
loggedIn = true;
}

// If we manage to login then redirect to the original URL held in the state.
if (loggedIn) {
LOGGER.info(() -> "Redirecting to initiating URI: " + state.getUri());
redirectUri = state.getUri();
// Login successful. Use the redirect URI if configured, else use the original request URI.
final String redirectUri = state.getUri();
LOGGER.info(() -> "Redirecting to initiating URI: " + redirectUri);
return redirectUri;
}
}

return redirectUri;
return null;
}

/**
* Build a complete redirect URL using the configured public URL.
* This is the URL used to redirect after the authorisation flow has completed.
*/
private String buildRedirectUrl(final HttpServletRequest request) {
return uriFactory.publicUri(UrlUtils.getFullUri(request)).toString();
}

/**
* Prefer the configured redirect URI. Otherwise, use the original request URI when redirecting after successful
* authorisation.
*/
public String getRedirectUri(final HttpServletRequest request) {
return getRedirectUri(buildRedirectUrl(request));
}

public String getRedirectUri(final String originalUrl) {
final String redirectUri = openIdConfig.getRedirectUri();
if (redirectUri == null || redirectUri.isEmpty()) {
return originalUrl;
} else {
return redirectUri;
}
}

/**
Expand Down Expand Up @@ -131,15 +152,14 @@ public Optional<UserIdentity> getOrSetSessionUser(final HttpServletRequest reque
return result;
}

public String logout(final HttpServletRequest request, final String postAuthRedirectUri) {
final String redirectUri = OpenId.removeReservedParams(postAuthRedirectUri);
public String logout(final HttpServletRequest request) {
final String endpoint = openIdConfig.getLogoutEndpoint();
final String clientId = openIdConfig.getClientId();
Objects.requireNonNull(endpoint,
"To make a logout request the OpenId config 'logoutEndpoint' must not be null");
Objects.requireNonNull(clientId,
"To make an authentication request the OpenId config 'clientId' must not be null");
final AuthenticationState state = AuthenticationStateSessionUtil.create(request, redirectUri);
final AuthenticationState state = AuthenticationStateSessionUtil.create(request, buildRedirectUrl(request));
LOGGER.debug(() -> "logout state=" + state);
return createAuthUri(request, endpoint, clientId, state, true);
}
Expand All @@ -154,7 +174,7 @@ private String createAuthUri(final HttpServletRequest request,
UriBuilder uriBuilder = UriBuilder.fromUri(endpoint);
uriBuilder = UriBuilderUtil.addParam(uriBuilder, OpenId.RESPONSE_TYPE, OpenId.CODE);
uriBuilder = UriBuilderUtil.addParam(uriBuilder, OpenId.CLIENT_ID, clientId);
uriBuilder = UriBuilderUtil.addParam(uriBuilder, OpenId.REDIRECT_URI, state.getUri());
uriBuilder = UriBuilderUtil.addParam(uriBuilder, OpenId.REDIRECT_URI, getRedirectUri(state.getUri()));
uriBuilder = UriBuilderUtil.addParam(uriBuilder, OpenId.SCOPE, OpenId.SCOPE__OPENID +
" " +
OpenId.SCOPE__EMAIL);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,4 +205,15 @@ public String getRequestScope() {
}
return openIdConfig.getRequestScope();
}

public String getRedirectUri() {
final OpenIdConfig openIdConfig = openIdConfigProvider.get();
if (openIdConfig.isUseInternal() ||
openIdConfig.getRedirectUri() == null ||
openIdConfig.getRedirectUri().isBlank()) {
return null;
} else {
return openIdConfig.getRedirectUri();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,9 @@ private void filter(final HttpServletRequest request,
// UI request, so instigate an OpenID authentication flow
// like the good relying party we are.
try {
final String postAuthRedirectUri = getPostAuthRedirectUri(request);

LOGGER.debug("Using postAuthRedirectUri: {}", postAuthRedirectUri);

final String code = UrlUtils.getLastParam(request, OpenId.CODE);
final String stateId = UrlUtils.getLastParam(request, OpenId.STATE);
final String redirectUri = openIdManager.redirect(request, code, stateId, postAuthRedirectUri);
final String redirectUri = openIdManager.redirect(request, code, stateId);
response.sendRedirect(redirectUri);

} catch (final RuntimeException e) {
Expand All @@ -193,19 +189,6 @@ private void filter(final HttpServletRequest request,
}
}

private String getPostAuthRedirectUri(final HttpServletRequest request) {
// We have a a new request so we're going to redirect with an AuthenticationRequest.
// Get the redirect URL for the auth service from the current request.
final String originalPath = request.getRequestURI() + Optional.ofNullable(request.getQueryString())
.map(queryStr -> "?" + queryStr)
.orElse("");

// Dropwiz is likely sat behind Nginx with requests reverse proxied to it
// so we need to append just the path/query part to the public URI defined in config
// rather than using the full url of the request
return uriFactory.publicUri(originalPath).toString();
}

private boolean isStaticResource(final HttpServletRequest request) {
final String url = request.getRequestURL().toString();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public ValidateSessionResponse validateSession(final String postAuthRedirectUri)
// If we have completed the front channel flow then we will have a state id.
final String code = getParam(postAuthRedirectUri, OpenId.CODE);
final String stateId = getParam(postAuthRedirectUri, OpenId.STATE);
final String redirectUri = openIdManager.redirect(request, code, stateId, postAuthRedirectUri);
final String redirectUri = openIdManager.redirect(request, code, stateId);

// We might have completed the back channel authentication now so see if we have a user session.
userIdentity = UserIdentitySessionUtil.get(request.getSession(false));
Expand Down Expand Up @@ -129,7 +129,7 @@ public UrlResponse logout(final String redirectUri) {
UserIdentitySessionUtil.set(session, null);
}

final String url = openIdManagerProvider.get().logout(request, redirectUri);
final String url = openIdManagerProvider.get().logout(request);
return new UrlResponse(url);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package stroom.security.impl;

import stroom.config.common.UriFactory;
import stroom.security.api.ProcessingUserIdentityProvider;
import stroom.security.api.UserIdentity;
import stroom.security.impl.exception.AuthenticationException;
Expand Down Expand Up @@ -56,6 +57,7 @@ class UserIdentityFactoryImpl implements UserIdentityFactory {
private final InternalJwtContextFactory internalJwtContextFactory;
private final StandardJwtContextFactory standardJwtContextFactory;
private final OpenIdConfig openIdConfig;
private final OpenIdManager openIdManager;
private final ResolvedOpenIdConfig resolvedOpenIdConfig;
private final DefaultOpenIdCredentials defaultOpenIdCredentials;
private final UserCache userCache;
Expand All @@ -66,6 +68,7 @@ class UserIdentityFactoryImpl implements UserIdentityFactory {
final InternalJwtContextFactory internalJwtContextFactory,
final StandardJwtContextFactory standardJwtContextFactory,
final OpenIdConfig openIdConfig,
final OpenIdManager openIdManager,
final ResolvedOpenIdConfig resolvedOpenIdConfig,
final DefaultOpenIdCredentials defaultOpenIdCredentials,
final UserCache userCache,
Expand All @@ -74,6 +77,7 @@ class UserIdentityFactoryImpl implements UserIdentityFactory {
this.internalJwtContextFactory = internalJwtContextFactory;
this.standardJwtContextFactory = standardJwtContextFactory;
this.openIdConfig = openIdConfig;
this.openIdManager = openIdManager;
this.resolvedOpenIdConfig = resolvedOpenIdConfig;
this.defaultOpenIdCredentials = defaultOpenIdCredentials;
this.userCache = userCache;
Expand Down Expand Up @@ -125,6 +129,7 @@ public Optional<UserIdentity> getAuthFlowUserIdentity(final HttpServletRequest r

final ObjectMapper mapper = getMapper();
final String tokenEndpoint = resolvedOpenIdConfig.getTokenEndpoint();
final String redirectUri = openIdManager.getRedirectUri(state.getUri());
final HttpPost httpPost = new HttpPost(tokenEndpoint);

// AWS requires form content and not a JSON object.
Expand All @@ -134,7 +139,7 @@ public Optional<UserIdentity> getAuthFlowUserIdentity(final HttpServletRequest r
nvps.add(new BasicNameValuePair(OpenId.GRANT_TYPE, OpenId.GRANT_TYPE__AUTHORIZATION_CODE));
nvps.add(new BasicNameValuePair(OpenId.CLIENT_ID, resolvedOpenIdConfig.getClientId()));
nvps.add(new BasicNameValuePair(OpenId.CLIENT_SECRET, resolvedOpenIdConfig.getClientSecret()));
nvps.add(new BasicNameValuePair(OpenId.REDIRECT_URI, state.getUri()));
nvps.add(new BasicNameValuePair(OpenId.REDIRECT_URI, redirectUri));
setFormParams(httpPost, nvps);

} else {
Expand All @@ -144,7 +149,7 @@ public Optional<UserIdentity> getAuthFlowUserIdentity(final HttpServletRequest r
.grantType(OpenId.GRANT_TYPE__AUTHORIZATION_CODE)
.clientId(resolvedOpenIdConfig.getClientId())
.clientSecret(resolvedOpenIdConfig.getClientSecret())
.redirectUri(state.getUri())
.redirectUri(redirectUri)
.build();
final String json = mapper.writeValueAsString(tokenRequest);

Expand Down
13 changes: 13 additions & 0 deletions stroom-util/src/main/java/stroom/util/net/UrlUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,26 @@ private UrlUtils() {
// Utility class.
}

/**
* Return the complete URL, including scheme, hostname, port and query string (if specified)
*/
public static String getFullUrl(final HttpServletRequest request) {
if (request.getQueryString() == null) {
return request.getRequestURL().toString();
}
return request.getRequestURL().toString() + "?" + request.getQueryString();
}

/**
* Return the URI plus the query string (if specified)
*/
public static String getFullUri(final HttpServletRequest request) {
if (request.getQueryString() == null) {
return request.getRequestURI();
}
return request.getRequestURI() + "?" + request.getQueryString();
}

public static Map<String, String> createParamMap(final String url) {
final URI uri = UriBuilder.fromUri(url).build();
final Map<String, String> params = new HashMap<>();
Expand Down