diff --git a/Package.swift b/Package.swift index aa5aa60..33661d8 100644 --- a/Package.swift +++ b/Package.swift @@ -39,7 +39,7 @@ let dependencies: [Package.Dependency] = [ ), .package( url: "https://github.com/apple/swift-nio.git", - from: "2.65.0" + from: "2.75.0" ), .package( url: "https://github.com/apple/swift-nio-http2.git", diff --git a/Sources/GRPCNIOTransportCore/Client/Connection/ClientConnectionHandler.swift b/Sources/GRPCNIOTransportCore/Client/Connection/ClientConnectionHandler.swift index 0b5297b..3d8dbcb 100644 --- a/Sources/GRPCNIOTransportCore/Client/Connection/ClientConnectionHandler.swift +++ b/Sources/GRPCNIOTransportCore/Client/Connection/ClientConnectionHandler.swift @@ -63,17 +63,14 @@ package final class ClientConnectionHandler: ChannelInboundHandler, ChannelOutbo /// The `EventLoop` of the `Channel` this handler exists in. private let eventLoop: any EventLoop - /// The maximum amount of time the connection may be idle for. If the connection remains idle - /// (i.e. has no open streams) for this period of time then the connection will be gracefully - /// closed. - private var maxIdleTimer: Timer? + /// The timer used to gracefully close idle connections. + private var maxIdleTimerHandler: Timer? - /// The amount of time to wait before sending a keep alive ping. - private var keepaliveTimer: Timer? + /// The timer used to send keep-alive pings. + private var keepaliveTimerHandler: Timer? - /// The amount of time the client has to reply after sending a keep alive ping. Only used if - /// `keepaliveTimer` is set. - private var keepaliveTimeoutTimer: Timer + /// The timer used to detect keep alive timeouts, if keep-alive pings are enabled. + private var keepaliveTimeoutHandler: Timer? /// Opaque data sent in keep alive pings. private let keepalivePingData: HTTP2PingData @@ -110,14 +107,34 @@ package final class ClientConnectionHandler: ChannelInboundHandler, ChannelOutbo keepaliveWithoutCalls: Bool ) { self.eventLoop = eventLoop - self.maxIdleTimer = maxIdleTime.map { Timer(delay: $0) } - self.keepaliveTimer = keepaliveTime.map { Timer(delay: $0, repeat: true) } - self.keepaliveTimeoutTimer = Timer(delay: keepaliveTimeout ?? .seconds(20)) self.keepalivePingData = HTTP2PingData(withInteger: .random(in: .min ... .max)) self.state = StateMachine(allowKeepaliveWithoutCalls: keepaliveWithoutCalls) self.flushPending = false self.inReadLoop = false + if let maxIdleTime { + self.maxIdleTimerHandler = Timer( + eventLoop: eventLoop, + duration: maxIdleTime, + repeating: false, + handler: MaxIdleTimerHandlerView(self) + ) + } + if let keepaliveTime { + let keepaliveTimeout = keepaliveTimeout ?? .seconds(20) + self.keepaliveTimerHandler = Timer( + eventLoop: eventLoop, + duration: keepaliveTime, + repeating: true, + handler: KeepaliveTimerHandlerView(self) + ) + self.keepaliveTimeoutHandler = Timer( + eventLoop: eventLoop, + duration: keepaliveTimeout, + repeating: false, + handler: KeepaliveTimeoutHandlerView(self) + ) + } } package func handlerAdded(context: ChannelHandlerContext) { @@ -142,8 +159,8 @@ package final class ClientConnectionHandler: ChannelInboundHandler, ChannelOutbo promise.succeed() } - self.keepaliveTimer?.cancel() - self.keepaliveTimeoutTimer.cancel() + self.keepaliveTimerHandler?.cancel() + self.keepaliveTimeoutHandler?.cancel() context.fireChannelInactive() } @@ -222,11 +239,8 @@ package final class ClientConnectionHandler: ChannelInboundHandler, ChannelOutbo // Pings are ack'd by the HTTP/2 handler so we only pay attention to acks here, and in // particular only those carrying the keep-alive data. if ack, data == self.keepalivePingData { - let loopBound = LoopBoundView(handler: self, context: context) - self.keepaliveTimeoutTimer.cancel() - self.keepaliveTimer?.schedule(on: context.eventLoop) { - loopBound.keepaliveTimerFired() - } + self.keepaliveTimeoutHandler?.cancel() + self.keepaliveTimerHandler?.start() } case .settings(.settings(_)): @@ -236,15 +250,8 @@ package final class ClientConnectionHandler: ChannelInboundHandler, ChannelOutbo // becoming active is insufficient as, for example, a TLS handshake may fail after // establishing the TCP connection, or the server isn't configured for gRPC (or HTTP/2). if isInitialSettings { - let loopBound = LoopBoundView(handler: self, context: context) - self.keepaliveTimer?.schedule(on: context.eventLoop) { - loopBound.keepaliveTimerFired() - } - - self.maxIdleTimer?.schedule(on: context.eventLoop) { - loopBound.maxIdleTimerFired() - } - + self.keepaliveTimerHandler?.start() + self.maxIdleTimerHandler?.start() context.fireChannelRead(self.wrapInboundOut(.ready)) } @@ -290,29 +297,44 @@ package final class ClientConnectionHandler: ChannelInboundHandler, ChannelOutbo } } +// Timer handler views. extension ClientConnectionHandler { - struct LoopBoundView: @unchecked Sendable { + struct MaxIdleTimerHandlerView: @unchecked Sendable, NIOScheduledCallbackHandler { private let handler: ClientConnectionHandler - private let context: ChannelHandlerContext - init(handler: ClientConnectionHandler, context: ChannelHandlerContext) { + init(_ handler: ClientConnectionHandler) { self.handler = handler - self.context = context } - func keepaliveTimerFired() { - self.context.eventLoop.assertInEventLoop() - self.handler.keepaliveTimerFired(context: self.context) + func handleScheduledCallback(eventLoop: some EventLoop) { + self.handler.eventLoop.assertInEventLoop() + self.handler.maxIdleTimerFired() } + } + + struct KeepaliveTimerHandlerView: @unchecked Sendable, NIOScheduledCallbackHandler { + private let handler: ClientConnectionHandler - func keepaliveTimeoutExpired() { - self.context.eventLoop.assertInEventLoop() - self.handler.keepaliveTimeoutExpired(context: self.context) + init(_ handler: ClientConnectionHandler) { + self.handler = handler } - func maxIdleTimerFired() { - self.context.eventLoop.assertInEventLoop() - self.handler.maxIdleTimerFired(context: self.context) + func handleScheduledCallback(eventLoop: some EventLoop) { + self.handler.eventLoop.assertInEventLoop() + self.handler.keepaliveTimerFired() + } + } + + struct KeepaliveTimeoutHandlerView: @unchecked Sendable, NIOScheduledCallbackHandler { + private let handler: ClientConnectionHandler + + init(_ handler: ClientConnectionHandler) { + self.handler = handler + } + + func handleScheduledCallback(eventLoop: some EventLoop) { + self.handler.eventLoop.assertInEventLoop() + self.handler.keepaliveTimeoutExpired() } } } @@ -356,7 +378,7 @@ extension ClientConnectionHandler { self.eventLoop.assertInEventLoop() // Stream created, so the connection isn't idle. - self.maxIdleTimer?.cancel() + self.maxIdleTimerHandler?.cancel() self.state.streamOpened(id) } @@ -368,13 +390,10 @@ extension ClientConnectionHandler { case .startIdleTimer(let cancelKeepalive): // All streams are closed, restart the idle timer, and stop the keep-alive timer (it may // not stop if keep-alive is allowed when there are no active calls). - let loopBound = LoopBoundView(handler: self, context: context) - self.maxIdleTimer?.schedule(on: context.eventLoop) { - loopBound.maxIdleTimerFired() - } + self.maxIdleTimerHandler?.start() if cancelKeepalive { - self.keepaliveTimer?.cancel() + self.keepaliveTimerHandler?.cancel() } case .close: @@ -397,34 +416,31 @@ extension ClientConnectionHandler { } } - private func keepaliveTimerFired(context: ChannelHandlerContext) { - guard self.state.sendKeepalivePing() else { return } + private func keepaliveTimerFired() { + guard self.state.sendKeepalivePing(), let context = self.context else { return } // Cancel the keep alive timer when the client sends a ping. The timer is resumed when the ping // is acknowledged. - self.keepaliveTimer?.cancel() + self.keepaliveTimerHandler?.cancel() let ping = HTTP2Frame(streamID: .rootStream, payload: .ping(self.keepalivePingData, ack: false)) context.write(self.wrapOutboundOut(ping), promise: nil) self.maybeFlush(context: context) // Schedule a timeout on waiting for the response. - let loopBound = LoopBoundView(handler: self, context: context) - self.keepaliveTimeoutTimer.schedule(on: context.eventLoop) { - loopBound.keepaliveTimeoutExpired() - } + self.keepaliveTimeoutHandler?.start() } - private func keepaliveTimeoutExpired(context: ChannelHandlerContext) { - guard self.state.beginClosing() else { return } + private func keepaliveTimeoutExpired() { + guard self.state.beginClosing(), let context = self.context else { return } context.fireChannelRead(self.wrapInboundOut(.closing(.keepaliveExpired))) self.writeAndFlushGoAway(context: context, message: "keepalive_expired") context.close(promise: nil) } - private func maxIdleTimerFired(context: ChannelHandlerContext) { - guard self.state.beginClosing() else { return } + private func maxIdleTimerFired() { + guard self.state.beginClosing(), let context = self.context else { return } context.fireChannelRead(self.wrapInboundOut(.closing(.idle))) self.writeAndFlushGoAway(context: context, message: "idle") diff --git a/Sources/GRPCNIOTransportCore/Internal/Timer.swift b/Sources/GRPCNIOTransportCore/Internal/Timer.swift index bfc4ff2..1180677 100644 --- a/Sources/GRPCNIOTransportCore/Internal/Timer.swift +++ b/Sources/GRPCNIOTransportCore/Internal/Timer.swift @@ -16,55 +16,54 @@ package import NIOCore -package struct Timer { - /// The delay to wait before running the task. - private let delay: TimeAmount - /// The task to run, if scheduled. - private var task: Kind? - /// Whether the task to schedule is repeated. - private let `repeat`: Bool +/// A timer backed by `NIOScheduledCallback`. +package final class Timer where Handler: Sendable { + /// The event loop on which to run this timer. + private let eventLoop: any EventLoop - private enum Kind { - case once(Scheduled) - case repeated(RepeatedTask) + /// The duration of the timer. + private let duration: TimeAmount - func cancel() { - switch self { - case .once(let task): - task.cancel() - case .repeated(let task): - task.cancel() - } - } - } + /// Whether this timer should repeat. + private let repeating: Bool + + /// The handler to call when the timer fires. + private let handler: Handler + + /// The currently scheduled callback if the timer is running. + private var scheduledCallback: NIOScheduledCallback? - package init(delay: TimeAmount, repeat: Bool = false) { - self.delay = delay - self.task = nil - self.repeat = `repeat` + package init(eventLoop: any EventLoop, duration: TimeAmount, repeating: Bool, handler: Handler) { + self.eventLoop = eventLoop + self.duration = duration + self.repeating = repeating + self.handler = handler + self.scheduledCallback = nil } - /// Schedule a task on the given `EventLoop`. - package mutating func schedule( - on eventLoop: any EventLoop, - work: @escaping @Sendable () throws -> Void - ) { - self.task?.cancel() + /// Cancel the timer, if it is running. + package func cancel() { + self.eventLoop.assertInEventLoop() + guard let scheduledCallback = self.scheduledCallback else { return } + scheduledCallback.cancel() + } - if self.repeat { - let task = eventLoop.scheduleRepeatedTask(initialDelay: self.delay, delay: self.delay) { _ in - try work() - } - self.task = .repeated(task) - } else { - let task = eventLoop.scheduleTask(in: self.delay, work) - self.task = .once(task) - } + /// Start or restart the timer. + package func start() { + self.eventLoop.assertInEventLoop() + self.scheduledCallback?.cancel() + // Only throws if the event loop is shutting down, so we'll just swallow the error here. + self.scheduledCallback = try? self.eventLoop.scheduleCallback(in: self.duration, handler: self) } +} - /// Cancels the task, if one was scheduled. - package mutating func cancel() { - self.task?.cancel() - self.task = nil +extension Timer: NIOScheduledCallbackHandler, @unchecked Sendable where Handler: Sendable { + /// For repeated timer support, the timer itself proxies the callback and restarts the timer. + /// + /// - NOTE: Users should not call this function directly. + package func handleScheduledCallback(eventLoop: some EventLoop) { + self.eventLoop.assertInEventLoop() + self.handler.handleScheduledCallback(eventLoop: eventLoop) + if self.repeating { self.start() } } } diff --git a/Sources/GRPCNIOTransportCore/Server/Connection/ServerConnectionManagementHandler.swift b/Sources/GRPCNIOTransportCore/Server/Connection/ServerConnectionManagementHandler.swift index 0f43e59..eb66466 100644 --- a/Sources/GRPCNIOTransportCore/Server/Connection/ServerConnectionManagementHandler.swift +++ b/Sources/GRPCNIOTransportCore/Server/Connection/ServerConnectionManagementHandler.swift @@ -48,25 +48,21 @@ package final class ServerConnectionManagementHandler: ChannelDuplexHandler { /// The `EventLoop` of the `Channel` this handler exists in. private let eventLoop: any EventLoop - /// The maximum amount of time a connection may be idle for. If the connection remains idle - /// (i.e. has no open streams) for this period of time then the connection will be gracefully - /// closed. - private var maxIdleTimer: Timer? + /// The timer used to gracefully close idle connections. + private var maxIdleTimerHandler: Timer? - /// The maximum age of a connection. If the connection remains open after this amount of time - /// then it will be gracefully closed. - private var maxAgeTimer: Timer? + /// The timer used to gracefully close old connections. + private var maxAgeTimerHandler: Timer? - /// The maximum amount of time a connection may spend closing gracefully, after which it is - /// closed abruptly. The timer starts after the second GOAWAY frame has been sent. - private var maxGraceTimer: Timer? + /// The timer used to forcefully close a connection during a graceful close. + /// The timer starts after the second GOAWAY frame has been sent. + private var maxGraceTimerHandler: Timer? - /// The amount of time to wait before sending a keep alive ping. - private var keepaliveTimer: Timer? + /// The timer used to send keep-alive pings. + private var keepaliveTimerHandler: Timer? - /// The amount of time the client has to reply after sending a keep alive ping. Only used if - /// `keepaliveTimer` is set. - private var keepaliveTimeoutTimer: Timer + /// The timer used to detect keep alive timeouts, if keep-alive pings are enabled. + private var keepaliveTimeoutHandler: Timer? /// Opaque data sent in keep alive pings. private let keepalivePingData: HTTP2PingData @@ -222,14 +218,6 @@ package final class ServerConnectionManagementHandler: ChannelDuplexHandler { ) { self.eventLoop = eventLoop - self.maxIdleTimer = maxIdleTime.map { Timer(delay: $0) } - self.maxAgeTimer = maxAge.map { Timer(delay: $0) } - self.maxGraceTimer = maxGraceTime.map { Timer(delay: $0) } - - self.keepaliveTimer = keepaliveTime.map { Timer(delay: $0) } - // Always create a keep alive timeout timer, it's only used if there is a keep alive timer. - self.keepaliveTimeoutTimer = Timer(delay: keepaliveTimeout ?? .seconds(20)) - // Generate a random value to be used as keep alive ping data. let pingData = UInt64.random(in: .min ... .max) self.keepalivePingData = HTTP2PingData(withInteger: pingData) @@ -246,6 +234,47 @@ package final class ServerConnectionManagementHandler: ChannelDuplexHandler { self.frameStats = FrameStats() self.requireALPN = requireALPN + + if let maxIdleTime { + self.maxIdleTimerHandler = Timer( + eventLoop: eventLoop, + duration: maxIdleTime, + repeating: false, + handler: MaxIdleTimerHandlerView(self) + ) + } + if let maxAge { + self.maxAgeTimerHandler = Timer( + eventLoop: eventLoop, + duration: maxAge, + repeating: false, + handler: MaxAgeTimerHandlerView(self) + ) + } + if let maxGraceTime { + self.maxGraceTimerHandler = Timer( + eventLoop: eventLoop, + duration: maxGraceTime, + repeating: false, + handler: MaxGraceTimerHandlerView(self) + ) + } + if let keepaliveTime { + let keepaliveTimeout = keepaliveTimeout ?? .seconds(20) + // NOTE: The use of a non-repeating timer is deliberate for the server, and is different from the client. + self.keepaliveTimerHandler = Timer( + eventLoop: eventLoop, + duration: keepaliveTime, + repeating: false, + handler: KeepaliveTimerHandlerView(self) + ) + self.keepaliveTimeoutHandler = Timer( + eventLoop: eventLoop, + duration: keepaliveTimeout, + repeating: false, + handler: KeepaliveTimeoutHandlerView(self) + ) + } } package func handlerAdded(context: ChannelHandlerContext) { @@ -258,29 +287,18 @@ package final class ServerConnectionManagementHandler: ChannelDuplexHandler { } package func channelActive(context: ChannelHandlerContext) { - let view = LoopBoundView(handler: self, context: context) - - self.maxAgeTimer?.schedule(on: context.eventLoop) { - view.initiateGracefulShutdown() - } - - self.maxIdleTimer?.schedule(on: context.eventLoop) { - view.initiateGracefulShutdown() - } - - self.keepaliveTimer?.schedule(on: context.eventLoop) { - view.keepaliveTimerFired() - } - + self.maxAgeTimerHandler?.start() + self.maxIdleTimerHandler?.start() + self.keepaliveTimerHandler?.start() context.fireChannelActive() } package func channelInactive(context: ChannelHandlerContext) { - self.maxIdleTimer?.cancel() - self.maxAgeTimer?.cancel() - self.maxGraceTimer?.cancel() - self.keepaliveTimer?.cancel() - self.keepaliveTimeoutTimer.cancel() + self.maxIdleTimerHandler?.cancel() + self.maxAgeTimerHandler?.cancel() + self.maxGraceTimerHandler?.cancel() + self.keepaliveTimerHandler?.cancel() + self.keepaliveTimeoutHandler?.cancel() context.fireChannelInactive() } @@ -293,7 +311,7 @@ package final class ServerConnectionManagementHandler: ChannelDuplexHandler { self._streamClosed(event.streamID, channel: context.channel) case is ChannelShouldQuiesceEvent: - self.initiateGracefulShutdown(context: context) + self.initiateGracefulShutdown() case TLSUserEvent.handshakeCompleted(let negotiatedProtocol): if negotiatedProtocol == nil, self.requireALPN { @@ -349,8 +367,8 @@ package final class ServerConnectionManagementHandler: ChannelDuplexHandler { self.inReadLoop = true // Any read data indicates that the connection is alive so cancel the keep-alive timers. - self.keepaliveTimer?.cancel() - self.keepaliveTimeoutTimer.cancel() + self.keepaliveTimerHandler?.cancel() + self.keepaliveTimeoutHandler?.cancel() let frame = self.unwrapInboundIn(data) switch frame.payload { @@ -377,10 +395,7 @@ package final class ServerConnectionManagementHandler: ChannelDuplexHandler { self.inReadLoop = false // Done reading: schedule the keep-alive timer. - let view = LoopBoundView(handler: self, context: context) - self.keepaliveTimer?.schedule(on: context.eventLoop) { - view.keepaliveTimerFired() - } + self.keepaliveTimerHandler?.start() context.fireChannelReadComplete() } @@ -390,26 +405,71 @@ package final class ServerConnectionManagementHandler: ChannelDuplexHandler { } } +// Timer handler views. extension ServerConnectionManagementHandler { - struct LoopBoundView: @unchecked Sendable { + struct MaxIdleTimerHandlerView: @unchecked Sendable, NIOScheduledCallbackHandler { private let handler: ServerConnectionManagementHandler - private let context: ChannelHandlerContext - init(handler: ServerConnectionManagementHandler, context: ChannelHandlerContext) { + init(_ handler: ServerConnectionManagementHandler) { self.handler = handler - self.context = context } - func initiateGracefulShutdown() { - self.context.eventLoop.assertInEventLoop() - self.handler.initiateGracefulShutdown(context: self.context) + func handleScheduledCallback(eventLoop: some EventLoop) { + self.handler.eventLoop.assertInEventLoop() + self.handler.initiateGracefulShutdown() } + } + + struct MaxAgeTimerHandlerView: @unchecked Sendable, NIOScheduledCallbackHandler { + private let handler: ServerConnectionManagementHandler - func keepaliveTimerFired() { - self.context.eventLoop.assertInEventLoop() - self.handler.keepaliveTimerFired(context: self.context) + init(_ handler: ServerConnectionManagementHandler) { + self.handler = handler } + func handleScheduledCallback(eventLoop: some EventLoop) { + self.handler.eventLoop.assertInEventLoop() + self.handler.initiateGracefulShutdown() + } + } + + struct MaxGraceTimerHandlerView: @unchecked Sendable, NIOScheduledCallbackHandler { + private let handler: ServerConnectionManagementHandler + + init(_ handler: ServerConnectionManagementHandler) { + self.handler = handler + } + + func handleScheduledCallback(eventLoop: some EventLoop) { + self.handler.eventLoop.assertInEventLoop() + self.handler.context?.close(promise: nil) + } + } + + struct KeepaliveTimerHandlerView: @unchecked Sendable, NIOScheduledCallbackHandler { + private let handler: ServerConnectionManagementHandler + + init(_ handler: ServerConnectionManagementHandler) { + self.handler = handler + } + + func handleScheduledCallback(eventLoop: some EventLoop) { + self.handler.eventLoop.assertInEventLoop() + self.handler.keepaliveTimerFired() + } + } + + struct KeepaliveTimeoutHandlerView: @unchecked Sendable, NIOScheduledCallbackHandler { + private let handler: ServerConnectionManagementHandler + + init(_ handler: ServerConnectionManagementHandler) { + self.handler = handler + } + + func handleScheduledCallback(eventLoop: some EventLoop) { + self.handler.eventLoop.assertInEventLoop() + self.handler.initiateGracefulShutdown() + } } } @@ -450,7 +510,7 @@ extension ServerConnectionManagementHandler { private func _streamCreated(_ id: HTTP2StreamID, channel: any Channel) { // The connection isn't idle if a stream is open. - self.maxIdleTimer?.cancel() + self.maxIdleTimerHandler?.cancel() self.state.streamOpened(id) } @@ -459,11 +519,7 @@ extension ServerConnectionManagementHandler { switch self.state.streamClosed(id) { case .startIdleTimer: - let loopBound = LoopBoundView(handler: self, context: context) - self.maxIdleTimer?.schedule(on: context.eventLoop) { - loopBound.initiateGracefulShutdown() - } - + self.maxIdleTimerHandler?.start() case .close: context.close(mode: .all, promise: nil) @@ -482,14 +538,15 @@ extension ServerConnectionManagementHandler { } } - private func initiateGracefulShutdown(context: ChannelHandlerContext) { + private func initiateGracefulShutdown() { + guard let context = self.context else { return } context.eventLoop.assertInEventLoop() // Cancel any timers if initiating shutdown. - self.maxIdleTimer?.cancel() - self.maxAgeTimer?.cancel() - self.keepaliveTimer?.cancel() - self.keepaliveTimeoutTimer.cancel() + self.maxIdleTimerHandler?.cancel() + self.maxAgeTimerHandler?.cancel() + self.keepaliveTimerHandler?.cancel() + self.keepaliveTimeoutHandler?.cancel() switch self.state.startGracefulShutdown() { case .sendGoAwayAndPing(let pingData): @@ -562,10 +619,7 @@ extension ServerConnectionManagementHandler { } else { // RPCs may have a grace period for finishing once the second GOAWAY frame has finished. // If this is set close the connection abruptly once the grace period passes. - let loopBound = NIOLoopBound(context, eventLoop: context.eventLoop) - self.maxGraceTimer?.schedule(on: context.eventLoop) { - loopBound.value.close(promise: nil) - } + self.maxGraceTimerHandler?.start() } case .none: @@ -573,15 +627,13 @@ extension ServerConnectionManagementHandler { } } - private func keepaliveTimerFired(context: ChannelHandlerContext) { + private func keepaliveTimerFired() { + guard let context = self.context else { return } let ping = HTTP2Frame(streamID: .rootStream, payload: .ping(self.keepalivePingData, ack: false)) context.write(self.wrapInboundOut(ping), promise: nil) self.maybeFlush(context: context) // Schedule a timeout on waiting for the response. - let loopBound = LoopBoundView(handler: self, context: context) - self.keepaliveTimeoutTimer.schedule(on: context.eventLoop) { - loopBound.initiateGracefulShutdown() - } + self.keepaliveTimeoutHandler?.start() } } diff --git a/Tests/GRPCNIOTransportCoreTests/Internal/TimerTests.swift b/Tests/GRPCNIOTransportCoreTests/Internal/TimerTests.swift index 47d461a..c112b0f 100644 --- a/Tests/GRPCNIOTransportCoreTests/Internal/TimerTests.swift +++ b/Tests/GRPCNIOTransportCoreTests/Internal/TimerTests.swift @@ -16,82 +16,100 @@ import GRPCCore import GRPCNIOTransportCore +import NIOCore import NIOEmbedded import Synchronization import XCTest internal final class TimerTests: XCTestCase { - func testScheduleOneOffTimer() { + fileprivate struct CounterTimerHandler: NIOScheduledCallbackHandler { + let counter = AtomicCounter(0) + + func handleScheduledCallback(eventLoop: some EventLoop) { + counter.increment() + } + } + + func testOneOffTimer() { let loop = EmbeddedEventLoop() defer { try! loop.close() } - let value = Atomic(0) - var timer = Timer(delay: .seconds(1), repeat: false) - timer.schedule(on: loop) { - let (old, _) = value.add(1, ordering: .releasing) - XCTAssertEqual(old, 0) - } + let handler = CounterTimerHandler() + let timer = Timer(eventLoop: loop, duration: .seconds(1), repeating: false, handler: handler) + timer.start() + // Timer hasn't fired because we haven't reached the required duration. loop.advanceTime(by: .milliseconds(999)) - XCTAssertEqual(value.load(ordering: .acquiring), 0) + XCTAssertEqual(handler.counter.value, 0) + + // Timer has fired once. loop.advanceTime(by: .milliseconds(1)) - XCTAssertEqual(value.load(ordering: .acquiring), 1) + XCTAssertEqual(handler.counter.value, 1) - // Run again to make sure the task wasn't repeated. + // Timer does not repeat. loop.advanceTime(by: .seconds(1)) - XCTAssertEqual(value.load(ordering: .acquiring), 1) - } + XCTAssertEqual(handler.counter.value, 1) - func testCancelOneOffTimer() { - let loop = EmbeddedEventLoop() - defer { try! loop.close() } - - var timer = Timer(delay: .seconds(1), repeat: false) - timer.schedule(on: loop) { - XCTFail("Timer wasn't cancelled") - } + // Timer can be restarted and then fires again after the duration. + timer.start() + loop.advanceTime(by: .seconds(1)) + XCTAssertEqual(handler.counter.value, 2) + // Timer can be cancelled before the duration and then does not fire. + timer.start() loop.advanceTime(by: .milliseconds(999)) timer.cancel() loop.advanceTime(by: .milliseconds(1)) + XCTAssertEqual(handler.counter.value, 2) + + // Timer can be restarted after being cancelled. + timer.start() + loop.advanceTime(by: .seconds(1)) + XCTAssertEqual(handler.counter.value, 3) } - func testScheduleRepeatedTimer() throws { + func testRepeatedTimer() { let loop = EmbeddedEventLoop() defer { try! loop.close() } - let counter = AtomicCounter() - var timer = Timer(delay: .seconds(1), repeat: true) - timer.schedule(on: loop) { - counter.increment() - } + let handler = CounterTimerHandler() + let timer = Timer(eventLoop: loop, duration: .seconds(1), repeating: true, handler: handler) + timer.start() + // Timer hasn't fired because we haven't reached the required duration. loop.advanceTime(by: .milliseconds(999)) - XCTAssertEqual(counter.value, 0) + XCTAssertEqual(handler.counter.value, 0) + + // Timer has fired once. + loop.advanceTime(by: .milliseconds(1)) + XCTAssertEqual(handler.counter.value, 1) + + // Timer hasn't fired again because we haven't reached the required duration again. + loop.advanceTime(by: .milliseconds(999)) + XCTAssertEqual(handler.counter.value, 1) + + // Timer has fired again. loop.advanceTime(by: .milliseconds(1)) - XCTAssertEqual(counter.value, 1) + XCTAssertEqual(handler.counter.value, 2) + // Timer continues to fire on each second. loop.advanceTime(by: .seconds(1)) - XCTAssertEqual(counter.value, 2) + XCTAssertEqual(handler.counter.value, 3) loop.advanceTime(by: .seconds(1)) - XCTAssertEqual(counter.value, 3) - - timer.cancel() + XCTAssertEqual(handler.counter.value, 4) loop.advanceTime(by: .seconds(1)) - XCTAssertEqual(counter.value, 3) - } + XCTAssertEqual(handler.counter.value, 5) + loop.advanceTime(by: .seconds(5)) + XCTAssertEqual(handler.counter.value, 10) - func testCancelRepeatedTimer() { - let loop = EmbeddedEventLoop() - defer { try! loop.close() } - - var timer = Timer(delay: .seconds(1), repeat: true) - timer.schedule(on: loop) { - XCTFail("Timer wasn't cancelled") - } - - loop.advanceTime(by: .milliseconds(999)) + // Timer does not fire again, after being cancelled. timer.cancel() - loop.advanceTime(by: .milliseconds(1)) + loop.advanceTime(by: .seconds(5)) + XCTAssertEqual(handler.counter.value, 10) + + // Timer can be restarted after being cancelled and continues to fire once per second. + timer.start() + loop.advanceTime(by: .seconds(5)) + XCTAssertEqual(handler.counter.value, 15) } }