From 08003c26350362192d481b547f0dcc63a9f1020d Mon Sep 17 00:00:00 2001 From: Brad Fol Date: Mon, 7 Aug 2023 13:06:39 -0700 Subject: [PATCH] Upgrade SwiftSyntax to 509.0.1 --- .github/workflows/release.yml | 3 + .github/workflows/swift.yml | 3 + Package.resolved | 4 +- Package.swift | 10 +- Sources/KnitCodeGen/AssemblyParsing.swift | 27 ++-- Sources/KnitCodeGen/Configuration.swift | 18 ++- Sources/KnitCodeGen/ConfigurationSet.swift | 27 ++-- .../FunctionCallRegistrationParsing.swift | 61 ++++---- .../ModuleAssemblyExtensionSourceFile.swift | 49 +++--- .../KnitCodeGen/TypeSafetySourceFile.swift | 32 ++-- Sources/KnitCodeGen/UnitTestSourceFile.swift | 60 ++++--- Sources/KnitCommand/GenCommand.swift | 2 +- .../ModuleDependenciesCommand.swift | 2 +- .../AssemblyParsingTests.swift | 38 +++-- .../ConfigurationSetTests.swift | 26 ++-- ...duleAssemblyExtensionSourceFileTests.swift | 31 ++-- .../RegistrationParsingTests.swift | 147 +++++++++--------- .../TypeSafetySourceFileTests.swift | 15 +- .../UnitTestSourceFileTests.swift | 54 +++---- 19 files changed, 306 insertions(+), 303 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 64f18e2..23ac3c5 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -15,6 +15,9 @@ jobs: - name: Checkout uses: actions/checkout@v3 - name: Create Release + uses: swift-actions/setup-swift@v1.25.0 + with: + swift-version: "5.9" run: | set -euo pipefail diff --git a/.github/workflows/swift.yml b/.github/workflows/swift.yml index cda61a2..c603026 100644 --- a/.github/workflows/swift.yml +++ b/.github/workflows/swift.yml @@ -16,6 +16,9 @@ jobs: steps: - uses: actions/checkout@v3 + - uses: swift-actions/setup-swift@v1.25.0 + with: + swift-version: "5.9" - name: Build run: swift build -v - name: Run tests diff --git a/Package.resolved b/Package.resolved index 86fdbdc..6c49724 100644 --- a/Package.resolved +++ b/Package.resolved @@ -14,8 +14,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-syntax.git", "state" : { - "revision" : "2c49d66d34dfd6f8130afdba889de77504b58ec0", - "version" : "508.0.1" + "revision" : "ffa3cd6fc2aa62adbedd31d3efaf7c0d86a9f029", + "version" : "509.0.1" } }, { diff --git a/Package.swift b/Package.swift index b86bc27..377048b 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version: 5.7 +// swift-tools-version: 5.9 // The swift-tools-version declares the minimum version of Swift required to build this package. import PackageDescription @@ -13,7 +13,7 @@ let package = Package( .executable(name: "knit-cli", targets: ["KnitCommand"]) ], dependencies: [ - .package(url: "https://github.com/apple/swift-syntax.git", from: "508.0.1"), + .package(url: "https://github.com/apple/swift-syntax.git", from: "509.0.1"), .package(url: "https://github.com/apple/swift-argument-parser", from: "1.2.0"), .package(url: "https://github.com/Swinject/Swinject.git", from: "2.8.3"), .package(url: "https://github.com/Swinject/SwinjectAutoregistration.git", from: "2.8.3"), @@ -29,8 +29,7 @@ let package = Package( .executableTarget( name: "KnitCommand", dependencies: [ - .product(name: "SwiftSyntaxBuilder", package: "swift-syntax"), - .product(name: "SwiftSyntaxParser", package: "swift-syntax"), + .product(name: "SwiftSyntax", package: "swift-syntax"), .product(name: "ArgumentParser", package: "swift-argument-parser"), .target(name: "KnitCodeGen"), ] @@ -38,8 +37,9 @@ let package = Package( .target( name: "KnitCodeGen", dependencies: [ + .product(name: "SwiftSyntax", package: "swift-syntax"), + .product(name: "SwiftParser", package: "swift-syntax"), .product(name: "SwiftSyntaxBuilder", package: "swift-syntax"), - .product(name: "SwiftSyntaxParser", package: "swift-syntax"), ] ), .testTarget( diff --git a/Sources/KnitCodeGen/AssemblyParsing.swift b/Sources/KnitCodeGen/AssemblyParsing.swift index 471fb52..680a8ec 100644 --- a/Sources/KnitCodeGen/AssemblyParsing.swift +++ b/Sources/KnitCodeGen/AssemblyParsing.swift @@ -1,6 +1,6 @@ import Foundation import SwiftSyntax -import SwiftSyntaxParser +import SwiftParser public func parseAssemblies(at paths: [String]) throws -> ConfigurationSet { var configs = [Configuration]() @@ -8,12 +8,13 @@ public func parseAssemblies(at paths: [String]) throws -> ConfigurationSet { let url = URL(fileURLWithPath: path, isDirectory: false) var errorsToPrint = [Error]() - let syntaxTree: SourceFileSyntax + let source: String do { - syntaxTree = try SwiftSyntaxParser.SyntaxParser.parse(url) + source = try String(contentsOf: url) } catch { - throw AssemblyParsingError.syntaxParsingError(error, path: path) + throw AssemblyParsingError.fileReadError(error, path: path) } + let syntaxTree = Parser.parse(source: source) let configuration = try parseSyntaxTree(syntaxTree, errorsToPrint: &errorsToPrint) configs.append(configuration) printErrors(errorsToPrint, filePath: path, syntaxTree: syntaxTree) @@ -79,11 +80,11 @@ private class AssemblyFileVisitor: SyntaxVisitor { } override func visit(_ node: ImportDeclSyntax) -> SyntaxVisitorContinueKind { - imports.append(node.withoutTrivia()) + imports.append(node.trimmed) return .skipChildren } - private func visitAssemblyType(_ node: IdentifiedDeclSyntax) -> SyntaxVisitorContinueKind { + private func visitAssemblyType(_ node: NamedDeclSyntax) -> SyntaxVisitorContinueKind { guard classDeclVisitor == nil else { // Only the first class declaration should be visited return .skipChildren @@ -139,12 +140,12 @@ private class ClassDeclVisitor: SyntaxVisitor { } -extension IdentifiedDeclSyntax { +extension NamedDeclSyntax { /// Returns the module name for the assembly class. /// If the class is not an assembly returns `nil`. var moduleNameForAssembly: String? { - let className = identifier.text + let className = name.text let assemblySuffx = "Assembly" guard className.hasSuffix(assemblySuffx) else { return nil @@ -157,7 +158,7 @@ extension IdentifiedDeclSyntax { // MARK: - Errors enum AssemblyParsingError: Error { - case syntaxParsingError(Error, path: String) + case fileReadError(Error, path: String) case missingModuleName } @@ -165,9 +166,9 @@ extension AssemblyParsingError: LocalizedError { var errorDescription: String? { switch self { - case let .syntaxParsingError(error, path: path): + case let .fileReadError(error, path: path): return """ - Error parsing assembly file: \(error) + Error reading file: \(error.localizedDescription) File path: \(path) """ @@ -184,12 +185,12 @@ func printErrors(_ errors: [Error], filePath: String, syntaxTree: SyntaxProtocol guard !errors.isEmpty else { return } - let lineConverter = SourceLocationConverter(file: filePath, tree: syntaxTree) + let lineConverter = SourceLocationConverter(fileName: filePath, tree: syntaxTree) for error in errors { if let syntaxError = error as? SyntaxError { let position = syntaxError.syntax.startLocation(converter: lineConverter, afterLeadingTrivia: true) - let line = position.line ?? 1 + let line = position.line print("\(filePath):\(line): error: \(error.localizedDescription)") } else { print("\(filePath): error: \(error.localizedDescription)") diff --git a/Sources/KnitCodeGen/Configuration.swift b/Sources/KnitCodeGen/Configuration.swift index 66e51c0..8c49f33 100644 --- a/Sources/KnitCodeGen/Configuration.swift +++ b/Sources/KnitCodeGen/Configuration.swift @@ -32,20 +32,20 @@ public struct Configuration: Encodable { public extension Configuration { - func makeTypeSafetySourceFile() -> SourceFileSyntax { - return TypeSafetySourceFile.make( + func makeTypeSafetySourceFile() throws -> SourceFileSyntax { + return try TypeSafetySourceFile.make( assemblyName: "\(name)Assembly", extensionTarget: "Resolver", registrations: registrations ) } - func makeUnitTestSourceFile() -> SourceFileSyntax { + func makeUnitTestSourceFile() throws -> SourceFileSyntax { var allImports = imports - allImports.append("@testable import \(raw: self.name)") - allImports.append("import XCTest") + allImports.append(try ImportDeclSyntax("@testable import \(raw: self.name)")) + allImports.append(try ImportDeclSyntax("import XCTest")) - return UnitTestSourceFile.make( + return try UnitTestSourceFile.make( configuration: self ) } @@ -76,7 +76,11 @@ public extension ImportDeclSyntax { init(moduleName: String) { self.init( - path: [ AccessPathComponentSyntax(name: moduleName) ] + path: [ + ImportPathComponentSyntax( + name: "\(raw: moduleName)" + ) + ] ) } diff --git a/Sources/KnitCodeGen/ConfigurationSet.swift b/Sources/KnitCodeGen/ConfigurationSet.swift index 10897f7..1ef0c58 100644 --- a/Sources/KnitCodeGen/ConfigurationSet.swift +++ b/Sources/KnitCodeGen/ConfigurationSet.swift @@ -1,7 +1,6 @@ -// Created by Alexander skorulis on 4/8/2023. - import Foundation import SwiftSyntax +import SwiftSyntaxBuilder // Multiple assemblies that are grouped together public struct ConfigurationSet { @@ -16,17 +15,17 @@ public struct ConfigurationSet { public func writeGeneratedFiles( typeSafetyExtensionsOutputPath: String?, unitTestOutputPath: String? - ) { + ) throws { if let typeSafetyExtensionsOutputPath { write( - text: makeTypeSafetySourceFile(), + text: try makeTypeSafetySourceFile(), to: typeSafetyExtensionsOutputPath ) } if let unitTestOutputPath { write( - text: makeUnitTestSourceFile(), + text: try makeUnitTestSourceFile(), to: unitTestOutputPath ) } @@ -41,24 +40,24 @@ public struct ConfigurationSet { public extension ConfigurationSet { - func makeTypeSafetySourceFile() -> String { + func makeTypeSafetySourceFile() throws -> String { var allImports = allImports - allImports.append("import Swinject") + allImports.append(try ImportDeclSyntax("import Swinject")) let header = HeaderSourceFile.make(importDecls: sortImports(allImports), comment: Self.typeSafetyIntro) - let body = assemblies.map { $0.makeTypeSafetySourceFile() } + let body = try assemblies.map { try $0.makeTypeSafetySourceFile() } let sourceFiles = [header] + body return Self.join(sourceFiles: sourceFiles) } - func makeUnitTestSourceFile() -> String { + func makeUnitTestSourceFile() throws -> String { var allImports = allImports - allImports.append("@testable import \(raw: primaryAssembly.name)") - allImports.append("import XCTest") + allImports.append(try ImportDeclSyntax("@testable import \(raw: primaryAssembly.name)")) + allImports.append(try ImportDeclSyntax("import XCTest")) let header = HeaderSourceFile.make(importDecls: sortImports(allImports), comment: nil) - let body = assemblies.map { $0.makeUnitTestSourceFile() } + let body = try assemblies.map { try $0.makeUnitTestSourceFile() } let allRegistrations = assemblies.flatMap { $0.registrations } let allRegistrationsIntoCollections = assemblies.flatMap { $0.registrationsIntoCollections } - let resolverExtensions = UnitTestSourceFile.resolverExtensions( + let resolverExtensions = try UnitTestSourceFile.resolverExtensions( registrations: allRegistrations, registrationsIntoCollections: allRegistrationsIntoCollections ) @@ -68,7 +67,7 @@ public extension ConfigurationSet { private static func join(sourceFiles: [SourceFileSyntax]) -> String { let result = sourceFiles.map { $0.formatted().description }.joined(separator: "\n") - return result.replacingOccurrences(of: ", \n", with: ",\n") + return result } private func sortImports(_ imports: [ImportDeclSyntax]) -> [ImportDeclSyntax] { diff --git a/Sources/KnitCodeGen/FunctionCallRegistrationParsing.swift b/Sources/KnitCodeGen/FunctionCallRegistrationParsing.swift index 7f015f1..eb4c78c 100644 --- a/Sources/KnitCodeGen/FunctionCallRegistrationParsing.swift +++ b/Sources/KnitCodeGen/FunctionCallRegistrationParsing.swift @@ -11,7 +11,7 @@ struct CalledMethod { let calledExpression: MemberAccessExprSyntax // The arguments passed to the method, e.g. `arg1: String` from the example above. - let arguments: TupleExprElementListSyntax + let arguments: LabeledExprListSyntax // A trailing closure after the called method (which is the last argument for that method call). let trailingClosure: ClosureExprSyntax? @@ -34,7 +34,7 @@ extension FunctionCallExprSyntax { let registrationIntoCollection = calledMethods .first { method in - let name = method.calledExpression.name.text + let name = method.calledExpression.declName.baseName.text return name == "registerIntoCollection" || name == "autoregisterIntoCollection" } .flatMap { method in @@ -47,7 +47,7 @@ extension FunctionCallExprSyntax { } let registerMethods = calledMethods.filter { method in - let name = method.calledExpression.name.text + let name = method.calledExpression.declName.baseName.text return name == "register" || name == "autoregister" || name == "registerAbstract" } @@ -77,14 +77,14 @@ extension FunctionCallExprSyntax { } let implementsCalledMethods = calledMethods.filter { method in - method.calledExpression.name.text == "implements" + method.calledExpression.declName.baseName.text == "implements" } var forwardedRegistrations = [Registration]() for implementsCalledMethod in implementsCalledMethods { - // For `.implements()` the leading trivia is attached to the Dot syntax node - let leadingTrivia = implementsCalledMethod.calledExpression.dot.leadingTrivia + // For `.implements()` the leading trivia is attached to the Period syntax node + let leadingTrivia = implementsCalledMethod.calledExpression.period.leadingTrivia if let forwardedRegistration = try makeRegistrationFor( defaultDirectives: defaultDirectives, @@ -121,10 +121,10 @@ func recurseAllCalledMethods( // Append each method call as we recurse calledMethods.append(CalledMethod( calledExpression: calledExpr, - arguments: funcCall.argumentList, + arguments: funcCall.arguments, trailingClosure: funcCall.trailingClosure )) - if let identifierToken = calledExpr.base?.as(IdentifierExprSyntax.self)?.identifier { + if let identifierToken = calledExpr.base?.as(DeclReferenceExprSyntax.self)?.baseName { return identifierToken } else { let innerFunctionCall = calledExpr.base!.as(FunctionCallExprSyntax.self)! @@ -140,16 +140,16 @@ func recurseAllCalledMethods( private func makeRegistrationFor( defaultDirectives: KnitDirectives, - arguments: TupleExprElementListSyntax, + arguments: LabeledExprListSyntax, registrationArguments: [Registration.Argument], leadingTrivia: Trivia?, isForwarded: Bool ) throws -> Registration? { - guard let firstParam = arguments.first?.as(TupleExprElementSyntax.self)? + guard let firstParam = arguments.first?.as(LabeledExprSyntax.self)? .expression.as(MemberAccessExprSyntax.self) else { return nil } - guard firstParam.name.text == "self" else { return nil } + guard firstParam.declName.baseName.text == "self" else { return nil } - let registrationText = firstParam.base!.withoutTrivia().description + let registrationText = firstParam.base!.trimmed.description let name = try getName(arguments: arguments) let directives = try KnitDirectives.parse(leadingTrivia: leadingTrivia) @@ -171,17 +171,17 @@ private func makeRegistrationFor( } private func makeRegistrationIntoCollection( - arguments: TupleExprElementListSyntax + arguments: LabeledExprListSyntax ) -> RegistrationIntoCollection? { - guard let firstParam = arguments.first?.as(TupleExprElementSyntax.self)? + guard let firstParam = arguments.first?.as(LabeledExprSyntax.self)? .expression.as(MemberAccessExprSyntax.self) else { return nil } - guard firstParam.name.text == "self" else { return nil } + guard firstParam.declName.baseName.text == "self" else { return nil } - let registrationText = firstParam.base!.withoutTrivia().description + let registrationText = firstParam.base!.trimmed.description return RegistrationIntoCollection(service: registrationText) } -private func getName(arguments: TupleExprElementListSyntax) throws -> String? { +private func getName(arguments: LabeledExprListSyntax) throws -> String? { guard let nameParam = arguments.first(where: {$0.label?.text == "name"}) else { return nil } @@ -192,7 +192,7 @@ private func getName(arguments: TupleExprElementListSyntax) throws -> String? { } private func getArguments( - arguments: TupleExprElementListSyntax, + arguments: LabeledExprListSyntax, trailingClosure: ClosureExprSyntax? ) throws -> [Registration.Argument] { // `autoregister` parsing @@ -236,20 +236,22 @@ private func getArguments( factoryClosure = nil } - // This type of closure param list syntax cannot include types, so force using `ParameterClauseSyntax` + // This type of closure param list syntax cannot include types, so force using `ClosureParameterClauseSyntax` // when there is more that one param. // If there is only one param then it is always the `Resolver`. - if let paramList = factoryClosure?.signature?.input?.as(ClosureParamListSyntax.self), paramList.count >= 2 { - throw RegistrationParsingError.unwrappedClosureParams(syntax: paramList) + if let paramList = factoryClosure?.signature?.parameterClause?.as(ClosureShorthandParameterListSyntax.self), + paramList.count >= 2 { + throw RegistrationParsingError.unwrappedClosureParams(syntax: paramList) } // Register methods take a closure with resolver and arguments. Argument types must be provided - if let closureParameters = factoryClosure?.signature?.input?.as(ParameterClauseSyntax.self) { - let params = closureParameters.parameterList + if let closureParameters = factoryClosure?.signature?.parameterClause?.as(ClosureParameterClauseSyntax.self) { + let params = closureParameters.parameters // The first param is the resolver, everything after that is an argument return try params[params.index(after: params.startIndex).. String? { +private func getArgumentType(arg: LabeledExprSyntax) -> String? { return arg.expression.as(MemberAccessExprSyntax.self)?.base?.description .replacingOccurrences(of: "@escaping", with: " ") .trimmingCharacters(in: .whitespacesAndNewlines) } -private func getArgumentType(arg: FunctionParameterSyntax) -> String? { - guard let type = arg.type else { - return nil - } - return type.description +private func getArgumentType(arg: ClosureParameterSyntax) -> String? { + return arg.type?.description .replacingOccurrences(of: "@escaping", with: " ") .trimmingCharacters(in: .whitespacesAndNewlines) } diff --git a/Sources/KnitCodeGen/ModuleAssemblyExtensionSourceFile.swift b/Sources/KnitCodeGen/ModuleAssemblyExtensionSourceFile.swift index cbe497b..2242a1b 100644 --- a/Sources/KnitCodeGen/ModuleAssemblyExtensionSourceFile.swift +++ b/Sources/KnitCodeGen/ModuleAssemblyExtensionSourceFile.swift @@ -14,37 +14,46 @@ public enum ModuleAssemblyExtensionSourceFile { currentModuleName: String, dependencyModuleNames: [String], additionalAssemblies: [String] - ) -> SourceFileSyntax { - return SourceFileSyntax(leadingTrivia: TriviaProvider.headerTrivia) { + ) throws -> SourceFileSyntax { + return try SourceFileSyntax(leadingTrivia: TriviaProvider.headerTrivia) { DeclSyntax("import Knit") for dependencyModuleName in dependencyModuleNames where !dependencyModuleName.hasSuffix("Assembly") { DeclSyntax("import \(raw: dependencyModuleName)") } - ExtensionDeclSyntax("extension \(currentModuleName)Assembly: GeneratedModuleAssembly") { + try ExtensionDeclSyntax("extension \(raw: currentModuleName)Assembly: GeneratedModuleAssembly") { // `public static var generatedDependencies: [any ModuleAssembly.Type]` VariableDeclSyntax( modifiers: [ - DeclModifierSyntax(name: TokenSyntax(.publicKeyword, presence: .present)), - DeclModifierSyntax(name: TokenSyntax(.staticKeyword, presence: .present)), + DeclModifierSyntax(name: TokenSyntax(.keyword(.public), presence: .present)), + DeclModifierSyntax(name: TokenSyntax(.keyword(.static), presence: .present)), ], - name: "generatedDependencies", - type: TypeAnnotationSyntax(type: "[any ModuleAssembly.Type]" as TypeSyntax), + bindingSpecifier: .keyword(.var), + bindingsBuilder: { + PatternBindingSyntax( + pattern: IdentifierPatternSyntax(identifier: .identifier("generatedDependencies")), + typeAnnotation: TypeAnnotationSyntax(type: "[any ModuleAssembly.Type]" as TypeSyntax), + accessorBlock: AccessorBlockSyntax( + + // Make the computed property accessor + accessors: .getter(.init( + itemsBuilder: { + let elements = ArrayElementListSyntax { + // Turn each module name string into a meta type of the Assembly + for name in (dependencyModuleNames + additionalAssemblies) { + ArrayElementSyntax( + leadingTrivia: [ .newlines(1) ], + expression: "\(raw: typeName(name)).self" as ExprSyntax + ) + } + } - // Make the computed property accessor - accessor: { - let elements = ArrayElementList { - // Turn each module name string into a meta type of the Assembly - for name in (dependencyModuleNames + additionalAssemblies) { - ArrayElementSyntax( - leadingTrivia: [ .newlines(1) ], - expression: "\(raw: typeName(name)).self" as MemberAccessExprSyntax - ) - } - } - - ArrayExpr(elements: elements) + ArrayExprSyntax(elements: elements) + } + )) + ) + ) } ) } diff --git a/Sources/KnitCodeGen/TypeSafetySourceFile.swift b/Sources/KnitCodeGen/TypeSafetySourceFile.swift index 5a03338..c6756ef 100644 --- a/Sources/KnitCodeGen/TypeSafetySourceFile.swift +++ b/Sources/KnitCodeGen/TypeSafetySourceFile.swift @@ -8,30 +8,30 @@ public enum TypeSafetySourceFile { assemblyName: String, extensionTarget: String, registrations allRegistrations: [Registration] - ) -> SourceFileSyntax { + ) throws -> SourceFileSyntax { let visibleRegistrations = allRegistrations.filter { // Exclude hidden registrations always $0.accessLevel != .hidden } let unnamedRegistrations = visibleRegistrations.filter { $0.name == nil } let namedGroups = NamedRegistrationGroup.make(from: visibleRegistrations) - return SourceFileSyntax() { - ExtensionDeclSyntax(""" - // Generated from \(assemblyName) - extension \(extensionTarget) + return try SourceFileSyntax() { + try ExtensionDeclSyntax(""" + // Generated from \(raw: assemblyName) + extension \(raw: extensionTarget) """) { for registration in unnamedRegistrations { if registration.getterConfig.contains(.callAsFunction) { - makeResolver(registration: registration, getterType: .callAsFunction) + try makeResolver(registration: registration, getterType: .callAsFunction) } if let namedGetter = registration.getterConfig.first(where: { $0.isNamed }) { - makeResolver(registration: registration, getterType: namedGetter) + try makeResolver(registration: registration, getterType: namedGetter) } } for namedGroup in namedGroups { let firstGetterConfig = namedGroup.registrations[0].getterConfig.first ?? .callAsFunction - makeResolver( + try makeResolver( registration: namedGroup.registrations[0], enumName: "\(assemblyName).\(namedGroup.enumName)", getterType: firstGetterConfig @@ -39,7 +39,7 @@ public enum TypeSafetySourceFile { } } if !namedGroups.isEmpty { - makeNamedEnums(assemblyName: assemblyName, namedGroups: namedGroups) + try makeNamedEnums(assemblyName: assemblyName, namedGroups: namedGroups) } } } @@ -49,7 +49,7 @@ public enum TypeSafetySourceFile { registration: Registration, enumName: String? = nil, getterType: GetterConfig = .callAsFunction - ) -> FunctionDeclSyntax { + ) throws -> FunctionDeclSyntax { let modifier = registration.accessLevel == .public ? "public " : "" let nameInput = enumName.map { "name: \($0)" } let nameUsage = enumName != nil ? "name: name.rawValue" : nil @@ -64,8 +64,8 @@ public enum TypeSafetySourceFile { funcName = name ?? TypeNamer.computedIdentifierName(type: registration.service) } - return FunctionDeclSyntax("\(modifier)func \(funcName)(\(inputs)) -> \(registration.service)") { - ForcedValueExprSyntax("self.resolve(\(raw: usages))!") + return try FunctionDeclSyntax("\(raw: modifier)func \(raw: funcName)(\(raw: inputs)) -> \(raw: registration.service)") { + "self.resolve(\(raw: usages))!" } } @@ -82,13 +82,13 @@ public enum TypeSafetySourceFile { private static func makeNamedEnums( assemblyName: String, namedGroups: [NamedRegistrationGroup] - ) -> ExtensionDeclSyntax { - ExtensionDeclSyntax("extension \(assemblyName)") { + ) throws -> ExtensionDeclSyntax { + try ExtensionDeclSyntax("extension \(raw: assemblyName)") { for namedGroup in namedGroups { let modifier = namedGroup.accessLevel == .public ? "public " : "" - EnumDeclSyntax("\(modifier)enum \(namedGroup.enumName): String, CaseIterable") { + try EnumDeclSyntax("\(raw: modifier)enum \(raw: namedGroup.enumName): String, CaseIterable") { for test in namedGroup.registrations { - EnumCaseDeclSyntax("case \(raw: test.name!)") + "case \(raw: test.name!)" as DeclSyntax } } } diff --git a/Sources/KnitCodeGen/UnitTestSourceFile.swift b/Sources/KnitCodeGen/UnitTestSourceFile.swift index 564c0c9..5748644 100644 --- a/Sources/KnitCodeGen/UnitTestSourceFile.swift +++ b/Sources/KnitCodeGen/UnitTestSourceFile.swift @@ -5,13 +5,13 @@ public enum UnitTestSourceFile { public static func make( configuration: Configuration - ) -> SourceFileSyntax { + ) throws -> SourceFileSyntax { let withArguments = configuration.registrations.filter { !$0.arguments.isEmpty } let hasArguments = !withArguments.isEmpty - return SourceFileSyntax() { - ClassDeclSyntax("final class \(configuration.name)RegistrationTests: XCTestCase") { + return try SourceFileSyntax() { + try ClassDeclSyntax("final class \(raw: configuration.name)RegistrationTests: XCTestCase") { - FunctionDeclSyntax("func testRegistrations()") { + try FunctionDeclSyntax("func testRegistrations()") { DeclSyntax(""" // In the test target for your module, please provide a static method that creates a @@ -38,7 +38,7 @@ public enum UnitTestSourceFile { } for (service, count) in groupByService(configuration.registrationsIntoCollections) { - FunctionCallExprSyntax( + ExprSyntax( "resolver.assertCollectionResolves(\(raw: service).self, count: \(raw: count))" ) } @@ -46,7 +46,7 @@ public enum UnitTestSourceFile { } if hasArguments { - makeArgumentStruct(registrations: configuration.registrations, moduleName: configuration.name) + try makeArgumentStruct(registrations: configuration.registrations, moduleName: configuration.name) } } } @@ -54,24 +54,24 @@ public enum UnitTestSourceFile { static func resolverExtensions( registrations: [Registration], registrationsIntoCollections: [RegistrationIntoCollection] - ) -> SourceFileSyntax { + ) throws -> SourceFileSyntax { let withArguments = registrations.filter { !$0.arguments.isEmpty } - return SourceFileSyntax() { + return try SourceFileSyntax() { // swiftlint:disable line_length - ExtensionDeclSyntax("private extension Resolver") { + try ExtensionDeclSyntax("private extension Resolver") { // This assert is only needed if there are registrations without arguments if registrations.count > withArguments.count { - makeTypeAssert() + try makeTypeAssert() } // This assert is only needed if there are registrations with arguments if !withArguments.isEmpty { - makeResultAssert() + try makeResultAssert() } if !groupByService(registrationsIntoCollections).isEmpty { - makeCollectionAssert() + try makeCollectionAssert() } } // swiftlint:enable line_length @@ -88,26 +88,22 @@ public enum UnitTestSourceFile { } /// Generate a function call to test a single registration resolves - static func makeAssertCall(registration: Registration) -> FunctionCallExprSyntax { + static func makeAssertCall(registration: Registration) -> ExprSyntax { if !registration.arguments.isEmpty { let argParams = argumentParams(registration: registration) let nameParam = registration.name.map { "name: \"\($0)\""} let params = ["\(registration.service).self", nameParam, argParams].compactMap { $0 }.joined(separator: ", ") - return FunctionCallExprSyntax( - "resolver.assertTypeResolved(resolver.resolve(\(raw: params)))" - ) + return "resolver.assertTypeResolved(resolver.resolve(\(raw: params)))" } else if let name = registration.name { - return FunctionCallExprSyntax( - "resolver.assertTypeResolves(\(raw: registration.service).self, name: \"\(raw: name)\")" - ) + return "resolver.assertTypeResolves(\(raw: registration.service).self, name: \"\(raw: name)\")" } else { - return FunctionCallExprSyntax("resolver.assertTypeResolves(\(raw: registration.service).self)") + return "resolver.assertTypeResolves(\(raw: registration.service).self)" } } - private static func makeCollectionAssert() -> FunctionDeclSyntax { - let string = #""" - func assertCollectionResolves ( + private static func makeCollectionAssert() throws -> FunctionDeclSyntax { + let string: SyntaxNodeString = #""" + func assertCollectionResolves( _ type: T.Type, count expectedCount: Int, file: StaticString = #filePath, @@ -126,12 +122,12 @@ public enum UnitTestSourceFile { ) } """# - return FunctionDeclSyntax(stringLiteral: string) + return try FunctionDeclSyntax(string) } /// Generate a function to assert that a type can be resolved - private static func makeTypeAssert() -> FunctionDeclSyntax { - let string = """ + private static func makeTypeAssert() throws -> FunctionDeclSyntax { + let string: SyntaxNodeString = """ func assertTypeResolves( _ type: T.Type, name: String? = nil, @@ -146,12 +142,12 @@ public enum UnitTestSourceFile { ) } """ - return FunctionDeclSyntax(stringLiteral: string) + return try FunctionDeclSyntax(string) } /// Generate a function to assert that a value resolved correctly - private static func makeResultAssert() -> FunctionDeclSyntax { - let string = """ + private static func makeResultAssert() throws -> FunctionDeclSyntax { + let string: SyntaxNodeString = """ func assertTypeResolved( _ result: T?, file: StaticString = #filePath, @@ -165,11 +161,11 @@ public enum UnitTestSourceFile { ) } """ - return FunctionDeclSyntax(stringLiteral: string) + return try FunctionDeclSyntax(string) } /// Generate code for a struct that contains all of the parameters used to resolve services - static func makeArgumentStruct(registrations: [Registration], moduleName: String) -> StructDeclSyntax { + static func makeArgumentStruct(registrations: [Registration], moduleName: String) throws -> StructDeclSyntax { let fields = registrations.flatMap { $0.serviceNamedArguments() } var seen: Set = [] // Make sure duplicate parameters don't get created @@ -182,7 +178,7 @@ public enum UnitTestSourceFile { return true } - return StructDeclSyntax("struct \(moduleName)RegistrationTestArguments") { + return try StructDeclSyntax("struct \(raw: moduleName)RegistrationTestArguments") { for field in uniqueFields { DeclSyntax("let \(raw: field.resolvedIdentifier()): \(raw: field.type)") } diff --git a/Sources/KnitCommand/GenCommand.swift b/Sources/KnitCommand/GenCommand.swift index 816662b..29edc07 100644 --- a/Sources/KnitCommand/GenCommand.swift +++ b/Sources/KnitCommand/GenCommand.swift @@ -55,7 +55,7 @@ struct GenCommand: ParsableCommand { throw ExitCode(1) } - parsedConfig.writeGeneratedFiles( + try parsedConfig.writeGeneratedFiles( typeSafetyExtensionsOutputPath: typeSafetyExtensionsOutputPath, unitTestOutputPath: unitTestOutputPath ) diff --git a/Sources/KnitCommand/ModuleDependenciesCommand.swift b/Sources/KnitCommand/ModuleDependenciesCommand.swift index 6f1b81c..d144eab 100644 --- a/Sources/KnitCommand/ModuleDependenciesCommand.swift +++ b/Sources/KnitCommand/ModuleDependenciesCommand.swift @@ -39,7 +39,7 @@ struct ModuleDependenciesCommand: ParsableCommand { func run() throws { - let result = ModuleAssemblyExtensionSourceFile.make( + let result = try ModuleAssemblyExtensionSourceFile.make( currentModuleName: currentModuleName, dependencyModuleNames: dependencyModuleNames, additionalAssemblies: additionalAssemblies diff --git a/Tests/KnitCodeGenTests/AssemblyParsingTests.swift b/Tests/KnitCodeGenTests/AssemblyParsingTests.swift index cf9af5b..a1e6885 100644 --- a/Tests/KnitCodeGenTests/AssemblyParsingTests.swift +++ b/Tests/KnitCodeGenTests/AssemblyParsingTests.swift @@ -10,7 +10,7 @@ import XCTest final class AssemblyParsingTests: XCTestCase { func testAssemblyImports() throws { - let sourceFile: SourceFile = """ + let sourceFile: SourceFileSyntax = """ import A import B // Comment after import should be stripped class FooTestAssembly: Assembly { } @@ -29,7 +29,7 @@ final class AssemblyParsingTests: XCTestCase { func testTestableImport() throws { // Unclear if this is a use case we care about, but we will retain attributes before the import statement - let sourceFile: SourceFile = """ + let sourceFile: SourceFileSyntax = """ @testable import A class FooTestAssembly: Assembly { } """ @@ -44,7 +44,7 @@ final class AssemblyParsingTests: XCTestCase { } func testAssemblyModuleName() throws { - let sourceFile: SourceFile = """ + let sourceFile: SourceFileSyntax = """ class FooTestAssembly: Assembly { func assemble(container: Container) { container.register(A.self) { } @@ -57,7 +57,7 @@ final class AssemblyParsingTests: XCTestCase { } func testAssemblyStructModuleName() throws { - let sourceFile: SourceFile = """ + let sourceFile: SourceFileSyntax = """ struct FooTestAssembly: Assembly { func assemble(container: Container) { container.register(A.self) { } @@ -70,7 +70,7 @@ final class AssemblyParsingTests: XCTestCase { } func testAssemblyRegistrations() throws { - let sourceFile: SourceFile = """ + let sourceFile: SourceFileSyntax = """ class TestAssembly: Assembly { func assemble(container: Container) { container.register(A.self) { } @@ -88,7 +88,7 @@ final class AssemblyParsingTests: XCTestCase { } func testKnitDirectives() throws { - let sourceFile: SourceFile = """ + let sourceFile: SourceFileSyntax = """ // @knit public getter-named class TestAssembly: Assembly { func assemble(container: Container) { @@ -110,7 +110,7 @@ final class AssemblyParsingTests: XCTestCase { } func testOnlyFirstOfMultipleAssemblies() throws { - let sourceFile: SourceFile = """ + let sourceFile: SourceFileSyntax = """ class KeyValueStoreAssembly: Assembly { func assemble(container: Container) { container.register(KeyValueStore.self) { } @@ -135,7 +135,7 @@ final class AssemblyParsingTests: XCTestCase { } func testAdditionalFunctions() throws { - let sourceFile: SourceFile = """ + let sourceFile: SourceFileSyntax = """ class ExampleAssembly: Assembly { func assemble(container: Container) { partialAssemble(container: container) @@ -164,7 +164,7 @@ final class AssemblyParsingTests: XCTestCase { } func testAdditionalFunctionsInComputedPropertyAreNotParsed() throws { - let sourceFile: SourceFile = """ + let sourceFile: SourceFileSyntax = """ class ExampleAssembly: Assembly { func assemble(container: Container) { container.register(MyService.self) { } @@ -188,23 +188,27 @@ final class AssemblyParsingTests: XCTestCase { // MARK: - ClassDecl Extension - func testClassDeclExtension() { - var classDecl: ClassDecl + func testClassDeclExtension() throws { + var classDecl: ClassDeclSyntax - classDecl = "class BarAssembly {}" + func makeClassDecl(from string: String) throws -> ClassDeclSyntax { + return try ClassDeclSyntax("\(raw: string)") + } + + classDecl = try makeClassDecl(from: "class BarAssembly {}") XCTAssertEqual(classDecl.moduleNameForAssembly, "Bar") - classDecl = "public final class FooAssembly {}" + classDecl = try makeClassDecl(from: "public final class FooAssembly {}") XCTAssertEqual(classDecl.moduleNameForAssembly, "Foo") - classDecl = "class AssemblyMissing {}" + classDecl = try makeClassDecl(from: "class AssemblyMissing {}") XCTAssertNil(classDecl.moduleNameForAssembly) } // MARK: - Error Throwing func testSyntaxParsingError() { - let sourceFile: SourceFile = """ + let sourceFile: SourceFileSyntax = """ class SomeClass { } // missing an assembly """ @@ -218,7 +222,7 @@ final class AssemblyParsingTests: XCTestCase { } func testRegistrationParsingErrorToPrint() throws { - let sourceFile: SourceFile = """ + let sourceFile: SourceFileSyntax = """ class MyAssembly: Assembly { func assemble(container: Container) { container.register(A.self) { resolver, arg1 in A(arg: arg1) } @@ -243,7 +247,7 @@ final class AssemblyParsingTests: XCTestCase { } private func assertParsesSyntaxTree( - _ sourceFile: SourceFile, + _ sourceFile: SourceFileSyntax, assertErrorsToPrint assertErrorsCallback: (([Error]) -> Void)? = nil, file: StaticString = #filePath, line: UInt = #line diff --git a/Tests/KnitCodeGenTests/ConfigurationSetTests.swift b/Tests/KnitCodeGenTests/ConfigurationSetTests.swift index bf92871..3e7d663 100644 --- a/Tests/KnitCodeGenTests/ConfigurationSetTests.swift +++ b/Tests/KnitCodeGenTests/ConfigurationSetTests.swift @@ -9,9 +9,8 @@ final class ConfigurationSetTests: XCTestCase { let configSet = ConfigurationSet(assemblies: [Factory.config1, Factory.config2, Factory.config3]) XCTAssertEqual( - configSet.makeTypeSafetySourceFile(), + try configSet.makeTypeSafetySourceFile(), """ - // Generated using Knit // Do not edit directly! @@ -21,14 +20,12 @@ final class ConfigurationSetTests: XCTestCase { // The correct resolution of each of these types is enforced by a matching automated unit test // If a type registration is missing or broken then the automated tests will fail for that PR - // Generated from Module1Assembly extension Resolver { public func service1() -> Service1 { self.resolve(Service1.self)! } } - // Generated from Module2Assembly extension Resolver { func callAsFunction() -> Service2 { @@ -38,7 +35,6 @@ final class ConfigurationSetTests: XCTestCase { self.resolve(ArgumentService.self, argument: string)! } } - // Generated from Module3Assembly extension Resolver { public func service3() -> Service3 { @@ -53,9 +49,8 @@ final class ConfigurationSetTests: XCTestCase { let configSet = ConfigurationSet(assemblies: [Factory.config1, Factory.config2]) XCTAssertEqual( - configSet.makeUnitTestSourceFile(), + try configSet.makeUnitTestSourceFile(), #""" - // Generated using Knit // Do not edit directly! @@ -63,7 +58,6 @@ final class ConfigurationSetTests: XCTestCase { import XCTest import Dependency1 import Dependency2 - final class Module1RegistrationTests: XCTestCase { func testRegistrations() { // In the test target for your module, please provide a static method that creates a @@ -74,7 +68,6 @@ final class ConfigurationSetTests: XCTestCase { resolver.assertCollectionResolves(CollectionService.self, count: 1) } } - final class Module2RegistrationTests: XCTestCase { func testRegistrations() { // In the test target for your module, please provide a static method that creates a @@ -91,9 +84,8 @@ final class ConfigurationSetTests: XCTestCase { struct Module2RegistrationTestArguments { let argumentServiceString: String } - private extension Resolver { - func assertTypeResolves < T > ( + func assertTypeResolves( _ type: T.Type, name: String? = nil, file: StaticString = #filePath, @@ -106,7 +98,7 @@ final class ConfigurationSetTests: XCTestCase { line: line ) } - func assertTypeResolved < T > ( + func assertTypeResolved( _ result: T?, file: StaticString = #filePath, line: UInt = #line @@ -118,7 +110,7 @@ final class ConfigurationSetTests: XCTestCase { line: line ) } - func assertCollectionResolves < T > ( + func assertCollectionResolves( _ type: T.Type, count expectedCount: Int, file: StaticString = #filePath, @@ -128,10 +120,10 @@ final class ConfigurationSetTests: XCTestCase { XCTAssert( actualCount >= expectedCount, """ - The resolved ServiceCollection<\(type)> did not contain the expected number of services \ - (resolved \(actualCount), expected \(expectedCount)). - Make sure your assembler contains a ServiceCollector behavior. - """, + The resolved ServiceCollection<\(type)> did not contain the expected number of services \ + (resolved \(actualCount), expected \(expectedCount)). + Make sure your assembler contains a ServiceCollector behavior. + """, file: file, line: line ) diff --git a/Tests/KnitCodeGenTests/ModuleAssemblyExtensionSourceFileTests.swift b/Tests/KnitCodeGenTests/ModuleAssemblyExtensionSourceFileTests.swift index 7dc0068..76bc077 100644 --- a/Tests/KnitCodeGenTests/ModuleAssemblyExtensionSourceFileTests.swift +++ b/Tests/KnitCodeGenTests/ModuleAssemblyExtensionSourceFileTests.swift @@ -10,8 +10,8 @@ import XCTest final class ModuleAssemblyExtensionSourceFileTests: XCTestCase { - func test_generation() { - let result = ModuleAssemblyExtensionSourceFile.make( + func test_generation() throws { + let result = try ModuleAssemblyExtensionSourceFile.make( currentModuleName: "CurrentModule", dependencyModuleNames: [ "DependencyA", @@ -21,7 +21,6 @@ final class ModuleAssemblyExtensionSourceFileTests: XCTestCase { ) let expected = #""" - // Generated using Knit // Do not edit directly! @@ -31,20 +30,20 @@ final class ModuleAssemblyExtensionSourceFileTests: XCTestCase { extension CurrentModuleAssembly: GeneratedModuleAssembly { public static var generatedDependencies: [any ModuleAssembly.Type] { [ - DependencyAAssembly.self, - DependencyBAssembly.self] + DependencyAAssembly.self, + DependencyBAssembly.self] } } """# XCTAssertEqual( - result.formatted().description.replacingOccurrences(of: ", \n", with: ",\n"), + result.formatted().description, expected ) } - func test_generation_emptyDependencies() { - let result = ModuleAssemblyExtensionSourceFile.make( + func test_generation_emptyDependencies() throws { + let result = try ModuleAssemblyExtensionSourceFile.make( currentModuleName: "CurrentModule", dependencyModuleNames: [ ], @@ -52,7 +51,6 @@ final class ModuleAssemblyExtensionSourceFileTests: XCTestCase { ) let expected = #""" - // Generated using Knit // Do not edit directly! @@ -65,13 +63,13 @@ final class ModuleAssemblyExtensionSourceFileTests: XCTestCase { """# XCTAssertEqual( - result.formatted().description.replacingOccurrences(of: ", \n", with: ",\n"), + result.formatted().description, expected ) } - func test_generation_additionalAssemblies() { - let result = ModuleAssemblyExtensionSourceFile.make( + func test_generation_additionalAssemblies() throws { + let result = try ModuleAssemblyExtensionSourceFile.make( currentModuleName: "CurrentModule", dependencyModuleNames: [ "DependencyA", @@ -83,7 +81,6 @@ final class ModuleAssemblyExtensionSourceFileTests: XCTestCase { ) let expected = #""" - // Generated using Knit // Do not edit directly! @@ -92,15 +89,15 @@ final class ModuleAssemblyExtensionSourceFileTests: XCTestCase { extension CurrentModuleAssembly: GeneratedModuleAssembly { public static var generatedDependencies: [any ModuleAssembly.Type] { [ - DependencyAAssembly.self, - DependencyASubAssembly.self, - DependencyAOtherAssembly.self] + DependencyAAssembly.self, + DependencyASubAssembly.self, + DependencyAOtherAssembly.self] } } """# XCTAssertEqual( - result.formatted().description.replacingOccurrences(of: ", \n", with: ",\n"), + result.formatted().description, expected ) } diff --git a/Tests/KnitCodeGenTests/RegistrationParsingTests.swift b/Tests/KnitCodeGenTests/RegistrationParsingTests.swift index 63574b2..b9aae66 100644 --- a/Tests/KnitCodeGenTests/RegistrationParsingTests.swift +++ b/Tests/KnitCodeGenTests/RegistrationParsingTests.swift @@ -3,21 +3,22 @@ // @testable import KnitCodeGen +import SwiftSyntax import SwiftSyntaxBuilder import XCTest final class RegistrationParsingTests: XCTestCase { - func testRegistrationStatements() { - assertRegistrationString( + func testRegistrationStatements() throws { + try assertRegistrationString( "container.register(AType.self)", serviceName: "AType" ) - assertRegistrationString( + try assertRegistrationString( "container.autoregister(BType.self)", serviceName: "BType" ) - assertMultipleRegistrationsString( + try assertMultipleRegistrationsString( """ container.register(AType.self) { _ in } .implements(AnotherType.self) @@ -28,7 +29,7 @@ final class RegistrationParsingTests: XCTestCase { Registration(service: "AnotherType", name: nil, accessLevel: .internal, isForwarded: true), ] ) - assertRegistrationString( + try assertRegistrationString( """ container.autoregister( AnyPublisher.self, @@ -37,7 +38,7 @@ final class RegistrationParsingTests: XCTestCase { """, serviceName: "AnyPublisher" ) - assertRegistrationString( + try assertRegistrationString( """ container.autoregister( ((String) -> EntityGainLossDataArchiver).self, @@ -47,7 +48,7 @@ final class RegistrationParsingTests: XCTestCase { serviceName: "((String) -> EntityGainLossDataArchiver)" ) - assertRegistrationString( + try assertRegistrationString( """ // @knit public container.register(AType.self) @@ -57,8 +58,8 @@ final class RegistrationParsingTests: XCTestCase { ) } - func testHiddenRegistrations() { - assertRegistrationString( + func testHiddenRegistrations() throws { + try assertRegistrationString( """ // @knit hidden container.register(AType.self) @@ -69,7 +70,7 @@ final class RegistrationParsingTests: XCTestCase { } func testNamedRegistrations() throws { - assertRegistrationString( + try assertRegistrationString( """ container.register(A.self, name: "service") { } """, @@ -77,7 +78,7 @@ final class RegistrationParsingTests: XCTestCase { name: "service" ) - assertRegistrationString( + try assertRegistrationString( """ container.autoregister(A.self, name: "service2", initializer: A.init) """, @@ -87,7 +88,7 @@ final class RegistrationParsingTests: XCTestCase { } func testGetterConfigRegistrations() throws { - assertMultipleRegistrationsString( + try assertMultipleRegistrationsString( """ // @knit public getter-named container.register(A.self) { } @@ -97,7 +98,7 @@ final class RegistrationParsingTests: XCTestCase { ] ) - assertMultipleRegistrationsString( + try assertMultipleRegistrationsString( """ // @knit public getter-callAsFunction container.register(A.self) { } @@ -107,7 +108,7 @@ final class RegistrationParsingTests: XCTestCase { ] ) - assertMultipleRegistrationsString( + try assertMultipleRegistrationsString( """ // @knit public getter-named getter-callAsFunction container.register(A.self) { } @@ -119,14 +120,14 @@ final class RegistrationParsingTests: XCTestCase { } func testAbstractRegistration() throws { - assertRegistrationString( + try assertRegistrationString( """ container.registerAbstract(AType.self) """, serviceName: "AType" ) - assertRegistrationString( + try assertRegistrationString( """ container.registerAbstract(AType.self, name: "service") """, @@ -136,7 +137,7 @@ final class RegistrationParsingTests: XCTestCase { } func testForwardedRegistration() throws { - assertMultipleRegistrationsString( + try assertMultipleRegistrationsString( """ container.register(A.self) { } .implements(B.self) @@ -147,7 +148,7 @@ final class RegistrationParsingTests: XCTestCase { ] ) - assertMultipleRegistrationsString( + try assertMultipleRegistrationsString( """ container.autoregister(A.self, initializer: A.init) .implements(B.self) @@ -163,7 +164,7 @@ final class RegistrationParsingTests: XCTestCase { ] ) - assertMultipleRegistrationsString( + try assertMultipleRegistrationsString( """ // @knit hidden container.register(A.self) { } @@ -180,9 +181,9 @@ final class RegistrationParsingTests: XCTestCase { ) } - func testRegisterWithArguments() { + func testRegisterWithArguments() throws { // Single argument, trailing closure - assertMultipleRegistrationsString( + try assertMultipleRegistrationsString( """ container.register(A.self) { (_, arg: String) in A(string: arg) @@ -194,7 +195,7 @@ final class RegistrationParsingTests: XCTestCase { ) // Single argument, named parameter - assertMultipleRegistrationsString( + try assertMultipleRegistrationsString( """ container.register(A.self, factory: { (_, arg: String) in A(string: arg) @@ -206,7 +207,7 @@ final class RegistrationParsingTests: XCTestCase { ) // Multiple arguments, trailing closure - assertMultipleRegistrationsString( + try assertMultipleRegistrationsString( """ container.register(A.self) { (resolver: Resolver, arg: String, arg2: Int) in A() @@ -225,7 +226,7 @@ final class RegistrationParsingTests: XCTestCase { ) // Multiple arguments, named parameter - assertMultipleRegistrationsString( + try assertMultipleRegistrationsString( """ container.register(A.self, factory: { (resolver: Resolver, arg: String, arg2: Int) in A() @@ -244,7 +245,7 @@ final class RegistrationParsingTests: XCTestCase { ) // Unused arguments (for test/abstract usages) - assertMultipleRegistrationsString( + try assertMultipleRegistrationsString( """ container.register(A.self, factory: { (_: Resolver, _: String) in A() @@ -262,9 +263,9 @@ final class RegistrationParsingTests: XCTestCase { ) } - func testAutoregisterWithArguments() { + func testAutoregisterWithArguments() throws { // Single argument - assertMultipleRegistrationsString( + try assertMultipleRegistrationsString( "container.autoregister(A.self, argument: URL.self, initializer: A.init)", registrations: [ Registration(service: "A", accessLevel: .internal, arguments: [.init(type: "URL")]) @@ -272,7 +273,7 @@ final class RegistrationParsingTests: XCTestCase { ) // Multiple arguments - assertMultipleRegistrationsString( + try assertMultipleRegistrationsString( """ container.autoregister( A.self, @@ -296,7 +297,7 @@ final class RegistrationParsingTests: XCTestCase { ) // Single argument with name - assertMultipleRegistrationsString( + try assertMultipleRegistrationsString( """ container.autoregister(A.self, name: "test", argument: URL.self, initializer: A.init) """, @@ -306,7 +307,7 @@ final class RegistrationParsingTests: XCTestCase { ) // Multiple arguments with name - assertMultipleRegistrationsString( + try assertMultipleRegistrationsString( """ container.autoregister( A.self, @@ -327,9 +328,9 @@ final class RegistrationParsingTests: XCTestCase { ) } - func testRegisterNonClosureFactoryType() { + func testRegisterNonClosureFactoryType() throws { // This is acceptable syntax but we will not be able to parse any arguments - assertMultipleRegistrationsString( + try assertMultipleRegistrationsString( """ container.register(A.self, factory: A.staticFunc) """, @@ -340,9 +341,9 @@ final class RegistrationParsingTests: XCTestCase { } // Arguments on the main registration apply to implements also - func testForwardedWithArgument() { + func testForwardedWithArgument() throws { // Single argument autoregister - assertMultipleRegistrationsString( + try assertMultipleRegistrationsString( """ container.autoregister(A.self, argument: URL.self, initializer: A.init) .implements(B.self) @@ -354,7 +355,7 @@ final class RegistrationParsingTests: XCTestCase { ) // Single argument register - assertMultipleRegistrationsString( + try assertMultipleRegistrationsString( """ container.register(A.self) { (_, arg: String) in A(string: arg) @@ -368,8 +369,8 @@ final class RegistrationParsingTests: XCTestCase { ) } - func testRegistrationWithComplexTypes() { - assertMultipleRegistrationsString( + func testRegistrationWithComplexTypes() throws { + try assertMultipleRegistrationsString( """ container.register(A.self) { (_, arg: A.Argument) in A(string: arg.string) @@ -380,7 +381,7 @@ final class RegistrationParsingTests: XCTestCase { ] ) - assertMultipleRegistrationsString( + try assertMultipleRegistrationsString( """ container.register(A.self) { (_, arg: Result) in A(string: arg.string) @@ -392,8 +393,8 @@ final class RegistrationParsingTests: XCTestCase { ) } - func testAutoRegistrationWithComplexTypes() { - assertMultipleRegistrationsString( + func testAutoRegistrationWithComplexTypes() throws { + try assertMultipleRegistrationsString( """ container.autoregister(A.self, argument: String?.self, initializer: A.init) """, @@ -402,7 +403,7 @@ final class RegistrationParsingTests: XCTestCase { ] ) - assertMultipleRegistrationsString( + try assertMultipleRegistrationsString( """ container.autoregister(A.self, arguments: Result.self, Optional.self, initializer: A.init) """, @@ -417,7 +418,7 @@ final class RegistrationParsingTests: XCTestCase { ] ) - assertMultipleRegistrationsString( + try assertMultipleRegistrationsString( """ // @knit getter-named container.autoregister((String, Int?).self, initializer: Factory.make) @@ -428,15 +429,15 @@ final class RegistrationParsingTests: XCTestCase { ) } - func testClosureArgument() { - assertMultipleRegistrationsString( + func testClosureArgument() throws { + try assertMultipleRegistrationsString( "container.autoregister(A.self, argument: (() -> Void).self, initializer: A.init)", registrations: [ Registration(service: "A", accessLevel: .internal, arguments: [.init(type: "(() -> Void)")]), ] ) - assertMultipleRegistrationsString( + try assertMultipleRegistrationsString( """ container.register(A.self) { (resolver, arg1: @escaping () -> Void) in A(arg: arg1) @@ -448,15 +449,15 @@ final class RegistrationParsingTests: XCTestCase { ) } - func testArgumentMissingType() { + func testArgumentMissingType() throws { // Type of arg can be inferred at build time but cannot be parsed - let string = """ + let expr: ExprSyntax = """ container.register(A.self) { (_, myArg) in A(string: myArg) } """ - let functionCall = FunctionCallExpr(stringLiteral: string) + let functionCall = try XCTUnwrap(FunctionCallExprSyntax(expr)) XCTAssertThrowsError(try functionCall.getRegistrations()) { error in if case let RegistrationParsingError.missingArgumentType(_, name) = error { @@ -471,14 +472,14 @@ final class RegistrationParsingTests: XCTestCase { } } - func testUnsupportedClosureSynatx() { - let string = """ + func testUnsupportedClosureSynatx() throws { + let expr: ExprSyntax = """ container.register(A.self) { _, myArg in A(string: myArg) } """ - let functionCall = FunctionCallExpr(stringLiteral: string) + let functionCall = try XCTUnwrap(FunctionCallExprSyntax(expr)) XCTAssertThrowsError(try functionCall.getRegistrations()) { error in if case RegistrationParsingError.unwrappedClosureParams = error { @@ -493,14 +494,14 @@ final class RegistrationParsingTests: XCTestCase { } } - func testInvalidName() { - let string = """ + func testInvalidName() throws { + let expr: ExprSyntax = """ container.register(A.self, name: name) { _ in A() } """ - let functionCall = FunctionCallExpr(stringLiteral: string) + let functionCall = try XCTUnwrap(FunctionCallExprSyntax(expr)) XCTAssertThrowsError(try functionCall.getRegistrations()) { error in XCTAssertEqual( @@ -510,8 +511,8 @@ final class RegistrationParsingTests: XCTestCase { } } - func testRegistrationIntoCollection() { - assertMultipleRegistrationsString( + func testRegistrationIntoCollection() throws { + try assertMultipleRegistrationsString( """ container.registerIntoCollection(AType.self) {} .inObjectScope(.container) @@ -520,7 +521,7 @@ final class RegistrationParsingTests: XCTestCase { .init(service: "AType"), ] ) - assertMultipleRegistrationsString( + try assertMultipleRegistrationsString( """ container.autoregisterIntoCollection(AType.self, initializer: AType.init) .inObjectScope(.container) @@ -531,8 +532,8 @@ final class RegistrationParsingTests: XCTestCase { ) } - func testMultiLineComments() { - assertMultipleRegistrationsString( + func testMultiLineComments() throws { + try assertMultipleRegistrationsString( """ // General comment // @knit public @@ -545,11 +546,11 @@ final class RegistrationParsingTests: XCTestCase { ) } - func testIncorrectRegistrations() { - assertNoRegistrationsString("container.someOtherMethod(AType.self)", message: "Incorrect method name") - assertNoRegistrationsString("container.register(A)", message: "First param is not a metatype") - assertNoRegistrationsString("doThing()", message:"Unrelated function call") - assertNoRegistrationsString("container.implements(AType.self)", message: "Missing primary registration") + func testIncorrectRegistrations() throws { + try assertNoRegistrationsString("container.someOtherMethod(AType.self)", message: "Incorrect method name") + try assertNoRegistrationsString("container.register(A)", message: "First param is not a metatype") + try assertNoRegistrationsString("doThing()", message:"Unrelated function call") + try assertNoRegistrationsString("container.implements(AType.self)", message: "Missing primary registration") } } @@ -562,10 +563,10 @@ private func assertRegistrationString( name: String? = nil, isForwarded: Bool = false, file: StaticString = #filePath, line: UInt = #line -) { - let functionCall = FunctionCallExpr(stringLiteral: string) +) throws { + let functionCall = try XCTUnwrap(FunctionCallExprSyntax("\(raw: string)" as ExprSyntax)) - let (registrations, registrationsIntoCollecions) = try! functionCall.getRegistrations() + let (registrations, registrationsIntoCollecions) = try functionCall.getRegistrations() XCTAssertEqual(registrations.count, 1, file: file, line: line) XCTAssert(registrationsIntoCollecions.isEmpty, file: file, line: line) @@ -583,10 +584,10 @@ private func assertMultipleRegistrationsString( registrations: [Registration] = [], registrationsIntoCollections: [RegistrationIntoCollection] = [], file: StaticString = #filePath, line: UInt = #line -) { - let functionCall = FunctionCallExpr(stringLiteral: string) +) throws { + let functionCall = try XCTUnwrap(FunctionCallExprSyntax("\(raw: string)" as ExprSyntax)) - let (parsedRegistrations, parsedRegistrationsIntoCollections) = try! functionCall.getRegistrations() + let (parsedRegistrations, parsedRegistrationsIntoCollections) = try functionCall.getRegistrations() XCTAssertEqual(parsedRegistrations.count, registrations.count, file: file, line: line) XCTAssertEqual(parsedRegistrations, registrations, file: file, line: line) @@ -599,9 +600,9 @@ private func assertNoRegistrationsString( _ string: String, message: String = "", file: StaticString = #filePath, line: UInt = #line -) { - let functionCall = FunctionCallExpr(stringLiteral: string) - let (registrations, registrationsIntoCollections) = try! functionCall.getRegistrations() +) throws { + let functionCall = try XCTUnwrap(FunctionCallExprSyntax("\(raw: string)" as ExprSyntax)) + let (registrations, registrationsIntoCollections) = try functionCall.getRegistrations() XCTAssert(registrations.isEmpty, message, file: file, line: line) XCTAssert(registrationsIntoCollections.isEmpty, message, file: file, line: line) } diff --git a/Tests/KnitCodeGenTests/TypeSafetySourceFileTests.swift b/Tests/KnitCodeGenTests/TypeSafetySourceFileTests.swift index 9cc4e8e..79eb1f9 100644 --- a/Tests/KnitCodeGenTests/TypeSafetySourceFileTests.swift +++ b/Tests/KnitCodeGenTests/TypeSafetySourceFileTests.swift @@ -7,8 +7,8 @@ import XCTest final class TypeSafetySourceFileTests: XCTestCase { - func test_generation() { - let result = TypeSafetySourceFile.make( + func test_generation() throws { + let result = try TypeSafetySourceFile.make( assemblyName: "ModuleAssembly", extensionTarget: "Resolve", registrations: [ @@ -24,7 +24,6 @@ final class TypeSafetySourceFileTests: XCTestCase { ) let expected = """ - // Generated from ModuleAssembly extension Resolve { func serviceA() -> ServiceA { @@ -63,7 +62,7 @@ final class TypeSafetySourceFileTests: XCTestCase { func testRegistrationMultipleArguments() { let registration = Registration(service: "A", accessLevel: .public, arguments: [.init(type: "String"), .init(type: "URL")]) XCTAssertEqual( - TypeSafetySourceFile.makeResolver( + try TypeSafetySourceFile.makeResolver( registration: registration, enumName: nil ).formatted().description, @@ -78,7 +77,7 @@ final class TypeSafetySourceFileTests: XCTestCase { func testRegistrationSingleArgument() { let registration = Registration(service: "A", accessLevel: .public, arguments: [.init(type: "String")]) XCTAssertEqual( - TypeSafetySourceFile.makeResolver( + try TypeSafetySourceFile.makeResolver( registration: registration, enumName: nil ).formatted().description, @@ -93,7 +92,7 @@ final class TypeSafetySourceFileTests: XCTestCase { func testRegistrationDuplicateParamType() { let registration = Registration(service: "A", accessLevel: .public, arguments: [.init(type: "String"), .init(type: "String")]) XCTAssertEqual( - TypeSafetySourceFile.makeResolver( + try TypeSafetySourceFile.makeResolver( registration: registration, enumName: nil ).formatted().description, @@ -108,7 +107,7 @@ final class TypeSafetySourceFileTests: XCTestCase { func testRegistrationArgumentAndName() { let registration = Registration(service: "A", name: "test", accessLevel: .public, arguments: [.init(type: "String")]) XCTAssertEqual( - TypeSafetySourceFile.makeResolver( + try TypeSafetySourceFile.makeResolver( registration: registration, enumName: "MyAssembly.A_ResolutionKey" ).formatted().description, @@ -123,7 +122,7 @@ final class TypeSafetySourceFileTests: XCTestCase { func testRegistrationWithPrenamedArguments() { let registration = Registration(service: "A", accessLevel: .public, arguments: [.init(identifier: "arg", type: "String")]) XCTAssertEqual( - TypeSafetySourceFile.makeResolver( + try TypeSafetySourceFile.makeResolver( registration: registration, enumName: nil ).formatted().description, diff --git a/Tests/KnitCodeGenTests/UnitTestSourceFileTests.swift b/Tests/KnitCodeGenTests/UnitTestSourceFileTests.swift index 14596d8..1200612 100644 --- a/Tests/KnitCodeGenTests/UnitTestSourceFileTests.swift +++ b/Tests/KnitCodeGenTests/UnitTestSourceFileTests.swift @@ -7,10 +7,10 @@ import XCTest final class UnitTestSourceFileTests: XCTestCase { - func test_generation() { - let result = UnitTestSourceFile.make( + func test_generation() throws { + let result = try UnitTestSourceFile.make( name: "MyModule", - importDecls: [ImportDeclSyntax("import Swinject")], + importDecls: [try ImportDeclSyntax("import Swinject")], registrations: [ .init(service: "ServiceA", name: nil, accessLevel: .internal, isForwarded: false), .init(service: "ServiceB", name: "name", accessLevel: .internal, isForwarded: false), @@ -25,10 +25,9 @@ final class UnitTestSourceFileTests: XCTestCase { ) //Remote trailing line spaces - let formattedResult = result.formatted().description.replacingOccurrences(of: ", \n", with: ",\n") + let formattedResult = result.formatted().description let expected = #""" - final class MyModuleRegistrationTests: XCTestCase { func testRegistrations() { // In the test target for your module, please provide a static method that creates a @@ -54,19 +53,18 @@ final class UnitTestSourceFileTests: XCTestCase { XCTAssertEqual(formattedResult, expected) } - func test_generation_emptyRegistrations() { - let result = UnitTestSourceFile.make( + func test_generation_emptyRegistrations() throws { + let result = try UnitTestSourceFile.make( name: "MyModule", - importDecls: [ImportDeclSyntax("import Swinject")], + importDecls: [try ImportDeclSyntax("import Swinject")], registrations: [], registrationsIntoCollections: [] ) //Remote trailing line spaces - let formattedResult = result.formatted().description.replacingOccurrences(of: ", \n", with: ",\n") + let formattedResult = result.formatted().description let expected = #""" - final class MyModuleRegistrationTests: XCTestCase { func testRegistrations() { // In the test target for your module, please provide a static method that creates a @@ -80,10 +78,10 @@ final class UnitTestSourceFileTests: XCTestCase { XCTAssertEqual(formattedResult, expected) } - func test_generation_onlySingleRegistrations() { - let result = UnitTestSourceFile.make( + func test_generation_onlySingleRegistrations() throws { + let result = try UnitTestSourceFile.make( name: "MyModule", - importDecls: [ImportDeclSyntax("import Swinject")], + importDecls: [try ImportDeclSyntax("import Swinject")], registrations: [ .init(service: "ServiceA", name: nil, accessLevel: .internal, isForwarded: false), ], @@ -91,10 +89,9 @@ final class UnitTestSourceFileTests: XCTestCase { ) //Remote trailing line spaces - let formattedResult = result.formatted().description.replacingOccurrences(of: ", \n", with: ",\n") + let formattedResult = result.formatted().description let expected = #""" - final class MyModuleRegistrationTests: XCTestCase { func testRegistrations() { // In the test target for your module, please provide a static method that creates a @@ -109,10 +106,10 @@ final class UnitTestSourceFileTests: XCTestCase { XCTAssertEqual(formattedResult, expected) } - func test_generation_onlyRegistrationsIntoCollections() { - let result = UnitTestSourceFile.make( + func test_generation_onlyRegistrationsIntoCollections() throws { + let result = try UnitTestSourceFile.make( name: "MyModule", - importDecls: [ImportDeclSyntax("import Swinject")], + importDecls: [try ImportDeclSyntax("import Swinject")], registrations: [], registrationsIntoCollections: [ .init(service: "ServiceA"), @@ -120,10 +117,9 @@ final class UnitTestSourceFileTests: XCTestCase { ) //Remote trailing line spaces - let formattedResult = result.formatted().description.replacingOccurrences(of: ", \n", with: ",\n") + let formattedResult = result.formatted().description let expected = #""" - final class MyModuleRegistrationTests: XCTestCase { func testRegistrations() { // In the test target for your module, please provide a static method that creates a @@ -138,15 +134,15 @@ final class UnitTestSourceFileTests: XCTestCase { XCTAssertEqual(formattedResult, expected) } - func test_argumentStruct() { + func test_argumentStruct() throws { let registrations = [ Registration(service: "A", accessLevel: .public, arguments: [.init(type: "String")]), Registration(service: "B", accessLevel: .public, arguments: [.init(identifier: "field", type: "String"), .init(type: "String")]), Registration(service: "A", accessLevel: .public, arguments: [.init(type: "Int"), .init(type: "String")]), ] - let result = UnitTestSourceFile.makeArgumentStruct(registrations: registrations, moduleName: "MyModule") + let result = try UnitTestSourceFile.makeArgumentStruct(registrations: registrations, moduleName: "MyModule") - let formattedResult = result.formatted().description.replacingOccurrences(of: ", \n", with: ",\n") + let formattedResult = result.formatted().description let expected = """ struct MyModuleRegistrationTestArguments { @@ -162,14 +158,14 @@ final class UnitTestSourceFileTests: XCTestCase { func test_registrationAssertPlain() { let result = UnitTestSourceFile.makeAssertCall(registration: .init(service: "A", accessLevel: .hidden)) - let formattedResult = result.formatted().description.replacingOccurrences(of: ", \n", with: ",\n") + let formattedResult = result.formatted().description let expected = "resolver.assertTypeResolves(A.self)" XCTAssertEqual(formattedResult, expected) } func test_registrationAssertNamed() { let result = UnitTestSourceFile.makeAssertCall(registration: .init(service: "A", name: "Name", accessLevel: .hidden)) - let formattedResult = result.formatted().description.replacingOccurrences(of: ", \n", with: ",\n") + let formattedResult = result.formatted().description let expected = "resolver.assertTypeResolves(A.self, name: \"Name\")" XCTAssertEqual(formattedResult, expected) } @@ -184,7 +180,7 @@ final class UnitTestSourceFileTests: XCTestCase { ] ) let result = UnitTestSourceFile.makeAssertCall(registration: registration) - let formattedResult = result.formatted().description.replacingOccurrences(of: ", \n", with: ",\n") + let formattedResult = result.formatted().description let expected = "resolver.assertTypeResolved(resolver.resolve(A.self, name: \"Name\", argument: args.aString))" XCTAssertEqual(formattedResult, expected) } @@ -199,7 +195,7 @@ final class UnitTestSourceFileTests: XCTestCase { ] ) let result = UnitTestSourceFile.makeAssertCall(registration: registration) - let formattedResult = result.formatted().description.replacingOccurrences(of: ", \n", with: ",\n") + let formattedResult = result.formatted().description let expected = "resolver.assertTypeResolved(resolver.resolve(A.self, arguments: args.aString1, args.aString2))" XCTAssertEqual(formattedResult, expected) } @@ -212,13 +208,13 @@ private extension UnitTestSourceFile { importDecls: [ImportDeclSyntax], registrations: [Registration], registrationsIntoCollections: [RegistrationIntoCollection] - ) -> SourceFileSyntax { + ) throws -> SourceFileSyntax { let configuration = Configuration( name: name, registrations: registrations, registrationsIntoCollections: registrationsIntoCollections, imports: importDecls ) - return UnitTestSourceFile.make(configuration: configuration) + return try UnitTestSourceFile.make(configuration: configuration) } }