From 8c24c71cc527fdf5c817309691ac731351cbd750 Mon Sep 17 00:00:00 2001 From: Alex Skorulis Date: Mon, 23 Dec 2024 14:52:47 +1100 Subject: [PATCH] Experiment to improve performance of the dependency builder --- Sources/Knit/Module/DependencyBuilder.swift | 11 ++-- Sources/Knit/Module/ModuleAssembly.swift | 32 ++++++++++ .../KnitCodeGen/TypeSafetySourceFile.swift | 40 ++++++++++++ .../ConfigurationSetTests.swift | 24 ++++++++ .../TypeSafetySourceFileTests.swift | 61 +++++++++++++++++++ 5 files changed, 163 insertions(+), 5 deletions(-) diff --git a/Sources/Knit/Module/DependencyBuilder.swift b/Sources/Knit/Module/DependencyBuilder.swift index 205da18..6b3c887 100644 --- a/Sources/Knit/Module/DependencyBuilder.swift +++ b/Sources/Knit/Module/DependencyBuilder.swift @@ -58,14 +58,14 @@ final class DependencyBuilder { return existingType } if let overrideType = try defaultOverride(moduleType, fromInput: inputModule != nil), - let autoInit = overrideType as? any AutoInitModuleAssembly.Type { - return autoInit.init() + let created = overrideType._autoInstantiate() { + return created } if let inputModule { return inputModule } - if let autoInit = moduleType as? any AutoInitModuleAssembly.Type { - return autoInit.init() + if let created = moduleType._autoInstantiate() { + return created } throw DependencyBuilderError.moduleNotProvided(moduleType, dependencyTree.sourcePathString(moduleType: moduleType)) @@ -230,7 +230,8 @@ internal extension DependencyBuilder { // Collect AbstractAssemblies as they should all be instantiated and added to the container. // This needs to happen before the filter below as they are all expected to be implemented by other assemblies // and will therefore be filtered out. - if ref.type is any AbstractAssembly.Type { + + if ref.type._assemblyFlags.contains(.abstract) { return true } diff --git a/Sources/Knit/Module/ModuleAssembly.swift b/Sources/Knit/Module/ModuleAssembly.swift index ad128b7..a37eb1b 100644 --- a/Sources/Knit/Module/ModuleAssembly.swift +++ b/Sources/Knit/Module/ModuleAssembly.swift @@ -24,6 +24,11 @@ public protocol ModuleAssembly { /// This can be overridden in apps with custom Resolver hierarchies static func scoped(_ dependencies: [any ModuleAssembly.Type]) -> [any ModuleAssembly.Type] + /// Hints about this assembly using by DependencyBuilder. Designed for internal use + static var _assemblyFlags: [ModuleAssemblyFlags] { get } + + /// Creates an instance of this assembly if the assembly conforms to AutoInitModuleAssembly + static func _autoInstantiate() -> (any ModuleAssembly)? } public extension ModuleAssembly { @@ -40,6 +45,24 @@ public extension ModuleAssembly { return self.resolverType == $0.resolverType } } + + static var _assemblyFlags: [ModuleAssemblyFlags] { + var result: [ModuleAssemblyFlags] = [] + if self is any AutoInitModuleAssembly.Type { + result.append(.autoInit) + } + if self is any AbstractAssembly.Type { + result.append(.abstract) + } + return result + } + + static func _autoInstantiate() -> (any ModuleAssembly)? { + if let autoInit = self as? any AutoInitModuleAssembly.Type { + return autoInit.init() + } + return nil + } } /// A ModuleAssembly that can be initialised without any parameters @@ -97,3 +120,12 @@ public struct OverrideBehavior { return NSClassFromString("XCTestCase") != nil } } + +public struct ModuleAssemblyFlags: OptionSet { + public let rawValue: UInt + + public init(rawValue: UInt) { self.rawValue = rawValue } + + public static let autoInit = ModuleAssemblyFlags(rawValue: 1 << 0) + public static let abstract = ModuleAssemblyFlags(rawValue: 1 << 1) +} diff --git a/Sources/KnitCodeGen/TypeSafetySourceFile.swift b/Sources/KnitCodeGen/TypeSafetySourceFile.swift index 8efdc95..1070655 100644 --- a/Sources/KnitCodeGen/TypeSafetySourceFile.swift +++ b/Sources/KnitCodeGen/TypeSafetySourceFile.swift @@ -46,6 +46,7 @@ public enum TypeSafetySourceFile { if let defaultOverrides = try makeDefaultOverrideExtensions(config: config) { defaultOverrides } + try makePerformanceExtension(config: config) } } @@ -160,6 +161,45 @@ public enum TypeSafetySourceFile { ) } + private static func makePerformanceExtension(config: Configuration) throws -> ExtensionDeclSyntax { + let accessorBlock: AccessorBlockSyntax + let isAutoInit: Bool + if config.assemblyType == .abstractAssembly { + accessorBlock = AccessorBlockSyntax( + accessors: .getter(.init(stringLiteral: "[.autoInit, .abstract]")) + ) + isAutoInit = true + } else if config.assemblyType == .autoInitAssembly || config.assemblyType == .fakeAssembly { + accessorBlock = AccessorBlockSyntax( + accessors: .getter(.init(stringLiteral: "[.autoInit]")) + ) + isAutoInit = true + } + else { + accessorBlock = AccessorBlockSyntax( + accessors: .getter(.init(stringLiteral: "[]")) + ) + isAutoInit = false + } + + return try ExtensionDeclSyntax( + extendedType: TypeSyntax(stringLiteral: config.assemblyName), + memberBlockBuilder: { + VariableDeclSyntax.makeVar( + keywords: [.public, .static], + name: "_assemblyFlags", + type: "[ModuleAssemblyFlags]", + accessorBlock: accessorBlock + ) + if isAutoInit { + try FunctionDeclSyntax("public static func _autoInstantiate() -> (any ModuleAssembly)? { \(raw: config.assemblyName)() }") + } else { + try FunctionDeclSyntax("public static func _autoInstantiate() -> (any ModuleAssembly)? { nil }") + } + } + ) + } + } extension Registration { diff --git a/Tests/KnitCodeGenTests/ConfigurationSetTests.swift b/Tests/KnitCodeGenTests/ConfigurationSetTests.swift index d408f55..0c3fb97 100644 --- a/Tests/KnitCodeGenTests/ConfigurationSetTests.swift +++ b/Tests/KnitCodeGenTests/ConfigurationSetTests.swift @@ -32,6 +32,14 @@ final class ConfigurationSetTests: XCTestCase { knitUnwrap(resolve(Service1.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } } + extension Module1Assembly { + public static var _assemblyFlags: [ModuleAssemblyFlags] { + [] + } + public static func _autoInstantiate() -> (any ModuleAssembly)? { + nil + } + } /// Generated from ``Module2Assembly`` extension Resolver { public func callAsFunction(file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> Service2 { @@ -41,12 +49,28 @@ final class ConfigurationSetTests: XCTestCase { knitUnwrap(resolve(ArgumentService.self, argument: string), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } } + extension Module2Assembly { + public static var _assemblyFlags: [ModuleAssemblyFlags] { + [] + } + public static func _autoInstantiate() -> (any ModuleAssembly)? { + nil + } + } /// Generated from ``Module3Assembly`` extension Resolver { public func service3(file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> Service3 { knitUnwrap(resolve(Service3.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } } + extension Module3Assembly { + public static var _assemblyFlags: [ModuleAssemblyFlags] { + [] + } + public static func _autoInstantiate() -> (any ModuleAssembly)? { + nil + } + } """ ) } diff --git a/Tests/KnitCodeGenTests/TypeSafetySourceFileTests.swift b/Tests/KnitCodeGenTests/TypeSafetySourceFileTests.swift index 139826f..4f92df0 100644 --- a/Tests/KnitCodeGenTests/TypeSafetySourceFileTests.swift +++ b/Tests/KnitCodeGenTests/TypeSafetySourceFileTests.swift @@ -62,6 +62,14 @@ final class TypeSafetySourceFileTests: XCTestCase { case otherName } } + extension ModuleAssembly { + public static var _assemblyFlags: [ModuleAssemblyFlags] { + [] + } + public static func _autoInstantiate() -> (any ModuleAssembly)? { + nil + } + } """ XCTAssertEqual(expected, result.formatted().description) @@ -257,6 +265,14 @@ final class TypeSafetySourceFileTests: XCTestCase { extension RealAssembly: Knit.DefaultModuleAssemblyOverride { public typealias OverrideType = MyFakeAssembly } + extension MyFakeAssembly { + public static var _assemblyFlags: [ModuleAssemblyFlags] { + [.autoInit] + } + public static func _autoInstantiate() -> (any ModuleAssembly)? { + MyFakeAssembly() + } + } """ XCTAssertEqual(expected, result.formatted().description) @@ -289,6 +305,43 @@ final class TypeSafetySourceFileTests: XCTestCase { extension OtherRealAssembly: Knit.DefaultModuleAssemblyOverride { public typealias OverrideType = MyFakeAssembly } + extension MyFakeAssembly { + public static var _assemblyFlags: [ModuleAssemblyFlags] { + [.autoInit] + } + public static func _autoInstantiate() -> (any ModuleAssembly)? { + MyFakeAssembly() + } + } + """ + + XCTAssertEqual(expected, result.formatted().description) + } + + func test_abstract_generation() throws { + let result = try TypeSafetySourceFile.make( + from: Configuration( + assemblyName: "SomeAbstractAssembly", + moduleName: "Module", + assemblyType: .abstractAssembly, + registrations: [], + replaces: [], + targetResolver: "AccountResolver" + ) + ) + + let expected = """ + /// Generated from ``SomeAbstractAssembly`` + extension AccountResolver { + } + extension SomeAbstractAssembly { + public static var _assemblyFlags: [ModuleAssemblyFlags] { + [.autoInit, .abstract] + } + public static func _autoInstantiate() -> (any ModuleAssembly)? { + SomeAbstractAssembly() + } + } """ XCTAssertEqual(expected, result.formatted().description) @@ -312,6 +365,14 @@ final class TypeSafetySourceFileTests: XCTestCase { knitUnwrap(resolve(ServiceA.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } } + extension MainActorAssembly { + public static var _assemblyFlags: [ModuleAssemblyFlags] { + [] + } + public static func _autoInstantiate() -> (any ModuleAssembly)? { + nil + } + } """ XCTAssertEqual(expected, result.formatted().description)