diff --git a/Sources/GRPCCodeGen/Internal/StructuredSwift+Server.swift b/Sources/GRPCCodeGen/Internal/StructuredSwift+Server.swift new file mode 100644 index 000000000..c46986fa3 --- /dev/null +++ b/Sources/GRPCCodeGen/Internal/StructuredSwift+Server.swift @@ -0,0 +1,425 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +extension FunctionSignatureDescription { + /// ``` + /// func ( + /// request: GRPCCore.ServerRequest, + /// context: GRPCCore.ServerContext + /// ) async throws -> GRPCCore.ServerResponse + /// ``` + static func serverMethod( + accessLevel: AccessModifier? = nil, + name: String, + input: String, + output: String, + streamingInput: Bool, + streamingOutput: Bool + ) -> Self { + return FunctionSignatureDescription( + accessModifier: accessLevel, + kind: .function(name: name), + parameters: [ + ParameterDescription( + label: "request", + type: .serverRequest(forType: input, streaming: streamingInput) + ), + ParameterDescription(label: "context", type: .serverContext), + ], + keywords: [.async, .throws], + returnType: .identifierType(.serverResponse(forType: output, streaming: streamingOutput)) + ) + } +} + +extension ProtocolDescription { + /// ``` + /// protocol : GRPCCore.RegistrableRPCService { + /// ... + /// } + /// ``` + static func streamingService( + accessLevel: AccessModifier? = nil, + name: String, + methods: [MethodDescriptor] + ) -> Self { + return ProtocolDescription( + accessModifier: accessLevel, + name: name, + conformances: ["GRPCCore.RegistrableRPCService"], + members: methods.map { method in + .commentable( + .preFormatted(method.documentation), + .function( + signature: .serverMethod( + name: method.name.generatedLowerCase, + input: method.inputType, + output: method.outputType, + streamingInput: true, + streamingOutput: true + ) + ) + ) + } + ) + } +} + +extension ExtensionDescription { + /// ``` + /// extension { + /// func registerMethods(with router: inout GRPCCore.RPCRouter) { + /// // ... + /// } + /// } + /// ``` + static func registrableRPCServiceDefaultImplementation( + accessLevel: AccessModifier? = nil, + on extensionName: String, + serviceNamespace: String, + methods: [MethodDescriptor], + serializer: (String) -> String, + deserializer: (String) -> String + ) -> Self { + return ExtensionDescription( + onType: extensionName, + declarations: [ + .function( + .registerMethods( + accessLevel: accessLevel, + serviceNamespace: serviceNamespace, + methods: methods, + serializer: serializer, + deserializer: deserializer + ) + ) + ] + ) + } +} + +extension ProtocolDescription { + /// ``` + /// protocol : { + /// ... + /// } + /// ``` + static func service( + accessLevel: AccessModifier? = nil, + name: String, + streamingProtocol: String, + methods: [MethodDescriptor] + ) -> Self { + return ProtocolDescription( + accessModifier: accessLevel, + name: name, + conformances: [streamingProtocol], + members: methods.map { method in + .commentable( + .preFormatted(method.documentation), + .function( + signature: .serverMethod( + name: method.name.generatedLowerCase, + input: method.inputType, + output: method.outputType, + streamingInput: method.isInputStreaming, + streamingOutput: method.isOutputStreaming + ) + ) + ) + } + ) + } +} + +extension FunctionCallDescription { + /// ``` + /// self.(request: request, context: context) + /// ``` + static func serverMethodCallOnSelf( + name: String, + requestArgument: Expression = .identifierPattern("request") + ) -> Self { + return FunctionCallDescription( + calledExpression: .memberAccess( + MemberAccessDescription( + left: .identifierPattern("self"), + right: name + ) + ), + arguments: [ + FunctionArgumentDescription( + label: "request", + expression: requestArgument + ), + FunctionArgumentDescription( + label: "context", + expression: .identifierPattern("context") + ), + ] + ) + } +} + +extension ClosureInvocationDescription { + /// ``` + /// { router, context in + /// try await self.( + /// request: request, + /// context: context + /// ) + /// } + /// ``` + static func routerHandlerInvokingRPC(method: String) -> Self { + return ClosureInvocationDescription( + argumentNames: ["request", "context"], + body: [ + .expression( + .unaryKeyword( + kind: .try, + expression: .unaryKeyword( + kind: .await, + expression: .functionCall(.serverMethodCallOnSelf(name: method)) + ) + ) + ) + ] + ) + } +} + +/// ``` +/// router.registerHandler( +/// forMethod: ..., +/// deserializer: ... +/// serializer: ... +/// handler: { request, context in +/// // ... +/// } +/// ) +/// ``` +extension FunctionCallDescription { + static func registerWithRouter( + serviceNamespace: String, + methodNamespace: String, + methodName: String, + inputDeserializer: String, + outputSerializer: String + ) -> Self { + return FunctionCallDescription( + calledExpression: .memberAccess( + .init(left: .identifierPattern("router"), right: "registerHandler") + ), + arguments: [ + FunctionArgumentDescription( + label: "forMethod", + expression: .identifierPattern("\(serviceNamespace).Method.\(methodNamespace).descriptor") + ), + FunctionArgumentDescription( + label: "deserializer", + expression: .identifierPattern(inputDeserializer) + ), + FunctionArgumentDescription( + label: "serializer", + expression: .identifierPattern(outputSerializer) + ), + FunctionArgumentDescription( + label: "handler", + expression: .closureInvocation(.routerHandlerInvokingRPC(method: methodName)) + ), + ] + ) + } +} + +extension FunctionDescription { + /// ``` + /// func registerMethods(with router: inout GRPCCore.RPCRouter) { + /// // ... + /// } + /// ``` + static func registerMethods( + accessLevel: AccessModifier? = nil, + serviceNamespace: String, + methods: [MethodDescriptor], + serializer: (String) -> String, + deserializer: (String) -> String + ) -> Self { + return FunctionDescription( + accessModifier: accessLevel, + kind: .function(name: "registerMethods"), + parameters: [ + ParameterDescription( + label: "with", + name: "router", + type: .rpcRouter, + `inout`: true + ) + ], + body: methods.map { method in + .functionCall( + .registerWithRouter( + serviceNamespace: serviceNamespace, + methodNamespace: method.name.generatedUpperCase, + methodName: method.name.generatedLowerCase, + inputDeserializer: deserializer(method.inputType), + outputSerializer: serializer(method.outputType) + ) + ) + } + ) + } +} + +extension FunctionDescription { + /// ``` + /// func ( + /// request: GRPCCore.StreamingServerRequest + /// context: GRPCCore.ServerContext + /// ) async throws -> GRPCCore.StreamingServerResponse { + /// let response = try await self.( + /// request: GRPCCore.ServerRequest(stream: request), + /// context: context + /// ) + /// return GRPCCore.StreamingServerResponse(single: response) + /// } + /// ``` + static func serverStreamingMethodsCallingMethod( + accessLevel: AccessModifier? = nil, + name: String, + input: String, + output: String, + streamingInput: Bool, + streamingOutput: Bool + ) -> FunctionDescription { + let signature: FunctionSignatureDescription = .serverMethod( + accessLevel: accessLevel, + name: name, + input: input, + output: output, + // This method converts from the fully streamed version to the specified version. + streamingInput: true, + streamingOutput: true + ) + + // Call the underlying function. + let functionCall: Expression = .functionCall( + calledExpression: .memberAccess( + MemberAccessDescription( + left: .identifierPattern("self"), + right: name + ) + ), + arguments: [ + FunctionArgumentDescription( + label: "request", + expression: streamingInput + ? .identifierPattern("request") + : .functionCall( + calledExpression: .identifierType(.serverRequest(forType: nil, streaming: false)), + arguments: [ + FunctionArgumentDescription( + label: "stream", + expression: .identifierPattern("request") + ) + ] + ) + ), + FunctionArgumentDescription( + label: "context", + expression: .identifierPattern("context") + ), + ] + ) + + // Call the function and assign to 'response'. + let response: Declaration = .variable( + kind: .let, + left: "response", + right: .unaryKeyword( + kind: .try, + expression: .unaryKeyword( + kind: .await, + expression: functionCall + ) + ) + ) + + // Build the return statement. + let returnExpression: Expression = .unaryKeyword( + kind: .return, + expression: streamingOutput + ? .identifierPattern("response") + : .functionCall( + calledExpression: .identifierType(.serverResponse(forType: nil, streaming: true)), + arguments: [ + FunctionArgumentDescription( + label: "single", + expression: .identifierPattern("response") + ) + ] + ) + ) + + return Self( + signature: signature, + body: [.declaration(response), .expression(returnExpression)] + ) + } +} + +extension ExtensionDescription { + /// ``` + /// extension { + /// func ( + /// request: GRPCCore.StreamingServerRequest + /// context: GRPCCore.ServerContext + /// ) async throws -> GRPCCore.StreamingServerResponse { + /// let response = try await self.( + /// request: GRPCCore.ServerRequest(stream: request), + /// context: context + /// ) + /// return GRPCCore.StreamingServerResponse(single: response) + /// } + /// ... + /// } + /// ``` + static func streamingServiceProtocolDefaultImplementation( + accessModifier: AccessModifier? = nil, + on extensionName: String, + methods: [MethodDescriptor] + ) -> Self { + return ExtensionDescription( + onType: extensionName, + declarations: methods.compactMap { method -> Declaration? in + // Bidirectional streaming methods don't need a default implementation as their signatures + // match across the two protocols. + if method.isInputStreaming, method.isOutputStreaming { return nil } + + return .function( + .serverStreamingMethodsCallingMethod( + accessLevel: accessModifier, + name: method.name.generatedLowerCase, + input: method.inputType, + output: method.outputType, + streamingInput: method.isInputStreaming, + streamingOutput: method.isOutputStreaming + ) + ) + } + ) + } +} diff --git a/Tests/GRPCCodeGenTests/Internal/StructuredSwift+ServerTests.swift b/Tests/GRPCCodeGenTests/Internal/StructuredSwift+ServerTests.swift new file mode 100644 index 000000000..78dbac42d --- /dev/null +++ b/Tests/GRPCCodeGenTests/Internal/StructuredSwift+ServerTests.swift @@ -0,0 +1,336 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import Testing + +@testable import GRPCCodeGen + +extension StructuedSwiftTests { + @Suite("Server") + struct Server { + @Test( + "func (request:context:) async throws -> ...", + arguments: AccessModifier.allCases, + RPCKind.allCases + ) + func serverMethodSignature(access: AccessModifier, kind: RPCKind) { + let decl: FunctionSignatureDescription = .serverMethod( + accessLevel: access, + name: "foo", + input: "Input", + output: "Output", + streamingInput: kind.streamsInput, + streamingOutput: kind.streamsOutput + ) + + let expected: String + + switch kind { + case .unary: + expected = """ + \(access) func foo( + request: GRPCCore.ServerRequest, + context: GRPCCore.ServerContext + ) async throws -> GRPCCore.ServerResponse + """ + case .clientStreaming: + expected = """ + \(access) func foo( + request: GRPCCore.StreamingServerRequest, + context: GRPCCore.ServerContext + ) async throws -> GRPCCore.ServerResponse + """ + case .serverStreaming: + expected = """ + \(access) func foo( + request: GRPCCore.ServerRequest, + context: GRPCCore.ServerContext + ) async throws -> GRPCCore.StreamingServerResponse + """ + case .bidirectionalStreaming: + expected = """ + \(access) func foo( + request: GRPCCore.StreamingServerRequest, + context: GRPCCore.ServerContext + ) async throws -> GRPCCore.StreamingServerResponse + """ + } + + #expect(render(.function(signature: decl)) == expected) + } + + @Test("protocol StreamingServiceProtocol { ... }", arguments: AccessModifier.allCases) + func serverStreamingServiceProtocol(access: AccessModifier) { + let decl: ProtocolDescription = .streamingService( + accessLevel: access, + name: "FooService", + methods: [ + .init( + documentation: "/// Some docs", + name: .init(base: "Foo", generatedUpperCase: "Foo", generatedLowerCase: "foo"), + isInputStreaming: false, + isOutputStreaming: false, + inputType: "FooInput", + outputType: "FooOutput" + ) + ] + ) + + let expected = """ + \(access) protocol FooService: GRPCCore.RegistrableRPCService { + /// Some docs + func foo( + request: GRPCCore.StreamingServerRequest, + context: GRPCCore.ServerContext + ) async throws -> GRPCCore.StreamingServerResponse + } + """ + + #expect(render(.protocol(decl)) == expected) + } + + @Test("protocol ServiceProtocol { ... }", arguments: AccessModifier.allCases) + func serverServiceProtocol(access: AccessModifier) { + let decl: ProtocolDescription = .service( + accessLevel: access, + name: "FooService", + streamingProtocol: "FooService_StreamingServiceProtocol", + methods: [ + .init( + documentation: "/// Some docs", + name: .init(base: "Foo", generatedUpperCase: "Foo", generatedLowerCase: "foo"), + isInputStreaming: false, + isOutputStreaming: false, + inputType: "FooInput", + outputType: "FooOutput" + ) + ] + ) + + let expected = """ + \(access) protocol FooService: FooService_StreamingServiceProtocol { + /// Some docs + func foo( + request: GRPCCore.ServerRequest, + context: GRPCCore.ServerContext + ) async throws -> GRPCCore.ServerResponse + } + """ + + #expect(render(.protocol(decl)) == expected) + } + + @Test("{ router, context in try await self.(...) }") + func routerHandlerInvokingRPC() { + let expression: ClosureInvocationDescription = .routerHandlerInvokingRPC(method: "foo") + let expected = """ + { request, context in + try await self.foo( + request: request, + context: context + ) + } + """ + #expect(render(.closureInvocation(expression)) == expected) + } + + @Test("router.registerHandler(...) { ... }") + func registerMethodsWithRouter() { + let expression: FunctionCallDescription = .registerWithRouter( + serviceNamespace: "FooService", + methodNamespace: "Bar", + methodName: "bar", + inputDeserializer: "Deserialize()", + outputSerializer: "Serialize()" + ) + + let expected = """ + router.registerHandler( + forMethod: FooService.Method.Bar.descriptor, + deserializer: Deserialize(), + serializer: Serialize(), + handler: { request, context in + try await self.bar( + request: request, + context: context + ) + } + ) + """ + + #expect(render(.functionCall(expression)) == expected) + } + + @Test("func registerMethods(router:)", arguments: AccessModifier.allCases) + func registerMethods(access: AccessModifier) { + let expression: FunctionDescription = .registerMethods( + accessLevel: access, + serviceNamespace: "FooService", + methods: [ + .init( + documentation: "", + name: .init(base: "Bar", generatedUpperCase: "Bar", generatedLowerCase: "bar"), + isInputStreaming: false, + isOutputStreaming: false, + inputType: "BarInput", + outputType: "BarOutput" + ) + ] + ) { type in + "Serialize<\(type)>()" + } deserializer: { type in + "Deserialize<\(type)>()" + } + + let expected = """ + \(access) func registerMethods(with router: inout GRPCCore.RPCRouter) { + router.registerHandler( + forMethod: FooService.Method.Bar.descriptor, + deserializer: Deserialize(), + serializer: Serialize(), + handler: { request, context in + try await self.bar( + request: request, + context: context + ) + } + ) + } + """ + + #expect(render(.function(expression)) == expected) + } + + @Test( + "func (request:context:) async throw { ... (convert to/from single) ... }", + arguments: AccessModifier.allCases, + RPCKind.allCases + ) + func serverStreamingMethodsCallingMethod(access: AccessModifier, kind: RPCKind) { + let expression: FunctionDescription = .serverStreamingMethodsCallingMethod( + accessLevel: access, + name: "foo", + input: "Input", + output: "Output", + streamingInput: kind.streamsInput, + streamingOutput: kind.streamsOutput + ) + + let expected: String + + switch kind { + case .unary: + expected = """ + \(access) func foo( + request: GRPCCore.StreamingServerRequest, + context: GRPCCore.ServerContext + ) async throws -> GRPCCore.StreamingServerResponse { + let response = try await self.foo( + request: GRPCCore.ServerRequest(stream: request), + context: context + ) + return GRPCCore.StreamingServerResponse(single: response) + } + """ + case .serverStreaming: + expected = """ + \(access) func foo( + request: GRPCCore.StreamingServerRequest, + context: GRPCCore.ServerContext + ) async throws -> GRPCCore.StreamingServerResponse { + let response = try await self.foo( + request: GRPCCore.ServerRequest(stream: request), + context: context + ) + return response + } + """ + case .clientStreaming: + expected = """ + \(access) func foo( + request: GRPCCore.StreamingServerRequest, + context: GRPCCore.ServerContext + ) async throws -> GRPCCore.StreamingServerResponse { + let response = try await self.foo( + request: request, + context: context + ) + return GRPCCore.StreamingServerResponse(single: response) + } + """ + case .bidirectionalStreaming: + expected = """ + \(access) func foo( + request: GRPCCore.StreamingServerRequest, + context: GRPCCore.ServerContext + ) async throws -> GRPCCore.StreamingServerResponse { + let response = try await self.foo( + request: request, + context: context + ) + return response + } + """ + } + + #expect(render(.function(expression)) == expected) + } + + @Test("extension FooService_ServiceProtocol { ... }", arguments: AccessModifier.allCases) + func streamingServiceProtocolDefaultImplementation(access: AccessModifier) { + let decl: ExtensionDescription = .streamingServiceProtocolDefaultImplementation( + accessModifier: access, + on: "Foo_ServiceProtocol", + methods: [ + .init( + documentation: "", + name: .init(base: "Foo", generatedUpperCase: "Foo", generatedLowerCase: "foo"), + isInputStreaming: false, + isOutputStreaming: false, + inputType: "FooInput", + outputType: "FooOutput" + ), + // Will be ignored as a bidirectional streaming method. + .init( + documentation: "", + name: .init(base: "Bar", generatedUpperCase: "Bar", generatedLowerCase: "bar"), + isInputStreaming: true, + isOutputStreaming: true, + inputType: "BarInput", + outputType: "BarOutput" + ), + ] + ) + + let expected = """ + extension Foo_ServiceProtocol { + \(access) func foo( + request: GRPCCore.StreamingServerRequest, + context: GRPCCore.ServerContext + ) async throws -> GRPCCore.StreamingServerResponse { + let response = try await self.foo( + request: GRPCCore.ServerRequest(stream: request), + context: context + ) + return GRPCCore.StreamingServerResponse(single: response) + } + } + """ + + #expect(render(.extension(decl)) == expected) + } + } +}