Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NIFI-13016 Add groups mapping from OIDC token claim for Registry #9566

Merged
merged 15 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import java.io.IOException;
import java.io.StringWriter;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;

public class StandardManagedAuthorizer implements ManagedAuthorizer {
Expand Down Expand Up @@ -95,19 +97,36 @@ public AuthorizationResult authorize(AuthorizationRequest request) throws Author

final UserAndGroups userAndGroups = userGroupProvider.getUserAndGroups(request.getIdentity());

final User user = userAndGroups.getUser();
if (user == null) {
return AuthorizationResult.denied(String.format("Unknown user with identity '%s'.", request.getIdentity()));
}
// combine groups from incoming request with groups from UserAndGroups because the request may contain groups from
// an external identity provider and the membership may not be maintained with in any of the UserGroupProviders
final Set<Group> userGroups = new HashSet<>();
userGroups.addAll(userAndGroups.getGroups() == null ? Collections.emptySet() : userAndGroups.getGroups());
userGroups.addAll(getGroups(request.getGroups()));

final Set<Group> userGroups = userAndGroups.getGroups();
if (policy.getUsers().contains(user.getIdentifier()) || containsGroup(userGroups, policy)) {
if (containsUser(userAndGroups.getUser(), policy) || containsGroup(userGroups, policy)) {
return AuthorizationResult.approved();
}

return AuthorizationResult.denied(request.getExplanationSupplier().get());
}

private Set<Group> getGroups(final Set<String> groupNames) {
if (groupNames == null || groupNames.isEmpty()) {
return Collections.emptySet();
}

final Set<Group> groups = new HashSet<>();

for (final String requestGroupName : groupNames) {
final Group requestGroup = userGroupProvider.getGroupByName(requestGroupName);
if (requestGroup != null) {
groups.add(requestGroup);
}
}

return groups;
}

/**
* Determines if the policy contains one of the user's groups.
*
Expand All @@ -129,6 +148,20 @@ private boolean containsGroup(final Set<Group> userGroups, final AccessPolicy po
return false;
}

/**
* Determines if the policy contains the user's identifier.
*
* @param user the user
* @param policy the policy
* @return true if the user is non-null and the user's identifies is contained in the policy's users
*/
private boolean containsUser(final User user, final AccessPolicy policy) {
if (user == null || policy.getUsers().isEmpty()) {
return false;
}
return policy.getUsers().contains(user.getIdentifier());
}

@Override
public String getFingerprint() throws AuthorizationAccessException {
XMLStreamWriter writer = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ public class NiFiRegistryProperties extends ApplicationProperties {
public static final String SECURITY_USER_OIDC_PREFERRED_JWSALGORITHM = "nifi.registry.security.user.oidc.preferred.jwsalgorithm";
public static final String SECURITY_USER_OIDC_ADDITIONAL_SCOPES = "nifi.registry.security.user.oidc.additional.scopes";
public static final String SECURITY_USER_OIDC_CLAIM_IDENTIFYING_USER = "nifi.registry.security.user.oidc.claim.identifying.user";
public static final String SECURITY_USER_OIDC_CLAIM_GROUPS = "nifi.registry.security.user.oidc.claim.groups";

// Revision Management Properties
public static final String REVISIONS_ENABLED = "nifi.registry.revisions.enabled";
Expand Down Expand Up @@ -481,6 +482,16 @@ public List<String> getOidcAdditionalScopes() {
public String getOidcClaimIdentifyingUser() {
return getProperty(SECURITY_USER_OIDC_CLAIM_IDENTIFYING_USER, "email").trim();
}
/**
* Returns the claim to be used to extract user groups from the OIDC payload.
* Claim must be requested by adding the scope for it.
* Default is 'groups'.
*
* @return The claim to be used to extract user groups.
*/
public String getOidcClaimGroups() {
return getProperty(SECURITY_USER_OIDC_CLAIM_GROUPS, "groups").trim();
}

/**
* Returns the network interface list to use for HTTPS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package org.apache.nifi.registry.security.authentication;

import java.io.Serializable;
import java.util.Collections;
import java.util.Set;

/**
* Authentication response for a user login attempt.
Expand All @@ -27,6 +29,7 @@ public class AuthenticationResponse implements Serializable {
private final String username;
private final long expiration;
private final String issuer;
private final Set<String> groups;

/**
* Creates an authentication response. The username and how long the authentication is valid in milliseconds
Expand All @@ -37,10 +40,24 @@ public class AuthenticationResponse implements Serializable {
* @param issuer The issuer of the token
*/
public AuthenticationResponse(final String identity, final String username, final long expiration, final String issuer) {
this(identity, username, expiration, issuer, Collections.emptySet());
}

/**
* Creates an authentication response. The username and how long the authentication is valid in milliseconds
*
* @param identity The user identity
* @param username The username
* @param expiration The expiration in milliseconds
* @param issuer The issuer of the token
* @param groups The user groups
*/
public AuthenticationResponse(final String identity, final String username, final long expiration, final String issuer, final Set<String> groups) {
this.identity = identity;
this.username = username;
this.expiration = expiration;
this.issuer = issuer;
this.groups = groups;
}

public String getIdentity() {
Expand All @@ -64,6 +81,10 @@ public long getExpiration() {
return expiration;
}

public Set<String> getGroups() {
return groups;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,30 @@ public interface UserGroupProvider {
*/
Group getGroup(String identifier) throws AuthorizationAccessException;

/**
* Retrieves a Group by name.
*
* @param name the name of the group to retrieve
* @return the Group with the given name, or null if no matching group was found
* @throws AuthorizationAccessException if there was an unexpected error performing the operation
*/
default Group getGroupByName(String name) throws AuthorizationAccessException {
hazmat345 marked this conversation as resolved.
Show resolved Hide resolved
final Set<Group> allGroups = getGroups();
if (allGroups == null) {
return null;
}

Group matchingGroup = null;
for (Group group : allGroups) {
if (group.getName().equals(name)) {
matchingGroup = group;
break;
}
}

return matchingGroup;
}

/**
* Gets a user and their groups. Must be non null. If the user is not known the UserAndGroups.getUser() and
* UserAndGroups.getGroups() should return null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.util.Collections;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class IdentityAuthenticationProvider implements AuthenticationProvider {

Expand Down Expand Up @@ -94,7 +95,7 @@ protected AuthenticationSuccessToken buildAuthenticatedToken(
return new AuthenticationSuccessToken(new NiFiUserDetails(
new StandardNiFiUser.Builder()
.identity(mappedIdentity)
.groups(getUserGroups(mappedIdentity))
.groups(getUserGroups(mappedIdentity, response))
.clientAddress(requestToken.getClientAddress())
.build()));
}
Expand All @@ -112,6 +113,12 @@ protected Set<String> getUserGroups(final String identity) {
return getUserGroups(authorizer, identity);
}

protected Set<String> getUserGroups(final String identity, AuthenticationResponse response) {
return Stream
.concat(getUserGroups(authorizer, identity).stream(), response.getGroups().stream())
.collect(Collectors.toSet());
}

private static Set<String> getUserGroups(final Authorizer authorizer, final String userIdentity) {
if (authorizer instanceof ManagedAuthorizer) {
final ManagedAuthorizer managedAuthorizer = (ManagedAuthorizer) authorizer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.util.Set;
import java.util.concurrent.TimeUnit;

@Component
Expand Down Expand Up @@ -70,7 +71,8 @@ public AuthenticationResponse authenticate(AuthenticationRequest authenticationR

try {
final String jwtPrincipal = jwtService.getUserIdentityFromToken(jwtAuthToken);
return new AuthenticationResponse(jwtPrincipal, jwtPrincipal, expiration, issuer);
final Set<String> groups = jwtService.getUserGroupsFromToken(jwtAuthToken);
hazmat345 marked this conversation as resolved.
Show resolved Hide resolved
return new AuthenticationResponse(jwtPrincipal, jwtPrincipal, expiration, issuer, groups);
} catch (JwtException e) {
throw new InvalidAuthenticationException(e.getMessage(), e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@
import org.apache.nifi.registry.security.key.KeyService;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.stereotype.Service;

import java.nio.charset.StandardCharsets;
import java.util.Calendar;
import java.util.*;
hazmat345 marked this conversation as resolved.
Show resolved Hide resolved
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

@Service
public class JwtService {
Expand All @@ -48,6 +50,7 @@ public class JwtService {
private static final MacAlgorithm SIGNATURE_ALGORITHM = Jwts.SIG.HS256;
private static final String KEY_ID_CLAIM = "kid";
private static final String USERNAME_CLAIM = "preferred_username";
private static final String GROUPS_CLAIM = "groups";

private final KeyService keyService;

Expand Down Expand Up @@ -82,6 +85,27 @@ public String getUserIdentityFromToken(final String base64EncodedToken) throws J
}
}

public Set<String> getUserGroupsFromToken(final String base64EncodedToken) throws JwtException {
hazmat345 marked this conversation as resolved.
Show resolved Hide resolved
// The library representations of the JWT should be kept internal to this service.
try {
final Jws<Claims> jws = parseTokenFromBase64EncodedString(base64EncodedToken);

if (jws == null) {
throw new JwtException("Unable to parse token");
}

@SuppressWarnings("unchecked")
ArrayList<String> groupsString = jws.getPayload().get(GROUPS_CLAIM, ArrayList.class);

return new HashSet<>(groupsString);
} catch (JwtException e) {
logger.debug("The Base64 encoded JWT: " + base64EncodedToken);
hazmat345 marked this conversation as resolved.
Show resolved Hide resolved
final String errorMessage = "There was an error validating the JWT";
logger.error(errorMessage, e);
throw e;
hazmat345 marked this conversation as resolved.
Show resolved Hide resolved
}
}

private Jws<Claims> parseTokenFromBase64EncodedString(final String base64EncodedToken) throws JwtException {
try {
return Jwts.parser().setSigningKeyResolver(new SigningKeyResolverAdapter() {
Expand Down Expand Up @@ -125,10 +149,15 @@ public String generateSignedToken(final AuthenticationResponse authenticationRes
authenticationResponse.getUsername(),
authenticationResponse.getIssuer(),
authenticationResponse.getIssuer(),
authenticationResponse.getExpiration());
authenticationResponse.getExpiration(),
null);
}

public String generateSignedToken(String identity, String preferredUsername, String issuer, String audience, long expirationMillis) throws JwtException {
return this.generateSignedToken(identity, preferredUsername, issuer, audience, expirationMillis, null);
}

public String generateSignedToken(String identity, String preferredUsername, String issuer, String audience, long expirationMillis, Collection<? extends GrantedAuthority> authorities) throws JwtException {

if (identity == null || StringUtils.isEmpty(identity)) {
String errorMessage = "Cannot generate a JWT for a token with an empty identity";
Expand All @@ -155,6 +184,7 @@ public String generateSignedToken(String identity, String preferredUsername, Str
.audience().add(audience).and()
.claim(USERNAME_CLAIM, preferredUsername)
.claim(KEY_ID_CLAIM, key.getId())
.claim(GROUPS_CLAIM, authorities.stream().map(GrantedAuthority::getAuthority).collect(Collectors.toSet()))
.issuedAt(now.getTime())
.expiration(expiration.getTime())
.signWith(Keys.hmacShaKeyFor(keyBytes), SIGNATURE_ALGORITHM).compact();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
import java.net.URL;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

Expand All @@ -39,6 +41,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.stereotype.Component;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
Expand Down Expand Up @@ -401,6 +404,10 @@ private String convertOIDCTokenToNiFiToken(OIDCTokenResponse response) throws Ba
String identityClaim = properties.getOidcClaimIdentifyingUser();
String identity = claimsSet.getStringClaim(identityClaim);

// Attempt to extract groups from the configured claim; default is 'groups'
String groupsClaim = properties.getOidcClaimGroups();
List<String> groups = claimsSet.getStringListClaim(groupsClaim);
hazmat345 marked this conversation as resolved.
Show resolved Hide resolved

// If default identity not available, attempt secondary identity extraction
if (StringUtils.isBlank(identity)) {
// Provide clear message to admin that desired claim is missing and present available claims
Expand All @@ -424,8 +431,15 @@ private String convertOIDCTokenToNiFiToken(OIDCTokenResponse response) throws Ba
final long expiresIn = expiration.getTime() - now.getTimeInMillis();
final String issuer = claimsSet.getIssuer().getValue();

Set<SimpleGrantedAuthority> authorities = groups != null ? groups.stream().map(
SimpleGrantedAuthority::new).collect(
Collectors.collectingAndThen(
Collectors.toSet(),
Collections::unmodifiableSet
)) : null;

// convert into a nifi jwt for retrieval later
return jwtService.generateSignedToken(identity, identity, issuer, issuer, expiresIn);
return jwtService.generateSignedToken(identity, identity, issuer, issuer, expiresIn, authorities);
}

private String retrieveIdentityFromUserInfoEndpoint(OIDCTokens oidcTokens) throws IOException {
Expand Down
Loading