Skip to content

Commit

Permalink
refactor: save one iteration of the extension list in DependencyGraph (
Browse files Browse the repository at this point in the history
…#4600)

refactor: save one extensions iteration in DependencyGraph
  • Loading branch information
ndr-brt authored Nov 6, 2024
1 parent c5dbf9d commit 858d09b
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,12 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static java.util.Optional.ofNullable;
import static java.util.function.Function.identity;
import static java.util.stream.Collectors.toMap;
import static java.util.stream.Collectors.toSet;
import static java.util.stream.Collectors.toList;


/**
Expand Down Expand Up @@ -71,64 +69,65 @@ public DependencyGraph(ServiceExtensionContext context) {
*/
public List<InjectionContainer<ServiceExtension>> of(List<ServiceExtension> extensions) {
Map<Class<?>, ServiceProvider> defaultServiceProviders = new HashMap<>();
Map<ServiceExtension, List<ServiceProvider>> serviceProviders = new HashMap<>();
Map<Class<?>, List<ServiceExtension>> dependencyMap = new HashMap<>();
extensions.forEach(extension -> {
getProvidedFeatures(extension).forEach(feature -> dependencyMap.computeIfAbsent(feature, k -> new ArrayList<>()).add(extension));
// check all @Provider methods
new ProviderMethodScanner(extension).allProviders()
.peek(providerMethod -> {
var serviceProvider = new ServiceProvider(providerMethod, extension);
if (providerMethod.isDefault()) {
defaultServiceProviders.put(providerMethod.getReturnType(), serviceProvider);
} else {
serviceProviders.computeIfAbsent(extension, k -> new ArrayList<>()).add(serviceProvider);
}
})
.map(ProviderMethod::getReturnType)
.forEach(feature -> dependencyMap.computeIfAbsent(feature, k -> new ArrayList<>()).add(extension));
});
Map<Class<?>, List<InjectionContainer<ServiceExtension>>> dependencyMap = new HashMap<>();
var injectionContainers = extensions.stream()
.map(it -> new InjectionContainer<>(it, new HashSet<>(), new ArrayList<>()))
.peek(injectionContainer -> {
getProvidedFeatures(injectionContainer.getInjectionTarget())
.forEach(feature -> dependencyMap.computeIfAbsent(feature, k -> new ArrayList<>()).add(injectionContainer));

// check all @Provider methods
new ProviderMethodScanner(injectionContainer.getInjectionTarget()).allProviders()
.peek(providerMethod -> {
var serviceProvider = new ServiceProvider(providerMethod, injectionContainer.getInjectionTarget());
if (providerMethod.isDefault()) {
defaultServiceProviders.put(providerMethod.getReturnType(), serviceProvider);
} else {
injectionContainer.getServiceProviders().add(serviceProvider);
}
})
.map(ProviderMethod::getReturnType)
.forEach(feature -> dependencyMap.computeIfAbsent(feature, k -> new ArrayList<>()).add(injectionContainer));
})
.collect(toList());

var sort = new TopologicalSort<ServiceExtension>();
var sort = new TopologicalSort<InjectionContainer<ServiceExtension>>();

// check if all injected fields are satisfied, collect missing ones and throw exception otherwise
var unsatisfiedInjectionPoints = new ArrayList<InjectionPoint<ServiceExtension>>();
var unsatisfiedRequirements = new ArrayList<String>();

var injectionPoints = extensions.stream()
.collect(toMap(identity(), ext -> {

//check that all the @Required features are there
getRequiredFeatures(ext.getClass()).forEach(serviceClass -> {
var dependencies = dependencyMap.get(serviceClass);
if (dependencies == null) {
unsatisfiedRequirements.add(serviceClass.getName());
injectionContainers.forEach(container -> {
//check that all the @Required features are there
getRequiredFeatures(container.getInjectionTarget().getClass()).forEach(serviceClass -> {
var dependencies = dependencyMap.get(serviceClass);
if (dependencies == null) {
unsatisfiedRequirements.add(serviceClass.getName());
} else {
dependencies.forEach(dependency -> sort.addDependency(container, dependency));
}
});

injectionPointScanner.getInjectionPoints(container.getInjectionTarget())
.peek(injectionPoint -> {
var maybeProviders = Optional.of(injectionPoint.getType()).map(dependencyMap::get);

if (maybeProviders.isPresent() || context.hasService(injectionPoint.getType())) {
maybeProviders.ifPresent(l -> l.stream()
.filter(d -> !Objects.equals(d, container)) // remove dependencies onto oneself
.forEach(provider -> sort.addDependency(container, provider)));
} else {
dependencies.forEach(dependency -> sort.addDependency(ext, dependency));
if (injectionPoint.isRequired()) {
unsatisfiedInjectionPoints.add(injectionPoint);
}
}
});

return injectionPointScanner.getInjectionPoints(ext)
.peek(injectionPoint -> {
if (!canResolve(dependencyMap, injectionPoint.getType())) {
if (injectionPoint.isRequired()) {
unsatisfiedInjectionPoints.add(injectionPoint);
}
} else {
// get() would return null, if the feature is already in the context's service list
ofNullable(dependencyMap.get(injectionPoint.getType()))
.ifPresent(l -> l.stream()
.filter(d -> !Objects.equals(d, ext)) // remove dependencies onto oneself
.forEach(provider -> sort.addDependency(ext, provider)));
}

var defaultServiceProvider = defaultServiceProviders.get(injectionPoint.getType());
if (defaultServiceProvider != null) {
injectionPoint.setDefaultServiceProvider(defaultServiceProvider);
}
})
.collect(toSet());
}));
var defaultServiceProvider = defaultServiceProviders.get(injectionPoint.getType());
if (defaultServiceProvider != null) {
injectionPoint.setDefaultServiceProvider(defaultServiceProvider);
}
})
.forEach(injectionPoint -> container.getInjectionPoints().add(injectionPoint));
});

if (!unsatisfiedInjectionPoints.isEmpty()) {
var message = "The following injected fields were not provided:\n";
Expand All @@ -141,21 +140,9 @@ public List<InjectionContainer<ServiceExtension>> of(List<ServiceExtension> exte
throw new EdcException(message);
}

sort.sort(extensions);

return extensions.stream()
.map(key -> new InjectionContainer<>(key, injectionPoints.get(key), serviceProviders.get(key)))
.toList();
}
sort.sort(injectionContainers);

private boolean canResolve(Map<Class<?>, List<ServiceExtension>> dependencyMap, Class<?> serviceClass) {
var providers = dependencyMap.get(serviceClass);
if (providers != null) {
return true;
} else {
// attempt to interpret the feature name as class name, instantiate it and see if the context has that service
return context.hasService(serviceClass);
}
return injectionContainers;
}

private Stream<Class<?>> getRequiredFeatures(Class<?> clazz) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,25 +36,7 @@ public ProviderMethodScanner(ServiceExtension target) {
* Returns all methods annotated with {@link Provider}.
*/
public Stream<ProviderMethod> allProviders() {
return getProviderMethods(target);
}

/**
* Returns all methods annotated with {@link Provider}, where {@link Provider#isDefault()} is {@code false}
*/
public Stream<ProviderMethod> nonDefaultProviders() {
return getProviderMethods(target).filter(pm -> !pm.isDefault());
}

/**
* Returns all methods annotated with {@link Provider}, where {@link Provider#isDefault()} is {@code true}
*/
public Stream<ProviderMethod> defaultProviders() {
return getProviderMethods(target).filter(ProviderMethod::isDefault);
}

private Stream<ProviderMethod> getProviderMethods(Object extension) {
return Arrays.stream(extension.getClass().getDeclaredMethods())
return Arrays.stream(target.getClass().getDeclaredMethods())
.filter(m -> m.getAnnotation(Provider.class) != null)
.map(ProviderMethod::new)
.peek(method -> {
Expand All @@ -66,4 +48,5 @@ private Stream<ProviderMethod> getProviderMethods(Object extension) {
}
});
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,9 @@ void setup() {
}

@Test
void allProviders() {
assertThat(scanner.allProviders()).hasSize(3);
}

@Test
void providerMethods() {
assertThat(scanner
.nonDefaultProviders())
.hasSize(2);
}

@Test
void defaultProviderMethods() throws NoSuchMethodException {
assertThat(scanner
.defaultProviders())
void allProviders() throws NoSuchMethodException {
assertThat(scanner.allProviders()).hasSize(3)
.filteredOn(ProviderMethod::isDefault)
.hasSize(1)
.extracting(ProviderMethod::getMethod)
.containsOnly(TestExtension.class.getMethod("providerDefault"));
Expand All @@ -56,16 +44,14 @@ void defaultProviderMethods() throws NoSuchMethodException {
@Test
void verifyInvalidReturnType() {
var scanner = new ProviderMethodScanner(new InvalidTestExtension());
assertThatThrownBy(() -> scanner.nonDefaultProviders().toList()).isInstanceOf(EdcInjectionException.class);
assertThatThrownBy(() -> scanner.defaultProviders().toList()).isInstanceOf(EdcInjectionException.class);
assertThatThrownBy(() -> scanner.allProviders().toList()).isInstanceOf(EdcInjectionException.class);
}

@Test
void verifyInvalidVisibility() {
var scanner = new ProviderMethodScanner(new InvalidTestExtension2());

assertThatThrownBy(() -> scanner.nonDefaultProviders().toList()).isInstanceOf(EdcInjectionException.class);
assertThatThrownBy(() -> scanner.defaultProviders().toList()).isInstanceOf(EdcInjectionException.class);
assertThatThrownBy(() -> scanner.allProviders().toList()).isInstanceOf(EdcInjectionException.class);
}

private static class TestExtension implements ServiceExtension {
Expand Down

0 comments on commit 858d09b

Please sign in to comment.