Skip to content

Commit

Permalink
Propagate the authority pseudoheader
Browse files Browse the repository at this point in the history
Motivation:

We should be sending the ":authority" pseudoheader, somehow this was
missed when doing the stream state machine.

The strategy for propagating the authority will be to:
- Use any override value set on the client transport, otherwise
- Use the value provided by a name resolver (if any), otherwise
- Derive a value from the target address.

Modifications:

- Add client config allowing the authority to be overridden
- Add authority to the name resolver interface
- Propagate the authority through to the stream state machine
- Add a percent encoder, as the authority should be percent encoded
- Remove hostname from the client TLS config and use the authority
  instead. This is generally a usability win as clients generally won't
  have to set this value themselves.

Result:

Authority is set automatically, and can be overridden if necessary.
  • Loading branch information
glbrntt committed Nov 29, 2024
1 parent 28d8171 commit 8cbd49a
Show file tree
Hide file tree
Showing 31 changed files with 624 additions and 145 deletions.
12 changes: 11 additions & 1 deletion Sources/GRPCNIOTransportCore/Client/Connection/Connection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ package final class Connection: Sendable {
/// The address to connect to.
private let address: SocketAddress

/// The server authority. If `nil`, a value will be computed based on the endpoint being
/// connected to.
private let authority: String?

/// The default compression algorithm used for requests.
private let defaultCompression: CompressionAlgorithm

Expand All @@ -109,11 +113,13 @@ package final class Connection: Sendable {

package init(
address: SocketAddress,
authority: String?,
http2Connector: any HTTP2Connector,
defaultCompression: CompressionAlgorithm,
enabledCompression: CompressionAlgorithmSet
) {
self.address = address
self.authority = authority
self.defaultCompression = defaultCompression
self.enabledCompression = enabledCompression
self.http2Connector = http2Connector
Expand All @@ -129,7 +135,10 @@ package final class Connection: Sendable {
package func run() async {
func establishConnectionOrThrow() async throws(RPCError) -> HTTP2Connection {
do {
return try await self.http2Connector.establishConnection(to: self.address)
return try await self.http2Connector.establishConnection(
to: self.address,
authority: self.authority ?? self.address.sniHostname
)
} catch let error as RPCError {
throw error
} catch {
Expand Down Expand Up @@ -214,6 +223,7 @@ package final class Connection: Sendable {
let streamHandler = GRPCClientStreamHandler(
methodDescriptor: descriptor,
scheme: scheme,
authority: self.authority ?? self.address.authority,
outboundEncoding: compression,
acceptedEncodings: self.enabledCompression,
maxPayloadSize: maxRequestSize
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ package import NIOHTTP2
internal import NIOPosix

package protocol HTTP2Connector: Sendable {
func establishConnection(to address: SocketAddress) async throws -> HTTP2Connection
func establishConnection(
to address: SocketAddress,
authority: String?
) async throws -> HTTP2Connection
}

package struct HTTP2Connection: Sendable {
Expand Down
12 changes: 12 additions & 0 deletions Sources/GRPCNIOTransportCore/Client/Connection/GRPCChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ package final class GRPCChannel: ClientTransport {
/// A factory for connections.
private let connector: any HTTP2Connector

/// The server authority. If `nil`, a value will be computed based on the endpoint being
/// connected to.
private let authority: String?

/// The connection backoff configuration used by the subchannel when establishing a connection.
private let backoff: ConnectionBackoff

Expand Down Expand Up @@ -82,6 +86,12 @@ package final class GRPCChannel: ClientTransport {
self.input = AsyncStream.makeStream()
self.connector = connector

if let authority = config.http2.authority ?? resolver.authority {
self.authority = PercentEncoding.encodeAuthority(authority)
} else {
self.authority = nil
}

self.backoff = ConnectionBackoff(
initial: config.backoff.initial,
max: config.backoff.max,
Expand Down Expand Up @@ -446,6 +456,7 @@ extension GRPCChannel {
state.changeLoadBalancerKind(to: loadBalancerConfig) {
let loadBalancer = RoundRobinLoadBalancer(
connector: self.connector,
authority: self.authority,
backoff: self.backoff,
defaultCompression: self.defaultCompression,
enabledCompression: self.enabledCompression
Expand All @@ -463,6 +474,7 @@ extension GRPCChannel {
state.changeLoadBalancerKind(to: loadBalancerConfig) {
let loadBalancer = PickFirstLoadBalancer(
connector: self.connector,
authority: self.authority,
backoff: self.backoff,
defaultCompression: self.defaultCompression,
enabledCompression: self.enabledCompression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ package final class PickFirstLoadBalancer: Sendable {
/// A connector, capable of creating connections.
private let connector: any HTTP2Connector

/// The server authority. If `nil`, a value will be computed based on the endpoint being
/// connected to.
private let authority: String?

/// Connection backoff configuration.
private let backoff: ConnectionBackoff

Expand All @@ -94,11 +98,13 @@ package final class PickFirstLoadBalancer: Sendable {

package init(
connector: any HTTP2Connector,
authority: String?,
backoff: ConnectionBackoff,
defaultCompression: CompressionAlgorithm,
enabledCompression: CompressionAlgorithmSet
) {
self.connector = connector
self.authority = authority
self.backoff = backoff
self.defaultCompression = defaultCompression
self.enabledCompression = enabledCompression
Expand Down Expand Up @@ -174,6 +180,7 @@ extension PickFirstLoadBalancer {
endpoint: endpoint,
id: id,
connector: self.connector,
authority: self.authority,
backoff: self.backoff,
defaultCompression: self.defaultCompression,
enabledCompression: self.enabledCompression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ package final class RoundRobinLoadBalancer: Sendable {
/// A connector, capable of creating connections.
private let connector: any HTTP2Connector

/// The server authority. If `nil`, a value will be computed based on the endpoint being
/// connected to.
private let authority: String?

/// Connection backoff configuration.
private let backoff: ConnectionBackoff

Expand All @@ -123,11 +127,13 @@ package final class RoundRobinLoadBalancer: Sendable {

package init(
connector: any HTTP2Connector,
authority: String?,
backoff: ConnectionBackoff,
defaultCompression: CompressionAlgorithm,
enabledCompression: CompressionAlgorithmSet
) {
self.connector = connector
self.authority = authority
self.backoff = backoff
self.defaultCompression = defaultCompression
self.enabledCompression = enabledCompression
Expand Down Expand Up @@ -223,6 +229,7 @@ extension RoundRobinLoadBalancer {
endpoint: endpoint,
id: id,
connector: self.connector,
authority: self.authority,
backoff: self.backoff,
defaultCompression: self.defaultCompression,
enabledCompression: self.enabledCompression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ package final class Subchannel: Sendable {
/// A factory for connections.
private let connector: any HTTP2Connector

/// The server authority. If `nil`, a value will be computed based on the endpoint being
/// connected to.
private let authority: String?

/// The connection backoff configuration used by the subchannel when establishing a connection.
private let backoff: ConnectionBackoff

Expand All @@ -96,6 +100,7 @@ package final class Subchannel: Sendable {
endpoint: Endpoint,
id: SubchannelID,
connector: any HTTP2Connector,
authority: String?,
backoff: ConnectionBackoff,
defaultCompression: CompressionAlgorithm,
enabledCompression: CompressionAlgorithmSet
Expand All @@ -106,6 +111,7 @@ package final class Subchannel: Sendable {
self.endpoint = endpoint
self.id = id
self.connector = connector
self.authority = authority
self.backoff = backoff
self.defaultCompression = defaultCompression
self.enabledCompression = enabledCompression
Expand Down Expand Up @@ -194,6 +200,7 @@ extension Subchannel {
state.makeConnection(
to: self.endpoint.addresses,
using: self.connector,
authority: self.authority,
backoff: self.backoff,
defaultCompression: self.defaultCompression,
enabledCompression: self.enabledCompression
Expand Down Expand Up @@ -283,7 +290,10 @@ extension Subchannel {
}

private func handleConnectFailedEvent(in group: inout DiscardingTaskGroup, error: RPCError) {
let onConnectFailed = self.state.withLock { $0.connectFailed(connector: self.connector) }
let onConnectFailed = self.state.withLock {
$0.connectFailed(connector: self.connector, authority: self.authority)
}

switch onConnectFailed {
case .connect(let connection):
// Try the next address.
Expand Down Expand Up @@ -469,6 +479,7 @@ extension Subchannel {
mutating func makeConnection(
to addresses: [SocketAddress],
using connector: any HTTP2Connector,
authority: String?,
backoff: ConnectionBackoff,
defaultCompression: CompressionAlgorithm,
enabledCompression: CompressionAlgorithmSet
Expand All @@ -480,6 +491,7 @@ extension Subchannel {

let connection = Connection(
address: address,
authority: authority,
http2Connector: connector,
defaultCompression: defaultCompression,
enabledCompression: enabledCompression
Expand Down Expand Up @@ -563,14 +575,18 @@ extension Subchannel {
case backoff(Duration)
}

mutating func connectFailed(connector: any HTTP2Connector) -> OnConnectFailed {
mutating func connectFailed(
connector: any HTTP2Connector,
authority: String?
) -> OnConnectFailed {
let onConnectFailed: OnConnectFailed

switch self {
case .connecting(var state):
if let address = state.addressIterator.next() {
state.connection = Connection(
address: address,
authority: authority,
http2Connector: connector,
defaultCompression: .none,
enabledCompression: .all
Expand All @@ -582,6 +598,7 @@ extension Subchannel {
let address = state.addressIterator.next()!
state.connection = Connection(
address: address,
authority: authority,
http2Connector: connector,
defaultCompression: .none,
enabledCompression: .all
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ final class GRPCClientStreamHandler: ChannelDuplexHandler {
init(
methodDescriptor: MethodDescriptor,
scheme: Scheme,
authority: String?,
outboundEncoding: CompressionAlgorithm,
acceptedEncodings: CompressionAlgorithmSet,
maxPayloadSize: Int,
Expand All @@ -43,6 +44,7 @@ final class GRPCClientStreamHandler: ChannelDuplexHandler {
.init(
methodDescriptor: methodDescriptor,
scheme: scheme,
authority: authority,
outboundEncoding: outboundEncoding,
acceptedEncodings: acceptedEncodings
)
Expand Down
13 changes: 11 additions & 2 deletions Sources/GRPCNIOTransportCore/Client/HTTP2ClientTransport.swift
Original file line number Diff line number Diff line change
Expand Up @@ -144,15 +144,24 @@ extension HTTP2ClientTransport.Config {
/// The value is clamped to `... (1 << 31) - 1`.
public var targetWindowSize: Int

/// The authority of the server.
///
/// Any value set here will unconditionally override any value derived from the target address.
///
/// The server authority is used in the ":authority" pseudoheader and in the TLS SNI
/// extension, if applicable.
public var authority: String?

/// Creates a new HTTP/2 configuration.
public init(maxFrameSize: Int, targetWindowSize: Int) {
public init(maxFrameSize: Int, targetWindowSize: Int, authority: String?) {
self.maxFrameSize = maxFrameSize
self.targetWindowSize = targetWindowSize
self.authority = authority
}

/// Default values, max frame size is 16KiB, and the target window size is 8MiB.
public static var defaults: Self {
Self(maxFrameSize: 1 << 14, targetWindowSize: 8 * 1024 * 1024)
Self(maxFrameSize: 1 << 14, targetWindowSize: 8 * 1024 * 1024, authority: nil)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ extension ResolvableTargets {
public var host: String

/// The port to use with resolved addresses.
public var port: Int
///
/// If no port is specified then 443 is used.
public var port: Int?

/// Create a new DNS target.
/// - Parameters:
/// - host: The host to resolve via DNS.
/// - port: The port to use with resolved addresses.
public init(host: String, port: Int) {
public init(host: String, port: Int?) {
self.host = host
self.port = port
}
Expand All @@ -43,9 +45,9 @@ extension ResolvableTarget where Self == ResolvableTargets.DNS {
/// Creates a new resolvable DNS target.
/// - Parameters:
/// - host: The host address to resolve.
/// - port: The port to use for each resolved address.
/// - port: The port to use for each resolved address. 443 will be used if unspecified.
/// - Returns: A ``ResolvableTarget``.
public static func dns(host: String, port: Int = 443) -> Self {
public static func dns(host: String, port: Int? = nil) -> Self {
return Self(host: host, port: port)
}
}
Expand All @@ -60,7 +62,14 @@ extension NameResolvers {

public func resolver(for target: Target) -> NameResolver {
let resolver = Self.Resolver(target: target)
return NameResolver(names: RPCAsyncSequence(wrapping: resolver), updateMode: .pull)
// Only append the port if explicitly set. If it's nil the default port of 443 is used
// should be omitted from the authority.
let authority = target.host + (target.port.map { ":\($0)" } ?? "")
return NameResolver(
names: RPCAsyncSequence(wrapping: resolver),
updateMode: .pull,
authority: authority
)
}
}
}
Expand All @@ -79,13 +88,16 @@ extension NameResolvers.DNS {
let addresses: [SocketAddress]

do {
addresses = try await DNSResolver.resolve(host: self.target.host, port: self.target.port)
addresses = try await DNSResolver.resolve(
host: self.target.host,
port: self.target.port ?? 443 // Assume TLS if no port is specified.
)
} catch let error as CancellationError {
throw error
} catch {
throw RPCError(
code: .internalError,
message: "Couldn't resolve address for \(self.target.host):\(self.target.port)",
message: "Couldn't resolve address for \(self.target.host):\(self.target.port ?? 443)",
cause: error
)
}
Expand Down
Loading

0 comments on commit 8cbd49a

Please sign in to comment.