Skip to content

Commit

Permalink
Merge pull request #136 from squareup/bradfol/abstract-validation
Browse files Browse the repository at this point in the history
Validate that abstract registrations are fulfilled
  • Loading branch information
bradfol authored Apr 8, 2024
2 parents 7a900ae + 3153a5e commit e187b22
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 24 deletions.
4 changes: 2 additions & 2 deletions Sources/Knit/Module/AbstractAssembly.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@

import Foundation

/// An AbstractAssembly can only contain abstract registrations and should not be initialised.
public protocol AbstractAssembly: ModuleAssembly { }
/// An AbstractAssembly can only contain abstract registrations.
public protocol AbstractAssembly: AutoInitModuleAssembly { }
7 changes: 6 additions & 1 deletion Sources/Knit/Module/DependencyBuilder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,18 @@ final class DependencyBuilder {
}
let overrideTypes = allModuleTypes.filter { !$0.implements.isEmpty }

// Collect AbstractAssemblies as they should all be instantiated and added to the container.
// This needs to happen before the filter below as they are all expected to be implemented by other assemblies
// and will therefore be filtered out.
let allAbstractModules = allModuleTypes.filter { $0 is any AbstractAssembly.Type }

// Filter out any types where an override was found
allModuleTypes = allModuleTypes.filter { moduleType in
return !overrideTypes.contains(where: {$0.doesImplement(type: moduleType)})
}

// Instantiate all types
for type in allModuleTypes {
for type in allModuleTypes + allAbstractModules {
guard !self.isRegisteredInParent(type) else {
continue
}
Expand Down
8 changes: 4 additions & 4 deletions Sources/Knit/Module/ModuleAssembly.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ public protocol ModuleAssembly: Assembly {

static var dependencies: [any ModuleAssembly.Type] { get }

/// A ModuleAssembly can implement any number of other modules
/// If this module implements another it is expected to provide all registrations that the base assembly supplies
/// A common case is for a fake assembly that registers fake services from matching those from the original module
/// The override is generally expected to live in a separate module so it can be imported just for tests
/// A ModuleAssembly can implement any number of other modules' assemblies.
/// If this module implements another it is expected to provide all registrations that the implemented assemblies supply.
/// A common case is for an "implementation" assembly to fulfill all the abstract registrations from an AbstractAssembly.
/// Similarly, another common case is a fake assembly that registers fake services matching those from the original module.
static var implements: [any ModuleAssembly.Type] { get }

/// Filter the list of dependencies down to those which match the scope of this assembly
Expand Down
22 changes: 15 additions & 7 deletions Sources/KnitCodeGen/AssemblyParsing.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class AssemblyFileVisitor: SyntaxVisitor, IfConfigVisitor {

private(set) var moduleName: String?

private(set) var assemblyType: String?
private(set) var assemblyType: Configuration.AssemblyType?

private(set) var directives: KnitDirectives = .empty

Expand Down Expand Up @@ -97,10 +97,18 @@ class AssemblyFileVisitor: SyntaxVisitor, IfConfigVisitor {
assemblyName = names?.0
moduleName = node.namesForAssembly?.1

let inheritedTypes = inheritance?.inheritedTypes.map {
$0.type.description.trimmingCharacters(in: .whitespaces)
let inheritedTypes = inheritance?.inheritedTypes.compactMap {
if let identifier = $0.type.as(IdentifierTypeSyntax.self) {
return identifier.name.text
} else if let member = $0.type.as(MemberTypeSyntax.self) {
return member.name.text
} else {
return nil
}
}
self.assemblyType = inheritedTypes?.first(where: { $0.hasSuffix("Assembly")})
self.assemblyType = inheritedTypes?
.first { $0.hasSuffix(Configuration.AssemblyType.baseAssembly.rawValue) }
.flatMap { Configuration.AssemblyType(rawValue: $0) }
classDeclVisitor = ClassDeclVisitor(viewMode: .fixedUp, directives: directives, assemblyType: assemblyType)
classDeclVisitor?.walk(node)
return .skipChildren
Expand All @@ -111,7 +119,7 @@ class AssemblyFileVisitor: SyntaxVisitor, IfConfigVisitor {
private class ClassDeclVisitor: SyntaxVisitor, IfConfigVisitor {

private let directives: KnitDirectives
private let assemblyType: String?
private let assemblyType: Configuration.AssemblyType?

/// The registrations that were found in the tree.
private(set) var registrations = [Registration]()
Expand All @@ -128,7 +136,7 @@ private class ClassDeclVisitor: SyntaxVisitor, IfConfigVisitor {
/// For any registrations parsed, this #if condition should be applied when it is used
var currentIfConfigCondition: IfConfigVisitorCondition?

init(viewMode: SyntaxTreeViewMode, directives: KnitDirectives, assemblyType: String?) {
init(viewMode: SyntaxTreeViewMode, directives: KnitDirectives, assemblyType: Configuration.AssemblyType?) {
self.directives = directives
self.assemblyType = assemblyType
super.init(viewMode: viewMode)
Expand All @@ -142,7 +150,7 @@ private class ClassDeclVisitor: SyntaxVisitor, IfConfigVisitor {
do {
var (registrations, registrationsIntoCollections) = try node.getRegistrations(
defaultDirectives: directives,
abstractOnly: assemblyType == "AbstractAssembly"
abstractOnly: assemblyType == .abstractAssembly
)
registrations = registrations.map { registration in
var mutable = registration
Expand Down
18 changes: 11 additions & 7 deletions Sources/KnitCodeGen/Configuration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@ public struct Configuration: Encodable {
public var assemblyName: String
public let moduleName: String
public var directives: KnitDirectives
public var assemblyType: String

public enum AssemblyType: String, Encodable {
/// `Swinject.Assembly`
case baseAssembly = "Assembly"
case moduleAssembly = "ModuleAssembly"
case autoInitAssembly = "AutoInitModuleAssembly"
case abstractAssembly = "AbstractAssembly"
}
public var assemblyType: AssemblyType

public var registrations: [Registration]
public var registrationsIntoCollections: [RegistrationIntoCollection]
Expand All @@ -24,7 +32,7 @@ public struct Configuration: Encodable {
assemblyName: String,
moduleName: String,
directives: KnitDirectives = .init(),
assemblyType: String = "Assembly",
assemblyType: AssemblyType = .baseAssembly,
registrations: [Registration],
registrationsIntoCollections: [RegistrationIntoCollection],
imports: [ModuleImport] = [],
Expand Down Expand Up @@ -67,10 +75,6 @@ public struct Configuration: Encodable {
return String(assemblyName.dropLast(8))
}

var isAbstract: Bool {
return assemblyType == "AbstractAssembly"
}

}

public extension Configuration {
Expand All @@ -84,7 +88,7 @@ public extension Configuration {
}

func makeUnitTestSourceFile() throws -> SourceFileSyntax {
guard !self.isAbstract else {
guard self.assemblyType != .abstractAssembly else {
// Abstract assemblies don't need unit tests but we should still generate an empty test case
// otherwise unit test jobs will fail if they don't find any test cases in a test target
return .init(stringLiteral: """
Expand Down
64 changes: 62 additions & 2 deletions Tests/KnitCodeGenTests/AssemblyParsingTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ final class AssemblyParsingTests: XCTestCase {
]
)
XCTAssertEqual(config.registrations.count, 0, "No registrations")
XCTAssertEqual(config.assemblyType, "ModuleAssembly")
XCTAssertEqual(config.assemblyType, .moduleAssembly)
XCTAssertEqual(config.assemblyShortName, "FooTest")
}

Expand Down Expand Up @@ -110,7 +110,7 @@ final class AssemblyParsingTests: XCTestCase {

let config = try assertParsesSyntaxTree(sourceFile)
XCTAssertEqual(config.assemblyName, "FooTestAssembly")
XCTAssertEqual(config.assemblyType, "Assembly")
XCTAssertEqual(config.assemblyType, .baseAssembly)
}

func testAssemblyStructModuleName() throws {
Expand Down Expand Up @@ -144,6 +144,66 @@ final class AssemblyParsingTests: XCTestCase {
)
}

func testAssemblyTypes() throws {
let sourceFilesAndExpected: [(SourceFileSyntax, Configuration.AssemblyType)] = [
(
"""
class TestAssembly: Assembly {
func assemble(container: Container) {}
}
""",
.baseAssembly
),

(
"""
class TestAssembly: Swinject.Assembly {
func assemble(container: Container) {}
}
""",
.baseAssembly
),

(
"""
class TestAssembly: ModuleAssembly {
func assemble(container: Container) {}
}
""",
.moduleAssembly
),

(
"""
class TestAssembly: AutoInitModuleAssembly {
func assemble(container: Container) {}
}
""",
.autoInitAssembly
),

(
"""
class TestAssembly: AbstractAssembly {
func assemble(container: Container) {}
}
""",
.abstractAssembly
),
]

try sourceFilesAndExpected.enumerated().forEach { (index, tuple) in
let (sourceFile, expectedType) = tuple
let config = try assertParsesSyntaxTree(sourceFile)
XCTAssertEqual(
config.assemblyType,
expectedType,
"Failed for tuple at index \(index)"
)
}

}

func testKnitDirectives() throws {
let sourceFile: SourceFileSyntax = """
// @knit public getter-named
Expand Down
2 changes: 1 addition & 1 deletion Tests/KnitCodeGenTests/UnitTestSourceFileTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ final class UnitTestSourceFileTests: XCTestCase {
let configuration = Configuration(
assemblyName: "ModuleAbstractAssembly",
moduleName: "Module",
assemblyType: "AbstractAssembly",
assemblyType: .abstractAssembly,
registrations: [
.init(service: "ServiceA", accessLevel: .internal, arguments: [.init(type: "String")]),
],
Expand Down
62 changes: 62 additions & 0 deletions Tests/KnitTests/ModuleAssemblerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,24 @@ final class ModuleAssemblerTests: XCTestCase {
XCTAssertNotNil(child.resolver.resolve(Service1.self))
XCTAssertNil(parent.resolver.resolve(Service3.self))
}

func test_abstractAssemblyValidation() {
XCTAssertThrowsError(
try ModuleAssembler(
_modules: [ Assembly4() ]
),
"Should throw an error for missing concrete registration to fulfill abstract registration",
{ error in
guard let abstractRegistrationErrors = error as? Container.AbstractRegistrationErrors else {
XCTFail("Incorrect error type \(error)")
return
}
XCTAssertEqual(abstractRegistrationErrors.errors.count, 1)
XCTAssertEqual(abstractRegistrationErrors.errors.first?.serviceType, "Assembly5Protocol")
}
)
}

}

// Assembly1 depends on Assembly2 and registers Service1
Expand Down Expand Up @@ -92,3 +110,47 @@ private struct Service1 {

private struct Service2 {}
private struct Service3 {}

// MARK: - AbstractAssembly

private struct Assembly4: AutoInitModuleAssembly {

func assemble(container: Swinject.Container) {
// None
}

static var dependencies: [any ModuleAssembly.Type] {
[
AbstractAssembly5.self,
Assembly5.self,
]
}
}

private struct AbstractAssembly5: AbstractAssembly {

static var dependencies: [any ModuleAssembly.Type] {
[]
}

func assemble(container: Swinject.Container) {
container.registerAbstract(Assembly5Protocol.self)
}

}

private protocol Assembly5Protocol { }

private struct Assembly5: AutoInitModuleAssembly {

static var dependencies: [any ModuleAssembly.Type] { [] }

func assemble(container: Swinject.Container) {
// Missing a concrete registration for `Assembly5Protocol`
}

static var implements: [any ModuleAssembly.Type] {
[AbstractAssembly5.self]
}

}

0 comments on commit e187b22

Please sign in to comment.