Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NIFI-13016 Add groups mapping from OIDC token claim for Registry #9566

Merged
merged 15 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nifi-registry/nifi-registry-assembly/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@
<nifi.registry.security.user.oidc.client.id />
<nifi.registry.security.user.oidc.client.secret />
<nifi.registry.security.user.oidc.preferred.jwsalgorithm />
<nifi.registry.security.user.oidc.claim.groups>groups</nifi.registry.security.user.oidc.claim.groups>

<!-- nifi.registry.properties: revision management properties -->
<nifi.registry.revisions.enabled>false</nifi.registry.revisions.enabled>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,9 @@ If this value is `none`, NiFi will attempt to validate unsecured/plain tokens. O
JSON Web Key (JWK) provided through the jwks_uri in the metadata found at the discovery URL
|`nifi.registry.security.user.oidc.additional.scopes` | Comma separated scopes that are sent to OpenID Connect Provider in addition to `openid` and `email`.
|`nifi.registry.security.user.oidc.claim.identifying.user` | Claim that identifies the authenticated user. The default value is `email`. Claim names may need to be requested using the `nifi.registry.security.user.oidc.additional.scopes` property
|`nifi.registry.security.user.oidc.claim.groups` | Name of the ID token claim that contains an array of group names of which the
user is a member. Application groups must be supplied from a User Group Provider with matching names in order for the
authorization process to use ID token claim groups. The default value is `groups`.
|==================================================================================================================================================

[[authorization]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@
import java.io.IOException;
import java.io.StringWriter;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;

public class StandardManagedAuthorizer implements ManagedAuthorizer {

Expand Down Expand Up @@ -95,19 +98,29 @@ public AuthorizationResult authorize(AuthorizationRequest request) throws Author

final UserAndGroups userAndGroups = userGroupProvider.getUserAndGroups(request.getIdentity());

final User user = userAndGroups.getUser();
if (user == null) {
return AuthorizationResult.denied(String.format("Unknown user with identity '%s'.", request.getIdentity()));
}
// combine groups from incoming request with groups from UserAndGroups because the request may contain groups from
// an external identity provider and the membership may not be maintained within any of the UserGroupProviders
final Set<Group> userGroups = new HashSet<>();
userGroups.addAll(userAndGroups.getGroups() == null ? Collections.emptySet() : userAndGroups.getGroups());
userGroups.addAll(getGroups(request.getGroups()));

final Set<Group> userGroups = userAndGroups.getGroups();
if (policy.getUsers().contains(user.getIdentifier()) || containsGroup(userGroups, policy)) {
if (containsUser(userAndGroups.getUser(), policy) || containsGroup(userGroups, policy)) {
return AuthorizationResult.approved();
}

return AuthorizationResult.denied(request.getExplanationSupplier().get());
}

private Set<Group> getGroups(final Set<String> groupNames) {
if (groupNames == null || groupNames.isEmpty()) {
return Collections.emptySet();
}

return userGroupProvider.getGroups().stream()
.filter(group -> groupNames.contains(group.getName()))
.collect(Collectors.toSet());
}

/**
* Determines if the policy contains one of the user's groups.
*
Expand All @@ -129,6 +142,20 @@ private boolean containsGroup(final Set<Group> userGroups, final AccessPolicy po
return false;
}

/**
* Determines if the policy contains the user's identifier.
*
* @param user the user
* @param policy the policy
* @return true if the user is non-null and the user's identifies is contained in the policy's users
*/
private boolean containsUser(final User user, final AccessPolicy policy) {
if (user == null || policy.getUsers().isEmpty()) {
return false;
}
return policy.getUsers().contains(user.getIdentifier());
}

@Override
public String getFingerprint() throws AuthorizationAccessException {
XMLStreamWriter writer = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ public class NiFiRegistryProperties extends ApplicationProperties {
public static final String SECURITY_USER_OIDC_PREFERRED_JWSALGORITHM = "nifi.registry.security.user.oidc.preferred.jwsalgorithm";
public static final String SECURITY_USER_OIDC_ADDITIONAL_SCOPES = "nifi.registry.security.user.oidc.additional.scopes";
public static final String SECURITY_USER_OIDC_CLAIM_IDENTIFYING_USER = "nifi.registry.security.user.oidc.claim.identifying.user";
public static final String SECURITY_USER_OIDC_CLAIM_GROUPS = "nifi.registry.security.user.oidc.claim.groups";

// Revision Management Properties
public static final String REVISIONS_ENABLED = "nifi.registry.revisions.enabled";
Expand Down Expand Up @@ -481,6 +482,16 @@ public List<String> getOidcAdditionalScopes() {
public String getOidcClaimIdentifyingUser() {
return getProperty(SECURITY_USER_OIDC_CLAIM_IDENTIFYING_USER, "email").trim();
}
/**
* Returns the claim to be used to extract user groups from the OIDC payload.
* Claim must be requested by adding the scope for it.
* Default is 'groups'.
*
* @return The claim to be used to extract user groups.
*/
public String getOidcClaimGroups() {
return getProperty(SECURITY_USER_OIDC_CLAIM_GROUPS, "groups").trim();
}

/**
* Returns the network interface list to use for HTTPS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ nifi.registry.security.user.oidc.read.timeout=${nifi.registry.security.user.oidc
nifi.registry.security.user.oidc.client.id=${nifi.registry.security.user.oidc.client.id}
nifi.registry.security.user.oidc.client.secret=${nifi.registry.security.user.oidc.client.secret}
nifi.registry.security.user.oidc.preferred.jwsalgorithm=${nifi.registry.security.user.oidc.preferred.jwsalgorithm}
nifi.registry.security.user.oidc.claim.groups=${nifi.registry.security.user.oidc.claim.groups}

# revision management #
# This feature should remain disabled until a future NiFi release that supports the revision API changes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package org.apache.nifi.registry.security.authentication;

import java.io.Serializable;
import java.util.Collections;
import java.util.Set;

/**
* Authentication response for a user login attempt.
Expand All @@ -27,6 +29,7 @@ public class AuthenticationResponse implements Serializable {
private final String username;
private final long expiration;
private final String issuer;
private final Set<String> groups;

/**
* Creates an authentication response. The username and how long the authentication is valid in milliseconds
Expand All @@ -37,10 +40,24 @@ public class AuthenticationResponse implements Serializable {
* @param issuer The issuer of the token
*/
public AuthenticationResponse(final String identity, final String username, final long expiration, final String issuer) {
this(identity, username, expiration, issuer, Collections.emptySet());
}

/**
* Creates an authentication response. The username and how long the authentication is valid in milliseconds
*
* @param identity The user identity
* @param username The username
* @param expiration The expiration in milliseconds
* @param issuer The issuer of the token
* @param groups The user groups
*/
public AuthenticationResponse(final String identity, final String username, final long expiration, final String issuer, final Set<String> groups) {
this.identity = identity;
this.username = username;
this.expiration = expiration;
this.issuer = issuer;
this.groups = groups;
}

public String getIdentity() {
Expand All @@ -64,6 +81,10 @@ public long getExpiration() {
return expiration;
}

public Set<String> getGroups() {
return groups;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.util.Collections;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class IdentityAuthenticationProvider implements AuthenticationProvider {

Expand Down Expand Up @@ -94,7 +95,7 @@ protected AuthenticationSuccessToken buildAuthenticatedToken(
return new AuthenticationSuccessToken(new NiFiUserDetails(
new StandardNiFiUser.Builder()
.identity(mappedIdentity)
.groups(getUserGroups(mappedIdentity))
.groups(getUserGroups(mappedIdentity, response))
.clientAddress(requestToken.getClientAddress())
.build()));
}
Expand All @@ -112,6 +113,12 @@ protected Set<String> getUserGroups(final String identity) {
return getUserGroups(authorizer, identity);
}

protected Set<String> getUserGroups(final String identity, AuthenticationResponse response) {
return Stream
.concat(getUserGroups(authorizer, identity).stream(), response.getGroups().stream())
.collect(Collectors.toSet());
}

private static Set<String> getUserGroups(final Authorizer authorizer, final String userIdentity) {
if (authorizer instanceof ManagedAuthorizer) {
final ManagedAuthorizer managedAuthorizer = (ManagedAuthorizer) authorizer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.nifi.registry.web.security.authentication.jwt;

import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jws;
import io.jsonwebtoken.JwtException;
import org.apache.nifi.registry.properties.NiFiRegistryProperties;
import org.apache.nifi.registry.security.authentication.AuthenticationRequest;
Expand All @@ -34,6 +36,7 @@
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.util.Set;
import java.util.concurrent.TimeUnit;

@Component
Expand Down Expand Up @@ -61,16 +64,19 @@ public AuthenticationResponse authenticate(AuthenticationRequest authenticationR
}

final Object credentials = authenticationRequest.getCredentials();
String jwtAuthToken = credentials != null && credentials instanceof String ? (String) credentials : null;

if (credentials == null) {
logger.info("JWT not found in authenticationRequest credentials, returning null.");
return null;
}

try {
final String jwtPrincipal = jwtService.getUserIdentityFromToken(jwtAuthToken);
return new AuthenticationResponse(jwtPrincipal, jwtPrincipal, expiration, issuer);
String jwtAuthToken = credentials.toString();
final Jws<Claims> jws = jwtService.parseAndValidateToken(jwtAuthToken);

final String jwtPrincipal = jwtService.getUserIdentityFromToken(jws);
final Set<String> groups = jwtService.getUserGroupsFromToken(jws);

return new AuthenticationResponse(jwtPrincipal, jwtPrincipal, expiration, issuer, groups);
} catch (JwtException e) {
throw new InvalidAuthenticationException(e.getMessage(), e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@
import org.springframework.stereotype.Service;

import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;

@Service
Expand All @@ -48,6 +54,7 @@ public class JwtService {
private static final MacAlgorithm SIGNATURE_ALGORITHM = Jwts.SIG.HS256;
private static final String KEY_ID_CLAIM = "kid";
private static final String USERNAME_CLAIM = "preferred_username";
private static final String GROUPS_CLAIM = "groups";

private final KeyService keyService;

Expand All @@ -56,7 +63,7 @@ public JwtService(final KeyService keyService) {
this.keyService = keyService;
}

public String getUserIdentityFromToken(final String base64EncodedToken) throws JwtException {
public Jws<Claims> parseAndValidateToken(final String base64EncodedToken) throws JwtException {
// The library representations of the JWT should be kept internal to this service.
try {
final Jws<Claims> jws = parseTokenFromBase64EncodedString(base64EncodedToken);
Expand All @@ -74,14 +81,24 @@ public String getUserIdentityFromToken(final String base64EncodedToken) throws J
if (StringUtils.isEmpty(jws.getPayload().getIssuer())) {
throw new JwtException("No issuer available in token");
}
return jws.getPayload().getSubject();

return jws;
} catch (JwtException e) {
final String errorMessage = "There was an error validating the JWT";
logger.error(errorMessage, e);
throw e;
throw new JwtException("There was an error validating the JWT", e);
}
}

public String getUserIdentityFromToken(final Jws<Claims> jws) throws JwtException {
return jws.getPayload().getSubject();
}

public Set<String> getUserGroupsFromToken(final Jws<Claims> jws) throws JwtException {
@SuppressWarnings("unchecked")
final List<String> groupsString = jws.getPayload().get(GROUPS_CLAIM, ArrayList.class);

return new HashSet<>(groupsString != null ? groupsString : Collections.emptyList());
}

private Jws<Claims> parseTokenFromBase64EncodedString(final String base64EncodedToken) throws JwtException {
try {
return Jwts.parser().setSigningKeyResolver(new SigningKeyResolverAdapter() {
Expand Down Expand Up @@ -125,11 +142,15 @@ public String generateSignedToken(final AuthenticationResponse authenticationRes
authenticationResponse.getUsername(),
authenticationResponse.getIssuer(),
authenticationResponse.getIssuer(),
authenticationResponse.getExpiration());
authenticationResponse.getExpiration(),
null);
}

public String generateSignedToken(String identity, String preferredUsername, String issuer, String audience, long expirationMillis) throws JwtException {
return this.generateSignedToken(identity, preferredUsername, issuer, audience, expirationMillis, null);
}

public String generateSignedToken(String identity, String preferredUsername, String issuer, String audience, long expirationMillis, Collection<String> groups) throws JwtException {
if (identity == null || StringUtils.isEmpty(identity)) {
String errorMessage = "Cannot generate a JWT for a token with an empty identity";
errorMessage = issuer != null ? errorMessage + " issued by " + issuer + "." : ".";
Expand All @@ -155,6 +176,7 @@ public String generateSignedToken(String identity, String preferredUsername, Str
.audience().add(audience).and()
.claim(USERNAME_CLAIM, preferredUsername)
.claim(KEY_ID_CLAIM, key.getId())
.claim(GROUPS_CLAIM, groups != null ? groups : Collections.EMPTY_LIST)
.issuedAt(now.getTime())
.expiration(expiration.getTime())
.signWith(Keys.hmacShaKeyFor(keyBytes), SIGNATURE_ALGORITHM).compact();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,10 @@ private String convertOIDCTokenToNiFiToken(OIDCTokenResponse response) throws Ba
String identityClaim = properties.getOidcClaimIdentifyingUser();
String identity = claimsSet.getStringClaim(identityClaim);

// Attempt to extract groups from the configured claim; default is 'groups'
final String groupsClaim = properties.getOidcClaimGroups();
final List<String> groups = claimsSet.getStringListClaim(groupsClaim);

// If default identity not available, attempt secondary identity extraction
if (StringUtils.isBlank(identity)) {
// Provide clear message to admin that desired claim is missing and present available claims
Expand All @@ -425,7 +429,7 @@ private String convertOIDCTokenToNiFiToken(OIDCTokenResponse response) throws Ba
final String issuer = claimsSet.getIssuer().getValue();

// convert into a nifi jwt for retrieval later
return jwtService.generateSignedToken(identity, identity, issuer, issuer, expiresIn);
return jwtService.generateSignedToken(identity, identity, issuer, issuer, expiresIn, groups);
}

private String retrieveIdentityFromUserInfoEndpoint(OIDCTokens oidcTokens) throws IOException {
Expand Down
Loading