Skip to content

Commit

Permalink
token exchange implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
redmitry committed Jul 24, 2024
1 parent d526e00 commit 22a99a3
Show file tree
Hide file tree
Showing 9 changed files with 506 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build-docker-image.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:

env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}-snapshot
IMAGE_NAME: ${{ github.repository }}-oidc

jobs:
build:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/maven-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Beacon Network Maven Build
on:
push:
branches:
- dev
- oidc

jobs:
build:
Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>es.bsc.inb.ga4gh</groupId>
<artifactId>beacon-network-v2</artifactId>
<version>0.0.12-SNAPSHOT</version>
<version>0.0.12-oidc</version>
<packaging>war</packaging>

<description>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

package es.bsc.inb.ga4gh.beacon.network.config;

import es.bsc.inb.ga4gh.beacon.network.model.OauthProtectedResource;
import es.bsc.inb.ga4gh.beacon.framework.model.v200.configuration.ServiceConfiguration;
import es.bsc.inb.ga4gh.beacon.framework.model.v200.responses.BeaconEntryTypesResponse;
import es.bsc.inb.ga4gh.beacon.framework.model.v200.responses.BeaconFilteringTermsResponse;
Expand Down Expand Up @@ -90,6 +91,12 @@ public class NetworkConfiguration {
*/
private Map<String, String> endpoints;

/**
* Map of beacons' authentication servers
* where key is a beacon's endpoint (e.g. 'https://beacons.bsc.es/beacon/v2.0.0/')
*/
private Map<String, OauthProtectedResource> protected_resources;

private Map<BeaconMetadataSchema, Map<String, ? extends BeaconInformationalResponse>> metadata;

private Map<String, List<BeaconValidationMessage>> errors;
Expand All @@ -107,6 +114,7 @@ public class NetworkConfiguration {
@PostConstruct
public void init() {
endpoints = new ConcurrentHashMap();
protected_resources = new ConcurrentHashMap();
metadata = new ConcurrentHashMap();
errors = new ConcurrentHashMap();
hashes = new ConcurrentHashMap();
Expand Down Expand Up @@ -214,6 +222,13 @@ private void updateBeacon(String endpoint) {
}
errors.put(endpoint, err);
}

final OauthProtectedResource resource = loadOauthProtectedResource(endpoint);
if (resource == null) {
protected_resources.remove(endpoint);
} else {
protected_resources.put(endpoint, resource);
}
}
}

Expand Down Expand Up @@ -323,6 +338,10 @@ public Map<String, String> getEndpoints() {
return endpoints;
}

public Map<String, OauthProtectedResource> getProtectedResources() {
return protected_resources;
}

/**
* Get the metadata JSON Schema parsing errors.
*
Expand Down Expand Up @@ -351,6 +370,28 @@ private String getBeaconId(String endpoint) {
return null;
}

/**
* Load oauth-protected-resource metadata, if provided.
*
* @param endpoint
* @return
*/
private OauthProtectedResource loadOauthProtectedResource(String endpoint) {
try {
final List<BeaconValidationMessage> err = new ArrayList();
final String json = validator.loadMetadata(endpoint + "/.well-known/oauth-protected-resource", new ValidationErrorsCollector(err));
if (err.isEmpty()) {
return validator.parseMetadata(json, OauthProtectedResource.class);
}
} catch(Exception ex) {
Logger.getLogger(NetworkConfiguration.class.getName()).log(
Level.SEVERE, "error loading from {0} {1}",
new Object[]{endpoint, ex.getMessage()});
}
return null;

}

/**
* Load filtering terms by the endpoint.
*
Expand All @@ -374,5 +415,5 @@ public BeaconFilteringTermsResponse loadFilteringTerms(String endpoint) {
new Object[]{endpoint, ex.getMessage()});
}
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ public class BeaconNetworkAggregator {
@Inject
private BeaconNetworkRequestAnalyzer requestAnalyzer;

@Inject
private BeaconNetworkTokenExchanger tokenExchanger;

@Inject
private BeaconEndpointsMatcher matcher;

Expand Down Expand Up @@ -127,7 +130,7 @@ public Response aggregate(HttpServletRequest request) {
}

final UUID xid = UUID.randomUUID();

final List<CompletableFuture<HttpResponse>> invocations = new ArrayList();

Map<String, Map.Entry<String, String>> matched_endpoints = matcher.match(request);
Expand All @@ -150,7 +153,7 @@ public Response aggregate(HttpServletRequest request) {
} else {
final String err_message =
String.format("request timeout '%s'", processor.template);

log(req, 408, err_message);
}
return res;
Expand Down Expand Up @@ -206,7 +209,7 @@ private Builder getInvocation(String endpoint, HttpServletRequest request) {

final Enumeration<String> authorization = request.getHeaders(HttpHeaders.AUTHORIZATION);
if (authorization != null && authorization.hasMoreElements()) {
Collections.list(authorization).stream()
tokenExchanger.exchange(endpoint, Collections.list(authorization)).stream()
.forEach(h -> builder.header(HttpHeaders.AUTHORIZATION, h));
}

Expand All @@ -220,7 +223,7 @@ private void log(HttpResponse<AbstractBeaconResponse> response) {
final BeaconError err = error.getError();
if (err != null) {
message = err.getErrorMessage();
}
}
}

log(response.request(), response.statusCode(), message);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
/**
* *****************************************************************************
* Copyright (C) 2024 ELIXIR ES, Spanish National Bioinformatics Institute (INB)
* and Barcelona Supercomputing Center (BSC)
*
* Modifications to the initial code base are copyright of their respective
* authors, or their employers as appropriate.
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
* MA 02110-1301 USA
* *****************************************************************************
*/

package es.bsc.inb.ga4gh.beacon.network.engine;

import es.bsc.inb.ga4gh.beacon.network.config.NetworkConfiguration;
import es.bsc.inb.ga4gh.beacon.network.model.AccessTokenResponse;
import es.bsc.inb.ga4gh.beacon.network.model.OauthProtectedResource;
import es.bsc.inb.ga4gh.beacon.network.model.OidcConfigurationProvider;
import jakarta.annotation.PostConstruct;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
import jakarta.json.Json;
import jakarta.json.JsonArray;
import jakarta.json.JsonObject;
import jakarta.json.JsonString;
import jakarta.json.bind.Jsonb;
import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.UriBuilder;
import java.io.ByteArrayInputStream;
import java.net.URLEncoder;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpRequest.BodyPublishers;
import java.net.http.HttpRequest.Builder;
import java.net.http.HttpResponse;
import java.net.http.HttpResponse.BodyHandlers;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;

/**
* @author Dmitry Repchevsky
*/

@ApplicationScoped
public class BeaconNetworkTokenExchanger {

@Inject
private NetworkConfiguration network_configuration;

@Inject
private Jsonb jsonb;

private HttpClient http_client;
private Map<String, OidcConfigurationProvider> configuration_providers;

@PostConstruct
public void init() {
http_client = HttpClient.newBuilder()
.version(HttpClient.Version.HTTP_2)
.followRedirects(HttpClient.Redirect.ALWAYS)
.connectTimeout(Duration.ofSeconds(30))
.build();

configuration_providers = new ConcurrentHashMap();
}

public List<String> exchange(String endpoint, List<String> headers) {
return headers.stream().map(h -> exchangeHeader(endpoint, h)).toList();
}

private String exchangeHeader(String endpoint, String header) {
if (header != null && header.startsWith("Bearer ")) {
final String token = exchangeToken(endpoint, header.substring(7));
if (token != null) {
return "Bearer " + token;
}
}
return header;
}

private String exchangeToken(String endpoint, String token) {
final OauthProtectedResource resource = network_configuration.getProtectedResources().get(endpoint);
if (resource != null) {
final String client_id = resource.getClientId();
final List<String> authorization_servers = resource.getAuthorizationServers();
if (client_id != null && authorization_servers != null) {
final String[] token_parts = token.split("\\.");
if (token_parts.length == 3) {
final JsonObject payload = decode(token_parts[1]);
if (payload != null) {
final String issuer = payload.getString("iss", null);

final List<String> audiences;
final String audience = payload.getString("aud", null);
if (audience != null) {
audiences = List.of(audience);
} else {
final JsonArray aud = payload.getJsonArray("aud");
audiences = aud != null
? aud.getValuesAs(JsonString::getString)
: null;
}
if (!authorization_servers.contains(issuer) ||
(audiences != null && !audiences.contains(client_id))) {
// need exchange
final List<OidcConfigurationProvider> providers = getWellKnownProviders(authorization_servers);
if (providers != null) {
for (OidcConfigurationProvider provider : providers) {
final Builder builder = createTokenExchangeRequest(provider, client_id, token);
try {
final HttpResponse<String> response = http_client.send(builder.build(),
BodyHandlers.ofString(StandardCharsets.UTF_8));
if (response != null && response.statusCode() < 300) {
final String body = response.body();
if (body != null) {
final AccessTokenResponse atResponse =
jsonb.fromJson(body, AccessTokenResponse.class);

final String accessToken = atResponse.getAccessToken();
if (accessToken != null) {
return accessToken;
}
}
}
} catch (Exception ex) {
Logger.getLogger(BeaconNetworkTokenExchanger.class.getName()).log(
Level.INFO, ex.getMessage());
}
}
}
}
}
}
}
}
return null;
}

private List<OidcConfigurationProvider> getWellKnownProviders(List<String> authorization_servers) {

final List<OidcConfigurationProvider> providers = new ArrayList();

final List<CompletableFuture<HttpResponse<String>>> invocations = new ArrayList();
for (String authorization_server : authorization_servers) {
final OidcConfigurationProvider provider = configuration_providers.get(authorization_server);
if (provider != null) {
providers.add(provider);
} else {
final Builder builder = createWellKnownProviderRequest(authorization_server);
final CompletableFuture<HttpResponse<String>> future =
http_client.sendAsync(builder.build(), BodyHandlers.ofString(StandardCharsets.UTF_8));
invocations.add(future);
}
}
for (CompletableFuture<HttpResponse<String>> invocation : invocations) {
try {
final HttpResponse<String> response = invocation.get(30, TimeUnit.SECONDS);
if (response != null && response.statusCode() < 300) {
final String body = response.body();
if (body != null) {
final OidcConfigurationProvider provider =
jsonb.fromJson(body, OidcConfigurationProvider.class);
providers.add(provider);
}
}
} catch (Exception ex) {
Logger.getLogger(BeaconNetworkTokenExchanger.class.getName()).log(
Level.INFO, ex.getMessage());
}
}
return providers;
}

private Builder createWellKnownProviderRequest(String authorization_server) {
return HttpRequest.newBuilder(UriBuilder.fromUri(authorization_server)
.path(".well-known/openid-configuration").build())
.header(HttpHeaders.USER_AGENT, "BN/2.0.0")
.header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON);
}

private Builder createTokenExchangeRequest(OidcConfigurationProvider provider,
String client_id, String token) {

final StringBuilder data = new StringBuilder();

data.append("client_id").append('=').append(client_id)
.append("&subject_token").append('=').append(token)
.append("&grant_type").append('=')
.append(URLEncoder.encode("urn:ietf:params:oauth:grant-type:token-exchange", StandardCharsets.UTF_8))
.append("&subject_token_type").append('=')
.append(URLEncoder.encode("urn:ietf:params:oauth:token-type:jwt", StandardCharsets.UTF_8))
.append("&requested_token_type").append('=')
.append(URLEncoder.encode("urn:ietf:params:oauth:token-type:access_token", StandardCharsets.UTF_8));

return HttpRequest.newBuilder(UriBuilder.fromUri(provider.getTokenEndpoint())
.build())
.header(HttpHeaders.USER_AGENT, "BN/2.0.0")
.header(HttpHeaders.CONTENT_TYPE, "application/x-www-form-urlencoded")
.header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON)
.POST(BodyPublishers.ofString(data.toString(), StandardCharsets.UTF_8));
}

private JsonObject decode(String base64) {
final Base64.Decoder decoder = Base64.getDecoder();
try {
final byte[] b = decoder.decode(base64);
return Json.createReader(new ByteArrayInputStream(b)).readObject();
} catch (Exception ex) {
return null;
}
}
}
Loading

0 comments on commit 22a99a3

Please sign in to comment.