Skip to content

Commit

Permalink
Add originalUser and authenticatedUser as selectors available for res…
Browse files Browse the repository at this point in the history
…ource group selection
  • Loading branch information
xkrogen committed Jan 9, 2025
1 parent e86635d commit f3dfb32
Show file tree
Hide file tree
Showing 26 changed files with 566 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import io.trino.spi.TrinoException;
import io.trino.spi.resourcegroups.SelectionContext;
import io.trino.spi.resourcegroups.SelectionCriteria;
import io.trino.spi.security.Identity;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.PreDestroy;
import org.weakref.jmx.Flatten;
Expand Down Expand Up @@ -229,6 +230,8 @@ private <C> void createQueryInternal(QueryId queryId, Span querySpan, Slug slug,
sessionContext.getIdentity().getPrincipal().isPresent(),
sessionContext.getIdentity().getUser(),
sessionContext.getIdentity().getGroups(),
sessionContext.getOriginalIdentity().getUser(),
sessionContext.getAuthenticatedIdentity().map(Identity::getUser),
sessionContext.getSource(),
sessionContext.getClientTags(),
sessionContext.getResourceEstimates(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ public final class SelectionCriteria
private final boolean authenticated;
private final String user;
private final Set<String> userGroups;
private final String originalUser;
private final Optional<String> authenticatedUser;
private final Optional<String> source;
private final Set<String> clientTags;
private final ResourceEstimates resourceEstimates;
Expand All @@ -35,6 +37,8 @@ public SelectionCriteria(
boolean authenticated,
String user,
Set<String> userGroups,
String originalUser,
Optional<String> authenticatedUser,
Optional<String> source,
Set<String> clientTags,
ResourceEstimates resourceEstimates,
Expand All @@ -43,12 +47,39 @@ public SelectionCriteria(
this.authenticated = authenticated;
this.user = requireNonNull(user, "user is null");
this.userGroups = requireNonNull(userGroups, "userGroups is null");
this.originalUser = requireNonNull(originalUser, "originalUser is null");
this.authenticatedUser = requireNonNull(authenticatedUser, "principal is null");
this.source = requireNonNull(source, "source is null");
this.clientTags = Set.copyOf(requireNonNull(clientTags, "clientTags is null"));
this.resourceEstimates = requireNonNull(resourceEstimates, "resourceEstimates is null");
this.queryType = requireNonNull(queryType, "queryType is null");
}

/**
* @deprecated Use {@link #SelectionCriteria(boolean, String, Set, String, Optional, Optional, Set, ResourceEstimates, Optional)} instead.
*/
@Deprecated(since = "469", forRemoval = true)
public SelectionCriteria(
boolean authenticated,
String user,
Set<String> userGroups,
Optional<String> source,
Set<String> clientTags,
ResourceEstimates resourceEstimates,
Optional<String> queryType)
{
this(
authenticated,
user,
userGroups,
user,
Optional.empty(),
source,
clientTags,
resourceEstimates,
queryType);
}

public boolean isAuthenticated()
{
return authenticated;
Expand All @@ -64,6 +95,16 @@ public Set<String> getUserGroups()
return userGroups;
}

public String getOriginalUser()
{
return originalUser;
}

public Optional<String> getAuthenticatedUser()
{
return authenticatedUser;
}

public Optional<String> getSource()
{
return source;
Expand Down Expand Up @@ -91,6 +132,8 @@ public String toString()
.add("authenticated=" + authenticated)
.add("user='" + user + "'")
.add("userGroups=" + userGroups)
.add("originalUser=" + originalUser)
.add("authenticatedUser=" + authenticatedUser)
.add("source=" + source)
.add("clientTags=" + clientTags)
.add("resourceEstimates=" + resourceEstimates)
Expand Down
8 changes: 8 additions & 0 deletions docs/src/main/sphinx/admin/resource-groups-example.json
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@
"user": "bob",
"group": "admin"
},
{
"originalUser": "bob",
"group": "admin"
},
{
"authenticatedUser": "bob",
"group": "admin"
},
{
"userGroup": "admin",
"group": "admin"
Expand Down
27 changes: 20 additions & 7 deletions docs/src/main/sphinx/admin/resource-groups.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,14 @@ documentation](https://docs.oracle.com/en/java/javase/23/docs/api/java.base/java

- `user` (optional): Java regex to match against user name.

- `originalUser` (optional): Java regex to match against the _original_ user name,
i.e. before any changes to the session user. For example, if user "foo" runs
`SET SESSION AUTHORIZATION 'bar'`, `originalUser` is "foo", while `user` is "bar".

- `authenticatedUser` (optional): Java regex to match against the _authenticated_ user name,
which will always refer to the user that authenticated with the system, regardless of any
changes made to the session user.

- `userGroup` (optional): Java regex to match against every user group the user belongs to.

- `source` (optional): Java regex to match against source string.
Expand Down Expand Up @@ -239,13 +247,17 @@ query. You may also use custom named variables in the `source` and `user` regula
There are four selectors, that define which queries run in which resource group:

- The first selector matches queries from `bob` and places them in the admin group.
- The second selector matches queries from `admin` user group and places them in the admin group.
- The third selector matches all data definition (DDL) queries from a source name that includes `pipeline`
- The next selector matches queries with an _original_ user of `bob`
and places them in the admin group.
- The next selector matches queries with an _authenticated_ user of `bob`
and places them in the admin group.
- The next selector matches queries from `admin` user group and places them in the admin group.
- The next selector matches all data definition (DDL) queries from a source name that includes `pipeline`
and places them in the `global.data_definition` group. This could help reduce queue times for this
class of queries, since they are expected to be fast.
- The fourth selector matches queries from a source name that includes `pipeline`, and places them in a
- The next selector matches queries from a source name that includes `pipeline`, and places them in a
dynamically-created per-user pipeline group under the `global.pipeline` group.
- The fifth selector matches queries that come from BI tools which have a source matching the regular
- The next selector matches queries that come from BI tools which have a source matching the regular
expression `jdbc#(?<toolname>.*)` and have client provided tags that are a superset of `hipri`.
These are placed in a dynamically-created sub-group under the `global.adhoc` group.
The dynamic sub-groups are created based on the values of named variables `toolname` and `user`.
Expand All @@ -257,9 +269,10 @@ There are four selectors, that define which queries run in which resource group:

Together, these selectors implement the following policy:

- The user `bob` and any user belonging to user group `admin`
is an admin and can run up to 50 concurrent queries.
Queries will be run based on user-provided priority.
- The user `bob` and any user belonging to user group `admin` is an admin and can run up to
50 concurrent queries. `bob` will be treated as an admin even if they have changed their session
user to a different user (i.e. via a `SET SESSION AUTHORIZATION` statement or the
`X-Trino-User` request header). Queries will be run based on user-provided priority.

For the remaining users:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ protected List<ResourceGroupSelector> buildSelectors(ManagerSpec managerSpec)
selectors.add(new StaticSelector(
spec.getUserRegex(),
spec.getUserGroupRegex(),
spec.getOriginalUserRegex(),
spec.getAuthenticatedUserRegex(),
spec.getSourceRegex(),
spec.getClientTags(),
spec.getResourceEstimate(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ public class SelectorSpec
{
private final Optional<Pattern> userRegex;
private final Optional<Pattern> userGroupRegex;
private final Optional<Pattern> originalUserRegex;
private final Optional<Pattern> authenticatedUserRegex;
private final Optional<Pattern> sourceRegex;
private final Optional<String> queryType;
private final Optional<List<String>> clientTags;
Expand All @@ -38,6 +40,8 @@ public class SelectorSpec
public SelectorSpec(
@JsonProperty("user") Optional<Pattern> userRegex,
@JsonProperty("userGroup") Optional<Pattern> userGroupRegex,
@JsonProperty("originalUser") Optional<Pattern> originalUserRegex,
@JsonProperty("authenticatedUser") Optional<Pattern> authenticatedUserRegex,
@JsonProperty("source") Optional<Pattern> sourceRegex,
@JsonProperty("queryType") Optional<String> queryType,
@JsonProperty("clientTags") Optional<List<String>> clientTags,
Expand All @@ -46,6 +50,8 @@ public SelectorSpec(
{
this.userRegex = requireNonNull(userRegex, "userRegex is null");
this.userGroupRegex = requireNonNull(userGroupRegex, "userGroupRegex is null");
this.originalUserRegex = requireNonNull(originalUserRegex, "originalUserRegex is null");
this.authenticatedUserRegex = requireNonNull(authenticatedUserRegex, "authenticatedUserRegex is null");
this.sourceRegex = requireNonNull(sourceRegex, "sourceRegex is null");
this.queryType = requireNonNull(queryType, "queryType is null");
this.clientTags = requireNonNull(clientTags, "clientTags is null");
Expand All @@ -63,6 +69,16 @@ public Optional<Pattern> getUserGroupRegex()
return userGroupRegex;
}

public Optional<Pattern> getOriginalUserRegex()
{
return originalUserRegex;
}

public Optional<Pattern> getAuthenticatedUserRegex()
{
return authenticatedUserRegex;
}

public Optional<Pattern> getSourceRegex()
{
return sourceRegex;
Expand Down Expand Up @@ -103,6 +119,10 @@ public boolean equals(Object other)
userRegex.map(Pattern::flags).equals(that.userRegex.map(Pattern::flags)) &&
userGroupRegex.map(Pattern::pattern).equals(that.userGroupRegex.map(Pattern::pattern)) &&
userGroupRegex.map(Pattern::flags).equals(that.userGroupRegex.map(Pattern::flags)) &&
originalUserRegex.map(Pattern::pattern).equals(that.originalUserRegex.map(Pattern::pattern)) &&
originalUserRegex.map(Pattern::flags).equals(that.originalUserRegex.map(Pattern::flags)) &&
authenticatedUserRegex.map(Pattern::pattern).equals(that.authenticatedUserRegex.map(Pattern::pattern)) &&
authenticatedUserRegex.map(Pattern::flags).equals(that.authenticatedUserRegex.map(Pattern::flags)) &&
sourceRegex.map(Pattern::pattern).equals(that.sourceRegex.map(Pattern::pattern))) &&
sourceRegex.map(Pattern::flags).equals(that.sourceRegex.map(Pattern::flags)) &&
queryType.equals(that.queryType) &&
Expand All @@ -118,6 +138,10 @@ public int hashCode()
userRegex.map(Pattern::flags),
userGroupRegex.map(Pattern::pattern),
userGroupRegex.map(Pattern::flags),
originalUserRegex.map(Pattern::pattern),
originalUserRegex.map(Pattern::flags),
authenticatedUserRegex.map(Pattern::pattern),
authenticatedUserRegex.map(Pattern::flags),
sourceRegex.map(Pattern::pattern),
sourceRegex.map(Pattern::flags),
queryType,
Expand All @@ -133,6 +157,10 @@ public String toString()
.add("userFlags", userRegex.map(Pattern::flags))
.add("userGroupRegex", userGroupRegex)
.add("userGroupFlags", userGroupRegex.map(Pattern::flags))
.add("originalUserRegex", originalUserRegex)
.add("originalUserFlags", originalUserRegex.map(Pattern::flags))
.add("authenticatedUserRegex", authenticatedUserRegex)
.add("authenticatedUserFlags", authenticatedUserRegex.map(Pattern::flags))
.add("sourceRegex", sourceRegex)
.add("sourceFlags", sourceRegex.map(Pattern::flags))
.add("queryType", queryType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ public class StaticSelector

private final Optional<Pattern> userRegex;
private final Optional<Pattern> userGroupRegex;
private final Optional<Pattern> originalUserRegex;
private final Optional<Pattern> authenticatedUserRegex;
private final Optional<Pattern> sourceRegex;
private final Set<String> clientTags;
private final Optional<SelectorResourceEstimate> selectorResourceEstimate;
Expand All @@ -52,6 +54,8 @@ public class StaticSelector
public StaticSelector(
Optional<Pattern> userRegex,
Optional<Pattern> userGroupRegex,
Optional<Pattern> originalUserRegex,
Optional<Pattern> authenticatedUserRegex,
Optional<Pattern> sourceRegex,
Optional<List<String>> clientTags,
Optional<SelectorResourceEstimate> selectorResourceEstimate,
Expand All @@ -60,6 +64,8 @@ public StaticSelector(
{
this.userRegex = requireNonNull(userRegex, "userRegex is null");
this.userGroupRegex = requireNonNull(userGroupRegex, "userGroupRegex is null");
this.originalUserRegex = requireNonNull(originalUserRegex, "originalUserRegex is null");
this.authenticatedUserRegex = requireNonNull(authenticatedUserRegex, "authenticatedUserRegex is null");
this.sourceRegex = requireNonNull(sourceRegex, "sourceRegex is null");
requireNonNull(clientTags, "clientTags is null");
this.clientTags = ImmutableSet.copyOf(clientTags.orElse(ImmutableList.of()));
Expand All @@ -69,6 +75,8 @@ public StaticSelector(

HashSet<String> variableNames = new HashSet<>(ImmutableList.of(USER_VARIABLE, SOURCE_VARIABLE));
userRegex.ifPresent(u -> addNamedGroups(u, variableNames));
originalUserRegex.ifPresent(u -> addNamedGroups(u, variableNames));
authenticatedUserRegex.ifPresent(u -> addNamedGroups(u, variableNames));
sourceRegex.ifPresent(s -> addNamedGroups(s, variableNames));
this.variableNames = ImmutableSet.copyOf(variableNames);

Expand All @@ -81,26 +89,24 @@ public Optional<SelectionContext<ResourceGroupIdTemplate>> match(SelectionCriter
{
Map<String, String> variables = new HashMap<>();

if (userRegex.isPresent()) {
Matcher userMatcher = userRegex.get().matcher(criteria.getUser());
if (!userMatcher.matches()) {
return Optional.empty();
}

addVariableValues(userRegex.get(), criteria.getUser(), variables);
if (!addVariablesForRegexIfMatching(userRegex, criteria.getUser(), variables)) {
return Optional.empty();
}

if (userGroupRegex.isPresent() && criteria.getUserGroups().stream().noneMatch(group -> userGroupRegex.get().matcher(group).matches())) {
return Optional.empty();
}

if (sourceRegex.isPresent()) {
String source = criteria.getSource().orElse("");
if (!sourceRegex.get().matcher(source).matches()) {
return Optional.empty();
}
if (!addVariablesForRegexIfMatching(originalUserRegex, criteria.getOriginalUser(), variables)) {
return Optional.empty();
}

addVariableValues(sourceRegex.get(), source, variables);
if (!addVariablesForRegexIfMatching(authenticatedUserRegex, criteria.getAuthenticatedUser().orElse(""), variables)) {
return Optional.empty();
}

if (!addVariablesForRegexIfMatching(sourceRegex, criteria.getSource().orElse(""), variables)) {
return Optional.empty();
}

if (!clientTags.isEmpty() && !criteria.getTags().containsAll(clientTags)) {
Expand Down Expand Up @@ -137,22 +143,38 @@ private static void addNamedGroups(Pattern pattern, HashSet<String> variables)
}
}

private void addVariableValues(Pattern pattern, String candidate, Map<String, String> mapping)
/**
* @param optionalRegex The optional regex to match against the input.
* @param input The input to match against the regex.
* @param variables Variables to populate with the values from the regex.
* @return False iff the regex is present and the input does not match it, else true,
* indicating matching should continue.
*/
private boolean addVariablesForRegexIfMatching(Optional<Pattern> optionalRegex, String input, Map<String, String> variables)
{
if (optionalRegex.isEmpty()) {
return true;
}
Pattern pattern = optionalRegex.get();
if (!pattern.matcher(input).matches()) {
return false;
}

for (String key : variableNames) {
Matcher keyMatcher = pattern.matcher(candidate);
Matcher keyMatcher = pattern.matcher(input);
if (keyMatcher.find()) {
try {
String value = keyMatcher.group(key);
if (value != null) {
mapping.put(key, value);
variables.put(key, value);
}
}
catch (IllegalArgumentException _) {
// there was no capturing group with the specified name
}
}
}
return true;
}

@VisibleForTesting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,8 @@ private synchronized Map.Entry<ManagerSpec, Map<ResourceGroupIdTemplate, Resourc
new SelectorSpec(
selectorRecord.getUserRegex(),
selectorRecord.getUserGroupRegex(),
selectorRecord.getOriginalUserRegex(),
selectorRecord.getAuthenticatedUserRegex(),
selectorRecord.getSourceRegex(),
selectorRecord.getQueryType(),
selectorRecord.getClientTags(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public interface ResourceGroupsDao
@UseRowMapper(ResourceGroupSpecBuilder.Mapper.class)
List<ResourceGroupSpecBuilder> getResourceGroups(@Bind("environment") String environment);

@SqlQuery("SELECT S.resource_group_id, S.priority, S.user_regex, S.source_regex, S.query_type, S.client_tags, S.selector_resource_estimate, S.user_group_regex\n" +
@SqlQuery("SELECT S.resource_group_id, S.priority, S.user_regex, S.source_regex, S.original_user_regex, S.authenticated_user_regex, S.query_type, S.client_tags, S.selector_resource_estimate, S.user_group_regex\n" +
"FROM selectors S\n" +
"JOIN resource_groups R ON (S.resource_group_id = R.resource_group_id)\n" +
"WHERE R.environment = :environment\n" +
Expand All @@ -73,6 +73,8 @@ public interface ResourceGroupsDao
" priority BIGINT NOT NULL,\n" +
" user_regex VARCHAR(512),\n" +
" user_group_regex VARCHAR(512),\n" +
" original_user_regex VARCHAR(512),\n" +
" authenticated_user_regex VARCHAR(512),\n" +
" source_regex VARCHAR(512),\n" +
" query_type VARCHAR(512),\n" +
" client_tags VARCHAR(512),\n" +
Expand Down
Loading

0 comments on commit f3dfb32

Please sign in to comment.