Skip to content

Commit

Permalink
Merge pull request #226 from cashapp/skorulis/cast-performance
Browse files Browse the repository at this point in the history
Improve performance of DependencyBuilder
  • Loading branch information
skorulis-ap authored Jan 15, 2025
2 parents 4794bb0 + 8c24c71 commit 6b852ec
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 5 deletions.
11 changes: 6 additions & 5 deletions Sources/Knit/Module/DependencyBuilder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ final class DependencyBuilder {
return existingType
}
if let overrideType = try defaultOverride(moduleType, fromInput: inputModule != nil),
let autoInit = overrideType as? any AutoInitModuleAssembly.Type {
return autoInit.init()
let created = overrideType._autoInstantiate() {
return created
}
if let inputModule {
return inputModule
}
if let autoInit = moduleType as? any AutoInitModuleAssembly.Type {
return autoInit.init()
if let created = moduleType._autoInstantiate() {
return created
}

throw DependencyBuilderError.moduleNotProvided(moduleType, dependencyTree.sourcePathString(moduleType: moduleType))
Expand Down Expand Up @@ -230,7 +230,8 @@ internal extension DependencyBuilder {
// Collect AbstractAssemblies as they should all be instantiated and added to the container.
// This needs to happen before the filter below as they are all expected to be implemented by other assemblies
// and will therefore be filtered out.
if ref.type is any AbstractAssembly.Type {

if ref.type._assemblyFlags.contains(.abstract) {
return true
}

Expand Down
32 changes: 32 additions & 0 deletions Sources/Knit/Module/ModuleAssembly.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ public protocol ModuleAssembly {
/// This can be overridden in apps with custom Resolver hierarchies
static func scoped(_ dependencies: [any ModuleAssembly.Type]) -> [any ModuleAssembly.Type]

/// Hints about this assembly using by DependencyBuilder. Designed for internal use
static var _assemblyFlags: [ModuleAssemblyFlags] { get }

/// Creates an instance of this assembly if the assembly conforms to AutoInitModuleAssembly
static func _autoInstantiate() -> (any ModuleAssembly)?
}

public extension ModuleAssembly {
Expand All @@ -40,6 +45,24 @@ public extension ModuleAssembly {
return self.resolverType == $0.resolverType
}
}

static var _assemblyFlags: [ModuleAssemblyFlags] {
var result: [ModuleAssemblyFlags] = []
if self is any AutoInitModuleAssembly.Type {
result.append(.autoInit)
}
if self is any AbstractAssembly.Type {
result.append(.abstract)
}
return result
}

static func _autoInstantiate() -> (any ModuleAssembly)? {
if let autoInit = self as? any AutoInitModuleAssembly.Type {
return autoInit.init()
}
return nil
}
}

/// A ModuleAssembly that can be initialised without any parameters
Expand Down Expand Up @@ -97,3 +120,12 @@ public struct OverrideBehavior {
return NSClassFromString("XCTestCase") != nil
}
}

public struct ModuleAssemblyFlags: OptionSet {
public let rawValue: UInt

public init(rawValue: UInt) { self.rawValue = rawValue }

public static let autoInit = ModuleAssemblyFlags(rawValue: 1 << 0)
public static let abstract = ModuleAssemblyFlags(rawValue: 1 << 1)
}
40 changes: 40 additions & 0 deletions Sources/KnitCodeGen/TypeSafetySourceFile.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public enum TypeSafetySourceFile {
if let defaultOverrides = try makeDefaultOverrideExtensions(config: config) {
defaultOverrides
}
try makePerformanceExtension(config: config)
}
}

Expand Down Expand Up @@ -160,6 +161,45 @@ public enum TypeSafetySourceFile {
)
}

private static func makePerformanceExtension(config: Configuration) throws -> ExtensionDeclSyntax {
let accessorBlock: AccessorBlockSyntax
let isAutoInit: Bool
if config.assemblyType == .abstractAssembly {
accessorBlock = AccessorBlockSyntax(
accessors: .getter(.init(stringLiteral: "[.autoInit, .abstract]"))
)
isAutoInit = true
} else if config.assemblyType == .autoInitAssembly || config.assemblyType == .fakeAssembly {
accessorBlock = AccessorBlockSyntax(
accessors: .getter(.init(stringLiteral: "[.autoInit]"))
)
isAutoInit = true
}
else {
accessorBlock = AccessorBlockSyntax(
accessors: .getter(.init(stringLiteral: "[]"))
)
isAutoInit = false
}

return try ExtensionDeclSyntax(
extendedType: TypeSyntax(stringLiteral: config.assemblyName),
memberBlockBuilder: {
VariableDeclSyntax.makeVar(
keywords: [.public, .static],
name: "_assemblyFlags",
type: "[ModuleAssemblyFlags]",
accessorBlock: accessorBlock
)
if isAutoInit {
try FunctionDeclSyntax("public static func _autoInstantiate() -> (any ModuleAssembly)? { \(raw: config.assemblyName)() }")
} else {
try FunctionDeclSyntax("public static func _autoInstantiate() -> (any ModuleAssembly)? { nil }")
}
}
)
}

}

extension Registration {
Expand Down
24 changes: 24 additions & 0 deletions Tests/KnitCodeGenTests/ConfigurationSetTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ final class ConfigurationSetTests: XCTestCase {
knitUnwrap(resolve(Service1.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line)
}
}
extension Module1Assembly {
public static var _assemblyFlags: [ModuleAssemblyFlags] {
[]
}
public static func _autoInstantiate() -> (any ModuleAssembly)? {
nil
}
}
/// Generated from ``Module2Assembly``
extension Resolver {
public func callAsFunction(file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> Service2 {
Expand All @@ -41,12 +49,28 @@ final class ConfigurationSetTests: XCTestCase {
knitUnwrap(resolve(ArgumentService.self, argument: string), callsiteFile: file, callsiteFunction: function, callsiteLine: line)
}
}
extension Module2Assembly {
public static var _assemblyFlags: [ModuleAssemblyFlags] {
[]
}
public static func _autoInstantiate() -> (any ModuleAssembly)? {
nil
}
}
/// Generated from ``Module3Assembly``
extension Resolver {
public func service3(file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> Service3 {
knitUnwrap(resolve(Service3.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line)
}
}
extension Module3Assembly {
public static var _assemblyFlags: [ModuleAssemblyFlags] {
[]
}
public static func _autoInstantiate() -> (any ModuleAssembly)? {
nil
}
}
"""
)
}
Expand Down
61 changes: 61 additions & 0 deletions Tests/KnitCodeGenTests/TypeSafetySourceFileTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ final class TypeSafetySourceFileTests: XCTestCase {
case otherName
}
}
extension ModuleAssembly {
public static var _assemblyFlags: [ModuleAssemblyFlags] {
[]
}
public static func _autoInstantiate() -> (any ModuleAssembly)? {
nil
}
}
"""

XCTAssertEqual(expected, result.formatted().description)
Expand Down Expand Up @@ -257,6 +265,14 @@ final class TypeSafetySourceFileTests: XCTestCase {
extension RealAssembly: Knit.DefaultModuleAssemblyOverride {
public typealias OverrideType = MyFakeAssembly
}
extension MyFakeAssembly {
public static var _assemblyFlags: [ModuleAssemblyFlags] {
[.autoInit]
}
public static func _autoInstantiate() -> (any ModuleAssembly)? {
MyFakeAssembly()
}
}
"""

XCTAssertEqual(expected, result.formatted().description)
Expand Down Expand Up @@ -289,6 +305,43 @@ final class TypeSafetySourceFileTests: XCTestCase {
extension OtherRealAssembly: Knit.DefaultModuleAssemblyOverride {
public typealias OverrideType = MyFakeAssembly
}
extension MyFakeAssembly {
public static var _assemblyFlags: [ModuleAssemblyFlags] {
[.autoInit]
}
public static func _autoInstantiate() -> (any ModuleAssembly)? {
MyFakeAssembly()
}
}
"""

XCTAssertEqual(expected, result.formatted().description)
}

func test_abstract_generation() throws {
let result = try TypeSafetySourceFile.make(
from: Configuration(
assemblyName: "SomeAbstractAssembly",
moduleName: "Module",
assemblyType: .abstractAssembly,
registrations: [],
replaces: [],
targetResolver: "AccountResolver"
)
)

let expected = """
/// Generated from ``SomeAbstractAssembly``
extension AccountResolver {
}
extension SomeAbstractAssembly {
public static var _assemblyFlags: [ModuleAssemblyFlags] {
[.autoInit, .abstract]
}
public static func _autoInstantiate() -> (any ModuleAssembly)? {
SomeAbstractAssembly()
}
}
"""

XCTAssertEqual(expected, result.formatted().description)
Expand All @@ -312,6 +365,14 @@ final class TypeSafetySourceFileTests: XCTestCase {
knitUnwrap(resolve(ServiceA.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line)
}
}
extension MainActorAssembly {
public static var _assemblyFlags: [ModuleAssemblyFlags] {
[]
}
public static func _autoInstantiate() -> (any ModuleAssembly)? {
nil
}
}
"""

XCTAssertEqual(expected, result.formatted().description)
Expand Down

0 comments on commit 6b852ec

Please sign in to comment.