From e934175ca9a0460378426fd2e59f76c78d80fba4 Mon Sep 17 00:00:00 2001 From: Charles Willis Date: Wed, 18 Dec 2024 15:03:04 +0000 Subject: [PATCH 01/25] Update package versions and target frameworks - Upgraded several package versions in `Directory.Packages.props`, including `Microsoft.Bcl.AsyncInterfaces`, `System.Collections.Immutable`, `System.IO.Pipelines`, and `System.Text.Json`. Added `Nerdbank.MessagePack`. - Modified `nuget.config` to add a new package source for `nuget.org` and included package source mappings. - Updated `StreamJsonRpc.csproj` to target only `net8.0`, removing `net6.0`, and added a reference to `Nerdbank.MessagePack`. - Changed `Benchmarks.csproj` to target `net8.0` instead of `net6.0`. - Adjusted `StreamJsonRpc.Tests.csproj` to exclusively target `net8.0`. --- Directory.Packages.props | 9 +++++---- nuget.config | 12 ++++++++++++ src/StreamJsonRpc/StreamJsonRpc.csproj | 3 ++- test/Benchmarks/Benchmarks.csproj | 2 +- test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj | 2 +- 5 files changed, 21 insertions(+), 7 deletions(-) diff --git a/Directory.Packages.props b/Directory.Packages.props index 462cbe578..270bd542d 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -15,21 +15,22 @@ - + + - + - + - + diff --git a/nuget.config b/nuget.config index 0a73357e8..6e30104a0 100644 --- a/nuget.config +++ b/nuget.config @@ -5,10 +5,22 @@ + + + + + + + + + + + + diff --git a/src/StreamJsonRpc/StreamJsonRpc.csproj b/src/StreamJsonRpc/StreamJsonRpc.csproj index 70fa95f12..b51740ebc 100644 --- a/src/StreamJsonRpc/StreamJsonRpc.csproj +++ b/src/StreamJsonRpc/StreamJsonRpc.csproj @@ -1,6 +1,6 @@  - netstandard2.0;netstandard2.1;net6.0;net8.0 + netstandard2.0;netstandard2.1;net8.0 prompt 4 true @@ -17,6 +17,7 @@ + diff --git a/test/Benchmarks/Benchmarks.csproj b/test/Benchmarks/Benchmarks.csproj index b3a91e41d..bbcc37f3f 100644 --- a/test/Benchmarks/Benchmarks.csproj +++ b/test/Benchmarks/Benchmarks.csproj @@ -2,7 +2,7 @@ Exe - net6.0;net472 + net8.0;net472 diff --git a/test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj b/test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj index 82fb00a07..e391c2ac1 100644 --- a/test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj +++ b/test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj @@ -1,7 +1,7 @@  - net6.0;net8.0 + net8.0 $(TargetFrameworks);net472 From 6819e6b7cd3e4ad8f4d357021cd02cb3a35d1136 Mon Sep 17 00:00:00 2001 From: Charles Willis Date: Fri, 20 Dec 2024 18:24:36 +0000 Subject: [PATCH 02/25] Continue implementing NerdbankMessagePackFormatter. --- Directory.Packages.props | 2 +- nuget.config | 5 +- src/StreamJsonRpc/FormatterBase.cs | 1 + src/StreamJsonRpc/JsonRpc.cs | 2 +- ...rdbankMessagePackFormatter.CommonString.cs | 104 + ...tter.ICompositeTypeShapeProviderBuilder.cs | 42 + ...kFormatter.ISerializationContextBuilder.cs | 40 + ...bankMessagePackFormatter.RawMessagePack.cs | 88 + .../NerdbankMessagePackFormatter.cs | 2446 +++++++++++++++++ src/StreamJsonRpc/Protocol/JsonRpcError.cs | 10 +- src/StreamJsonRpc/Protocol/JsonRpcMessage.cs | 8 + src/StreamJsonRpc/Protocol/JsonRpcRequest.cs | 19 +- src/StreamJsonRpc/Protocol/JsonRpcResult.cs | 8 +- 13 files changed, 2769 insertions(+), 6 deletions(-) create mode 100644 src/StreamJsonRpc/NerdbankMessagePackFormatter.CommonString.cs create mode 100644 src/StreamJsonRpc/NerdbankMessagePackFormatter.ICompositeTypeShapeProviderBuilder.cs create mode 100644 src/StreamJsonRpc/NerdbankMessagePackFormatter.ISerializationContextBuilder.cs create mode 100644 src/StreamJsonRpc/NerdbankMessagePackFormatter.RawMessagePack.cs create mode 100644 src/StreamJsonRpc/NerdbankMessagePackFormatter.cs diff --git a/Directory.Packages.props b/Directory.Packages.props index 270bd542d..f04dabd42 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -22,7 +22,7 @@ - + diff --git a/nuget.config b/nuget.config index 6e30104a0..8ef7b7880 100644 --- a/nuget.config +++ b/nuget.config @@ -5,6 +5,7 @@ + @@ -18,9 +19,11 @@ - + + + diff --git a/src/StreamJsonRpc/FormatterBase.cs b/src/StreamJsonRpc/FormatterBase.cs index 7dd4c6479..c7bdb6ffa 100644 --- a/src/StreamJsonRpc/FormatterBase.cs +++ b/src/StreamJsonRpc/FormatterBase.cs @@ -6,6 +6,7 @@ using System.IO.Pipelines; using System.Reflection; using System.Runtime.Serialization; +using Nerdbank.MessagePack; using Nerdbank.Streams; using StreamJsonRpc.Protocol; using StreamJsonRpc.Reflection; diff --git a/src/StreamJsonRpc/JsonRpc.cs b/src/StreamJsonRpc/JsonRpc.cs index e8ab2712a..e5f1e596a 100644 --- a/src/StreamJsonRpc/JsonRpc.cs +++ b/src/StreamJsonRpc/JsonRpc.cs @@ -1697,7 +1697,7 @@ protected virtual async ValueTask DispatchRequestAsync(JsonRpcRe } /// - /// Sends the JSON-RPC message to intance to be transmitted. + /// Sends the JSON-RPC message to instance to be transmitted. /// /// The message to send. /// A token to cancel the send request. diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.CommonString.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.CommonString.cs new file mode 100644 index 000000000..dfc6ef831 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.CommonString.cs @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Diagnostics; +using NBMP = Nerdbank.MessagePack; + +namespace StreamJsonRpc; + +public partial class NerdbankMessagePackFormatter +{ + [DebuggerDisplay("{" + nameof(Value) + "}")] + private struct CommonString + { + internal CommonString(string value) + { + Requires.Argument(value.Length > 0 && value.Length <= 16, nameof(value), "Length must be >0 and <=16."); + this.Value = value; + ReadOnlyMemory encodedBytes = MessagePack.Internal.CodeGenHelpers.GetEncodedStringBytes(value); + this.EncodedBytes = encodedBytes; + + ReadOnlySpan span = this.EncodedBytes.Span.Slice(1); + this.Key = MessagePack.Internal.AutomataKeyGen.GetKey(ref span); // header is 1 byte because string length <= 16 + this.Key2 = span.Length > 0 ? (ulong?)MessagePack.Internal.AutomataKeyGen.GetKey(ref span) : null; + } + + /// + /// Gets the original string. + /// + internal string Value { get; } + + /// + /// Gets the 64-bit integer that represents the string without decoding it. + /// + private ulong Key { get; } + + /// + /// Gets the next 64-bit integer that represents the string without decoding it. + /// + private ulong? Key2 { get; } + + /// + /// Gets the messagepack header and UTF-8 bytes for this string. + /// + private ReadOnlyMemory EncodedBytes { get; } + + /// + /// Writes out the messagepack binary for this common string, if it matches the given value. + /// + /// The writer to use. + /// The value to be written, if it matches this . + /// if matches this and it was written; otherwise. + internal bool TryWrite(ref NBMP::MessagePackWriter writer, string value) + { + if (value == this.Value) + { + this.Write(ref writer); + return true; + } + + return false; + } + + internal readonly void Write(ref NBMP::MessagePackWriter writer) => writer.WriteRaw(this.EncodedBytes.Span); + + /// + /// Checks whether a span of UTF-8 bytes equal this common string. + /// + /// The UTF-8 string. + /// if the UTF-8 bytes are the encoding of this common string; otherwise. + internal readonly bool TryRead(ReadOnlySpan utf8String) + { + if (utf8String.Length != this.EncodedBytes.Length - 1) + { + return false; + } + + ulong key1 = MessagePack.Internal.AutomataKeyGen.GetKey(ref utf8String); + if (key1 != this.Key) + { + return false; + } + + if (utf8String.Length > 0) + { + if (!this.Key2.HasValue) + { + return false; + } + + ulong key2 = MessagePack.Internal.AutomataKeyGen.GetKey(ref utf8String); + if (key2 != this.Key2.Value) + { + return false; + } + } + else if (this.Key2.HasValue) + { + return false; + } + + return true; + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.ICompositeTypeShapeProviderBuilder.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ICompositeTypeShapeProviderBuilder.cs new file mode 100644 index 000000000..6d681dea8 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ICompositeTypeShapeProviderBuilder.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using PolyType; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +/// +/// The MessagePack implementation used here comes from https://github.com/AArnott/Nerdbank.MessagePack. +/// +public sealed partial class NerdbankMessagePackFormatter +{ + /// + /// Provides a builder interface for adding type shape providers. + /// + public interface ICompositeTypeShapeProviderBuilder + { + /// + /// Adds a single type shape provider to the builder. + /// + /// The type shape provider to add. + /// The current builder instance. + ICompositeTypeShapeProviderBuilder Add(ITypeShapeProvider provider); + + /// + /// Adds a range of type shape providers to the builder. + /// + /// The collection of type shape providers to add. + /// The current builder instance. + ICompositeTypeShapeProviderBuilder AddRange(IEnumerable providers); + + /// + /// Adds a reflection-based type shape provider to the builder. + /// + /// A value indicating whether to use Reflection.Emit for dynamic type generation. + /// The current builder instance. + ICompositeTypeShapeProviderBuilder AddReflectionTypeShapeProvider(bool useReflectionEmit); + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.ISerializationContextBuilder.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ISerializationContextBuilder.cs new file mode 100644 index 000000000..d898b41d2 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ISerializationContextBuilder.cs @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Nerdbank.MessagePack; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +/// +/// The MessagePack implementation used here comes from https://github.com/AArnott/Nerdbank.MessagePack. +/// +public sealed partial class NerdbankMessagePackFormatter +{ + /// + /// Provides a builder interface for configuring the serialization context. + /// + public interface IFormatterContextBuilder + { + /// + /// Gets the type shape provider builder. + /// + ICompositeTypeShapeProviderBuilder TypeShapeProviderBuilder { get; } + + /// + /// Registers a custom converter for a specific type. + /// + /// The type for which the converter is registered. + /// The converter to register. + void RegisterConverter(MessagePackConverter converter); + + /// + /// Registers known subtypes for a base type. + /// + /// The base type for which the subtypes are registered. + /// The mapping of known subtypes. + void RegisterKnownSubTypes(KnownSubTypeMapping mapping); + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.RawMessagePack.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.RawMessagePack.cs new file mode 100644 index 000000000..895980392 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.RawMessagePack.cs @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Buffers; +using Nerdbank.MessagePack; +using PolyType.Abstractions; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +/// +/// The MessagePack implementation used here comes from https://github.com/AArnott/Nerdbank.MessagePack. +/// +public partial class NerdbankMessagePackFormatter +{ + private struct RawMessagePack + { + private readonly ReadOnlySequence rawSequence; + + private readonly ReadOnlyMemory rawMemory; + + private RawMessagePack(ReadOnlySequence raw) + { + this.rawSequence = raw; + this.rawMemory = default; + } + + private RawMessagePack(ReadOnlyMemory raw) + { + this.rawSequence = default; + this.rawMemory = raw; + } + + internal readonly bool IsDefault => this.rawMemory.IsEmpty && this.rawSequence.IsEmpty; + + public override readonly string ToString() => ""; + + /// + /// Reads one raw messagepack token. + /// + /// The reader to use. + /// if the token must outlive the lifetime of the reader's underlying buffer; otherwise. + /// The serialization context to use. + /// The raw messagepack slice. + internal static RawMessagePack ReadRaw(ref MessagePackReader reader, bool copy, Nerdbank.MessagePack.SerializationContext context) + { + SequencePosition initialPosition = reader.Position; + reader.Skip(context); + ReadOnlySequence slice = reader.Sequence.Slice(initialPosition, reader.Position); + return copy ? new RawMessagePack(slice.ToArray()) : new RawMessagePack(slice); + } + + internal readonly void WriteRaw(ref MessagePackWriter writer) + { + if (this.rawSequence.IsEmpty) + { + writer.WriteRaw(this.rawMemory.Span); + } + else + { + writer.WriteRaw(this.rawSequence); + } + } + + internal readonly object? Deserialize(Type type, FormatterContext options) + { + MessagePackReader reader = this.rawSequence.IsEmpty + ? new MessagePackReader(this.rawMemory) + : new MessagePackReader(this.rawSequence); + + return options.Serializer.DeserializeObject( + ref reader, + options.ShapeProvider.Resolve(type)); + } + + internal readonly T Deserialize(FormatterContext options) + { + MessagePackReader reader = this.rawSequence.IsEmpty + ? new MessagePackReader(this.rawMemory) + : new MessagePackReader(this.rawSequence); + + return options.Serializer.Deserialize(ref reader, options.ShapeProvider) + ?? throw new MessagePackSerializationException(Resources.FailureDeserializingJsonRpc); + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs new file mode 100644 index 000000000..48f4a53ca --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs @@ -0,0 +1,2446 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Buffers; +using System.Collections.Immutable; +using System.Collections.ObjectModel; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.IO.Pipelines; +using System.Reflection; +using System.Runtime.ExceptionServices; +using System.Runtime.Serialization; +using System.Text; +using System.Text.Json.Nodes; +using MessagePack; +using MessagePack.Formatters; +using MessagePack.Resolvers; +using Nerdbank.MessagePack; +using Nerdbank.Streams; +using PolyType; +using PolyType.Abstractions; +using PolyType.ReflectionProvider; +using PolyType.SourceGenerator; +using StreamJsonRpc.Protocol; +using StreamJsonRpc.Reflection; +using NBMP = Nerdbank.MessagePack; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +/// +/// The MessagePack implementation used here comes from https://github.com/AArnott/Nerdbank.MessagePack. +/// +public sealed partial class NerdbankMessagePackFormatter : FormatterBase, IJsonRpcMessageFormatter, IJsonRpcFormatterTracingCallbacks, IJsonRpcMessageFactory +{ + /// + /// The constant "jsonrpc", in its various forms. + /// + private static readonly CommonString VersionPropertyName = new(Constants.jsonrpc); + + /// + /// The constant "id", in its various forms. + /// + private static readonly CommonString IdPropertyName = new(Constants.id); + + /// + /// The constant "method", in its various forms. + /// + private static readonly CommonString MethodPropertyName = new(Constants.Request.method); + + /// + /// The constant "result", in its various forms. + /// + private static readonly CommonString ResultPropertyName = new(Constants.Result.result); + + /// + /// The constant "error", in its various forms. + /// + private static readonly CommonString ErrorPropertyName = new(Constants.Error.error); + + /// + /// The constant "params", in its various forms. + /// + private static readonly CommonString ParamsPropertyName = new(Constants.Request.@params); + + /// + /// The constant "traceparent", in its various forms. + /// + private static readonly CommonString TraceParentPropertyName = new(Constants.Request.traceparent); + + /// + /// The constant "tracestate", in its various forms. + /// + private static readonly CommonString TraceStatePropertyName = new(Constants.Request.tracestate); + + /// + /// The constant "2.0", in its various forms. + /// + private static readonly CommonString Version2 = new("2.0"); + + /// + /// A cache of property names to declared property types, indexed by their containing parameter object type. + /// + /// + /// All access to this field should be while holding a lock on this member's value. + /// + private static readonly Dictionary> ParameterObjectPropertyTypes = new Dictionary>(); + + /// + /// The serializer context to use for top-level RPC messages. + /// + private readonly FormatterContext rpcContext; + + private readonly ProgressFormatterResolver progressFormatterResolver; + + private readonly AsyncEnumerableFormatterResolver asyncEnumerableFormatterResolver; + + private readonly PipeFormatterResolver pipeFormatterResolver; + + private readonly MessagePackExceptionResolver exceptionResolver; + + private readonly ToStringHelper serializationToStringHelper = new(); + + private readonly ToStringHelper deserializationToStringHelper = new(); + + /// + /// The serializer to use for user data (e.g. arguments, return values and errors). + /// + private FormatterContext userDataContext; + + /// + /// Initializes a new instance of the class. + /// + public NerdbankMessagePackFormatter() + { + // Set up initial options for our own message types. + NBMP::MessagePackSerializer serializer = new() + { + InternStrings = true, + SerializeDefaultValues = false, + }; + + serializer.RegisterConverter(new RequestIdConverter()); + serializer.RegisterConverter(new JsonRpcMessageConverter(this)); + serializer.RegisterConverter(new JsonRpcRequestConverter(this)); + serializer.RegisterConverter(new JsonRpcResultConverter(this)); + serializer.RegisterConverter(new JsonRpcErrorConverter(this)); + serializer.RegisterConverter(new JsonRpcErrorDetailConverter(this)); + serializer.RegisterConverter(new TraceParentConverter()); + + this.rpcContext = new FormatterContext(serializer, ShapeProvider_StreamJsonRpc.Default); + + // Create the specialized formatters/resolvers that we will inject into the chain for user data. + this.progressFormatterResolver = new ProgressFormatterResolver(this); + this.asyncEnumerableFormatterResolver = new AsyncEnumerableFormatterResolver(this); + this.pipeFormatterResolver = new PipeFormatterResolver(this); + this.exceptionResolver = new MessagePackExceptionResolver(this); + + FormatterContext userDataContext = new( + new() + { + InternStrings = true, + SerializeDefaultValues = false, + }, + ReflectionTypeShapeProvider.Default); + + this.MassageUserDataContext(userDataContext); + this.userDataContext = userDataContext; + } + + private interface IJsonRpcMessagePackRetention + { + /// + /// Gets the original msgpack sequence that was deserialized into this message. + /// + /// + /// The buffer is only retained for a short time. If it has already been cleared, the result of this property is an empty sequence. + /// + ReadOnlySequence OriginalMessagePack { get; } + } + + /// + public new MultiplexingStream? MultiplexingStream + { + get => base.MultiplexingStream; + set => base.MultiplexingStream = value; + } + + /// + /// Configures the serialization context for user data with the specified configuration action. + /// + /// The action to configure the serialization context. + public void SetFormatterContext(Action configure) + { + Requires.NotNull(configure, nameof(configure)); + + var builder = new FormatterContextBuilder(this.userDataContext.Serializer); + configure(builder); + + FormatterContext context = builder.Build(); + this.MassageUserDataContext(context); + + this.userDataContext = context; + } + + /// + public JsonRpcMessage Deserialize(ReadOnlySequence contentBuffer) + { + JsonRpcMessage message = this.rpcContext.Serializer.Deserialize(contentBuffer, ShapeProvider_StreamJsonRpc.Default) + ?? throw new NBMP::MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); + + IJsonRpcTracingCallbacks? tracingCallbacks = this.JsonRpc; + this.deserializationToStringHelper.Activate(contentBuffer); + try + { + tracingCallbacks?.OnMessageDeserialized(message, this.deserializationToStringHelper); + } + finally + { + this.deserializationToStringHelper.Deactivate(); + } + + return message; + } + + /// + public void Serialize(IBufferWriter contentBuffer, JsonRpcMessage message) + { + if (message is Protocol.JsonRpcRequest request + && request.Arguments is not null + && request.ArgumentsList is null + && request.Arguments is not IReadOnlyDictionary) + { + // This request contains named arguments, but not using a standard dictionary. Convert it to a dictionary so that + // the parameters can be matched to the method we're invoking. + if (GetParamsObjectDictionary(request.Arguments) is { } namedArgs) + { + request.Arguments = namedArgs.ArgumentValues; + request.NamedArgumentDeclaredTypes = namedArgs.ArgumentTypes; + } + } + + var writer = new NBMP::MessagePackWriter(contentBuffer); + try + { + this.rpcContext.Serializer.Serialize(ref writer, message, this.rpcContext.ShapeProvider); + writer.Flush(); + } + catch (Exception ex) + { + throw new NBMP::MessagePackSerializationException(string.Format(CultureInfo.CurrentCulture, Resources.ErrorWritingJsonRpcMessage, ex.GetType().Name, ex.Message), ex); + } + } + + /// + public object GetJsonText(JsonRpcMessage message) => message is IJsonRpcMessagePackRetention retainedMsgPack + ? NBMP::MessagePackSerializer.ConvertToJson(retainedMsgPack.OriginalMessagePack) + : throw new NotSupportedException(); + + /// + Protocol.JsonRpcRequest IJsonRpcMessageFactory.CreateRequestMessage() => new OutboundJsonRpcRequest(this); + + /// + Protocol.JsonRpcError IJsonRpcMessageFactory.CreateErrorMessage() => new JsonRpcError(this.userDataContext); + + /// + Protocol.JsonRpcResult IJsonRpcMessageFactory.CreateResultMessage() => new JsonRpcResult(this, this.rpcContext); + + void IJsonRpcFormatterTracingCallbacks.OnSerializationComplete(JsonRpcMessage message, ReadOnlySequence encodedMessage) + { + IJsonRpcTracingCallbacks? tracingCallbacks = this.JsonRpc; + this.serializationToStringHelper.Activate(encodedMessage); + try + { + tracingCallbacks?.OnMessageSerialized(message, this.serializationToStringHelper); + } + finally + { + this.serializationToStringHelper.Deactivate(); + } + } + + /// + /// Extracts a dictionary of property names and values from the specified params object. + /// + /// The params object. + /// A dictionary of argument values and another of declared argument types, or if is null. + /// + /// This method supports DataContractSerializer-compliant types. This includes C# anonymous types. + /// + [return: NotNullIfNotNull(nameof(paramsObject))] + private static (IReadOnlyDictionary ArgumentValues, IReadOnlyDictionary ArgumentTypes)? GetParamsObjectDictionary(object? paramsObject) + { + if (paramsObject is null) + { + return default; + } + + // Look up the argument types dictionary if we saved it before. + Type paramsObjectType = paramsObject.GetType(); + IReadOnlyDictionary? argumentTypes; + lock (ParameterObjectPropertyTypes) + { + ParameterObjectPropertyTypes.TryGetValue(paramsObjectType, out argumentTypes); + } + + // If we couldn't find a previously created argument types dictionary, create a mutable one that we'll build this time. + Dictionary? mutableArgumentTypes = argumentTypes is null ? new Dictionary() : null; + + var result = new Dictionary(StringComparer.Ordinal); + + TypeInfo paramsTypeInfo = paramsObject.GetType().GetTypeInfo(); + bool isDataContract = paramsTypeInfo.GetCustomAttribute() is not null; + + BindingFlags bindingFlags = BindingFlags.FlattenHierarchy | BindingFlags.Public | BindingFlags.Instance; + if (isDataContract) + { + bindingFlags |= BindingFlags.NonPublic; + } + + bool TryGetSerializationInfo(MemberInfo memberInfo, out string key) + { + key = memberInfo.Name; + if (isDataContract) + { + DataMemberAttribute? dataMemberAttribute = memberInfo.GetCustomAttribute(); + if (dataMemberAttribute is null) + { + return false; + } + + if (!dataMemberAttribute.EmitDefaultValue) + { + throw new NotSupportedException($"(DataMemberAttribute.EmitDefaultValue == false) is not supported but was found on: {memberInfo.DeclaringType!.FullName}.{memberInfo.Name}."); + } + + key = dataMemberAttribute.Name ?? memberInfo.Name; + return true; + } + else + { + return memberInfo.GetCustomAttribute() is null; + } + } + + foreach (PropertyInfo property in paramsTypeInfo.GetProperties(bindingFlags)) + { + if (property.GetMethod is not null) + { + if (TryGetSerializationInfo(property, out string key)) + { + result[key] = property.GetValue(paramsObject); + if (mutableArgumentTypes is object) + { + mutableArgumentTypes[key] = property.PropertyType; + } + } + } + } + + foreach (FieldInfo field in paramsTypeInfo.GetFields(bindingFlags)) + { + if (TryGetSerializationInfo(field, out string key)) + { + result[key] = field.GetValue(paramsObject); + if (mutableArgumentTypes is object) + { + mutableArgumentTypes[key] = field.FieldType; + } + } + } + + // If we assembled the argument types dictionary this time, save it for next time. + if (mutableArgumentTypes is not null) + { + lock (ParameterObjectPropertyTypes) + { + if (ParameterObjectPropertyTypes.TryGetValue(paramsObjectType, out IReadOnlyDictionary? lostRace)) + { + // Of the two, pick the winner to use ourselves so we consolidate on one and allow the GC to collect the loser sooner. + argumentTypes = lostRace; + } + else + { + ParameterObjectPropertyTypes.Add(paramsObjectType, argumentTypes = mutableArgumentTypes); + } + } + } + + return (result, argumentTypes!); + } + + private static ReadOnlySequence GetSliceForNextToken(ref NBMP::MessagePackReader reader, in NBMP::SerializationContext context) + { + SequencePosition startingPosition = reader.Position; + reader.Skip(context); + SequencePosition endingPosition = reader.Position; + return reader.Sequence.Slice(startingPosition, endingPosition); + } + + /// + /// Reads a string with an optimized path for the value "2.0". + /// + /// The reader to use. + /// The decoded string. + private static unsafe string ReadProtocolVersion(ref NBMP::MessagePackReader reader) + { + if (!reader.TryReadStringSpan(out ReadOnlySpan valueBytes)) + { + // TODO: More specific exception type + throw new NBMP::MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); + } + + // Recognize "2.0" since we expect it and can avoid decoding and allocating a new string for it. + if (Version2.TryRead(valueBytes)) + { + return Version2.Value; + } + else + { + // It wasn't the expected value, so decode it. + fixed (byte* pValueBytes = valueBytes) + { + return Encoding.UTF8.GetString(pValueBytes, valueBytes.Length); + } + } + } + + /// + /// Writes the JSON-RPC version property name and value in a highly optimized way. + /// + private static void WriteProtocolVersionPropertyAndValue(ref NBMP::MessagePackWriter writer, string version) + { + VersionPropertyName.Write(ref writer); + if (!Version2.TryWrite(ref writer, version)) + { + writer.Write(version); + } + } + + private static void ReadUnknownProperty(ref NBMP::MessagePackReader reader, in NBMP::SerializationContext context, ref Dictionary>? topLevelProperties, ReadOnlySpan stringKey) + { + topLevelProperties ??= new Dictionary>(StringComparer.Ordinal); +#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER + string name = Encoding.UTF8.GetString(stringKey); +#else + string name = Encoding.UTF8.GetString(stringKey.ToArray()); +#endif + topLevelProperties.Add(name, GetSliceForNextToken(ref reader, context)); + } + + /// + /// Takes the user-supplied resolver for their data types and prepares the wrapping options + /// and the dynamic object wrapper for serialization. + /// + /// The options for user data that is supplied by the user (or the default). + private void MassageUserDataContext(FormatterContext userDataContext) + { + // Add our own resolvers to fill in specialized behavior if the user doesn't provide/override it by their own resolver. + userDataContext.Serializer.RegisterConverter(RequestIdConverter.Instance); + userDataContext.Serializer.RegisterConverter(RawMessagePackConverter.Instance); + userDataContext.Serializer.RegisterConverter(EventArgsConverter.Instance); + + var resolvers = new IFormatterResolver[] + { + // Support for marshalled objects. + // new RpcMarshalableResolver(this) + + // Stateful or per-connection resolvers. + this.progressFormatterResolver, + this.asyncEnumerableFormatterResolver, + this.pipeFormatterResolver, + this.exceptionResolver, + }; + + // Wrap the resolver in another class as a way to pass information to our custom formatters. + IFormatterResolver userDataResolver = new ResolverWrapper(CompositeResolver.Create(resolvers), this); + } + + private class ResolverWrapper : IFormatterResolver + { + private readonly IFormatterResolver inner; + + internal ResolverWrapper(IFormatterResolver inner, NerdbankMessagePackFormatter formatter) + { + this.inner = inner; + this.Formatter = formatter; + } + + internal NerdbankMessagePackFormatter Formatter { get; } + + public IMessagePackFormatter? GetFormatter() => this.inner.GetFormatter(); + } + + private class MessagePackFormatterConverter : IFormatterConverter + { + private readonly FormatterContext options; + + internal MessagePackFormatterConverter(FormatterContext options) + { + this.options = options; + } + +#pragma warning disable CS8766 // This method may in fact return null, and no one cares. + public object? Convert(object value, Type type) +#pragma warning restore CS8766 + => ((RawMessagePack)value).Deserialize(type, this.options); + + public object Convert(object value, TypeCode typeCode) + { + return typeCode switch + { + TypeCode.Object => ((RawMessagePack)value).Deserialize(this.options), + _ => ExceptionSerializationHelpers.Convert(this, value, typeCode), + }; + } + + public bool ToBoolean(object value) => ((RawMessagePack)value).Deserialize(this.options); + + public byte ToByte(object value) => ((RawMessagePack)value).Deserialize(this.options); + + public char ToChar(object value) => ((RawMessagePack)value).Deserialize(this.options); + + public DateTime ToDateTime(object value) => ((RawMessagePack)value).Deserialize(this.options); + + public decimal ToDecimal(object value) => ((RawMessagePack)value).Deserialize(this.options); + + public double ToDouble(object value) => ((RawMessagePack)value).Deserialize(this.options); + + public short ToInt16(object value) => ((RawMessagePack)value).Deserialize(this.options); + + public int ToInt32(object value) => ((RawMessagePack)value).Deserialize(this.options); + + public long ToInt64(object value) => ((RawMessagePack)value).Deserialize(this.options); + + public sbyte ToSByte(object value) => ((RawMessagePack)value).Deserialize(this.options); + + public float ToSingle(object value) => ((RawMessagePack)value).Deserialize(this.options); + + public string? ToString(object value) => value is null ? null : ((RawMessagePack)value).Deserialize(this.options); + + public ushort ToUInt16(object value) => ((RawMessagePack)value).Deserialize(this.options); + + public uint ToUInt32(object value) => ((RawMessagePack)value).Deserialize(this.options); + + public ulong ToUInt64(object value) => ((RawMessagePack)value).Deserialize(this.options); + } + + /// + /// A recyclable object that can serialize a message to JSON on demand. + /// + /// + /// In perf traces, creation of this object used to show up as one of the most allocated objects. + /// It is used even when tracing isn't active. So we changed its design to be reused, + /// since its lifetime is only required during a synchronous call to a trace API. + /// + private class ToStringHelper + { + private ReadOnlySequence? encodedMessage; + private string? jsonString; + + public override string ToString() + { + Verify.Operation(this.encodedMessage.HasValue, "This object has not been activated. It may have already been recycled."); + + return this.jsonString ??= NBMP::MessagePackSerializer.ConvertToJson(this.encodedMessage.Value); + } + + /// + /// Initializes this object to represent a message. + /// + internal void Activate(ReadOnlySequence encodedMessage) + { + this.encodedMessage = encodedMessage; + } + + /// + /// Cleans out this object to release memory and ensure throws if someone uses it after deactivation. + /// + internal void Deactivate() + { + this.encodedMessage = null; + this.jsonString = null; + } + } + + private class RequestIdConverter : NBMP::MessagePackConverter + { + internal static readonly RequestIdConverter Instance = new(); + + public override RequestId Read(ref NBMP.MessagePackReader reader, SerializationContext context) + { + context.DepthStep(); + + if (reader.NextMessagePackType == NBMP.MessagePackType.Integer) + { + return new RequestId(reader.ReadInt64()); + } + else + { + // Do *not* read as an interned string here because this ID should be unique. + return new RequestId(reader.ReadString()); + } + } + + public override void Write(ref NBMP.MessagePackWriter writer, in RequestId value, SerializationContext context) + { + context.DepthStep(); + + if (value.Number.HasValue) + { + writer.Write(value.Number.Value); + } + else + { + writer.Write(value.String); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => JsonNode.Parse(""" + { + "type": ["string", { "type": "integer", "format": "int64" }] + } + """)?.AsObject(); + } + + private class RawMessagePackConverter : MessagePackConverter + { + internal static readonly RawMessagePackConverter Instance = new(); + + private RawMessagePackConverter() + { + } + + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] + public override RawMessagePack Read(ref NBMP.MessagePackReader reader, SerializationContext context) + { + return RawMessagePack.ReadRaw(ref reader, copy: false, context); + } + + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] + public override void Write(ref NBMP.MessagePackWriter writer, in RawMessagePack value, SerializationContext context) + { + value.WriteRaw(ref writer); + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(RawMessagePackConverter)); + } + } + + private class ProgressFormatterResolver : IFormatterResolver + { + private readonly MessagePackFormatter mainFormatter; + + private readonly Dictionary progressFormatters = []; + + internal ProgressFormatterResolver(MessagePackFormatter formatter) + { + this.mainFormatter = formatter; + } + + public IMessagePackFormatter? GetFormatter() + { + lock (this.progressFormatters) + { + if (!this.progressFormatters.TryGetValue(typeof(T), out IMessagePackFormatter? formatter)) + { + if (MessageFormatterProgressTracker.CanDeserialize(typeof(T))) + { + formatter = new PreciseTypeFormatter(this.mainFormatter); + } + else if (MessageFormatterProgressTracker.CanSerialize(typeof(T))) + { + formatter = new ProgressClientFormatter(this.mainFormatter); + } + + this.progressFormatters.Add(typeof(T), formatter); + } + + return (IMessagePackFormatter?)formatter; + } + } + + /// + /// Converts an instance of to a progress token. + /// + private class ProgressClientConverter : MessagePackConverter + { + private readonly NerdbankMessagePackFormatter formatter; + + internal ProgressClientConverter(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter; + } + + public override TClass Read(ref NBMP.MessagePackReader reader, SerializationContext context) + { + throw new NotSupportedException("This formatter only serializes IProgress instances."); + } + + public override void Write(ref NBMP.MessagePackWriter writer, in TClass? value, SerializationContext context) + { + if (value is null) + { + writer.WriteNil(); + } + else + { + long progressId = this.formatter.FormatterProgressTracker.GetTokenForProgress(value); + writer.Write(progressId); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(ProgressClientConverter)); + } + } + + /// + /// Converts a progress token to an or an into a token. + /// + private class PreciseTypeConverter : MessagePackConverter + { + private readonly NerdbankMessagePackFormatter formatter; + + internal PreciseTypeConverter(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter; + } + + [return: MaybeNull] + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] + public override TClass? Read(ref NBMP.MessagePackReader reader, SerializationContext context) + { + if (reader.TryReadNil()) + { + return default!; + } + + Assumes.NotNull(this.formatter.JsonRpc); + RawMessagePack token = RawMessagePack.ReadRaw(ref reader, copy: true, context); + bool clientRequiresNamedArgs = this.formatter.ApplicableMethodAttributeOnDeserializingMethod?.ClientRequiresNamedArguments is true; + return (TClass)this.formatter.FormatterProgressTracker.CreateProgress(this.formatter.JsonRpc, token, typeof(TClass), clientRequiresNamedArgs); + } + + public override void Write(ref NBMP.MessagePackWriter writer, in TClass? value, SerializationContext context) + { + if (value is null) + { + writer.WriteNil(); + } + else + { + long progressId = this.formatter.FormatterProgressTracker.GetTokenForProgress(value); + writer.Write(progressId); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(PreciseTypeConverter)); + } + } + } + + private class AsyncEnumerableFormatterResolver : IFormatterResolver + { + private readonly MessagePackFormatter mainFormatter; + + private readonly Dictionary enumerableFormatters = new Dictionary(); + + internal AsyncEnumerableFormatterResolver(MessagePackFormatter formatter) + { + this.mainFormatter = formatter; + } + + public IMessagePackFormatter? GetFormatter() + { + lock (this.enumerableFormatters) + { + if (!this.enumerableFormatters.TryGetValue(typeof(T), out IMessagePackFormatter? formatter)) + { + if (TrackerHelpers>.IsActualInterfaceMatch(typeof(T))) + { + formatter = (IMessagePackFormatter?)Activator.CreateInstance(typeof(PreciseTypeConverter<>).MakeGenericType(typeof(T).GenericTypeArguments[0]), new object[] { this.mainFormatter }); + } + else if (TrackerHelpers>.FindInterfaceImplementedBy(typeof(T)) is { } iface) + { + formatter = (IMessagePackFormatter?)Activator.CreateInstance(typeof(GeneratorConverter<,>).MakeGenericType(typeof(T), iface.GenericTypeArguments[0]), new object[] { this.mainFormatter }); + } + + this.enumerableFormatters.Add(typeof(T), formatter); + } + + return (IMessagePackFormatter?)formatter; + } + } + + /// + /// Converts an enumeration token to an + /// or an into an enumeration token. + /// +#pragma warning disable CA1812 + private partial class PreciseTypeConverter(NerdbankMessagePackFormatter mainFormatter) : MessagePackConverter> +#pragma warning restore CA1812 + { + /// + /// The constant "token", in its various forms. + /// + private static readonly CommonString TokenPropertyName = new(MessageFormatterEnumerableTracker.TokenPropertyName); + + /// + /// The constant "values", in its various forms. + /// + private static readonly CommonString ValuesPropertyName = new(MessageFormatterEnumerableTracker.ValuesPropertyName); + + public override IAsyncEnumerable? Read(ref NBMP.MessagePackReader reader, SerializationContext context) + { + if (reader.TryReadNil()) + { + return default; + } + + context.DepthStep(); + RawMessagePack token = default; + IReadOnlyList? initialElements = null; + int propertyCount = reader.ReadMapHeader(); + for (int i = 0; i < propertyCount; i++) + { + if (!reader.TryReadStringSpan(out ReadOnlySpan stringKey)) + { + throw new NBMP.MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); + } + + if (TokenPropertyName.TryRead(stringKey)) + { + token = RawMessagePack.ReadRaw(ref reader, copy: true, context); + } + else if (ValuesPropertyName.TryRead(stringKey)) + { + initialElements = context.GetConverter>(context.TypeShapeProvider).Read(ref reader, context); + } + else + { + reader.Skip(context); + } + } + + return mainFormatter.EnumerableTracker.CreateEnumerableProxy(token.IsDefault ? null : token, initialElements); + } + + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] + public override void Write(ref NBMP.MessagePackWriter writer, in IAsyncEnumerable? value, SerializationContext context) + { + Serialize_Shared(mainFormatter, ref writer, value, context); + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(PreciseTypeConverter)); + } + + internal static void Serialize_Shared(NerdbankMessagePackFormatter mainFormatter, ref NBMP::MessagePackWriter writer, IAsyncEnumerable? value, NBMP::SerializationContext context) + { + if (value is null) + { + writer.WriteNil(); + } + else + { + (IReadOnlyList Elements, bool Finished) prefetched = value.TearOffPrefetchedElements(); + long token = mainFormatter.EnumerableTracker.GetToken(value); + + int propertyCount = 0; + if (prefetched.Elements.Count > 0) + { + propertyCount++; + } + + if (!prefetched.Finished) + { + propertyCount++; + } + + writer.WriteMapHeader(propertyCount); + + if (!prefetched.Finished) + { + writer.Write(MessageFormatterEnumerableTracker.TokenPropertyName); + writer.Write(token); + } + + if (prefetched.Elements.Count > 0) + { + writer.Write(MessageFormatterEnumerableTracker.ValuesPropertyName); + context.GetConverter>(context.TypeShapeProvider).Write(ref writer, prefetched.Elements, context); + } + } + } + } + + /// + /// Converts an instance of to an enumeration token. + /// +#pragma warning disable CA1812 + private class GeneratorConverter(NerdbankMessagePackFormatter mainFormatter) : MessagePackConverter where TClass : IAsyncEnumerable +#pragma warning restore CA1812 + { + public override TClass Read(ref NBMP.MessagePackReader reader, SerializationContext context) + { + throw new NotSupportedException(); + } + + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] + public override void Write(ref NBMP.MessagePackWriter writer, in TClass? value, SerializationContext context) + { + PreciseTypeConverter.Serialize_Shared(mainFormatter, ref writer, value, context); + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(GeneratorConverter)); + } + } + } + + private class PipeFormatterResolver : IFormatterResolver + { + private readonly NerdbankMessagePackFormatter mainFormatter; + + private readonly Dictionary pipeFormatters = []; + + internal PipeFormatterResolver(NerdbankMessagePackFormatter formatter) + { + this.mainFormatter = formatter; + } + + public IMessagePackFormatter? GetFormatter() + { + lock (this.pipeFormatters) + { + if (!this.pipeFormatters.TryGetValue(typeof(T), out IMessagePackFormatter? formatter)) + { + if (typeof(IDuplexPipe).IsAssignableFrom(typeof(T))) + { + formatter = (IMessagePackFormatter)Activator.CreateInstance(typeof(DuplexPipeConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; + } + else if (typeof(PipeReader).IsAssignableFrom(typeof(T))) + { + formatter = (IMessagePackFormatter)Activator.CreateInstance(typeof(PipeReaderConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; + } + else if (typeof(PipeWriter).IsAssignableFrom(typeof(T))) + { + formatter = (IMessagePackFormatter)Activator.CreateInstance(typeof(PipeWriterConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; + } + else if (typeof(Stream).IsAssignableFrom(typeof(T))) + { + formatter = (IMessagePackFormatter)Activator.CreateInstance(typeof(StreamConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; + } + + this.pipeFormatters.Add(typeof(T), formatter); + } + + return (IMessagePackFormatter?)formatter; + } + } + +#pragma warning disable CA1812 +#pragma warning disable NBMsgPack032 // Converters should override GetJsonSchema + private class DuplexPipeConverter(NerdbankMessagePackFormatter formatter) : MessagePackConverter + where T : class, IDuplexPipe +#pragma warning restore NBMsgPack032 // Converters should override GetJsonSchema +#pragma warning restore CA1812 + { + public override T? Read(ref NBMP.MessagePackReader reader, SerializationContext context) + { + if (reader.TryReadNil()) + { + return null; + } + + return (T)formatter.DuplexPipeTracker.GetPipe(reader.ReadUInt64()); + } + + public override void Write(ref NBMP.MessagePackWriter writer, in T? value, SerializationContext context) + { + if (formatter.DuplexPipeTracker.GetULongToken(value) is { } token) + { + writer.Write(token); + } + else + { + writer.WriteNil(); + } + } + } + +#pragma warning disable CA1812 +#pragma warning disable NBMsgPack032 // Converters should override GetJsonSchema + private class PipeReaderConverter(NerdbankMessagePackFormatter formatter) : MessagePackConverter + where T : PipeReader +#pragma warning restore NBMsgPack032 // Converters should override GetJsonSchema +#pragma warning restore CA1812 + { + public override T? Read(ref NBMP.MessagePackReader reader, SerializationContext context) + { + if (reader.TryReadNil()) + { + return null; + } + + return (T)formatter.DuplexPipeTracker.GetPipeReader(reader.ReadUInt64()); + } + + public override void Write(ref NBMP.MessagePackWriter writer, in T? value, SerializationContext context) + { + if (formatter.DuplexPipeTracker.GetULongToken(value) is { } token) + { + writer.Write(token); + } + else + { + writer.WriteNil(); + } + } + } + +#pragma warning disable CA1812 +#pragma warning disable NBMsgPack032 // Converters should override GetJsonSchema + private class PipeWriterConverter(NerdbankMessagePackFormatter formatter) : MessagePackConverter + where T : PipeWriter +#pragma warning restore NBMsgPack032 // Converters should override GetJsonSchema +#pragma warning restore CA1812 + { + public override T? Read(ref NBMP.MessagePackReader reader, SerializationContext context) + { + if (reader.TryReadNil()) + { + return null; + } + + return (T)formatter.DuplexPipeTracker.GetPipeWriter(reader.ReadUInt64()); + } + + public override void Write(ref NBMP.MessagePackWriter writer, in T? value, SerializationContext context) + { + if (formatter.DuplexPipeTracker.GetULongToken(value) is { } token) + { + writer.Write(token); + } + else + { + writer.WriteNil(); + } + } + } + +#pragma warning disable CA1812 +#pragma warning disable NBMsgPack032 // Converters should override GetJsonSchema + private class StreamConverter : MessagePackConverter + where T : Stream +#pragma warning restore NBMsgPack032 // Converters should override GetJsonSchema +#pragma warning restore CA1812 + { + private readonly NerdbankMessagePackFormatter formatter; + + public StreamConverter(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter; + } + + public override T? Read(ref NBMP.MessagePackReader reader, SerializationContext context) + { + if (reader.TryReadNil()) + { + return null; + } + + return (T)this.formatter.DuplexPipeTracker.GetPipe(reader.ReadUInt64()).AsStream(); + } + + public override void Write(ref NBMP.MessagePackWriter writer, in T? value, SerializationContext context) + { + if (this.formatter.DuplexPipeTracker.GetULongToken(value?.UsePipe()) is { } token) + { + writer.Write(token); + } + else + { + writer.WriteNil(); + } + } + } + } + + private class RpcMarshalableResolver : IFormatterResolver + { + private readonly NerdbankMessagePackFormatter formatter; + private readonly Dictionary formatters = new Dictionary(); + + internal RpcMarshalableResolver(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter; + } + + public IMessagePackFormatter? GetFormatter() + { + if (typeof(T).IsValueType) + { + return null; + } + + lock (this.formatters) + { + if (this.formatters.TryGetValue(typeof(T), out object? cachedFormatter)) + { + return (IMessagePackFormatter)cachedFormatter; + } + } + + if (MessageFormatterRpcMarshaledContextTracker.TryGetMarshalOptionsForType( + typeof(T), + out JsonRpcProxyOptions? proxyOptions, + out JsonRpcTargetOptions? targetOptions, + out RpcMarshalableAttribute? attribute)) + { + object formatter = Activator.CreateInstance( + typeof(RpcMarshalableFormatter<>).MakeGenericType(typeof(T)), + this.formatter, + proxyOptions, + targetOptions, + attribute)!; + + lock (this.formatters) + { + if (!this.formatters.TryGetValue(typeof(T), out object? cachedFormatter)) + { + this.formatters.Add(typeof(T), cachedFormatter = formatter); + } + + return (IMessagePackFormatter)cachedFormatter; + } + } + + return null; + } + } + +#pragma warning disable CA1812 + private class RpcMarshalableFormatter(NerdbankMessagePackFormatter messagePackFormatter, JsonRpcProxyOptions proxyOptions, JsonRpcTargetOptions targetOptions, RpcMarshalableAttribute rpcMarshalableAttribute) : IMessagePackFormatter + where T : class +#pragma warning restore CA1812 + { + public T? Deserialize(ref MessagePack.MessagePackReader reader, MessagePackSerializerOptions options) + { + MessageFormatterRpcMarshaledContextTracker.MarshalToken? token = MessagePack.MessagePackSerializer.Deserialize(ref reader, options); + return token.HasValue ? (T?)messagePackFormatter.RpcMarshaledContextTracker.GetObject(typeof(T), token, proxyOptions) : null; + } + + public void Serialize(ref MessagePack.MessagePackWriter writer, T? value, MessagePackSerializerOptions options) + { + if (value is null) + { + writer.WriteNil(); + } + else + { + MessageFormatterRpcMarshaledContextTracker.MarshalToken token = messagePackFormatter.RpcMarshaledContextTracker.GetToken(value, targetOptions, typeof(T), rpcMarshalableAttribute); + MessagePack.MessagePackSerializer.Serialize(ref writer, token, options); + } + } + } + + /// + /// Manages serialization of any -derived type that follows standard rules. + /// + /// + /// A serializable class will: + /// 1. Derive from + /// 2. Be attributed with + /// 3. Declare a constructor with a signature of (, ). + /// + private class MessagePackExceptionResolver : IFormatterResolver + { + /// + /// Tracks recursion count while serializing or deserializing an exception. + /// + /// + /// This is placed here (outside the generic class) + /// so that it's one counter shared across all exception types that may be serialized or deserialized. + /// + private static ThreadLocal exceptionRecursionCounter = new(); + + private readonly object[] formatterActivationArgs; + + private readonly Dictionary formatterCache = new Dictionary(); + + internal MessagePackExceptionResolver(MessagePackFormatter formatter) + { + this.formatterActivationArgs = new object[] { formatter }; + } + + public IMessagePackFormatter? GetFormatter() + { + lock (this.formatterCache) + { + if (this.formatterCache.TryGetValue(typeof(T), out object? cachedFormatter)) + { + return (IMessagePackFormatter?)cachedFormatter; + } + + IMessagePackFormatter? formatter = null; + if (typeof(Exception).IsAssignableFrom(typeof(T)) && typeof(T).GetCustomAttribute() is object) + { + formatter = (IMessagePackFormatter)Activator.CreateInstance(typeof(ExceptionFormatter<>).MakeGenericType(typeof(T)), this.formatterActivationArgs)!; + } + + this.formatterCache.Add(typeof(T), formatter); + return formatter; + } + } + +#pragma warning disable CA1812 + private partial class ExceptionFormatter(NerdbankMessagePackFormatter formatter) : MessagePackConverter + where T : Exception +#pragma warning restore CA1812 + { + public override T? Read(ref NBMP.MessagePackReader reader, SerializationContext context) + { + Assumes.NotNull(formatter.JsonRpc); + if (reader.TryReadNil()) + { + return null; + } + + // We have to guard our own recursion because the serializer has no visibility into inner exceptions. + // Each exception in the russian doll is a new serialization job from its perspective. + exceptionRecursionCounter.Value++; + try + { + if (exceptionRecursionCounter.Value > formatter.JsonRpc.ExceptionOptions.RecursionLimit) + { + // Exception recursion has gone too deep. Skip this value and return null as if there were no inner exception. + // Note that in skipping, the parser may use recursion internally and may still throw if its own limits are exceeded. + reader.Skip(context); + return null; + } + + // TODO: Is this the right context? + var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(formatter.rpcContext)); + int memberCount = reader.ReadMapHeader(); + for (int i = 0; i < memberCount; i++) + { + string? name = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context) + ?? throw new NBMP::MessagePackSerializationException(Resources.UnexpectedNullValueInMap); + + // SerializationInfo.GetValue(string, typeof(object)) does not call our formatter, + // so the caller will get a boxed RawMessagePack struct in that case. + // Although we can't do much about *that* in general, we can at least ensure that null values + // are represented as null instead of this boxed struct. + var value = reader.TryReadNil() ? null : (object)RawMessagePack.ReadRaw(ref reader, false, context); + + info.AddSafeValue(name, value); + } + + return ExceptionSerializationHelpers.Deserialize(formatter.JsonRpc, info, formatter.JsonRpc.TraceSource); + } + finally + { + exceptionRecursionCounter.Value--; + } + } + + public override void Write(ref NBMP.MessagePackWriter writer, in T? value, SerializationContext context) + { + if (value is null) + { + writer.WriteNil(); + return; + } + + exceptionRecursionCounter.Value++; + try + { + if (exceptionRecursionCounter.Value > formatter.JsonRpc?.ExceptionOptions.RecursionLimit) + { + // Exception recursion has gone too deep. Skip this value and write null as if there were no inner exception. + writer.WriteNil(); + return; + } + + // TODO: Is this the right context? + var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(formatter.rpcContext)); + ExceptionSerializationHelpers.Serialize(value, info); + writer.WriteMapHeader(info.GetSafeMemberCount()); + foreach (SerializationEntry element in info.GetSafeMembers()) + { + writer.Write(element.Name); +#pragma warning disable NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods + formatter.rpcContext.Serializer.SerializeObject( + ref writer, + element.Value, + formatter.rpcContext.ShapeProvider.Resolve(element.ObjectType)); +#pragma warning restore NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods + } + } + finally + { + exceptionRecursionCounter.Value--; + } + } + } + } + + private class JsonRpcMessageConverter : NBMP::MessagePackConverter + { + private readonly NerdbankMessagePackFormatter formatter; + + internal JsonRpcMessageConverter(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter; + } + + public override JsonRpcMessage? Read(ref NBMP.MessagePackReader reader, NBMP::SerializationContext context) + { + context.DepthStep(); + + NBMP::MessagePackReader readAhead = reader.CreatePeekReader(); + int propertyCount = readAhead.ReadMapHeader(); + for (int i = 0; i < propertyCount; i++) + { + // We read the property name in this fancy way in order to avoid paying to decode and allocate a string when we already know what we're looking for. + // MessagePackFormatter: ReadOnlySpan stringKey = MessagePack.Internal.CodeGenHelpers.ReadStringSpan(ref readAhead); + if (!readAhead.TryReadStringSpan(out ReadOnlySpan stringKey)) + { + throw new UnrecognizedJsonRpcMessageException(); + } + + if (MethodPropertyName.TryRead(stringKey)) + { + return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (ResultPropertyName.TryRead(stringKey)) + { + return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (ErrorPropertyName.TryRead(stringKey)) + { + return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else + { + readAhead.Skip(context); + } + } + + throw new UnrecognizedJsonRpcMessageException(); + } + + public override void Write(ref NBMP.MessagePackWriter writer, in JsonRpcMessage? value, NBMP::SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + + using (this.formatter.TrackSerialization(value)) + { + context.DepthStep(); + + switch (value) + { + case Protocol.JsonRpcRequest request: + context.GetConverter(context.TypeShapeProvider).Write(ref writer, request, context); + break; + case Protocol.JsonRpcResult result: + context.GetConverter(context.TypeShapeProvider).Write(ref writer, result, context); + break; + case Protocol.JsonRpcError error: + context.GetConverter(context.TypeShapeProvider).Write(ref writer, error, context); + break; + default: + throw new NotSupportedException("Unexpected JsonRpcMessage-derived type: " + value.GetType().Name); + } + } + } + + public override JsonObject? GetJsonSchema(NBMP::JsonSchemaContext context, ITypeShape typeShape) + { + return base.GetJsonSchema(context, typeShape); + } + } + + private class JsonRpcRequestConverter : NBMP::MessagePackConverter + { + private readonly NerdbankMessagePackFormatter formatter; + + internal JsonRpcRequestConverter(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter; + } + + public override Protocol.JsonRpcRequest? Read(ref NBMP::MessagePackReader reader, NBMP::SerializationContext context) + { + var result = new JsonRpcRequest(this.formatter) + { + OriginalMessagePack = reader.Sequence, + }; + + context.DepthStep(); + + int propertyCount = reader.ReadMapHeader(); + Dictionary>? topLevelProperties = null; + for (int propertyIndex = 0; propertyIndex < propertyCount; propertyIndex++) + { + // We read the property name in this fancy way in order to avoid paying to decode and allocate a string when we already know what we're looking for. + if (!reader.TryReadStringSpan(out ReadOnlySpan stringKey)) + { + throw new UnrecognizedJsonRpcMessageException(); + } + + if (VersionPropertyName.TryRead(stringKey)) + { + result.Version = ReadProtocolVersion(ref reader); + } + else if (IdPropertyName.TryRead(stringKey)) + { + result.RequestId = context.GetConverter(null).Read(ref reader, context); + } + else if (MethodPropertyName.TryRead(stringKey)) + { + result.Method = context.GetConverter(null).Read(ref reader, context); + } + else if (ParamsPropertyName.TryRead(stringKey)) + { + SequencePosition paramsTokenStartPosition = reader.Position; + + // Parse out the arguments into a dictionary or array, but don't deserialize them because we don't yet know what types to deserialize them to. + switch (reader.NextMessagePackType) + { + case NBMP::MessagePackType.Array: + var positionalArgs = new ReadOnlySequence[reader.ReadArrayHeader()]; + for (int i = 0; i < positionalArgs.Length; i++) + { + positionalArgs[i] = GetSliceForNextToken(ref reader, context); + } + + result.MsgPackPositionalArguments = positionalArgs; + break; + case NBMP::MessagePackType.Map: + int namedArgsCount = reader.ReadMapHeader(); + var namedArgs = new Dictionary>(namedArgsCount); + for (int i = 0; i < namedArgsCount; i++) + { + string? propertyName = context.GetConverter(null).Read(ref reader, context); + if (propertyName is null) + { + throw new NBMP::MessagePackSerializationException(Resources.UnexpectedNullValueInMap); + } + + namedArgs.Add(propertyName, GetSliceForNextToken(ref reader, context)); + } + + result.MsgPackNamedArguments = namedArgs; + break; + case NBMP::MessagePackType.Nil: + result.MsgPackPositionalArguments = Array.Empty>(); + reader.ReadNil(); + break; + case NBMP::MessagePackType type: + throw new NBMP::MessagePackSerializationException("Expected a map or array of arguments but got " + type); + } + + result.MsgPackArguments = reader.Sequence.Slice(paramsTokenStartPosition, reader.Position); + } + else if (TraceParentPropertyName.TryRead(stringKey)) + { + TraceParent traceParent = context.GetConverter(null).Read(ref reader, context); + result.TraceParent = traceParent.ToString(); + } + else if (TraceStatePropertyName.TryRead(stringKey)) + { + result.TraceState = ReadTraceState(ref reader, context); + } + else + { + ReadUnknownProperty(ref reader, context, ref topLevelProperties, stringKey); + } + } + + if (topLevelProperties is not null) + { + result.TopLevelPropertyBag = new TopLevelPropertyBag(this.formatter.userDataContext, topLevelProperties); + } + + this.formatter.TryHandleSpecialIncomingMessage(result); + + return result; + } + + public override void Write(ref NBMP.MessagePackWriter writer, in Protocol.JsonRpcRequest? value, NBMP::SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + + context.DepthStep(); + + var topLevelPropertyBag = (TopLevelPropertyBag?)(value as IMessageWithTopLevelPropertyBag)?.TopLevelPropertyBag; + + int mapElementCount = value.RequestId.IsEmpty ? 3 : 4; + if (value.TraceParent?.Length > 0) + { + mapElementCount++; + if (value.TraceState?.Length > 0) + { + mapElementCount++; + } + } + + mapElementCount += topLevelPropertyBag?.PropertyCount ?? 0; + writer.WriteMapHeader(mapElementCount); + + WriteProtocolVersionPropertyAndValue(ref writer, value.Version); + + if (!value.RequestId.IsEmpty) + { + IdPropertyName.Write(ref writer); + context.GetConverter(context.TypeShapeProvider) + .Write(ref writer, value.RequestId, context); + } + + MethodPropertyName.Write(ref writer); + writer.Write(value.Method); + + ParamsPropertyName.Write(ref writer); + + // TODO: Get from SetOptions + ITypeShapeProvider? userShapeProvider = context.TypeShapeProvider; + + if (value.ArgumentsList is not null) + { + writer.WriteArrayHeader(value.ArgumentsList.Count); + + + for (int i = 0; i < value.ArgumentsList.Count; i++) + { + object? arg = value.ArgumentsList[i]; + ITypeShape? argShape = arg is null + ? null + : value.ArgumentListDeclaredTypes is not null + ? userShapeProvider?.GetShape(value.ArgumentListDeclaredTypes[i]) + : ReflectionTypeShapeProvider.Default.Resolve(arg.GetType()); + + if (argShape is not null) + { +#pragma warning disable NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods + this.formatter.userDataContext.Serializer.SerializeObject(ref writer, arg, argShape, context.CancellationToken); +#pragma warning restore NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods + } + else + { + // TODO: NOT REALLY SURE ABOUT THIS YET + writer.WriteNil(); + } + } + } + else if (value.NamedArguments is not null) + { + writer.WriteMapHeader(value.NamedArguments.Count); + foreach (KeyValuePair entry in value.NamedArguments) + { + writer.Write(entry.Key); + ITypeShape? argShape = value.NamedArgumentDeclaredTypes?[entry.Key] is Type argType + ? userShapeProvider?.GetShape(argType) + : null; + + if (argShape is not null) + { +#pragma warning disable NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods + this.formatter.userDataContext.Serializer.SerializeObject(ref writer, entry.Value, argShape, context.CancellationToken); +#pragma warning restore NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods + } + else + { + // TODO: NOT REALLY SURE ABOUT THIS YET + writer.WriteNil(); + } + } + } + else + { + writer.WriteNil(); + } + + if (value.TraceParent?.Length > 0) + { + TraceParentPropertyName.Write(ref writer); + context.GetConverter(context.TypeShapeProvider) + .Write(ref writer, new TraceParent(value.TraceParent), context); + + if (value.TraceState?.Length > 0) + { + TraceStatePropertyName.Write(ref writer); + WriteTraceState(ref writer, value.TraceState); + } + } + + topLevelPropertyBag?.WriteProperties(ref writer); + } + + public override JsonObject? GetJsonSchema(NBMP::JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(JsonRpcRequestConverter)); + } + + private static void WriteTraceState(ref NBMP::MessagePackWriter writer, string traceState) + { + ReadOnlySpan traceStateChars = traceState.AsSpan(); + + // Count elements first so we can write the header. + int elementCount = 1; + int commaIndex; + while ((commaIndex = traceStateChars.IndexOf(',')) >= 0) + { + elementCount++; + traceStateChars = traceStateChars.Slice(commaIndex + 1); + } + + // For every element, we have a key and value to record. + writer.WriteArrayHeader(elementCount * 2); + + traceStateChars = traceState.AsSpan(); + while ((commaIndex = traceStateChars.IndexOf(',')) >= 0) + { + ReadOnlySpan element = traceStateChars.Slice(0, commaIndex); + WritePair(ref writer, element); + traceStateChars = traceStateChars.Slice(commaIndex + 1); + } + + // Write out the last one. + WritePair(ref writer, traceStateChars); + + static void WritePair(ref NBMP::MessagePackWriter writer, ReadOnlySpan pair) + { + int equalsIndex = pair.IndexOf('='); + ReadOnlySpan key = pair.Slice(0, equalsIndex); + ReadOnlySpan value = pair.Slice(equalsIndex + 1); + writer.Write(key); + writer.Write(value); + } + } + + private static unsafe string ReadTraceState(ref NBMP::MessagePackReader reader, NBMP::SerializationContext context) + { + int elements = reader.ReadArrayHeader(); + if (elements % 2 != 0) + { + throw new NotSupportedException("Odd number of elements not expected."); + } + + // With care, we could probably assemble this string with just two allocations (the string + a char[]). + var resultBuilder = new StringBuilder(); + for (int i = 0; i < elements; i += 2) + { + if (resultBuilder.Length > 0) + { + resultBuilder.Append(','); + } + + // We assume the key is a frequent string, and the value is unique, + // so we optimize whether to use string interning or not on that basis. + resultBuilder.Append(context.GetConverter(null).Read(ref reader, context)); + resultBuilder.Append('='); + resultBuilder.Append(reader.ReadString()); + } + + return resultBuilder.ToString(); + } + } + + private partial class JsonRpcResultConverter : NBMP::MessagePackConverter + { + private readonly NerdbankMessagePackFormatter formatter; + + internal JsonRpcResultConverter(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter; + } + + public override Protocol.JsonRpcResult Read(ref NBMP.MessagePackReader reader, NBMP::SerializationContext context) + { + var result = new JsonRpcResult(this.formatter, this.formatter.userDataContext) + { + OriginalMessagePack = reader.Sequence, + }; + + Dictionary>? topLevelProperties = null; + context.DepthStep(); + + int propertyCount = reader.ReadMapHeader(); + for (int propertyIndex = 0; propertyIndex < propertyCount; propertyIndex++) + { + // We read the property name in this fancy way in order to avoid paying to decode and allocate a string when we already know what we're looking for. + if (!reader.TryReadStringSpan(out ReadOnlySpan stringKey)) + { + throw new UnrecognizedJsonRpcMessageException(); + } + + if (VersionPropertyName.TryRead(stringKey)) + { + result.Version = ReadProtocolVersion(ref reader); + } + else if (IdPropertyName.TryRead(stringKey)) + { + result.RequestId = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (ResultPropertyName.TryRead(stringKey)) + { + result.MsgPackResult = GetSliceForNextToken(ref reader, context); + } + else + { + ReadUnknownProperty(ref reader, context, ref topLevelProperties, stringKey); + } + } + + if (topLevelProperties is not null) + { + result.TopLevelPropertyBag = new TopLevelPropertyBag(this.formatter.userDataContext, topLevelProperties); + } + + return result; + } + + public override void Write(ref NBMP.MessagePackWriter writer, in Protocol.JsonRpcResult? value, NBMP::SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + + var topLevelPropertyBagMessage = value as IMessageWithTopLevelPropertyBag; + + int mapElementCount = 3; + mapElementCount += (topLevelPropertyBagMessage?.TopLevelPropertyBag as TopLevelPropertyBag)?.PropertyCount ?? 0; + writer.WriteMapHeader(mapElementCount); + + WriteProtocolVersionPropertyAndValue(ref writer, value.Version); + + IdPropertyName.Write(ref writer); + context.GetConverter(context.TypeShapeProvider).Write(ref writer, value.RequestId, context); + + ResultPropertyName.Write(ref writer); + + ITypeShape? typeShape = value.ResultDeclaredType is not null && value.ResultDeclaredType != typeof(void) + ? this.formatter.userDataContext.ShapeProvider.Resolve(value.ResultDeclaredType) + : value.Result is null + ? null + : this.formatter.userDataContext.ShapeProvider.Resolve(value.Result.GetType()); + + if (typeShape is not null) + { +#pragma warning disable NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods + this.formatter.userDataContext.Serializer.SerializeObject(ref writer, value.Result, typeShape, context.CancellationToken); +#pragma warning restore NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods + } + else + { + // TODO: NOT REALLY SURE ABOUT THIS YET + writer.WriteNil(); + } + + (topLevelPropertyBagMessage?.TopLevelPropertyBag as TopLevelPropertyBag)?.WriteProperties(ref writer); + } + + public override JsonObject? GetJsonSchema(NBMP::JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(JsonRpcResultConverter)); + } + } + + private partial class JsonRpcErrorConverter : MessagePackConverter + { + private readonly NerdbankMessagePackFormatter formatter; + + internal JsonRpcErrorConverter(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter; + } + + public override Protocol.JsonRpcError Read(ref NBMP::MessagePackReader reader, SerializationContext context) + { + var error = new JsonRpcError(this.formatter.rpcContext) + { + OriginalMessagePack = reader.Sequence, + }; + + Dictionary>? topLevelProperties = null; + + context.DepthStep(); + + int propertyCount = reader.ReadMapHeader(); + for (int propertyIdx = 0; propertyIdx < propertyCount; propertyIdx++) + { + // We read the property name in this fancy way in order to avoid paying to decode and allocate a string when we already know what we're looking for. + if (!reader.TryReadStringSpan(out ReadOnlySpan stringKey)) + { + throw new UnrecognizedJsonRpcMessageException(); + } + + if (VersionPropertyName.TryRead(stringKey)) + { + error.Version = ReadProtocolVersion(ref reader); + } + else if (IdPropertyName.TryRead(stringKey)) + { + error.RequestId = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (ErrorPropertyName.TryRead(stringKey)) + { + error.Error = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else + { + ReadUnknownProperty(ref reader, context, ref topLevelProperties, stringKey); + } + } + + if (topLevelProperties is not null) + { + error.TopLevelPropertyBag = new TopLevelPropertyBag(this.formatter.userDataContext, topLevelProperties); + } + + return error; + } + + public override void Write(ref NBMP::MessagePackWriter writer, in Protocol.JsonRpcError? value, SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + + var topLevelPropertyBag = (TopLevelPropertyBag?)(value as IMessageWithTopLevelPropertyBag)?.TopLevelPropertyBag; + + int mapElementCount = 3; + mapElementCount += topLevelPropertyBag?.PropertyCount ?? 0; + writer.WriteMapHeader(mapElementCount); + + WriteProtocolVersionPropertyAndValue(ref writer, value.Version); + + IdPropertyName.Write(ref writer); + context.GetConverter(context.TypeShapeProvider).Write(ref writer, value.RequestId, context); + + ErrorPropertyName.Write(ref writer); + context.GetConverter(context.TypeShapeProvider).Write(ref writer, value.Error, context); + + topLevelPropertyBag?.WriteProperties(ref writer); + } + + public override JsonObject? GetJsonSchema(NBMP::JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(JsonRpcErrorConverter)); + } + } + + private partial class JsonRpcErrorDetailConverter : MessagePackConverter + { + private static readonly CommonString CodePropertyName = new("code"); + private static readonly CommonString MessagePropertyName = new("message"); + private static readonly CommonString DataPropertyName = new("data"); + + private readonly NerdbankMessagePackFormatter formatter; + + internal JsonRpcErrorDetailConverter(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter; + } + + public override Protocol.JsonRpcError.ErrorDetail Read(ref NBMP.MessagePackReader reader, SerializationContext context) + { + var result = new JsonRpcError.ErrorDetail(this.formatter.userDataContext); + context.DepthStep(); + + int propertyCount = reader.ReadMapHeader(); + for (int propertyIdx = 0; propertyIdx < propertyCount; propertyIdx++) + { + if (!reader.TryReadStringSpan(out ReadOnlySpan stringKey)) + { + throw new UnrecognizedJsonRpcMessageException(); + } + + if (CodePropertyName.TryRead(stringKey)) + { + result.Code = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (MessagePropertyName.TryRead(stringKey)) + { + result.Message = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (DataPropertyName.TryRead(stringKey)) + { + result.MsgPackData = GetSliceForNextToken(ref reader, context); + } + else + { + reader.Skip(context); + } + } + + return result; + } + + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] + public override void Write(ref NBMP.MessagePackWriter writer, in Protocol.JsonRpcError.ErrorDetail? value, SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + + writer.WriteMapHeader(3); + + CodePropertyName.Write(ref writer); + context.GetConverter(context.TypeShapeProvider).Write(ref writer, value.Code, context); + + MessagePropertyName.Write(ref writer); + writer.Write(value.Message); + + DataPropertyName.Write(ref writer); +#pragma warning disable NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods + this.formatter.userDataContext.Serializer.SerializeObject( + ref writer, + value.Data, + this.formatter.userDataContext.ShapeProvider.Resolve()); +#pragma warning restore NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(JsonRpcErrorDetailConverter)); + } + } + + /// + /// Enables formatting the default/empty class. + /// + private class EventArgsConverter : MessagePackConverter + { + internal static readonly EventArgsConverter Instance = new(); + + private EventArgsConverter() + { + } + + /// + public override void Write(ref NBMP.MessagePackWriter writer, in EventArgs? value, SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + writer.WriteMapHeader(0); + } + + /// + public override EventArgs Read(ref NBMP.MessagePackReader reader, SerializationContext context) + { + reader.Skip(context); + return EventArgs.Empty; + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(EventArgsConverter)); + } + } + + private class TraceParentConverter : MessagePackConverter + { + public unsafe override TraceParent Read(ref NBMP.MessagePackReader reader, SerializationContext context) + { + if (reader.ReadArrayHeader() != 2) + { + throw new NotSupportedException("Unexpected array length."); + } + + var result = default(TraceParent); + result.Version = reader.ReadByte(); + if (result.Version != 0) + { + throw new NotSupportedException("traceparent version " + result.Version + " is not supported."); + } + + if (reader.ReadArrayHeader() != 3) + { + throw new NotSupportedException("Unexpected array length in version-format."); + } + + ReadOnlySequence bytes = reader.ReadBytes() ?? throw new NotSupportedException("Expected traceid not found."); + bytes.CopyTo(new Span(result.TraceId, TraceParent.TraceIdByteCount)); + + bytes = reader.ReadBytes() ?? throw new NotSupportedException("Expected parentid not found."); + bytes.CopyTo(new Span(result.ParentId, TraceParent.ParentIdByteCount)); + + result.Flags = (TraceParent.TraceFlags)reader.ReadByte(); + + return result; + } + + public unsafe override void Write(ref NBMP.MessagePackWriter writer, in TraceParent value, SerializationContext context) + { + if (value.Version != 0) + { + throw new NotSupportedException("traceparent version " + value.Version + " is not supported."); + } + + writer.WriteArrayHeader(2); + + writer.Write(value.Version); + + writer.WriteArrayHeader(3); + + fixed (byte* traceId = value.TraceId) + { + writer.Write(new ReadOnlySpan(traceId, TraceParent.TraceIdByteCount)); + } + + fixed (byte* parentId = value.ParentId) + { + writer.Write(new ReadOnlySpan(parentId, TraceParent.ParentIdByteCount)); + } + + writer.Write((byte)value.Flags); + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(TraceParentConverter)); + } + } + + private class TopLevelPropertyBag : TopLevelPropertyBagBase + { + private readonly FormatterContext formatterContext; + private readonly IReadOnlyDictionary>? inboundUnknownProperties; + + /// + /// Initializes a new instance of the class + /// for an incoming message. + /// + /// The serializer options to use for this data. + /// The map of unrecognized inbound properties. + internal TopLevelPropertyBag(FormatterContext userDataContext, IReadOnlyDictionary> inboundUnknownProperties) + : base(isOutbound: false) + { + this.formatterContext = userDataContext; + this.inboundUnknownProperties = inboundUnknownProperties; + } + + /// + /// Initializes a new instance of the class + /// for an outbound message. + /// + /// The serializer options to use for this data. + internal TopLevelPropertyBag(FormatterContext formatterContext) + : base(isOutbound: true) + { + this.formatterContext = formatterContext; + } + + internal int PropertyCount => this.inboundUnknownProperties?.Count ?? this.OutboundProperties?.Count ?? 0; + + /// + /// Writes the properties tracked by this collection to a messagepack writer. + /// + /// The writer to use. + internal void WriteProperties(ref NBMP::MessagePackWriter writer) + { + if (this.inboundUnknownProperties is not null) + { + // We're actually re-transmitting an incoming message (remote target feature). + // We need to copy all the properties that were in the original message. + // Don't implement this without enabling the tests for the scenario found in JsonRpcRemoteTargetMessagePackFormatterTests.cs. + // The tests fail for reasons even without this support, so there's work to do beyond just implementing this. + throw new NotImplementedException(); + + ////foreach (KeyValuePair> entry in this.inboundUnknownProperties) + ////{ + //// writer.Write(entry.Key); + //// writer.Write(entry.Value); + ////} + } + else + { + foreach (KeyValuePair entry in this.OutboundProperties) + { + ITypeShape shape = this.formatterContext.ShapeProvider.Resolve(entry.Value.DeclaredType); + + writer.Write(entry.Key); + this.formatterContext.Serializer.SerializeObject(ref writer, entry.Value.Value, shape); + } + } + } + + protected internal override bool TryGetTopLevelProperty(string name, [MaybeNull] out T value) + { + if (this.inboundUnknownProperties is null) + { + throw new InvalidOperationException(Resources.InboundMessageOnly); + } + + value = default; + + if (this.inboundUnknownProperties.TryGetValue(name, out ReadOnlySequence serializedValue) is true) + { + var reader = new NBMP::MessagePackReader(serializedValue); + value = this.formatterContext.Serializer.Deserialize(ref reader, this.formatterContext.ShapeProvider); + return true; + } + + return false; + } + } + + [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] + [DataContract] + private class OutboundJsonRpcRequest : JsonRpcRequestBase + { + private readonly NerdbankMessagePackFormatter formatter; + + internal OutboundJsonRpcRequest(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter ?? throw new ArgumentNullException(nameof(formatter)); + } + + protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.formatter.userDataContext); + } + + [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] + [DataContract] + private class JsonRpcRequest : JsonRpcRequestBase, IJsonRpcMessagePackRetention + { + private readonly NerdbankMessagePackFormatter formatter; + + internal JsonRpcRequest(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter ?? throw new ArgumentNullException(nameof(formatter)); + } + + public override int ArgumentCount => this.MsgPackNamedArguments?.Count ?? this.MsgPackPositionalArguments?.Count ?? base.ArgumentCount; + + public override IEnumerable? ArgumentNames => this.MsgPackNamedArguments?.Keys; + + public ReadOnlySequence OriginalMessagePack { get; internal set; } + + internal ReadOnlySequence MsgPackArguments { get; set; } + + internal IReadOnlyDictionary>? MsgPackNamedArguments { get; set; } + + internal IReadOnlyList>? MsgPackPositionalArguments { get; set; } + + public override ArgumentMatchResult TryGetTypedArguments(ReadOnlySpan parameters, Span typedArguments) + { + using (this.formatter.TrackDeserialization(this, parameters)) + { + if (parameters.Length == 1 && this.MsgPackNamedArguments is not null) + { + if (this.formatter.ApplicableMethodAttributeOnDeserializingMethod?.UseSingleObjectParameterDeserialization ?? false) + { + var reader = new NBMP::MessagePackReader(this.MsgPackArguments); + try + { + typedArguments[0] = this.formatter.userDataContext.Serializer.DeserializeObject( + ref reader, + this.formatter.userDataContext.ShapeProvider.Resolve(parameters[0].ParameterType)); + + return ArgumentMatchResult.Success; + } + catch (NBMP::MessagePackSerializationException) + { + return ArgumentMatchResult.ParameterArgumentTypeMismatch; + } + } + } + + return base.TryGetTypedArguments(parameters, typedArguments); + } + } + + public override bool TryGetArgumentByNameOrIndex(string? name, int position, Type? typeHint, out object? value) + { + // If anyone asks us for an argument *after* we've been told deserialization is done, there's something very wrong. + Assumes.True(this.MsgPackNamedArguments is not null || this.MsgPackPositionalArguments is not null); + + ReadOnlySequence msgpackArgument = default; + if (position >= 0 && this.MsgPackPositionalArguments?.Count > position) + { + msgpackArgument = this.MsgPackPositionalArguments[position]; + } + else if (name is not null && this.MsgPackNamedArguments is not null) + { + this.MsgPackNamedArguments.TryGetValue(name, out msgpackArgument); + } + + if (msgpackArgument.IsEmpty) + { + value = null; + return false; + } + + var reader = new NBMP::MessagePackReader(msgpackArgument); + using (this.formatter.TrackDeserialization(this)) + { + try + { + value = this.formatter.userDataContext.Serializer.DeserializeObject( + ref reader, + this.formatter.userDataContext.ShapeProvider.Resolve(typeHint ?? typeof(object))); + + return true; + } + catch (NBMP::MessagePackSerializationException ex) + { + if (this.formatter.JsonRpc?.TraceSource.Switch.ShouldTrace(TraceEventType.Warning) ?? false) + { + this.formatter.JsonRpc.TraceSource.TraceEvent(TraceEventType.Warning, (int)JsonRpc.TraceEvents.MethodArgumentDeserializationFailure, Resources.FailureDeserializingRpcArgument, name, position, typeHint, ex); + } + + throw new RpcArgumentDeserializationException(name, position, typeHint, ex); + } + } + } + + protected override void ReleaseBuffers() + { + base.ReleaseBuffers(); + this.MsgPackNamedArguments = null; + this.MsgPackPositionalArguments = null; + this.TopLevelPropertyBag = null; + this.MsgPackArguments = default; + this.OriginalMessagePack = default; + } + + protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.formatter.userDataContext); + } + + [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] + [DataContract] + private class JsonRpcResult : JsonRpcResultBase, IJsonRpcMessagePackRetention + { + private readonly NerdbankMessagePackFormatter formatter; + private readonly FormatterContext serializerOptions; + + private Exception? resultDeserializationException; + + internal JsonRpcResult(NerdbankMessagePackFormatter formatter, FormatterContext serializationOptions) + { + this.formatter = formatter; + this.serializerOptions = serializationOptions; + } + + public ReadOnlySequence OriginalMessagePack { get; internal set; } + + internal ReadOnlySequence MsgPackResult { get; set; } + + public override T GetResult() + { + if (this.resultDeserializationException is not null) + { + ExceptionDispatchInfo.Capture(this.resultDeserializationException).Throw(); + } + + return this.MsgPackResult.IsEmpty + ? (T)this.Result! + : this.serializerOptions.Serializer.Deserialize(this.MsgPackResult, this.serializerOptions.ShapeProvider) + ?? throw new NBMP::MessagePackSerializationException(Resources.FailureDeserializingJsonRpc); + } + + protected internal override void SetExpectedResultType(Type resultType) + { + Verify.Operation(!this.MsgPackResult.IsEmpty, "Result is no longer available or has already been deserialized."); + + var reader = new NBMP::MessagePackReader(this.MsgPackResult); + try + { + using (this.formatter.TrackDeserialization(this)) + { + this.Result = this.serializerOptions.Serializer.DeserializeObject( + ref reader, + this.serializerOptions.ShapeProvider.Resolve(resultType)); + } + + this.MsgPackResult = default; + } + catch (NBMP::MessagePackSerializationException ex) + { + // This was a best effort anyway. We'll throw again later at a more convenient time for JsonRpc. + this.resultDeserializationException = ex; + } + } + + protected override void ReleaseBuffers() + { + base.ReleaseBuffers(); + this.MsgPackResult = default; + this.OriginalMessagePack = default; + } + + protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.serializerOptions); + } + + [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] + [DataContract] + private class JsonRpcError : JsonRpcErrorBase, IJsonRpcMessagePackRetention + { + private readonly FormatterContext serializerOptions; + + public JsonRpcError(FormatterContext serializerOptions) + { + this.serializerOptions = serializerOptions; + } + + public ReadOnlySequence OriginalMessagePack { get; internal set; } + + protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.serializerOptions); + + protected override void ReleaseBuffers() + { + base.ReleaseBuffers(); + if (this.Error is ErrorDetail privateDetail) + { + privateDetail.MsgPackData = default; + } + + this.OriginalMessagePack = default; + } + + [DataContract] + internal new class ErrorDetail : Protocol.JsonRpcError.ErrorDetail + { + private readonly FormatterContext serializerOptions; + + internal ErrorDetail(FormatterContext serializerOptions) + { + this.serializerOptions = serializerOptions ?? throw new ArgumentNullException(nameof(serializerOptions)); + } + + internal ReadOnlySequence MsgPackData { get; set; } + + public override object? GetData(Type dataType) + { + Requires.NotNull(dataType, nameof(dataType)); + if (this.MsgPackData.IsEmpty) + { + return this.Data; + } + + var reader = new NBMP::MessagePackReader(this.MsgPackData); + try + { + return this.serializerOptions.Serializer.DeserializeObject( + ref reader, + this.serializerOptions.ShapeProvider.Resolve(dataType)) + ?? throw new NBMP::MessagePackSerializationException(Resources.FailureDeserializingJsonRpc); + } + catch (NBMP::MessagePackSerializationException) + { + // Deserialization failed. Try returning array/dictionary based primitive objects. + try + { + // return MessagePackSerializer.Deserialize(this.MsgPackData, this.serializerOptions.WithResolver(PrimitiveObjectResolver.Instance)); + // TODO: Which Shape Provider to use? + return this.serializerOptions.Serializer.Deserialize(this.MsgPackData, this.serializerOptions.ShapeProvider); + } + catch (NBMP::MessagePackSerializationException) + { + return null; + } + } + } + + protected internal override void SetExpectedDataType(Type dataType) + { + Verify.Operation(!this.MsgPackData.IsEmpty, "Data is no longer available or has already been deserialized."); + + this.Data = this.GetData(dataType); + + // Clear the source now that we've deserialized to prevent GetData from attempting + // deserialization later when the buffer may be recycled on another thread. + this.MsgPackData = default; + } + } + } + + private record FormatterContext(NBMP::MessagePackSerializer Serializer, ITypeShapeProvider ShapeProvider); + + private class FormatterContextBuilder(NBMP::MessagePackSerializer serializer) : IFormatterContextBuilder + { + private readonly CompositeTypeShapeProviderBuilder providerBuilder = new(); + + public ICompositeTypeShapeProviderBuilder TypeShapeProviderBuilder => this.providerBuilder; + + public void RegisterConverter(NBMP::MessagePackConverter converter) => serializer.RegisterConverter(converter); + + public void RegisterKnownSubTypes(NBMP::KnownSubTypeMapping mapping) + { + Requires.NotNull(mapping, nameof(mapping)); + serializer.RegisterKnownSubTypes(mapping); + } + + internal FormatterContext Build() => new(serializer, this.providerBuilder.Build()); + } + + private class CompositeTypeShapeProviderBuilder : ICompositeTypeShapeProviderBuilder + { + private readonly List providers = []; + + public ICompositeTypeShapeProviderBuilder Add(ITypeShapeProvider provider) + { + this.providers.Add(provider); + return this; + } + + public ICompositeTypeShapeProviderBuilder AddRange(IEnumerable providers) + { + this.providers.AddRange(providers); + return this; + } + + public ICompositeTypeShapeProviderBuilder AddReflectionTypeShapeProvider(bool useReflectionEmit) + { + ReflectionTypeShapeProviderOptions options = new() + { + UseReflectionEmit = useReflectionEmit, + }; + + this.providers.Add(ReflectionTypeShapeProvider.Create(options)); + return this; + } + + public ITypeShapeProvider Build() + { + return this.providers.Count switch + { + 0 => ReflectionTypeShapeProvider.Default, + 1 => this.providers[0], + _ => new CompositeTypeShapeProvider(this.providers.AsReadOnly()), + }; + } + + private class CompositeTypeShapeProvider : ITypeShapeProvider + { + private readonly ReadOnlyCollection providers; + + internal CompositeTypeShapeProvider(ReadOnlyCollection providers) + { + this.providers = providers; + } + + public ITypeShape? GetShape(Type type) + { + foreach (ITypeShapeProvider provider in this.providers) + { + ITypeShape? shape = provider.GetShape(type); + if (shape is not null) + { + return shape; + } + } + + return null; + } + } + } +} diff --git a/src/StreamJsonRpc/Protocol/JsonRpcError.cs b/src/StreamJsonRpc/Protocol/JsonRpcError.cs index 7905eb811..d8500c09a 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcError.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcError.cs @@ -5,6 +5,7 @@ using System.Runtime.Serialization; using StreamJsonRpc.Reflection; using JsonNET = Newtonsoft.Json.Linq; +using PT = PolyType; using STJ = System.Text.Json.Serialization; namespace StreamJsonRpc.Protocol; @@ -13,14 +14,16 @@ namespace StreamJsonRpc.Protocol; /// Describes the error resulting from a that failed on the server. /// [DataContract] +[PT.GenerateShape] [DebuggerDisplay("{" + nameof(DebuggerDisplay) + "}")] -public class JsonRpcError : JsonRpcMessage, IJsonRpcMessageWithId +public partial class JsonRpcError : JsonRpcMessage, IJsonRpcMessageWithId { /// /// Gets or sets the detail about the error. /// [DataMember(Name = "error", Order = 2, IsRequired = true)] [STJ.JsonPropertyName("error"), STJ.JsonPropertyOrder(2), STJ.JsonRequired] + [PT.PropertyShape(Name = "error", Order = 2)] public ErrorDetail? Error { get; set; } /// @@ -30,6 +33,7 @@ public class JsonRpcError : JsonRpcMessage, IJsonRpcMessageWithId [Obsolete("Use " + nameof(RequestId) + " instead.")] [IgnoreDataMember] [STJ.JsonIgnore] + [PT.PropertyShape(Ignore = true)] public object? Id { get => this.RequestId.ObjectValue; @@ -41,6 +45,7 @@ public object? Id /// [DataMember(Name = "id", Order = 1, IsRequired = true, EmitDefaultValue = true)] [STJ.JsonPropertyName("id"), STJ.JsonPropertyOrder(1), STJ.JsonRequired] + [PT.PropertyShape(Name = "id", Order = 1)] public RequestId RequestId { get; set; } /// @@ -77,6 +82,7 @@ public class ErrorDetail /// [DataMember(Name = "code", Order = 0, IsRequired = true)] [STJ.JsonPropertyName("code"), STJ.JsonPropertyOrder(0), STJ.JsonRequired] + [PT.PropertyShape(Name = "code", Order = 0)] public JsonRpcErrorCode Code { get; set; } /// @@ -87,6 +93,7 @@ public class ErrorDetail /// [DataMember(Name = "message", Order = 1, IsRequired = true)] [STJ.JsonPropertyName("message"), STJ.JsonPropertyOrder(1), STJ.JsonRequired] + [PT.PropertyShape(Name = "message", Order = 1)] public string? Message { get; set; } /// @@ -95,6 +102,7 @@ public class ErrorDetail [DataMember(Name = "data", Order = 2, IsRequired = false)] [Newtonsoft.Json.JsonProperty(DefaultValueHandling = Newtonsoft.Json.DefaultValueHandling.Ignore)] [STJ.JsonPropertyName("data"), STJ.JsonPropertyOrder(2)] + [PT.PropertyShape(Name = "data", Order = 2)] public object? Data { get; set; } /// diff --git a/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs b/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs index 84acc9373..e717a2b2a 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs @@ -3,6 +3,8 @@ using System.Diagnostics.CodeAnalysis; using System.Runtime.Serialization; +using NBMP = Nerdbank.MessagePack; +using PT = PolyType; using STJ = System.Text.Json.Serialization; namespace StreamJsonRpc.Protocol; @@ -14,6 +16,11 @@ namespace StreamJsonRpc.Protocol; [KnownType(typeof(JsonRpcRequest))] [KnownType(typeof(JsonRpcResult))] [KnownType(typeof(JsonRpcError))] +#pragma warning disable CS0618 //'KnownSubTypeAttribute.KnownSubTypeAttribute(Type)' is obsolete: 'Use the generic version of this attribute instead.' +[NBMP::KnownSubType(typeof(JsonRpcRequest))] +[NBMP::KnownSubType(typeof(JsonRpcResult))] +[NBMP::KnownSubType(typeof(JsonRpcError))] +#pragma warning restore CS0618 public abstract class JsonRpcMessage { /// @@ -22,6 +29,7 @@ public abstract class JsonRpcMessage /// Defaults to "2.0". [DataMember(Name = "jsonrpc", Order = 0, IsRequired = true)] [STJ.JsonPropertyName("jsonrpc"), STJ.JsonPropertyOrder(0), STJ.JsonRequired] + [PT.PropertyShape(Name = "jsonrpc", Order = 0)] public string Version { get; set; } = "2.0"; /// diff --git a/src/StreamJsonRpc/Protocol/JsonRpcRequest.cs b/src/StreamJsonRpc/Protocol/JsonRpcRequest.cs index c41239ac6..4f3fca685 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcRequest.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcRequest.cs @@ -5,6 +5,7 @@ using System.Reflection; using System.Runtime.Serialization; using JsonNET = Newtonsoft.Json.Linq; +using PT = PolyType; using STJ = System.Text.Json.Serialization; namespace StreamJsonRpc.Protocol; @@ -13,8 +14,9 @@ namespace StreamJsonRpc.Protocol; /// Describes a method to be invoked on the server. /// [DataContract] +[PT.GenerateShape] [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] -public class JsonRpcRequest : JsonRpcMessage, IJsonRpcMessageWithId +public partial class JsonRpcRequest : JsonRpcMessage, IJsonRpcMessageWithId { /// /// The result of an attempt to match request arguments with a candidate method's parameters. @@ -47,6 +49,7 @@ public enum ArgumentMatchResult /// [DataMember(Name = "method", Order = 2, IsRequired = true)] [STJ.JsonPropertyName("method"), STJ.JsonPropertyOrder(2), STJ.JsonRequired] + [PT.PropertyShape(Name = "method", Order = 2)] public string? Method { get; set; } /// @@ -61,6 +64,7 @@ public enum ArgumentMatchResult /// [DataMember(Name = "params", Order = 3, IsRequired = false, EmitDefaultValue = false)] [STJ.JsonPropertyName("params"), STJ.JsonPropertyOrder(3), STJ.JsonIgnore(Condition = STJ.JsonIgnoreCondition.WhenWritingNull)] + [PT.PropertyShape(Name = "params", Order = 1)] public object? Arguments { get; set; } /// @@ -70,6 +74,7 @@ public enum ArgumentMatchResult [Obsolete("Use " + nameof(RequestId) + " instead.")] [IgnoreDataMember] [STJ.JsonIgnore] + [PT.PropertyShape(Ignore = true)] public object? Id { get => this.RequestId.ObjectValue; @@ -81,6 +86,7 @@ public object? Id /// [DataMember(Name = "id", Order = 1, IsRequired = false, EmitDefaultValue = false)] [STJ.JsonPropertyName("id"), STJ.JsonPropertyOrder(1), STJ.JsonIgnore(Condition = STJ.JsonIgnoreCondition.WhenWritingDefault)] + [PT.PropertyShape(Name = "id", Order = 1)] public RequestId RequestId { get; set; } /// @@ -88,6 +94,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] + [PT.PropertyShape(Ignore = true)] public bool IsResponseExpected => !this.RequestId.IsEmpty; /// @@ -95,6 +102,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] + [PT.PropertyShape(Ignore = true)] public bool IsNotification => this.RequestId.IsEmpty; /// @@ -102,6 +110,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] + [PT.PropertyShape(Ignore = true)] public virtual int ArgumentCount => this.NamedArguments?.Count ?? this.ArgumentsList?.Count ?? 0; /// @@ -109,6 +118,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] + [PT.PropertyShape(Ignore = true)] public IReadOnlyDictionary? NamedArguments { get => this.Arguments as IReadOnlyDictionary; @@ -127,6 +137,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] + [PT.PropertyShape(Ignore = true)] public IReadOnlyDictionary? NamedArgumentDeclaredTypes { get; set; } /// @@ -134,6 +145,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] + [PT.PropertyShape(Ignore = true)] [Obsolete("Use " + nameof(ArgumentsList) + " instead.")] public object?[]? ArgumentsArray { @@ -146,6 +158,7 @@ public object?[]? ArgumentsArray /// [IgnoreDataMember] [STJ.JsonIgnore] + [PT.PropertyShape(Ignore = true)] public IReadOnlyList? ArgumentsList { get => this.Arguments as IReadOnlyList; @@ -166,6 +179,7 @@ public IReadOnlyList? ArgumentsList /// [IgnoreDataMember] [STJ.JsonIgnore] + [PT.PropertyShape(Ignore = true)] public IReadOnlyList? ArgumentListDeclaredTypes { get; set; } /// @@ -173,6 +187,7 @@ public IReadOnlyList? ArgumentsList /// [IgnoreDataMember] [STJ.JsonIgnore] + [PT.PropertyShape(Ignore = true)] public virtual IEnumerable? ArgumentNames => this.NamedArguments?.Keys; /// @@ -180,6 +195,7 @@ public IReadOnlyList? ArgumentsList /// [DataMember(Name = "traceparent", EmitDefaultValue = false)] [STJ.JsonPropertyName("traceparent"), STJ.JsonIgnore(Condition = STJ.JsonIgnoreCondition.WhenWritingNull)] + [PT.PropertyShape(Name = "traceparent")] public string? TraceParent { get; set; } /// @@ -187,6 +203,7 @@ public IReadOnlyList? ArgumentsList /// [DataMember(Name = "tracestate", EmitDefaultValue = false)] [STJ.JsonPropertyName("tracestate"), STJ.JsonIgnore(Condition = STJ.JsonIgnoreCondition.WhenWritingNull)] + [PT.PropertyShape(Name = "tracestate")] public string? TraceState { get; set; } /// diff --git a/src/StreamJsonRpc/Protocol/JsonRpcResult.cs b/src/StreamJsonRpc/Protocol/JsonRpcResult.cs index 6bd3157e6..e81ba931a 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcResult.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcResult.cs @@ -4,6 +4,7 @@ using System.Diagnostics; using System.Runtime.Serialization; using JsonNET = Newtonsoft.Json.Linq; +using PT = PolyType; using STJ = System.Text.Json.Serialization; namespace StreamJsonRpc.Protocol; @@ -13,13 +14,15 @@ namespace StreamJsonRpc.Protocol; /// [DataContract] [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] -public class JsonRpcResult : JsonRpcMessage, IJsonRpcMessageWithId +[PT.GenerateShape] +public partial class JsonRpcResult : JsonRpcMessage, IJsonRpcMessageWithId { /// /// Gets or sets the value of the result of an invocation, if any. /// [DataMember(Name = "result", Order = 2, IsRequired = true, EmitDefaultValue = true)] [STJ.JsonPropertyName("result"), STJ.JsonPropertyOrder(2), STJ.JsonRequired] + [PT.PropertyShape(Name = "result", Order = 2)] public object? Result { get; set; } /// @@ -30,6 +33,7 @@ public class JsonRpcResult : JsonRpcMessage, IJsonRpcMessageWithId /// [IgnoreDataMember] [STJ.JsonIgnore] + [PT.PropertyShape(Ignore = true)] public Type? ResultDeclaredType { get; set; } /// @@ -39,6 +43,7 @@ public class JsonRpcResult : JsonRpcMessage, IJsonRpcMessageWithId [Obsolete("Use " + nameof(RequestId) + " instead.")] [IgnoreDataMember] [STJ.JsonIgnore] + [PT.PropertyShape(Ignore = true)] public object? Id { get => this.RequestId.ObjectValue; @@ -50,6 +55,7 @@ public object? Id /// [DataMember(Name = "id", Order = 1, IsRequired = true)] [STJ.JsonPropertyName("id"), STJ.JsonPropertyOrder(1), STJ.JsonRequired] + [PT.PropertyShape(Name = "id", Order = 1)] public RequestId RequestId { get; set; } /// From 052592e2ddbb9b2c9d29a19ea9c37d6566e4f973 Mon Sep 17 00:00:00 2001 From: Charles Willis Date: Fri, 20 Dec 2024 23:29:46 +0000 Subject: [PATCH 03/25] Refactor NerdbankMessagePackFormatter for converters Replaces the `ProgressFormatterResolver` and `AsyncEnumerableFormatterResolver` with `ProgressConverterResolver` and `AsyncEnumerableConverterResolver`. --- .../NerdbankMessagePackFormatter.cs | 95 +++++++++---------- 1 file changed, 47 insertions(+), 48 deletions(-) diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs index 48f4a53ca..25ebda05d 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs @@ -15,7 +15,6 @@ using System.Text.Json.Nodes; using MessagePack; using MessagePack.Formatters; -using MessagePack.Resolvers; using Nerdbank.MessagePack; using Nerdbank.Streams; using PolyType; @@ -94,9 +93,9 @@ public sealed partial class NerdbankMessagePackFormatter : FormatterBase, IJsonR /// private readonly FormatterContext rpcContext; - private readonly ProgressFormatterResolver progressFormatterResolver; + private readonly ProgressConverterResolver progressConverterResolver; - private readonly AsyncEnumerableFormatterResolver asyncEnumerableFormatterResolver; + private readonly AsyncEnumerableConverterResolver asyncEnumerableConverterResolver; private readonly PipeFormatterResolver pipeFormatterResolver; @@ -134,8 +133,10 @@ public NerdbankMessagePackFormatter() this.rpcContext = new FormatterContext(serializer, ShapeProvider_StreamJsonRpc.Default); // Create the specialized formatters/resolvers that we will inject into the chain for user data. - this.progressFormatterResolver = new ProgressFormatterResolver(this); - this.asyncEnumerableFormatterResolver = new AsyncEnumerableFormatterResolver(this); + this.progressConverterResolver = new ProgressConverterResolver(this); + this.asyncEnumerableConverterResolver = new AsyncEnumerableConverterResolver(this); + + // TODO: Convert these to converter resolvers? this.pipeFormatterResolver = new PipeFormatterResolver(this); this.exceptionResolver = new MessagePackExceptionResolver(this); @@ -177,7 +178,7 @@ public void SetFormatterContext(Action configure) { Requires.NotNull(configure, nameof(configure)); - var builder = new FormatterContextBuilder(this.userDataContext.Serializer); + var builder = new FormatterContextBuilder(this, this.userDataContext.Serializer); configure(builder); FormatterContext context = builder.Build(); @@ -449,30 +450,11 @@ private void MassageUserDataContext(FormatterContext userDataContext) // Support for marshalled objects. // new RpcMarshalableResolver(this) + // TODO: Add support for exotic types // Stateful or per-connection resolvers. - this.progressFormatterResolver, - this.asyncEnumerableFormatterResolver, this.pipeFormatterResolver, this.exceptionResolver, }; - - // Wrap the resolver in another class as a way to pass information to our custom formatters. - IFormatterResolver userDataResolver = new ResolverWrapper(CompositeResolver.Create(resolvers), this); - } - - private class ResolverWrapper : IFormatterResolver - { - private readonly IFormatterResolver inner; - - internal ResolverWrapper(IFormatterResolver inner, NerdbankMessagePackFormatter formatter) - { - this.inner = inner; - this.Formatter = formatter; - } - - internal NerdbankMessagePackFormatter Formatter { get; } - - public IMessagePackFormatter? GetFormatter() => this.inner.GetFormatter(); } private class MessagePackFormatterConverter : IFormatterConverter @@ -633,36 +615,36 @@ public override void Write(ref NBMP.MessagePackWriter writer, in RawMessagePack } } - private class ProgressFormatterResolver : IFormatterResolver + private class ProgressConverterResolver { - private readonly MessagePackFormatter mainFormatter; + private readonly NerdbankMessagePackFormatter mainFormatter; - private readonly Dictionary progressFormatters = []; + private readonly Dictionary progressConverters = []; - internal ProgressFormatterResolver(MessagePackFormatter formatter) + internal ProgressConverterResolver(NerdbankMessagePackFormatter formatter) { this.mainFormatter = formatter; } - public IMessagePackFormatter? GetFormatter() + public MessagePackConverter? GetConverter() { - lock (this.progressFormatters) + lock (this.progressConverters) { - if (!this.progressFormatters.TryGetValue(typeof(T), out IMessagePackFormatter? formatter)) + if (!this.progressConverters.TryGetValue(typeof(T), out IMessagePackConverter? converter)) { if (MessageFormatterProgressTracker.CanDeserialize(typeof(T))) { - formatter = new PreciseTypeFormatter(this.mainFormatter); + converter = new PreciseTypeConverter(this.mainFormatter); } else if (MessageFormatterProgressTracker.CanSerialize(typeof(T))) { - formatter = new ProgressClientFormatter(this.mainFormatter); + converter = new ProgressClientConverter(this.mainFormatter); } - this.progressFormatters.Add(typeof(T), formatter); + this.progressConverters.Add(typeof(T), converter); } - return (IMessagePackFormatter?)formatter; + return (MessagePackConverter?)converter; } } @@ -749,36 +731,36 @@ public override void Write(ref NBMP.MessagePackWriter writer, in TClass? value, } } - private class AsyncEnumerableFormatterResolver : IFormatterResolver + private class AsyncEnumerableConverterResolver { - private readonly MessagePackFormatter mainFormatter; + private readonly NerdbankMessagePackFormatter mainFormatter; - private readonly Dictionary enumerableFormatters = new Dictionary(); + private readonly Dictionary enumerableFormatters = []; - internal AsyncEnumerableFormatterResolver(MessagePackFormatter formatter) + internal AsyncEnumerableConverterResolver(NerdbankMessagePackFormatter formatter) { this.mainFormatter = formatter; } - public IMessagePackFormatter? GetFormatter() + public MessagePackConverter? GetConverter() { lock (this.enumerableFormatters) { - if (!this.enumerableFormatters.TryGetValue(typeof(T), out IMessagePackFormatter? formatter)) + if (!this.enumerableFormatters.TryGetValue(typeof(T), out IMessagePackConverter? converter)) { if (TrackerHelpers>.IsActualInterfaceMatch(typeof(T))) { - formatter = (IMessagePackFormatter?)Activator.CreateInstance(typeof(PreciseTypeConverter<>).MakeGenericType(typeof(T).GenericTypeArguments[0]), new object[] { this.mainFormatter }); + converter = (MessagePackConverter?)Activator.CreateInstance(typeof(PreciseTypeConverter<>).MakeGenericType(typeof(T).GenericTypeArguments[0]), new object[] { this.mainFormatter }); } else if (TrackerHelpers>.FindInterfaceImplementedBy(typeof(T)) is { } iface) { - formatter = (IMessagePackFormatter?)Activator.CreateInstance(typeof(GeneratorConverter<,>).MakeGenericType(typeof(T), iface.GenericTypeArguments[0]), new object[] { this.mainFormatter }); + converter = (MessagePackConverter?)Activator.CreateInstance(typeof(GeneratorConverter<,>).MakeGenericType(typeof(T), iface.GenericTypeArguments[0]), new object[] { this.mainFormatter }); } - this.enumerableFormatters.Add(typeof(T), formatter); + this.enumerableFormatters.Add(typeof(T), converter); } - return (IMessagePackFormatter?)formatter; + return (MessagePackConverter?)converter; } } @@ -2365,7 +2347,7 @@ protected internal override void SetExpectedDataType(Type dataType) private record FormatterContext(NBMP::MessagePackSerializer Serializer, ITypeShapeProvider ShapeProvider); - private class FormatterContextBuilder(NBMP::MessagePackSerializer serializer) : IFormatterContextBuilder + private class FormatterContextBuilder(NerdbankMessagePackFormatter formatter, NBMP::MessagePackSerializer serializer) : IFormatterContextBuilder { private readonly CompositeTypeShapeProviderBuilder providerBuilder = new(); @@ -2373,6 +2355,23 @@ private class FormatterContextBuilder(NBMP::MessagePackSerializer serializer) : public void RegisterConverter(NBMP::MessagePackConverter converter) => serializer.RegisterConverter(converter); + public void RegisterProgressTypeConverter() + { + // TODO: Improve Exception + MessagePackConverter converter = formatter.progressConverterResolver.GetConverter() + ?? throw new InvalidOperationException("No converter found for " + typeof(TProgress).FullName); + + serializer.RegisterConverter(converter); + } + + public void RegisterAsyncEnumerableTypeConverter() + { + MessagePackConverter converter = formatter.asyncEnumerableConverterResolver.GetConverter() + ?? throw new InvalidOperationException("No converter found for " + typeof(TElement).FullName); + + serializer.RegisterConverter(converter); + } + public void RegisterKnownSubTypes(NBMP::KnownSubTypeMapping mapping) { Requires.NotNull(mapping, nameof(mapping)); From 259558a5bb15a0fecd80b2d2f508c7d85de203cc Mon Sep 17 00:00:00 2001 From: Charles Willis Date: Sat, 21 Dec 2024 00:45:37 +0000 Subject: [PATCH 04/25] Add type converter registration methods to interface --- ...sagePackFormatter.ISerializationContextBuilder.cs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.ISerializationContextBuilder.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ISerializationContextBuilder.cs index d898b41d2..dc3acc1a0 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.ISerializationContextBuilder.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ISerializationContextBuilder.cs @@ -23,6 +23,12 @@ public interface IFormatterContextBuilder /// ICompositeTypeShapeProviderBuilder TypeShapeProviderBuilder { get; } + /// + /// Registers a type converter for asynchronous enumerable types. + /// + /// The element type of the asynchronous enumerable. + void RegisterAsyncEnumerableTypeConverter(); + /// /// Registers a custom converter for a specific type. /// @@ -36,5 +42,11 @@ public interface IFormatterContextBuilder /// The base type for which the subtypes are registered. /// The mapping of known subtypes. void RegisterKnownSubTypes(KnownSubTypeMapping mapping); + + /// + /// Registers a type converter for progress types. + /// + /// The type of the progress to register. + void RegisterProgressTypeConverter(); } } From e87da44fff3c55bfc288c7a94457118c81ab9528 Mon Sep 17 00:00:00 2001 From: Charles Willis Date: Sat, 21 Dec 2024 15:56:38 +0000 Subject: [PATCH 05/25] Enhance serialization context and type shape providers Significant modifications to the `NerdbankMessagePackFormatter` and related classes to improve serialization context and type shape provider functionalities. - Introduced new methods in `ISerializationContextBuilder.cs` for registering various type converters, enhancing flexibility for asynchronous enumerables, duplex pipes, and streams. - Replaced `PipeFormatterResolver` with `PipeConverterResolver` in `NerdbankMessagePackFormatter` to align with new converter registration methods. - Updated `FormatterContextBuilder` to utilize new context and type shape provider structures for a more modular design. --- ...tter.ICompositeTypeShapeProviderBuilder.cs | 42 --- ...kFormatter.ISerializationContextBuilder.cs | 213 ++++++++++++- .../NerdbankMessagePackFormatter.cs | 294 +++++------------- .../Reflection/RpcMarshalableAttribute.cs | 2 +- 4 files changed, 277 insertions(+), 274 deletions(-) delete mode 100644 src/StreamJsonRpc/NerdbankMessagePackFormatter.ICompositeTypeShapeProviderBuilder.cs diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.ICompositeTypeShapeProviderBuilder.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ICompositeTypeShapeProviderBuilder.cs deleted file mode 100644 index 6d681dea8..000000000 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.ICompositeTypeShapeProviderBuilder.cs +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. - -using PolyType; - -namespace StreamJsonRpc; - -/// -/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). -/// -/// -/// The MessagePack implementation used here comes from https://github.com/AArnott/Nerdbank.MessagePack. -/// -public sealed partial class NerdbankMessagePackFormatter -{ - /// - /// Provides a builder interface for adding type shape providers. - /// - public interface ICompositeTypeShapeProviderBuilder - { - /// - /// Adds a single type shape provider to the builder. - /// - /// The type shape provider to add. - /// The current builder instance. - ICompositeTypeShapeProviderBuilder Add(ITypeShapeProvider provider); - - /// - /// Adds a range of type shape providers to the builder. - /// - /// The collection of type shape providers to add. - /// The current builder instance. - ICompositeTypeShapeProviderBuilder AddRange(IEnumerable providers); - - /// - /// Adds a reflection-based type shape provider to the builder. - /// - /// A value indicating whether to use Reflection.Emit for dynamic type generation. - /// The current builder instance. - ICompositeTypeShapeProviderBuilder AddReflectionTypeShapeProvider(bool useReflectionEmit); - } -} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.ISerializationContextBuilder.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ISerializationContextBuilder.cs index dc3acc1a0..64a92d6e2 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.ISerializationContextBuilder.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ISerializationContextBuilder.cs @@ -1,7 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System.Collections.Immutable; +using System.Collections.ObjectModel; +using System.IO.Pipelines; +using MessagePack.Formatters; using Nerdbank.MessagePack; +using PolyType; +using PolyType.Abstractions; +using PolyType.ReflectionProvider; +using StreamJsonRpc.Reflection; namespace StreamJsonRpc; @@ -18,17 +26,6 @@ public sealed partial class NerdbankMessagePackFormatter /// public interface IFormatterContextBuilder { - /// - /// Gets the type shape provider builder. - /// - ICompositeTypeShapeProviderBuilder TypeShapeProviderBuilder { get; } - - /// - /// Registers a type converter for asynchronous enumerable types. - /// - /// The element type of the asynchronous enumerable. - void RegisterAsyncEnumerableTypeConverter(); - /// /// Registers a custom converter for a specific type. /// @@ -47,6 +44,198 @@ public interface IFormatterContextBuilder /// Registers a type converter for progress types. /// /// The type of the progress to register. - void RegisterProgressTypeConverter(); + /// The type of the progress report. + void RegisterProgressType() + where TProgress : IProgress; + + /// + /// Registers a type converter for asynchronous enumerable types. + /// + /// The type of the asynchronous enumerable to register. + /// The element type of the asynchronous enumerable. + void RegisterAsyncEnumerableType() + where TEnumerable : IAsyncEnumerable; + + /// + /// Registers a type converter for duplex pipe types. + /// + /// The type of the duplex pipe to register. + void RegisterDuplexPipeType() + where TPipe : IDuplexPipe; + + /// + /// Registers a type converter for pipe reader types. + /// + /// The type of the pipe reader to register. + void RegisterPipeReaderType() + where TReader : PipeReader; + + /// + /// Registers a type converter for pipe writer types. + /// + /// The type of the pipe writer to register. + void RegisterPipeWriterType() + where TWriter : PipeWriter; + + /// + /// Registers a type converter for stream types. + /// + /// The type of the stream to register. + void RegisterStreamType() + where TStream : Stream; + + /// + /// Adds a type shape provider to the formatter context. + /// + /// The type shape provider to add. + void AddTypeShapeProvider(ITypeShapeProvider provider); + + /// + /// Enable the reflection fallback for dynamic type generation. + /// + /// A value indicating whether to use Reflection.Emit for dynamic type generation. + void EnableReflectionFallback(bool useReflectionEmit); + } + + private class FormatterContextBuilder( + NerdbankMessagePackFormatter formatter, + FormatterContext baseContext) + : IFormatterContextBuilder + { + private ImmutableArray.Builder? typeShapeProvidersBuilder = null; + private ReflectionTypeShapeProvider? reflectionTypeShapeProvider = null; + + public void AddTypeShapeProvider(ITypeShapeProvider provider) + { + this.typeShapeProvidersBuilder ??= ImmutableArray.CreateBuilder(); + this.typeShapeProvidersBuilder.Add(provider); + } + + public void EnableReflectionFallback(bool useReflectionEmit) + { + ReflectionTypeShapeProviderOptions options = new() + { + UseReflectionEmit = useReflectionEmit, + }; + + this.reflectionTypeShapeProvider = ReflectionTypeShapeProvider.Create(options); + } + + public void RegisterAsyncEnumerableType() + where TEnumerable : IAsyncEnumerable + { + MessagePackConverter converter = formatter.asyncEnumerableConverterResolver.GetConverter(); + baseContext.Serializer.RegisterConverter(converter); + } + + public void RegisterConverter(MessagePackConverter converter) + { + baseContext.Serializer.RegisterConverter(converter); + } + + public void RegisterKnownSubTypes(KnownSubTypeMapping mapping) + { + baseContext.Serializer.RegisterKnownSubTypes(mapping); + } + + public void RegisterProgressType() + where TProgress : IProgress + { + MessagePackConverter converter = formatter.progressConverterResolver.GetConverter(); + baseContext.Serializer.RegisterConverter(converter); + } + + public void RegisterDuplexPipeType() + where TPipe : IDuplexPipe + { + MessagePackConverter converter = formatter.pipeConverterResolver.GetConverter(); + baseContext.Serializer.RegisterConverter(converter); + } + + public void RegisterPipeReaderType() + where TReader : PipeReader + { + MessagePackConverter converter = formatter.pipeConverterResolver.GetConverter(); + baseContext.Serializer.RegisterConverter(converter); + } + + public void RegisterPipeWriterType() + where TWriter : PipeWriter + { + MessagePackConverter converter = formatter.pipeConverterResolver.GetConverter(); + baseContext.Serializer.RegisterConverter(converter); + } + + public void RegisterStreamType() + where TStream : Stream + { + MessagePackConverter converter = formatter.pipeConverterResolver.GetConverter(); + baseContext.Serializer.RegisterConverter(converter); + } + + public void RegisterRpcMarshalableType() + where T : class + { + if (MessageFormatterRpcMarshaledContextTracker.TryGetMarshalOptionsForType( + typeof(T), + out JsonRpcProxyOptions? proxyOptions, + out JsonRpcTargetOptions? targetOptions, + out RpcMarshalableAttribute? attribute)) + { + var converter = (RpcMarshalableConverter)Activator.CreateInstance( + typeof(RpcMarshalableConverter<>).MakeGenericType(typeof(T)), + formatter, + proxyOptions, + targetOptions, + attribute)!; + + baseContext.Serializer.RegisterConverter(converter); + } + + // TODO: Throw? + } + + internal FormatterContext Build() + { + if (this.reflectionTypeShapeProvider is not null) + { + this.AddTypeShapeProvider(this.reflectionTypeShapeProvider); + } + + if (this.typeShapeProvidersBuilder is null || this.typeShapeProvidersBuilder.Count < 1) + { + return baseContext; + } + + ITypeShapeProvider provider = this.typeShapeProvidersBuilder.Count == 1 + ? this.typeShapeProvidersBuilder[0] + : new CompositeTypeShapeProvider(this.typeShapeProvidersBuilder.ToImmutable()); + + return new FormatterContext(baseContext.Serializer, provider); + } + } + + private class CompositeTypeShapeProvider : ITypeShapeProvider + { + private readonly ImmutableArray providers; + + internal CompositeTypeShapeProvider(ImmutableArray providers) + { + this.providers = providers; + } + + public ITypeShape? GetShape(Type type) + { + foreach (ITypeShapeProvider provider in this.providers) + { + ITypeShape? shape = provider.GetShape(type); + if (shape is not null) + { + return shape; + } + } + + return null; + } } } diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs index 25ebda05d..b7bd4f4dc 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs @@ -10,6 +10,7 @@ using System.IO.Pipelines; using System.Reflection; using System.Runtime.ExceptionServices; +using System.Runtime.InteropServices.ComTypes; using System.Runtime.Serialization; using System.Text; using System.Text.Json.Nodes; @@ -97,7 +98,7 @@ public sealed partial class NerdbankMessagePackFormatter : FormatterBase, IJsonR private readonly AsyncEnumerableConverterResolver asyncEnumerableConverterResolver; - private readonly PipeFormatterResolver pipeFormatterResolver; + private readonly PipeConverterResolver pipeConverterResolver; private readonly MessagePackExceptionResolver exceptionResolver; @@ -135,9 +136,9 @@ public NerdbankMessagePackFormatter() // Create the specialized formatters/resolvers that we will inject into the chain for user data. this.progressConverterResolver = new ProgressConverterResolver(this); this.asyncEnumerableConverterResolver = new AsyncEnumerableConverterResolver(this); + this.pipeConverterResolver = new PipeConverterResolver(this); // TODO: Convert these to converter resolvers? - this.pipeFormatterResolver = new PipeFormatterResolver(this); this.exceptionResolver = new MessagePackExceptionResolver(this); FormatterContext userDataContext = new( @@ -178,7 +179,7 @@ public void SetFormatterContext(Action configure) { Requires.NotNull(configure, nameof(configure)); - var builder = new FormatterContextBuilder(this, this.userDataContext.Serializer); + var builder = new FormatterContextBuilder(this, this.userDataContext); configure(builder); FormatterContext context = builder.Build(); @@ -452,7 +453,6 @@ private void MassageUserDataContext(FormatterContext userDataContext) // TODO: Add support for exotic types // Stateful or per-connection resolvers. - this.pipeFormatterResolver, this.exceptionResolver, }; } @@ -619,33 +619,26 @@ private class ProgressConverterResolver { private readonly NerdbankMessagePackFormatter mainFormatter; - private readonly Dictionary progressConverters = []; - internal ProgressConverterResolver(NerdbankMessagePackFormatter formatter) { this.mainFormatter = formatter; } - public MessagePackConverter? GetConverter() + public MessagePackConverter GetConverter() { - lock (this.progressConverters) - { - if (!this.progressConverters.TryGetValue(typeof(T), out IMessagePackConverter? converter)) - { - if (MessageFormatterProgressTracker.CanDeserialize(typeof(T))) - { - converter = new PreciseTypeConverter(this.mainFormatter); - } - else if (MessageFormatterProgressTracker.CanSerialize(typeof(T))) - { - converter = new ProgressClientConverter(this.mainFormatter); - } - - this.progressConverters.Add(typeof(T), converter); - } + MessagePackConverter? converter = default; - return (MessagePackConverter?)converter; + if (MessageFormatterProgressTracker.CanDeserialize(typeof(T))) + { + converter = new PreciseTypeConverter(this.mainFormatter); + } + else if (MessageFormatterProgressTracker.CanSerialize(typeof(T))) + { + converter = new ProgressClientConverter(this.mainFormatter); } + + // TODO: Improve Exception + return converter ?? throw new NotSupportedException(); } /// @@ -735,33 +728,26 @@ private class AsyncEnumerableConverterResolver { private readonly NerdbankMessagePackFormatter mainFormatter; - private readonly Dictionary enumerableFormatters = []; - internal AsyncEnumerableConverterResolver(NerdbankMessagePackFormatter formatter) { this.mainFormatter = formatter; } - public MessagePackConverter? GetConverter() + public MessagePackConverter GetConverter() { - lock (this.enumerableFormatters) - { - if (!this.enumerableFormatters.TryGetValue(typeof(T), out IMessagePackConverter? converter)) - { - if (TrackerHelpers>.IsActualInterfaceMatch(typeof(T))) - { - converter = (MessagePackConverter?)Activator.CreateInstance(typeof(PreciseTypeConverter<>).MakeGenericType(typeof(T).GenericTypeArguments[0]), new object[] { this.mainFormatter }); - } - else if (TrackerHelpers>.FindInterfaceImplementedBy(typeof(T)) is { } iface) - { - converter = (MessagePackConverter?)Activator.CreateInstance(typeof(GeneratorConverter<,>).MakeGenericType(typeof(T), iface.GenericTypeArguments[0]), new object[] { this.mainFormatter }); - } + MessagePackConverter? converter = default; - this.enumerableFormatters.Add(typeof(T), converter); - } - - return (MessagePackConverter?)converter; + if (TrackerHelpers>.IsActualInterfaceMatch(typeof(T))) + { + converter = (MessagePackConverter?)Activator.CreateInstance(typeof(PreciseTypeConverter<>).MakeGenericType(typeof(T).GenericTypeArguments[0]), new object[] { this.mainFormatter }); + } + else if (TrackerHelpers>.FindInterfaceImplementedBy(typeof(T)) is { } iface) + { + converter = (MessagePackConverter?)Activator.CreateInstance(typeof(GeneratorConverter<,>).MakeGenericType(typeof(T), iface.GenericTypeArguments[0]), new object[] { this.mainFormatter }); } + + // TODO: Improve Exception + return converter ?? throw new NotSupportedException(); } /// @@ -892,45 +878,38 @@ public override void Write(ref NBMP.MessagePackWriter writer, in TClass? value, } } - private class PipeFormatterResolver : IFormatterResolver + private class PipeConverterResolver { private readonly NerdbankMessagePackFormatter mainFormatter; - private readonly Dictionary pipeFormatters = []; - - internal PipeFormatterResolver(NerdbankMessagePackFormatter formatter) + internal PipeConverterResolver(NerdbankMessagePackFormatter formatter) { this.mainFormatter = formatter; } - public IMessagePackFormatter? GetFormatter() + public MessagePackConverter GetConverter() { - lock (this.pipeFormatters) - { - if (!this.pipeFormatters.TryGetValue(typeof(T), out IMessagePackFormatter? formatter)) - { - if (typeof(IDuplexPipe).IsAssignableFrom(typeof(T))) - { - formatter = (IMessagePackFormatter)Activator.CreateInstance(typeof(DuplexPipeConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; - } - else if (typeof(PipeReader).IsAssignableFrom(typeof(T))) - { - formatter = (IMessagePackFormatter)Activator.CreateInstance(typeof(PipeReaderConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; - } - else if (typeof(PipeWriter).IsAssignableFrom(typeof(T))) - { - formatter = (IMessagePackFormatter)Activator.CreateInstance(typeof(PipeWriterConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; - } - else if (typeof(Stream).IsAssignableFrom(typeof(T))) - { - formatter = (IMessagePackFormatter)Activator.CreateInstance(typeof(StreamConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; - } - - this.pipeFormatters.Add(typeof(T), formatter); - } + MessagePackConverter? converter = default; - return (IMessagePackFormatter?)formatter; + if (typeof(IDuplexPipe).IsAssignableFrom(typeof(T))) + { + converter = (MessagePackConverter?)Activator.CreateInstance(typeof(DuplexPipeConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; + } + else if (typeof(PipeReader).IsAssignableFrom(typeof(T))) + { + converter = (MessagePackConverter?)Activator.CreateInstance(typeof(PipeReaderConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; + } + else if (typeof(PipeWriter).IsAssignableFrom(typeof(T))) + { + converter = (MessagePackConverter?)Activator.CreateInstance(typeof(PipeWriterConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; + } + else if (typeof(Stream).IsAssignableFrom(typeof(T))) + { + converter = (MessagePackConverter?)Activator.CreateInstance(typeof(StreamConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; } + + // TODO: Improve Exception + return converter ?? throw new NotSupportedException(); } #pragma warning disable CA1812 @@ -1061,71 +1040,24 @@ public override void Write(ref NBMP.MessagePackWriter writer, in T? value, Seria } } - private class RpcMarshalableResolver : IFormatterResolver - { - private readonly NerdbankMessagePackFormatter formatter; - private readonly Dictionary formatters = new Dictionary(); - - internal RpcMarshalableResolver(NerdbankMessagePackFormatter formatter) - { - this.formatter = formatter; - } - - public IMessagePackFormatter? GetFormatter() - { - if (typeof(T).IsValueType) - { - return null; - } - - lock (this.formatters) - { - if (this.formatters.TryGetValue(typeof(T), out object? cachedFormatter)) - { - return (IMessagePackFormatter)cachedFormatter; - } - } - - if (MessageFormatterRpcMarshaledContextTracker.TryGetMarshalOptionsForType( - typeof(T), - out JsonRpcProxyOptions? proxyOptions, - out JsonRpcTargetOptions? targetOptions, - out RpcMarshalableAttribute? attribute)) - { - object formatter = Activator.CreateInstance( - typeof(RpcMarshalableFormatter<>).MakeGenericType(typeof(T)), - this.formatter, - proxyOptions, - targetOptions, - attribute)!; - - lock (this.formatters) - { - if (!this.formatters.TryGetValue(typeof(T), out object? cachedFormatter)) - { - this.formatters.Add(typeof(T), cachedFormatter = formatter); - } - - return (IMessagePackFormatter)cachedFormatter; - } - } - - return null; - } - } - #pragma warning disable CA1812 - private class RpcMarshalableFormatter(NerdbankMessagePackFormatter messagePackFormatter, JsonRpcProxyOptions proxyOptions, JsonRpcTargetOptions targetOptions, RpcMarshalableAttribute rpcMarshalableAttribute) : IMessagePackFormatter + private class RpcMarshalableConverter( + NerdbankMessagePackFormatter formatter, + JsonRpcProxyOptions proxyOptions, + JsonRpcTargetOptions targetOptions, + RpcMarshalableAttribute rpcMarshalableAttribute) : MessagePackConverter where T : class #pragma warning restore CA1812 { - public T? Deserialize(ref MessagePack.MessagePackReader reader, MessagePackSerializerOptions options) + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] + public override T? Read(ref NBMP.MessagePackReader reader, SerializationContext context) { - MessageFormatterRpcMarshaledContextTracker.MarshalToken? token = MessagePack.MessagePackSerializer.Deserialize(ref reader, options); - return token.HasValue ? (T?)messagePackFormatter.RpcMarshaledContextTracker.GetObject(typeof(T), token, proxyOptions) : null; + MessageFormatterRpcMarshaledContextTracker.MarshalToken? token = formatter.rpcContext.Deserialize(ref reader); + return token.HasValue ? (T?)formatter.RpcMarshaledContextTracker.GetObject(typeof(T), token, proxyOptions) : null; } - public void Serialize(ref MessagePack.MessagePackWriter writer, T? value, MessagePackSerializerOptions options) + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] + public override void Write(ref NBMP.MessagePackWriter writer, in T? value, SerializationContext context) { if (value is null) { @@ -1133,10 +1065,15 @@ public void Serialize(ref MessagePack.MessagePackWriter writer, T? value, Messag } else { - MessageFormatterRpcMarshaledContextTracker.MarshalToken token = messagePackFormatter.RpcMarshaledContextTracker.GetToken(value, targetOptions, typeof(T), rpcMarshalableAttribute); - MessagePack.MessagePackSerializer.Serialize(ref writer, token, options); + MessageFormatterRpcMarshaledContextTracker.MarshalToken token = formatter.RpcMarshaledContextTracker.GetToken(value, targetOptions, typeof(T), rpcMarshalableAttribute); + formatter.rpcContext.Serialize(ref writer, token); } } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(RpcMarshalableConverter)); + } } /// @@ -2345,101 +2282,20 @@ protected internal override void SetExpectedDataType(Type dataType) } } - private record FormatterContext(NBMP::MessagePackSerializer Serializer, ITypeShapeProvider ShapeProvider); - - private class FormatterContextBuilder(NerdbankMessagePackFormatter formatter, NBMP::MessagePackSerializer serializer) : IFormatterContextBuilder - { - private readonly CompositeTypeShapeProviderBuilder providerBuilder = new(); - - public ICompositeTypeShapeProviderBuilder TypeShapeProviderBuilder => this.providerBuilder; - - public void RegisterConverter(NBMP::MessagePackConverter converter) => serializer.RegisterConverter(converter); - - public void RegisterProgressTypeConverter() - { - // TODO: Improve Exception - MessagePackConverter converter = formatter.progressConverterResolver.GetConverter() - ?? throw new InvalidOperationException("No converter found for " + typeof(TProgress).FullName); - - serializer.RegisterConverter(converter); - } - - public void RegisterAsyncEnumerableTypeConverter() - { - MessagePackConverter converter = formatter.asyncEnumerableConverterResolver.GetConverter() - ?? throw new InvalidOperationException("No converter found for " + typeof(TElement).FullName); - - serializer.RegisterConverter(converter); - } - - public void RegisterKnownSubTypes(NBMP::KnownSubTypeMapping mapping) - { - Requires.NotNull(mapping, nameof(mapping)); - serializer.RegisterKnownSubTypes(mapping); - } - - internal FormatterContext Build() => new(serializer, this.providerBuilder.Build()); - } - - private class CompositeTypeShapeProviderBuilder : ICompositeTypeShapeProviderBuilder + private class FormatterContext(NBMP::MessagePackSerializer serializer, ITypeShapeProvider shapeProvider) { - private readonly List providers = []; - - public ICompositeTypeShapeProviderBuilder Add(ITypeShapeProvider provider) - { - this.providers.Add(provider); - return this; - } - - public ICompositeTypeShapeProviderBuilder AddRange(IEnumerable providers) - { - this.providers.AddRange(providers); - return this; - } - - public ICompositeTypeShapeProviderBuilder AddReflectionTypeShapeProvider(bool useReflectionEmit) - { - ReflectionTypeShapeProviderOptions options = new() - { - UseReflectionEmit = useReflectionEmit, - }; + public NBMP::MessagePackSerializer Serializer => serializer; - this.providers.Add(ReflectionTypeShapeProvider.Create(options)); - return this; - } + public ITypeShapeProvider ShapeProvider => shapeProvider; - public ITypeShapeProvider Build() + public T? Deserialize(ref NBMP.MessagePackReader reader, CancellationToken cancellationToken = default) { - return this.providers.Count switch - { - 0 => ReflectionTypeShapeProvider.Default, - 1 => this.providers[0], - _ => new CompositeTypeShapeProvider(this.providers.AsReadOnly()), - }; + return serializer.Deserialize(ref reader, shapeProvider, cancellationToken); } - private class CompositeTypeShapeProvider : ITypeShapeProvider + public void Serialize(ref NBMP.MessagePackWriter writer, T? value, CancellationToken cancellationToken = default) { - private readonly ReadOnlyCollection providers; - - internal CompositeTypeShapeProvider(ReadOnlyCollection providers) - { - this.providers = providers; - } - - public ITypeShape? GetShape(Type type) - { - foreach (ITypeShapeProvider provider in this.providers) - { - ITypeShape? shape = provider.GetShape(type); - if (shape is not null) - { - return shape; - } - } - - return null; - } + serializer.Serialize(ref writer, value, shapeProvider, cancellationToken); } } } diff --git a/src/StreamJsonRpc/Reflection/RpcMarshalableAttribute.cs b/src/StreamJsonRpc/Reflection/RpcMarshalableAttribute.cs index 07e5f29d6..e2158249e 100644 --- a/src/StreamJsonRpc/Reflection/RpcMarshalableAttribute.cs +++ b/src/StreamJsonRpc/Reflection/RpcMarshalableAttribute.cs @@ -7,7 +7,7 @@ namespace StreamJsonRpc; /// Designates an interface that is used in an RPC contract to marshal the object so the receiver can invoke remote methods on it instead of serializing the object to send its data to the remote end. /// /// -/// Learn more about marshable interfaces. +/// Learn more about marshalable interfaces. /// [AttributeUsage(AttributeTargets.Interface, AllowMultiple = false, Inherited = false)] public class RpcMarshalableAttribute : Attribute From ca439ea388a3bcc3570ab884f9f68421f841e33a Mon Sep 17 00:00:00 2001 From: Charles Willis Date: Sat, 21 Dec 2024 17:30:39 +0000 Subject: [PATCH 06/25] Enhance NerdbankMessagePackFormatter with new methods --- ...kFormatter.ISerializationContextBuilder.cs | 23 ++++- .../NerdbankMessagePackFormatter.cs | 90 +++++++++---------- 2 files changed, 66 insertions(+), 47 deletions(-) diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.ISerializationContextBuilder.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ISerializationContextBuilder.cs index 64a92d6e2..a8308e423 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.ISerializationContextBuilder.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ISerializationContextBuilder.cs @@ -2,9 +2,7 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System.Collections.Immutable; -using System.Collections.ObjectModel; using System.IO.Pipelines; -using MessagePack.Formatters; using Nerdbank.MessagePack; using PolyType; using PolyType.Abstractions; @@ -84,6 +82,20 @@ void RegisterPipeWriterType() void RegisterStreamType() where TStream : Stream; + /// + /// Registers a type that can be marshaled over RPC. + /// + /// The type to register. + void RegisterRpcMarshalableType() + where T : class; + + /// + /// Registers a custom exception type for serialization. + /// + /// The type of the exception to register. + void RegisterExceptionType() + where TException : Exception; + /// /// Adds a type shape provider to the formatter context. /// @@ -173,6 +185,13 @@ public void RegisterStreamType() baseContext.Serializer.RegisterConverter(converter); } + public void RegisterExceptionType() + where TException : Exception + { + MessagePackConverter converter = formatter.exceptionResolver.GetConverter(); + baseContext.Serializer.RegisterConverter(converter); + } + public void RegisterRpcMarshalableType() where T : class { diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs index b7bd4f4dc..af70ecae5 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs @@ -100,7 +100,7 @@ public sealed partial class NerdbankMessagePackFormatter : FormatterBase, IJsonR private readonly PipeConverterResolver pipeConverterResolver; - private readonly MessagePackExceptionResolver exceptionResolver; + private readonly MessagePackExceptionConverterResolver exceptionResolver; private readonly ToStringHelper serializationToStringHelper = new(); @@ -137,9 +137,7 @@ public NerdbankMessagePackFormatter() this.progressConverterResolver = new ProgressConverterResolver(this); this.asyncEnumerableConverterResolver = new AsyncEnumerableConverterResolver(this); this.pipeConverterResolver = new PipeConverterResolver(this); - - // TODO: Convert these to converter resolvers? - this.exceptionResolver = new MessagePackExceptionResolver(this); + this.exceptionResolver = new MessagePackExceptionConverterResolver(this); FormatterContext userDataContext = new( new() @@ -445,16 +443,6 @@ private void MassageUserDataContext(FormatterContext userDataContext) userDataContext.Serializer.RegisterConverter(RequestIdConverter.Instance); userDataContext.Serializer.RegisterConverter(RawMessagePackConverter.Instance); userDataContext.Serializer.RegisterConverter(EventArgsConverter.Instance); - - var resolvers = new IFormatterResolver[] - { - // Support for marshalled objects. - // new RpcMarshalableResolver(this) - - // TODO: Add support for exotic types - // Stateful or per-connection resolvers. - this.exceptionResolver, - }; } private class MessagePackFormatterConverter : IFormatterConverter @@ -913,10 +901,8 @@ public MessagePackConverter GetConverter() } #pragma warning disable CA1812 -#pragma warning disable NBMsgPack032 // Converters should override GetJsonSchema private class DuplexPipeConverter(NerdbankMessagePackFormatter formatter) : MessagePackConverter where T : class, IDuplexPipe -#pragma warning restore NBMsgPack032 // Converters should override GetJsonSchema #pragma warning restore CA1812 { public override T? Read(ref NBMP.MessagePackReader reader, SerializationContext context) @@ -940,13 +926,16 @@ public override void Write(ref NBMP.MessagePackWriter writer, in T? value, Seria writer.WriteNil(); } } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(DuplexPipeConverter)); + } } #pragma warning disable CA1812 -#pragma warning disable NBMsgPack032 // Converters should override GetJsonSchema private class PipeReaderConverter(NerdbankMessagePackFormatter formatter) : MessagePackConverter where T : PipeReader -#pragma warning restore NBMsgPack032 // Converters should override GetJsonSchema #pragma warning restore CA1812 { public override T? Read(ref NBMP.MessagePackReader reader, SerializationContext context) @@ -970,13 +959,16 @@ public override void Write(ref NBMP.MessagePackWriter writer, in T? value, Seria writer.WriteNil(); } } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(PipeReaderConverter)); + } } #pragma warning disable CA1812 -#pragma warning disable NBMsgPack032 // Converters should override GetJsonSchema private class PipeWriterConverter(NerdbankMessagePackFormatter formatter) : MessagePackConverter where T : PipeWriter -#pragma warning restore NBMsgPack032 // Converters should override GetJsonSchema #pragma warning restore CA1812 { public override T? Read(ref NBMP.MessagePackReader reader, SerializationContext context) @@ -1000,13 +992,16 @@ public override void Write(ref NBMP.MessagePackWriter writer, in T? value, Seria writer.WriteNil(); } } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(PipeWriterConverter)); + } } #pragma warning disable CA1812 -#pragma warning disable NBMsgPack032 // Converters should override GetJsonSchema private class StreamConverter : MessagePackConverter where T : Stream -#pragma warning restore NBMsgPack032 // Converters should override GetJsonSchema #pragma warning restore CA1812 { private readonly NerdbankMessagePackFormatter formatter; @@ -1037,6 +1032,11 @@ public override void Write(ref NBMP.MessagePackWriter writer, in T? value, Seria writer.WriteNil(); } } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(StreamConverter)); + } } } @@ -1085,48 +1085,38 @@ public override void Write(ref NBMP.MessagePackWriter writer, in T? value, Seria /// 2. Be attributed with /// 3. Declare a constructor with a signature of (, ). /// - private class MessagePackExceptionResolver : IFormatterResolver + private class MessagePackExceptionConverterResolver { /// /// Tracks recursion count while serializing or deserializing an exception. /// /// - /// This is placed here (outside the generic class) + /// This is placed here (outside the generic class) /// so that it's one counter shared across all exception types that may be serialized or deserialized. /// private static ThreadLocal exceptionRecursionCounter = new(); private readonly object[] formatterActivationArgs; - private readonly Dictionary formatterCache = new Dictionary(); - - internal MessagePackExceptionResolver(MessagePackFormatter formatter) + internal MessagePackExceptionConverterResolver(NerdbankMessagePackFormatter formatter) { this.formatterActivationArgs = new object[] { formatter }; } - public IMessagePackFormatter? GetFormatter() + public MessagePackConverter GetConverter() { - lock (this.formatterCache) + MessagePackConverter? formatter = null; + if (typeof(Exception).IsAssignableFrom(typeof(T)) && typeof(T).GetCustomAttribute() is object) { - if (this.formatterCache.TryGetValue(typeof(T), out object? cachedFormatter)) - { - return (IMessagePackFormatter?)cachedFormatter; - } - - IMessagePackFormatter? formatter = null; - if (typeof(Exception).IsAssignableFrom(typeof(T)) && typeof(T).GetCustomAttribute() is object) - { - formatter = (IMessagePackFormatter)Activator.CreateInstance(typeof(ExceptionFormatter<>).MakeGenericType(typeof(T)), this.formatterActivationArgs)!; - } - - this.formatterCache.Add(typeof(T), formatter); - return formatter; + formatter = (MessagePackConverter)Activator.CreateInstance(typeof(ExceptionConverter<>).MakeGenericType(typeof(T)), this.formatterActivationArgs)!; } + + // TODO: Improve Exception + return formatter ?? throw new NotSupportedException(); } #pragma warning disable CA1812 - private partial class ExceptionFormatter(NerdbankMessagePackFormatter formatter) : MessagePackConverter + private partial class ExceptionConverter(NerdbankMessagePackFormatter formatter) : MessagePackConverter where T : Exception #pragma warning restore CA1812 { @@ -1202,10 +1192,10 @@ public override void Write(ref NBMP.MessagePackWriter writer, in T? value, Seria { writer.Write(element.Name); #pragma warning disable NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods - formatter.rpcContext.Serializer.SerializeObject( + formatter.rpcContext.SerializeObject( ref writer, element.Value, - formatter.rpcContext.ShapeProvider.Resolve(element.ObjectType)); + element.ObjectType); #pragma warning restore NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods } } @@ -1214,6 +1204,11 @@ public override void Write(ref NBMP.MessagePackWriter writer, in T? value, Seria exceptionRecursionCounter.Value--; } } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(ExceptionConverter)); + } } } @@ -2297,5 +2292,10 @@ public void Serialize(ref NBMP.MessagePackWriter writer, T? value, Cancellati { serializer.Serialize(ref writer, value, shapeProvider, cancellationToken); } + + internal void SerializeObject(ref NBMP.MessagePackWriter writer, object? value, Type objectType, CancellationToken cancellationToken = default) + { + serializer.SerializeObject(ref writer, value, shapeProvider.Resolve(objectType), cancellationToken); + } } } From b12c357aab5bc6229ade0678dddd43e3add1b90f Mon Sep 17 00:00:00 2001 From: Charles Willis Date: Sat, 21 Dec 2024 18:23:58 +0000 Subject: [PATCH 07/25] Replaced the `IFormatterContextBuilder` interface with a concrete `FormatterContextBuilder` class. The `NerdbankMessagePackFormatter` class has been modified to eliminate the `NBMP` namespace prefix, favoring direct usage of `MessagePack` types. --- ...kFormatter.ISerializationContextBuilder.cs | 205 +++++++----------- .../NerdbankMessagePackFormatter.cs | 185 ++++++++-------- 2 files changed, 172 insertions(+), 218 deletions(-) diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.ISerializationContextBuilder.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ISerializationContextBuilder.cs index a8308e423..b511cbf22 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.ISerializationContextBuilder.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ISerializationContextBuilder.cs @@ -6,7 +6,6 @@ using Nerdbank.MessagePack; using PolyType; using PolyType.Abstractions; -using PolyType.ReflectionProvider; using StreamJsonRpc.Reflection; namespace StreamJsonRpc; @@ -20,178 +19,139 @@ namespace StreamJsonRpc; public sealed partial class NerdbankMessagePackFormatter { /// - /// Provides a builder interface for configuring the serialization context. + /// Provides methods to build a serialization context for the . /// - public interface IFormatterContextBuilder + public class FormatterContextBuilder { - /// - /// Registers a custom converter for a specific type. - /// - /// The type for which the converter is registered. - /// The converter to register. - void RegisterConverter(MessagePackConverter converter); - - /// - /// Registers known subtypes for a base type. - /// - /// The base type for which the subtypes are registered. - /// The mapping of known subtypes. - void RegisterKnownSubTypes(KnownSubTypeMapping mapping); - - /// - /// Registers a type converter for progress types. - /// - /// The type of the progress to register. - /// The type of the progress report. - void RegisterProgressType() - where TProgress : IProgress; - - /// - /// Registers a type converter for asynchronous enumerable types. - /// - /// The type of the asynchronous enumerable to register. - /// The element type of the asynchronous enumerable. - void RegisterAsyncEnumerableType() - where TEnumerable : IAsyncEnumerable; - - /// - /// Registers a type converter for duplex pipe types. - /// - /// The type of the duplex pipe to register. - void RegisterDuplexPipeType() - where TPipe : IDuplexPipe; - - /// - /// Registers a type converter for pipe reader types. - /// - /// The type of the pipe reader to register. - void RegisterPipeReaderType() - where TReader : PipeReader; + private readonly NerdbankMessagePackFormatter formatter; + private readonly FormatterContext baseContext; - /// - /// Registers a type converter for pipe writer types. - /// - /// The type of the pipe writer to register. - void RegisterPipeWriterType() - where TWriter : PipeWriter; - - /// - /// Registers a type converter for stream types. - /// - /// The type of the stream to register. - void RegisterStreamType() - where TStream : Stream; - - /// - /// Registers a type that can be marshaled over RPC. - /// - /// The type to register. - void RegisterRpcMarshalableType() - where T : class; + private ImmutableArray.Builder? typeShapeProvidersBuilder = null; /// - /// Registers a custom exception type for serialization. + /// Initializes a new instance of the class. /// - /// The type of the exception to register. - void RegisterExceptionType() - where TException : Exception; + /// The formatter to use. + /// The base context to build upon. + internal FormatterContextBuilder(NerdbankMessagePackFormatter formatter, FormatterContext baseContext) + { + this.formatter = formatter; + this.baseContext = baseContext; + } /// - /// Adds a type shape provider to the formatter context. + /// Adds a type shape provider to the context. /// /// The type shape provider to add. - void AddTypeShapeProvider(ITypeShapeProvider provider); - - /// - /// Enable the reflection fallback for dynamic type generation. - /// - /// A value indicating whether to use Reflection.Emit for dynamic type generation. - void EnableReflectionFallback(bool useReflectionEmit); - } - - private class FormatterContextBuilder( - NerdbankMessagePackFormatter formatter, - FormatterContext baseContext) - : IFormatterContextBuilder - { - private ImmutableArray.Builder? typeShapeProvidersBuilder = null; - private ReflectionTypeShapeProvider? reflectionTypeShapeProvider = null; - public void AddTypeShapeProvider(ITypeShapeProvider provider) { this.typeShapeProvidersBuilder ??= ImmutableArray.CreateBuilder(); this.typeShapeProvidersBuilder.Add(provider); } - public void EnableReflectionFallback(bool useReflectionEmit) - { - ReflectionTypeShapeProviderOptions options = new() - { - UseReflectionEmit = useReflectionEmit, - }; - - this.reflectionTypeShapeProvider = ReflectionTypeShapeProvider.Create(options); - } - + /// + /// Registers an async enumerable type with the context. + /// + /// The type of the async enumerable. + /// The type of the elements in the async enumerable. public void RegisterAsyncEnumerableType() where TEnumerable : IAsyncEnumerable { - MessagePackConverter converter = formatter.asyncEnumerableConverterResolver.GetConverter(); - baseContext.Serializer.RegisterConverter(converter); + MessagePackConverter converter = this.formatter.asyncEnumerableConverterResolver.GetConverter(); + this.baseContext.Serializer.RegisterConverter(converter); } + /// + /// Registers a converter with the context. + /// + /// The type the converter handles. + /// The converter to register. public void RegisterConverter(MessagePackConverter converter) { - baseContext.Serializer.RegisterConverter(converter); + this.baseContext.Serializer.RegisterConverter(converter); } + /// + /// Registers known subtypes for a base type with the context. + /// + /// The base type. + /// The mapping of known subtypes. public void RegisterKnownSubTypes(KnownSubTypeMapping mapping) { - baseContext.Serializer.RegisterKnownSubTypes(mapping); + this.baseContext.Serializer.RegisterKnownSubTypes(mapping); } + /// + /// Registers a progress type with the context. + /// + /// The type of the progress. + /// The type of the report. public void RegisterProgressType() where TProgress : IProgress { - MessagePackConverter converter = formatter.progressConverterResolver.GetConverter(); - baseContext.Serializer.RegisterConverter(converter); + MessagePackConverter converter = this.formatter.progressConverterResolver.GetConverter(); + this.baseContext.Serializer.RegisterConverter(converter); } + /// + /// Registers a duplex pipe type with the context. + /// + /// The type of the duplex pipe. public void RegisterDuplexPipeType() where TPipe : IDuplexPipe { - MessagePackConverter converter = formatter.pipeConverterResolver.GetConverter(); - baseContext.Serializer.RegisterConverter(converter); + MessagePackConverter converter = this.formatter.pipeConverterResolver.GetConverter(); + this.baseContext.Serializer.RegisterConverter(converter); } + /// + /// Registers a pipe reader type with the context. + /// + /// The type of the pipe reader. public void RegisterPipeReaderType() where TReader : PipeReader { - MessagePackConverter converter = formatter.pipeConverterResolver.GetConverter(); - baseContext.Serializer.RegisterConverter(converter); + MessagePackConverter converter = this.formatter.pipeConverterResolver.GetConverter(); + this.baseContext.Serializer.RegisterConverter(converter); } + /// + /// Registers a pipe writer type with the context. + /// + /// The type of the pipe writer. public void RegisterPipeWriterType() where TWriter : PipeWriter { - MessagePackConverter converter = formatter.pipeConverterResolver.GetConverter(); - baseContext.Serializer.RegisterConverter(converter); + MessagePackConverter converter = this.formatter.pipeConverterResolver.GetConverter(); + this.baseContext.Serializer.RegisterConverter(converter); } + /// + /// Registers a stream type with the context. + /// + /// The type of the stream. public void RegisterStreamType() where TStream : Stream { - MessagePackConverter converter = formatter.pipeConverterResolver.GetConverter(); - baseContext.Serializer.RegisterConverter(converter); + MessagePackConverter converter = this.formatter.pipeConverterResolver.GetConverter(); + this.baseContext.Serializer.RegisterConverter(converter); } + /// + /// Registers an exception type with the context. + /// + /// The type of the exception. public void RegisterExceptionType() where TException : Exception { - MessagePackConverter converter = formatter.exceptionResolver.GetConverter(); - baseContext.Serializer.RegisterConverter(converter); + MessagePackConverter converter = this.formatter.exceptionResolver.GetConverter(); + this.baseContext.Serializer.RegisterConverter(converter); } + /// + /// Registers an RPC marshalable type with the context. + /// + /// The type to register. public void RegisterRpcMarshalableType() where T : class { @@ -203,34 +163,33 @@ public void RegisterRpcMarshalableType() { var converter = (RpcMarshalableConverter)Activator.CreateInstance( typeof(RpcMarshalableConverter<>).MakeGenericType(typeof(T)), - formatter, + this.formatter, proxyOptions, targetOptions, attribute)!; - baseContext.Serializer.RegisterConverter(converter); + this.baseContext.Serializer.RegisterConverter(converter); } // TODO: Throw? } + /// + /// Builds the formatter context. + /// + /// The built formatter context. internal FormatterContext Build() { - if (this.reflectionTypeShapeProvider is not null) - { - this.AddTypeShapeProvider(this.reflectionTypeShapeProvider); - } - if (this.typeShapeProvidersBuilder is null || this.typeShapeProvidersBuilder.Count < 1) { - return baseContext; + return this.baseContext; } ITypeShapeProvider provider = this.typeShapeProvidersBuilder.Count == 1 ? this.typeShapeProvidersBuilder[0] : new CompositeTypeShapeProvider(this.typeShapeProvidersBuilder.ToImmutable()); - return new FormatterContext(baseContext.Serializer, provider); + return new FormatterContext(this.baseContext.Serializer, provider); } } diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs index af70ecae5..78566beb5 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs @@ -3,19 +3,15 @@ using System.Buffers; using System.Collections.Immutable; -using System.Collections.ObjectModel; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.IO.Pipelines; using System.Reflection; using System.Runtime.ExceptionServices; -using System.Runtime.InteropServices.ComTypes; using System.Runtime.Serialization; using System.Text; using System.Text.Json.Nodes; -using MessagePack; -using MessagePack.Formatters; using Nerdbank.MessagePack; using Nerdbank.Streams; using PolyType; @@ -24,7 +20,6 @@ using PolyType.SourceGenerator; using StreamJsonRpc.Protocol; using StreamJsonRpc.Reflection; -using NBMP = Nerdbank.MessagePack; namespace StreamJsonRpc; @@ -117,7 +112,7 @@ public sealed partial class NerdbankMessagePackFormatter : FormatterBase, IJsonR public NerdbankMessagePackFormatter() { // Set up initial options for our own message types. - NBMP::MessagePackSerializer serializer = new() + MessagePackSerializer serializer = new() { InternStrings = true, SerializeDefaultValues = false, @@ -173,7 +168,7 @@ private interface IJsonRpcMessagePackRetention /// Configures the serialization context for user data with the specified configuration action. /// /// The action to configure the serialization context. - public void SetFormatterContext(Action configure) + public void SetFormatterContext(Action configure) { Requires.NotNull(configure, nameof(configure)); @@ -190,7 +185,7 @@ public void SetFormatterContext(Action configure) public JsonRpcMessage Deserialize(ReadOnlySequence contentBuffer) { JsonRpcMessage message = this.rpcContext.Serializer.Deserialize(contentBuffer, ShapeProvider_StreamJsonRpc.Default) - ?? throw new NBMP::MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); + ?? throw new MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); IJsonRpcTracingCallbacks? tracingCallbacks = this.JsonRpc; this.deserializationToStringHelper.Activate(contentBuffer); @@ -223,7 +218,7 @@ public void Serialize(IBufferWriter contentBuffer, JsonRpcMessage message) } } - var writer = new NBMP::MessagePackWriter(contentBuffer); + var writer = new MessagePackWriter(contentBuffer); try { this.rpcContext.Serializer.Serialize(ref writer, message, this.rpcContext.ShapeProvider); @@ -231,13 +226,13 @@ public void Serialize(IBufferWriter contentBuffer, JsonRpcMessage message) } catch (Exception ex) { - throw new NBMP::MessagePackSerializationException(string.Format(CultureInfo.CurrentCulture, Resources.ErrorWritingJsonRpcMessage, ex.GetType().Name, ex.Message), ex); + throw new MessagePackSerializationException(string.Format(CultureInfo.CurrentCulture, Resources.ErrorWritingJsonRpcMessage, ex.GetType().Name, ex.Message), ex); } } /// public object GetJsonText(JsonRpcMessage message) => message is IJsonRpcMessagePackRetention retainedMsgPack - ? NBMP::MessagePackSerializer.ConvertToJson(retainedMsgPack.OriginalMessagePack) + ? MessagePackSerializer.ConvertToJson(retainedMsgPack.OriginalMessagePack) : throw new NotSupportedException(); /// @@ -373,7 +368,7 @@ bool TryGetSerializationInfo(MemberInfo memberInfo, out string key) return (result, argumentTypes!); } - private static ReadOnlySequence GetSliceForNextToken(ref NBMP::MessagePackReader reader, in NBMP::SerializationContext context) + private static ReadOnlySequence GetSliceForNextToken(ref MessagePackReader reader, in SerializationContext context) { SequencePosition startingPosition = reader.Position; reader.Skip(context); @@ -386,12 +381,12 @@ private static ReadOnlySequence GetSliceForNextToken(ref NBMP::MessagePack /// /// The reader to use. /// The decoded string. - private static unsafe string ReadProtocolVersion(ref NBMP::MessagePackReader reader) + private static unsafe string ReadProtocolVersion(ref MessagePackReader reader) { if (!reader.TryReadStringSpan(out ReadOnlySpan valueBytes)) { // TODO: More specific exception type - throw new NBMP::MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); + throw new MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); } // Recognize "2.0" since we expect it and can avoid decoding and allocating a new string for it. @@ -412,7 +407,7 @@ private static unsafe string ReadProtocolVersion(ref NBMP::MessagePackReader rea /// /// Writes the JSON-RPC version property name and value in a highly optimized way. /// - private static void WriteProtocolVersionPropertyAndValue(ref NBMP::MessagePackWriter writer, string version) + private static void WriteProtocolVersionPropertyAndValue(ref MessagePackWriter writer, string version) { VersionPropertyName.Write(ref writer); if (!Version2.TryWrite(ref writer, version)) @@ -421,7 +416,7 @@ private static void WriteProtocolVersionPropertyAndValue(ref NBMP::MessagePackWr } } - private static void ReadUnknownProperty(ref NBMP::MessagePackReader reader, in NBMP::SerializationContext context, ref Dictionary>? topLevelProperties, ReadOnlySpan stringKey) + private static void ReadUnknownProperty(ref MessagePackReader reader, in SerializationContext context, ref Dictionary>? topLevelProperties, ReadOnlySpan stringKey) { topLevelProperties ??= new Dictionary>(StringComparer.Ordinal); #if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER @@ -516,7 +511,7 @@ public override string ToString() { Verify.Operation(this.encodedMessage.HasValue, "This object has not been activated. It may have already been recycled."); - return this.jsonString ??= NBMP::MessagePackSerializer.ConvertToJson(this.encodedMessage.Value); + return this.jsonString ??= MessagePackSerializer.ConvertToJson(this.encodedMessage.Value); } /// @@ -537,15 +532,15 @@ internal void Deactivate() } } - private class RequestIdConverter : NBMP::MessagePackConverter + private class RequestIdConverter : MessagePackConverter { internal static readonly RequestIdConverter Instance = new(); - public override RequestId Read(ref NBMP.MessagePackReader reader, SerializationContext context) + public override RequestId Read(ref MessagePackReader reader, SerializationContext context) { context.DepthStep(); - if (reader.NextMessagePackType == NBMP.MessagePackType.Integer) + if (reader.NextMessagePackType == MessagePackType.Integer) { return new RequestId(reader.ReadInt64()); } @@ -556,7 +551,7 @@ public override RequestId Read(ref NBMP.MessagePackReader reader, SerializationC } } - public override void Write(ref NBMP.MessagePackWriter writer, in RequestId value, SerializationContext context) + public override void Write(ref MessagePackWriter writer, in RequestId value, SerializationContext context) { context.DepthStep(); @@ -586,13 +581,13 @@ private RawMessagePackConverter() } [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] - public override RawMessagePack Read(ref NBMP.MessagePackReader reader, SerializationContext context) + public override RawMessagePack Read(ref MessagePackReader reader, SerializationContext context) { return RawMessagePack.ReadRaw(ref reader, copy: false, context); } [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] - public override void Write(ref NBMP.MessagePackWriter writer, in RawMessagePack value, SerializationContext context) + public override void Write(ref MessagePackWriter writer, in RawMessagePack value, SerializationContext context) { value.WriteRaw(ref writer); } @@ -641,12 +636,12 @@ internal ProgressClientConverter(NerdbankMessagePackFormatter formatter) this.formatter = formatter; } - public override TClass Read(ref NBMP.MessagePackReader reader, SerializationContext context) + public override TClass Read(ref MessagePackReader reader, SerializationContext context) { throw new NotSupportedException("This formatter only serializes IProgress instances."); } - public override void Write(ref NBMP.MessagePackWriter writer, in TClass? value, SerializationContext context) + public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) { if (value is null) { @@ -679,7 +674,7 @@ internal PreciseTypeConverter(NerdbankMessagePackFormatter formatter) [return: MaybeNull] [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] - public override TClass? Read(ref NBMP.MessagePackReader reader, SerializationContext context) + public override TClass? Read(ref MessagePackReader reader, SerializationContext context) { if (reader.TryReadNil()) { @@ -692,7 +687,7 @@ internal PreciseTypeConverter(NerdbankMessagePackFormatter formatter) return (TClass)this.formatter.FormatterProgressTracker.CreateProgress(this.formatter.JsonRpc, token, typeof(TClass), clientRequiresNamedArgs); } - public override void Write(ref NBMP.MessagePackWriter writer, in TClass? value, SerializationContext context) + public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) { if (value is null) { @@ -756,7 +751,7 @@ private partial class PreciseTypeConverter(NerdbankMessagePackFormatter mainF /// private static readonly CommonString ValuesPropertyName = new(MessageFormatterEnumerableTracker.ValuesPropertyName); - public override IAsyncEnumerable? Read(ref NBMP.MessagePackReader reader, SerializationContext context) + public override IAsyncEnumerable? Read(ref MessagePackReader reader, SerializationContext context) { if (reader.TryReadNil()) { @@ -771,7 +766,7 @@ private partial class PreciseTypeConverter(NerdbankMessagePackFormatter mainF { if (!reader.TryReadStringSpan(out ReadOnlySpan stringKey)) { - throw new NBMP.MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); + throw new MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); } if (TokenPropertyName.TryRead(stringKey)) @@ -792,7 +787,7 @@ private partial class PreciseTypeConverter(NerdbankMessagePackFormatter mainF } [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] - public override void Write(ref NBMP.MessagePackWriter writer, in IAsyncEnumerable? value, SerializationContext context) + public override void Write(ref MessagePackWriter writer, in IAsyncEnumerable? value, SerializationContext context) { Serialize_Shared(mainFormatter, ref writer, value, context); } @@ -802,7 +797,7 @@ public override void Write(ref NBMP.MessagePackWriter writer, in IAsyncEnumerabl return CreateUndocumentedSchema(typeof(PreciseTypeConverter)); } - internal static void Serialize_Shared(NerdbankMessagePackFormatter mainFormatter, ref NBMP::MessagePackWriter writer, IAsyncEnumerable? value, NBMP::SerializationContext context) + internal static void Serialize_Shared(NerdbankMessagePackFormatter mainFormatter, ref MessagePackWriter writer, IAsyncEnumerable? value, SerializationContext context) { if (value is null) { @@ -848,13 +843,13 @@ internal static void Serialize_Shared(NerdbankMessagePackFormatter mainFormatter private class GeneratorConverter(NerdbankMessagePackFormatter mainFormatter) : MessagePackConverter where TClass : IAsyncEnumerable #pragma warning restore CA1812 { - public override TClass Read(ref NBMP.MessagePackReader reader, SerializationContext context) + public override TClass Read(ref MessagePackReader reader, SerializationContext context) { throw new NotSupportedException(); } [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] - public override void Write(ref NBMP.MessagePackWriter writer, in TClass? value, SerializationContext context) + public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) { PreciseTypeConverter.Serialize_Shared(mainFormatter, ref writer, value, context); } @@ -905,7 +900,7 @@ private class DuplexPipeConverter(NerdbankMessagePackFormatter formatter) : M where T : class, IDuplexPipe #pragma warning restore CA1812 { - public override T? Read(ref NBMP.MessagePackReader reader, SerializationContext context) + public override T? Read(ref MessagePackReader reader, SerializationContext context) { if (reader.TryReadNil()) { @@ -915,7 +910,7 @@ private class DuplexPipeConverter(NerdbankMessagePackFormatter formatter) : M return (T)formatter.DuplexPipeTracker.GetPipe(reader.ReadUInt64()); } - public override void Write(ref NBMP.MessagePackWriter writer, in T? value, SerializationContext context) + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) { if (formatter.DuplexPipeTracker.GetULongToken(value) is { } token) { @@ -938,7 +933,7 @@ private class PipeReaderConverter(NerdbankMessagePackFormatter formatter) : M where T : PipeReader #pragma warning restore CA1812 { - public override T? Read(ref NBMP.MessagePackReader reader, SerializationContext context) + public override T? Read(ref MessagePackReader reader, SerializationContext context) { if (reader.TryReadNil()) { @@ -948,7 +943,7 @@ private class PipeReaderConverter(NerdbankMessagePackFormatter formatter) : M return (T)formatter.DuplexPipeTracker.GetPipeReader(reader.ReadUInt64()); } - public override void Write(ref NBMP.MessagePackWriter writer, in T? value, SerializationContext context) + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) { if (formatter.DuplexPipeTracker.GetULongToken(value) is { } token) { @@ -971,7 +966,7 @@ private class PipeWriterConverter(NerdbankMessagePackFormatter formatter) : M where T : PipeWriter #pragma warning restore CA1812 { - public override T? Read(ref NBMP.MessagePackReader reader, SerializationContext context) + public override T? Read(ref MessagePackReader reader, SerializationContext context) { if (reader.TryReadNil()) { @@ -981,7 +976,7 @@ private class PipeWriterConverter(NerdbankMessagePackFormatter formatter) : M return (T)formatter.DuplexPipeTracker.GetPipeWriter(reader.ReadUInt64()); } - public override void Write(ref NBMP.MessagePackWriter writer, in T? value, SerializationContext context) + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) { if (formatter.DuplexPipeTracker.GetULongToken(value) is { } token) { @@ -1011,7 +1006,7 @@ public StreamConverter(NerdbankMessagePackFormatter formatter) this.formatter = formatter; } - public override T? Read(ref NBMP.MessagePackReader reader, SerializationContext context) + public override T? Read(ref MessagePackReader reader, SerializationContext context) { if (reader.TryReadNil()) { @@ -1021,7 +1016,7 @@ public StreamConverter(NerdbankMessagePackFormatter formatter) return (T)this.formatter.DuplexPipeTracker.GetPipe(reader.ReadUInt64()).AsStream(); } - public override void Write(ref NBMP.MessagePackWriter writer, in T? value, SerializationContext context) + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) { if (this.formatter.DuplexPipeTracker.GetULongToken(value?.UsePipe()) is { } token) { @@ -1050,14 +1045,14 @@ private class RpcMarshalableConverter( #pragma warning restore CA1812 { [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] - public override T? Read(ref NBMP.MessagePackReader reader, SerializationContext context) + public override T? Read(ref MessagePackReader reader, SerializationContext context) { MessageFormatterRpcMarshaledContextTracker.MarshalToken? token = formatter.rpcContext.Deserialize(ref reader); return token.HasValue ? (T?)formatter.RpcMarshaledContextTracker.GetObject(typeof(T), token, proxyOptions) : null; } [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] - public override void Write(ref NBMP.MessagePackWriter writer, in T? value, SerializationContext context) + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) { if (value is null) { @@ -1120,7 +1115,7 @@ private partial class ExceptionConverter(NerdbankMessagePackFormatter formatt where T : Exception #pragma warning restore CA1812 { - public override T? Read(ref NBMP.MessagePackReader reader, SerializationContext context) + public override T? Read(ref MessagePackReader reader, SerializationContext context) { Assumes.NotNull(formatter.JsonRpc); if (reader.TryReadNil()) @@ -1147,7 +1142,7 @@ private partial class ExceptionConverter(NerdbankMessagePackFormatter formatt for (int i = 0; i < memberCount; i++) { string? name = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context) - ?? throw new NBMP::MessagePackSerializationException(Resources.UnexpectedNullValueInMap); + ?? throw new MessagePackSerializationException(Resources.UnexpectedNullValueInMap); // SerializationInfo.GetValue(string, typeof(object)) does not call our formatter, // so the caller will get a boxed RawMessagePack struct in that case. @@ -1166,7 +1161,7 @@ private partial class ExceptionConverter(NerdbankMessagePackFormatter formatt } } - public override void Write(ref NBMP.MessagePackWriter writer, in T? value, SerializationContext context) + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) { if (value is null) { @@ -1212,7 +1207,7 @@ public override void Write(ref NBMP.MessagePackWriter writer, in T? value, Seria } } - private class JsonRpcMessageConverter : NBMP::MessagePackConverter + private class JsonRpcMessageConverter : MessagePackConverter { private readonly NerdbankMessagePackFormatter formatter; @@ -1221,11 +1216,11 @@ internal JsonRpcMessageConverter(NerdbankMessagePackFormatter formatter) this.formatter = formatter; } - public override JsonRpcMessage? Read(ref NBMP.MessagePackReader reader, NBMP::SerializationContext context) + public override JsonRpcMessage? Read(ref MessagePackReader reader, SerializationContext context) { context.DepthStep(); - NBMP::MessagePackReader readAhead = reader.CreatePeekReader(); + MessagePackReader readAhead = reader.CreatePeekReader(); int propertyCount = readAhead.ReadMapHeader(); for (int i = 0; i < propertyCount; i++) { @@ -1257,7 +1252,7 @@ internal JsonRpcMessageConverter(NerdbankMessagePackFormatter formatter) throw new UnrecognizedJsonRpcMessageException(); } - public override void Write(ref NBMP.MessagePackWriter writer, in JsonRpcMessage? value, NBMP::SerializationContext context) + public override void Write(ref MessagePackWriter writer, in JsonRpcMessage? value, SerializationContext context) { Requires.NotNull(value!, nameof(value)); @@ -1282,13 +1277,13 @@ public override void Write(ref NBMP.MessagePackWriter writer, in JsonRpcMessage? } } - public override JsonObject? GetJsonSchema(NBMP::JsonSchemaContext context, ITypeShape typeShape) + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) { return base.GetJsonSchema(context, typeShape); } } - private class JsonRpcRequestConverter : NBMP::MessagePackConverter + private class JsonRpcRequestConverter : MessagePackConverter { private readonly NerdbankMessagePackFormatter formatter; @@ -1297,7 +1292,7 @@ internal JsonRpcRequestConverter(NerdbankMessagePackFormatter formatter) this.formatter = formatter; } - public override Protocol.JsonRpcRequest? Read(ref NBMP::MessagePackReader reader, NBMP::SerializationContext context) + public override Protocol.JsonRpcRequest? Read(ref MessagePackReader reader, SerializationContext context) { var result = new JsonRpcRequest(this.formatter) { @@ -1335,7 +1330,7 @@ internal JsonRpcRequestConverter(NerdbankMessagePackFormatter formatter) // Parse out the arguments into a dictionary or array, but don't deserialize them because we don't yet know what types to deserialize them to. switch (reader.NextMessagePackType) { - case NBMP::MessagePackType.Array: + case MessagePackType.Array: var positionalArgs = new ReadOnlySequence[reader.ReadArrayHeader()]; for (int i = 0; i < positionalArgs.Length; i++) { @@ -1344,7 +1339,7 @@ internal JsonRpcRequestConverter(NerdbankMessagePackFormatter formatter) result.MsgPackPositionalArguments = positionalArgs; break; - case NBMP::MessagePackType.Map: + case MessagePackType.Map: int namedArgsCount = reader.ReadMapHeader(); var namedArgs = new Dictionary>(namedArgsCount); for (int i = 0; i < namedArgsCount; i++) @@ -1352,7 +1347,7 @@ internal JsonRpcRequestConverter(NerdbankMessagePackFormatter formatter) string? propertyName = context.GetConverter(null).Read(ref reader, context); if (propertyName is null) { - throw new NBMP::MessagePackSerializationException(Resources.UnexpectedNullValueInMap); + throw new MessagePackSerializationException(Resources.UnexpectedNullValueInMap); } namedArgs.Add(propertyName, GetSliceForNextToken(ref reader, context)); @@ -1360,12 +1355,12 @@ internal JsonRpcRequestConverter(NerdbankMessagePackFormatter formatter) result.MsgPackNamedArguments = namedArgs; break; - case NBMP::MessagePackType.Nil: + case MessagePackType.Nil: result.MsgPackPositionalArguments = Array.Empty>(); reader.ReadNil(); break; - case NBMP::MessagePackType type: - throw new NBMP::MessagePackSerializationException("Expected a map or array of arguments but got " + type); + case MessagePackType type: + throw new MessagePackSerializationException("Expected a map or array of arguments but got " + type); } result.MsgPackArguments = reader.Sequence.Slice(paramsTokenStartPosition, reader.Position); @@ -1395,7 +1390,7 @@ internal JsonRpcRequestConverter(NerdbankMessagePackFormatter formatter) return result; } - public override void Write(ref NBMP.MessagePackWriter writer, in Protocol.JsonRpcRequest? value, NBMP::SerializationContext context) + public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcRequest? value, SerializationContext context) { Requires.NotNull(value!, nameof(value)); @@ -1504,12 +1499,12 @@ public override void Write(ref NBMP.MessagePackWriter writer, in Protocol.JsonRp topLevelPropertyBag?.WriteProperties(ref writer); } - public override JsonObject? GetJsonSchema(NBMP::JsonSchemaContext context, ITypeShape typeShape) + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) { return CreateUndocumentedSchema(typeof(JsonRpcRequestConverter)); } - private static void WriteTraceState(ref NBMP::MessagePackWriter writer, string traceState) + private static void WriteTraceState(ref MessagePackWriter writer, string traceState) { ReadOnlySpan traceStateChars = traceState.AsSpan(); @@ -1536,7 +1531,7 @@ private static void WriteTraceState(ref NBMP::MessagePackWriter writer, string t // Write out the last one. WritePair(ref writer, traceStateChars); - static void WritePair(ref NBMP::MessagePackWriter writer, ReadOnlySpan pair) + static void WritePair(ref MessagePackWriter writer, ReadOnlySpan pair) { int equalsIndex = pair.IndexOf('='); ReadOnlySpan key = pair.Slice(0, equalsIndex); @@ -1546,7 +1541,7 @@ static void WritePair(ref NBMP::MessagePackWriter writer, ReadOnlySpan pai } } - private static unsafe string ReadTraceState(ref NBMP::MessagePackReader reader, NBMP::SerializationContext context) + private static unsafe string ReadTraceState(ref MessagePackReader reader, SerializationContext context) { int elements = reader.ReadArrayHeader(); if (elements % 2 != 0) @@ -1574,7 +1569,7 @@ private static unsafe string ReadTraceState(ref NBMP::MessagePackReader reader, } } - private partial class JsonRpcResultConverter : NBMP::MessagePackConverter + private partial class JsonRpcResultConverter : MessagePackConverter { private readonly NerdbankMessagePackFormatter formatter; @@ -1583,7 +1578,7 @@ internal JsonRpcResultConverter(NerdbankMessagePackFormatter formatter) this.formatter = formatter; } - public override Protocol.JsonRpcResult Read(ref NBMP.MessagePackReader reader, NBMP::SerializationContext context) + public override Protocol.JsonRpcResult Read(ref MessagePackReader reader, SerializationContext context) { var result = new JsonRpcResult(this.formatter, this.formatter.userDataContext) { @@ -1628,7 +1623,7 @@ public override Protocol.JsonRpcResult Read(ref NBMP.MessagePackReader reader, N return result; } - public override void Write(ref NBMP.MessagePackWriter writer, in Protocol.JsonRpcResult? value, NBMP::SerializationContext context) + public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcResult? value, SerializationContext context) { Requires.NotNull(value!, nameof(value)); @@ -1666,7 +1661,7 @@ public override void Write(ref NBMP.MessagePackWriter writer, in Protocol.JsonRp (topLevelPropertyBagMessage?.TopLevelPropertyBag as TopLevelPropertyBag)?.WriteProperties(ref writer); } - public override JsonObject? GetJsonSchema(NBMP::JsonSchemaContext context, ITypeShape typeShape) + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) { return CreateUndocumentedSchema(typeof(JsonRpcResultConverter)); } @@ -1681,7 +1676,7 @@ internal JsonRpcErrorConverter(NerdbankMessagePackFormatter formatter) this.formatter = formatter; } - public override Protocol.JsonRpcError Read(ref NBMP::MessagePackReader reader, SerializationContext context) + public override Protocol.JsonRpcError Read(ref MessagePackReader reader, SerializationContext context) { var error = new JsonRpcError(this.formatter.rpcContext) { @@ -1727,7 +1722,7 @@ public override Protocol.JsonRpcError Read(ref NBMP::MessagePackReader reader, S return error; } - public override void Write(ref NBMP::MessagePackWriter writer, in Protocol.JsonRpcError? value, SerializationContext context) + public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcError? value, SerializationContext context) { Requires.NotNull(value!, nameof(value)); @@ -1748,7 +1743,7 @@ public override void Write(ref NBMP::MessagePackWriter writer, in Protocol.JsonR topLevelPropertyBag?.WriteProperties(ref writer); } - public override JsonObject? GetJsonSchema(NBMP::JsonSchemaContext context, ITypeShape typeShape) + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) { return CreateUndocumentedSchema(typeof(JsonRpcErrorConverter)); } @@ -1767,7 +1762,7 @@ internal JsonRpcErrorDetailConverter(NerdbankMessagePackFormatter formatter) this.formatter = formatter; } - public override Protocol.JsonRpcError.ErrorDetail Read(ref NBMP.MessagePackReader reader, SerializationContext context) + public override Protocol.JsonRpcError.ErrorDetail Read(ref MessagePackReader reader, SerializationContext context) { var result = new JsonRpcError.ErrorDetail(this.formatter.userDataContext); context.DepthStep(); @@ -1802,7 +1797,7 @@ public override Protocol.JsonRpcError.ErrorDetail Read(ref NBMP.MessagePackReade } [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] - public override void Write(ref NBMP.MessagePackWriter writer, in Protocol.JsonRpcError.ErrorDetail? value, SerializationContext context) + public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcError.ErrorDetail? value, SerializationContext context) { Requires.NotNull(value!, nameof(value)); @@ -1841,14 +1836,14 @@ private EventArgsConverter() } /// - public override void Write(ref NBMP.MessagePackWriter writer, in EventArgs? value, SerializationContext context) + public override void Write(ref MessagePackWriter writer, in EventArgs? value, SerializationContext context) { Requires.NotNull(value!, nameof(value)); writer.WriteMapHeader(0); } /// - public override EventArgs Read(ref NBMP.MessagePackReader reader, SerializationContext context) + public override EventArgs Read(ref MessagePackReader reader, SerializationContext context) { reader.Skip(context); return EventArgs.Empty; @@ -1862,7 +1857,7 @@ public override EventArgs Read(ref NBMP.MessagePackReader reader, SerializationC private class TraceParentConverter : MessagePackConverter { - public unsafe override TraceParent Read(ref NBMP.MessagePackReader reader, SerializationContext context) + public unsafe override TraceParent Read(ref MessagePackReader reader, SerializationContext context) { if (reader.ReadArrayHeader() != 2) { @@ -1892,7 +1887,7 @@ public unsafe override TraceParent Read(ref NBMP.MessagePackReader reader, Seria return result; } - public unsafe override void Write(ref NBMP.MessagePackWriter writer, in TraceParent value, SerializationContext context) + public unsafe override void Write(ref MessagePackWriter writer, in TraceParent value, SerializationContext context) { if (value.Version != 0) { @@ -1959,7 +1954,7 @@ internal TopLevelPropertyBag(FormatterContext formatterContext) /// Writes the properties tracked by this collection to a messagepack writer. /// /// The writer to use. - internal void WriteProperties(ref NBMP::MessagePackWriter writer) + internal void WriteProperties(ref MessagePackWriter writer) { if (this.inboundUnknownProperties is not null) { @@ -1998,7 +1993,7 @@ protected internal override bool TryGetTopLevelProperty(string name, [MaybeNu if (this.inboundUnknownProperties.TryGetValue(name, out ReadOnlySequence serializedValue) is true) { - var reader = new NBMP::MessagePackReader(serializedValue); + var reader = new MessagePackReader(serializedValue); value = this.formatterContext.Serializer.Deserialize(ref reader, this.formatterContext.ShapeProvider); return true; } @@ -2052,7 +2047,7 @@ public override ArgumentMatchResult TryGetTypedArguments(ReadOnlySpan() return this.MsgPackResult.IsEmpty ? (T)this.Result! : this.serializerOptions.Serializer.Deserialize(this.MsgPackResult, this.serializerOptions.ShapeProvider) - ?? throw new NBMP::MessagePackSerializationException(Resources.FailureDeserializingJsonRpc); + ?? throw new MessagePackSerializationException(Resources.FailureDeserializingJsonRpc); } protected internal override void SetExpectedResultType(Type resultType) { Verify.Operation(!this.MsgPackResult.IsEmpty, "Result is no longer available or has already been deserialized."); - var reader = new NBMP::MessagePackReader(this.MsgPackResult); + var reader = new MessagePackReader(this.MsgPackResult); try { using (this.formatter.TrackDeserialization(this)) @@ -2177,7 +2172,7 @@ protected internal override void SetExpectedResultType(Type resultType) this.MsgPackResult = default; } - catch (NBMP::MessagePackSerializationException ex) + catch (MessagePackSerializationException ex) { // This was a best effort anyway. We'll throw again later at a more convenient time for JsonRpc. this.resultDeserializationException = ex; @@ -2240,15 +2235,15 @@ internal ErrorDetail(FormatterContext serializerOptions) return this.Data; } - var reader = new NBMP::MessagePackReader(this.MsgPackData); + var reader = new MessagePackReader(this.MsgPackData); try { return this.serializerOptions.Serializer.DeserializeObject( ref reader, this.serializerOptions.ShapeProvider.Resolve(dataType)) - ?? throw new NBMP::MessagePackSerializationException(Resources.FailureDeserializingJsonRpc); + ?? throw new MessagePackSerializationException(Resources.FailureDeserializingJsonRpc); } - catch (NBMP::MessagePackSerializationException) + catch (MessagePackSerializationException) { // Deserialization failed. Try returning array/dictionary based primitive objects. try @@ -2257,7 +2252,7 @@ internal ErrorDetail(FormatterContext serializerOptions) // TODO: Which Shape Provider to use? return this.serializerOptions.Serializer.Deserialize(this.MsgPackData, this.serializerOptions.ShapeProvider); } - catch (NBMP::MessagePackSerializationException) + catch (MessagePackSerializationException) { return null; } @@ -2277,23 +2272,23 @@ protected internal override void SetExpectedDataType(Type dataType) } } - private class FormatterContext(NBMP::MessagePackSerializer serializer, ITypeShapeProvider shapeProvider) + internal class FormatterContext(MessagePackSerializer serializer, ITypeShapeProvider shapeProvider) { - public NBMP::MessagePackSerializer Serializer => serializer; + public MessagePackSerializer Serializer => serializer; public ITypeShapeProvider ShapeProvider => shapeProvider; - public T? Deserialize(ref NBMP.MessagePackReader reader, CancellationToken cancellationToken = default) + public T? Deserialize(ref MessagePackReader reader, CancellationToken cancellationToken = default) { return serializer.Deserialize(ref reader, shapeProvider, cancellationToken); } - public void Serialize(ref NBMP.MessagePackWriter writer, T? value, CancellationToken cancellationToken = default) + public void Serialize(ref MessagePackWriter writer, T? value, CancellationToken cancellationToken = default) { serializer.Serialize(ref writer, value, shapeProvider, cancellationToken); } - internal void SerializeObject(ref NBMP.MessagePackWriter writer, object? value, Type objectType, CancellationToken cancellationToken = default) + internal void SerializeObject(ref MessagePackWriter writer, object? value, Type objectType, CancellationToken cancellationToken = default) { serializer.SerializeObject(ref writer, value, shapeProvider.Resolve(objectType), cancellationToken); } From 345a39c5b05f9b88bda85389c7443ca3fc1146a0 Mon Sep 17 00:00:00 2001 From: Charles Willis Date: Sat, 21 Dec 2024 18:33:37 +0000 Subject: [PATCH 08/25] Downgrade package versions for compatibility --- Directory.Packages.props | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Directory.Packages.props b/Directory.Packages.props index f04dabd42..6a87cb5f6 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -15,7 +15,7 @@ - + @@ -25,12 +25,12 @@ - + - + From ad436f73b8e801cc1277564087fa11f0310e9bc6 Mon Sep 17 00:00:00 2001 From: Charles Willis Date: Sat, 21 Dec 2024 19:25:35 +0000 Subject: [PATCH 09/25] User RawMessagePack from Nerdbank.MessagePack --- ...nkMessagePackFormatter.FormatterContext.cs | 55 ++++++++++ ...bankMessagePackFormatter.RawMessagePack.cs | 88 --------------- .../NerdbankMessagePackFormatter.cs | 101 +++++------------- 3 files changed, 82 insertions(+), 162 deletions(-) create mode 100644 src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContext.cs delete mode 100644 src/StreamJsonRpc/NerdbankMessagePackFormatter.RawMessagePack.cs diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContext.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContext.cs new file mode 100644 index 000000000..3d560bbc5 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContext.cs @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Nerdbank.MessagePack; +using PolyType; +using PolyType.Abstractions; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +/// +/// The MessagePack implementation used here comes from https://github.com/AArnott/Nerdbank.MessagePack. +/// +public sealed partial class NerdbankMessagePackFormatter +{ + internal class FormatterContext(MessagePackSerializer serializer, ITypeShapeProvider shapeProvider) + { + public MessagePackSerializer Serializer => serializer; + + public ITypeShapeProvider ShapeProvider => shapeProvider; + + public T? Deserialize(ref MessagePackReader reader, CancellationToken cancellationToken = default) + { + return serializer.Deserialize(ref reader, shapeProvider, cancellationToken); + } + + public T Deserialize(in RawMessagePack pack, CancellationToken cancellationToken = default) + { + // TODO: Improve the exception + return serializer.Deserialize(pack, shapeProvider, cancellationToken) + ?? throw new InvalidOperationException("Deserialization failed."); + } + + public object? DeserializeObject(in RawMessagePack pack, Type objectType, CancellationToken cancellationToken = default) + { + MessagePackReader reader = new(pack); + return serializer.DeserializeObject( + ref reader, + shapeProvider.Resolve(objectType), + cancellationToken); + } + + public void Serialize(ref MessagePackWriter writer, T? value, CancellationToken cancellationToken = default) + { + serializer.Serialize(ref writer, value, shapeProvider, cancellationToken); + } + + internal void SerializeObject(ref MessagePackWriter writer, object? value, Type objectType, CancellationToken cancellationToken = default) + { + serializer.SerializeObject(ref writer, value, shapeProvider.Resolve(objectType), cancellationToken); + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.RawMessagePack.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.RawMessagePack.cs deleted file mode 100644 index 895980392..000000000 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.RawMessagePack.cs +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. - -using System.Buffers; -using Nerdbank.MessagePack; -using PolyType.Abstractions; - -namespace StreamJsonRpc; - -/// -/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). -/// -/// -/// The MessagePack implementation used here comes from https://github.com/AArnott/Nerdbank.MessagePack. -/// -public partial class NerdbankMessagePackFormatter -{ - private struct RawMessagePack - { - private readonly ReadOnlySequence rawSequence; - - private readonly ReadOnlyMemory rawMemory; - - private RawMessagePack(ReadOnlySequence raw) - { - this.rawSequence = raw; - this.rawMemory = default; - } - - private RawMessagePack(ReadOnlyMemory raw) - { - this.rawSequence = default; - this.rawMemory = raw; - } - - internal readonly bool IsDefault => this.rawMemory.IsEmpty && this.rawSequence.IsEmpty; - - public override readonly string ToString() => ""; - - /// - /// Reads one raw messagepack token. - /// - /// The reader to use. - /// if the token must outlive the lifetime of the reader's underlying buffer; otherwise. - /// The serialization context to use. - /// The raw messagepack slice. - internal static RawMessagePack ReadRaw(ref MessagePackReader reader, bool copy, Nerdbank.MessagePack.SerializationContext context) - { - SequencePosition initialPosition = reader.Position; - reader.Skip(context); - ReadOnlySequence slice = reader.Sequence.Slice(initialPosition, reader.Position); - return copy ? new RawMessagePack(slice.ToArray()) : new RawMessagePack(slice); - } - - internal readonly void WriteRaw(ref MessagePackWriter writer) - { - if (this.rawSequence.IsEmpty) - { - writer.WriteRaw(this.rawMemory.Span); - } - else - { - writer.WriteRaw(this.rawSequence); - } - } - - internal readonly object? Deserialize(Type type, FormatterContext options) - { - MessagePackReader reader = this.rawSequence.IsEmpty - ? new MessagePackReader(this.rawMemory) - : new MessagePackReader(this.rawSequence); - - return options.Serializer.DeserializeObject( - ref reader, - options.ShapeProvider.Resolve(type)); - } - - internal readonly T Deserialize(FormatterContext options) - { - MessagePackReader reader = this.rawSequence.IsEmpty - ? new MessagePackReader(this.rawMemory) - : new MessagePackReader(this.rawSequence); - - return options.Serializer.Deserialize(ref reader, options.ShapeProvider) - ?? throw new MessagePackSerializationException(Resources.FailureDeserializingJsonRpc); - } - } -} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs index 78566beb5..41c336a17 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs @@ -436,62 +436,63 @@ private void MassageUserDataContext(FormatterContext userDataContext) { // Add our own resolvers to fill in specialized behavior if the user doesn't provide/override it by their own resolver. userDataContext.Serializer.RegisterConverter(RequestIdConverter.Instance); - userDataContext.Serializer.RegisterConverter(RawMessagePackConverter.Instance); userDataContext.Serializer.RegisterConverter(EventArgsConverter.Instance); } private class MessagePackFormatterConverter : IFormatterConverter { - private readonly FormatterContext options; + private readonly FormatterContext context; - internal MessagePackFormatterConverter(FormatterContext options) + internal MessagePackFormatterConverter(FormatterContext formatterContext) { - this.options = options; + this.context = formatterContext; } #pragma warning disable CS8766 // This method may in fact return null, and no one cares. public object? Convert(object value, Type type) #pragma warning restore CS8766 - => ((RawMessagePack)value).Deserialize(type, this.options); + { + return this.context.DeserializeObject((RawMessagePack)value, type); + } public object Convert(object value, TypeCode typeCode) { return typeCode switch { - TypeCode.Object => ((RawMessagePack)value).Deserialize(this.options), + TypeCode.Object => this.context.Deserialize((RawMessagePack)value), _ => ExceptionSerializationHelpers.Convert(this, value, typeCode), }; } - public bool ToBoolean(object value) => ((RawMessagePack)value).Deserialize(this.options); + public bool ToBoolean(object value) => this.context.Deserialize((RawMessagePack)value); - public byte ToByte(object value) => ((RawMessagePack)value).Deserialize(this.options); + public byte ToByte(object value) => this.context.Deserialize((RawMessagePack)value); - public char ToChar(object value) => ((RawMessagePack)value).Deserialize(this.options); + public char ToChar(object value) => this.context.Deserialize((RawMessagePack)value); - public DateTime ToDateTime(object value) => ((RawMessagePack)value).Deserialize(this.options); + public DateTime ToDateTime(object value) => this.context.Deserialize((RawMessagePack)value); - public decimal ToDecimal(object value) => ((RawMessagePack)value).Deserialize(this.options); + public decimal ToDecimal(object value) => this.context.Deserialize((RawMessagePack)value); - public double ToDouble(object value) => ((RawMessagePack)value).Deserialize(this.options); + public double ToDouble(object value) => this.context.Deserialize((RawMessagePack)value); - public short ToInt16(object value) => ((RawMessagePack)value).Deserialize(this.options); + public short ToInt16(object value) => this.context.Deserialize((RawMessagePack)value); - public int ToInt32(object value) => ((RawMessagePack)value).Deserialize(this.options); + public int ToInt32(object value) => this.context.Deserialize((RawMessagePack)value); - public long ToInt64(object value) => ((RawMessagePack)value).Deserialize(this.options); + public long ToInt64(object value) => this.context.Deserialize((RawMessagePack)value); - public sbyte ToSByte(object value) => ((RawMessagePack)value).Deserialize(this.options); + public sbyte ToSByte(object value) => this.context.Deserialize((RawMessagePack)value); - public float ToSingle(object value) => ((RawMessagePack)value).Deserialize(this.options); + public float ToSingle(object value) => this.context.Deserialize((RawMessagePack)value); - public string? ToString(object value) => value is null ? null : ((RawMessagePack)value).Deserialize(this.options); + public string? ToString(object value) => value is null ? null : this.context.Deserialize((RawMessagePack)value); - public ushort ToUInt16(object value) => ((RawMessagePack)value).Deserialize(this.options); + public ushort ToUInt16(object value) => this.context.Deserialize((RawMessagePack)value); - public uint ToUInt32(object value) => ((RawMessagePack)value).Deserialize(this.options); + public uint ToUInt32(object value) => this.context.Deserialize((RawMessagePack)value); - public ulong ToUInt64(object value) => ((RawMessagePack)value).Deserialize(this.options); + public ulong ToUInt64(object value) => this.context.Deserialize((RawMessagePack)value); } /// @@ -572,32 +573,6 @@ public override void Write(ref MessagePackWriter writer, in RequestId value, Ser """)?.AsObject(); } - private class RawMessagePackConverter : MessagePackConverter - { - internal static readonly RawMessagePackConverter Instance = new(); - - private RawMessagePackConverter() - { - } - - [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] - public override RawMessagePack Read(ref MessagePackReader reader, SerializationContext context) - { - return RawMessagePack.ReadRaw(ref reader, copy: false, context); - } - - [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] - public override void Write(ref MessagePackWriter writer, in RawMessagePack value, SerializationContext context) - { - value.WriteRaw(ref writer); - } - - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) - { - return CreateUndocumentedSchema(typeof(RawMessagePackConverter)); - } - } - private class ProgressConverterResolver { private readonly NerdbankMessagePackFormatter mainFormatter; @@ -682,7 +657,7 @@ internal PreciseTypeConverter(NerdbankMessagePackFormatter formatter) } Assumes.NotNull(this.formatter.JsonRpc); - RawMessagePack token = RawMessagePack.ReadRaw(ref reader, copy: true, context); + RawMessagePack token = reader.ReadRaw(context); bool clientRequiresNamedArgs = this.formatter.ApplicableMethodAttributeOnDeserializingMethod?.ClientRequiresNamedArguments is true; return (TClass)this.formatter.FormatterProgressTracker.CreateProgress(this.formatter.JsonRpc, token, typeof(TClass), clientRequiresNamedArgs); } @@ -759,7 +734,7 @@ private partial class PreciseTypeConverter(NerdbankMessagePackFormatter mainF } context.DepthStep(); - RawMessagePack token = default; + RawMessagePack? token = default; IReadOnlyList? initialElements = null; int propertyCount = reader.ReadMapHeader(); for (int i = 0; i < propertyCount; i++) @@ -771,7 +746,7 @@ private partial class PreciseTypeConverter(NerdbankMessagePackFormatter mainF if (TokenPropertyName.TryRead(stringKey)) { - token = RawMessagePack.ReadRaw(ref reader, copy: true, context); + token = reader.ReadRaw(context); } else if (ValuesPropertyName.TryRead(stringKey)) { @@ -783,7 +758,7 @@ private partial class PreciseTypeConverter(NerdbankMessagePackFormatter mainF } } - return mainFormatter.EnumerableTracker.CreateEnumerableProxy(token.IsDefault ? null : token, initialElements); + return mainFormatter.EnumerableTracker.CreateEnumerableProxy(token.HasValue ? token : null, initialElements); } [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] @@ -1148,7 +1123,7 @@ private partial class ExceptionConverter(NerdbankMessagePackFormatter formatt // so the caller will get a boxed RawMessagePack struct in that case. // Although we can't do much about *that* in general, we can at least ensure that null values // are represented as null instead of this boxed struct. - var value = reader.TryReadNil() ? null : (object)RawMessagePack.ReadRaw(ref reader, false, context); + var value = reader.TryReadNil() ? null : (object)reader.ReadRaw(context); info.AddSafeValue(name, value); } @@ -2271,26 +2246,4 @@ protected internal override void SetExpectedDataType(Type dataType) } } } - - internal class FormatterContext(MessagePackSerializer serializer, ITypeShapeProvider shapeProvider) - { - public MessagePackSerializer Serializer => serializer; - - public ITypeShapeProvider ShapeProvider => shapeProvider; - - public T? Deserialize(ref MessagePackReader reader, CancellationToken cancellationToken = default) - { - return serializer.Deserialize(ref reader, shapeProvider, cancellationToken); - } - - public void Serialize(ref MessagePackWriter writer, T? value, CancellationToken cancellationToken = default) - { - serializer.Serialize(ref writer, value, shapeProvider, cancellationToken); - } - - internal void SerializeObject(ref MessagePackWriter writer, object? value, Type objectType, CancellationToken cancellationToken = default) - { - serializer.SerializeObject(ref writer, value, shapeProvider.Resolve(objectType), cancellationToken); - } - } } From 1a795c7368a2844cde90484376eca9b6699ed4bd Mon Sep 17 00:00:00 2001 From: Charles Willis Date: Sat, 21 Dec 2024 20:49:17 +0000 Subject: [PATCH 10/25] Remove namespace aliases --- .../NerdbankMessagePackFormatter.cs | 11 ++---- src/StreamJsonRpc/Protocol/JsonRpcError.cs | 16 ++++----- src/StreamJsonRpc/Protocol/JsonRpcMessage.cs | 12 +++---- src/StreamJsonRpc/Protocol/JsonRpcRequest.cs | 34 +++++++++---------- src/StreamJsonRpc/Protocol/JsonRpcResult.cs | 12 +++---- 5 files changed, 39 insertions(+), 46 deletions(-) diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs index 41c336a17..204196138 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs @@ -157,13 +157,6 @@ private interface IJsonRpcMessagePackRetention ReadOnlySequence OriginalMessagePack { get; } } - /// - public new MultiplexingStream? MultiplexingStream - { - get => base.MultiplexingStream; - set => base.MultiplexingStream = value; - } - /// /// Configures the serialization context for user data with the specified configuration action. /// @@ -328,7 +321,7 @@ bool TryGetSerializationInfo(MemberInfo memberInfo, out string key) if (TryGetSerializationInfo(property, out string key)) { result[key] = property.GetValue(paramsObject); - if (mutableArgumentTypes is object) + if (mutableArgumentTypes is not null) { mutableArgumentTypes[key] = property.PropertyType; } @@ -341,7 +334,7 @@ bool TryGetSerializationInfo(MemberInfo memberInfo, out string key) if (TryGetSerializationInfo(field, out string key)) { result[key] = field.GetValue(paramsObject); - if (mutableArgumentTypes is object) + if (mutableArgumentTypes is not null) { mutableArgumentTypes[key] = field.FieldType; } diff --git a/src/StreamJsonRpc/Protocol/JsonRpcError.cs b/src/StreamJsonRpc/Protocol/JsonRpcError.cs index d8500c09a..143cd82cb 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcError.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcError.cs @@ -3,9 +3,9 @@ using System.Diagnostics; using System.Runtime.Serialization; +using PolyType; using StreamJsonRpc.Reflection; using JsonNET = Newtonsoft.Json.Linq; -using PT = PolyType; using STJ = System.Text.Json.Serialization; namespace StreamJsonRpc.Protocol; @@ -14,7 +14,7 @@ namespace StreamJsonRpc.Protocol; /// Describes the error resulting from a that failed on the server. /// [DataContract] -[PT.GenerateShape] +[GenerateShape] [DebuggerDisplay("{" + nameof(DebuggerDisplay) + "}")] public partial class JsonRpcError : JsonRpcMessage, IJsonRpcMessageWithId { @@ -23,7 +23,7 @@ public partial class JsonRpcError : JsonRpcMessage, IJsonRpcMessageWithId /// [DataMember(Name = "error", Order = 2, IsRequired = true)] [STJ.JsonPropertyName("error"), STJ.JsonPropertyOrder(2), STJ.JsonRequired] - [PT.PropertyShape(Name = "error", Order = 2)] + [PropertyShape(Name = "error", Order = 2)] public ErrorDetail? Error { get; set; } /// @@ -33,7 +33,7 @@ public partial class JsonRpcError : JsonRpcMessage, IJsonRpcMessageWithId [Obsolete("Use " + nameof(RequestId) + " instead.")] [IgnoreDataMember] [STJ.JsonIgnore] - [PT.PropertyShape(Ignore = true)] + [PropertyShape(Ignore = true)] public object? Id { get => this.RequestId.ObjectValue; @@ -45,7 +45,7 @@ public object? Id /// [DataMember(Name = "id", Order = 1, IsRequired = true, EmitDefaultValue = true)] [STJ.JsonPropertyName("id"), STJ.JsonPropertyOrder(1), STJ.JsonRequired] - [PT.PropertyShape(Name = "id", Order = 1)] + [PropertyShape(Name = "id", Order = 1)] public RequestId RequestId { get; set; } /// @@ -82,7 +82,7 @@ public class ErrorDetail /// [DataMember(Name = "code", Order = 0, IsRequired = true)] [STJ.JsonPropertyName("code"), STJ.JsonPropertyOrder(0), STJ.JsonRequired] - [PT.PropertyShape(Name = "code", Order = 0)] + [PropertyShape(Name = "code", Order = 0)] public JsonRpcErrorCode Code { get; set; } /// @@ -93,7 +93,7 @@ public class ErrorDetail /// [DataMember(Name = "message", Order = 1, IsRequired = true)] [STJ.JsonPropertyName("message"), STJ.JsonPropertyOrder(1), STJ.JsonRequired] - [PT.PropertyShape(Name = "message", Order = 1)] + [PropertyShape(Name = "message", Order = 1)] public string? Message { get; set; } /// @@ -102,7 +102,7 @@ public class ErrorDetail [DataMember(Name = "data", Order = 2, IsRequired = false)] [Newtonsoft.Json.JsonProperty(DefaultValueHandling = Newtonsoft.Json.DefaultValueHandling.Ignore)] [STJ.JsonPropertyName("data"), STJ.JsonPropertyOrder(2)] - [PT.PropertyShape(Name = "data", Order = 2)] + [PropertyShape(Name = "data", Order = 2)] public object? Data { get; set; } /// diff --git a/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs b/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs index e717a2b2a..1a9a6edc9 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs @@ -3,8 +3,8 @@ using System.Diagnostics.CodeAnalysis; using System.Runtime.Serialization; -using NBMP = Nerdbank.MessagePack; -using PT = PolyType; +using Nerdbank.MessagePack; +using PolyType; using STJ = System.Text.Json.Serialization; namespace StreamJsonRpc.Protocol; @@ -17,9 +17,9 @@ namespace StreamJsonRpc.Protocol; [KnownType(typeof(JsonRpcResult))] [KnownType(typeof(JsonRpcError))] #pragma warning disable CS0618 //'KnownSubTypeAttribute.KnownSubTypeAttribute(Type)' is obsolete: 'Use the generic version of this attribute instead.' -[NBMP::KnownSubType(typeof(JsonRpcRequest))] -[NBMP::KnownSubType(typeof(JsonRpcResult))] -[NBMP::KnownSubType(typeof(JsonRpcError))] +[KnownSubType(typeof(JsonRpcRequest))] +[KnownSubType(typeof(JsonRpcResult))] +[KnownSubType(typeof(JsonRpcError))] #pragma warning restore CS0618 public abstract class JsonRpcMessage { @@ -29,7 +29,7 @@ public abstract class JsonRpcMessage /// Defaults to "2.0". [DataMember(Name = "jsonrpc", Order = 0, IsRequired = true)] [STJ.JsonPropertyName("jsonrpc"), STJ.JsonPropertyOrder(0), STJ.JsonRequired] - [PT.PropertyShape(Name = "jsonrpc", Order = 0)] + [PropertyShape(Name = "jsonrpc", Order = 0)] public string Version { get; set; } = "2.0"; /// diff --git a/src/StreamJsonRpc/Protocol/JsonRpcRequest.cs b/src/StreamJsonRpc/Protocol/JsonRpcRequest.cs index 4f3fca685..33f5a96cd 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcRequest.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcRequest.cs @@ -4,8 +4,8 @@ using System.Diagnostics; using System.Reflection; using System.Runtime.Serialization; +using PolyType; using JsonNET = Newtonsoft.Json.Linq; -using PT = PolyType; using STJ = System.Text.Json.Serialization; namespace StreamJsonRpc.Protocol; @@ -14,7 +14,7 @@ namespace StreamJsonRpc.Protocol; /// Describes a method to be invoked on the server. /// [DataContract] -[PT.GenerateShape] +[GenerateShape] [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] public partial class JsonRpcRequest : JsonRpcMessage, IJsonRpcMessageWithId { @@ -49,7 +49,7 @@ public enum ArgumentMatchResult /// [DataMember(Name = "method", Order = 2, IsRequired = true)] [STJ.JsonPropertyName("method"), STJ.JsonPropertyOrder(2), STJ.JsonRequired] - [PT.PropertyShape(Name = "method", Order = 2)] + [PropertyShape(Name = "method", Order = 2)] public string? Method { get; set; } /// @@ -64,7 +64,7 @@ public enum ArgumentMatchResult /// [DataMember(Name = "params", Order = 3, IsRequired = false, EmitDefaultValue = false)] [STJ.JsonPropertyName("params"), STJ.JsonPropertyOrder(3), STJ.JsonIgnore(Condition = STJ.JsonIgnoreCondition.WhenWritingNull)] - [PT.PropertyShape(Name = "params", Order = 1)] + [PropertyShape(Name = "params", Order = 3)] public object? Arguments { get; set; } /// @@ -74,7 +74,7 @@ public enum ArgumentMatchResult [Obsolete("Use " + nameof(RequestId) + " instead.")] [IgnoreDataMember] [STJ.JsonIgnore] - [PT.PropertyShape(Ignore = true)] + [PropertyShape(Ignore = true)] public object? Id { get => this.RequestId.ObjectValue; @@ -86,7 +86,7 @@ public object? Id /// [DataMember(Name = "id", Order = 1, IsRequired = false, EmitDefaultValue = false)] [STJ.JsonPropertyName("id"), STJ.JsonPropertyOrder(1), STJ.JsonIgnore(Condition = STJ.JsonIgnoreCondition.WhenWritingDefault)] - [PT.PropertyShape(Name = "id", Order = 1)] + [PropertyShape(Name = "id", Order = 1)] public RequestId RequestId { get; set; } /// @@ -94,7 +94,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] - [PT.PropertyShape(Ignore = true)] + [PropertyShape(Ignore = true)] public bool IsResponseExpected => !this.RequestId.IsEmpty; /// @@ -102,7 +102,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] - [PT.PropertyShape(Ignore = true)] + [PropertyShape(Ignore = true)] public bool IsNotification => this.RequestId.IsEmpty; /// @@ -110,7 +110,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] - [PT.PropertyShape(Ignore = true)] + [PropertyShape(Ignore = true)] public virtual int ArgumentCount => this.NamedArguments?.Count ?? this.ArgumentsList?.Count ?? 0; /// @@ -118,7 +118,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] - [PT.PropertyShape(Ignore = true)] + [PropertyShape(Ignore = true)] public IReadOnlyDictionary? NamedArguments { get => this.Arguments as IReadOnlyDictionary; @@ -137,7 +137,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] - [PT.PropertyShape(Ignore = true)] + [PropertyShape(Ignore = true)] public IReadOnlyDictionary? NamedArgumentDeclaredTypes { get; set; } /// @@ -145,7 +145,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] - [PT.PropertyShape(Ignore = true)] + [PropertyShape(Ignore = true)] [Obsolete("Use " + nameof(ArgumentsList) + " instead.")] public object?[]? ArgumentsArray { @@ -158,7 +158,7 @@ public object?[]? ArgumentsArray /// [IgnoreDataMember] [STJ.JsonIgnore] - [PT.PropertyShape(Ignore = true)] + [PropertyShape(Ignore = true)] public IReadOnlyList? ArgumentsList { get => this.Arguments as IReadOnlyList; @@ -179,7 +179,7 @@ public IReadOnlyList? ArgumentsList /// [IgnoreDataMember] [STJ.JsonIgnore] - [PT.PropertyShape(Ignore = true)] + [PropertyShape(Ignore = true)] public IReadOnlyList? ArgumentListDeclaredTypes { get; set; } /// @@ -187,7 +187,7 @@ public IReadOnlyList? ArgumentsList /// [IgnoreDataMember] [STJ.JsonIgnore] - [PT.PropertyShape(Ignore = true)] + [PropertyShape(Ignore = true)] public virtual IEnumerable? ArgumentNames => this.NamedArguments?.Keys; /// @@ -195,7 +195,7 @@ public IReadOnlyList? ArgumentsList /// [DataMember(Name = "traceparent", EmitDefaultValue = false)] [STJ.JsonPropertyName("traceparent"), STJ.JsonIgnore(Condition = STJ.JsonIgnoreCondition.WhenWritingNull)] - [PT.PropertyShape(Name = "traceparent")] + [PropertyShape(Name = "traceparent")] public string? TraceParent { get; set; } /// @@ -203,7 +203,7 @@ public IReadOnlyList? ArgumentsList /// [DataMember(Name = "tracestate", EmitDefaultValue = false)] [STJ.JsonPropertyName("tracestate"), STJ.JsonIgnore(Condition = STJ.JsonIgnoreCondition.WhenWritingNull)] - [PT.PropertyShape(Name = "tracestate")] + [PropertyShape(Name = "tracestate")] public string? TraceState { get; set; } /// diff --git a/src/StreamJsonRpc/Protocol/JsonRpcResult.cs b/src/StreamJsonRpc/Protocol/JsonRpcResult.cs index e81ba931a..0019e660e 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcResult.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcResult.cs @@ -3,8 +3,8 @@ using System.Diagnostics; using System.Runtime.Serialization; +using PolyType; using JsonNET = Newtonsoft.Json.Linq; -using PT = PolyType; using STJ = System.Text.Json.Serialization; namespace StreamJsonRpc.Protocol; @@ -14,7 +14,7 @@ namespace StreamJsonRpc.Protocol; /// [DataContract] [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] -[PT.GenerateShape] +[GenerateShape] public partial class JsonRpcResult : JsonRpcMessage, IJsonRpcMessageWithId { /// @@ -22,7 +22,7 @@ public partial class JsonRpcResult : JsonRpcMessage, IJsonRpcMessageWithId /// [DataMember(Name = "result", Order = 2, IsRequired = true, EmitDefaultValue = true)] [STJ.JsonPropertyName("result"), STJ.JsonPropertyOrder(2), STJ.JsonRequired] - [PT.PropertyShape(Name = "result", Order = 2)] + [PropertyShape(Name = "result", Order = 2)] public object? Result { get; set; } /// @@ -33,7 +33,7 @@ public partial class JsonRpcResult : JsonRpcMessage, IJsonRpcMessageWithId /// [IgnoreDataMember] [STJ.JsonIgnore] - [PT.PropertyShape(Ignore = true)] + [PropertyShape(Ignore = true)] public Type? ResultDeclaredType { get; set; } /// @@ -43,7 +43,7 @@ public partial class JsonRpcResult : JsonRpcMessage, IJsonRpcMessageWithId [Obsolete("Use " + nameof(RequestId) + " instead.")] [IgnoreDataMember] [STJ.JsonIgnore] - [PT.PropertyShape(Ignore = true)] + [PropertyShape(Ignore = true)] public object? Id { get => this.RequestId.ObjectValue; @@ -55,7 +55,7 @@ public object? Id /// [DataMember(Name = "id", Order = 1, IsRequired = true)] [STJ.JsonPropertyName("id"), STJ.JsonPropertyOrder(1), STJ.JsonRequired] - [PT.PropertyShape(Name = "id", Order = 1)] + [PropertyShape(Name = "id", Order = 1)] public RequestId RequestId { get; set; } /// From 0e6b4aa9c8ffee13394a5abe720a2b21c9d5228e Mon Sep 17 00:00:00 2001 From: Charles Willis Date: Sat, 21 Dec 2024 22:59:24 +0000 Subject: [PATCH 11/25] Add unit tests --- ...ePackFormatter.FormatterContextBuilder.cs} | 0 test/StreamJsonRpc.Tests/AssemblyLoadTests.cs | 26 + ...AsyncEnumerableNerdbankMessagePackTests.cs | 16 + ...DisposableProxyNerdbankMessagePackTests.cs | 16 + ...xPipeMarshalingNerdbankMessagePackTests.cs | 16 + .../JsonRpcNerdbankMessagePackLengthTests.cs | 543 ++++++++++++++++++ ...arshalableProxyNerdbankMessagePackTests.cs | 16 + .../NerdbankMessagePackFormatterTests.cs | 453 +++++++++++++++ ...erverMarshalingNerdbankMessagePackTests.cs | 12 + .../StreamJsonRpc.Tests.csproj | 9 + ...getObjectEventsNerdbankMessagePackTests.cs | 16 + ...tMessageHandlerNerdbankMessagePackTests.cs | 7 + 12 files changed, 1130 insertions(+) rename src/StreamJsonRpc/{NerdbankMessagePackFormatter.ISerializationContextBuilder.cs => NerdbankMessagePackFormatter.FormatterContextBuilder.cs} (100%) create mode 100644 test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs create mode 100644 test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs create mode 100644 test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs create mode 100644 test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs create mode 100644 test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs create mode 100644 test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs create mode 100644 test/StreamJsonRpc.Tests/ObserverMarshalingNerdbankMessagePackTests.cs create mode 100644 test/StreamJsonRpc.Tests/TargetObjectEventsNerdbankMessagePackTests.cs create mode 100644 test/StreamJsonRpc.Tests/WebSocketMessageHandlerNerdbankMessagePackTests.cs diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.ISerializationContextBuilder.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContextBuilder.cs similarity index 100% rename from src/StreamJsonRpc/NerdbankMessagePackFormatter.ISerializationContextBuilder.cs rename to src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContextBuilder.cs diff --git a/test/StreamJsonRpc.Tests/AssemblyLoadTests.cs b/test/StreamJsonRpc.Tests/AssemblyLoadTests.cs index ebf80d0d1..441ab0215 100644 --- a/test/StreamJsonRpc.Tests/AssemblyLoadTests.cs +++ b/test/StreamJsonRpc.Tests/AssemblyLoadTests.cs @@ -63,6 +63,27 @@ public void MessagePackDoesNotLoadNewtonsoftJsonUnnecessarily() } } + [Fact] + public void NerdbankMessagePackDoesNotLoadNewtonsoftJsonUnnecessarily() + { + AppDomain testDomain = CreateTestAppDomain(); + try + { + var driver = (AppDomainTestDriver)testDomain.CreateInstanceAndUnwrap(typeof(AppDomainTestDriver).Assembly.FullName, typeof(AppDomainTestDriver).FullName); + + this.PrintLoadedAssemblies(driver); + + driver.CreateNerdbankMessagePackConnection(); + + this.PrintLoadedAssemblies(driver); + driver.ThrowIfAssembliesLoaded("Newtonsoft.Json"); + } + finally + { + AppDomain.Unload(testDomain); + } + } + [Fact] public void MockFormatterDoesNotLoadJsonOrMessagePackUnnecessarily() { @@ -142,6 +163,11 @@ internal void CreateMessagePackConnection() var jsonRpc = new JsonRpc(new LengthHeaderMessageHandler(FullDuplexStream.CreatePipePair().Item1, new MessagePackFormatter())); } + internal void CreateNerdbankMessagePackConnection() + { + var jsonRpc = new JsonRpc(new LengthHeaderMessageHandler(FullDuplexStream.CreatePipePair().Item1, new NerdbankMessagePackFormatter())); + } + #pragma warning restore CA1822 // Mark members as static private class MockFormatter : IJsonRpcMessageFormatter diff --git a/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs new file mode 100644 index 000000000..d4c5ee933 --- /dev/null +++ b/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +public class AsyncEnumerableNerdbankMessagePackTests : AsyncEnumerableTests +{ + public AsyncEnumerableNerdbankMessagePackTests(ITestOutputHelper logger) + : base(logger) + { + } + + protected override void InitializeFormattersAndHandlers() + { + this.serverMessageFormatter = new NerdbankMessagePackFormatter(); + this.clientMessageFormatter = new NerdbankMessagePackFormatter(); + } +} diff --git a/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs new file mode 100644 index 000000000..c4a381e65 --- /dev/null +++ b/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Nerdbank.MessagePack; + +public class DisposableProxyNerdbankMessagePackTests : DisposableProxyTests +{ + public DisposableProxyNerdbankMessagePackTests(ITestOutputHelper logger) + : base(logger) + { + } + + protected override Type FormatterExceptionType => typeof(MessagePackSerializationException); + + protected override IJsonRpcMessageFormatter CreateFormatter() => new NerdbankMessagePackFormatter(); +} diff --git a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs new file mode 100644 index 000000000..924929ae3 --- /dev/null +++ b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +public class DuplexPipeMarshalingNerdbankMessagePackTests : DuplexPipeMarshalingTests +{ + public DuplexPipeMarshalingNerdbankMessagePackTests(ITestOutputHelper logger) + : base(logger) + { + } + + protected override void InitializeFormattersAndHandlers() + { + this.serverMessageFormatter = new NerdbankMessagePackFormatter { MultiplexingStream = this.serverMx }; + this.clientMessageFormatter = new NerdbankMessagePackFormatter { MultiplexingStream = this.clientMx }; + } +} diff --git a/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs new file mode 100644 index 000000000..277ee3287 --- /dev/null +++ b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs @@ -0,0 +1,543 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Runtime.CompilerServices; +using Microsoft.VisualStudio.Threading; +using Nerdbank.MessagePack; +using PolyType; +using PolyType.SourceGenerator; + +public partial class JsonRpcNerdbankMessagePackLengthTests : JsonRpcTests +{ + public JsonRpcNerdbankMessagePackLengthTests(ITestOutputHelper logger) + : base(logger) + { + } + + internal interface IMessagePackServer + { + Task ReturnUnionTypeAsync(CancellationToken cancellationToken); + + Task AcceptUnionTypeAndReturnStringAsync(UnionBaseClass value, CancellationToken cancellationToken); + + Task AcceptUnionTypeAsync(UnionBaseClass value, CancellationToken cancellationToken); + + Task ProgressUnionType(IProgress progress, CancellationToken cancellationToken); + + IAsyncEnumerable GetAsyncEnumerableOfUnionType(CancellationToken cancellationToken); + + Task IsExtensionArgNonNull(CustomExtensionType extensionValue); + } + + protected override Type FormatterExceptionType => typeof(MessagePackSerializationException); + + [Fact] + public override async Task CanPassAndCallPrivateMethodsObjects() + { + var result = await this.clientRpc.InvokeAsync(nameof(Server.MethodThatAcceptsFoo), new Foo { Bar = "bar", Bazz = 1000 }); + Assert.NotNull(result); + Assert.Equal("bar!", result.Bar); + Assert.Equal(1001, result.Bazz); + } + + [Fact] + public async Task ExceptionControllingErrorData() + { + var exception = await Assert.ThrowsAsync(() => this.clientRpc.InvokeAsync(nameof(Server.ThrowLocalRpcException))).WithCancellation(this.TimeoutToken); + + IDictionary? data = (IDictionary?)exception.ErrorData; + Assert.NotNull(data); + object myCustomData = data["myCustomData"]; + string actual = (string)myCustomData; + Assert.Equal("hi", actual); + } + + [Fact] + public override async Task CanPassExceptionFromServer_ErrorData() + { + RemoteInvocationException exception = await Assert.ThrowsAnyAsync(() => this.clientRpc.InvokeAsync(nameof(Server.MethodThatThrowsUnauthorizedAccessException))); + Assert.Equal((int)JsonRpcErrorCode.InvocationError, exception.ErrorCode); + + var errorData = Assert.IsType(exception.ErrorData); + Assert.NotNull(errorData.StackTrace); + Assert.StrictEqual(COR_E_UNAUTHORIZEDACCESS, errorData.HResult); + } + + /// + /// Verifies that return values can support union types by considering the return type as declared in the server method signature. + /// + [Fact] + public async Task UnionType_ReturnValue() + { + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(new MessagePackServer()); + UnionBaseClass result = await this.clientRpc.InvokeWithCancellationAsync(nameof(MessagePackServer.ReturnUnionTypeAsync), null, this.TimeoutToken); + Assert.IsType(result); + } + + /// + /// Verifies that return values can support union types by considering the return type as declared in the server method signature. + /// + [Fact] + public async Task UnionType_ReturnValue_NonAsync() + { + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(new MessagePackServer()); + UnionBaseClass result = await this.clientRpc.InvokeWithCancellationAsync(nameof(MessagePackServer.ReturnUnionType), null, this.TimeoutToken); + Assert.IsType(result); + } + + /// + /// Verifies that positional parameters can support union types by providing extra type information for each argument. + /// + [Theory] + [CombinatorialData] + public async Task UnionType_PositionalParameter_NoReturnValue(bool notify) + { + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + UnionBaseClass? receivedValue; + if (notify) + { + await this.clientRpc.NotifyAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), new object?[] { new UnionDerivedClass() }, new[] { typeof(UnionBaseClass) }).WithCancellation(this.TimeoutToken); + receivedValue = await server.ReceivedValueSource.Task.WithCancellation(this.TimeoutToken); + } + else + { + await this.clientRpc.InvokeWithCancellationAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), new object?[] { new UnionDerivedClass() }, new[] { typeof(UnionBaseClass) }, this.TimeoutToken); + receivedValue = server.ReceivedValue; + } + + Assert.IsType(receivedValue); + } + + /// + /// Verifies that positional parameters can support union types by providing extra type information for each argument. + /// + [Fact] + public async Task UnionType_PositionalParameter_AndReturnValue() + { + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(new MessagePackServer()); + string? result = await this.clientRpc.InvokeWithCancellationAsync(nameof(MessagePackServer.AcceptUnionTypeAndReturnStringAsync), new object?[] { new UnionDerivedClass() }, new[] { typeof(UnionBaseClass) }, this.TimeoutToken); + Assert.Equal(typeof(UnionDerivedClass).Name, result); + } + + /// + /// Verifies that the type information associated with named parameters is used for proper serialization of union types. + /// + [Theory] + [CombinatorialData] + public async Task UnionType_NamedParameter_NoReturnValue_UntypedDictionary(bool notify) + { + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + var argument = new Dictionary { { "value", new UnionDerivedClass() } }; + var argumentDeclaredTypes = new Dictionary { { "value", typeof(UnionBaseClass) } }; + + UnionBaseClass? receivedValue; + if (notify) + { + await this.clientRpc.NotifyWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), argument, argumentDeclaredTypes).WithCancellation(this.TimeoutToken); + receivedValue = await server.ReceivedValueSource.Task.WithCancellation(this.TimeoutToken); + } + else + { + await this.clientRpc.InvokeWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), argument, argumentDeclaredTypes, this.TimeoutToken); + receivedValue = server.ReceivedValue; + } + + Assert.IsType(receivedValue); + + // Exercise the non-init path by repeating + server.ReceivedValueSource = new TaskCompletionSource(); + if (notify) + { + await this.clientRpc.NotifyWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), argument, argumentDeclaredTypes).WithCancellation(this.TimeoutToken); + receivedValue = await server.ReceivedValueSource.Task.WithCancellation(this.TimeoutToken); + } + else + { + await this.clientRpc.InvokeWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), argument, argumentDeclaredTypes, this.TimeoutToken); + receivedValue = server.ReceivedValue; + } + + Assert.IsType(receivedValue); + } + + /// + /// Verifies that the type information associated with named parameters is used for proper serialization of union types. + /// + [Fact] + public async Task UnionType_NamedParameter_AndReturnValue_UntypedDictionary() + { + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + string? result = await this.clientRpc.InvokeWithParameterObjectAsync( + nameof(MessagePackServer.AcceptUnionTypeAndReturnStringAsync), + new Dictionary { { "value", new UnionDerivedClass() } }, + new Dictionary { { "value", typeof(UnionBaseClass) } }, + this.TimeoutToken); + Assert.Equal(typeof(UnionDerivedClass).Name, result); + Assert.IsType(server.ReceivedValue); + + // Exercise the non-init path by repeating + result = await this.clientRpc.InvokeWithParameterObjectAsync( + nameof(MessagePackServer.AcceptUnionTypeAndReturnStringAsync), + new Dictionary { { "value", new UnionDerivedClass() } }, + new Dictionary { { "value", typeof(UnionBaseClass) } }, + this.TimeoutToken); + Assert.Equal(typeof(UnionDerivedClass).Name, result); + Assert.IsType(server.ReceivedValue); + } + + /// + /// Verifies that the type information associated with named parameters is used for proper serialization of union types. + /// + [Theory] + [CombinatorialData] + public async Task UnionType_NamedParameter_NoReturnValue(bool notify) + { + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + var namedArgs = new { value = (UnionBaseClass)new UnionDerivedClass() }; + + UnionBaseClass? receivedValue; + if (notify) + { + await this.clientRpc.NotifyWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), namedArgs).WithCancellation(this.TimeoutToken); + receivedValue = await server.ReceivedValueSource.Task.WithCancellation(this.TimeoutToken); + } + else + { + await this.clientRpc.InvokeWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), namedArgs, this.TimeoutToken); + receivedValue = server.ReceivedValue; + } + + Assert.IsType(receivedValue); + + // Exercise the non-init path by repeating + server.ReceivedValueSource = new TaskCompletionSource(); + if (notify) + { + await this.clientRpc.NotifyWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), namedArgs).WithCancellation(this.TimeoutToken); + receivedValue = await server.ReceivedValueSource.Task.WithCancellation(this.TimeoutToken); + } + else + { + await this.clientRpc.InvokeWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), namedArgs, this.TimeoutToken); + receivedValue = server.ReceivedValue; + } + + Assert.IsType(receivedValue); + } + + /// + /// Verifies that the type information associated with named parameters is used for proper serialization of union types. + /// + [Fact] + public async Task UnionType_NamedParameter_AndReturnValue() + { + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(new MessagePackServer()); + string? result = await this.clientRpc.InvokeWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAndReturnStringAsync), new { value = (UnionBaseClass)new UnionDerivedClass() }, this.TimeoutToken); + Assert.Equal(typeof(UnionDerivedClass).Name, result); + } + + /// + /// Verifies that return values can support union types by considering the return type as declared in the server method signature. + /// + [Fact] + public async Task UnionType_ReturnValue_Proxy() + { + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(new MessagePackServer()); + var clientProxy = this.clientRpc.Attach(); + UnionBaseClass result = await clientProxy.ReturnUnionTypeAsync(this.TimeoutToken); + Assert.IsType(result); + } + + /// + /// Verifies that positional parameters can support union types by providing extra type information for each argument. + /// + [Fact] + public async Task UnionType_PositionalParameter_AndReturnValue_Proxy() + { + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(new MessagePackServer()); + var clientProxy = this.clientRpc.Attach(); + string? result = await clientProxy.AcceptUnionTypeAndReturnStringAsync(new UnionDerivedClass(), this.TimeoutToken); + Assert.Equal(typeof(UnionDerivedClass).Name, result); + + // Repeat the proxy call to exercise the non-init path of the dynamically generated proxy. + result = await clientProxy.AcceptUnionTypeAndReturnStringAsync(new UnionDerivedClass(), this.TimeoutToken); + Assert.Equal(typeof(UnionDerivedClass).Name, result); + } + + /// + /// Verifies that the type information associated with named parameters is used for proper serialization of union types. + /// + [Fact] + public async Task UnionType_NamedParameter_AndReturnValue_Proxy() + { + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(new MessagePackServer()); + var clientProxy = this.clientRpc.Attach(new JsonRpcProxyOptions { ServerRequiresNamedArguments = true }); + string? result = await clientProxy.AcceptUnionTypeAndReturnStringAsync(new UnionDerivedClass(), this.TimeoutToken); + Assert.Equal(typeof(UnionDerivedClass).Name, result); + + // Repeat the proxy call to exercise the non-init path of the dynamically generated proxy. + result = await clientProxy.AcceptUnionTypeAndReturnStringAsync(new UnionDerivedClass(), this.TimeoutToken); + Assert.Equal(typeof(UnionDerivedClass).Name, result); + } + + /// + /// Verifies that positional parameters can support union types by providing extra type information for each argument. + /// + [Fact] + public async Task UnionType_PositionalParameter_NoReturnValue_Proxy() + { + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + var clientProxy = this.clientRpc.Attach(); + await clientProxy.AcceptUnionTypeAsync(new UnionDerivedClass(), this.TimeoutToken); + Assert.IsType(server.ReceivedValue); + + // Repeat the proxy call to exercise the non-init path of the dynamically generated proxy. + server.ReceivedValueSource = new TaskCompletionSource(); + await clientProxy.AcceptUnionTypeAsync(new UnionDerivedClass(), this.TimeoutToken); + Assert.IsType(server.ReceivedValue); + } + + /// + /// Verifies that the type information associated with named parameters is used for proper serialization of union types. + /// + [Fact] + public async Task UnionType_NamedParameter_NoReturnValue_Proxy() + { + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + var clientProxy = this.clientRpc.Attach(new JsonRpcProxyOptions { ServerRequiresNamedArguments = true }); + await clientProxy.AcceptUnionTypeAsync(new UnionDerivedClass(), this.TimeoutToken); + Assert.IsType(server.ReceivedValue); + + // Repeat the proxy call to exercise the non-init path of the dynamically generated proxy. + server.ReceivedValueSource = new TaskCompletionSource(); + await clientProxy.AcceptUnionTypeAsync(new UnionDerivedClass(), this.TimeoutToken); + Assert.IsType(server.ReceivedValue); + } + + [Fact] + public async Task UnionType_AsIProgressTypeArgument() + { + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + var clientProxy = this.clientRpc.Attach(); + + var reportSource = new TaskCompletionSource(); + var progress = new Progress(v => reportSource.SetResult(v)); + await clientProxy.ProgressUnionType(progress, this.TimeoutToken); + Assert.IsType(await reportSource.Task.WithCancellation(this.TimeoutToken)); + } + + [Fact] + public async Task UnionType_AsAsyncEnumerableTypeArgument() + { + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + var clientProxy = this.clientRpc.Attach(); + + UnionBaseClass? actualItem = null; + await foreach (UnionBaseClass item in clientProxy.GetAsyncEnumerableOfUnionType(this.TimeoutToken)) + { + actualItem = item; + } + + Assert.IsType(actualItem); + } + + /// + /// Verifies that an argument that cannot be deserialized by the msgpack primitive formatter will not cause a failure. + /// + /// + /// This is a regression test for a bug where + /// verbose ETW tracing would fail to deserialize arguments with the primitive formatter that deserialize just fine for the actual method dispatch. + /// + [SkippableTheory, PairwiseData] + public async Task VerboseLoggingDoesNotFailWhenArgsDoNotDeserializePrimitively(bool namedArguments) + { + Skip.IfNot(SharedUtilities.GetEventSourceTestMode() == SharedUtilities.EventSourceTestMode.EmulateProduction, $"This test specifically verifies behavior when the EventSource should swallow exceptions. Current mode: {SharedUtilities.GetEventSourceTestMode()}."); + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + var clientProxy = this.clientRpc.Attach(new JsonRpcProxyOptions { ServerRequiresNamedArguments = namedArguments }); + + Assert.True(await clientProxy.IsExtensionArgNonNull(new CustomExtensionType())); + } + + protected override void InitializeFormattersAndHandlers( + Stream serverStream, + Stream clientStream, + out IJsonRpcMessageFormatter serverMessageFormatter, + out IJsonRpcMessageFormatter clientMessageFormatter, + out IJsonRpcMessageHandler serverMessageHandler, + out IJsonRpcMessageHandler clientMessageHandler, + bool controlledFlushingClient) + { + serverMessageFormatter = new NerdbankMessagePackFormatter(); + clientMessageFormatter = new NerdbankMessagePackFormatter(); + + ((NerdbankMessagePackFormatter)serverMessageFormatter).SetFormatterContext(Configure); + ((NerdbankMessagePackFormatter)clientMessageFormatter).SetFormatterContext(Configure); + + serverMessageHandler = new LengthHeaderMessageHandler(serverStream, serverStream, serverMessageFormatter); + clientMessageHandler = controlledFlushingClient + ? new DelayedFlushingHandler(clientStream, clientMessageFormatter) + : new LengthHeaderMessageHandler(clientStream, clientStream, clientMessageFormatter); + + static void Configure(NerdbankMessagePackFormatter.FormatterContextBuilder b) + { + b.RegisterConverter(new UnserializableTypeConverter()); + b.RegisterConverter(new TypeThrowsWhenDeserializedConverter()); + b.RegisterConverter(new CustomExtensionConverter()); + b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + } + } + + protected override object[] CreateFormatterIntrinsicParamsObject(string arg) => []; + + [GenerateShape] +#pragma warning disable CS0618 + [KnownSubType(typeof(UnionDerivedClass))] +#pragma warning restore CS0618 + public abstract partial class UnionBaseClass + { + } + + [GenerateShape] + public partial class UnionDerivedClass : UnionBaseClass + { + } + + [GenerateShape] + internal partial class CustomExtensionType + { + } + + private class CustomExtensionConverter : MessagePackConverter + { + public override CustomExtensionType? Read(ref MessagePackReader reader, SerializationContext context) + { + if (reader.TryReadNil()) + { + return null; + } + + if (reader.ReadExtensionHeader() is { TypeCode: 1, Length: 0 }) + { + return new(); + } + + throw new Exception("Unexpected extension header."); + } + + public override void Write(ref MessagePackWriter writer, in CustomExtensionType? value, SerializationContext context) + { + if (value is null) + { + writer.WriteNil(); + } + else + { + writer.Write(new Extension(1, default(Memory))); + } + } + } + + private class UnserializableTypeConverter : MessagePackConverter + { + public override CustomSerializedType Read(ref MessagePackReader reader, SerializationContext context) + { + return new CustomSerializedType { Value = reader.ReadString() }; + } + + public override void Write(ref MessagePackWriter writer, in CustomSerializedType? value, SerializationContext context) + { + writer.Write(value?.Value); + } + } + + private class TypeThrowsWhenDeserializedConverter : MessagePackConverter + { + public override TypeThrowsWhenDeserialized Read(ref MessagePackReader reader, SerializationContext context) + { + throw CreateExceptionToBeThrownByDeserializer(); + } + + public override void Write(ref MessagePackWriter writer, in TypeThrowsWhenDeserialized? value, SerializationContext context) + { + writer.WriteArrayHeader(0); + } + } + + private class MessagePackServer : IMessagePackServer + { + internal UnionBaseClass? ReceivedValue { get; private set; } + + internal TaskCompletionSource ReceivedValueSource { get; set; } = new TaskCompletionSource(); + + public Task ReturnUnionTypeAsync(CancellationToken cancellationToken) => Task.FromResult(new UnionDerivedClass()); + + public Task AcceptUnionTypeAndReturnStringAsync(UnionBaseClass value, CancellationToken cancellationToken) => Task.FromResult((this.ReceivedValue = value)?.GetType().Name); + + public Task AcceptUnionTypeAsync(UnionBaseClass value, CancellationToken cancellationToken) + { + this.ReceivedValue = value; + this.ReceivedValueSource.SetResult(value); + return Task.CompletedTask; + } + + public UnionBaseClass ReturnUnionType() => new UnionDerivedClass(); + + public Task ProgressUnionType(IProgress progress, CancellationToken cancellationToken) + { + progress.Report(new UnionDerivedClass()); + return Task.CompletedTask; + } + + public async IAsyncEnumerable GetAsyncEnumerableOfUnionType([EnumeratorCancellation] CancellationToken cancellationToken) + { + await Task.Yield(); + yield return new UnionDerivedClass(); + } + + public Task IsExtensionArgNonNull(CustomExtensionType extensionValue) => Task.FromResult(extensionValue is not null); + } + + private class DelayedFlushingHandler : LengthHeaderMessageHandler, IControlledFlushHandler + { + public DelayedFlushingHandler(Stream stream, IJsonRpcMessageFormatter formatter) + : base(stream, stream, formatter) + { + } + + public AsyncAutoResetEvent FlushEntered { get; } = new AsyncAutoResetEvent(); + + public AsyncManualResetEvent AllowFlushAsyncExit { get; } = new AsyncManualResetEvent(); + + protected override async ValueTask FlushAsync(CancellationToken cancellationToken) + { + this.FlushEntered.Set(); + await this.AllowFlushAsyncExit.WaitAsync(CancellationToken.None); + await base.FlushAsync(cancellationToken); + } + } +} diff --git a/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs new file mode 100644 index 000000000..13d44420a --- /dev/null +++ b/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Nerdbank.MessagePack; + +public class MarshalableProxyNerdbankMessagePackTests : MarshalableProxyTests +{ + public MarshalableProxyNerdbankMessagePackTests(ITestOutputHelper logger) + : base(logger) + { + } + + protected override Type FormatterExceptionType => typeof(MessagePackSerializationException); + + protected override IJsonRpcMessageFormatter CreateFormatter() => new NerdbankMessagePackFormatter(); +} diff --git a/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs b/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs new file mode 100644 index 000000000..9a62e6a5b --- /dev/null +++ b/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs @@ -0,0 +1,453 @@ +using System.Diagnostics; +using System.Runtime.Serialization; +using Microsoft.VisualStudio.Threading; +using Nerdbank.MessagePack; +using Nerdbank.Streams; +using PolyType; +using PolyType.ReflectionProvider; +using PolyType.SourceGenerator; + +public partial class NerdbankMessagePackFormatterTests : FormatterTestBase +{ + public NerdbankMessagePackFormatterTests(ITestOutputHelper logger) + : base(logger) + { + } + + [Fact] + public void JsonRpcRequest_PositionalArgs() + { + var original = new JsonRpcRequest + { + RequestId = new RequestId(5), + Method = "test", + ArgumentsList = new object[] { 5, "hi", new CustomType { Age = 8 } }, + }; + + var actual = this.Roundtrip(original); + Assert.Equal(original.RequestId, actual.RequestId); + Assert.Equal(original.Method, actual.Method); + + Assert.True(actual.TryGetArgumentByNameOrIndex(null, 0, typeof(int), out object? actualArg0)); + Assert.Equal(original.ArgumentsList[0], actualArg0); + + Assert.True(actual.TryGetArgumentByNameOrIndex(null, 1, typeof(string), out object? actualArg1)); + Assert.Equal(original.ArgumentsList[1], actualArg1); + + Assert.True(actual.TryGetArgumentByNameOrIndex(null, 2, typeof(CustomType), out object? actualArg2)); + Assert.Equal(((CustomType?)original.ArgumentsList[2])!.Age, ((CustomType)actualArg2!).Age); + } + + [Fact] + public void JsonRpcRequest_NamedArgs() + { + var original = new JsonRpcRequest + { + RequestId = new RequestId(5), + Method = "test", + NamedArguments = new Dictionary + { + { "Number", 5 }, + { "Message", "hi" }, + { "Custom", new CustomType { Age = 8 } }, + }, + }; + + var actual = this.Roundtrip(original); + Assert.Equal(original.RequestId, actual.RequestId); + Assert.Equal(original.Method, actual.Method); + + Assert.True(actual.TryGetArgumentByNameOrIndex("Number", -1, typeof(int), out object? actualArg0)); + Assert.Equal(original.NamedArguments["Number"], actualArg0); + + Assert.True(actual.TryGetArgumentByNameOrIndex("Message", -1, typeof(string), out object? actualArg1)); + Assert.Equal(original.NamedArguments["Message"], actualArg1); + + Assert.True(actual.TryGetArgumentByNameOrIndex("Custom", -1, typeof(CustomType), out object? actualArg2)); + Assert.Equal(((CustomType?)original.NamedArguments["Custom"])!.Age, ((CustomType)actualArg2!).Age); + } + + [Fact] + public void JsonRpcResult() + { + var original = new JsonRpcResult + { + RequestId = new RequestId(5), + Result = new CustomType { Age = 7 }, + }; + + var actual = this.Roundtrip(original); + Assert.Equal(original.RequestId, actual.RequestId); + Assert.Equal(((CustomType?)original.Result)!.Age, actual.GetResult().Age); + } + + [Fact] + public void JsonRpcError() + { + var original = new JsonRpcError + { + RequestId = new RequestId(5), + Error = new JsonRpcError.ErrorDetail + { + Code = JsonRpcErrorCode.InvocationError, + Message = "Oops", + Data = new CustomType { Age = 15 }, + }, + }; + + var actual = this.Roundtrip(original); + Assert.Equal(original.RequestId, actual.RequestId); + Assert.Equal(original.Error.Code, actual.Error!.Code); + Assert.Equal(original.Error.Message, actual.Error.Message); + Assert.Equal(((CustomType)original.Error.Data).Age, actual.Error.GetData().Age); + } + + [Fact] + public async Task BasicJsonRpc() + { + var (clientStream, serverStream) = FullDuplexStream.CreatePair(); + var clientFormatter = new NerdbankMessagePackFormatter(); + var serverFormatter = new NerdbankMessagePackFormatter(); + + var clientHandler = new LengthHeaderMessageHandler(clientStream.UsePipe(), clientFormatter); + var serverHandler = new LengthHeaderMessageHandler(serverStream.UsePipe(), serverFormatter); + + var clientRpc = new JsonRpc(clientHandler); + var serverRpc = new JsonRpc(serverHandler, new Server()); + + serverRpc.TraceSource = new TraceSource("Server", SourceLevels.Verbose); + clientRpc.TraceSource = new TraceSource("Client", SourceLevels.Verbose); + + serverRpc.TraceSource.Listeners.Add(new XunitTraceListener(this.Logger)); + clientRpc.TraceSource.Listeners.Add(new XunitTraceListener(this.Logger)); + + clientRpc.StartListening(); + serverRpc.StartListening(); + + int result = await clientRpc.InvokeAsync(nameof(Server.Add), 3, 5).WithCancellation(this.TimeoutToken); + Assert.Equal(8, result); + } + + [Fact] + public void Resolver_RequestArgInArray() + { + this.Formatter.SetFormatterContext(b => + { + b.RegisterConverter(new CustomConverter()); + b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + }); + + var originalArg = new TypeRequiringCustomFormatter { Prop1 = 3, Prop2 = 5 }; + var originalRequest = new JsonRpcRequest + { + RequestId = new RequestId(1), + Method = "Eat", + ArgumentsList = new object[] { originalArg }, + }; + var roundtripRequest = this.Roundtrip(originalRequest); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(null, 0, typeof(TypeRequiringCustomFormatter), out object? roundtripArgObj)); + var roundtripArg = (TypeRequiringCustomFormatter)roundtripArgObj!; + Assert.Equal(originalArg.Prop1, roundtripArg.Prop1); + Assert.Equal(originalArg.Prop2, roundtripArg.Prop2); + } + + [Fact] + public void Resolver_RequestArgInNamedArgs_AnonymousType() + { + this.Formatter.SetFormatterContext(b => + { + b.RegisterConverter(new CustomConverter()); + b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + }); + + var originalArg = new { Prop1 = 3, Prop2 = 5 }; + var originalRequest = new JsonRpcRequest + { + RequestId = new RequestId(1), + Method = "Eat", + Arguments = originalArg, + }; + var roundtripRequest = this.Roundtrip(originalRequest); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.Prop1), -1, typeof(int), out object? prop1)); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.Prop2), -1, typeof(int), out object? prop2)); + Assert.Equal(originalArg.Prop1, prop1); + Assert.Equal(originalArg.Prop2, prop2); + } + + [Fact] + public void Resolver_RequestArgInNamedArgs_DataContractObject() + { + this.Formatter.SetFormatterContext(b => + { + b.RegisterConverter(new CustomConverter()); + b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + }); + + var originalArg = new DataContractWithSubsetOfMembersIncluded { ExcludedField = "A", ExcludedProperty = "B", IncludedField = "C", IncludedProperty = "D" }; + var originalRequest = new JsonRpcRequest + { + RequestId = new RequestId(1), + Method = "Eat", + Arguments = originalArg, + }; + var roundtripRequest = this.Roundtrip(originalRequest); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.ExcludedField), -1, typeof(string), out object? _)); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.ExcludedProperty), -1, typeof(string), out object? _)); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.IncludedField), -1, typeof(string), out object? includedField)); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.IncludedProperty), -1, typeof(string), out object? includedProperty)); + Assert.Equal(originalArg.IncludedProperty, includedProperty); + Assert.Equal(originalArg.IncludedField, includedField); + } + + [Fact] + public void Resolver_RequestArgInNamedArgs_NonDataContractObject() + { + this.Formatter.SetFormatterContext(b => + { + b.RegisterConverter(new CustomConverter()); + b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + }); + + var originalArg = new NonDataContractWithExcludedMembers { ExcludedField = "A", ExcludedProperty = "B", InternalField = "C", InternalProperty = "D", PublicField = "E", PublicProperty = "F" }; + var originalRequest = new JsonRpcRequest + { + RequestId = new RequestId(1), + Method = "Eat", + Arguments = originalArg, + }; + var roundtripRequest = this.Roundtrip(originalRequest); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.ExcludedField), -1, typeof(string), out object? _)); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.ExcludedProperty), -1, typeof(string), out object? _)); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.InternalField), -1, typeof(string), out object? _)); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.InternalProperty), -1, typeof(string), out object? _)); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.PublicField), -1, typeof(string), out object? publicField)); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.PublicProperty), -1, typeof(string), out object? publicProperty)); + Assert.Equal(originalArg.PublicProperty, publicProperty); + Assert.Equal(originalArg.PublicField, publicField); + } + + [Fact] + public void Resolver_RequestArgInNamedArgs_NullObject() + { + var originalRequest = new JsonRpcRequest + { + RequestId = new RequestId(1), + Method = "Eat", + Arguments = null, + }; + var roundtripRequest = this.Roundtrip(originalRequest); + Assert.Null(roundtripRequest.Arguments); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex("AnythingReally", -1, typeof(string), out object? _)); + } + + [Fact] + public void Resolver_Result() + { + this.Formatter.SetFormatterContext(b => + { + b.RegisterConverter(new CustomConverter()); + b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + }); + + var originalResultValue = new TypeRequiringCustomFormatter { Prop1 = 3, Prop2 = 5 }; + var originalResult = new JsonRpcResult + { + RequestId = new RequestId(1), + Result = originalResultValue, + }; + var roundtripResult = this.Roundtrip(originalResult); + var roundtripResultValue = roundtripResult.GetResult(); + Assert.Equal(originalResultValue.Prop1, roundtripResultValue.Prop1); + Assert.Equal(originalResultValue.Prop2, roundtripResultValue.Prop2); + } + + [Fact] + public void Resolver_ErrorData() + { + this.Formatter.SetFormatterContext(b => + { + b.RegisterConverter(new CustomConverter()); + b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + }); + + var originalErrorData = new TypeRequiringCustomFormatter { Prop1 = 3, Prop2 = 5 }; + var originalError = new JsonRpcError + { + RequestId = new RequestId(1), + Error = new JsonRpcError.ErrorDetail + { + Data = originalErrorData, + }, + }; + var roundtripError = this.Roundtrip(originalError); + var roundtripErrorData = roundtripError.Error!.GetData(); + Assert.Equal(originalErrorData.Prop1, roundtripErrorData.Prop1); + Assert.Equal(originalErrorData.Prop2, roundtripErrorData.Prop2); + } + + [Fact] + public void CanDeserializeWithExtraProperty_JsonRpcRequest() + { + var dynamic = new + { + jsonrpc = "2.0", + method = "something", + extra = (object?)null, + @params = new object[] { "hi" }, + }; + var request = this.Read(dynamic); + Assert.Equal(dynamic.jsonrpc, request.Version); + Assert.Equal(dynamic.method, request.Method); + Assert.Equal(dynamic.@params.Length, request.ArgumentCount); + Assert.True(request.TryGetArgumentByNameOrIndex(null, 0, typeof(string), out object? arg)); + Assert.Equal(dynamic.@params[0], arg); + } + + [Fact] + public void CanDeserializeWithExtraProperty_JsonRpcResult() + { + var dynamic = new + { + jsonrpc = "2.0", + id = 2, + extra = (object?)null, + result = "hi", + }; + var request = this.Read(dynamic); + Assert.Equal(dynamic.jsonrpc, request.Version); + Assert.Equal(dynamic.id, request.RequestId.Number); + Assert.Equal(dynamic.result, request.GetResult()); + } + + [Fact] + public void CanDeserializeWithExtraProperty_JsonRpcError() + { + var dynamic = new + { + jsonrpc = "2.0", + id = 2, + extra = (object?)null, + error = new { extra = 2, code = 5 }, + }; + var request = this.Read(dynamic); + Assert.Equal(dynamic.jsonrpc, request.Version); + Assert.Equal(dynamic.id, request.RequestId.Number); + Assert.Equal(dynamic.error.code, (int?)request.Error?.Code); + } + + [Fact] + public void StringsInUserDataAreInterned() + { + var dynamic = new + { + jsonrpc = "2.0", + method = "something", + extra = (object?)null, + @params = new object[] { "hi" }, + }; + var request1 = this.Read(dynamic); + var request2 = this.Read(dynamic); + Assert.True(request1.TryGetArgumentByNameOrIndex(null, 0, typeof(string), out object? arg1)); + Assert.True(request2.TryGetArgumentByNameOrIndex(null, 0, typeof(string), out object? arg2)); + Assert.Same(arg2, arg1); // reference equality to ensure it was interned. + } + + [Fact] + public void StringValuesOfStandardPropertiesAreInterned() + { + var dynamic = new + { + jsonrpc = "2.0", + method = "something", + extra = (object?)null, + @params = Array.Empty(), + }; + var request1 = this.Read(dynamic); + var request2 = this.Read(dynamic); + Assert.Same(request1.Method, request2.Method); // reference equality to ensure it was interned. + } + + protected override NerdbankMessagePackFormatter CreateFormatter() => new(); + + private T Read(object anonymousObject) + where T : JsonRpcMessage + { + var sequence = new Sequence(); + var writer = new MessagePackWriter(sequence); + new MessagePackSerializer().Serialize(ref writer, anonymousObject, ReflectionTypeShapeProvider.Default); + writer.Flush(); + return (T)this.Formatter.Deserialize(sequence); + } + + [DataContract] + [GenerateShape] + public partial class DataContractWithSubsetOfMembersIncluded + { + [PropertyShape(Ignore = true)] + public string? ExcludedField; + + [DataMember] + internal string? IncludedField; + + [PropertyShape(Ignore = true)] + public string? ExcludedProperty { get; set; } + + [DataMember] + internal string? IncludedProperty { get; set; } + } + + [GenerateShape] + public partial class NonDataContractWithExcludedMembers + { + [IgnoreDataMember] + [PropertyShape(Ignore = true)] + public string? ExcludedField; + + public string? PublicField; + + internal string? InternalField; + + [IgnoreDataMember] + [PropertyShape(Ignore = true)] + public string? ExcludedProperty { get; set; } + + public string? PublicProperty { get; set; } + + internal string? InternalProperty { get; set; } + } + + [GenerateShape] + public partial class TypeRequiringCustomFormatter + { + internal int Prop1 { get; set; } + + internal int Prop2 { get; set; } + } + + private class CustomConverter : MessagePackConverter + { + public override TypeRequiringCustomFormatter Read(ref MessagePackReader reader, SerializationContext context) + { + Assert.Equal(2, reader.ReadArrayHeader()); + return new TypeRequiringCustomFormatter + { + Prop1 = reader.ReadInt32(), + Prop2 = reader.ReadInt32(), + }; + } + + public override void Write(ref MessagePackWriter writer, in TypeRequiringCustomFormatter? value, SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + + writer.WriteArrayHeader(2); + writer.Write(value.Prop1); + writer.Write(value.Prop2); + } + } + + private class Server + { + public int Add(int a, int b) => a + b; + } +} diff --git a/test/StreamJsonRpc.Tests/ObserverMarshalingNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/ObserverMarshalingNerdbankMessagePackTests.cs new file mode 100644 index 000000000..7f46b1da4 --- /dev/null +++ b/test/StreamJsonRpc.Tests/ObserverMarshalingNerdbankMessagePackTests.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +public class ObserverMarshalingNerdbankMessagePackTests : ObserverMarshalingTests +{ + public ObserverMarshalingNerdbankMessagePackTests(ITestOutputHelper logger) + : base(logger) + { + } + + protected override IJsonRpcMessageFormatter CreateFormatter() => new NerdbankMessagePackFormatter(); +} diff --git a/test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj b/test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj index e391c2ac1..f7bf2232d 100644 --- a/test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj +++ b/test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj @@ -11,30 +11,39 @@ + + + + + + + + + diff --git a/test/StreamJsonRpc.Tests/TargetObjectEventsNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/TargetObjectEventsNerdbankMessagePackTests.cs new file mode 100644 index 000000000..12152f55d --- /dev/null +++ b/test/StreamJsonRpc.Tests/TargetObjectEventsNerdbankMessagePackTests.cs @@ -0,0 +1,16 @@ +public class TargetObjectEventsNerdbankMessagePackTests : TargetObjectEventsTests +{ + public TargetObjectEventsNerdbankMessagePackTests(ITestOutputHelper logger) + : base(logger) + { + } + + protected override void InitializeFormattersAndHandlers() + { + var serverMessageFormatter = new NerdbankMessagePackFormatter(); + var clientMessageFormatter = new NerdbankMessagePackFormatter(); + + this.serverMessageHandler = new LengthHeaderMessageHandler(this.serverStream, this.serverStream, serverMessageFormatter); + this.clientMessageHandler = new LengthHeaderMessageHandler(this.clientStream, this.clientStream, clientMessageFormatter); + } +} diff --git a/test/StreamJsonRpc.Tests/WebSocketMessageHandlerNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/WebSocketMessageHandlerNerdbankMessagePackTests.cs new file mode 100644 index 000000000..b14c7b25b --- /dev/null +++ b/test/StreamJsonRpc.Tests/WebSocketMessageHandlerNerdbankMessagePackTests.cs @@ -0,0 +1,7 @@ +public class WebSocketMessageHandlerNerdbankMessagePackTests : WebSocketMessageHandlerTests +{ + public WebSocketMessageHandlerNerdbankMessagePackTests(ITestOutputHelper logger) + : base(new NerdbankMessagePackFormatter(), logger) + { + } +} From bca2af4417c6e5ae0baf12bacc84c085ea2c976e Mon Sep 17 00:00:00 2001 From: Charles Willis Date: Sun, 22 Dec 2024 18:53:41 +0000 Subject: [PATCH 12/25] Some tests are now passing. --- src/StreamJsonRpc/FormatterBase.cs | 5 +- ...rdbankMessagePackFormatter.CommonString.cs | 3 + ...nkMessagePackFormatter.FormatterContext.cs | 7 +- ...gePackFormatter.FormatterContextBuilder.cs | 3 - .../NerdbankMessagePackFormatter.cs | 68 ++++++++++--------- src/StreamJsonRpc/Protocol/JsonRpcError.cs | 2 +- src/StreamJsonRpc/Protocol/JsonRpcMessage.cs | 16 +++-- ...AsyncEnumerableNerdbankMessagePackTests.cs | 15 +++- 8 files changed, 70 insertions(+), 49 deletions(-) diff --git a/src/StreamJsonRpc/FormatterBase.cs b/src/StreamJsonRpc/FormatterBase.cs index c7bdb6ffa..42b283359 100644 --- a/src/StreamJsonRpc/FormatterBase.cs +++ b/src/StreamJsonRpc/FormatterBase.cs @@ -8,6 +8,7 @@ using System.Runtime.Serialization; using Nerdbank.MessagePack; using Nerdbank.Streams; +using PolyType; using StreamJsonRpc.Protocol; using StreamJsonRpc.Reflection; @@ -321,7 +322,7 @@ public void Dispose() /// /// A base class for top-level property bags that should be declared in the derived formatter class. /// - protected abstract class TopLevelPropertyBagBase + protected internal abstract class TopLevelPropertyBagBase { private readonly bool isOutbound; private Dictionary? outboundProperties; @@ -437,6 +438,7 @@ protected abstract class JsonRpcErrorBase : JsonRpcError, IJsonRpcMessageBufferM /// [Newtonsoft.Json.JsonIgnore] [IgnoreDataMember] + [PropertyShape(Ignore = true)] public TopLevelPropertyBagBase? TopLevelPropertyBag { get; set; } void IJsonRpcMessageBufferManager.DeserializationComplete(JsonRpcMessage message) @@ -481,6 +483,7 @@ protected abstract class JsonRpcResultBase : JsonRpcResult, IJsonRpcMessageBuffe /// [Newtonsoft.Json.JsonIgnore] [IgnoreDataMember] + [PropertyShape(Ignore = true)] public TopLevelPropertyBagBase? TopLevelPropertyBag { get; set; } void IJsonRpcMessageBufferManager.DeserializationComplete(JsonRpcMessage message) diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.CommonString.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.CommonString.cs index dfc6ef831..9b03c9158 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.CommonString.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.CommonString.cs @@ -6,6 +6,9 @@ namespace StreamJsonRpc; +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// public partial class NerdbankMessagePackFormatter { [DebuggerDisplay("{" + nameof(Value) + "}")] diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContext.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContext.cs index 3d560bbc5..56f86fa5f 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContext.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContext.cs @@ -10,9 +10,6 @@ namespace StreamJsonRpc; /// /// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). /// -/// -/// The MessagePack implementation used here comes from https://github.com/AArnott/Nerdbank.MessagePack. -/// public sealed partial class NerdbankMessagePackFormatter { internal class FormatterContext(MessagePackSerializer serializer, ITypeShapeProvider shapeProvider) @@ -30,7 +27,7 @@ public T Deserialize(in RawMessagePack pack, CancellationToken cancellationTo { // TODO: Improve the exception return serializer.Deserialize(pack, shapeProvider, cancellationToken) - ?? throw new InvalidOperationException("Deserialization failed."); + ?? throw new MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); } public object? DeserializeObject(in RawMessagePack pack, Type objectType, CancellationToken cancellationToken = default) @@ -47,7 +44,7 @@ public void Serialize(ref MessagePackWriter writer, T? value, CancellationTok serializer.Serialize(ref writer, value, shapeProvider, cancellationToken); } - internal void SerializeObject(ref MessagePackWriter writer, object? value, Type objectType, CancellationToken cancellationToken = default) + public void SerializeObject(ref MessagePackWriter writer, object? value, Type objectType, CancellationToken cancellationToken = default) { serializer.SerializeObject(ref writer, value, shapeProvider.Resolve(objectType), cancellationToken); } diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContextBuilder.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContextBuilder.cs index b511cbf22..b36b3a9c4 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContextBuilder.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContextBuilder.cs @@ -13,9 +13,6 @@ namespace StreamJsonRpc; /// /// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). /// -/// -/// The MessagePack implementation used here comes from https://github.com/AArnott/Nerdbank.MessagePack. -/// public sealed partial class NerdbankMessagePackFormatter { /// diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs index 204196138..af0816064 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs @@ -29,7 +29,8 @@ namespace StreamJsonRpc; /// /// The MessagePack implementation used here comes from https://github.com/AArnott/Nerdbank.MessagePack. /// -public sealed partial class NerdbankMessagePackFormatter : FormatterBase, IJsonRpcMessageFormatter, IJsonRpcFormatterTracingCallbacks, IJsonRpcMessageFactory +[System.Diagnostics.CodeAnalysis.SuppressMessage("ApiDesign", "RS0016:Add public types and members to the declared API", Justification = "TODO: Suppressed for Development")] +public partial class NerdbankMessagePackFormatter : FormatterBase, IJsonRpcMessageFormatter, IJsonRpcFormatterTracingCallbacks, IJsonRpcMessageFactory { /// /// The constant "jsonrpc", in its various forms. @@ -177,8 +178,7 @@ public void SetFormatterContext(Action configure) /// public JsonRpcMessage Deserialize(ReadOnlySequence contentBuffer) { - JsonRpcMessage message = this.rpcContext.Serializer.Deserialize(contentBuffer, ShapeProvider_StreamJsonRpc.Default) - ?? throw new MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); + JsonRpcMessage message = this.rpcContext.Deserialize(contentBuffer); IJsonRpcTracingCallbacks? tracingCallbacks = this.JsonRpc; this.deserializationToStringHelper.Activate(contentBuffer); @@ -214,7 +214,7 @@ public void Serialize(IBufferWriter contentBuffer, JsonRpcMessage message) var writer = new MessagePackWriter(contentBuffer); try { - this.rpcContext.Serializer.Serialize(ref writer, message, this.rpcContext.ShapeProvider); + this.rpcContext.Serialize(ref writer, message); writer.Flush(); } catch (Exception ex) @@ -641,7 +641,6 @@ internal PreciseTypeConverter(NerdbankMessagePackFormatter formatter) } [return: MaybeNull] - [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] public override TClass? Read(ref MessagePackReader reader, SerializationContext context) { if (reader.TryReadNil()) @@ -754,7 +753,7 @@ private partial class PreciseTypeConverter(NerdbankMessagePackFormatter mainF return mainFormatter.EnumerableTracker.CreateEnumerableProxy(token.HasValue ? token : null, initialElements); } - [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to helper method")] public override void Write(ref MessagePackWriter writer, in IAsyncEnumerable? value, SerializationContext context) { Serialize_Shared(mainFormatter, ref writer, value, context); @@ -808,7 +807,8 @@ internal static void Serialize_Shared(NerdbankMessagePackFormatter mainFormatter /// Converts an instance of to an enumeration token. /// #pragma warning disable CA1812 - private class GeneratorConverter(NerdbankMessagePackFormatter mainFormatter) : MessagePackConverter where TClass : IAsyncEnumerable + private class GeneratorConverter(NerdbankMessagePackFormatter mainFormatter) : MessagePackConverter + where TClass : IAsyncEnumerable #pragma warning restore CA1812 { public override TClass Read(ref MessagePackReader reader, SerializationContext context) @@ -816,7 +816,7 @@ public override TClass Read(ref MessagePackReader reader, SerializationContext c throw new NotSupportedException(); } - [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to helper method")] public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) { PreciseTypeConverter.Serialize_Shared(mainFormatter, ref writer, value, context); @@ -1012,14 +1012,14 @@ private class RpcMarshalableConverter( where T : class #pragma warning restore CA1812 { - [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Reader is passed to rpc context")] public override T? Read(ref MessagePackReader reader, SerializationContext context) { MessageFormatterRpcMarshaledContextTracker.MarshalToken? token = formatter.rpcContext.Deserialize(ref reader); return token.HasValue ? (T?)formatter.RpcMarshaledContextTracker.GetObject(typeof(T), token, proxyOptions) : null; } - [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to rpc context")] public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) { if (value is null) @@ -1175,7 +1175,8 @@ public override void Write(ref MessagePackWriter writer, in T? value, Serializat } } - private class JsonRpcMessageConverter : MessagePackConverter + [GenerateShape] + private partial class JsonRpcMessageConverter : MessagePackConverter { private readonly NerdbankMessagePackFormatter formatter; @@ -1251,7 +1252,7 @@ public override void Write(ref MessagePackWriter writer, in JsonRpcMessage? valu } } - private class JsonRpcRequestConverter : MessagePackConverter + private partial class JsonRpcRequestConverter : MessagePackConverter { private readonly NerdbankMessagePackFormatter formatter; @@ -1276,7 +1277,10 @@ internal JsonRpcRequestConverter(NerdbankMessagePackFormatter formatter) // We read the property name in this fancy way in order to avoid paying to decode and allocate a string when we already know what we're looking for. if (!reader.TryReadStringSpan(out ReadOnlySpan stringKey)) { - throw new UnrecognizedJsonRpcMessageException(); + ReadOnlySequence? keySequence = reader.ReadStringSequence() ?? throw new MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); + stringKey = keySequence.Value.IsSingleSegment + ? keySequence.Value.First.Span + : keySequence.Value.ToArray(); } if (VersionPropertyName.TryRead(stringKey)) @@ -1400,7 +1404,6 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcRequ { writer.WriteArrayHeader(value.ArgumentsList.Count); - for (int i = 0; i < value.ArgumentsList.Count; i++) { object? arg = value.ArgumentsList[i]; @@ -1553,16 +1556,19 @@ public override Protocol.JsonRpcResult Read(ref MessagePackReader reader, Serial OriginalMessagePack = reader.Sequence, }; - Dictionary>? topLevelProperties = null; context.DepthStep(); + Dictionary>? topLevelProperties = null; int propertyCount = reader.ReadMapHeader(); for (int propertyIndex = 0; propertyIndex < propertyCount; propertyIndex++) { // We read the property name in this fancy way in order to avoid paying to decode and allocate a string when we already know what we're looking for. if (!reader.TryReadStringSpan(out ReadOnlySpan stringKey)) { - throw new UnrecognizedJsonRpcMessageException(); + ReadOnlySequence? keySequence = reader.ReadStringSequence() ?? throw new MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); + stringKey = keySequence.Value.IsSingleSegment + ? keySequence.Value.First.Span + : keySequence.Value.ToArray(); } if (VersionPropertyName.TryRead(stringKey)) @@ -1661,7 +1667,10 @@ public override Protocol.JsonRpcError Read(ref MessagePackReader reader, Seriali // We read the property name in this fancy way in order to avoid paying to decode and allocate a string when we already know what we're looking for. if (!reader.TryReadStringSpan(out ReadOnlySpan stringKey)) { - throw new UnrecognizedJsonRpcMessageException(); + ReadOnlySequence? keySequence = reader.ReadStringSequence() ?? throw new MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); + stringKey = keySequence.Value.IsSingleSegment + ? keySequence.Value.First.Span + : keySequence.Value.ToArray(); } if (VersionPropertyName.TryRead(stringKey)) @@ -1740,7 +1749,10 @@ public override Protocol.JsonRpcError.ErrorDetail Read(ref MessagePackReader rea { if (!reader.TryReadStringSpan(out ReadOnlySpan stringKey)) { - throw new UnrecognizedJsonRpcMessageException(); + ReadOnlySequence? keySequence = reader.ReadStringSequence() ?? throw new MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); + stringKey = keySequence.Value.IsSingleSegment + ? keySequence.Value.First.Span + : keySequence.Value.ToArray(); } if (CodePropertyName.TryRead(stringKey)) @@ -1764,26 +1776,24 @@ public override Protocol.JsonRpcError.ErrorDetail Read(ref MessagePackReader rea return result; } - [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to user data context")] public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcError.ErrorDetail? value, SerializationContext context) { Requires.NotNull(value!, nameof(value)); + context.DepthStep(); + writer.WriteMapHeader(3); CodePropertyName.Write(ref writer); - context.GetConverter(context.TypeShapeProvider).Write(ref writer, value.Code, context); + context.GetConverter(context.TypeShapeProvider) + .Write(ref writer, value.Code, context); MessagePropertyName.Write(ref writer); writer.Write(value.Message); DataPropertyName.Write(ref writer); -#pragma warning disable NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods - this.formatter.userDataContext.Serializer.SerializeObject( - ref writer, - value.Data, - this.formatter.userDataContext.ShapeProvider.Resolve()); -#pragma warning restore NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods + this.formatter.userDataContext.Serialize(ref writer, value.Data); } public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) @@ -1971,7 +1981,6 @@ protected internal override bool TryGetTopLevelProperty(string name, [MaybeNu } [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] - [DataContract] private class OutboundJsonRpcRequest : JsonRpcRequestBase { private readonly NerdbankMessagePackFormatter formatter; @@ -1985,7 +1994,6 @@ internal OutboundJsonRpcRequest(NerdbankMessagePackFormatter formatter) } [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] - [DataContract] private class JsonRpcRequest : JsonRpcRequestBase, IJsonRpcMessagePackRetention { private readonly NerdbankMessagePackFormatter formatter; @@ -2093,8 +2101,7 @@ protected override void ReleaseBuffers() } [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] - [DataContract] - private class JsonRpcResult : JsonRpcResultBase, IJsonRpcMessagePackRetention + private partial class JsonRpcResult : JsonRpcResultBase, IJsonRpcMessagePackRetention { private readonly NerdbankMessagePackFormatter formatter; private readonly FormatterContext serializerOptions; @@ -2158,7 +2165,6 @@ protected override void ReleaseBuffers() } [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] - [DataContract] private class JsonRpcError : JsonRpcErrorBase, IJsonRpcMessagePackRetention { private readonly FormatterContext serializerOptions; diff --git a/src/StreamJsonRpc/Protocol/JsonRpcError.cs b/src/StreamJsonRpc/Protocol/JsonRpcError.cs index 143cd82cb..9842d5719 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcError.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcError.cs @@ -137,7 +137,7 @@ public class ErrorDetail /// /// The type that will be used as the generic type argument to . /// - /// Overridding methods in types that retain buffers used to deserialize should deserialize within this method and clear those buffers + /// Overriding methods in types that retain buffers used to deserialize should deserialize within this method and clear those buffers /// to prevent further access to these buffers which may otherwise happen concurrently with a call to /// that would recycle the same buffer being deserialized from. /// diff --git a/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs b/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs index 1a9a6edc9..0b09d0769 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs @@ -16,12 +16,16 @@ namespace StreamJsonRpc.Protocol; [KnownType(typeof(JsonRpcRequest))] [KnownType(typeof(JsonRpcResult))] [KnownType(typeof(JsonRpcError))] -#pragma warning disable CS0618 //'KnownSubTypeAttribute.KnownSubTypeAttribute(Type)' is obsolete: 'Use the generic version of this attribute instead.' -[KnownSubType(typeof(JsonRpcRequest))] -[KnownSubType(typeof(JsonRpcResult))] -[KnownSubType(typeof(JsonRpcError))] -#pragma warning restore CS0618 -public abstract class JsonRpcMessage +#if NETSTANDARD2_0_OR_GREATER +[KnownSubType(typeof(JsonRpcRequest), 1)] +[KnownSubType(typeof(JsonRpcResult), 2)] +[KnownSubType(typeof(JsonRpcError), 3)] +#elif NET +[KnownSubType(1)] +[KnownSubType(2)] +[KnownSubType(3)] +#endif +public abstract partial class JsonRpcMessage { /// /// Gets or sets the version of the JSON-RPC protocol that this message conforms to. diff --git a/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs index d4c5ee933..b0e10234a 100644 --- a/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs @@ -10,7 +10,18 @@ public AsyncEnumerableNerdbankMessagePackTests(ITestOutputHelper logger) protected override void InitializeFormattersAndHandlers() { - this.serverMessageFormatter = new NerdbankMessagePackFormatter(); - this.clientMessageFormatter = new NerdbankMessagePackFormatter(); + NerdbankMessagePackFormatter serverFormatter = new(); + serverFormatter.SetFormatterContext(ConfigureContext); + + NerdbankMessagePackFormatter clientFormatter = new(); + clientFormatter.SetFormatterContext(ConfigureContext); + + this.serverMessageFormatter = serverFormatter; + this.clientMessageFormatter = clientFormatter; + + static void ConfigureContext(NerdbankMessagePackFormatter.FormatterContextBuilder contextBuilder) + { + contextBuilder.RegisterAsyncEnumerableType, int>(); + } } } From d693a1cbe7ca53831ff0dd52ac881d95406b835a Mon Sep 17 00:00:00 2001 From: Charles Willis Date: Sun, 22 Dec 2024 22:22:49 +0000 Subject: [PATCH 13/25] Cleaning up formatters. Rename `FormatterContext` to `FormatterProfile`. --- ...gePackFormatter.FormatterContextBuilder.cs | 44 +-- ...kMessagePackFormatter.FormatterProfile.cs} | 23 +- .../NerdbankMessagePackFormatter.cs | 288 ++++++++++-------- src/StreamJsonRpc/Protocol/JsonRpcError.cs | 3 +- ...AsyncEnumerableNerdbankMessagePackTests.cs | 6 +- .../JsonRpcNerdbankMessagePackLengthTests.cs | 6 +- .../NerdbankMessagePackFormatterTests.cs | 12 +- 7 files changed, 217 insertions(+), 165 deletions(-) rename src/StreamJsonRpc/{NerdbankMessagePackFormatter.FormatterContext.cs => NerdbankMessagePackFormatter.FormatterProfile.cs} (73%) diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContextBuilder.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContextBuilder.cs index b36b3a9c4..19fe149ea 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContextBuilder.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContextBuilder.cs @@ -16,24 +16,24 @@ namespace StreamJsonRpc; public sealed partial class NerdbankMessagePackFormatter { /// - /// Provides methods to build a serialization context for the . + /// Provides methods to build a serialization profile for the . /// - public class FormatterContextBuilder + public class FormatterProfileBuilder { private readonly NerdbankMessagePackFormatter formatter; - private readonly FormatterContext baseContext; + private readonly FormatterProfile baseProfile; private ImmutableArray.Builder? typeShapeProvidersBuilder = null; /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// The formatter to use. - /// The base context to build upon. - internal FormatterContextBuilder(NerdbankMessagePackFormatter formatter, FormatterContext baseContext) + /// The base profile to build upon. + internal FormatterProfileBuilder(NerdbankMessagePackFormatter formatter, FormatterProfile baseProfile) { this.formatter = formatter; - this.baseContext = baseContext; + this.baseProfile = baseProfile; } /// @@ -55,7 +55,7 @@ public void RegisterAsyncEnumerableType() where TEnumerable : IAsyncEnumerable { MessagePackConverter converter = this.formatter.asyncEnumerableConverterResolver.GetConverter(); - this.baseContext.Serializer.RegisterConverter(converter); + this.baseProfile.Serializer.RegisterConverter(converter); } /// @@ -65,7 +65,7 @@ public void RegisterAsyncEnumerableType() /// The converter to register. public void RegisterConverter(MessagePackConverter converter) { - this.baseContext.Serializer.RegisterConverter(converter); + this.baseProfile.Serializer.RegisterConverter(converter); } /// @@ -75,7 +75,7 @@ public void RegisterConverter(MessagePackConverter converter) /// The mapping of known subtypes. public void RegisterKnownSubTypes(KnownSubTypeMapping mapping) { - this.baseContext.Serializer.RegisterKnownSubTypes(mapping); + this.baseProfile.Serializer.RegisterKnownSubTypes(mapping); } /// @@ -87,7 +87,7 @@ public void RegisterProgressType() where TProgress : IProgress { MessagePackConverter converter = this.formatter.progressConverterResolver.GetConverter(); - this.baseContext.Serializer.RegisterConverter(converter); + this.baseProfile.Serializer.RegisterConverter(converter); } /// @@ -98,7 +98,7 @@ public void RegisterDuplexPipeType() where TPipe : IDuplexPipe { MessagePackConverter converter = this.formatter.pipeConverterResolver.GetConverter(); - this.baseContext.Serializer.RegisterConverter(converter); + this.baseProfile.Serializer.RegisterConverter(converter); } /// @@ -109,7 +109,7 @@ public void RegisterPipeReaderType() where TReader : PipeReader { MessagePackConverter converter = this.formatter.pipeConverterResolver.GetConverter(); - this.baseContext.Serializer.RegisterConverter(converter); + this.baseProfile.Serializer.RegisterConverter(converter); } /// @@ -120,7 +120,7 @@ public void RegisterPipeWriterType() where TWriter : PipeWriter { MessagePackConverter converter = this.formatter.pipeConverterResolver.GetConverter(); - this.baseContext.Serializer.RegisterConverter(converter); + this.baseProfile.Serializer.RegisterConverter(converter); } /// @@ -131,7 +131,7 @@ public void RegisterStreamType() where TStream : Stream { MessagePackConverter converter = this.formatter.pipeConverterResolver.GetConverter(); - this.baseContext.Serializer.RegisterConverter(converter); + this.baseProfile.Serializer.RegisterConverter(converter); } /// @@ -142,7 +142,7 @@ public void RegisterExceptionType() where TException : Exception { MessagePackConverter converter = this.formatter.exceptionResolver.GetConverter(); - this.baseContext.Serializer.RegisterConverter(converter); + this.baseProfile.Serializer.RegisterConverter(converter); } /// @@ -165,28 +165,28 @@ public void RegisterRpcMarshalableType() targetOptions, attribute)!; - this.baseContext.Serializer.RegisterConverter(converter); + this.baseProfile.Serializer.RegisterConverter(converter); } // TODO: Throw? } /// - /// Builds the formatter context. + /// Builds the formatter profile. /// - /// The built formatter context. - internal FormatterContext Build() + /// The built formatter profile. + internal FormatterProfile Build() { if (this.typeShapeProvidersBuilder is null || this.typeShapeProvidersBuilder.Count < 1) { - return this.baseContext; + return this.baseProfile; } ITypeShapeProvider provider = this.typeShapeProvidersBuilder.Count == 1 ? this.typeShapeProvidersBuilder[0] : new CompositeTypeShapeProvider(this.typeShapeProvidersBuilder.ToImmutable()); - return new FormatterContext(this.baseContext.Serializer, provider); + return new FormatterProfile(this.baseProfile.Serializer, provider); } } diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContext.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfile.cs similarity index 73% rename from src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContext.cs rename to src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfile.cs index 56f86fa5f..14ce67500 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContext.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfile.cs @@ -12,11 +12,11 @@ namespace StreamJsonRpc; /// public sealed partial class NerdbankMessagePackFormatter { - internal class FormatterContext(MessagePackSerializer serializer, ITypeShapeProvider shapeProvider) + internal class FormatterProfile(MessagePackSerializer serializer, ITypeShapeProvider shapeProvider) { - public MessagePackSerializer Serializer => serializer; + internal MessagePackSerializer Serializer => serializer; - public ITypeShapeProvider ShapeProvider => shapeProvider; + internal ITypeShapeProvider ShapeProvider => shapeProvider; public T? Deserialize(ref MessagePackReader reader, CancellationToken cancellationToken = default) { @@ -46,7 +46,24 @@ public void Serialize(ref MessagePackWriter writer, T? value, CancellationTok public void SerializeObject(ref MessagePackWriter writer, object? value, Type objectType, CancellationToken cancellationToken = default) { + if (value is null) + { + writer.WriteNil(); + return; + } + serializer.SerializeObject(ref writer, value, shapeProvider.Resolve(objectType), cancellationToken); } + + public void SerializeObject(ref MessagePackWriter writer, object? value, CancellationToken cancellationToken = default) + { + if (value is null) + { + writer.WriteNil(); + return; + } + + serializer.SerializeObject(ref writer, value, shapeProvider.Resolve(value.GetType()), cancellationToken); + } } } diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs index af0816064..cc53867d0 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs @@ -83,12 +83,12 @@ public partial class NerdbankMessagePackFormatter : FormatterBase, IJsonRpcMessa /// /// All access to this field should be while holding a lock on this member's value. /// - private static readonly Dictionary> ParameterObjectPropertyTypes = new Dictionary>(); + private static readonly Dictionary> ParameterObjectPropertyTypes = []; /// /// The serializer context to use for top-level RPC messages. /// - private readonly FormatterContext rpcContext; + private readonly FormatterProfile rpcContext; private readonly ProgressConverterResolver progressConverterResolver; @@ -105,7 +105,7 @@ public partial class NerdbankMessagePackFormatter : FormatterBase, IJsonRpcMessa /// /// The serializer to use for user data (e.g. arguments, return values and errors). /// - private FormatterContext userDataContext; + private FormatterProfile userDataContext; /// /// Initializes a new instance of the class. @@ -127,7 +127,7 @@ public NerdbankMessagePackFormatter() serializer.RegisterConverter(new JsonRpcErrorDetailConverter(this)); serializer.RegisterConverter(new TraceParentConverter()); - this.rpcContext = new FormatterContext(serializer, ShapeProvider_StreamJsonRpc.Default); + this.rpcContext = new FormatterProfile(serializer, ShapeProvider_StreamJsonRpc.Default); // Create the specialized formatters/resolvers that we will inject into the chain for user data. this.progressConverterResolver = new ProgressConverterResolver(this); @@ -135,7 +135,7 @@ public NerdbankMessagePackFormatter() this.pipeConverterResolver = new PipeConverterResolver(this); this.exceptionResolver = new MessagePackExceptionConverterResolver(this); - FormatterContext userDataContext = new( + FormatterProfile userDataContext = new( new() { InternStrings = true, @@ -162,14 +162,14 @@ private interface IJsonRpcMessagePackRetention /// Configures the serialization context for user data with the specified configuration action. /// /// The action to configure the serialization context. - public void SetFormatterContext(Action configure) + public void SetFormatterProfile(Action configure) { Requires.NotNull(configure, nameof(configure)); - var builder = new FormatterContextBuilder(this, this.userDataContext); + var builder = new FormatterProfileBuilder(this, this.userDataContext); configure(builder); - FormatterContext context = builder.Build(); + FormatterProfile context = builder.Build(); this.MassageUserDataContext(context); this.userDataContext = context; @@ -276,7 +276,7 @@ private static (IReadOnlyDictionary ArgumentValues, IReadOnlyDi } // If we couldn't find a previously created argument types dictionary, create a mutable one that we'll build this time. - Dictionary? mutableArgumentTypes = argumentTypes is null ? new Dictionary() : null; + Dictionary? mutableArgumentTypes = argumentTypes is null ? [] : null; var result = new Dictionary(StringComparer.Ordinal); @@ -378,8 +378,10 @@ private static unsafe string ReadProtocolVersion(ref MessagePackReader reader) { if (!reader.TryReadStringSpan(out ReadOnlySpan valueBytes)) { - // TODO: More specific exception type - throw new MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); + ReadOnlySequence? valueSequence = reader.ReadStringSequence() ?? throw new MessagePackSerializationException(Resources.FailureDeserializingJsonRpc); + valueBytes = valueSequence.Value.IsSingleSegment + ? valueSequence.Value.First.Span + : valueSequence.Value.ToArray(); } // Recognize "2.0" since we expect it and can avoid decoding and allocating a new string for it. @@ -425,7 +427,7 @@ private static void ReadUnknownProperty(ref MessagePackReader reader, in Seriali /// and the dynamic object wrapper for serialization. /// /// The options for user data that is supplied by the user (or the default). - private void MassageUserDataContext(FormatterContext userDataContext) + private void MassageUserDataContext(FormatterProfile userDataContext) { // Add our own resolvers to fill in specialized behavior if the user doesn't provide/override it by their own resolver. userDataContext.Serializer.RegisterConverter(RequestIdConverter.Instance); @@ -434,58 +436,58 @@ private void MassageUserDataContext(FormatterContext userDataContext) private class MessagePackFormatterConverter : IFormatterConverter { - private readonly FormatterContext context; + private readonly FormatterProfile formatterContext; - internal MessagePackFormatterConverter(FormatterContext formatterContext) + internal MessagePackFormatterConverter(FormatterProfile formatterContext) { - this.context = formatterContext; + this.formatterContext = formatterContext; } #pragma warning disable CS8766 // This method may in fact return null, and no one cares. public object? Convert(object value, Type type) #pragma warning restore CS8766 { - return this.context.DeserializeObject((RawMessagePack)value, type); + return this.formatterContext.DeserializeObject((RawMessagePack)value, type); } public object Convert(object value, TypeCode typeCode) { return typeCode switch { - TypeCode.Object => this.context.Deserialize((RawMessagePack)value), + TypeCode.Object => this.formatterContext.Deserialize((RawMessagePack)value), _ => ExceptionSerializationHelpers.Convert(this, value, typeCode), }; } - public bool ToBoolean(object value) => this.context.Deserialize((RawMessagePack)value); + public bool ToBoolean(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - public byte ToByte(object value) => this.context.Deserialize((RawMessagePack)value); + public byte ToByte(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - public char ToChar(object value) => this.context.Deserialize((RawMessagePack)value); + public char ToChar(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - public DateTime ToDateTime(object value) => this.context.Deserialize((RawMessagePack)value); + public DateTime ToDateTime(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - public decimal ToDecimal(object value) => this.context.Deserialize((RawMessagePack)value); + public decimal ToDecimal(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - public double ToDouble(object value) => this.context.Deserialize((RawMessagePack)value); + public double ToDouble(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - public short ToInt16(object value) => this.context.Deserialize((RawMessagePack)value); + public short ToInt16(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - public int ToInt32(object value) => this.context.Deserialize((RawMessagePack)value); + public int ToInt32(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - public long ToInt64(object value) => this.context.Deserialize((RawMessagePack)value); + public long ToInt64(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - public sbyte ToSByte(object value) => this.context.Deserialize((RawMessagePack)value); + public sbyte ToSByte(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - public float ToSingle(object value) => this.context.Deserialize((RawMessagePack)value); + public float ToSingle(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - public string? ToString(object value) => value is null ? null : this.context.Deserialize((RawMessagePack)value); + public string? ToString(object value) => value is null ? null : this.formatterContext.Deserialize((RawMessagePack)value); - public ushort ToUInt16(object value) => this.context.Deserialize((RawMessagePack)value); + public ushort ToUInt16(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - public uint ToUInt32(object value) => this.context.Deserialize((RawMessagePack)value); + public uint ToUInt32(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - public ulong ToUInt64(object value) => this.context.Deserialize((RawMessagePack)value); + public ulong ToUInt64(object value) => this.formatterContext.Deserialize((RawMessagePack)value); } /// @@ -611,6 +613,8 @@ public override TClass Read(ref MessagePackReader reader, SerializationContext c public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) { + context.DepthStep(); + if (value is null) { writer.WriteNil(); @@ -643,6 +647,8 @@ internal PreciseTypeConverter(NerdbankMessagePackFormatter formatter) [return: MaybeNull] public override TClass? Read(ref MessagePackReader reader, SerializationContext context) { + context.DepthStep(); + if (reader.TryReadNil()) { return default!; @@ -656,6 +662,8 @@ internal PreciseTypeConverter(NerdbankMessagePackFormatter formatter) public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) { + context.DepthStep(); + if (value is null) { writer.WriteNil(); @@ -720,12 +728,13 @@ private partial class PreciseTypeConverter(NerdbankMessagePackFormatter mainF public override IAsyncEnumerable? Read(ref MessagePackReader reader, SerializationContext context) { + context.DepthStep(); + if (reader.TryReadNil()) { return default; } - context.DepthStep(); RawMessagePack? token = default; IReadOnlyList? initialElements = null; int propertyCount = reader.ReadMapHeader(); @@ -733,7 +742,10 @@ private partial class PreciseTypeConverter(NerdbankMessagePackFormatter mainF { if (!reader.TryReadStringSpan(out ReadOnlySpan stringKey)) { - throw new MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); + ReadOnlySequence keySequence = reader.ReadStringSequence() ?? throw new MessagePackSerializationException(Resources.FailureDeserializingJsonRpc); + stringKey = keySequence.IsSingleSegment + ? keySequence.First.Span + : keySequence.ToArray(); } if (TokenPropertyName.TryRead(stringKey)) @@ -756,6 +768,7 @@ private partial class PreciseTypeConverter(NerdbankMessagePackFormatter mainF [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to helper method")] public override void Write(ref MessagePackWriter writer, in IAsyncEnumerable? value, SerializationContext context) { + context.DepthStep(); Serialize_Shared(mainFormatter, ref writer, value, context); } @@ -819,6 +832,7 @@ public override TClass Read(ref MessagePackReader reader, SerializationContext c [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to helper method")] public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) { + context.DepthStep(); PreciseTypeConverter.Serialize_Shared(mainFormatter, ref writer, value, context); } @@ -870,6 +884,8 @@ private class DuplexPipeConverter(NerdbankMessagePackFormatter formatter) : M { public override T? Read(ref MessagePackReader reader, SerializationContext context) { + context.DepthStep(); + if (reader.TryReadNil()) { return null; @@ -880,6 +896,8 @@ private class DuplexPipeConverter(NerdbankMessagePackFormatter formatter) : M public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) { + context.DepthStep(); + if (formatter.DuplexPipeTracker.GetULongToken(value) is { } token) { writer.Write(token); @@ -903,6 +921,7 @@ private class PipeReaderConverter(NerdbankMessagePackFormatter formatter) : M { public override T? Read(ref MessagePackReader reader, SerializationContext context) { + context.DepthStep(); if (reader.TryReadNil()) { return null; @@ -913,6 +932,7 @@ private class PipeReaderConverter(NerdbankMessagePackFormatter formatter) : M public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) { + context.DepthStep(); if (formatter.DuplexPipeTracker.GetULongToken(value) is { } token) { writer.Write(token); @@ -936,6 +956,7 @@ private class PipeWriterConverter(NerdbankMessagePackFormatter formatter) : M { public override T? Read(ref MessagePackReader reader, SerializationContext context) { + context.DepthStep(); if (reader.TryReadNil()) { return null; @@ -946,6 +967,7 @@ private class PipeWriterConverter(NerdbankMessagePackFormatter formatter) : M public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) { + context.DepthStep(); if (formatter.DuplexPipeTracker.GetULongToken(value) is { } token) { writer.Write(token); @@ -976,6 +998,8 @@ public StreamConverter(NerdbankMessagePackFormatter formatter) public override T? Read(ref MessagePackReader reader, SerializationContext context) { + context.DepthStep(); + if (reader.TryReadNil()) { return null; @@ -986,6 +1010,8 @@ public StreamConverter(NerdbankMessagePackFormatter formatter) public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) { + context.DepthStep(); + if (this.formatter.DuplexPipeTracker.GetULongToken(value?.UsePipe()) is { } token) { writer.Write(token); @@ -1015,13 +1041,22 @@ private class RpcMarshalableConverter( [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Reader is passed to rpc context")] public override T? Read(ref MessagePackReader reader, SerializationContext context) { - MessageFormatterRpcMarshaledContextTracker.MarshalToken? token = formatter.rpcContext.Deserialize(ref reader); + context.DepthStep(); + + MessageFormatterRpcMarshaledContextTracker.MarshalToken? token = formatter + .rpcContext + .Deserialize( + ref reader, + context.CancellationToken); + return token.HasValue ? (T?)formatter.RpcMarshaledContextTracker.GetObject(typeof(T), token, proxyOptions) : null; } [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to rpc context")] public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) { + context.DepthStep(); + if (value is null) { writer.WriteNil(); @@ -1029,7 +1064,7 @@ public override void Write(ref MessagePackWriter writer, in T? value, Serializat else { MessageFormatterRpcMarshaledContextTracker.MarshalToken token = formatter.RpcMarshaledContextTracker.GetToken(value, targetOptions, typeof(T), rpcMarshalableAttribute); - formatter.rpcContext.Serialize(ref writer, token); + formatter.rpcContext.Serialize(ref writer, token, context.CancellationToken); } } @@ -1086,6 +1121,9 @@ private partial class ExceptionConverter(NerdbankMessagePackFormatter formatt public override T? Read(ref MessagePackReader reader, SerializationContext context) { Assumes.NotNull(formatter.JsonRpc); + + context.DepthStep(); + if (reader.TryReadNil()) { return null; @@ -1131,6 +1169,8 @@ private partial class ExceptionConverter(NerdbankMessagePackFormatter formatt public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) { + context.DepthStep(); + if (value is null) { writer.WriteNil(); @@ -1154,12 +1194,11 @@ public override void Write(ref MessagePackWriter writer, in T? value, Serializat foreach (SerializationEntry element in info.GetSafeMembers()) { writer.Write(element.Name); -#pragma warning disable NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods formatter.rpcContext.SerializeObject( ref writer, element.Value, - element.ObjectType); -#pragma warning restore NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods + element.ObjectType, + context.CancellationToken); } } finally @@ -1197,7 +1236,10 @@ internal JsonRpcMessageConverter(NerdbankMessagePackFormatter formatter) // MessagePackFormatter: ReadOnlySpan stringKey = MessagePack.Internal.CodeGenHelpers.ReadStringSpan(ref readAhead); if (!readAhead.TryReadStringSpan(out ReadOnlySpan stringKey)) { - throw new UnrecognizedJsonRpcMessageException(); + ReadOnlySequence? keySequence = readAhead.ReadStringSequence() ?? throw new MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); + stringKey = keySequence.Value.IsSingleSegment + ? keySequence.Value.First.Span + : keySequence.Value.ToArray(); } if (MethodPropertyName.TryRead(stringKey)) @@ -1225,10 +1267,10 @@ public override void Write(ref MessagePackWriter writer, in JsonRpcMessage? valu { Requires.NotNull(value!, nameof(value)); + context.DepthStep(); + using (this.formatter.TrackSerialization(value)) { - context.DepthStep(); - switch (value) { case Protocol.JsonRpcRequest request: @@ -1263,15 +1305,16 @@ internal JsonRpcRequestConverter(NerdbankMessagePackFormatter formatter) public override Protocol.JsonRpcRequest? Read(ref MessagePackReader reader, SerializationContext context) { + context.DepthStep(); + var result = new JsonRpcRequest(this.formatter) { OriginalMessagePack = reader.Sequence, }; - context.DepthStep(); + Dictionary>? topLevelProperties = null; int propertyCount = reader.ReadMapHeader(); - Dictionary>? topLevelProperties = null; for (int propertyIndex = 0; propertyIndex < propertyCount; propertyIndex++) { // We read the property name in this fancy way in order to avoid paying to decode and allocate a string when we already know what we're looking for. @@ -1316,12 +1359,7 @@ internal JsonRpcRequestConverter(NerdbankMessagePackFormatter formatter) var namedArgs = new Dictionary>(namedArgsCount); for (int i = 0; i < namedArgsCount; i++) { - string? propertyName = context.GetConverter(null).Read(ref reader, context); - if (propertyName is null) - { - throw new MessagePackSerializationException(Resources.UnexpectedNullValueInMap); - } - + string? propertyName = context.GetConverter(null).Read(ref reader, context) ?? throw new MessagePackSerializationException(Resources.UnexpectedNullValueInMap); namedArgs.Add(propertyName, GetSliceForNextToken(ref reader, context)); } @@ -1396,10 +1434,6 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcRequ writer.Write(value.Method); ParamsPropertyName.Write(ref writer); - - // TODO: Get from SetOptions - ITypeShapeProvider? userShapeProvider = context.TypeShapeProvider; - if (value.ArgumentsList is not null) { writer.WriteArrayHeader(value.ArgumentsList.Count); @@ -1407,22 +1441,21 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcRequ for (int i = 0; i < value.ArgumentsList.Count; i++) { object? arg = value.ArgumentsList[i]; - ITypeShape? argShape = arg is null - ? null - : value.ArgumentListDeclaredTypes is not null - ? userShapeProvider?.GetShape(value.ArgumentListDeclaredTypes[i]) - : ReflectionTypeShapeProvider.Default.Resolve(arg.GetType()); - if (argShape is not null) + if (value.ArgumentListDeclaredTypes is null) { -#pragma warning disable NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods - this.formatter.userDataContext.Serializer.SerializeObject(ref writer, arg, argShape, context.CancellationToken); -#pragma warning restore NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods + this.formatter.userDataContext.SerializeObject( + ref writer, + arg, + context.CancellationToken); } else { - // TODO: NOT REALLY SURE ABOUT THIS YET - writer.WriteNil(); + this.formatter.userDataContext.SerializeObject( + ref writer, + arg, + value.ArgumentListDeclaredTypes[i], + context.CancellationToken); } } } @@ -1432,20 +1465,22 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcRequ foreach (KeyValuePair entry in value.NamedArguments) { writer.Write(entry.Key); - ITypeShape? argShape = value.NamedArgumentDeclaredTypes?[entry.Key] is Type argType - ? userShapeProvider?.GetShape(argType) - : null; - if (argShape is not null) + if (value.NamedArgumentDeclaredTypes is null) { -#pragma warning disable NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods - this.formatter.userDataContext.Serializer.SerializeObject(ref writer, entry.Value, argShape, context.CancellationToken); -#pragma warning restore NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods + this.formatter.userDataContext.SerializeObject( + ref writer, + entry.Value, + context.CancellationToken); } else { - // TODO: NOT REALLY SURE ABOUT THIS YET - writer.WriteNil(); + Type argType = value.NamedArgumentDeclaredTypes[entry.Key]; + this.formatter.userDataContext.SerializeObject( + ref writer, + entry.Value, + argType, + context.CancellationToken); } } } @@ -1551,14 +1586,15 @@ internal JsonRpcResultConverter(NerdbankMessagePackFormatter formatter) public override Protocol.JsonRpcResult Read(ref MessagePackReader reader, SerializationContext context) { + context.DepthStep(); + var result = new JsonRpcResult(this.formatter, this.formatter.userDataContext) { OriginalMessagePack = reader.Sequence, }; - context.DepthStep(); - Dictionary>? topLevelProperties = null; + int propertyCount = reader.ReadMapHeader(); for (int propertyIndex = 0; propertyIndex < propertyCount; propertyIndex++) { @@ -1601,6 +1637,8 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcResu { Requires.NotNull(value!, nameof(value)); + context.DepthStep(); + var topLevelPropertyBagMessage = value as IMessageWithTopLevelPropertyBag; int mapElementCount = 3; @@ -1614,22 +1652,18 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcResu ResultPropertyName.Write(ref writer); - ITypeShape? typeShape = value.ResultDeclaredType is not null && value.ResultDeclaredType != typeof(void) - ? this.formatter.userDataContext.ShapeProvider.Resolve(value.ResultDeclaredType) - : value.Result is null - ? null - : this.formatter.userDataContext.ShapeProvider.Resolve(value.Result.GetType()); + if (value.Result is null) + { + writer.WriteNil(); + } - if (typeShape is not null) + if (value.ResultDeclaredType is not null && value.ResultDeclaredType != typeof(void)) { -#pragma warning disable NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods - this.formatter.userDataContext.Serializer.SerializeObject(ref writer, value.Result, typeShape, context.CancellationToken); -#pragma warning restore NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods + this.formatter.userDataContext.SerializeObject(ref writer, value.Result, value.ResultDeclaredType, context.CancellationToken); } else { - // TODO: NOT REALLY SURE ABOUT THIS YET - writer.WriteNil(); + this.formatter.userDataContext.SerializeObject(ref writer, value.Result, context.CancellationToken); } (topLevelPropertyBagMessage?.TopLevelPropertyBag as TopLevelPropertyBag)?.WriteProperties(ref writer); @@ -1657,10 +1691,10 @@ public override Protocol.JsonRpcError Read(ref MessagePackReader reader, Seriali OriginalMessagePack = reader.Sequence, }; - Dictionary>? topLevelProperties = null; - context.DepthStep(); + Dictionary>? topLevelProperties = null; + int propertyCount = reader.ReadMapHeader(); for (int propertyIdx = 0; propertyIdx < propertyCount; propertyIdx++) { @@ -1703,6 +1737,8 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcErro { Requires.NotNull(value!, nameof(value)); + context.DepthStep(); + var topLevelPropertyBag = (TopLevelPropertyBag?)(value as IMessageWithTopLevelPropertyBag)?.TopLevelPropertyBag; int mapElementCount = 3; @@ -1712,10 +1748,12 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcErro WriteProtocolVersionPropertyAndValue(ref writer, value.Version); IdPropertyName.Write(ref writer); - context.GetConverter(context.TypeShapeProvider).Write(ref writer, value.RequestId, context); + context.GetConverter(context.TypeShapeProvider) + .Write(ref writer, value.RequestId, context); ErrorPropertyName.Write(ref writer); - context.GetConverter(context.TypeShapeProvider).Write(ref writer, value.Error, context); + context.GetConverter(context.TypeShapeProvider) + .Write(ref writer, value.Error, context); topLevelPropertyBag?.WriteProperties(ref writer); } @@ -1741,9 +1779,10 @@ internal JsonRpcErrorDetailConverter(NerdbankMessagePackFormatter formatter) public override Protocol.JsonRpcError.ErrorDetail Read(ref MessagePackReader reader, SerializationContext context) { - var result = new JsonRpcError.ErrorDetail(this.formatter.userDataContext); context.DepthStep(); + var result = new JsonRpcError.ErrorDetail(this.formatter.userDataContext); + int propertyCount = reader.ReadMapHeader(); for (int propertyIdx = 0; propertyIdx < propertyCount; propertyIdx++) { @@ -1793,7 +1832,7 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcErro writer.Write(value.Message); DataPropertyName.Write(ref writer); - this.formatter.userDataContext.Serialize(ref writer, value.Data); + this.formatter.userDataContext.Serialize(ref writer, value.Data, context.CancellationToken); } public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) @@ -1817,12 +1856,14 @@ private EventArgsConverter() public override void Write(ref MessagePackWriter writer, in EventArgs? value, SerializationContext context) { Requires.NotNull(value!, nameof(value)); + context.DepthStep(); writer.WriteMapHeader(0); } /// public override EventArgs Read(ref MessagePackReader reader, SerializationContext context) { + context.DepthStep(); reader.Skip(context); return EventArgs.Empty; } @@ -1837,6 +1878,8 @@ private class TraceParentConverter : MessagePackConverter { public unsafe override TraceParent Read(ref MessagePackReader reader, SerializationContext context) { + context.DepthStep(); + if (reader.ReadArrayHeader() != 2) { throw new NotSupportedException("Unexpected array length."); @@ -1872,6 +1915,8 @@ public unsafe override void Write(ref MessagePackWriter writer, in TraceParent v throw new NotSupportedException("traceparent version " + value.Version + " is not supported."); } + context.DepthStep(); + writer.WriteArrayHeader(2); writer.Write(value.Version); @@ -1899,7 +1944,7 @@ public unsafe override void Write(ref MessagePackWriter writer, in TraceParent v private class TopLevelPropertyBag : TopLevelPropertyBagBase { - private readonly FormatterContext formatterContext; + private readonly FormatterProfile formatterContext; private readonly IReadOnlyDictionary>? inboundUnknownProperties; /// @@ -1908,7 +1953,7 @@ private class TopLevelPropertyBag : TopLevelPropertyBagBase /// /// The serializer options to use for this data. /// The map of unrecognized inbound properties. - internal TopLevelPropertyBag(FormatterContext userDataContext, IReadOnlyDictionary> inboundUnknownProperties) + internal TopLevelPropertyBag(FormatterProfile userDataContext, IReadOnlyDictionary> inboundUnknownProperties) : base(isOutbound: false) { this.formatterContext = userDataContext; @@ -1920,7 +1965,7 @@ internal TopLevelPropertyBag(FormatterContext userDataContext, IReadOnlyDictiona /// for an outbound message. /// /// The serializer options to use for this data. - internal TopLevelPropertyBag(FormatterContext formatterContext) + internal TopLevelPropertyBag(FormatterProfile formatterContext) : base(isOutbound: true) { this.formatterContext = formatterContext; @@ -1952,10 +1997,8 @@ internal void WriteProperties(ref MessagePackWriter writer) { foreach (KeyValuePair entry in this.OutboundProperties) { - ITypeShape shape = this.formatterContext.ShapeProvider.Resolve(entry.Value.DeclaredType); - writer.Write(entry.Key); - this.formatterContext.Serializer.SerializeObject(ref writer, entry.Value.Value, shape); + this.formatterContext.SerializeObject(ref writer, entry.Value.Value, entry.Value.DeclaredType); } } } @@ -1971,8 +2014,7 @@ protected internal override bool TryGetTopLevelProperty(string name, [MaybeNu if (this.inboundUnknownProperties.TryGetValue(name, out ReadOnlySequence serializedValue) is true) { - var reader = new MessagePackReader(serializedValue); - value = this.formatterContext.Serializer.Deserialize(ref reader, this.formatterContext.ShapeProvider); + value = this.formatterContext.Deserialize(serializedValue); return true; } @@ -2064,14 +2106,13 @@ public override bool TryGetArgumentByNameOrIndex(string? name, int position, Typ return false; } - var reader = new MessagePackReader(msgpackArgument); using (this.formatter.TrackDeserialization(this)) { try { - value = this.formatter.userDataContext.Serializer.DeserializeObject( - ref reader, - this.formatter.userDataContext.ShapeProvider.Resolve(typeHint ?? typeof(object))); + value = this.formatter.userDataContext.DeserializeObject( + msgpackArgument, + typeHint ?? typeof(object)); return true; } @@ -2104,14 +2145,14 @@ protected override void ReleaseBuffers() private partial class JsonRpcResult : JsonRpcResultBase, IJsonRpcMessagePackRetention { private readonly NerdbankMessagePackFormatter formatter; - private readonly FormatterContext serializerOptions; + private readonly FormatterProfile formatterContext; private Exception? resultDeserializationException; - internal JsonRpcResult(NerdbankMessagePackFormatter formatter, FormatterContext serializationOptions) + internal JsonRpcResult(NerdbankMessagePackFormatter formatter, FormatterProfile serializationOptions) { this.formatter = formatter; - this.serializerOptions = serializationOptions; + this.formatterContext = serializationOptions; } public ReadOnlySequence OriginalMessagePack { get; internal set; } @@ -2127,7 +2168,7 @@ public override T GetResult() return this.MsgPackResult.IsEmpty ? (T)this.Result! - : this.serializerOptions.Serializer.Deserialize(this.MsgPackResult, this.serializerOptions.ShapeProvider) + : this.formatterContext.Deserialize(this.MsgPackResult) ?? throw new MessagePackSerializationException(Resources.FailureDeserializingJsonRpc); } @@ -2135,14 +2176,11 @@ protected internal override void SetExpectedResultType(Type resultType) { Verify.Operation(!this.MsgPackResult.IsEmpty, "Result is no longer available or has already been deserialized."); - var reader = new MessagePackReader(this.MsgPackResult); try { using (this.formatter.TrackDeserialization(this)) { - this.Result = this.serializerOptions.Serializer.DeserializeObject( - ref reader, - this.serializerOptions.ShapeProvider.Resolve(resultType)); + this.Result = this.formatterContext.DeserializeObject(this.MsgPackResult, resultType); } this.MsgPackResult = default; @@ -2161,22 +2199,22 @@ protected override void ReleaseBuffers() this.OriginalMessagePack = default; } - protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.serializerOptions); + protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.formatterContext); } [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] private class JsonRpcError : JsonRpcErrorBase, IJsonRpcMessagePackRetention { - private readonly FormatterContext serializerOptions; + private readonly FormatterProfile formatterContext; - public JsonRpcError(FormatterContext serializerOptions) + public JsonRpcError(FormatterProfile serializerOptions) { - this.serializerOptions = serializerOptions; + this.formatterContext = serializerOptions; } public ReadOnlySequence OriginalMessagePack { get; internal set; } - protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.serializerOptions); + protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.formatterContext); protected override void ReleaseBuffers() { @@ -2189,14 +2227,13 @@ protected override void ReleaseBuffers() this.OriginalMessagePack = default; } - [DataContract] internal new class ErrorDetail : Protocol.JsonRpcError.ErrorDetail { - private readonly FormatterContext serializerOptions; + private readonly FormatterProfile formatterContext; - internal ErrorDetail(FormatterContext serializerOptions) + internal ErrorDetail(FormatterProfile serializerOptions) { - this.serializerOptions = serializerOptions ?? throw new ArgumentNullException(nameof(serializerOptions)); + this.formatterContext = serializerOptions ?? throw new ArgumentNullException(nameof(serializerOptions)); } internal ReadOnlySequence MsgPackData { get; set; } @@ -2209,12 +2246,9 @@ internal ErrorDetail(FormatterContext serializerOptions) return this.Data; } - var reader = new MessagePackReader(this.MsgPackData); try { - return this.serializerOptions.Serializer.DeserializeObject( - ref reader, - this.serializerOptions.ShapeProvider.Resolve(dataType)) + return this.formatterContext.DeserializeObject(this.MsgPackData, dataType) ?? throw new MessagePackSerializationException(Resources.FailureDeserializingJsonRpc); } catch (MessagePackSerializationException) @@ -2224,7 +2258,7 @@ internal ErrorDetail(FormatterContext serializerOptions) { // return MessagePackSerializer.Deserialize(this.MsgPackData, this.serializerOptions.WithResolver(PrimitiveObjectResolver.Instance)); // TODO: Which Shape Provider to use? - return this.serializerOptions.Serializer.Deserialize(this.MsgPackData, this.serializerOptions.ShapeProvider); + return this.formatterContext.Deserialize(this.MsgPackData); } catch (MessagePackSerializationException) { diff --git a/src/StreamJsonRpc/Protocol/JsonRpcError.cs b/src/StreamJsonRpc/Protocol/JsonRpcError.cs index 9842d5719..244beea36 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcError.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcError.cs @@ -71,7 +71,8 @@ public override string ToString() /// Describes the error. /// [DataContract] - public class ErrorDetail + [GenerateShape] + public partial class ErrorDetail { /// /// Gets or sets a number that indicates the error type that occurred. diff --git a/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs index b0e10234a..1e40c9187 100644 --- a/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs @@ -11,15 +11,15 @@ public AsyncEnumerableNerdbankMessagePackTests(ITestOutputHelper logger) protected override void InitializeFormattersAndHandlers() { NerdbankMessagePackFormatter serverFormatter = new(); - serverFormatter.SetFormatterContext(ConfigureContext); + serverFormatter.SetFormatterProfile(ConfigureContext); NerdbankMessagePackFormatter clientFormatter = new(); - clientFormatter.SetFormatterContext(ConfigureContext); + clientFormatter.SetFormatterProfile(ConfigureContext); this.serverMessageFormatter = serverFormatter; this.clientMessageFormatter = clientFormatter; - static void ConfigureContext(NerdbankMessagePackFormatter.FormatterContextBuilder contextBuilder) + static void ConfigureContext(NerdbankMessagePackFormatter.FormatterProfileBuilder contextBuilder) { contextBuilder.RegisterAsyncEnumerableType, int>(); } diff --git a/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs index 277ee3287..2de5c5cd2 100644 --- a/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs +++ b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs @@ -395,15 +395,15 @@ protected override void InitializeFormattersAndHandlers( serverMessageFormatter = new NerdbankMessagePackFormatter(); clientMessageFormatter = new NerdbankMessagePackFormatter(); - ((NerdbankMessagePackFormatter)serverMessageFormatter).SetFormatterContext(Configure); - ((NerdbankMessagePackFormatter)clientMessageFormatter).SetFormatterContext(Configure); + ((NerdbankMessagePackFormatter)serverMessageFormatter).SetFormatterProfile(Configure); + ((NerdbankMessagePackFormatter)clientMessageFormatter).SetFormatterProfile(Configure); serverMessageHandler = new LengthHeaderMessageHandler(serverStream, serverStream, serverMessageFormatter); clientMessageHandler = controlledFlushingClient ? new DelayedFlushingHandler(clientStream, clientMessageFormatter) : new LengthHeaderMessageHandler(clientStream, clientStream, clientMessageFormatter); - static void Configure(NerdbankMessagePackFormatter.FormatterContextBuilder b) + static void Configure(NerdbankMessagePackFormatter.FormatterProfileBuilder b) { b.RegisterConverter(new UnserializableTypeConverter()); b.RegisterConverter(new TypeThrowsWhenDeserializedConverter()); diff --git a/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs b/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs index 9a62e6a5b..519f0bf71 100644 --- a/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs +++ b/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs @@ -131,7 +131,7 @@ public async Task BasicJsonRpc() [Fact] public void Resolver_RequestArgInArray() { - this.Formatter.SetFormatterContext(b => + this.Formatter.SetFormatterProfile(b => { b.RegisterConverter(new CustomConverter()); b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); @@ -154,7 +154,7 @@ public void Resolver_RequestArgInArray() [Fact] public void Resolver_RequestArgInNamedArgs_AnonymousType() { - this.Formatter.SetFormatterContext(b => + this.Formatter.SetFormatterProfile(b => { b.RegisterConverter(new CustomConverter()); b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); @@ -177,7 +177,7 @@ public void Resolver_RequestArgInNamedArgs_AnonymousType() [Fact] public void Resolver_RequestArgInNamedArgs_DataContractObject() { - this.Formatter.SetFormatterContext(b => + this.Formatter.SetFormatterProfile(b => { b.RegisterConverter(new CustomConverter()); b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); @@ -202,7 +202,7 @@ public void Resolver_RequestArgInNamedArgs_DataContractObject() [Fact] public void Resolver_RequestArgInNamedArgs_NonDataContractObject() { - this.Formatter.SetFormatterContext(b => + this.Formatter.SetFormatterProfile(b => { b.RegisterConverter(new CustomConverter()); b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); @@ -243,7 +243,7 @@ public void Resolver_RequestArgInNamedArgs_NullObject() [Fact] public void Resolver_Result() { - this.Formatter.SetFormatterContext(b => + this.Formatter.SetFormatterProfile(b => { b.RegisterConverter(new CustomConverter()); b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); @@ -264,7 +264,7 @@ public void Resolver_Result() [Fact] public void Resolver_ErrorData() { - this.Formatter.SetFormatterContext(b => + this.Formatter.SetFormatterProfile(b => { b.RegisterConverter(new CustomConverter()); b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); From 0c06ae50bbbfd91ff9f204a26c44a417c7ae29be Mon Sep 17 00:00:00 2001 From: Charles Willis Date: Tue, 24 Dec 2024 01:30:09 +0000 Subject: [PATCH 14/25] Adopt MessagePackString and configure test formatters. --- Directory.Packages.props | 2 +- ...rdbankMessagePackFormatter.CommonString.cs | 107 ---- .../NerdbankMessagePackFormatter.Constants.cs | 60 +++ ...nkMessagePackFormatter.FormatterProfile.cs | 62 +-- ...ePackFormatter.FormatterProfileBuilder.cs} | 4 +- ...Formatter.MessagePackFormatterConverter.cs | 70 +++ ...bankMessagePackFormatter.ToStringHelper.cs | 51 ++ .../NerdbankMessagePackFormatter.cs | 487 ++++++------------ ...nkMessagePackFormatterProfileExtensions.cs | 60 +++ ...AsyncEnumerableNerdbankMessagePackTests.cs | 13 +- .../AsyncEnumerableTests.cs | 4 +- ...xPipeMarshalingNerdbankMessagePackTests.cs | 32 +- .../JsonRpcNerdbankMessagePackLengthTests.cs | 5 +- test/StreamJsonRpc.Tests/JsonRpcTests.cs | 13 +- ...arshalableProxyNerdbankMessagePackTests.cs | 28 +- .../MarshalableProxyTests.cs | 20 +- .../NerdbankMessagePackFormatterTests.cs | 6 + 17 files changed, 519 insertions(+), 505 deletions(-) delete mode 100644 src/StreamJsonRpc/NerdbankMessagePackFormatter.CommonString.cs create mode 100644 src/StreamJsonRpc/NerdbankMessagePackFormatter.Constants.cs rename src/StreamJsonRpc/{NerdbankMessagePackFormatter.FormatterContextBuilder.cs => NerdbankMessagePackFormatter.FormatterProfileBuilder.cs} (98%) create mode 100644 src/StreamJsonRpc/NerdbankMessagePackFormatter.MessagePackFormatterConverter.cs create mode 100644 src/StreamJsonRpc/NerdbankMessagePackFormatter.ToStringHelper.cs create mode 100644 src/StreamJsonRpc/NerdbankMessagePackFormatterProfileExtensions.cs diff --git a/Directory.Packages.props b/Directory.Packages.props index 6a87cb5f6..bd3c7d2a2 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -22,7 +22,7 @@ - + diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.CommonString.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.CommonString.cs deleted file mode 100644 index 9b03c9158..000000000 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.CommonString.cs +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. - -using System.Diagnostics; -using NBMP = Nerdbank.MessagePack; - -namespace StreamJsonRpc; - -/// -/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). -/// -public partial class NerdbankMessagePackFormatter -{ - [DebuggerDisplay("{" + nameof(Value) + "}")] - private struct CommonString - { - internal CommonString(string value) - { - Requires.Argument(value.Length > 0 && value.Length <= 16, nameof(value), "Length must be >0 and <=16."); - this.Value = value; - ReadOnlyMemory encodedBytes = MessagePack.Internal.CodeGenHelpers.GetEncodedStringBytes(value); - this.EncodedBytes = encodedBytes; - - ReadOnlySpan span = this.EncodedBytes.Span.Slice(1); - this.Key = MessagePack.Internal.AutomataKeyGen.GetKey(ref span); // header is 1 byte because string length <= 16 - this.Key2 = span.Length > 0 ? (ulong?)MessagePack.Internal.AutomataKeyGen.GetKey(ref span) : null; - } - - /// - /// Gets the original string. - /// - internal string Value { get; } - - /// - /// Gets the 64-bit integer that represents the string without decoding it. - /// - private ulong Key { get; } - - /// - /// Gets the next 64-bit integer that represents the string without decoding it. - /// - private ulong? Key2 { get; } - - /// - /// Gets the messagepack header and UTF-8 bytes for this string. - /// - private ReadOnlyMemory EncodedBytes { get; } - - /// - /// Writes out the messagepack binary for this common string, if it matches the given value. - /// - /// The writer to use. - /// The value to be written, if it matches this . - /// if matches this and it was written; otherwise. - internal bool TryWrite(ref NBMP::MessagePackWriter writer, string value) - { - if (value == this.Value) - { - this.Write(ref writer); - return true; - } - - return false; - } - - internal readonly void Write(ref NBMP::MessagePackWriter writer) => writer.WriteRaw(this.EncodedBytes.Span); - - /// - /// Checks whether a span of UTF-8 bytes equal this common string. - /// - /// The UTF-8 string. - /// if the UTF-8 bytes are the encoding of this common string; otherwise. - internal readonly bool TryRead(ReadOnlySpan utf8String) - { - if (utf8String.Length != this.EncodedBytes.Length - 1) - { - return false; - } - - ulong key1 = MessagePack.Internal.AutomataKeyGen.GetKey(ref utf8String); - if (key1 != this.Key) - { - return false; - } - - if (utf8String.Length > 0) - { - if (!this.Key2.HasValue) - { - return false; - } - - ulong key2 = MessagePack.Internal.AutomataKeyGen.GetKey(ref utf8String); - if (key2 != this.Key2.Value) - { - return false; - } - } - else if (this.Key2.HasValue) - { - return false; - } - - return true; - } - } -} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.Constants.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.Constants.cs new file mode 100644 index 000000000..52699eff1 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.Constants.cs @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Diagnostics; +using Nerdbank.MessagePack; +using StreamJsonRpc.Protocol; +using NBMP = Nerdbank.MessagePack; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + /// + /// The constant "jsonrpc", in its various forms. + /// + private static readonly MessagePackString VersionPropertyName = new(Constants.jsonrpc); + + /// + /// The constant "id", in its various forms. + /// + private static readonly MessagePackString IdPropertyName = new(Constants.id); + + /// + /// The constant "method", in its various forms. + /// + private static readonly MessagePackString MethodPropertyName = new(Constants.Request.method); + + /// + /// The constant "result", in its various forms. + /// + private static readonly MessagePackString ResultPropertyName = new(Constants.Result.result); + + /// + /// The constant "error", in its various forms. + /// + private static readonly MessagePackString ErrorPropertyName = new(Constants.Error.error); + + /// + /// The constant "params", in its various forms. + /// + private static readonly MessagePackString ParamsPropertyName = new(Constants.Request.@params); + + /// + /// The constant "traceparent", in its various forms. + /// + private static readonly MessagePackString TraceParentPropertyName = new(Constants.Request.traceparent); + + /// + /// The constant "tracestate", in its various forms. + /// + private static readonly MessagePackString TraceStatePropertyName = new(Constants.Request.tracestate); + + /// + /// The constant "2.0", in its various forms. + /// + private static readonly MessagePackString Version2 = new("2.0"); +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfile.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfile.cs index 14ce67500..6edcbfa5a 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfile.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfile.cs @@ -3,7 +3,6 @@ using Nerdbank.MessagePack; using PolyType; -using PolyType.Abstractions; namespace StreamJsonRpc; @@ -12,58 +11,21 @@ namespace StreamJsonRpc; /// public sealed partial class NerdbankMessagePackFormatter { - internal class FormatterProfile(MessagePackSerializer serializer, ITypeShapeProvider shapeProvider) + /// + /// Initializes a new instance of the class. + /// + /// The MessagePack serializer to use. + /// The type shape provider to use. + public class FormatterProfile(MessagePackSerializer serializer, ITypeShapeProvider shapeProvider) { + /// + /// Gets the MessagePack serializer. + /// internal MessagePackSerializer Serializer => serializer; + /// + /// Gets the type shape provider. + /// internal ITypeShapeProvider ShapeProvider => shapeProvider; - - public T? Deserialize(ref MessagePackReader reader, CancellationToken cancellationToken = default) - { - return serializer.Deserialize(ref reader, shapeProvider, cancellationToken); - } - - public T Deserialize(in RawMessagePack pack, CancellationToken cancellationToken = default) - { - // TODO: Improve the exception - return serializer.Deserialize(pack, shapeProvider, cancellationToken) - ?? throw new MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); - } - - public object? DeserializeObject(in RawMessagePack pack, Type objectType, CancellationToken cancellationToken = default) - { - MessagePackReader reader = new(pack); - return serializer.DeserializeObject( - ref reader, - shapeProvider.Resolve(objectType), - cancellationToken); - } - - public void Serialize(ref MessagePackWriter writer, T? value, CancellationToken cancellationToken = default) - { - serializer.Serialize(ref writer, value, shapeProvider, cancellationToken); - } - - public void SerializeObject(ref MessagePackWriter writer, object? value, Type objectType, CancellationToken cancellationToken = default) - { - if (value is null) - { - writer.WriteNil(); - return; - } - - serializer.SerializeObject(ref writer, value, shapeProvider.Resolve(objectType), cancellationToken); - } - - public void SerializeObject(ref MessagePackWriter writer, object? value, CancellationToken cancellationToken = default) - { - if (value is null) - { - writer.WriteNil(); - return; - } - - serializer.SerializeObject(ref writer, value, shapeProvider.Resolve(value.GetType()), cancellationToken); - } } } diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContextBuilder.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfileBuilder.cs similarity index 98% rename from src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContextBuilder.cs rename to src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfileBuilder.cs index 19fe149ea..24f0c3780 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContextBuilder.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfileBuilder.cs @@ -166,16 +166,18 @@ public void RegisterRpcMarshalableType() attribute)!; this.baseProfile.Serializer.RegisterConverter(converter); + return; } // TODO: Throw? + throw new NotSupportedException(); } /// /// Builds the formatter profile. /// /// The built formatter profile. - internal FormatterProfile Build() + public FormatterProfile Build() { if (this.typeShapeProvidersBuilder is null || this.typeShapeProvidersBuilder.Count < 1) { diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.MessagePackFormatterConverter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.MessagePackFormatterConverter.cs new file mode 100644 index 000000000..d2de639cd --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.MessagePackFormatterConverter.cs @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Runtime.Serialization; +using Nerdbank.MessagePack; +using StreamJsonRpc.Reflection; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + private class MessagePackFormatterConverter : IFormatterConverter + { + private readonly FormatterProfile formatterContext; + + internal MessagePackFormatterConverter(FormatterProfile formatterContext) + { + this.formatterContext = formatterContext; + } + +#pragma warning disable CS8766 // This method may in fact return null, and no one cares. + public object? Convert(object value, Type type) +#pragma warning restore CS8766 + { + return this.formatterContext.DeserializeObject((RawMessagePack)value, type); + } + + public object Convert(object value, TypeCode typeCode) + { + return typeCode switch + { + TypeCode.Object => this.formatterContext.Deserialize((RawMessagePack)value), + _ => ExceptionSerializationHelpers.Convert(this, value, typeCode), + }; + } + + public bool ToBoolean(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + + public byte ToByte(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + + public char ToChar(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + + public DateTime ToDateTime(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + + public decimal ToDecimal(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + + public double ToDouble(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + + public short ToInt16(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + + public int ToInt32(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + + public long ToInt64(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + + public sbyte ToSByte(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + + public float ToSingle(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + + public string? ToString(object value) => value is null ? null : this.formatterContext.Deserialize((RawMessagePack)value); + + public ushort ToUInt16(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + + public uint ToUInt32(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + + public ulong ToUInt64(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.ToStringHelper.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ToStringHelper.cs new file mode 100644 index 000000000..ba740d2b5 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ToStringHelper.cs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Buffers; +using Nerdbank.MessagePack; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + /// + /// A recyclable object that can serialize a message to JSON on demand. + /// + /// + /// In perf traces, creation of this object used to show up as one of the most allocated objects. + /// It is used even when tracing isn't active. So we changed its design to be reused, + /// since its lifetime is only required during a synchronous call to a trace API. + /// + private class ToStringHelper + { + private ReadOnlySequence? encodedMessage; + private string? jsonString; + + public override string ToString() + { + Verify.Operation(this.encodedMessage.HasValue, "This object has not been activated. It may have already been recycled."); + + return this.jsonString ??= MessagePackSerializer.ConvertToJson(this.encodedMessage.Value); + } + + /// + /// Initializes this object to represent a message. + /// + internal void Activate(ReadOnlySequence encodedMessage) + { + this.encodedMessage = encodedMessage; + } + + /// + /// Cleans out this object to release memory and ensure throws if someone uses it after deactivation. + /// + internal void Deactivate() + { + this.encodedMessage = null; + this.jsonString = null; + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs index cc53867d0..4243ae456 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs @@ -29,54 +29,9 @@ namespace StreamJsonRpc; /// /// The MessagePack implementation used here comes from https://github.com/AArnott/Nerdbank.MessagePack. /// -[System.Diagnostics.CodeAnalysis.SuppressMessage("ApiDesign", "RS0016:Add public types and members to the declared API", Justification = "TODO: Suppressed for Development")] +[SuppressMessage("ApiDesign", "RS0016:Add public types and members to the declared API", Justification = "TODO: Suppressed for Development")] public partial class NerdbankMessagePackFormatter : FormatterBase, IJsonRpcMessageFormatter, IJsonRpcFormatterTracingCallbacks, IJsonRpcMessageFactory { - /// - /// The constant "jsonrpc", in its various forms. - /// - private static readonly CommonString VersionPropertyName = new(Constants.jsonrpc); - - /// - /// The constant "id", in its various forms. - /// - private static readonly CommonString IdPropertyName = new(Constants.id); - - /// - /// The constant "method", in its various forms. - /// - private static readonly CommonString MethodPropertyName = new(Constants.Request.method); - - /// - /// The constant "result", in its various forms. - /// - private static readonly CommonString ResultPropertyName = new(Constants.Result.result); - - /// - /// The constant "error", in its various forms. - /// - private static readonly CommonString ErrorPropertyName = new(Constants.Error.error); - - /// - /// The constant "params", in its various forms. - /// - private static readonly CommonString ParamsPropertyName = new(Constants.Request.@params); - - /// - /// The constant "traceparent", in its various forms. - /// - private static readonly CommonString TraceParentPropertyName = new(Constants.Request.traceparent); - - /// - /// The constant "tracestate", in its various forms. - /// - private static readonly CommonString TraceStatePropertyName = new(Constants.Request.tracestate); - - /// - /// The constant "2.0", in its various forms. - /// - private static readonly CommonString Version2 = new("2.0"); - /// /// A cache of property names to declared property types, indexed by their containing parameter object type. /// @@ -88,7 +43,7 @@ public partial class NerdbankMessagePackFormatter : FormatterBase, IJsonRpcMessa /// /// The serializer context to use for top-level RPC messages. /// - private readonly FormatterProfile rpcContext; + private readonly FormatterProfile rpcProfile; private readonly ProgressConverterResolver progressConverterResolver; @@ -105,7 +60,7 @@ public partial class NerdbankMessagePackFormatter : FormatterBase, IJsonRpcMessa /// /// The serializer to use for user data (e.g. arguments, return values and errors). /// - private FormatterProfile userDataContext; + private FormatterProfile userDataProfile; /// /// Initializes a new instance of the class. @@ -116,10 +71,10 @@ public NerdbankMessagePackFormatter() MessagePackSerializer serializer = new() { InternStrings = true, - SerializeDefaultValues = false, + SerializeDefaultValues = true, }; - serializer.RegisterConverter(new RequestIdConverter()); + serializer.RegisterConverter(RequestIdConverter.Instance); serializer.RegisterConverter(new JsonRpcMessageConverter(this)); serializer.RegisterConverter(new JsonRpcRequestConverter(this)); serializer.RegisterConverter(new JsonRpcResultConverter(this)); @@ -127,7 +82,7 @@ public NerdbankMessagePackFormatter() serializer.RegisterConverter(new JsonRpcErrorDetailConverter(this)); serializer.RegisterConverter(new TraceParentConverter()); - this.rpcContext = new FormatterProfile(serializer, ShapeProvider_StreamJsonRpc.Default); + this.rpcProfile = new FormatterProfile(serializer, ShapeProvider_StreamJsonRpc.Default); // Create the specialized formatters/resolvers that we will inject into the chain for user data. this.progressConverterResolver = new ProgressConverterResolver(this); @@ -135,16 +90,23 @@ public NerdbankMessagePackFormatter() this.pipeConverterResolver = new PipeConverterResolver(this); this.exceptionResolver = new MessagePackExceptionConverterResolver(this); - FormatterProfile userDataContext = new( - new() - { - InternStrings = true, - SerializeDefaultValues = false, - }, - ReflectionTypeShapeProvider.Default); + // Create a serializer for user data. + MessagePackSerializer userSerializer = new() + { + InternStrings = true, + SerializeDefaultValues = true, + }; + + // Add our own resolvers to fill in specialized behavior if the user doesn't provide/override it by their own resolver. + // We preset this one in user data because $/cancellation methods can carry RequestId values as arguments. + userSerializer.RegisterConverter(RequestIdConverter.Instance); + + // We preset this one because for some protocols like IProgress, tokens are passed in that we must relay exactly back to the client as an argument. + userSerializer.RegisterConverter(RawMessagePackFormatter.Instance); + userSerializer.RegisterConverter(EventArgsConverter.Instance); - this.MassageUserDataContext(userDataContext); - this.userDataContext = userDataContext; + this.userDataProfile = new FormatterProfile(userSerializer, ReflectionTypeShapeProvider.Default); + this.ProfileBuilder = new FormatterProfileBuilder(this, this.userDataProfile); } private interface IJsonRpcMessagePackRetention @@ -158,27 +120,40 @@ private interface IJsonRpcMessagePackRetention ReadOnlySequence OriginalMessagePack { get; } } + /// + /// Gets the profile builder for the formatter. + /// + public FormatterProfileBuilder ProfileBuilder { get; } + + /// + /// Sets the formatter profile. + /// + /// The formatter profile to set. + public void SetFormatterProfile(FormatterProfile profile) + { + Requires.NotNull(profile, nameof(profile)); + + this.userDataProfile = profile; + } + /// /// Configures the serialization context for user data with the specified configuration action. /// /// The action to configure the serialization context. - public void SetFormatterProfile(Action configure) + public void SetFormatterProfile(Func configure) { Requires.NotNull(configure, nameof(configure)); - var builder = new FormatterProfileBuilder(this, this.userDataContext); - configure(builder); + var builder = new FormatterProfileBuilder(this, this.userDataProfile); + FormatterProfile profile = configure(builder); - FormatterProfile context = builder.Build(); - this.MassageUserDataContext(context); - - this.userDataContext = context; + this.userDataProfile = profile; } /// public JsonRpcMessage Deserialize(ReadOnlySequence contentBuffer) { - JsonRpcMessage message = this.rpcContext.Deserialize(contentBuffer); + JsonRpcMessage message = this.rpcProfile.Deserialize(contentBuffer); IJsonRpcTracingCallbacks? tracingCallbacks = this.JsonRpc; this.deserializationToStringHelper.Activate(contentBuffer); @@ -214,7 +189,7 @@ public void Serialize(IBufferWriter contentBuffer, JsonRpcMessage message) var writer = new MessagePackWriter(contentBuffer); try { - this.rpcContext.Serialize(ref writer, message); + this.rpcProfile.Serialize(ref writer, message); writer.Flush(); } catch (Exception ex) @@ -232,10 +207,10 @@ public object GetJsonText(JsonRpcMessage message) => message is IJsonRpcMessageP Protocol.JsonRpcRequest IJsonRpcMessageFactory.CreateRequestMessage() => new OutboundJsonRpcRequest(this); /// - Protocol.JsonRpcError IJsonRpcMessageFactory.CreateErrorMessage() => new JsonRpcError(this.userDataContext); + Protocol.JsonRpcError IJsonRpcMessageFactory.CreateErrorMessage() => new JsonRpcError(this.rpcProfile); /// - Protocol.JsonRpcResult IJsonRpcMessageFactory.CreateResultMessage() => new JsonRpcResult(this, this.rpcContext); + Protocol.JsonRpcResult IJsonRpcMessageFactory.CreateResultMessage() => new JsonRpcResult(this, this.rpcProfile); void IJsonRpcFormatterTracingCallbacks.OnSerializationComplete(JsonRpcMessage message, ReadOnlySequence encodedMessage) { @@ -376,26 +351,15 @@ private static ReadOnlySequence GetSliceForNextToken(ref MessagePackReader /// The decoded string. private static unsafe string ReadProtocolVersion(ref MessagePackReader reader) { - if (!reader.TryReadStringSpan(out ReadOnlySpan valueBytes)) - { - ReadOnlySequence? valueSequence = reader.ReadStringSequence() ?? throw new MessagePackSerializationException(Resources.FailureDeserializingJsonRpc); - valueBytes = valueSequence.Value.IsSingleSegment - ? valueSequence.Value.First.Span - : valueSequence.Value.ToArray(); - } - // Recognize "2.0" since we expect it and can avoid decoding and allocating a new string for it. - if (Version2.TryRead(valueBytes)) + if (Version2.TryRead(ref reader)) { return Version2.Value; } else { - // It wasn't the expected value, so decode it. - fixed (byte* pValueBytes = valueBytes) - { - return Encoding.UTF8.GetString(pValueBytes, valueBytes.Length); - } + // TODO: Should throw? + return reader.ReadString() ?? string.Empty; } } @@ -404,11 +368,8 @@ private static unsafe string ReadProtocolVersion(ref MessagePackReader reader) /// private static void WriteProtocolVersionPropertyAndValue(ref MessagePackWriter writer, string version) { - VersionPropertyName.Write(ref writer); - if (!Version2.TryWrite(ref writer, version)) - { - writer.Write(version); - } + writer.WriteRaw(VersionPropertyName.MsgPack.Span); + writer.WriteRaw(Version2.MsgPack.Span); } private static void ReadUnknownProperty(ref MessagePackReader reader, in SerializationContext context, ref Dictionary>? topLevelProperties, ReadOnlySpan stringKey) @@ -422,116 +383,14 @@ private static void ReadUnknownProperty(ref MessagePackReader reader, in Seriali topLevelProperties.Add(name, GetSliceForNextToken(ref reader, context)); } - /// - /// Takes the user-supplied resolver for their data types and prepares the wrapping options - /// and the dynamic object wrapper for serialization. - /// - /// The options for user data that is supplied by the user (or the default). - private void MassageUserDataContext(FormatterProfile userDataContext) - { - // Add our own resolvers to fill in specialized behavior if the user doesn't provide/override it by their own resolver. - userDataContext.Serializer.RegisterConverter(RequestIdConverter.Instance); - userDataContext.Serializer.RegisterConverter(EventArgsConverter.Instance); - } - - private class MessagePackFormatterConverter : IFormatterConverter - { - private readonly FormatterProfile formatterContext; - - internal MessagePackFormatterConverter(FormatterProfile formatterContext) - { - this.formatterContext = formatterContext; - } - -#pragma warning disable CS8766 // This method may in fact return null, and no one cares. - public object? Convert(object value, Type type) -#pragma warning restore CS8766 - { - return this.formatterContext.DeserializeObject((RawMessagePack)value, type); - } - - public object Convert(object value, TypeCode typeCode) - { - return typeCode switch - { - TypeCode.Object => this.formatterContext.Deserialize((RawMessagePack)value), - _ => ExceptionSerializationHelpers.Convert(this, value, typeCode), - }; - } - - public bool ToBoolean(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - - public byte ToByte(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - - public char ToChar(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - - public DateTime ToDateTime(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - - public decimal ToDecimal(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - - public double ToDouble(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - - public short ToInt16(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - - public int ToInt32(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - - public long ToInt64(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - - public sbyte ToSByte(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - - public float ToSingle(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - - public string? ToString(object value) => value is null ? null : this.formatterContext.Deserialize((RawMessagePack)value); - - public ushort ToUInt16(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - - public uint ToUInt32(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - - public ulong ToUInt64(object value) => this.formatterContext.Deserialize((RawMessagePack)value); - } - - /// - /// A recyclable object that can serialize a message to JSON on demand. - /// - /// - /// In perf traces, creation of this object used to show up as one of the most allocated objects. - /// It is used even when tracing isn't active. So we changed its design to be reused, - /// since its lifetime is only required during a synchronous call to a trace API. - /// - private class ToStringHelper + private class RequestIdConverter : MessagePackConverter { - private ReadOnlySequence? encodedMessage; - private string? jsonString; - - public override string ToString() - { - Verify.Operation(this.encodedMessage.HasValue, "This object has not been activated. It may have already been recycled."); - - return this.jsonString ??= MessagePackSerializer.ConvertToJson(this.encodedMessage.Value); - } + internal static readonly RequestIdConverter Instance = new(); - /// - /// Initializes this object to represent a message. - /// - internal void Activate(ReadOnlySequence encodedMessage) + private RequestIdConverter() { - this.encodedMessage = encodedMessage; } - /// - /// Cleans out this object to release memory and ensure throws if someone uses it after deactivation. - /// - internal void Deactivate() - { - this.encodedMessage = null; - this.jsonString = null; - } - } - - private class RequestIdConverter : MessagePackConverter - { - internal static readonly RequestIdConverter Instance = new(); - public override RequestId Read(ref MessagePackReader reader, SerializationContext context) { context.DepthStep(); @@ -568,6 +427,25 @@ public override void Write(ref MessagePackWriter writer, in RequestId value, Ser """)?.AsObject(); } + private class RawMessagePackFormatter : MessagePackConverter + { + internal static readonly RawMessagePackFormatter Instance = new(); + + private RawMessagePackFormatter() + { + } + + public override RawMessagePack Read(ref MessagePackReader reader, SerializationContext context) + { + return new RawMessagePack(reader.ReadRaw(context)); + } + + public override void Write(ref MessagePackWriter writer, in RawMessagePack value, SerializationContext context) + { + writer.WriteRaw(value); + } + } + private class ProgressConverterResolver { private readonly NerdbankMessagePackFormatter mainFormatter; @@ -655,7 +533,7 @@ internal PreciseTypeConverter(NerdbankMessagePackFormatter formatter) } Assumes.NotNull(this.formatter.JsonRpc); - RawMessagePack token = reader.ReadRaw(context); + RawMessagePack token = (RawMessagePack)reader.ReadRaw(context); bool clientRequiresNamedArgs = this.formatter.ApplicableMethodAttributeOnDeserializingMethod?.ClientRequiresNamedArguments is true; return (TClass)this.formatter.FormatterProgressTracker.CreateProgress(this.formatter.JsonRpc, token, typeof(TClass), clientRequiresNamedArgs); } @@ -697,11 +575,15 @@ public MessagePackConverter GetConverter() if (TrackerHelpers>.IsActualInterfaceMatch(typeof(T))) { - converter = (MessagePackConverter?)Activator.CreateInstance(typeof(PreciseTypeConverter<>).MakeGenericType(typeof(T).GenericTypeArguments[0]), new object[] { this.mainFormatter }); + converter = (MessagePackConverter?)Activator.CreateInstance( + typeof(PreciseTypeConverter<>).MakeGenericType(typeof(T).GenericTypeArguments[0]), + [this.mainFormatter]); } else if (TrackerHelpers>.FindInterfaceImplementedBy(typeof(T)) is { } iface) { - converter = (MessagePackConverter?)Activator.CreateInstance(typeof(GeneratorConverter<,>).MakeGenericType(typeof(T), iface.GenericTypeArguments[0]), new object[] { this.mainFormatter }); + converter = (MessagePackConverter?)Activator.CreateInstance( + typeof(GeneratorConverter<,>).MakeGenericType(typeof(T), iface.GenericTypeArguments[0]), + [this.mainFormatter]); } // TODO: Improve Exception @@ -719,40 +601,32 @@ private partial class PreciseTypeConverter(NerdbankMessagePackFormatter mainF /// /// The constant "token", in its various forms. /// - private static readonly CommonString TokenPropertyName = new(MessageFormatterEnumerableTracker.TokenPropertyName); + private static readonly MessagePackString TokenPropertyName = new(MessageFormatterEnumerableTracker.TokenPropertyName); /// /// The constant "values", in its various forms. /// - private static readonly CommonString ValuesPropertyName = new(MessageFormatterEnumerableTracker.ValuesPropertyName); + private static readonly MessagePackString ValuesPropertyName = new(MessageFormatterEnumerableTracker.ValuesPropertyName); public override IAsyncEnumerable? Read(ref MessagePackReader reader, SerializationContext context) { - context.DepthStep(); - if (reader.TryReadNil()) { return default; } + context.DepthStep(); + RawMessagePack? token = default; IReadOnlyList? initialElements = null; int propertyCount = reader.ReadMapHeader(); for (int i = 0; i < propertyCount; i++) { - if (!reader.TryReadStringSpan(out ReadOnlySpan stringKey)) + if (TokenPropertyName.TryRead(ref reader)) { - ReadOnlySequence keySequence = reader.ReadStringSequence() ?? throw new MessagePackSerializationException(Resources.FailureDeserializingJsonRpc); - stringKey = keySequence.IsSingleSegment - ? keySequence.First.Span - : keySequence.ToArray(); + token = (RawMessagePack)reader.ReadRaw(context); } - - if (TokenPropertyName.TryRead(stringKey)) - { - token = reader.ReadRaw(context); - } - else if (ValuesPropertyName.TryRead(stringKey)) + else if (ValuesPropertyName.TryRead(ref reader)) { initialElements = context.GetConverter>(context.TypeShapeProvider).Read(ref reader, context); } @@ -762,7 +636,7 @@ private partial class PreciseTypeConverter(NerdbankMessagePackFormatter mainF } } - return mainFormatter.EnumerableTracker.CreateEnumerableProxy(token.HasValue ? token : null, initialElements); + return mainFormatter.EnumerableTracker.CreateEnumerableProxy(token.HasValue ? (object)token : null, initialElements); } [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to helper method")] @@ -785,32 +659,32 @@ internal static void Serialize_Shared(NerdbankMessagePackFormatter mainFormatter } else { - (IReadOnlyList Elements, bool Finished) prefetched = value.TearOffPrefetchedElements(); + (IReadOnlyList elements, bool finished) = value.TearOffPrefetchedElements(); long token = mainFormatter.EnumerableTracker.GetToken(value); int propertyCount = 0; - if (prefetched.Elements.Count > 0) + if (elements.Count > 0) { propertyCount++; } - if (!prefetched.Finished) + if (!finished) { propertyCount++; } writer.WriteMapHeader(propertyCount); - if (!prefetched.Finished) + if (!finished) { - writer.Write(MessageFormatterEnumerableTracker.TokenPropertyName); + writer.WriteRaw(TokenPropertyName.MsgPack.Span); writer.Write(token); } - if (prefetched.Elements.Count > 0) + if (elements.Count > 0) { - writer.Write(MessageFormatterEnumerableTracker.ValuesPropertyName); - context.GetConverter>(context.TypeShapeProvider).Write(ref writer, prefetched.Elements, context); + writer.WriteRaw(ValuesPropertyName.MsgPack.Span); + context.GetConverter>(context.TypeShapeProvider).Write(ref writer, elements, context); } } } @@ -1044,7 +918,7 @@ private class RpcMarshalableConverter( context.DepthStep(); MessageFormatterRpcMarshaledContextTracker.MarshalToken? token = formatter - .rpcContext + .rpcProfile .Deserialize( ref reader, context.CancellationToken); @@ -1064,7 +938,7 @@ public override void Write(ref MessagePackWriter writer, in T? value, Serializat else { MessageFormatterRpcMarshaledContextTracker.MarshalToken token = formatter.RpcMarshaledContextTracker.GetToken(value, targetOptions, typeof(T), rpcMarshalableAttribute); - formatter.rpcContext.Serialize(ref writer, token, context.CancellationToken); + formatter.rpcProfile.Serialize(ref writer, token, context.CancellationToken); } } @@ -1143,7 +1017,7 @@ private partial class ExceptionConverter(NerdbankMessagePackFormatter formatt } // TODO: Is this the right context? - var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(formatter.rpcContext)); + var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(formatter.rpcProfile)); int memberCount = reader.ReadMapHeader(); for (int i = 0; i < memberCount; i++) { @@ -1188,13 +1062,13 @@ public override void Write(ref MessagePackWriter writer, in T? value, Serializat } // TODO: Is this the right context? - var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(formatter.rpcContext)); + var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(formatter.rpcProfile)); ExceptionSerializationHelpers.Serialize(value, info); writer.WriteMapHeader(info.GetSafeMemberCount()); foreach (SerializationEntry element in info.GetSafeMembers()) { writer.Write(element.Name); - formatter.rpcContext.SerializeObject( + formatter.rpcProfile.SerializeObject( ref writer, element.Value, element.ObjectType, @@ -1232,25 +1106,15 @@ internal JsonRpcMessageConverter(NerdbankMessagePackFormatter formatter) int propertyCount = readAhead.ReadMapHeader(); for (int i = 0; i < propertyCount; i++) { - // We read the property name in this fancy way in order to avoid paying to decode and allocate a string when we already know what we're looking for. - // MessagePackFormatter: ReadOnlySpan stringKey = MessagePack.Internal.CodeGenHelpers.ReadStringSpan(ref readAhead); - if (!readAhead.TryReadStringSpan(out ReadOnlySpan stringKey)) - { - ReadOnlySequence? keySequence = readAhead.ReadStringSequence() ?? throw new MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); - stringKey = keySequence.Value.IsSingleSegment - ? keySequence.Value.First.Span - : keySequence.Value.ToArray(); - } - - if (MethodPropertyName.TryRead(stringKey)) + if (MethodPropertyName.TryRead(ref readAhead)) { return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); } - else if (ResultPropertyName.TryRead(stringKey)) + else if (ResultPropertyName.TryRead(ref readAhead)) { return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); } - else if (ErrorPropertyName.TryRead(stringKey)) + else if (ErrorPropertyName.TryRead(ref readAhead)) { return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); } @@ -1317,28 +1181,19 @@ internal JsonRpcRequestConverter(NerdbankMessagePackFormatter formatter) int propertyCount = reader.ReadMapHeader(); for (int propertyIndex = 0; propertyIndex < propertyCount; propertyIndex++) { - // We read the property name in this fancy way in order to avoid paying to decode and allocate a string when we already know what we're looking for. - if (!reader.TryReadStringSpan(out ReadOnlySpan stringKey)) - { - ReadOnlySequence? keySequence = reader.ReadStringSequence() ?? throw new MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); - stringKey = keySequence.Value.IsSingleSegment - ? keySequence.Value.First.Span - : keySequence.Value.ToArray(); - } - - if (VersionPropertyName.TryRead(stringKey)) + if (VersionPropertyName.TryRead(ref reader)) { result.Version = ReadProtocolVersion(ref reader); } - else if (IdPropertyName.TryRead(stringKey)) + else if (IdPropertyName.TryRead(ref reader)) { result.RequestId = context.GetConverter(null).Read(ref reader, context); } - else if (MethodPropertyName.TryRead(stringKey)) + else if (MethodPropertyName.TryRead(ref reader)) { result.Method = context.GetConverter(null).Read(ref reader, context); } - else if (ParamsPropertyName.TryRead(stringKey)) + else if (ParamsPropertyName.TryRead(ref reader)) { SequencePosition paramsTokenStartPosition = reader.Position; @@ -1375,24 +1230,24 @@ internal JsonRpcRequestConverter(NerdbankMessagePackFormatter formatter) result.MsgPackArguments = reader.Sequence.Slice(paramsTokenStartPosition, reader.Position); } - else if (TraceParentPropertyName.TryRead(stringKey)) + else if (TraceParentPropertyName.TryRead(ref reader)) { TraceParent traceParent = context.GetConverter(null).Read(ref reader, context); result.TraceParent = traceParent.ToString(); } - else if (TraceStatePropertyName.TryRead(stringKey)) + else if (TraceStatePropertyName.TryRead(ref reader)) { result.TraceState = ReadTraceState(ref reader, context); } else { - ReadUnknownProperty(ref reader, context, ref topLevelProperties, stringKey); + ReadUnknownProperty(ref reader, context, ref topLevelProperties, reader.ReadStringSpan()); } } if (topLevelProperties is not null) { - result.TopLevelPropertyBag = new TopLevelPropertyBag(this.formatter.userDataContext, topLevelProperties); + result.TopLevelPropertyBag = new TopLevelPropertyBag(this.formatter.userDataProfile, topLevelProperties); } this.formatter.TryHandleSpecialIncomingMessage(result); @@ -1425,15 +1280,15 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcRequ if (!value.RequestId.IsEmpty) { - IdPropertyName.Write(ref writer); + writer.WriteRaw(IdPropertyName.MsgPack.Span); context.GetConverter(context.TypeShapeProvider) .Write(ref writer, value.RequestId, context); } - MethodPropertyName.Write(ref writer); + writer.WriteRaw(MethodPropertyName.MsgPack.Span); writer.Write(value.Method); - ParamsPropertyName.Write(ref writer); + writer.WriteRaw(ParamsPropertyName.MsgPack.Span); if (value.ArgumentsList is not null) { writer.WriteArrayHeader(value.ArgumentsList.Count); @@ -1444,14 +1299,14 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcRequ if (value.ArgumentListDeclaredTypes is null) { - this.formatter.userDataContext.SerializeObject( + this.formatter.userDataProfile.SerializeObject( ref writer, arg, context.CancellationToken); } else { - this.formatter.userDataContext.SerializeObject( + this.formatter.userDataProfile.SerializeObject( ref writer, arg, value.ArgumentListDeclaredTypes[i], @@ -1468,7 +1323,7 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcRequ if (value.NamedArgumentDeclaredTypes is null) { - this.formatter.userDataContext.SerializeObject( + this.formatter.userDataProfile.SerializeObject( ref writer, entry.Value, context.CancellationToken); @@ -1476,7 +1331,7 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcRequ else { Type argType = value.NamedArgumentDeclaredTypes[entry.Key]; - this.formatter.userDataContext.SerializeObject( + this.formatter.userDataProfile.SerializeObject( ref writer, entry.Value, argType, @@ -1491,13 +1346,13 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcRequ if (value.TraceParent?.Length > 0) { - TraceParentPropertyName.Write(ref writer); + writer.WriteRaw(TraceParentPropertyName.MsgPack.Span); context.GetConverter(context.TypeShapeProvider) .Write(ref writer, new TraceParent(value.TraceParent), context); if (value.TraceState?.Length > 0) { - TraceStatePropertyName.Write(ref writer); + writer.WriteRaw(TraceStatePropertyName.MsgPack.Span); WriteTraceState(ref writer, value.TraceState); } } @@ -1588,7 +1443,7 @@ public override Protocol.JsonRpcResult Read(ref MessagePackReader reader, Serial { context.DepthStep(); - var result = new JsonRpcResult(this.formatter, this.formatter.userDataContext) + var result = new JsonRpcResult(this.formatter, this.formatter.userDataProfile) { OriginalMessagePack = reader.Sequence, }; @@ -1598,36 +1453,27 @@ public override Protocol.JsonRpcResult Read(ref MessagePackReader reader, Serial int propertyCount = reader.ReadMapHeader(); for (int propertyIndex = 0; propertyIndex < propertyCount; propertyIndex++) { - // We read the property name in this fancy way in order to avoid paying to decode and allocate a string when we already know what we're looking for. - if (!reader.TryReadStringSpan(out ReadOnlySpan stringKey)) - { - ReadOnlySequence? keySequence = reader.ReadStringSequence() ?? throw new MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); - stringKey = keySequence.Value.IsSingleSegment - ? keySequence.Value.First.Span - : keySequence.Value.ToArray(); - } - - if (VersionPropertyName.TryRead(stringKey)) + if (VersionPropertyName.TryRead(ref reader)) { result.Version = ReadProtocolVersion(ref reader); } - else if (IdPropertyName.TryRead(stringKey)) + else if (IdPropertyName.TryRead(ref reader)) { result.RequestId = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); } - else if (ResultPropertyName.TryRead(stringKey)) + else if (ResultPropertyName.TryRead(ref reader)) { result.MsgPackResult = GetSliceForNextToken(ref reader, context); } else { - ReadUnknownProperty(ref reader, context, ref topLevelProperties, stringKey); + ReadUnknownProperty(ref reader, context, ref topLevelProperties, reader.ReadStringSpan()); } } if (topLevelProperties is not null) { - result.TopLevelPropertyBag = new TopLevelPropertyBag(this.formatter.userDataContext, topLevelProperties); + result.TopLevelPropertyBag = new TopLevelPropertyBag(this.formatter.userDataProfile, topLevelProperties); } return result; @@ -1647,10 +1493,10 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcResu WriteProtocolVersionPropertyAndValue(ref writer, value.Version); - IdPropertyName.Write(ref writer); + writer.WriteRaw(IdPropertyName.MsgPack.Span); context.GetConverter(context.TypeShapeProvider).Write(ref writer, value.RequestId, context); - ResultPropertyName.Write(ref writer); + writer.WriteRaw(ResultPropertyName.MsgPack.Span); if (value.Result is null) { @@ -1659,11 +1505,11 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcResu if (value.ResultDeclaredType is not null && value.ResultDeclaredType != typeof(void)) { - this.formatter.userDataContext.SerializeObject(ref writer, value.Result, value.ResultDeclaredType, context.CancellationToken); + this.formatter.userDataProfile.SerializeObject(ref writer, value.Result, value.ResultDeclaredType, context.CancellationToken); } else { - this.formatter.userDataContext.SerializeObject(ref writer, value.Result, context.CancellationToken); + this.formatter.userDataProfile.SerializeObject(ref writer, value.Result, context.CancellationToken); } (topLevelPropertyBagMessage?.TopLevelPropertyBag as TopLevelPropertyBag)?.WriteProperties(ref writer); @@ -1686,48 +1532,38 @@ internal JsonRpcErrorConverter(NerdbankMessagePackFormatter formatter) public override Protocol.JsonRpcError Read(ref MessagePackReader reader, SerializationContext context) { - var error = new JsonRpcError(this.formatter.rpcContext) + var error = new JsonRpcError(this.formatter.rpcProfile) { OriginalMessagePack = reader.Sequence, }; - context.DepthStep(); - Dictionary>? topLevelProperties = null; + context.DepthStep(); int propertyCount = reader.ReadMapHeader(); for (int propertyIdx = 0; propertyIdx < propertyCount; propertyIdx++) { - // We read the property name in this fancy way in order to avoid paying to decode and allocate a string when we already know what we're looking for. - if (!reader.TryReadStringSpan(out ReadOnlySpan stringKey)) - { - ReadOnlySequence? keySequence = reader.ReadStringSequence() ?? throw new MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); - stringKey = keySequence.Value.IsSingleSegment - ? keySequence.Value.First.Span - : keySequence.Value.ToArray(); - } - - if (VersionPropertyName.TryRead(stringKey)) + if (VersionPropertyName.TryRead(ref reader)) { error.Version = ReadProtocolVersion(ref reader); } - else if (IdPropertyName.TryRead(stringKey)) + else if (IdPropertyName.TryRead(ref reader)) { error.RequestId = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); } - else if (ErrorPropertyName.TryRead(stringKey)) + else if (ErrorPropertyName.TryRead(ref reader)) { error.Error = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); } else { - ReadUnknownProperty(ref reader, context, ref topLevelProperties, stringKey); + ReadUnknownProperty(ref reader, context, ref topLevelProperties, reader.ReadStringSpan()); } } if (topLevelProperties is not null) { - error.TopLevelPropertyBag = new TopLevelPropertyBag(this.formatter.userDataContext, topLevelProperties); + error.TopLevelPropertyBag = new TopLevelPropertyBag(this.formatter.userDataProfile, topLevelProperties); } return error; @@ -1737,21 +1573,20 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcErro { Requires.NotNull(value!, nameof(value)); - context.DepthStep(); - var topLevelPropertyBag = (TopLevelPropertyBag?)(value as IMessageWithTopLevelPropertyBag)?.TopLevelPropertyBag; + context.DepthStep(); int mapElementCount = 3; mapElementCount += topLevelPropertyBag?.PropertyCount ?? 0; writer.WriteMapHeader(mapElementCount); WriteProtocolVersionPropertyAndValue(ref writer, value.Version); - IdPropertyName.Write(ref writer); + writer.WriteRaw(IdPropertyName.MsgPack.Span); context.GetConverter(context.TypeShapeProvider) .Write(ref writer, value.RequestId, context); - ErrorPropertyName.Write(ref writer); + writer.WriteRaw(ErrorPropertyName.MsgPack.Span); context.GetConverter(context.TypeShapeProvider) .Write(ref writer, value.Error, context); @@ -1766,9 +1601,9 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcErro private partial class JsonRpcErrorDetailConverter : MessagePackConverter { - private static readonly CommonString CodePropertyName = new("code"); - private static readonly CommonString MessagePropertyName = new("message"); - private static readonly CommonString DataPropertyName = new("data"); + private static readonly MessagePackString CodePropertyName = new("code"); + private static readonly MessagePackString MessagePropertyName = new("message"); + private static readonly MessagePackString DataPropertyName = new("data"); private readonly NerdbankMessagePackFormatter formatter; @@ -1781,28 +1616,20 @@ public override Protocol.JsonRpcError.ErrorDetail Read(ref MessagePackReader rea { context.DepthStep(); - var result = new JsonRpcError.ErrorDetail(this.formatter.userDataContext); + var result = new JsonRpcError.ErrorDetail(this.formatter.userDataProfile); int propertyCount = reader.ReadMapHeader(); for (int propertyIdx = 0; propertyIdx < propertyCount; propertyIdx++) { - if (!reader.TryReadStringSpan(out ReadOnlySpan stringKey)) - { - ReadOnlySequence? keySequence = reader.ReadStringSequence() ?? throw new MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); - stringKey = keySequence.Value.IsSingleSegment - ? keySequence.Value.First.Span - : keySequence.Value.ToArray(); - } - - if (CodePropertyName.TryRead(stringKey)) + if (CodePropertyName.TryRead(ref reader)) { result.Code = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); } - else if (MessagePropertyName.TryRead(stringKey)) + else if (MessagePropertyName.TryRead(ref reader)) { result.Message = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); } - else if (DataPropertyName.TryRead(stringKey)) + else if (DataPropertyName.TryRead(ref reader)) { result.MsgPackData = GetSliceForNextToken(ref reader, context); } @@ -1824,15 +1651,15 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcErro writer.WriteMapHeader(3); - CodePropertyName.Write(ref writer); + writer.WriteRaw(CodePropertyName.MsgPack.Span); context.GetConverter(context.TypeShapeProvider) .Write(ref writer, value.Code, context); - MessagePropertyName.Write(ref writer); + writer.WriteRaw(MessagePropertyName.MsgPack.Span); writer.Write(value.Message); - DataPropertyName.Write(ref writer); - this.formatter.userDataContext.Serialize(ref writer, value.Data, context.CancellationToken); + writer.WriteRaw(DataPropertyName.MsgPack.Span); + this.formatter.userDataProfile.Serialize(ref writer, value.Data, context.CancellationToken); } public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) @@ -2032,7 +1859,7 @@ internal OutboundJsonRpcRequest(NerdbankMessagePackFormatter formatter) this.formatter = formatter ?? throw new ArgumentNullException(nameof(formatter)); } - protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.formatter.userDataContext); + protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.formatter.userDataProfile); } [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] @@ -2068,9 +1895,9 @@ public override ArgumentMatchResult TryGetTypedArguments(ReadOnlySpan new TopLevelPropertyBag(this.formatter.userDataContext); + protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.formatter.userDataProfile); } [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatterProfileExtensions.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatterProfileExtensions.cs new file mode 100644 index 000000000..16a0510c7 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatterProfileExtensions.cs @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Buffers; +using Nerdbank.MessagePack; +using PolyType.Abstractions; +using static StreamJsonRpc.NerdbankMessagePackFormatter; + +namespace StreamJsonRpc; + +internal static class NerdbankMessagePackFormatterProfileExtensions +{ + internal static T? Deserialize(this FormatterProfile profile, ref MessagePackReader reader, CancellationToken cancellationToken = default) + { + return profile.Serializer.Deserialize(ref reader, profile.ShapeProvider, cancellationToken); + } + + internal static T Deserialize(this FormatterProfile profile, in ReadOnlySequence pack, CancellationToken cancellationToken = default) + { + // TODO: Improve the exception + return profile.Serializer.Deserialize(pack, profile.ShapeProvider, cancellationToken) + ?? throw new MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); + } + + internal static object? DeserializeObject(this FormatterProfile profile, in ReadOnlySequence pack, Type objectType, CancellationToken cancellationToken = default) + { + MessagePackReader reader = new(pack); + return profile.Serializer.DeserializeObject( + ref reader, + profile.ShapeProvider.Resolve(objectType), + cancellationToken); + } + + internal static void Serialize(this FormatterProfile profile, ref MessagePackWriter writer, T? value, CancellationToken cancellationToken = default) + { + profile.Serializer.Serialize(ref writer, value, profile.ShapeProvider, cancellationToken); + } + + internal static void SerializeObject(this FormatterProfile profile, ref MessagePackWriter writer, object? value, Type objectType, CancellationToken cancellationToken = default) + { + if (value is null) + { + writer.WriteNil(); + return; + } + + profile.Serializer.SerializeObject(ref writer, value, profile.ShapeProvider.Resolve(objectType), cancellationToken); + } + + internal static void SerializeObject(this FormatterProfile profile, ref MessagePackWriter writer, object? value, CancellationToken cancellationToken = default) + { + if (value is null) + { + writer.WriteNil(); + return; + } + + profile.Serializer.SerializeObject(ref writer, value, profile.ShapeProvider.Resolve(value.GetType()), cancellationToken); + } +} diff --git a/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs index 1e40c9187..ba5b064bc 100644 --- a/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs @@ -11,17 +11,22 @@ public AsyncEnumerableNerdbankMessagePackTests(ITestOutputHelper logger) protected override void InitializeFormattersAndHandlers() { NerdbankMessagePackFormatter serverFormatter = new(); - serverFormatter.SetFormatterProfile(ConfigureContext); + NerdbankMessagePackFormatter.FormatterProfile serverProfile = ConfigureContext(serverFormatter.ProfileBuilder); + serverFormatter.SetFormatterProfile(serverProfile); NerdbankMessagePackFormatter clientFormatter = new(); - clientFormatter.SetFormatterProfile(ConfigureContext); + NerdbankMessagePackFormatter.FormatterProfile clientProfile = ConfigureContext(clientFormatter.ProfileBuilder); + clientFormatter.SetFormatterProfile(clientProfile); this.serverMessageFormatter = serverFormatter; this.clientMessageFormatter = clientFormatter; - static void ConfigureContext(NerdbankMessagePackFormatter.FormatterProfileBuilder contextBuilder) + static NerdbankMessagePackFormatter.FormatterProfile ConfigureContext(NerdbankMessagePackFormatter.FormatterProfileBuilder profileBuilder) { - contextBuilder.RegisterAsyncEnumerableType, int>(); + profileBuilder.RegisterAsyncEnumerableType, int>(); + profileBuilder.AddTypeShapeProvider(PolyType.SourceGenerator.ShapeProvider_StreamJsonRpc_Tests.Default); + profileBuilder.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); + return profileBuilder.Build(); } } } diff --git a/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs b/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs index 7b1aed635..7e503ad57 100644 --- a/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs +++ b/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs @@ -371,7 +371,7 @@ public async Task Cancellation_DuringLongRunningServerMoveNext(bool useProxy) await Assert.ThrowsAnyAsync(async () => await moveNextTask).WithCancellation(this.TimeoutToken); } - [Theory] + [Theory(Timeout = 2 * 1000)] // TODO: Temporary for development [PairwiseData] public async Task Cancellation_DuringLongRunningServerBeforeReturning(bool useProxy, [CombinatorialValues(0, 1, 2, 3)] int prefetchStrategy) { @@ -543,7 +543,7 @@ public async Task AsyncIteratorThrows(int minBatchSize, int maxReadAhead, int pr Assert.Equal(Server.FailByDesignExceptionMessage, ex.Message); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development public async Task EnumerableIdDisposal() { // This test is specially arranged to create two RPC calls going opposite directions, with the same request ID. diff --git a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs index 924929ae3..f506e9f3b 100644 --- a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using Nerdbank.Streams; + public class DuplexPipeMarshalingNerdbankMessagePackTests : DuplexPipeMarshalingTests { public DuplexPipeMarshalingNerdbankMessagePackTests(ITestOutputHelper logger) @@ -10,7 +12,33 @@ public DuplexPipeMarshalingNerdbankMessagePackTests(ITestOutputHelper logger) protected override void InitializeFormattersAndHandlers() { - this.serverMessageFormatter = new NerdbankMessagePackFormatter { MultiplexingStream = this.serverMx }; - this.clientMessageFormatter = new NerdbankMessagePackFormatter { MultiplexingStream = this.clientMx }; + NerdbankMessagePackFormatter serverFormatter = new() + { + MultiplexingStream = this.serverMx, + }; + + serverFormatter.SetFormatterProfile(b => + { + b.RegisterDuplexPipeType(); + b.AddTypeShapeProvider(PolyType.SourceGenerator.ShapeProvider_StreamJsonRpc_Tests.Default); + b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); + return b.Build(); + }); + + NerdbankMessagePackFormatter clientFormatter = new() + { + MultiplexingStream = this.clientMx, + }; + + clientFormatter.SetFormatterProfile(b => + { + b.RegisterDuplexPipeType(); + b.AddTypeShapeProvider(PolyType.SourceGenerator.ShapeProvider_StreamJsonRpc_Tests.Default); + b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); + return b.Build(); + }); + + this.serverMessageFormatter = serverFormatter; + this.clientMessageFormatter = clientFormatter; } } diff --git a/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs index 2de5c5cd2..294154beb 100644 --- a/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs +++ b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs @@ -403,12 +403,15 @@ protected override void InitializeFormattersAndHandlers( ? new DelayedFlushingHandler(clientStream, clientMessageFormatter) : new LengthHeaderMessageHandler(clientStream, clientStream, clientMessageFormatter); - static void Configure(NerdbankMessagePackFormatter.FormatterProfileBuilder b) + static NerdbankMessagePackFormatter.FormatterProfile Configure(NerdbankMessagePackFormatter.FormatterProfileBuilder b) { b.RegisterConverter(new UnserializableTypeConverter()); b.RegisterConverter(new TypeThrowsWhenDeserializedConverter()); b.RegisterConverter(new CustomExtensionConverter()); + b.RegisterStreamType(); b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); + return b.Build(); } } diff --git a/test/StreamJsonRpc.Tests/JsonRpcTests.cs b/test/StreamJsonRpc.Tests/JsonRpcTests.cs index 3ee01caf1..ced85ca53 100644 --- a/test/StreamJsonRpc.Tests/JsonRpcTests.cs +++ b/test/StreamJsonRpc.Tests/JsonRpcTests.cs @@ -10,6 +10,7 @@ using Microsoft.VisualStudio.Threading; using Nerdbank.Streams; using Newtonsoft.Json.Linq; +using PolyType; using JsonNET = Newtonsoft.Json; using STJ = System.Text.Json.Serialization; @@ -1118,7 +1119,7 @@ public async Task CancelMayStillReturnResultFromServer() } } - [Theory, PairwiseData] + [Theory(Timeout = 2 * 1000), PairwiseData] // TODO: temporary for development public async Task CancelMayStillReturnErrorFromServer(ExceptionProcessing exceptionStrategy) { this.clientRpc.AllowModificationWhileListening = true; @@ -1946,7 +1947,7 @@ public void AddLocalRpcMethod_String_MethodInfo_Object_NullTargetForInstanceMeth Assert.Throws(() => this.serverRpc.AddLocalRpcMethod("biz.bar", methodInfo, null)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development public async Task ServerMethodIsCanceledWhenConnectionDrops() { this.ReinitializeRpcWithoutListening(); @@ -1962,7 +1963,7 @@ public async Task ServerMethodIsCanceledWhenConnectionDrops() Assert.Null(oce.InnerException); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development public async Task ServerMethodIsNotCanceledWhenConnectionDrops() { Assert.False(this.serverRpc.CancelLocallyInvokedMethodsWhenConnectionIsClosed); @@ -3315,7 +3316,7 @@ public class BaseClass } #pragma warning disable CA1801 // use all parameters - public class Server : BaseClass, IServerDerived + public partial class Server : BaseClass, IServerDerived { internal const string ExceptionMessage = "some message"; internal const string ThrowAfterCancellationMessage = "Throw after cancellation"; @@ -3970,10 +3971,12 @@ public class Foo { [DataMember(Order = 0, IsRequired = true)] [STJ.JsonRequired, STJ.JsonPropertyOrder(0)] + [PropertyShape(Order = 0)] public string? Bar { get; set; } [DataMember(Order = 1)] [STJ.JsonPropertyOrder(1)] + [PropertyShape(Order = 1)] public int Bazz { get; set; } } @@ -3982,6 +3985,7 @@ public class CustomSerializedType // Ignore this so default serializers will drop it, proving that custom serializers were used if the value propagates. [JsonNET.JsonIgnore] [IgnoreDataMember] + [PropertyShape(Ignore = true)] public string? Value { get; set; } } @@ -3989,6 +3993,7 @@ public class CustomSerializedType public class CustomISerializableData : ISerializable { [MessagePack.SerializationConstructor] + [ConstructorShape] public CustomISerializableData(int major) { this.Major = major; diff --git a/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs index 13d44420a..443185613 100644 --- a/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs @@ -12,5 +12,31 @@ public MarshalableProxyNerdbankMessagePackTests(ITestOutputHelper logger) protected override Type FormatterExceptionType => typeof(MessagePackSerializationException); - protected override IJsonRpcMessageFormatter CreateFormatter() => new NerdbankMessagePackFormatter(); + protected override IJsonRpcMessageFormatter CreateFormatter() + { + NerdbankMessagePackFormatter formatter = new(); + formatter.SetFormatterProfile(b => + { + b.RegisterRpcMarshalableType(); + b.RegisterRpcMarshalableType(); + b.RegisterRpcMarshalableType(); + b.RegisterRpcMarshalableType(); + b.RegisterRpcMarshalableType(); + b.RegisterRpcMarshalableType(); + b.RegisterRpcMarshalableType(); + b.RegisterRpcMarshalableType(); + b.RegisterRpcMarshalableType(); + b.RegisterRpcMarshalableType(); + b.RegisterRpcMarshalableType(); + b.RegisterRpcMarshalableType(); + b.RegisterRpcMarshalableType(); + b.RegisterRpcMarshalableType(); + b.RegisterRpcMarshalableType(); + b.AddTypeShapeProvider(PolyType.SourceGenerator.ShapeProvider_StreamJsonRpc_Tests.Default); + b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); + return b.Build(); + }); + + return formatter; + } } diff --git a/test/StreamJsonRpc.Tests/MarshalableProxyTests.cs b/test/StreamJsonRpc.Tests/MarshalableProxyTests.cs index 8d51a31b2..e762d7cdc 100644 --- a/test/StreamJsonRpc.Tests/MarshalableProxyTests.cs +++ b/test/StreamJsonRpc.Tests/MarshalableProxyTests.cs @@ -9,8 +9,10 @@ using MessagePack; using Microsoft; using Microsoft.VisualStudio.Threading; +using Nerdbank.MessagePack; using Nerdbank.Streams; using Newtonsoft.Json; +using PolyType; using StreamJsonRpc; using Xunit; using Xunit.Abstractions; @@ -49,6 +51,7 @@ protected MarshalableProxyTests(ITestOutputHelper logger) [RpcMarshalable] [JsonConverter(typeof(MarshalableConverter))] [MessagePackFormatter(typeof(MarshalableFormatter))] + [MessagePackConverter(typeof(MarshalableNerdbankConverter))] public interface IMarshalableAndSerializable : IMarshalable { private class MarshalableConverter : JsonConverter @@ -71,12 +74,25 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer private class MarshalableFormatter : MessagePack.Formatters.IMessagePackFormatter { - public IMarshalableAndSerializable Deserialize(ref MessagePackReader reader, MessagePackSerializerOptions options) + public IMarshalableAndSerializable Deserialize(ref MessagePack.MessagePackReader reader, MessagePackSerializerOptions options) { throw new NotImplementedException(); } - public void Serialize(ref MessagePackWriter writer, IMarshalableAndSerializable value, MessagePackSerializerOptions options) + public void Serialize(ref MessagePack.MessagePackWriter writer, IMarshalableAndSerializable value, MessagePackSerializerOptions options) + { + throw new NotImplementedException(); + } + } + + private class MarshalableNerdbankConverter : Nerdbank.MessagePack.MessagePackConverter + { + public override IMarshalableAndSerializable? Read(ref Nerdbank.MessagePack.MessagePackReader reader, Nerdbank.MessagePack.SerializationContext context) + { + throw new NotImplementedException(); + } + + public override void Write(ref Nerdbank.MessagePack.MessagePackWriter writer, in IMarshalableAndSerializable? value, Nerdbank.MessagePack.SerializationContext context) { throw new NotImplementedException(); } diff --git a/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs b/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs index 519f0bf71..9903fb90c 100644 --- a/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs +++ b/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs @@ -135,6 +135,7 @@ public void Resolver_RequestArgInArray() { b.RegisterConverter(new CustomConverter()); b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + return b.Build(); }); var originalArg = new TypeRequiringCustomFormatter { Prop1 = 3, Prop2 = 5 }; @@ -158,6 +159,7 @@ public void Resolver_RequestArgInNamedArgs_AnonymousType() { b.RegisterConverter(new CustomConverter()); b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + return b.Build(); }); var originalArg = new { Prop1 = 3, Prop2 = 5 }; @@ -181,6 +183,7 @@ public void Resolver_RequestArgInNamedArgs_DataContractObject() { b.RegisterConverter(new CustomConverter()); b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + return b.Build(); }); var originalArg = new DataContractWithSubsetOfMembersIncluded { ExcludedField = "A", ExcludedProperty = "B", IncludedField = "C", IncludedProperty = "D" }; @@ -206,6 +209,7 @@ public void Resolver_RequestArgInNamedArgs_NonDataContractObject() { b.RegisterConverter(new CustomConverter()); b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + return b.Build(); }); var originalArg = new NonDataContractWithExcludedMembers { ExcludedField = "A", ExcludedProperty = "B", InternalField = "C", InternalProperty = "D", PublicField = "E", PublicProperty = "F" }; @@ -247,6 +251,7 @@ public void Resolver_Result() { b.RegisterConverter(new CustomConverter()); b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + return b.Build(); }); var originalResultValue = new TypeRequiringCustomFormatter { Prop1 = 3, Prop2 = 5 }; @@ -268,6 +273,7 @@ public void Resolver_ErrorData() { b.RegisterConverter(new CustomConverter()); b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + return b.Build(); }); var originalErrorData = new TypeRequiringCustomFormatter { Prop1 = 3, Prop2 = 5 }; From 166ae498db6ecde69ad5f467bcdeabb35d1fb39f Mon Sep 17 00:00:00 2001 From: Charles Willis Date: Thu, 26 Dec 2024 00:31:22 +0000 Subject: [PATCH 15/25] Adopt simpler MessagePackString write API --- Directory.Packages.props | 2 +- .../NerdbankMessagePackFormatter.cs | 84 ++++++++++--------- ...nkMessagePackFormatterProfileExtensions.cs | 64 ++++++++++---- ...AsyncEnumerableNerdbankMessagePackTests.cs | 1 - ...DisposableProxyNerdbankMessagePackTests.cs | 16 +++- ...xPipeMarshalingNerdbankMessagePackTests.cs | 7 +- .../NerdbankMessagePackFormatterTests.cs | 24 +++--- 7 files changed, 125 insertions(+), 73 deletions(-) diff --git a/Directory.Packages.props b/Directory.Packages.props index bd3c7d2a2..1e2caeb7b 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -22,7 +22,7 @@ - + diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs index 4243ae456..deb519148 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs @@ -368,8 +368,8 @@ private static unsafe string ReadProtocolVersion(ref MessagePackReader reader) /// private static void WriteProtocolVersionPropertyAndValue(ref MessagePackWriter writer, string version) { - writer.WriteRaw(VersionPropertyName.MsgPack.Span); - writer.WriteRaw(Version2.MsgPack.Span); + writer.Write(VersionPropertyName); + writer.Write(version); } private static void ReadUnknownProperty(ref MessagePackReader reader, in SerializationContext context, ref Dictionary>? topLevelProperties, ReadOnlySpan stringKey) @@ -444,6 +444,11 @@ public override void Write(ref MessagePackWriter writer, in RawMessagePack value { writer.WriteRaw(value); } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(RawMessagePackFormatter)); + } } private class ProgressConverterResolver @@ -677,13 +682,13 @@ internal static void Serialize_Shared(NerdbankMessagePackFormatter mainFormatter if (!finished) { - writer.WriteRaw(TokenPropertyName.MsgPack.Span); + writer.Write(TokenPropertyName); writer.Write(token); } if (elements.Count > 0) { - writer.WriteRaw(ValuesPropertyName.MsgPack.Span); + writer.Write(ValuesPropertyName); context.GetConverter>(context.TypeShapeProvider).Write(ref writer, elements, context); } } @@ -858,10 +863,8 @@ public override void Write(ref MessagePackWriter writer, in T? value, Serializat } } -#pragma warning disable CA1812 private class StreamConverter : MessagePackConverter where T : Stream -#pragma warning restore CA1812 { private readonly NerdbankMessagePackFormatter formatter; @@ -903,22 +906,19 @@ public override void Write(ref MessagePackWriter writer, in T? value, Serializat } } -#pragma warning disable CA1812 private class RpcMarshalableConverter( NerdbankMessagePackFormatter formatter, JsonRpcProxyOptions proxyOptions, JsonRpcTargetOptions targetOptions, RpcMarshalableAttribute rpcMarshalableAttribute) : MessagePackConverter where T : class -#pragma warning restore CA1812 { [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Reader is passed to rpc context")] public override T? Read(ref MessagePackReader reader, SerializationContext context) { context.DepthStep(); - MessageFormatterRpcMarshaledContextTracker.MarshalToken? token = formatter - .rpcProfile + MessageFormatterRpcMarshaledContextTracker.MarshalToken? token = formatter.rpcProfile .Deserialize( ref reader, context.CancellationToken); @@ -1104,17 +1104,19 @@ internal JsonRpcMessageConverter(NerdbankMessagePackFormatter formatter) MessagePackReader readAhead = reader.CreatePeekReader(); int propertyCount = readAhead.ReadMapHeader(); + for (int i = 0; i < propertyCount; i++) { - if (MethodPropertyName.TryRead(ref readAhead)) + ReadOnlySequence stringKey = readAhead.ReadStringSequence() ?? ReadOnlySequence.Empty; + if (MethodPropertyName.IsMatch(stringKey)) { return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); } - else if (ResultPropertyName.TryRead(ref readAhead)) + else if (ResultPropertyName.IsMatch(stringKey)) { return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); } - else if (ErrorPropertyName.TryRead(ref readAhead)) + else if (ErrorPropertyName.IsMatch(stringKey)) { return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); } @@ -1221,7 +1223,7 @@ internal JsonRpcRequestConverter(NerdbankMessagePackFormatter formatter) result.MsgPackNamedArguments = namedArgs; break; case MessagePackType.Nil: - result.MsgPackPositionalArguments = Array.Empty>(); + result.MsgPackPositionalArguments = []; reader.ReadNil(); break; case MessagePackType type: @@ -1280,15 +1282,15 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcRequ if (!value.RequestId.IsEmpty) { - writer.WriteRaw(IdPropertyName.MsgPack.Span); + writer.Write(IdPropertyName); context.GetConverter(context.TypeShapeProvider) .Write(ref writer, value.RequestId, context); } - writer.WriteRaw(MethodPropertyName.MsgPack.Span); + writer.Write(MethodPropertyName); writer.Write(value.Method); - writer.WriteRaw(ParamsPropertyName.MsgPack.Span); + writer.Write(ParamsPropertyName); if (value.ArgumentsList is not null) { writer.WriteArrayHeader(value.ArgumentsList.Count); @@ -1346,13 +1348,13 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcRequ if (value.TraceParent?.Length > 0) { - writer.WriteRaw(TraceParentPropertyName.MsgPack.Span); + writer.Write(TraceParentPropertyName); context.GetConverter(context.TypeShapeProvider) .Write(ref writer, new TraceParent(value.TraceParent), context); if (value.TraceState?.Length > 0) { - writer.WriteRaw(TraceStatePropertyName.MsgPack.Span); + writer.Write(TraceStatePropertyName); WriteTraceState(ref writer, value.TraceState); } } @@ -1493,10 +1495,10 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcResu WriteProtocolVersionPropertyAndValue(ref writer, value.Version); - writer.WriteRaw(IdPropertyName.MsgPack.Span); + writer.Write(IdPropertyName); context.GetConverter(context.TypeShapeProvider).Write(ref writer, value.RequestId, context); - writer.WriteRaw(ResultPropertyName.MsgPack.Span); + writer.Write(ResultPropertyName); if (value.Result is null) { @@ -1532,7 +1534,7 @@ internal JsonRpcErrorConverter(NerdbankMessagePackFormatter formatter) public override Protocol.JsonRpcError Read(ref MessagePackReader reader, SerializationContext context) { - var error = new JsonRpcError(this.formatter.rpcProfile) + var error = new JsonRpcError(this.formatter.userDataProfile) { OriginalMessagePack = reader.Sequence, }; @@ -1582,11 +1584,11 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcErro WriteProtocolVersionPropertyAndValue(ref writer, value.Version); - writer.WriteRaw(IdPropertyName.MsgPack.Span); + writer.Write(IdPropertyName); context.GetConverter(context.TypeShapeProvider) .Write(ref writer, value.RequestId, context); - writer.WriteRaw(ErrorPropertyName.MsgPack.Span); + writer.Write(ErrorPropertyName); context.GetConverter(context.TypeShapeProvider) .Write(ref writer, value.Error, context); @@ -1651,14 +1653,14 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcErro writer.WriteMapHeader(3); - writer.WriteRaw(CodePropertyName.MsgPack.Span); + writer.Write(CodePropertyName); context.GetConverter(context.TypeShapeProvider) .Write(ref writer, value.Code, context); - writer.WriteRaw(MessagePropertyName.MsgPack.Span); + writer.Write(MessagePropertyName); writer.Write(value.Message); - writer.WriteRaw(DataPropertyName.MsgPack.Span); + writer.Write(DataPropertyName); this.formatter.userDataProfile.Serialize(ref writer, value.Data, context.CancellationToken); } @@ -1771,7 +1773,7 @@ public unsafe override void Write(ref MessagePackWriter writer, in TraceParent v private class TopLevelPropertyBag : TopLevelPropertyBagBase { - private readonly FormatterProfile formatterContext; + private readonly FormatterProfile formatterProfile; private readonly IReadOnlyDictionary>? inboundUnknownProperties; /// @@ -1783,7 +1785,7 @@ private class TopLevelPropertyBag : TopLevelPropertyBagBase internal TopLevelPropertyBag(FormatterProfile userDataContext, IReadOnlyDictionary> inboundUnknownProperties) : base(isOutbound: false) { - this.formatterContext = userDataContext; + this.formatterProfile = userDataContext; this.inboundUnknownProperties = inboundUnknownProperties; } @@ -1791,11 +1793,11 @@ internal TopLevelPropertyBag(FormatterProfile userDataContext, IReadOnlyDictiona /// Initializes a new instance of the class /// for an outbound message. /// - /// The serializer options to use for this data. - internal TopLevelPropertyBag(FormatterProfile formatterContext) + /// The serializer options to use for this data. + internal TopLevelPropertyBag(FormatterProfile formatterProfile) : base(isOutbound: true) { - this.formatterContext = formatterContext; + this.formatterProfile = formatterProfile; } internal int PropertyCount => this.inboundUnknownProperties?.Count ?? this.OutboundProperties?.Count ?? 0; @@ -1825,7 +1827,7 @@ internal void WriteProperties(ref MessagePackWriter writer) foreach (KeyValuePair entry in this.OutboundProperties) { writer.Write(entry.Key); - this.formatterContext.SerializeObject(ref writer, entry.Value.Value, entry.Value.DeclaredType); + this.formatterProfile.SerializeObject(ref writer, entry.Value.Value, entry.Value.DeclaredType); } } } @@ -1841,7 +1843,7 @@ protected internal override bool TryGetTopLevelProperty(string name, [MaybeNu if (this.inboundUnknownProperties.TryGetValue(name, out ReadOnlySequence serializedValue) is true) { - value = this.formatterContext.Deserialize(serializedValue); + value = this.formatterProfile.Deserialize(serializedValue); return true; } @@ -1895,9 +1897,9 @@ public override ArgumentMatchResult TryGetTypedArguments(ReadOnlySpan MsgPackData { get; set; } @@ -2075,7 +2077,7 @@ internal ErrorDetail(FormatterProfile serializerOptions) try { - return this.formatterContext.DeserializeObject(this.MsgPackData, dataType) + return this.formatterProfile.DeserializeObject(this.MsgPackData, dataType) ?? throw new MessagePackSerializationException(Resources.FailureDeserializingJsonRpc); } catch (MessagePackSerializationException) @@ -2085,7 +2087,7 @@ internal ErrorDetail(FormatterProfile serializerOptions) { // return MessagePackSerializer.Deserialize(this.MsgPackData, this.serializerOptions.WithResolver(PrimitiveObjectResolver.Instance)); // TODO: Which Shape Provider to use? - return this.formatterContext.Deserialize(this.MsgPackData); + return this.formatterProfile.Deserialize(this.MsgPackData); } catch (MessagePackSerializationException) { diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatterProfileExtensions.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatterProfileExtensions.cs index 16a0510c7..2ce7c0a8c 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatterProfileExtensions.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatterProfileExtensions.cs @@ -8,18 +8,53 @@ namespace StreamJsonRpc; -internal static class NerdbankMessagePackFormatterProfileExtensions +/// +/// Extension methods for that are specific to the . +/// +[System.Diagnostics.CodeAnalysis.SuppressMessage("ApiDesign", "RS0016:Add public types and members to the declared API", Justification = "TODO: Temporary for development")] +public static class NerdbankMessagePackFormatterProfileExtensions { - internal static T? Deserialize(this FormatterProfile profile, ref MessagePackReader reader, CancellationToken cancellationToken = default) + /// + /// Serializes an object using the specified . + /// + /// The formatter profile to use for serialization. + /// The writer to which the object will be serialized. + /// The object to serialize. + /// A token to monitor for cancellation requests. + public static void SerializeObject(this FormatterProfile profile, ref MessagePackWriter writer, object? value, CancellationToken cancellationToken = default) { - return profile.Serializer.Deserialize(ref reader, profile.ShapeProvider, cancellationToken); + Requires.NotNull(profile, nameof(profile)); + + if (value is null) + { + writer.WriteNil(); + return; + } + + profile.Serializer.SerializeObject(ref writer, value, profile.ShapeProvider.Resolve(value.GetType()), cancellationToken); } - internal static T Deserialize(this FormatterProfile profile, in ReadOnlySequence pack, CancellationToken cancellationToken = default) + /// + /// Deserializes a sequence of bytes into an object of type using the specified . + /// + /// The type of the object to deserialize. + /// The formatter profile to use for deserialization. + /// The sequence of bytes to deserialize. + /// A token to monitor for cancellation requests. + /// The deserialized object of type . + /// Thrown when deserialization fails. + public static T Deserialize(this FormatterProfile profile, in ReadOnlySequence pack, CancellationToken cancellationToken = default) { + Requires.NotNull(profile, nameof(profile)); + // TODO: Improve the exception return profile.Serializer.Deserialize(pack, profile.ShapeProvider, cancellationToken) - ?? throw new MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); + ?? throw new MessagePackSerializationException(Resources.FailureDeserializingRpcResult); + } + + internal static T? Deserialize(this FormatterProfile profile, ref MessagePackReader reader, CancellationToken cancellationToken = default) + { + return profile.Serializer.Deserialize(ref reader, profile.ShapeProvider, cancellationToken); } internal static object? DeserializeObject(this FormatterProfile profile, in ReadOnlySequence pack, Type objectType, CancellationToken cancellationToken = default) @@ -31,6 +66,14 @@ internal static T Deserialize(this FormatterProfile profile, in ReadOnlySeque cancellationToken); } + internal static object? DeserializeObject(this FormatterProfile profile, ref MessagePackReader reader, Type objectType, CancellationToken cancellationToken = default) + { + return profile.Serializer.DeserializeObject( + ref reader, + profile.ShapeProvider.Resolve(objectType), + cancellationToken); + } + internal static void Serialize(this FormatterProfile profile, ref MessagePackWriter writer, T? value, CancellationToken cancellationToken = default) { profile.Serializer.Serialize(ref writer, value, profile.ShapeProvider, cancellationToken); @@ -46,15 +89,4 @@ internal static void SerializeObject(this FormatterProfile profile, ref MessageP profile.Serializer.SerializeObject(ref writer, value, profile.ShapeProvider.Resolve(objectType), cancellationToken); } - - internal static void SerializeObject(this FormatterProfile profile, ref MessagePackWriter writer, object? value, CancellationToken cancellationToken = default) - { - if (value is null) - { - writer.WriteNil(); - return; - } - - profile.Serializer.SerializeObject(ref writer, value, profile.ShapeProvider.Resolve(value.GetType()), cancellationToken); - } } diff --git a/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs index ba5b064bc..3dce7a740 100644 --- a/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs @@ -24,7 +24,6 @@ protected override void InitializeFormattersAndHandlers() static NerdbankMessagePackFormatter.FormatterProfile ConfigureContext(NerdbankMessagePackFormatter.FormatterProfileBuilder profileBuilder) { profileBuilder.RegisterAsyncEnumerableType, int>(); - profileBuilder.AddTypeShapeProvider(PolyType.SourceGenerator.ShapeProvider_StreamJsonRpc_Tests.Default); profileBuilder.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); return profileBuilder.Build(); } diff --git a/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs index c4a381e65..0184a756a 100644 --- a/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System.IO.Pipelines; using Nerdbank.MessagePack; public class DisposableProxyNerdbankMessagePackTests : DisposableProxyTests @@ -12,5 +13,18 @@ public DisposableProxyNerdbankMessagePackTests(ITestOutputHelper logger) protected override Type FormatterExceptionType => typeof(MessagePackSerializationException); - protected override IJsonRpcMessageFormatter CreateFormatter() => new NerdbankMessagePackFormatter(); + protected override IJsonRpcMessageFormatter CreateFormatter() + { + NerdbankMessagePackFormatter formatter = new(); + formatter.SetFormatterProfile(b => + { + b.RegisterStreamType(); + b.RegisterDuplexPipeType(); + b.AddTypeShapeProvider(PolyType.SourceGenerator.ShapeProvider_StreamJsonRpc_Tests.Default); + b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); + return b.Build(); + }); + + return formatter; + } } diff --git a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs index f506e9f3b..9b3c56efb 100644 --- a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System.IO.Pipelines; using Nerdbank.Streams; public class DuplexPipeMarshalingNerdbankMessagePackTests : DuplexPipeMarshalingTests @@ -19,8 +20,9 @@ protected override void InitializeFormattersAndHandlers() serverFormatter.SetFormatterProfile(b => { + b.RegisterPipeReaderType(); + b.RegisterPipeWriterType(); b.RegisterDuplexPipeType(); - b.AddTypeShapeProvider(PolyType.SourceGenerator.ShapeProvider_StreamJsonRpc_Tests.Default); b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); return b.Build(); }); @@ -32,8 +34,9 @@ protected override void InitializeFormattersAndHandlers() clientFormatter.SetFormatterProfile(b => { + b.RegisterPipeReaderType(); + b.RegisterPipeWriterType(); b.RegisterDuplexPipeType(); - b.AddTypeShapeProvider(PolyType.SourceGenerator.ShapeProvider_StreamJsonRpc_Tests.Default); b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); return b.Build(); }); diff --git a/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs b/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs index 9903fb90c..3b4832865 100644 --- a/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs +++ b/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs @@ -131,13 +131,6 @@ public async Task BasicJsonRpc() [Fact] public void Resolver_RequestArgInArray() { - this.Formatter.SetFormatterProfile(b => - { - b.RegisterConverter(new CustomConverter()); - b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); - return b.Build(); - }); - var originalArg = new TypeRequiringCustomFormatter { Prop1 = 3, Prop2 = 5 }; var originalRequest = new JsonRpcRequest { @@ -272,7 +265,7 @@ public void Resolver_ErrorData() this.Formatter.SetFormatterProfile(b => { b.RegisterConverter(new CustomConverter()); - b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + b.AddTypeShapeProvider(ReflectionTypeShapeProvider.Default); return b.Build(); }); @@ -373,16 +366,25 @@ public void StringValuesOfStandardPropertiesAreInterned() Assert.Same(request1.Method, request2.Method); // reference equality to ensure it was interned. } - protected override NerdbankMessagePackFormatter CreateFormatter() => new(); + protected override NerdbankMessagePackFormatter CreateFormatter() + { + NerdbankMessagePackFormatter formatter = new(); + return formatter; + } private T Read(object anonymousObject) where T : JsonRpcMessage { + NerdbankMessagePackFormatter.FormatterProfileBuilder profileBuilder = this.Formatter.ProfileBuilder; + profileBuilder.AddTypeShapeProvider(ReflectionTypeShapeProvider.Default); + profileBuilder.RegisterConverter(new CustomConverter()); + NerdbankMessagePackFormatter.FormatterProfile profile = profileBuilder.Build(); + var sequence = new Sequence(); var writer = new MessagePackWriter(sequence); - new MessagePackSerializer().Serialize(ref writer, anonymousObject, ReflectionTypeShapeProvider.Default); + profile.SerializeObject(ref writer, anonymousObject); writer.Flush(); - return (T)this.Formatter.Deserialize(sequence); + return profile.Deserialize(sequence); } [DataContract] From 3d3cb907e31360723d6973f53ca09c8a0e70b43e Mon Sep 17 00:00:00 2001 From: Charles Willis Date: Thu, 26 Dec 2024 13:03:47 +0000 Subject: [PATCH 16/25] Fix bug when searching property names from peek reader --- .../NerdbankMessagePackFormatter.cs | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs index deb519148..5aa0b2f21 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs @@ -466,7 +466,7 @@ public MessagePackConverter GetConverter() if (MessageFormatterProgressTracker.CanDeserialize(typeof(T))) { - converter = new PreciseTypeConverter(this.mainFormatter); + converter = new FullProgressConverter(this.mainFormatter); } else if (MessageFormatterProgressTracker.CanSerialize(typeof(T))) { @@ -518,11 +518,11 @@ public override void Write(ref MessagePackWriter writer, in TClass? value, Seria /// /// Converts a progress token to an or an into a token. /// - private class PreciseTypeConverter : MessagePackConverter + private class FullProgressConverter : MessagePackConverter { private readonly NerdbankMessagePackFormatter formatter; - internal PreciseTypeConverter(NerdbankMessagePackFormatter formatter) + internal FullProgressConverter(NerdbankMessagePackFormatter formatter) { this.formatter = formatter; } @@ -560,7 +560,7 @@ public override void Write(ref MessagePackWriter writer, in TClass? value, Seria public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) { - return CreateUndocumentedSchema(typeof(PreciseTypeConverter)); + return CreateUndocumentedSchema(typeof(FullProgressConverter)); } } } @@ -1107,16 +1107,15 @@ internal JsonRpcMessageConverter(NerdbankMessagePackFormatter formatter) for (int i = 0; i < propertyCount; i++) { - ReadOnlySequence stringKey = readAhead.ReadStringSequence() ?? ReadOnlySequence.Empty; - if (MethodPropertyName.IsMatch(stringKey)) + if (MethodPropertyName.TryRead(ref readAhead)) { return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); } - else if (ResultPropertyName.IsMatch(stringKey)) + else if (ResultPropertyName.TryRead(ref readAhead)) { return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); } - else if (ErrorPropertyName.IsMatch(stringKey)) + else if (ErrorPropertyName.TryRead(ref readAhead)) { return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); } @@ -1124,6 +1123,9 @@ internal JsonRpcMessageConverter(NerdbankMessagePackFormatter formatter) { readAhead.Skip(context); } + + // Skip the value of the property. + readAhead.Skip(context); } throw new UnrecognizedJsonRpcMessageException(); From d8430f4b1b3977df73df8a2f7837243b0e9ebab8 Mon Sep 17 00:00:00 2001 From: Charles Willis Date: Fri, 27 Dec 2024 02:38:20 +0000 Subject: [PATCH 17/25] Begin process of making converters stateless --- ...nkMessagePackFormatter.FormatterProfile.cs | 103 +- ...gePackFormatter.FormatterProfileBuilder.cs | 33 +- ...Formatter.MessagePackFormatterConverter.cs | 2 +- .../NerdbankMessagePackFormatter.cs | 2079 +++++++++-------- ...nkMessagePackFormatterProfileExtensions.cs | 52 +- src/StreamJsonRpc/Protocol/JsonRpcError.cs | 3 + src/StreamJsonRpc/Protocol/JsonRpcMessage.cs | 1 + src/StreamJsonRpc/Protocol/JsonRpcRequest.cs | 2 + src/StreamJsonRpc/Protocol/JsonRpcResult.cs | 4 +- src/StreamJsonRpc/Protocol/TraceParent.cs | 2 + .../SerializationContextExtensions.cs | 16 + .../DisposableProxyTests.cs | 4 +- test/StreamJsonRpc.Tests/FormatterTestBase.cs | 5 - .../JsonRpcNerdbankMessagePackLengthTests.cs | 18 +- test/StreamJsonRpc.Tests/JsonRpcTests.cs | 14 +- .../MessagePackFormatterTests.cs | 4 - .../NerdbankMessagePackFormatterTests.cs | 21 +- 17 files changed, 1282 insertions(+), 1081 deletions(-) create mode 100644 src/StreamJsonRpc/SerializationContextExtensions.cs diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfile.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfile.cs index 6edcbfa5a..b281dfad6 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfile.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfile.cs @@ -1,8 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System.Collections.Immutable; +using System.Diagnostics; using Nerdbank.MessagePack; using PolyType; +using PolyType.Abstractions; namespace StreamJsonRpc; @@ -15,17 +18,109 @@ public sealed partial class NerdbankMessagePackFormatter /// Initializes a new instance of the class. /// /// The MessagePack serializer to use. - /// The type shape provider to use. - public class FormatterProfile(MessagePackSerializer serializer, ITypeShapeProvider shapeProvider) + /// The type shape providers to use. + [DebuggerDisplay($"{{{nameof(GetDebuggerDisplay)}(),nq}}")] + public class FormatterProfile(MessagePackSerializer serializer, ImmutableArray shapeProviders) { + /// + /// Initializes a new instance of the class. + /// + /// The source of the profile. + /// The MessagePack serializer to use. + /// The type shape providers to use. + internal FormatterProfile(ProfileSource source, MessagePackSerializer serializer, ImmutableArray shapeProviders) + : this(serializer, shapeProviders) + { + this.Source = source; + } + + internal enum ProfileSource + { + /// + /// The profile is internal to the formatter. + /// + Internal, + + /// + /// The profile is external to the formatter. + /// + External, + } + + internal ProfileSource Source { get; } = ProfileSource.Internal; + /// /// Gets the MessagePack serializer. /// internal MessagePackSerializer Serializer => serializer; /// - /// Gets the type shape provider. + /// Gets the shape provider resolver. + /// + internal TypeShapeProviderResolver ShapeProviderResolver { get; } = new TypeShapeProviderResolver(shapeProviders); + + private int ProvidersCount => shapeProviders.Length; + + private string GetDebuggerDisplay() => $"{this.Source} [{this.ProvidersCount}]"; + + /// + /// When passing a type shape provider to MessagePackSerializer, it will resolve the type shape + /// from that provider and try to cache it. If the resolved type shape is not sourced from the + /// passed provider, it will throw an ArgumentException with the message: + /// System.ArgumentException : The specified shape provider is not valid for this cache. + /// To avoid this, the resolver does not implement ITypeShapeProvider directly so that it cannot + /// be passed to the serializer. Instead, use to get the + /// provider that will resolve the shape for the specified type, or if the serialization method supports + /// if use to get the shape directly. + /// Related issue: https://github.com/eiriktsarpalis/PolyType/issues/92. /// - internal ITypeShapeProvider ShapeProvider => shapeProvider; + internal class TypeShapeProviderResolver + { + private readonly ImmutableArray providers; + + internal TypeShapeProviderResolver(ImmutableArray providers) + { + this.providers = providers; + } + + public ITypeShape ResolveShape() => (ITypeShape)this.ResolveShape(typeof(T)); + + public ITypeShape ResolveShape(Type type) + { + foreach (ITypeShapeProvider provider in this.providers) + { + ITypeShape? shape = provider.GetShape(type); + if (shape is not null) + { + return shape; + } + } + + // TODO: Loc the exception message. + throw new MessagePackSerializationException($"No shape provider found for type '{type}'."); + } + + /// + /// Find the first provider that can provide a shape for the specified type. + /// + /// The type to resolve. + /// The relevant shape provider. + public ITypeShapeProvider ResolveShapeProvider() => this.ResolveShapeProvider(typeof(T)); + + /// + public ITypeShapeProvider ResolveShapeProvider(Type type) + { + foreach (ITypeShapeProvider provider in this.providers) + { + if (provider.GetShape(type) is not null) + { + return provider; + } + } + + // TODO: Loc the exception message. + throw new MessagePackSerializationException($"No shape provider found for type '{type}'."); + } + } } } diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfileBuilder.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfileBuilder.cs index 24f0c3780..a57f8be13 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfileBuilder.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfileBuilder.cs @@ -184,35 +184,10 @@ public FormatterProfile Build() return this.baseProfile; } - ITypeShapeProvider provider = this.typeShapeProvidersBuilder.Count == 1 - ? this.typeShapeProvidersBuilder[0] - : new CompositeTypeShapeProvider(this.typeShapeProvidersBuilder.ToImmutable()); - - return new FormatterProfile(this.baseProfile.Serializer, provider); - } - } - - private class CompositeTypeShapeProvider : ITypeShapeProvider - { - private readonly ImmutableArray providers; - - internal CompositeTypeShapeProvider(ImmutableArray providers) - { - this.providers = providers; - } - - public ITypeShape? GetShape(Type type) - { - foreach (ITypeShapeProvider provider in this.providers) - { - ITypeShape? shape = provider.GetShape(type); - if (shape is not null) - { - return shape; - } - } - - return null; + return new FormatterProfile( + this.baseProfile.Source, + this.baseProfile.Serializer, + this.typeShapeProvidersBuilder.ToImmutable()); } } } diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.MessagePackFormatterConverter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.MessagePackFormatterConverter.cs index d2de639cd..e6a28a3ae 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.MessagePackFormatterConverter.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.MessagePackFormatterConverter.cs @@ -32,7 +32,7 @@ public object Convert(object value, TypeCode typeCode) { return typeCode switch { - TypeCode.Object => this.formatterContext.Deserialize((RawMessagePack)value), + TypeCode.Object => this.formatterContext.Deserialize((RawMessagePack)value)!, _ => ExceptionSerializationHelpers.Convert(this, value, typeCode), }; } diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs index 5aa0b2f21..7803f954e 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs @@ -72,17 +72,19 @@ public NerdbankMessagePackFormatter() { InternStrings = true, SerializeDefaultValues = true, + StartingContext = new SerializationContext() + { + [SerializationContextExtensions.FormatterKey] = this, + }, }; serializer.RegisterConverter(RequestIdConverter.Instance); - serializer.RegisterConverter(new JsonRpcMessageConverter(this)); - serializer.RegisterConverter(new JsonRpcRequestConverter(this)); - serializer.RegisterConverter(new JsonRpcResultConverter(this)); - serializer.RegisterConverter(new JsonRpcErrorConverter(this)); - serializer.RegisterConverter(new JsonRpcErrorDetailConverter(this)); serializer.RegisterConverter(new TraceParentConverter()); - this.rpcProfile = new FormatterProfile(serializer, ShapeProvider_StreamJsonRpc.Default); + this.rpcProfile = new FormatterProfile( + FormatterProfile.ProfileSource.Internal, + serializer, + [ShapeProvider_StreamJsonRpc.Default]); // Create the specialized formatters/resolvers that we will inject into the chain for user data. this.progressConverterResolver = new ProgressConverterResolver(this); @@ -95,6 +97,10 @@ public NerdbankMessagePackFormatter() { InternStrings = true, SerializeDefaultValues = true, + StartingContext = new SerializationContext() + { + [SerializationContextExtensions.FormatterKey] = this, + }, }; // Add our own resolvers to fill in specialized behavior if the user doesn't provide/override it by their own resolver. @@ -102,10 +108,14 @@ public NerdbankMessagePackFormatter() userSerializer.RegisterConverter(RequestIdConverter.Instance); // We preset this one because for some protocols like IProgress, tokens are passed in that we must relay exactly back to the client as an argument. - userSerializer.RegisterConverter(RawMessagePackFormatter.Instance); + // userSerializer.RegisterConverter(RawMessagePackFormatter.Instance); userSerializer.RegisterConverter(EventArgsConverter.Instance); - this.userDataProfile = new FormatterProfile(userSerializer, ReflectionTypeShapeProvider.Default); + this.userDataProfile = new FormatterProfile( + FormatterProfile.ProfileSource.External, + userSerializer, + [ReflectionTypeShapeProvider.Default]); + this.ProfileBuilder = new FormatterProfileBuilder(this, this.userDataProfile); } @@ -132,7 +142,6 @@ private interface IJsonRpcMessagePackRetention public void SetFormatterProfile(FormatterProfile profile) { Requires.NotNull(profile, nameof(profile)); - this.userDataProfile = profile; } @@ -145,15 +154,15 @@ public void SetFormatterProfile(Func Requires.NotNull(configure, nameof(configure)); var builder = new FormatterProfileBuilder(this, this.userDataProfile); - FormatterProfile profile = configure(builder); - this.userDataProfile = profile; + this.SetFormatterProfile(configure(builder)); } /// public JsonRpcMessage Deserialize(ReadOnlySequence contentBuffer) { - JsonRpcMessage message = this.rpcProfile.Deserialize(contentBuffer); + JsonRpcMessage message = this.rpcProfile.Deserialize(contentBuffer) + ?? throw new MessagePackSerializationException("Failed to deserialize JSON-RPC message."); IJsonRpcTracingCallbacks? tracingCallbacks = this.JsonRpc; this.deserializationToStringHelper.Activate(contentBuffer); @@ -338,10 +347,11 @@ bool TryGetSerializationInfo(MemberInfo memberInfo, out string key) private static ReadOnlySequence GetSliceForNextToken(ref MessagePackReader reader, in SerializationContext context) { - SequencePosition startingPosition = reader.Position; - reader.Skip(context); - SequencePosition endingPosition = reader.Position; - return reader.Sequence.Slice(startingPosition, endingPosition); + return reader.ReadRaw(context); + ////SequencePosition startingPosition = reader.Position; + ////reader.Skip(context); + ////SequencePosition endingPosition = reader.Position; + ////return reader.Sequence.Slice(startingPosition, endingPosition); } /// @@ -349,7 +359,7 @@ private static ReadOnlySequence GetSliceForNextToken(ref MessagePackReader /// /// The reader to use. /// The decoded string. - private static unsafe string ReadProtocolVersion(ref MessagePackReader reader) + private static string ReadProtocolVersion(ref MessagePackReader reader) { // Recognize "2.0" since we expect it and can avoid decoding and allocating a new string for it. if (Version2.TryRead(ref reader)) @@ -383,1292 +393,1441 @@ private static void ReadUnknownProperty(ref MessagePackReader reader, in Seriali topLevelProperties.Add(name, GetSliceForNextToken(ref reader, context)); } - private class RequestIdConverter : MessagePackConverter + /// + /// Converts JSON-RPC messages to and from MessagePack format. + /// + [GenerateShape] + internal partial class JsonRpcMessageConverter : MessagePackConverter { - internal static readonly RequestIdConverter Instance = new(); - - private RequestIdConverter() - { - } - - public override RequestId Read(ref MessagePackReader reader, SerializationContext context) + /// + /// Reads a JSON-RPC message from the specified MessagePack reader. + /// + /// The MessagePack reader to read from. + /// The serialization context. + /// The deserialized JSON-RPC message. + public override JsonRpcMessage? Read(ref MessagePackReader reader, SerializationContext context) { context.DepthStep(); - if (reader.NextMessagePackType == MessagePackType.Integer) - { - return new RequestId(reader.ReadInt64()); - } - else - { - // Do *not* read as an interned string here because this ID should be unique. - return new RequestId(reader.ReadString()); - } - } - - public override void Write(ref MessagePackWriter writer, in RequestId value, SerializationContext context) - { - context.DepthStep(); + MessagePackReader readAhead = reader.CreatePeekReader(); + int propertyCount = readAhead.ReadMapHeader(); - if (value.Number.HasValue) - { - writer.Write(value.Number.Value); - } - else + for (int i = 0; i < propertyCount; i++) { - writer.Write(value.String); + if (MethodPropertyName.TryRead(ref readAhead)) + { + return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (ResultPropertyName.TryRead(ref readAhead)) + { + return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (ErrorPropertyName.TryRead(ref readAhead)) + { + return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else + { + readAhead.Skip(context); + } + + // Skip the value of the property. + readAhead.Skip(context); } - } - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => JsonNode.Parse(""" - { - "type": ["string", { "type": "integer", "format": "int64" }] + throw new UnrecognizedJsonRpcMessageException(); } - """)?.AsObject(); - } - - private class RawMessagePackFormatter : MessagePackConverter - { - internal static readonly RawMessagePackFormatter Instance = new(); - private RawMessagePackFormatter() + /// + /// Writes a JSON-RPC message to the specified MessagePack writer. + /// + /// The MessagePack writer to write to. + /// The JSON-RPC message to write. + /// The serialization context. + public override void Write(ref MessagePackWriter writer, in JsonRpcMessage? value, SerializationContext context) { - } + Requires.NotNull(value!, nameof(value)); - public override RawMessagePack Read(ref MessagePackReader reader, SerializationContext context) - { - return new RawMessagePack(reader.ReadRaw(context)); - } + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); - public override void Write(ref MessagePackWriter writer, in RawMessagePack value, SerializationContext context) - { - writer.WriteRaw(value); + using (formatter.TrackSerialization(value)) + { + switch (value) + { + case Protocol.JsonRpcRequest request: + context.GetConverter(context.TypeShapeProvider).Write(ref writer, request, context); + break; + case Protocol.JsonRpcResult result: + context.GetConverter(context.TypeShapeProvider).Write(ref writer, result, context); + break; + case Protocol.JsonRpcError error: + context.GetConverter(context.TypeShapeProvider).Write(ref writer, error, context); + break; + default: + throw new NotSupportedException("Unexpected JsonRpcMessage-derived type: " + value.GetType().Name); + } + } } + /// + /// Gets the JSON schema for the specified type. + /// + /// The JSON schema context. + /// The type shape. + /// The JSON schema for the specified type. public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) { - return CreateUndocumentedSchema(typeof(RawMessagePackFormatter)); + return base.GetJsonSchema(context, typeShape); } } - private class ProgressConverterResolver + /// + /// Converts a JSON-RPC request message to and from MessagePack format. + /// + internal partial class JsonRpcRequestConverter : MessagePackConverter { - private readonly NerdbankMessagePackFormatter mainFormatter; - - internal ProgressConverterResolver(NerdbankMessagePackFormatter formatter) + /// + /// Reads a JSON-RPC request message from the specified MessagePack reader. + /// + /// The MessagePack reader to read from. + /// The serialization context. + /// The deserialized JSON-RPC request message. + public override Protocol.JsonRpcRequest? Read(ref MessagePackReader reader, SerializationContext context) { - this.mainFormatter = formatter; - } + NerdbankMessagePackFormatter formatter = context.GetFormatter(); - public MessagePackConverter GetConverter() - { - MessagePackConverter? converter = default; + context.DepthStep(); - if (MessageFormatterProgressTracker.CanDeserialize(typeof(T))) - { - converter = new FullProgressConverter(this.mainFormatter); - } - else if (MessageFormatterProgressTracker.CanSerialize(typeof(T))) + var result = new JsonRpcRequest(formatter) { - converter = new ProgressClientConverter(this.mainFormatter); - } - - // TODO: Improve Exception - return converter ?? throw new NotSupportedException(); - } + OriginalMessagePack = reader.Sequence, + }; - /// - /// Converts an instance of to a progress token. - /// - private class ProgressClientConverter : MessagePackConverter - { - private readonly NerdbankMessagePackFormatter formatter; + Dictionary>? topLevelProperties = null; - internal ProgressClientConverter(NerdbankMessagePackFormatter formatter) + int propertyCount = reader.ReadMapHeader(); + for (int propertyIndex = 0; propertyIndex < propertyCount; propertyIndex++) { - this.formatter = formatter; - } + if (VersionPropertyName.TryRead(ref reader)) + { + result.Version = ReadProtocolVersion(ref reader); + } + else if (IdPropertyName.TryRead(ref reader)) + { + result.RequestId = context.GetConverter(null).Read(ref reader, context); + } + else if (MethodPropertyName.TryRead(ref reader)) + { + result.Method = context.GetConverter(null).Read(ref reader, context); + } + else if (ParamsPropertyName.TryRead(ref reader)) + { + SequencePosition paramsTokenStartPosition = reader.Position; - public override TClass Read(ref MessagePackReader reader, SerializationContext context) - { - throw new NotSupportedException("This formatter only serializes IProgress instances."); - } + // Parse out the arguments into a dictionary or array, but don't deserialize them because we don't yet know what types to deserialize them to. + switch (reader.NextMessagePackType) + { + case MessagePackType.Array: + var positionalArgs = new ReadOnlySequence[reader.ReadArrayHeader()]; + for (int i = 0; i < positionalArgs.Length; i++) + { + positionalArgs[i] = GetSliceForNextToken(ref reader, context); + } - public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) - { - context.DepthStep(); + result.MsgPackPositionalArguments = positionalArgs; + break; + case MessagePackType.Map: + int namedArgsCount = reader.ReadMapHeader(); + var namedArgs = new Dictionary>(namedArgsCount); + for (int i = 0; i < namedArgsCount; i++) + { + string? propertyName = context.GetConverter(null).Read(ref reader, context) ?? throw new MessagePackSerializationException(Resources.UnexpectedNullValueInMap); + namedArgs.Add(propertyName, GetSliceForNextToken(ref reader, context)); + } - if (value is null) + result.MsgPackNamedArguments = namedArgs; + break; + case MessagePackType.Nil: + result.MsgPackPositionalArguments = []; + reader.ReadNil(); + break; + case MessagePackType type: + throw new MessagePackSerializationException("Expected a map or array of arguments but got " + type); + } + + result.MsgPackArguments = reader.Sequence.Slice(paramsTokenStartPosition, reader.Position); + } + else if (TraceParentPropertyName.TryRead(ref reader)) { - writer.WriteNil(); + TraceParent traceParent = context.GetConverter(null).Read(ref reader, context); + result.TraceParent = traceParent.ToString(); + } + else if (TraceStatePropertyName.TryRead(ref reader)) + { + result.TraceState = ReadTraceState(ref reader, context); } else { - long progressId = this.formatter.FormatterProgressTracker.GetTokenForProgress(value); - writer.Write(progressId); + ReadUnknownProperty(ref reader, context, ref topLevelProperties, reader.ReadStringSpan()); } } - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + if (topLevelProperties is not null) { - return CreateUndocumentedSchema(typeof(ProgressClientConverter)); + result.TopLevelPropertyBag = new TopLevelPropertyBag(formatter.userDataProfile, topLevelProperties); } + + formatter.TryHandleSpecialIncomingMessage(result); + + return result; } /// - /// Converts a progress token to an or an into a token. + /// Writes a JSON-RPC request message to the specified MessagePack writer. /// - private class FullProgressConverter : MessagePackConverter + /// The MessagePack writer to write to. + /// The JSON-RPC request message to write. + /// The serialization context. + public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcRequest? value, SerializationContext context) { - private readonly NerdbankMessagePackFormatter formatter; + Requires.NotNull(value!, nameof(value)); - internal FullProgressConverter(NerdbankMessagePackFormatter formatter) - { - this.formatter = formatter; - } + NerdbankMessagePackFormatter formatter = context.GetFormatter(); - [return: MaybeNull] - public override TClass? Read(ref MessagePackReader reader, SerializationContext context) - { - context.DepthStep(); + context.DepthStep(); - if (reader.TryReadNil()) + var topLevelPropertyBag = (TopLevelPropertyBag?)(value as IMessageWithTopLevelPropertyBag)?.TopLevelPropertyBag; + + int mapElementCount = value.RequestId.IsEmpty ? 3 : 4; + if (value.TraceParent?.Length > 0) + { + mapElementCount++; + if (value.TraceState?.Length > 0) { - return default!; + mapElementCount++; } - - Assumes.NotNull(this.formatter.JsonRpc); - RawMessagePack token = (RawMessagePack)reader.ReadRaw(context); - bool clientRequiresNamedArgs = this.formatter.ApplicableMethodAttributeOnDeserializingMethod?.ClientRequiresNamedArguments is true; - return (TClass)this.formatter.FormatterProgressTracker.CreateProgress(this.formatter.JsonRpc, token, typeof(TClass), clientRequiresNamedArgs); } - public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) - { - context.DepthStep(); + mapElementCount += topLevelPropertyBag?.PropertyCount ?? 0; + writer.WriteMapHeader(mapElementCount); - if (value is null) - { - writer.WriteNil(); - } - else - { - long progressId = this.formatter.FormatterProgressTracker.GetTokenForProgress(value); - writer.Write(progressId); - } - } + WriteProtocolVersionPropertyAndValue(ref writer, value.Version); - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + if (!value.RequestId.IsEmpty) { - return CreateUndocumentedSchema(typeof(FullProgressConverter)); + writer.Write(IdPropertyName); + context.GetConverter(context.TypeShapeProvider) + .Write(ref writer, value.RequestId, context); } - } - } - - private class AsyncEnumerableConverterResolver - { - private readonly NerdbankMessagePackFormatter mainFormatter; - - internal AsyncEnumerableConverterResolver(NerdbankMessagePackFormatter formatter) - { - this.mainFormatter = formatter; - } - public MessagePackConverter GetConverter() - { - MessagePackConverter? converter = default; + writer.Write(MethodPropertyName); + writer.Write(value.Method); - if (TrackerHelpers>.IsActualInterfaceMatch(typeof(T))) - { - converter = (MessagePackConverter?)Activator.CreateInstance( - typeof(PreciseTypeConverter<>).MakeGenericType(typeof(T).GenericTypeArguments[0]), - [this.mainFormatter]); - } - else if (TrackerHelpers>.FindInterfaceImplementedBy(typeof(T)) is { } iface) + writer.Write(ParamsPropertyName); + if (value.ArgumentsList is not null) { - converter = (MessagePackConverter?)Activator.CreateInstance( - typeof(GeneratorConverter<,>).MakeGenericType(typeof(T), iface.GenericTypeArguments[0]), - [this.mainFormatter]); - } - - // TODO: Improve Exception - return converter ?? throw new NotSupportedException(); - } - - /// - /// Converts an enumeration token to an - /// or an into an enumeration token. - /// -#pragma warning disable CA1812 - private partial class PreciseTypeConverter(NerdbankMessagePackFormatter mainFormatter) : MessagePackConverter> -#pragma warning restore CA1812 - { - /// - /// The constant "token", in its various forms. - /// - private static readonly MessagePackString TokenPropertyName = new(MessageFormatterEnumerableTracker.TokenPropertyName); - - /// - /// The constant "values", in its various forms. - /// - private static readonly MessagePackString ValuesPropertyName = new(MessageFormatterEnumerableTracker.ValuesPropertyName); + writer.WriteArrayHeader(value.ArgumentsList.Count); - public override IAsyncEnumerable? Read(ref MessagePackReader reader, SerializationContext context) - { - if (reader.TryReadNil()) + for (int i = 0; i < value.ArgumentsList.Count; i++) { - return default; - } - - context.DepthStep(); + object? arg = value.ArgumentsList[i]; - RawMessagePack? token = default; - IReadOnlyList? initialElements = null; - int propertyCount = reader.ReadMapHeader(); - for (int i = 0; i < propertyCount; i++) - { - if (TokenPropertyName.TryRead(ref reader)) - { - token = (RawMessagePack)reader.ReadRaw(context); - } - else if (ValuesPropertyName.TryRead(ref reader)) + if (value.ArgumentListDeclaredTypes is null) { - initialElements = context.GetConverter>(context.TypeShapeProvider).Read(ref reader, context); + formatter.userDataProfile.SerializeObject( + ref writer, + arg, + context.CancellationToken); } else { - reader.Skip(context); + formatter.userDataProfile.SerializeObject( + ref writer, + arg, + value.ArgumentListDeclaredTypes[i], + context.CancellationToken); } } - - return mainFormatter.EnumerableTracker.CreateEnumerableProxy(token.HasValue ? (object)token : null, initialElements); - } - - [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to helper method")] - public override void Write(ref MessagePackWriter writer, in IAsyncEnumerable? value, SerializationContext context) - { - context.DepthStep(); - Serialize_Shared(mainFormatter, ref writer, value, context); - } - - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) - { - return CreateUndocumentedSchema(typeof(PreciseTypeConverter)); } - - internal static void Serialize_Shared(NerdbankMessagePackFormatter mainFormatter, ref MessagePackWriter writer, IAsyncEnumerable? value, SerializationContext context) + else if (value.NamedArguments is not null) { - if (value is null) - { - writer.WriteNil(); - } - else + writer.WriteMapHeader(value.NamedArguments.Count); + foreach (KeyValuePair entry in value.NamedArguments) { - (IReadOnlyList elements, bool finished) = value.TearOffPrefetchedElements(); - long token = mainFormatter.EnumerableTracker.GetToken(value); - - int propertyCount = 0; - if (elements.Count > 0) - { - propertyCount++; - } - - if (!finished) - { - propertyCount++; - } - - writer.WriteMapHeader(propertyCount); + writer.Write(entry.Key); - if (!finished) + if (value.NamedArgumentDeclaredTypes is null) { - writer.Write(TokenPropertyName); - writer.Write(token); + formatter.userDataProfile.SerializeObject( + ref writer, + entry.Value, + context.CancellationToken); } - - if (elements.Count > 0) + else { - writer.Write(ValuesPropertyName); - context.GetConverter>(context.TypeShapeProvider).Write(ref writer, elements, context); + Type argType = value.NamedArgumentDeclaredTypes[entry.Key]; + formatter.userDataProfile.SerializeObject( + ref writer, + entry.Value, + argType, + context.CancellationToken); } } } - } - - /// - /// Converts an instance of to an enumeration token. - /// -#pragma warning disable CA1812 - private class GeneratorConverter(NerdbankMessagePackFormatter mainFormatter) : MessagePackConverter - where TClass : IAsyncEnumerable -#pragma warning restore CA1812 - { - public override TClass Read(ref MessagePackReader reader, SerializationContext context) + else { - throw new NotSupportedException(); + writer.WriteNil(); } - [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to helper method")] - public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) + if (value.TraceParent?.Length > 0) { - context.DepthStep(); - PreciseTypeConverter.Serialize_Shared(mainFormatter, ref writer, value, context); - } + writer.Write(TraceParentPropertyName); + context.GetConverter(context.TypeShapeProvider) + .Write(ref writer, new TraceParent(value.TraceParent), context); - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) - { - return CreateUndocumentedSchema(typeof(GeneratorConverter)); + if (value.TraceState?.Length > 0) + { + writer.Write(TraceStatePropertyName); + WriteTraceState(ref writer, value.TraceState); + } } - } - } - private class PipeConverterResolver - { - private readonly NerdbankMessagePackFormatter mainFormatter; + topLevelPropertyBag?.WriteProperties(ref writer); + } - internal PipeConverterResolver(NerdbankMessagePackFormatter formatter) + /// + /// Gets the JSON schema for the specified type. + /// + /// The JSON schema context. + /// The type shape. + /// The JSON schema for the specified type. + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) { - this.mainFormatter = formatter; + return CreateUndocumentedSchema(typeof(JsonRpcRequestConverter)); } - public MessagePackConverter GetConverter() + private static void WriteTraceState(ref MessagePackWriter writer, string traceState) { - MessagePackConverter? converter = default; + ReadOnlySpan traceStateChars = traceState.AsSpan(); - if (typeof(IDuplexPipe).IsAssignableFrom(typeof(T))) - { - converter = (MessagePackConverter?)Activator.CreateInstance(typeof(DuplexPipeConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; - } - else if (typeof(PipeReader).IsAssignableFrom(typeof(T))) + // Count elements first so we can write the header. + int elementCount = 1; + int commaIndex; + while ((commaIndex = traceStateChars.IndexOf(',')) >= 0) { - converter = (MessagePackConverter?)Activator.CreateInstance(typeof(PipeReaderConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; + elementCount++; + traceStateChars = traceStateChars.Slice(commaIndex + 1); } - else if (typeof(PipeWriter).IsAssignableFrom(typeof(T))) + + // For every element, we have a key and value to record. + writer.WriteArrayHeader(elementCount * 2); + + traceStateChars = traceState.AsSpan(); + while ((commaIndex = traceStateChars.IndexOf(',')) >= 0) { - converter = (MessagePackConverter?)Activator.CreateInstance(typeof(PipeWriterConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; + ReadOnlySpan element = traceStateChars.Slice(0, commaIndex); + WritePair(ref writer, element); + traceStateChars = traceStateChars.Slice(commaIndex + 1); } - else if (typeof(Stream).IsAssignableFrom(typeof(T))) + + // Write out the last one. + WritePair(ref writer, traceStateChars); + + static void WritePair(ref MessagePackWriter writer, ReadOnlySpan pair) { - converter = (MessagePackConverter?)Activator.CreateInstance(typeof(StreamConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; + int equalsIndex = pair.IndexOf('='); + ReadOnlySpan key = pair.Slice(0, equalsIndex); + ReadOnlySpan value = pair.Slice(equalsIndex + 1); + writer.Write(key); + writer.Write(value); } - - // TODO: Improve Exception - return converter ?? throw new NotSupportedException(); } -#pragma warning disable CA1812 - private class DuplexPipeConverter(NerdbankMessagePackFormatter formatter) : MessagePackConverter - where T : class, IDuplexPipe -#pragma warning restore CA1812 + private static unsafe string ReadTraceState(ref MessagePackReader reader, SerializationContext context) { - public override T? Read(ref MessagePackReader reader, SerializationContext context) + int elements = reader.ReadArrayHeader(); + if (elements % 2 != 0) { - context.DepthStep(); + throw new NotSupportedException("Odd number of elements not expected."); + } - if (reader.TryReadNil()) + // With care, we could probably assemble this string with just two allocations (the string + a char[]). + var resultBuilder = new StringBuilder(); + for (int i = 0; i < elements; i += 2) + { + if (resultBuilder.Length > 0) { - return null; + resultBuilder.Append(','); } - return (T)formatter.DuplexPipeTracker.GetPipe(reader.ReadUInt64()); - } - - public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) - { - context.DepthStep(); - - if (formatter.DuplexPipeTracker.GetULongToken(value) is { } token) - { - writer.Write(token); - } - else - { - writer.WriteNil(); - } + // We assume the key is a frequent string, and the value is unique, + // so we optimize whether to use string interning or not on that basis. + resultBuilder.Append(context.GetConverter(null).Read(ref reader, context)); + resultBuilder.Append('='); + resultBuilder.Append(reader.ReadString()); } - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) - { - return CreateUndocumentedSchema(typeof(DuplexPipeConverter)); - } + return resultBuilder.ToString(); } + } -#pragma warning disable CA1812 - private class PipeReaderConverter(NerdbankMessagePackFormatter formatter) : MessagePackConverter - where T : PipeReader -#pragma warning restore CA1812 + /// + /// Converts a JSON-RPC result message to and from MessagePack format. + /// + internal partial class JsonRpcResultConverter : MessagePackConverter + { + /// + /// Reads a JSON-RPC result message from the specified MessagePack reader. + /// + /// The MessagePack reader to read from. + /// The serialization context. + /// The deserialized JSON-RPC result message. + public override Protocol.JsonRpcResult Read(ref MessagePackReader reader, SerializationContext context) { - public override T? Read(ref MessagePackReader reader, SerializationContext context) + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); + + var result = new JsonRpcResult(formatter, formatter.userDataProfile) { - context.DepthStep(); - if (reader.TryReadNil()) - { - return null; - } + OriginalMessagePack = reader.Sequence, + }; - return (T)formatter.DuplexPipeTracker.GetPipeReader(reader.ReadUInt64()); - } + Dictionary>? topLevelProperties = null; - public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + int propertyCount = reader.ReadMapHeader(); + for (int propertyIndex = 0; propertyIndex < propertyCount; propertyIndex++) { - context.DepthStep(); - if (formatter.DuplexPipeTracker.GetULongToken(value) is { } token) + if (VersionPropertyName.TryRead(ref reader)) { - writer.Write(token); + result.Version = ReadProtocolVersion(ref reader); + } + else if (IdPropertyName.TryRead(ref reader)) + { + result.RequestId = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (ResultPropertyName.TryRead(ref reader)) + { + result.MsgPackResult = GetSliceForNextToken(ref reader, context); } else { - writer.WriteNil(); + ReadUnknownProperty(ref reader, context, ref topLevelProperties, reader.ReadStringSpan()); } } - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + if (topLevelProperties is not null) { - return CreateUndocumentedSchema(typeof(PipeReaderConverter)); + result.TopLevelPropertyBag = new TopLevelPropertyBag(formatter.userDataProfile, topLevelProperties); } + + return result; } -#pragma warning disable CA1812 - private class PipeWriterConverter(NerdbankMessagePackFormatter formatter) : MessagePackConverter - where T : PipeWriter -#pragma warning restore CA1812 + /// + /// Writes a JSON-RPC result message to the specified MessagePack writer. + /// + /// The MessagePack writer to write to. + /// The JSON-RPC result message to write. + /// The serialization context. + public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcResult? value, SerializationContext context) { - public override T? Read(ref MessagePackReader reader, SerializationContext context) - { - context.DepthStep(); - if (reader.TryReadNil()) - { - return null; - } + Requires.NotNull(value!, nameof(value)); - return (T)formatter.DuplexPipeTracker.GetPipeWriter(reader.ReadUInt64()); - } + NerdbankMessagePackFormatter formatter = context.GetFormatter(); - public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + context.DepthStep(); + + var topLevelPropertyBagMessage = value as IMessageWithTopLevelPropertyBag; + + int mapElementCount = 3; + mapElementCount += (topLevelPropertyBagMessage?.TopLevelPropertyBag as TopLevelPropertyBag)?.PropertyCount ?? 0; + writer.WriteMapHeader(mapElementCount); + + WriteProtocolVersionPropertyAndValue(ref writer, value.Version); + + writer.Write(IdPropertyName); + context.GetConverter(context.TypeShapeProvider).Write(ref writer, value.RequestId, context); + + writer.Write(ResultPropertyName); + + if (value.Result is null) { - context.DepthStep(); - if (formatter.DuplexPipeTracker.GetULongToken(value) is { } token) - { - writer.Write(token); - } - else - { - writer.WriteNil(); - } + writer.WriteNil(); } - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + if (value.ResultDeclaredType is not null && value.ResultDeclaredType != typeof(void)) { - return CreateUndocumentedSchema(typeof(PipeWriterConverter)); + formatter.userDataProfile.SerializeObject(ref writer, value.Result, value.ResultDeclaredType, context.CancellationToken); + } + else + { + formatter.userDataProfile.SerializeObject(ref writer, value.Result, context.CancellationToken); } + + (topLevelPropertyBagMessage?.TopLevelPropertyBag as TopLevelPropertyBag)?.WriteProperties(ref writer); } - private class StreamConverter : MessagePackConverter - where T : Stream + /// + /// Gets the JSON schema for the specified type. + /// + /// The JSON schema context. + /// The type shape. + /// The JSON schema for the specified type. + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) { - private readonly NerdbankMessagePackFormatter formatter; + return CreateUndocumentedSchema(typeof(JsonRpcResultConverter)); + } + } - public StreamConverter(NerdbankMessagePackFormatter formatter) + /// + /// Converts a JSON-RPC error message to and from MessagePack format. + /// + internal partial class JsonRpcErrorConverter : MessagePackConverter + { + /// + /// Reads a JSON-RPC error message from the specified MessagePack reader. + /// + /// The MessagePack reader to read from. + /// The serialization context. + /// The deserialized JSON-RPC error message. + public override Protocol.JsonRpcError Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + var error = new JsonRpcError(formatter.userDataProfile) { - this.formatter = formatter; - } + OriginalMessagePack = reader.Sequence, + }; - public override T? Read(ref MessagePackReader reader, SerializationContext context) - { - context.DepthStep(); + Dictionary>? topLevelProperties = null; - if (reader.TryReadNil()) + context.DepthStep(); + int propertyCount = reader.ReadMapHeader(); + for (int propertyIdx = 0; propertyIdx < propertyCount; propertyIdx++) + { + if (VersionPropertyName.TryRead(ref reader)) { - return null; + error.Version = ReadProtocolVersion(ref reader); } - - return (T)this.formatter.DuplexPipeTracker.GetPipe(reader.ReadUInt64()).AsStream(); - } - - public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) - { - context.DepthStep(); - - if (this.formatter.DuplexPipeTracker.GetULongToken(value?.UsePipe()) is { } token) + else if (IdPropertyName.TryRead(ref reader)) { - writer.Write(token); + error.RequestId = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (ErrorPropertyName.TryRead(ref reader)) + { + error.Error = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); } else { - writer.WriteNil(); + ReadUnknownProperty(ref reader, context, ref topLevelProperties, reader.ReadStringSpan()); } } - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + if (topLevelProperties is not null) { - return CreateUndocumentedSchema(typeof(StreamConverter)); + error.TopLevelPropertyBag = new TopLevelPropertyBag(formatter.userDataProfile, topLevelProperties); } + + return error; } - } - private class RpcMarshalableConverter( - NerdbankMessagePackFormatter formatter, - JsonRpcProxyOptions proxyOptions, - JsonRpcTargetOptions targetOptions, - RpcMarshalableAttribute rpcMarshalableAttribute) : MessagePackConverter - where T : class - { - [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Reader is passed to rpc context")] - public override T? Read(ref MessagePackReader reader, SerializationContext context) + /// + /// Writes a JSON-RPC error message to the specified MessagePack writer. + /// + /// The MessagePack writer to write to. + /// The JSON-RPC error message to write. + /// The serialization context. + public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcError? value, SerializationContext context) { + Requires.NotNull(value!, nameof(value)); + + var topLevelPropertyBag = (TopLevelPropertyBag?)(value as IMessageWithTopLevelPropertyBag)?.TopLevelPropertyBag; + context.DepthStep(); + int mapElementCount = 3; + mapElementCount += topLevelPropertyBag?.PropertyCount ?? 0; + writer.WriteMapHeader(mapElementCount); - MessageFormatterRpcMarshaledContextTracker.MarshalToken? token = formatter.rpcProfile - .Deserialize( - ref reader, - context.CancellationToken); + WriteProtocolVersionPropertyAndValue(ref writer, value.Version); - return token.HasValue ? (T?)formatter.RpcMarshaledContextTracker.GetObject(typeof(T), token, proxyOptions) : null; - } + writer.Write(IdPropertyName); + context.GetConverter(context.TypeShapeProvider) + .Write(ref writer, value.RequestId, context); - [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to rpc context")] - public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) - { - context.DepthStep(); + writer.Write(ErrorPropertyName); + context.GetConverter(context.TypeShapeProvider) + .Write(ref writer, value.Error, context); - if (value is null) - { - writer.WriteNil(); - } - else - { - MessageFormatterRpcMarshaledContextTracker.MarshalToken token = formatter.RpcMarshaledContextTracker.GetToken(value, targetOptions, typeof(T), rpcMarshalableAttribute); - formatter.rpcProfile.Serialize(ref writer, token, context.CancellationToken); - } + topLevelPropertyBag?.WriteProperties(ref writer); } + /// + /// Gets the JSON schema for the specified type. + /// + /// The JSON schema context. + /// The type shape. + /// The JSON schema for the specified type. public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) { - return CreateUndocumentedSchema(typeof(RpcMarshalableConverter)); + return CreateUndocumentedSchema(typeof(JsonRpcErrorConverter)); } } /// - /// Manages serialization of any -derived type that follows standard rules. + /// Converts a JSON-RPC error detail to and from MessagePack format. /// - /// - /// A serializable class will: - /// 1. Derive from - /// 2. Be attributed with - /// 3. Declare a constructor with a signature of (, ). - /// - private class MessagePackExceptionConverterResolver + internal partial class JsonRpcErrorDetailConverter : MessagePackConverter { + private static readonly MessagePackString CodePropertyName = new("code"); + private static readonly MessagePackString MessagePropertyName = new("message"); + private static readonly MessagePackString DataPropertyName = new("data"); + /// - /// Tracks recursion count while serializing or deserializing an exception. + /// Reads a JSON-RPC error detail from the specified MessagePack reader. /// - /// - /// This is placed here (outside the generic class) - /// so that it's one counter shared across all exception types that may be serialized or deserialized. - /// - private static ThreadLocal exceptionRecursionCounter = new(); - - private readonly object[] formatterActivationArgs; - - internal MessagePackExceptionConverterResolver(NerdbankMessagePackFormatter formatter) + /// The MessagePack reader to read from. + /// The serialization context. + /// The deserialized JSON-RPC error detail. + public override Protocol.JsonRpcError.ErrorDetail Read(ref MessagePackReader reader, SerializationContext context) { - this.formatterActivationArgs = new object[] { formatter }; - } + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); - public MessagePackConverter GetConverter() - { - MessagePackConverter? formatter = null; - if (typeof(Exception).IsAssignableFrom(typeof(T)) && typeof(T).GetCustomAttribute() is object) - { - formatter = (MessagePackConverter)Activator.CreateInstance(typeof(ExceptionConverter<>).MakeGenericType(typeof(T)), this.formatterActivationArgs)!; - } - - // TODO: Improve Exception - return formatter ?? throw new NotSupportedException(); - } + var result = new JsonRpcError.ErrorDetail(formatter.userDataProfile); -#pragma warning disable CA1812 - private partial class ExceptionConverter(NerdbankMessagePackFormatter formatter) : MessagePackConverter - where T : Exception -#pragma warning restore CA1812 - { - public override T? Read(ref MessagePackReader reader, SerializationContext context) + int propertyCount = reader.ReadMapHeader(); + for (int propertyIdx = 0; propertyIdx < propertyCount; propertyIdx++) { - Assumes.NotNull(formatter.JsonRpc); - - context.DepthStep(); - - if (reader.TryReadNil()) + if (CodePropertyName.TryRead(ref reader)) { - return null; + result.Code = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); } - - // We have to guard our own recursion because the serializer has no visibility into inner exceptions. - // Each exception in the russian doll is a new serialization job from its perspective. - exceptionRecursionCounter.Value++; - try + else if (MessagePropertyName.TryRead(ref reader)) { - if (exceptionRecursionCounter.Value > formatter.JsonRpc.ExceptionOptions.RecursionLimit) - { - // Exception recursion has gone too deep. Skip this value and return null as if there were no inner exception. - // Note that in skipping, the parser may use recursion internally and may still throw if its own limits are exceeded. - reader.Skip(context); - return null; - } - - // TODO: Is this the right context? - var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(formatter.rpcProfile)); - int memberCount = reader.ReadMapHeader(); - for (int i = 0; i < memberCount; i++) - { - string? name = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context) - ?? throw new MessagePackSerializationException(Resources.UnexpectedNullValueInMap); - - // SerializationInfo.GetValue(string, typeof(object)) does not call our formatter, - // so the caller will get a boxed RawMessagePack struct in that case. - // Although we can't do much about *that* in general, we can at least ensure that null values - // are represented as null instead of this boxed struct. - var value = reader.TryReadNil() ? null : (object)reader.ReadRaw(context); - - info.AddSafeValue(name, value); - } - - return ExceptionSerializationHelpers.Deserialize(formatter.JsonRpc, info, formatter.JsonRpc.TraceSource); + result.Message = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); } - finally + else if (DataPropertyName.TryRead(ref reader)) { - exceptionRecursionCounter.Value--; + result.MsgPackData = GetSliceForNextToken(ref reader, context); + } + else + { + reader.Skip(context); } } - public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) - { - context.DepthStep(); + return result; + } - if (value is null) - { - writer.WriteNil(); - return; - } + /// + /// Writes a JSON-RPC error detail to the specified MessagePack writer. + /// + /// The MessagePack writer to write to. + /// The JSON-RPC error detail to write. + /// The serialization context. + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to user data context")] + public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcError.ErrorDetail? value, SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); - exceptionRecursionCounter.Value++; - try - { - if (exceptionRecursionCounter.Value > formatter.JsonRpc?.ExceptionOptions.RecursionLimit) - { - // Exception recursion has gone too deep. Skip this value and write null as if there were no inner exception. - writer.WriteNil(); - return; - } + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); - // TODO: Is this the right context? - var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(formatter.rpcProfile)); - ExceptionSerializationHelpers.Serialize(value, info); - writer.WriteMapHeader(info.GetSafeMemberCount()); - foreach (SerializationEntry element in info.GetSafeMembers()) - { - writer.Write(element.Name); - formatter.rpcProfile.SerializeObject( - ref writer, - element.Value, - element.ObjectType, - context.CancellationToken); - } - } - finally - { - exceptionRecursionCounter.Value--; - } - } + writer.WriteMapHeader(3); - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) - { - return CreateUndocumentedSchema(typeof(ExceptionConverter)); - } - } - } + writer.Write(CodePropertyName); + context.GetConverter(context.TypeShapeProvider) + .Write(ref writer, value.Code, context); - [GenerateShape] - private partial class JsonRpcMessageConverter : MessagePackConverter - { - private readonly NerdbankMessagePackFormatter formatter; + writer.Write(MessagePropertyName); + writer.Write(value.Message); - internal JsonRpcMessageConverter(NerdbankMessagePackFormatter formatter) + writer.Write(DataPropertyName); + formatter.userDataProfile.Serialize(ref writer, value.Data, context.CancellationToken); + } + + /// + /// Gets the JSON schema for the specified type. + /// + /// The JSON schema context. + /// The type shape. + /// The JSON schema for the specified type. + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) { - this.formatter = formatter; + return CreateUndocumentedSchema(typeof(JsonRpcErrorDetailConverter)); } + } - public override JsonRpcMessage? Read(ref MessagePackReader reader, SerializationContext context) + internal class TraceParentConverter : MessagePackConverter + { + public unsafe override TraceParent Read(ref MessagePackReader reader, SerializationContext context) { context.DepthStep(); - MessagePackReader readAhead = reader.CreatePeekReader(); - int propertyCount = readAhead.ReadMapHeader(); + if (reader.ReadArrayHeader() != 2) + { + throw new NotSupportedException("Unexpected array length."); + } - for (int i = 0; i < propertyCount; i++) + var result = default(TraceParent); + result.Version = reader.ReadByte(); + if (result.Version != 0) { - if (MethodPropertyName.TryRead(ref readAhead)) - { - return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); - } - else if (ResultPropertyName.TryRead(ref readAhead)) - { - return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); - } - else if (ErrorPropertyName.TryRead(ref readAhead)) - { - return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); - } - else - { - readAhead.Skip(context); - } + throw new NotSupportedException("traceparent version " + result.Version + " is not supported."); + } - // Skip the value of the property. - readAhead.Skip(context); + if (reader.ReadArrayHeader() != 3) + { + throw new NotSupportedException("Unexpected array length in version-format."); } - throw new UnrecognizedJsonRpcMessageException(); + ReadOnlySequence bytes = reader.ReadBytes() ?? throw new NotSupportedException("Expected traceid not found."); + bytes.CopyTo(new Span(result.TraceId, TraceParent.TraceIdByteCount)); + + bytes = reader.ReadBytes() ?? throw new NotSupportedException("Expected parentid not found."); + bytes.CopyTo(new Span(result.ParentId, TraceParent.ParentIdByteCount)); + + result.Flags = (TraceParent.TraceFlags)reader.ReadByte(); + + return result; } - public override void Write(ref MessagePackWriter writer, in JsonRpcMessage? value, SerializationContext context) + public unsafe override void Write(ref MessagePackWriter writer, in TraceParent value, SerializationContext context) { - Requires.NotNull(value!, nameof(value)); + if (value.Version != 0) + { + throw new NotSupportedException("traceparent version " + value.Version + " is not supported."); + } context.DepthStep(); - using (this.formatter.TrackSerialization(value)) + writer.WriteArrayHeader(2); + + writer.Write(value.Version); + + writer.WriteArrayHeader(3); + + fixed (byte* traceId = value.TraceId) { - switch (value) - { - case Protocol.JsonRpcRequest request: - context.GetConverter(context.TypeShapeProvider).Write(ref writer, request, context); - break; - case Protocol.JsonRpcResult result: - context.GetConverter(context.TypeShapeProvider).Write(ref writer, result, context); - break; - case Protocol.JsonRpcError error: - context.GetConverter(context.TypeShapeProvider).Write(ref writer, error, context); - break; - default: - throw new NotSupportedException("Unexpected JsonRpcMessage-derived type: " + value.GetType().Name); - } + writer.Write(new ReadOnlySpan(traceId, TraceParent.TraceIdByteCount)); + } + + fixed (byte* parentId = value.ParentId) + { + writer.Write(new ReadOnlySpan(parentId, TraceParent.ParentIdByteCount)); } + + writer.Write((byte)value.Flags); } public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) { - return base.GetJsonSchema(context, typeShape); + return CreateUndocumentedSchema(typeof(TraceParentConverter)); } } - private partial class JsonRpcRequestConverter : MessagePackConverter + private class RequestIdConverter : MessagePackConverter { - private readonly NerdbankMessagePackFormatter formatter; + internal static readonly RequestIdConverter Instance = new(); - internal JsonRpcRequestConverter(NerdbankMessagePackFormatter formatter) + private RequestIdConverter() { - this.formatter = formatter; } - public override Protocol.JsonRpcRequest? Read(ref MessagePackReader reader, SerializationContext context) + public override RequestId Read(ref MessagePackReader reader, SerializationContext context) { context.DepthStep(); - var result = new JsonRpcRequest(this.formatter) + if (reader.NextMessagePackType == MessagePackType.Integer) { - OriginalMessagePack = reader.Sequence, - }; + return new RequestId(reader.ReadInt64()); + } + else + { + // Do *not* read as an interned string here because this ID should be unique. + return new RequestId(reader.ReadString()); + } + } - Dictionary>? topLevelProperties = null; + public override void Write(ref MessagePackWriter writer, in RequestId value, SerializationContext context) + { + context.DepthStep(); - int propertyCount = reader.ReadMapHeader(); - for (int propertyIndex = 0; propertyIndex < propertyCount; propertyIndex++) + if (value.Number.HasValue) { - if (VersionPropertyName.TryRead(ref reader)) - { - result.Version = ReadProtocolVersion(ref reader); - } - else if (IdPropertyName.TryRead(ref reader)) - { - result.RequestId = context.GetConverter(null).Read(ref reader, context); - } - else if (MethodPropertyName.TryRead(ref reader)) - { - result.Method = context.GetConverter(null).Read(ref reader, context); - } - else if (ParamsPropertyName.TryRead(ref reader)) - { - SequencePosition paramsTokenStartPosition = reader.Position; + writer.Write(value.Number.Value); + } + else + { + writer.Write(value.String); + } + } - // Parse out the arguments into a dictionary or array, but don't deserialize them because we don't yet know what types to deserialize them to. - switch (reader.NextMessagePackType) - { - case MessagePackType.Array: - var positionalArgs = new ReadOnlySequence[reader.ReadArrayHeader()]; - for (int i = 0; i < positionalArgs.Length; i++) - { - positionalArgs[i] = GetSliceForNextToken(ref reader, context); - } + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => JsonNode.Parse(""" + { + "type": ["string", { "type": "integer", "format": "int64" }] + } + """)?.AsObject(); + } - result.MsgPackPositionalArguments = positionalArgs; - break; - case MessagePackType.Map: - int namedArgsCount = reader.ReadMapHeader(); - var namedArgs = new Dictionary>(namedArgsCount); - for (int i = 0; i < namedArgsCount; i++) - { - string? propertyName = context.GetConverter(null).Read(ref reader, context) ?? throw new MessagePackSerializationException(Resources.UnexpectedNullValueInMap); - namedArgs.Add(propertyName, GetSliceForNextToken(ref reader, context)); - } + private class RawMessagePackFormatter : MessagePackConverter + { + internal static readonly RawMessagePackFormatter Instance = new(); - result.MsgPackNamedArguments = namedArgs; - break; - case MessagePackType.Nil: - result.MsgPackPositionalArguments = []; - reader.ReadNil(); - break; - case MessagePackType type: - throw new MessagePackSerializationException("Expected a map or array of arguments but got " + type); - } + private RawMessagePackFormatter() + { + } - result.MsgPackArguments = reader.Sequence.Slice(paramsTokenStartPosition, reader.Position); - } - else if (TraceParentPropertyName.TryRead(ref reader)) - { - TraceParent traceParent = context.GetConverter(null).Read(ref reader, context); - result.TraceParent = traceParent.ToString(); - } - else if (TraceStatePropertyName.TryRead(ref reader)) + public override RawMessagePack Read(ref MessagePackReader reader, SerializationContext context) + { + return new RawMessagePack(reader.ReadRaw(context)); + } + + public override void Write(ref MessagePackWriter writer, in RawMessagePack value, SerializationContext context) + { + writer.WriteRaw(value); + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(RawMessagePackFormatter)); + } + } + + private class ProgressConverterResolver + { + private readonly NerdbankMessagePackFormatter mainFormatter; + + internal ProgressConverterResolver(NerdbankMessagePackFormatter formatter) + { + this.mainFormatter = formatter; + } + + public MessagePackConverter GetConverter() + { + MessagePackConverter? converter = default; + + if (MessageFormatterProgressTracker.CanDeserialize(typeof(T))) + { + converter = new FullProgressConverter(this.mainFormatter); + } + else if (MessageFormatterProgressTracker.CanSerialize(typeof(T))) + { + converter = new ProgressClientConverter(this.mainFormatter); + } + + // TODO: Improve Exception + return converter ?? throw new NotSupportedException(); + } + + /// + /// Converts an instance of to a progress token. + /// + private class ProgressClientConverter : MessagePackConverter + { + private readonly NerdbankMessagePackFormatter formatter; + + internal ProgressClientConverter(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter; + } + + public override TClass Read(ref MessagePackReader reader, SerializationContext context) + { + throw new NotSupportedException("This formatter only serializes IProgress instances."); + } + + public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) + { + context.DepthStep(); + + if (value is null) { - result.TraceState = ReadTraceState(ref reader, context); + writer.WriteNil(); } else { - ReadUnknownProperty(ref reader, context, ref topLevelProperties, reader.ReadStringSpan()); + long progressId = this.formatter.FormatterProgressTracker.GetTokenForProgress(value); + writer.Write(progressId); } } - if (topLevelProperties is not null) + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) { - result.TopLevelPropertyBag = new TopLevelPropertyBag(this.formatter.userDataProfile, topLevelProperties); + return CreateUndocumentedSchema(typeof(ProgressClientConverter)); } - - this.formatter.TryHandleSpecialIncomingMessage(result); - - return result; } - public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcRequest? value, SerializationContext context) + /// + /// Converts a progress token to an or an into a token. + /// + private class FullProgressConverter : MessagePackConverter { - Requires.NotNull(value!, nameof(value)); - - context.DepthStep(); + private readonly NerdbankMessagePackFormatter formatter; - var topLevelPropertyBag = (TopLevelPropertyBag?)(value as IMessageWithTopLevelPropertyBag)?.TopLevelPropertyBag; + internal FullProgressConverter(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter; + } - int mapElementCount = value.RequestId.IsEmpty ? 3 : 4; - if (value.TraceParent?.Length > 0) + [return: MaybeNull] + public override TClass? Read(ref MessagePackReader reader, SerializationContext context) { - mapElementCount++; - if (value.TraceState?.Length > 0) + context.DepthStep(); + + if (reader.TryReadNil()) { - mapElementCount++; + return default!; } + + Assumes.NotNull(this.formatter.JsonRpc); + RawMessagePack token = (RawMessagePack)reader.ReadRaw(context); + bool clientRequiresNamedArgs = this.formatter.ApplicableMethodAttributeOnDeserializingMethod?.ClientRequiresNamedArguments is true; + return (TClass)this.formatter.FormatterProgressTracker.CreateProgress(this.formatter.JsonRpc, token, typeof(TClass), clientRequiresNamedArgs); } - mapElementCount += topLevelPropertyBag?.PropertyCount ?? 0; - writer.WriteMapHeader(mapElementCount); + public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) + { + context.DepthStep(); - WriteProtocolVersionPropertyAndValue(ref writer, value.Version); + if (value is null) + { + writer.WriteNil(); + } + else + { + long progressId = this.formatter.FormatterProgressTracker.GetTokenForProgress(value); + writer.Write(progressId); + } + } - if (!value.RequestId.IsEmpty) + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) { - writer.Write(IdPropertyName); - context.GetConverter(context.TypeShapeProvider) - .Write(ref writer, value.RequestId, context); + return CreateUndocumentedSchema(typeof(FullProgressConverter)); } + } + } - writer.Write(MethodPropertyName); - writer.Write(value.Method); + private class AsyncEnumerableConverterResolver + { + private readonly NerdbankMessagePackFormatter mainFormatter; - writer.Write(ParamsPropertyName); - if (value.ArgumentsList is not null) - { - writer.WriteArrayHeader(value.ArgumentsList.Count); + internal AsyncEnumerableConverterResolver(NerdbankMessagePackFormatter formatter) + { + this.mainFormatter = formatter; + } - for (int i = 0; i < value.ArgumentsList.Count; i++) - { - object? arg = value.ArgumentsList[i]; + public MessagePackConverter GetConverter() + { + MessagePackConverter? converter = default; - if (value.ArgumentListDeclaredTypes is null) - { - this.formatter.userDataProfile.SerializeObject( - ref writer, - arg, - context.CancellationToken); - } - else - { - this.formatter.userDataProfile.SerializeObject( - ref writer, - arg, - value.ArgumentListDeclaredTypes[i], - context.CancellationToken); - } - } + if (TrackerHelpers>.IsActualInterfaceMatch(typeof(T))) + { + converter = (MessagePackConverter?)Activator.CreateInstance( + typeof(PreciseTypeConverter<>).MakeGenericType(typeof(T).GenericTypeArguments[0]), + [this.mainFormatter]); } - else if (value.NamedArguments is not null) + else if (TrackerHelpers>.FindInterfaceImplementedBy(typeof(T)) is { } iface) { - writer.WriteMapHeader(value.NamedArguments.Count); - foreach (KeyValuePair entry in value.NamedArguments) + converter = (MessagePackConverter?)Activator.CreateInstance( + typeof(GeneratorConverter<,>).MakeGenericType(typeof(T), iface.GenericTypeArguments[0]), + [this.mainFormatter]); + } + + // TODO: Improve Exception + return converter ?? throw new NotSupportedException(); + } + + /// + /// Converts an enumeration token to an + /// or an into an enumeration token. + /// +#pragma warning disable CA1812 + private partial class PreciseTypeConverter(NerdbankMessagePackFormatter mainFormatter) : MessagePackConverter> +#pragma warning restore CA1812 + { + /// + /// The constant "token", in its various forms. + /// + private static readonly MessagePackString TokenPropertyName = new(MessageFormatterEnumerableTracker.TokenPropertyName); + + /// + /// The constant "values", in its various forms. + /// + private static readonly MessagePackString ValuesPropertyName = new(MessageFormatterEnumerableTracker.ValuesPropertyName); + + public override IAsyncEnumerable? Read(ref MessagePackReader reader, SerializationContext context) + { + if (reader.TryReadNil()) { - writer.Write(entry.Key); + return default; + } - if (value.NamedArgumentDeclaredTypes is null) + context.DepthStep(); + + RawMessagePack? token = default; + IReadOnlyList? initialElements = null; + int propertyCount = reader.ReadMapHeader(); + for (int i = 0; i < propertyCount; i++) + { + if (TokenPropertyName.TryRead(ref reader)) { - this.formatter.userDataProfile.SerializeObject( - ref writer, - entry.Value, - context.CancellationToken); + token = (RawMessagePack)reader.ReadRaw(context); + } + else if (ValuesPropertyName.TryRead(ref reader)) + { + initialElements = context.GetConverter>(context.TypeShapeProvider).Read(ref reader, context); } else { - Type argType = value.NamedArgumentDeclaredTypes[entry.Key]; - this.formatter.userDataProfile.SerializeObject( - ref writer, - entry.Value, - argType, - context.CancellationToken); + reader.Skip(context); } } + + return mainFormatter.EnumerableTracker.CreateEnumerableProxy(token.HasValue ? (object)token : null, initialElements); } - else + + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to helper method")] + public override void Write(ref MessagePackWriter writer, in IAsyncEnumerable? value, SerializationContext context) { - writer.WriteNil(); + context.DepthStep(); + Serialize_Shared(mainFormatter, ref writer, value, context); } - if (value.TraceParent?.Length > 0) + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) { - writer.Write(TraceParentPropertyName); - context.GetConverter(context.TypeShapeProvider) - .Write(ref writer, new TraceParent(value.TraceParent), context); + return CreateUndocumentedSchema(typeof(PreciseTypeConverter)); + } - if (value.TraceState?.Length > 0) + internal static void Serialize_Shared(NerdbankMessagePackFormatter mainFormatter, ref MessagePackWriter writer, IAsyncEnumerable? value, SerializationContext context) + { + if (value is null) { - writer.Write(TraceStatePropertyName); - WriteTraceState(ref writer, value.TraceState); + writer.WriteNil(); } - } - - topLevelPropertyBag?.WriteProperties(ref writer); - } - - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) - { - return CreateUndocumentedSchema(typeof(JsonRpcRequestConverter)); - } - - private static void WriteTraceState(ref MessagePackWriter writer, string traceState) - { - ReadOnlySpan traceStateChars = traceState.AsSpan(); + else + { + (IReadOnlyList elements, bool finished) = value.TearOffPrefetchedElements(); + long token = mainFormatter.EnumerableTracker.GetToken(value); - // Count elements first so we can write the header. - int elementCount = 1; - int commaIndex; - while ((commaIndex = traceStateChars.IndexOf(',')) >= 0) - { - elementCount++; - traceStateChars = traceStateChars.Slice(commaIndex + 1); - } + int propertyCount = 0; + if (elements.Count > 0) + { + propertyCount++; + } - // For every element, we have a key and value to record. - writer.WriteArrayHeader(elementCount * 2); + if (!finished) + { + propertyCount++; + } - traceStateChars = traceState.AsSpan(); - while ((commaIndex = traceStateChars.IndexOf(',')) >= 0) - { - ReadOnlySpan element = traceStateChars.Slice(0, commaIndex); - WritePair(ref writer, element); - traceStateChars = traceStateChars.Slice(commaIndex + 1); - } + writer.WriteMapHeader(propertyCount); - // Write out the last one. - WritePair(ref writer, traceStateChars); + if (!finished) + { + writer.Write(TokenPropertyName); + writer.Write(token); + } - static void WritePair(ref MessagePackWriter writer, ReadOnlySpan pair) - { - int equalsIndex = pair.IndexOf('='); - ReadOnlySpan key = pair.Slice(0, equalsIndex); - ReadOnlySpan value = pair.Slice(equalsIndex + 1); - writer.Write(key); - writer.Write(value); + if (elements.Count > 0) + { + writer.Write(ValuesPropertyName); + context.GetConverter>(context.TypeShapeProvider).Write(ref writer, elements, context); + } + } } } - private static unsafe string ReadTraceState(ref MessagePackReader reader, SerializationContext context) + /// + /// Converts an instance of to an enumeration token. + /// +#pragma warning disable CA1812 + private class GeneratorConverter(NerdbankMessagePackFormatter mainFormatter) : MessagePackConverter + where TClass : IAsyncEnumerable +#pragma warning restore CA1812 { - int elements = reader.ReadArrayHeader(); - if (elements % 2 != 0) + public override TClass Read(ref MessagePackReader reader, SerializationContext context) { - throw new NotSupportedException("Odd number of elements not expected."); + throw new NotSupportedException(); } - // With care, we could probably assemble this string with just two allocations (the string + a char[]). - var resultBuilder = new StringBuilder(); - for (int i = 0; i < elements; i += 2) + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to helper method")] + public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) { - if (resultBuilder.Length > 0) - { - resultBuilder.Append(','); - } - - // We assume the key is a frequent string, and the value is unique, - // so we optimize whether to use string interning or not on that basis. - resultBuilder.Append(context.GetConverter(null).Read(ref reader, context)); - resultBuilder.Append('='); - resultBuilder.Append(reader.ReadString()); + context.DepthStep(); + PreciseTypeConverter.Serialize_Shared(mainFormatter, ref writer, value, context); } - return resultBuilder.ToString(); + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(GeneratorConverter)); + } } } - private partial class JsonRpcResultConverter : MessagePackConverter + private class PipeConverterResolver { - private readonly NerdbankMessagePackFormatter formatter; + private readonly NerdbankMessagePackFormatter mainFormatter; - internal JsonRpcResultConverter(NerdbankMessagePackFormatter formatter) + internal PipeConverterResolver(NerdbankMessagePackFormatter formatter) { - this.formatter = formatter; + this.mainFormatter = formatter; } - public override Protocol.JsonRpcResult Read(ref MessagePackReader reader, SerializationContext context) + public MessagePackConverter GetConverter() { - context.DepthStep(); + MessagePackConverter? converter = default; - var result = new JsonRpcResult(this.formatter, this.formatter.userDataProfile) + if (typeof(IDuplexPipe).IsAssignableFrom(typeof(T))) { - OriginalMessagePack = reader.Sequence, - }; + converter = (MessagePackConverter?)Activator.CreateInstance(typeof(DuplexPipeConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; + } + else if (typeof(PipeReader).IsAssignableFrom(typeof(T))) + { + converter = (MessagePackConverter?)Activator.CreateInstance(typeof(PipeReaderConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; + } + else if (typeof(PipeWriter).IsAssignableFrom(typeof(T))) + { + converter = (MessagePackConverter?)Activator.CreateInstance(typeof(PipeWriterConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; + } + else if (typeof(Stream).IsAssignableFrom(typeof(T))) + { + converter = (MessagePackConverter?)Activator.CreateInstance(typeof(StreamConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; + } - Dictionary>? topLevelProperties = null; + // TODO: Improve Exception + return converter ?? throw new NotSupportedException(); + } - int propertyCount = reader.ReadMapHeader(); - for (int propertyIndex = 0; propertyIndex < propertyCount; propertyIndex++) +#pragma warning disable CA1812 + private class DuplexPipeConverter(NerdbankMessagePackFormatter formatter) : MessagePackConverter + where T : class, IDuplexPipe +#pragma warning restore CA1812 + { + public override T? Read(ref MessagePackReader reader, SerializationContext context) { - if (VersionPropertyName.TryRead(ref reader)) - { - result.Version = ReadProtocolVersion(ref reader); - } - else if (IdPropertyName.TryRead(ref reader)) + context.DepthStep(); + + if (reader.TryReadNil()) { - result.RequestId = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + return null; } - else if (ResultPropertyName.TryRead(ref reader)) + + return (T)formatter.DuplexPipeTracker.GetPipe(reader.ReadUInt64()); + } + + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + { + context.DepthStep(); + + if (formatter.DuplexPipeTracker.GetULongToken(value) is { } token) { - result.MsgPackResult = GetSliceForNextToken(ref reader, context); + writer.Write(token); } else { - ReadUnknownProperty(ref reader, context, ref topLevelProperties, reader.ReadStringSpan()); + writer.WriteNil(); } } - if (topLevelProperties is not null) + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) { - result.TopLevelPropertyBag = new TopLevelPropertyBag(this.formatter.userDataProfile, topLevelProperties); + return CreateUndocumentedSchema(typeof(DuplexPipeConverter)); } - - return result; } - public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcResult? value, SerializationContext context) +#pragma warning disable CA1812 + private class PipeReaderConverter(NerdbankMessagePackFormatter formatter) : MessagePackConverter + where T : PipeReader +#pragma warning restore CA1812 { - Requires.NotNull(value!, nameof(value)); - - context.DepthStep(); - - var topLevelPropertyBagMessage = value as IMessageWithTopLevelPropertyBag; - - int mapElementCount = 3; - mapElementCount += (topLevelPropertyBagMessage?.TopLevelPropertyBag as TopLevelPropertyBag)?.PropertyCount ?? 0; - writer.WriteMapHeader(mapElementCount); - - WriteProtocolVersionPropertyAndValue(ref writer, value.Version); - - writer.Write(IdPropertyName); - context.GetConverter(context.TypeShapeProvider).Write(ref writer, value.RequestId, context); - - writer.Write(ResultPropertyName); - - if (value.Result is null) + public override T? Read(ref MessagePackReader reader, SerializationContext context) { - writer.WriteNil(); + context.DepthStep(); + if (reader.TryReadNil()) + { + return null; + } + + return (T)formatter.DuplexPipeTracker.GetPipeReader(reader.ReadUInt64()); } - if (value.ResultDeclaredType is not null && value.ResultDeclaredType != typeof(void)) + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) { - this.formatter.userDataProfile.SerializeObject(ref writer, value.Result, value.ResultDeclaredType, context.CancellationToken); + context.DepthStep(); + if (formatter.DuplexPipeTracker.GetULongToken(value) is { } token) + { + writer.Write(token); + } + else + { + writer.WriteNil(); + } } - else + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) { - this.formatter.userDataProfile.SerializeObject(ref writer, value.Result, context.CancellationToken); + return CreateUndocumentedSchema(typeof(PipeReaderConverter)); } - - (topLevelPropertyBagMessage?.TopLevelPropertyBag as TopLevelPropertyBag)?.WriteProperties(ref writer); } - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) +#pragma warning disable CA1812 + private class PipeWriterConverter(NerdbankMessagePackFormatter formatter) : MessagePackConverter + where T : PipeWriter +#pragma warning restore CA1812 { - return CreateUndocumentedSchema(typeof(JsonRpcResultConverter)); - } - } + public override T? Read(ref MessagePackReader reader, SerializationContext context) + { + context.DepthStep(); + if (reader.TryReadNil()) + { + return null; + } - private partial class JsonRpcErrorConverter : MessagePackConverter - { - private readonly NerdbankMessagePackFormatter formatter; + return (T)formatter.DuplexPipeTracker.GetPipeWriter(reader.ReadUInt64()); + } - internal JsonRpcErrorConverter(NerdbankMessagePackFormatter formatter) - { - this.formatter = formatter; + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + { + context.DepthStep(); + if (formatter.DuplexPipeTracker.GetULongToken(value) is { } token) + { + writer.Write(token); + } + else + { + writer.WriteNil(); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(PipeWriterConverter)); + } } - public override Protocol.JsonRpcError Read(ref MessagePackReader reader, SerializationContext context) + private class StreamConverter : MessagePackConverter + where T : Stream { - var error = new JsonRpcError(this.formatter.userDataProfile) - { - OriginalMessagePack = reader.Sequence, - }; + private readonly NerdbankMessagePackFormatter formatter; - Dictionary>? topLevelProperties = null; + public StreamConverter(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter; + } - context.DepthStep(); - int propertyCount = reader.ReadMapHeader(); - for (int propertyIdx = 0; propertyIdx < propertyCount; propertyIdx++) + public override T? Read(ref MessagePackReader reader, SerializationContext context) { - if (VersionPropertyName.TryRead(ref reader)) - { - error.Version = ReadProtocolVersion(ref reader); - } - else if (IdPropertyName.TryRead(ref reader)) + context.DepthStep(); + + if (reader.TryReadNil()) { - error.RequestId = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + return null; } - else if (ErrorPropertyName.TryRead(ref reader)) + + return (T)this.formatter.DuplexPipeTracker.GetPipe(reader.ReadUInt64()).AsStream(); + } + + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + { + context.DepthStep(); + + if (this.formatter.DuplexPipeTracker.GetULongToken(value?.UsePipe()) is { } token) { - error.Error = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + writer.Write(token); } else { - ReadUnknownProperty(ref reader, context, ref topLevelProperties, reader.ReadStringSpan()); + writer.WriteNil(); } } - if (topLevelProperties is not null) + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) { - error.TopLevelPropertyBag = new TopLevelPropertyBag(this.formatter.userDataProfile, topLevelProperties); + return CreateUndocumentedSchema(typeof(StreamConverter)); } - - return error; } + } - public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcError? value, SerializationContext context) + private class RpcMarshalableConverter( + NerdbankMessagePackFormatter formatter, + JsonRpcProxyOptions proxyOptions, + JsonRpcTargetOptions targetOptions, + RpcMarshalableAttribute rpcMarshalableAttribute) : MessagePackConverter + where T : class + { + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Reader is passed to rpc context")] + public override T? Read(ref MessagePackReader reader, SerializationContext context) { - Requires.NotNull(value!, nameof(value)); - - var topLevelPropertyBag = (TopLevelPropertyBag?)(value as IMessageWithTopLevelPropertyBag)?.TopLevelPropertyBag; - - context.DepthStep(); - int mapElementCount = 3; - mapElementCount += topLevelPropertyBag?.PropertyCount ?? 0; - writer.WriteMapHeader(mapElementCount); + context.DepthStep(); - WriteProtocolVersionPropertyAndValue(ref writer, value.Version); + MessageFormatterRpcMarshaledContextTracker.MarshalToken? token = formatter.rpcProfile + .Deserialize( + ref reader, + context.CancellationToken); - writer.Write(IdPropertyName); - context.GetConverter(context.TypeShapeProvider) - .Write(ref writer, value.RequestId, context); + return token.HasValue ? (T?)formatter.RpcMarshaledContextTracker.GetObject(typeof(T), token, proxyOptions) : null; + } - writer.Write(ErrorPropertyName); - context.GetConverter(context.TypeShapeProvider) - .Write(ref writer, value.Error, context); + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to rpc context")] + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + { + context.DepthStep(); - topLevelPropertyBag?.WriteProperties(ref writer); + if (value is null) + { + writer.WriteNil(); + } + else + { + MessageFormatterRpcMarshaledContextTracker.MarshalToken token = formatter.RpcMarshaledContextTracker.GetToken(value, targetOptions, typeof(T), rpcMarshalableAttribute); + formatter.rpcProfile.Serialize(ref writer, token, context.CancellationToken); + } } public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) { - return CreateUndocumentedSchema(typeof(JsonRpcErrorConverter)); + return CreateUndocumentedSchema(typeof(RpcMarshalableConverter)); } } - private partial class JsonRpcErrorDetailConverter : MessagePackConverter + /// + /// Manages serialization of any -derived type that follows standard rules. + /// + /// + /// A serializable class will: + /// 1. Derive from + /// 2. Be attributed with + /// 3. Declare a constructor with a signature of (, ). + /// + private class MessagePackExceptionConverterResolver { - private static readonly MessagePackString CodePropertyName = new("code"); - private static readonly MessagePackString MessagePropertyName = new("message"); - private static readonly MessagePackString DataPropertyName = new("data"); + /// + /// Tracks recursion count while serializing or deserializing an exception. + /// + /// + /// This is placed here (outside the generic class) + /// so that it's one counter shared across all exception types that may be serialized or deserialized. + /// + private static ThreadLocal exceptionRecursionCounter = new(); - private readonly NerdbankMessagePackFormatter formatter; + private readonly object[] formatterActivationArgs; - internal JsonRpcErrorDetailConverter(NerdbankMessagePackFormatter formatter) + internal MessagePackExceptionConverterResolver(NerdbankMessagePackFormatter formatter) { - this.formatter = formatter; + this.formatterActivationArgs = new object[] { formatter }; } - public override Protocol.JsonRpcError.ErrorDetail Read(ref MessagePackReader reader, SerializationContext context) + public MessagePackConverter GetConverter() { - context.DepthStep(); + MessagePackConverter? formatter = null; + if (typeof(Exception).IsAssignableFrom(typeof(T)) && typeof(T).GetCustomAttribute() is object) + { + formatter = (MessagePackConverter)Activator.CreateInstance(typeof(ExceptionConverter<>).MakeGenericType(typeof(T)), this.formatterActivationArgs)!; + } - var result = new JsonRpcError.ErrorDetail(this.formatter.userDataProfile); + // TODO: Improve Exception + return formatter ?? throw new NotSupportedException(); + } - int propertyCount = reader.ReadMapHeader(); - for (int propertyIdx = 0; propertyIdx < propertyCount; propertyIdx++) +#pragma warning disable CA1812 + private partial class ExceptionConverter(NerdbankMessagePackFormatter formatter) : MessagePackConverter + where T : Exception +#pragma warning restore CA1812 + { + public override T? Read(ref MessagePackReader reader, SerializationContext context) { - if (CodePropertyName.TryRead(ref reader)) - { - result.Code = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); - } - else if (MessagePropertyName.TryRead(ref reader)) + Assumes.NotNull(formatter.JsonRpc); + + context.DepthStep(); + + if (reader.TryReadNil()) { - result.Message = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + return null; } - else if (DataPropertyName.TryRead(ref reader)) + + // We have to guard our own recursion because the serializer has no visibility into inner exceptions. + // Each exception in the russian doll is a new serialization job from its perspective. + exceptionRecursionCounter.Value++; + try { - result.MsgPackData = GetSliceForNextToken(ref reader, context); + if (exceptionRecursionCounter.Value > formatter.JsonRpc.ExceptionOptions.RecursionLimit) + { + // Exception recursion has gone too deep. Skip this value and return null as if there were no inner exception. + // Note that in skipping, the parser may use recursion internally and may still throw if its own limits are exceeded. + reader.Skip(context); + return null; + } + + // TODO: Is this the right context? + var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(formatter.rpcProfile)); + int memberCount = reader.ReadMapHeader(); + for (int i = 0; i < memberCount; i++) + { + string? name = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context) + ?? throw new MessagePackSerializationException(Resources.UnexpectedNullValueInMap); + + // SerializationInfo.GetValue(string, typeof(object)) does not call our formatter, + // so the caller will get a boxed RawMessagePack struct in that case. + // Although we can't do much about *that* in general, we can at least ensure that null values + // are represented as null instead of this boxed struct. + var value = reader.TryReadNil() ? null : (object)reader.ReadRaw(context); + + info.AddSafeValue(name, value); + } + + return ExceptionSerializationHelpers.Deserialize(formatter.JsonRpc, info, formatter.JsonRpc.TraceSource); } - else + finally { - reader.Skip(context); + exceptionRecursionCounter.Value--; } } - return result; - } - - [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to user data context")] - public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcError.ErrorDetail? value, SerializationContext context) - { - Requires.NotNull(value!, nameof(value)); - - context.DepthStep(); - - writer.WriteMapHeader(3); + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + { + context.DepthStep(); - writer.Write(CodePropertyName); - context.GetConverter(context.TypeShapeProvider) - .Write(ref writer, value.Code, context); + if (value is null) + { + writer.WriteNil(); + return; + } - writer.Write(MessagePropertyName); - writer.Write(value.Message); + exceptionRecursionCounter.Value++; + try + { + if (exceptionRecursionCounter.Value > formatter.JsonRpc?.ExceptionOptions.RecursionLimit) + { + // Exception recursion has gone too deep. Skip this value and write null as if there were no inner exception. + writer.WriteNil(); + return; + } - writer.Write(DataPropertyName); - this.formatter.userDataProfile.Serialize(ref writer, value.Data, context.CancellationToken); - } + // TODO: Is this the right context? + var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(formatter.rpcProfile)); + ExceptionSerializationHelpers.Serialize(value, info); + writer.WriteMapHeader(info.GetSafeMemberCount()); + foreach (SerializationEntry element in info.GetSafeMembers()) + { + writer.Write(element.Name); + formatter.rpcProfile.SerializeObject( + ref writer, + element.Value, + element.ObjectType, + context.CancellationToken); + } + } + finally + { + exceptionRecursionCounter.Value--; + } + } - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) - { - return CreateUndocumentedSchema(typeof(JsonRpcErrorDetailConverter)); + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(ExceptionConverter)); + } } } @@ -1705,74 +1864,6 @@ public override EventArgs Read(ref MessagePackReader reader, SerializationContex } } - private class TraceParentConverter : MessagePackConverter - { - public unsafe override TraceParent Read(ref MessagePackReader reader, SerializationContext context) - { - context.DepthStep(); - - if (reader.ReadArrayHeader() != 2) - { - throw new NotSupportedException("Unexpected array length."); - } - - var result = default(TraceParent); - result.Version = reader.ReadByte(); - if (result.Version != 0) - { - throw new NotSupportedException("traceparent version " + result.Version + " is not supported."); - } - - if (reader.ReadArrayHeader() != 3) - { - throw new NotSupportedException("Unexpected array length in version-format."); - } - - ReadOnlySequence bytes = reader.ReadBytes() ?? throw new NotSupportedException("Expected traceid not found."); - bytes.CopyTo(new Span(result.TraceId, TraceParent.TraceIdByteCount)); - - bytes = reader.ReadBytes() ?? throw new NotSupportedException("Expected parentid not found."); - bytes.CopyTo(new Span(result.ParentId, TraceParent.ParentIdByteCount)); - - result.Flags = (TraceParent.TraceFlags)reader.ReadByte(); - - return result; - } - - public unsafe override void Write(ref MessagePackWriter writer, in TraceParent value, SerializationContext context) - { - if (value.Version != 0) - { - throw new NotSupportedException("traceparent version " + value.Version + " is not supported."); - } - - context.DepthStep(); - - writer.WriteArrayHeader(2); - - writer.Write(value.Version); - - writer.WriteArrayHeader(3); - - fixed (byte* traceId = value.TraceId) - { - writer.Write(new ReadOnlySpan(traceId, TraceParent.TraceIdByteCount)); - } - - fixed (byte* parentId = value.ParentId) - { - writer.Write(new ReadOnlySpan(parentId, TraceParent.ParentIdByteCount)); - } - - writer.Write((byte)value.Flags); - } - - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) - { - return CreateUndocumentedSchema(typeof(TraceParentConverter)); - } - } - private class TopLevelPropertyBag : TopLevelPropertyBagBase { private readonly FormatterProfile formatterProfile; @@ -1976,14 +2067,14 @@ protected override void ReleaseBuffers() private partial class JsonRpcResult : JsonRpcResultBase, IJsonRpcMessagePackRetention { private readonly NerdbankMessagePackFormatter formatter; - private readonly FormatterProfile formatterContext; + private readonly FormatterProfile formatterProfile; private Exception? resultDeserializationException; - internal JsonRpcResult(NerdbankMessagePackFormatter formatter, FormatterProfile serializationOptions) + internal JsonRpcResult(NerdbankMessagePackFormatter formatter, FormatterProfile formatterProfile) { this.formatter = formatter; - this.formatterContext = serializationOptions; + this.formatterProfile = formatterProfile; } public ReadOnlySequence OriginalMessagePack { get; internal set; } @@ -1999,8 +2090,8 @@ public override T GetResult() return this.MsgPackResult.IsEmpty ? (T)this.Result! - : this.formatterContext.Deserialize(this.MsgPackResult) - ?? throw new MessagePackSerializationException(Resources.FailureDeserializingJsonRpc); + : this.formatterProfile.Deserialize(this.MsgPackResult) + ?? throw new MessagePackSerializationException("Failed to deserialize result."); } protected internal override void SetExpectedResultType(Type resultType) @@ -2011,7 +2102,7 @@ protected internal override void SetExpectedResultType(Type resultType) { using (this.formatter.TrackDeserialization(this)) { - this.Result = this.formatterContext.DeserializeObject(this.MsgPackResult, resultType); + this.Result = this.formatterProfile.DeserializeObject(this.MsgPackResult, resultType); } this.MsgPackResult = default; @@ -2030,7 +2121,7 @@ protected override void ReleaseBuffers() this.OriginalMessagePack = default; } - protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.formatterContext); + protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.formatterProfile); } [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatterProfileExtensions.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatterProfileExtensions.cs index 2ce7c0a8c..32ee75adc 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatterProfileExtensions.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatterProfileExtensions.cs @@ -3,7 +3,6 @@ using System.Buffers; using Nerdbank.MessagePack; -using PolyType.Abstractions; using static StreamJsonRpc.NerdbankMessagePackFormatter; namespace StreamJsonRpc; @@ -22,6 +21,20 @@ public static class NerdbankMessagePackFormatterProfileExtensions /// The object to serialize. /// A token to monitor for cancellation requests. public static void SerializeObject(this FormatterProfile profile, ref MessagePackWriter writer, object? value, CancellationToken cancellationToken = default) + { + Requires.NotNull(profile, nameof(profile)); + SerializeObject(profile, ref writer, value, value?.GetType() ?? typeof(object), cancellationToken); + } + + /// + /// Serializes an object using the specified . + /// + /// The formatter profile to use for serialization. + /// The writer to which the object will be serialized. + /// The object to serialize. + /// A token to monitor for cancellation requests. + /// The type of the object to serialize. + public static void Serialize(this FormatterProfile profile, ref MessagePackWriter writer, T? value, CancellationToken cancellationToken = default) { Requires.NotNull(profile, nameof(profile)); @@ -31,7 +44,11 @@ public static void SerializeObject(this FormatterProfile profile, ref MessagePac return; } - profile.Serializer.SerializeObject(ref writer, value, profile.ShapeProvider.Resolve(value.GetType()), cancellationToken); + profile.Serializer.Serialize( + ref writer, + value, + profile.ShapeProviderResolver.ResolveShape(), + cancellationToken); } /// @@ -43,42 +60,35 @@ public static void SerializeObject(this FormatterProfile profile, ref MessagePac /// A token to monitor for cancellation requests. /// The deserialized object of type . /// Thrown when deserialization fails. - public static T Deserialize(this FormatterProfile profile, in ReadOnlySequence pack, CancellationToken cancellationToken = default) + public static T? Deserialize(this FormatterProfile profile, in ReadOnlySequence pack, CancellationToken cancellationToken = default) { Requires.NotNull(profile, nameof(profile)); - - // TODO: Improve the exception - return profile.Serializer.Deserialize(pack, profile.ShapeProvider, cancellationToken) - ?? throw new MessagePackSerializationException(Resources.FailureDeserializingRpcResult); + MessagePackReader reader = new(pack); + return Deserialize(profile, ref reader, cancellationToken); } internal static T? Deserialize(this FormatterProfile profile, ref MessagePackReader reader, CancellationToken cancellationToken = default) { - return profile.Serializer.Deserialize(ref reader, profile.ShapeProvider, cancellationToken); + return profile.Serializer.Deserialize( + ref reader, + profile.ShapeProviderResolver.ResolveShapeProvider(), + cancellationToken); } internal static object? DeserializeObject(this FormatterProfile profile, in ReadOnlySequence pack, Type objectType, CancellationToken cancellationToken = default) { MessagePackReader reader = new(pack); - return profile.Serializer.DeserializeObject( - ref reader, - profile.ShapeProvider.Resolve(objectType), - cancellationToken); + return DeserializeObject(profile, ref reader, objectType, cancellationToken); } internal static object? DeserializeObject(this FormatterProfile profile, ref MessagePackReader reader, Type objectType, CancellationToken cancellationToken = default) { return profile.Serializer.DeserializeObject( ref reader, - profile.ShapeProvider.Resolve(objectType), + profile.ShapeProviderResolver.ResolveShape(objectType), cancellationToken); } - internal static void Serialize(this FormatterProfile profile, ref MessagePackWriter writer, T? value, CancellationToken cancellationToken = default) - { - profile.Serializer.Serialize(ref writer, value, profile.ShapeProvider, cancellationToken); - } - internal static void SerializeObject(this FormatterProfile profile, ref MessagePackWriter writer, object? value, Type objectType, CancellationToken cancellationToken = default) { if (value is null) @@ -87,6 +97,10 @@ internal static void SerializeObject(this FormatterProfile profile, ref MessageP return; } - profile.Serializer.SerializeObject(ref writer, value, profile.ShapeProvider.Resolve(objectType), cancellationToken); + profile.Serializer.SerializeObject( + ref writer, + value, + profile.ShapeProviderResolver.ResolveShape(objectType), + cancellationToken); } } diff --git a/src/StreamJsonRpc/Protocol/JsonRpcError.cs b/src/StreamJsonRpc/Protocol/JsonRpcError.cs index 244beea36..a088772f3 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcError.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcError.cs @@ -3,6 +3,7 @@ using System.Diagnostics; using System.Runtime.Serialization; +using Nerdbank.MessagePack; using PolyType; using StreamJsonRpc.Reflection; using JsonNET = Newtonsoft.Json.Linq; @@ -15,6 +16,7 @@ namespace StreamJsonRpc.Protocol; /// [DataContract] [GenerateShape] +[MessagePackConverter(typeof(NerdbankMessagePackFormatter.JsonRpcErrorConverter))] [DebuggerDisplay("{" + nameof(DebuggerDisplay) + "}")] public partial class JsonRpcError : JsonRpcMessage, IJsonRpcMessageWithId { @@ -72,6 +74,7 @@ public override string ToString() /// [DataContract] [GenerateShape] + [MessagePackConverter(typeof(NerdbankMessagePackFormatter.JsonRpcErrorDetailConverter))] public partial class ErrorDetail { /// diff --git a/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs b/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs index 0b09d0769..c69e34c1f 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs @@ -25,6 +25,7 @@ namespace StreamJsonRpc.Protocol; [KnownSubType(2)] [KnownSubType(3)] #endif +[MessagePackConverter(typeof(NerdbankMessagePackFormatter.JsonRpcMessageConverter))] public abstract partial class JsonRpcMessage { /// diff --git a/src/StreamJsonRpc/Protocol/JsonRpcRequest.cs b/src/StreamJsonRpc/Protocol/JsonRpcRequest.cs index 33f5a96cd..37fe79216 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcRequest.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcRequest.cs @@ -4,6 +4,7 @@ using System.Diagnostics; using System.Reflection; using System.Runtime.Serialization; +using Nerdbank.MessagePack; using PolyType; using JsonNET = Newtonsoft.Json.Linq; using STJ = System.Text.Json.Serialization; @@ -15,6 +16,7 @@ namespace StreamJsonRpc.Protocol; /// [DataContract] [GenerateShape] +[MessagePackConverter(typeof(NerdbankMessagePackFormatter.JsonRpcRequestConverter))] [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] public partial class JsonRpcRequest : JsonRpcMessage, IJsonRpcMessageWithId { diff --git a/src/StreamJsonRpc/Protocol/JsonRpcResult.cs b/src/StreamJsonRpc/Protocol/JsonRpcResult.cs index 0019e660e..d9860a99f 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcResult.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcResult.cs @@ -3,6 +3,7 @@ using System.Diagnostics; using System.Runtime.Serialization; +using Nerdbank.MessagePack; using PolyType; using JsonNET = Newtonsoft.Json.Linq; using STJ = System.Text.Json.Serialization; @@ -13,8 +14,9 @@ namespace StreamJsonRpc.Protocol; /// Describes the result of a successful method invocation. /// [DataContract] -[DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] [GenerateShape] +[MessagePackConverter(typeof(NerdbankMessagePackFormatter.JsonRpcResultConverter))] +[DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] public partial class JsonRpcResult : JsonRpcMessage, IJsonRpcMessageWithId { /// diff --git a/src/StreamJsonRpc/Protocol/TraceParent.cs b/src/StreamJsonRpc/Protocol/TraceParent.cs index b4fe393e2..de125b958 100644 --- a/src/StreamJsonRpc/Protocol/TraceParent.cs +++ b/src/StreamJsonRpc/Protocol/TraceParent.cs @@ -2,9 +2,11 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System.Diagnostics; +using Nerdbank.MessagePack; namespace StreamJsonRpc.Protocol; +[MessagePackConverter(typeof(NerdbankMessagePackFormatter.TraceParentConverter))] internal unsafe struct TraceParent { internal const int VersionByteCount = 1; diff --git a/src/StreamJsonRpc/SerializationContextExtensions.cs b/src/StreamJsonRpc/SerializationContextExtensions.cs new file mode 100644 index 000000000..b4f5ab84a --- /dev/null +++ b/src/StreamJsonRpc/SerializationContextExtensions.cs @@ -0,0 +1,16 @@ +using Nerdbank.MessagePack; + +namespace StreamJsonRpc; + +internal static class SerializationContextExtensions +{ + private static readonly object FormatterStateKey = new(); + + internal static object FormatterKey => FormatterStateKey; + + internal static NerdbankMessagePackFormatter GetFormatter(this ref SerializationContext context) + { + return (context[FormatterKey] as NerdbankMessagePackFormatter) + ?? throw new InvalidOperationException(); + } +} diff --git a/test/StreamJsonRpc.Tests/DisposableProxyTests.cs b/test/StreamJsonRpc.Tests/DisposableProxyTests.cs index 8be43c55b..b391b2410 100644 --- a/test/StreamJsonRpc.Tests/DisposableProxyTests.cs +++ b/test/StreamJsonRpc.Tests/DisposableProxyTests.cs @@ -77,7 +77,7 @@ public async Task IDisposableInNotificationArgumentIsRejected() Assert.True(IsExceptionOrInnerOfType(ex)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development public async Task DisposableReturnValue_DisposeSwallowsSecondCall() { IDisposable? proxyDisposable = await this.client.GetDisposableAsync(); @@ -86,7 +86,7 @@ public async Task DisposableReturnValue_DisposeSwallowsSecondCall() proxyDisposable.Dispose(); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development public async Task DisposableReturnValue_IsMarshaledAndLaterCollected() { var weakRefs = await this.DisposableReturnValue_Helper(); diff --git a/test/StreamJsonRpc.Tests/FormatterTestBase.cs b/test/StreamJsonRpc.Tests/FormatterTestBase.cs index 92803a257..9da28e887 100644 --- a/test/StreamJsonRpc.Tests/FormatterTestBase.cs +++ b/test/StreamJsonRpc.Tests/FormatterTestBase.cs @@ -1,10 +1,5 @@ using System.Runtime.Serialization; -using MessagePack.Formatters; using Nerdbank.Streams; -using StreamJsonRpc; -using StreamJsonRpc.Protocol; -using Xunit; -using Xunit.Abstractions; public abstract class FormatterTestBase : TestBase where TFormatter : IJsonRpcMessageFormatter diff --git a/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs index 294154beb..50ee951a3 100644 --- a/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs +++ b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs @@ -392,11 +392,14 @@ protected override void InitializeFormattersAndHandlers( out IJsonRpcMessageHandler clientMessageHandler, bool controlledFlushingClient) { - serverMessageFormatter = new NerdbankMessagePackFormatter(); - clientMessageFormatter = new NerdbankMessagePackFormatter(); + NerdbankMessagePackFormatter serverFormatter = new(); + serverFormatter.SetFormatterProfile(Configure); - ((NerdbankMessagePackFormatter)serverMessageFormatter).SetFormatterProfile(Configure); - ((NerdbankMessagePackFormatter)clientMessageFormatter).SetFormatterProfile(Configure); + NerdbankMessagePackFormatter clientFormatter = new(); + clientFormatter.SetFormatterProfile(Configure); + + serverMessageFormatter = serverFormatter; + clientMessageFormatter = clientFormatter; serverMessageHandler = new LengthHeaderMessageHandler(serverStream, serverStream, serverMessageFormatter); clientMessageHandler = controlledFlushingClient @@ -405,6 +408,7 @@ protected override void InitializeFormattersAndHandlers( static NerdbankMessagePackFormatter.FormatterProfile Configure(NerdbankMessagePackFormatter.FormatterProfileBuilder b) { + b.RegisterAsyncEnumerableType, UnionBaseClass>(); b.RegisterConverter(new UnserializableTypeConverter()); b.RegisterConverter(new TypeThrowsWhenDeserializedConverter()); b.RegisterConverter(new CustomExtensionConverter()); @@ -418,9 +422,11 @@ static NerdbankMessagePackFormatter.FormatterProfile Configure(NerdbankMessagePa protected override object[] CreateFormatterIntrinsicParamsObject(string arg) => []; [GenerateShape] -#pragma warning disable CS0618 +#if NET + [KnownSubType] +#else [KnownSubType(typeof(UnionDerivedClass))] -#pragma warning restore CS0618 +#endif public abstract partial class UnionBaseClass { } diff --git a/test/StreamJsonRpc.Tests/JsonRpcTests.cs b/test/StreamJsonRpc.Tests/JsonRpcTests.cs index ced85ca53..f5117dc7d 100644 --- a/test/StreamJsonRpc.Tests/JsonRpcTests.cs +++ b/test/StreamJsonRpc.Tests/JsonRpcTests.cs @@ -393,7 +393,7 @@ public async Task CanCallAsyncMethod() Assert.Equal("test!", result); } - [Theory, PairwiseData] + [Theory(Timeout = 2 * 1000), PairwiseData] // TODO: Temporary for development public async Task CanCallAsyncMethodThatThrows(ExceptionProcessing exceptionStrategy) { this.clientRpc.AllowModificationWhileListening = true; @@ -433,7 +433,7 @@ public async Task CanCallAsyncMethodThatThrowsNonSerializableException(Exception } } - [Theory, PairwiseData] + [Theory(Timeout = 2 * 1000), PairwiseData] // TODO: Temporary for development public async Task CanCallAsyncMethodThatThrowsExceptionWithoutDeserializingConstructor(ExceptionProcessing exceptionStrategy) { this.clientRpc.AllowModificationWhileListening = true; @@ -458,7 +458,7 @@ public async Task CanCallAsyncMethodThatThrowsExceptionWithoutDeserializingConst } } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development public async Task CanCallAsyncMethodThatThrowsExceptionWhileSerializingException() { this.clientRpc.AllowModificationWhileListening = true; @@ -472,7 +472,7 @@ public async Task CanCallAsyncMethodThatThrowsExceptionWhileSerializingException Assert.Null(exception.InnerException); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development public async Task ThrowCustomExceptionThatImplementsISerializableProperly() { this.clientRpc.AllowModificationWhileListening = true; @@ -2249,7 +2249,7 @@ public async Task CanPassExceptionFromServer_DeserializedErrorData() Assert.StrictEqual(COR_E_UNAUTHORIZEDACCESS, errorData.HResult); } - [Theory, PairwiseData] + [Theory(Timeout = 2 * 1000), PairwiseData] // TODO: Temporary for development public async Task ExceptionTreeThrownFromServerIsDeserializedAtClient(ExceptionProcessing exceptionStrategy) { this.clientRpc.AllowModificationWhileListening = true; @@ -2478,7 +2478,7 @@ public async Task ExceptionRecursionLimit_ArgumentDeserialization() Assert.Equal(this.serverRpc.ExceptionOptions.RecursionLimit, CountRecursionLevel(this.server.ReceivedException)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development public async Task ExceptionRecursionLimit_ThrownSerialization() { this.serverRpc.AllowModificationWhileListening = true; @@ -2494,7 +2494,7 @@ public async Task ExceptionRecursionLimit_ThrownSerialization() Assert.Equal(this.clientRpc.ExceptionOptions.RecursionLimit, actualRecursionLevel); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development public async Task ExceptionRecursionLimit_ThrownDeserialization() { this.clientRpc.AllowModificationWhileListening = true; diff --git a/test/StreamJsonRpc.Tests/MessagePackFormatterTests.cs b/test/StreamJsonRpc.Tests/MessagePackFormatterTests.cs index 39f633486..fe15c18b5 100644 --- a/test/StreamJsonRpc.Tests/MessagePackFormatterTests.cs +++ b/test/StreamJsonRpc.Tests/MessagePackFormatterTests.cs @@ -5,10 +5,6 @@ using MessagePack.Resolvers; using Microsoft.VisualStudio.Threading; using Nerdbank.Streams; -using StreamJsonRpc; -using StreamJsonRpc.Protocol; -using Xunit; -using Xunit.Abstractions; public class MessagePackFormatterTests : FormatterTestBase { diff --git a/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs b/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs index 3b4832865..e9f364601 100644 --- a/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs +++ b/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs @@ -262,13 +262,6 @@ public void Resolver_Result() [Fact] public void Resolver_ErrorData() { - this.Formatter.SetFormatterProfile(b => - { - b.RegisterConverter(new CustomConverter()); - b.AddTypeShapeProvider(ReflectionTypeShapeProvider.Default); - return b.Build(); - }); - var originalErrorData = new TypeRequiringCustomFormatter { Prop1 = 3, Prop2 = 5 }; var originalError = new JsonRpcError { @@ -369,6 +362,13 @@ public void StringValuesOfStandardPropertiesAreInterned() protected override NerdbankMessagePackFormatter CreateFormatter() { NerdbankMessagePackFormatter formatter = new(); + formatter.SetFormatterProfile(b => + { + b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + b.AddTypeShapeProvider(ReflectionTypeShapeProvider.Default); + return b.Build(); + }); + return formatter; } @@ -376,15 +376,17 @@ private T Read(object anonymousObject) where T : JsonRpcMessage { NerdbankMessagePackFormatter.FormatterProfileBuilder profileBuilder = this.Formatter.ProfileBuilder; + profileBuilder.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); profileBuilder.AddTypeShapeProvider(ReflectionTypeShapeProvider.Default); - profileBuilder.RegisterConverter(new CustomConverter()); NerdbankMessagePackFormatter.FormatterProfile profile = profileBuilder.Build(); + this.Formatter.SetFormatterProfile(profile); + var sequence = new Sequence(); var writer = new MessagePackWriter(sequence); profile.SerializeObject(ref writer, anonymousObject); writer.Flush(); - return profile.Deserialize(sequence); + return (T)this.Formatter.Deserialize(sequence.AsReadOnlySequence); } [DataContract] @@ -424,6 +426,7 @@ public partial class NonDataContractWithExcludedMembers internal string? InternalProperty { get; set; } } + [MessagePackConverter(typeof(CustomConverter))] [GenerateShape] public partial class TypeRequiringCustomFormatter { From 6444f5560ed8d72632e114ff5e71d3071ae4bbe0 Mon Sep 17 00:00:00 2001 From: Charles Willis Date: Fri, 27 Dec 2024 19:33:58 +0000 Subject: [PATCH 18/25] Remove state from converters. --- ...nkMessagePackFormatter.FormatterProfile.cs | 11 +- ...gePackFormatter.FormatterProfileBuilder.cs | 21 +- .../NerdbankMessagePackFormatter.cs | 412 ++++++++---------- .../JsonRpcNerdbankMessagePackLengthTests.cs | 3 + test/StreamJsonRpc.Tests/TestBase.cs | 6 +- 5 files changed, 197 insertions(+), 256 deletions(-) diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfile.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfile.cs index b281dfad6..de1499530 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfile.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfile.cs @@ -61,6 +61,15 @@ internal enum ProfileSource private int ProvidersCount => shapeProviders.Length; + internal FormatterProfile WithFormatterState(NerdbankMessagePackFormatter formatter) + { + SerializationContext nextContext = serializer.StartingContext; + nextContext[SerializationContextExtensions.FormatterKey] = formatter; + MessagePackSerializer nextSerializer = serializer with { StartingContext = nextContext }; + + return new(this.Source, nextSerializer, shapeProviders); + } + private string GetDebuggerDisplay() => $"{this.Source} [{this.ProvidersCount}]"; /// @@ -68,7 +77,7 @@ internal enum ProfileSource /// from that provider and try to cache it. If the resolved type shape is not sourced from the /// passed provider, it will throw an ArgumentException with the message: /// System.ArgumentException : The specified shape provider is not valid for this cache. - /// To avoid this, the resolver does not implement ITypeShapeProvider directly so that it cannot + /// To avoid this, this class does not implement ITypeShapeProvider directly so that it cannot /// be passed to the serializer. Instead, use to get the /// provider that will resolve the shape for the specified type, or if the serialization method supports /// if use to get the shape directly. diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfileBuilder.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfileBuilder.cs index a57f8be13..03adb3cc0 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfileBuilder.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfileBuilder.cs @@ -5,7 +5,6 @@ using System.IO.Pipelines; using Nerdbank.MessagePack; using PolyType; -using PolyType.Abstractions; using StreamJsonRpc.Reflection; namespace StreamJsonRpc; @@ -20,7 +19,6 @@ public sealed partial class NerdbankMessagePackFormatter /// public class FormatterProfileBuilder { - private readonly NerdbankMessagePackFormatter formatter; private readonly FormatterProfile baseProfile; private ImmutableArray.Builder? typeShapeProvidersBuilder = null; @@ -28,11 +26,9 @@ public class FormatterProfileBuilder /// /// Initializes a new instance of the class. /// - /// The formatter to use. /// The base profile to build upon. - internal FormatterProfileBuilder(NerdbankMessagePackFormatter formatter, FormatterProfile baseProfile) + internal FormatterProfileBuilder(FormatterProfile baseProfile) { - this.formatter = formatter; this.baseProfile = baseProfile; } @@ -54,7 +50,7 @@ public void AddTypeShapeProvider(ITypeShapeProvider provider) public void RegisterAsyncEnumerableType() where TEnumerable : IAsyncEnumerable { - MessagePackConverter converter = this.formatter.asyncEnumerableConverterResolver.GetConverter(); + MessagePackConverter converter = AsyncEnumerableConverterResolver.GetConverter(); this.baseProfile.Serializer.RegisterConverter(converter); } @@ -86,7 +82,7 @@ public void RegisterKnownSubTypes(KnownSubTypeMapping mapping) public void RegisterProgressType() where TProgress : IProgress { - MessagePackConverter converter = this.formatter.progressConverterResolver.GetConverter(); + MessagePackConverter converter = ProgressConverterResolver.GetConverter(); this.baseProfile.Serializer.RegisterConverter(converter); } @@ -97,7 +93,7 @@ public void RegisterProgressType() public void RegisterDuplexPipeType() where TPipe : IDuplexPipe { - MessagePackConverter converter = this.formatter.pipeConverterResolver.GetConverter(); + MessagePackConverter converter = PipeConverterResolver.GetConverter(); this.baseProfile.Serializer.RegisterConverter(converter); } @@ -108,7 +104,7 @@ public void RegisterDuplexPipeType() public void RegisterPipeReaderType() where TReader : PipeReader { - MessagePackConverter converter = this.formatter.pipeConverterResolver.GetConverter(); + MessagePackConverter converter = PipeConverterResolver.GetConverter(); this.baseProfile.Serializer.RegisterConverter(converter); } @@ -119,7 +115,7 @@ public void RegisterPipeReaderType() public void RegisterPipeWriterType() where TWriter : PipeWriter { - MessagePackConverter converter = this.formatter.pipeConverterResolver.GetConverter(); + MessagePackConverter converter = PipeConverterResolver.GetConverter(); this.baseProfile.Serializer.RegisterConverter(converter); } @@ -130,7 +126,7 @@ public void RegisterPipeWriterType() public void RegisterStreamType() where TStream : Stream { - MessagePackConverter converter = this.formatter.pipeConverterResolver.GetConverter(); + MessagePackConverter converter = PipeConverterResolver.GetConverter(); this.baseProfile.Serializer.RegisterConverter(converter); } @@ -141,7 +137,7 @@ public void RegisterStreamType() public void RegisterExceptionType() where TException : Exception { - MessagePackConverter converter = this.formatter.exceptionResolver.GetConverter(); + MessagePackConverter converter = MessagePackExceptionConverterResolver.GetConverter(); this.baseProfile.Serializer.RegisterConverter(converter); } @@ -160,7 +156,6 @@ public void RegisterRpcMarshalableType() { var converter = (RpcMarshalableConverter)Activator.CreateInstance( typeof(RpcMarshalableConverter<>).MakeGenericType(typeof(T)), - this.formatter, proxyOptions, targetOptions, attribute)!; diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs index 7803f954e..d8ff95bef 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs @@ -45,14 +45,6 @@ public partial class NerdbankMessagePackFormatter : FormatterBase, IJsonRpcMessa /// private readonly FormatterProfile rpcProfile; - private readonly ProgressConverterResolver progressConverterResolver; - - private readonly AsyncEnumerableConverterResolver asyncEnumerableConverterResolver; - - private readonly PipeConverterResolver pipeConverterResolver; - - private readonly MessagePackExceptionConverterResolver exceptionResolver; - private readonly ToStringHelper serializationToStringHelper = new(); private readonly ToStringHelper deserializationToStringHelper = new(); @@ -71,7 +63,7 @@ public NerdbankMessagePackFormatter() MessagePackSerializer serializer = new() { InternStrings = true, - SerializeDefaultValues = true, + SerializeDefaultValues = false, StartingContext = new SerializationContext() { [SerializationContextExtensions.FormatterKey] = this, @@ -86,17 +78,11 @@ public NerdbankMessagePackFormatter() serializer, [ShapeProvider_StreamJsonRpc.Default]); - // Create the specialized formatters/resolvers that we will inject into the chain for user data. - this.progressConverterResolver = new ProgressConverterResolver(this); - this.asyncEnumerableConverterResolver = new AsyncEnumerableConverterResolver(this); - this.pipeConverterResolver = new PipeConverterResolver(this); - this.exceptionResolver = new MessagePackExceptionConverterResolver(this); - // Create a serializer for user data. MessagePackSerializer userSerializer = new() { InternStrings = true, - SerializeDefaultValues = true, + SerializeDefaultValues = false, StartingContext = new SerializationContext() { [SerializationContextExtensions.FormatterKey] = this, @@ -108,7 +94,7 @@ public NerdbankMessagePackFormatter() userSerializer.RegisterConverter(RequestIdConverter.Instance); // We preset this one because for some protocols like IProgress, tokens are passed in that we must relay exactly back to the client as an argument. - // userSerializer.RegisterConverter(RawMessagePackFormatter.Instance); + userSerializer.RegisterConverter(new TraceParentConverter()); userSerializer.RegisterConverter(EventArgsConverter.Instance); this.userDataProfile = new FormatterProfile( @@ -116,7 +102,7 @@ public NerdbankMessagePackFormatter() userSerializer, [ReflectionTypeShapeProvider.Default]); - this.ProfileBuilder = new FormatterProfileBuilder(this, this.userDataProfile); + this.ProfileBuilder = new FormatterProfileBuilder(this.userDataProfile); } private interface IJsonRpcMessagePackRetention @@ -142,7 +128,7 @@ private interface IJsonRpcMessagePackRetention public void SetFormatterProfile(FormatterProfile profile) { Requires.NotNull(profile, nameof(profile)); - this.userDataProfile = profile; + this.userDataProfile = profile.WithFormatterState(this); } /// @@ -153,8 +139,7 @@ public void SetFormatterProfile(Func { Requires.NotNull(configure, nameof(configure)); - var builder = new FormatterProfileBuilder(this, this.userDataProfile); - + var builder = new FormatterProfileBuilder(this.userDataProfile); this.SetFormatterProfile(configure(builder)); } @@ -348,10 +333,6 @@ bool TryGetSerializationInfo(MemberInfo memberInfo, out string key) private static ReadOnlySequence GetSliceForNextToken(ref MessagePackReader reader, in SerializationContext context) { return reader.ReadRaw(context); - ////SequencePosition startingPosition = reader.Position; - ////reader.Skip(context); - ////SequencePosition endingPosition = reader.Position; - ////return reader.Sequence.Slice(startingPosition, endingPosition); } /// @@ -369,7 +350,7 @@ private static string ReadProtocolVersion(ref MessagePackReader reader) else { // TODO: Should throw? - return reader.ReadString() ?? string.Empty; + return reader.ReadString() ?? throw new MessagePackSerializationException(Resources.RequiredArgumentMissing); } } @@ -1126,94 +1107,19 @@ public unsafe override void Write(ref MessagePackWriter writer, in TraceParent v } } - private class RequestIdConverter : MessagePackConverter + private static class ProgressConverterResolver { - internal static readonly RequestIdConverter Instance = new(); - - private RequestIdConverter() - { - } - - public override RequestId Read(ref MessagePackReader reader, SerializationContext context) - { - context.DepthStep(); - - if (reader.NextMessagePackType == MessagePackType.Integer) - { - return new RequestId(reader.ReadInt64()); - } - else - { - // Do *not* read as an interned string here because this ID should be unique. - return new RequestId(reader.ReadString()); - } - } - - public override void Write(ref MessagePackWriter writer, in RequestId value, SerializationContext context) - { - context.DepthStep(); - - if (value.Number.HasValue) - { - writer.Write(value.Number.Value); - } - else - { - writer.Write(value.String); - } - } - - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => JsonNode.Parse(""" - { - "type": ["string", { "type": "integer", "format": "int64" }] - } - """)?.AsObject(); - } - - private class RawMessagePackFormatter : MessagePackConverter - { - internal static readonly RawMessagePackFormatter Instance = new(); - - private RawMessagePackFormatter() - { - } - - public override RawMessagePack Read(ref MessagePackReader reader, SerializationContext context) - { - return new RawMessagePack(reader.ReadRaw(context)); - } - - public override void Write(ref MessagePackWriter writer, in RawMessagePack value, SerializationContext context) - { - writer.WriteRaw(value); - } - - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) - { - return CreateUndocumentedSchema(typeof(RawMessagePackFormatter)); - } - } - - private class ProgressConverterResolver - { - private readonly NerdbankMessagePackFormatter mainFormatter; - - internal ProgressConverterResolver(NerdbankMessagePackFormatter formatter) - { - this.mainFormatter = formatter; - } - - public MessagePackConverter GetConverter() + public static MessagePackConverter GetConverter() { MessagePackConverter? converter = default; if (MessageFormatterProgressTracker.CanDeserialize(typeof(T))) { - converter = new FullProgressConverter(this.mainFormatter); + converter = new FullProgressConverter(); } else if (MessageFormatterProgressTracker.CanSerialize(typeof(T))) { - converter = new ProgressClientConverter(this.mainFormatter); + converter = new ProgressClientConverter(); } // TODO: Improve Exception @@ -1225,13 +1131,6 @@ public MessagePackConverter GetConverter() /// private class ProgressClientConverter : MessagePackConverter { - private readonly NerdbankMessagePackFormatter formatter; - - internal ProgressClientConverter(NerdbankMessagePackFormatter formatter) - { - this.formatter = formatter; - } - public override TClass Read(ref MessagePackReader reader, SerializationContext context) { throw new NotSupportedException("This formatter only serializes IProgress instances."); @@ -1239,6 +1138,8 @@ public override TClass Read(ref MessagePackReader reader, SerializationContext c public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); if (value is null) @@ -1247,7 +1148,7 @@ public override void Write(ref MessagePackWriter writer, in TClass? value, Seria } else { - long progressId = this.formatter.FormatterProgressTracker.GetTokenForProgress(value); + long progressId = formatter.FormatterProgressTracker.GetTokenForProgress(value); writer.Write(progressId); } } @@ -1263,16 +1164,11 @@ public override void Write(ref MessagePackWriter writer, in TClass? value, Seria /// private class FullProgressConverter : MessagePackConverter { - private readonly NerdbankMessagePackFormatter formatter; - - internal FullProgressConverter(NerdbankMessagePackFormatter formatter) - { - this.formatter = formatter; - } - [return: MaybeNull] public override TClass? Read(ref MessagePackReader reader, SerializationContext context) { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); if (reader.TryReadNil()) @@ -1280,14 +1176,16 @@ internal FullProgressConverter(NerdbankMessagePackFormatter formatter) return default!; } - Assumes.NotNull(this.formatter.JsonRpc); - RawMessagePack token = (RawMessagePack)reader.ReadRaw(context); - bool clientRequiresNamedArgs = this.formatter.ApplicableMethodAttributeOnDeserializingMethod?.ClientRequiresNamedArguments is true; - return (TClass)this.formatter.FormatterProgressTracker.CreateProgress(this.formatter.JsonRpc, token, typeof(TClass), clientRequiresNamedArgs); + Assumes.NotNull(formatter.JsonRpc); + ReadOnlySequence token = reader.ReadRaw(context); + bool clientRequiresNamedArgs = formatter.ApplicableMethodAttributeOnDeserializingMethod?.ClientRequiresNamedArguments is true; + return (TClass)formatter.FormatterProgressTracker.CreateProgress(formatter.JsonRpc, token, typeof(TClass), clientRequiresNamedArgs); } public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); if (value is null) @@ -1296,7 +1194,7 @@ public override void Write(ref MessagePackWriter writer, in TClass? value, Seria } else { - long progressId = this.formatter.FormatterProgressTracker.GetTokenForProgress(value); + long progressId = formatter.FormatterProgressTracker.GetTokenForProgress(value); writer.Write(progressId); } } @@ -1308,30 +1206,21 @@ public override void Write(ref MessagePackWriter writer, in TClass? value, Seria } } - private class AsyncEnumerableConverterResolver + private static class AsyncEnumerableConverterResolver { - private readonly NerdbankMessagePackFormatter mainFormatter; - - internal AsyncEnumerableConverterResolver(NerdbankMessagePackFormatter formatter) - { - this.mainFormatter = formatter; - } - - public MessagePackConverter GetConverter() + public static MessagePackConverter GetConverter() { MessagePackConverter? converter = default; if (TrackerHelpers>.IsActualInterfaceMatch(typeof(T))) { converter = (MessagePackConverter?)Activator.CreateInstance( - typeof(PreciseTypeConverter<>).MakeGenericType(typeof(T).GenericTypeArguments[0]), - [this.mainFormatter]); + typeof(PreciseTypeConverter<>).MakeGenericType(typeof(T).GenericTypeArguments[0])); } else if (TrackerHelpers>.FindInterfaceImplementedBy(typeof(T)) is { } iface) { converter = (MessagePackConverter?)Activator.CreateInstance( - typeof(GeneratorConverter<,>).MakeGenericType(typeof(T), iface.GenericTypeArguments[0]), - [this.mainFormatter]); + typeof(GeneratorConverter<,>).MakeGenericType(typeof(T), iface.GenericTypeArguments[0])); } // TODO: Improve Exception @@ -1343,7 +1232,7 @@ public MessagePackConverter GetConverter() /// or an into an enumeration token. /// #pragma warning disable CA1812 - private partial class PreciseTypeConverter(NerdbankMessagePackFormatter mainFormatter) : MessagePackConverter> + private partial class PreciseTypeConverter : MessagePackConverter> #pragma warning restore CA1812 { /// @@ -1363,6 +1252,8 @@ private partial class PreciseTypeConverter(NerdbankMessagePackFormatter mainF return default; } + NerdbankMessagePackFormatter mainFormatter = context.GetFormatter(); + context.DepthStep(); RawMessagePack? token = default; @@ -1391,6 +1282,8 @@ private partial class PreciseTypeConverter(NerdbankMessagePackFormatter mainF public override void Write(ref MessagePackWriter writer, in IAsyncEnumerable? value, SerializationContext context) { context.DepthStep(); + + NerdbankMessagePackFormatter mainFormatter = context.GetFormatter(); Serialize_Shared(mainFormatter, ref writer, value, context); } @@ -1442,7 +1335,7 @@ internal static void Serialize_Shared(NerdbankMessagePackFormatter mainFormatter /// Converts an instance of to an enumeration token. /// #pragma warning disable CA1812 - private class GeneratorConverter(NerdbankMessagePackFormatter mainFormatter) : MessagePackConverter + private class GeneratorConverter : MessagePackConverter where TClass : IAsyncEnumerable #pragma warning restore CA1812 { @@ -1454,6 +1347,8 @@ public override TClass Read(ref MessagePackReader reader, SerializationContext c [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to helper method")] public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) { + NerdbankMessagePackFormatter mainFormatter = context.GetFormatter(); + context.DepthStep(); PreciseTypeConverter.Serialize_Shared(mainFormatter, ref writer, value, context); } @@ -1465,47 +1360,40 @@ public override void Write(ref MessagePackWriter writer, in TClass? value, Seria } } - private class PipeConverterResolver + private static class PipeConverterResolver { - private readonly NerdbankMessagePackFormatter mainFormatter; - - internal PipeConverterResolver(NerdbankMessagePackFormatter formatter) - { - this.mainFormatter = formatter; - } - - public MessagePackConverter GetConverter() + public static MessagePackConverter GetConverter() { MessagePackConverter? converter = default; if (typeof(IDuplexPipe).IsAssignableFrom(typeof(T))) { - converter = (MessagePackConverter?)Activator.CreateInstance(typeof(DuplexPipeConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; + converter = (MessagePackConverter?)Activator.CreateInstance(typeof(DuplexPipeConverter<>).MakeGenericType(typeof(T)))!; } else if (typeof(PipeReader).IsAssignableFrom(typeof(T))) { - converter = (MessagePackConverter?)Activator.CreateInstance(typeof(PipeReaderConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; + converter = (MessagePackConverter?)Activator.CreateInstance(typeof(PipeReaderConverter<>).MakeGenericType(typeof(T)))!; } else if (typeof(PipeWriter).IsAssignableFrom(typeof(T))) { - converter = (MessagePackConverter?)Activator.CreateInstance(typeof(PipeWriterConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; + converter = (MessagePackConverter?)Activator.CreateInstance(typeof(PipeWriterConverter<>).MakeGenericType(typeof(T)))!; } else if (typeof(Stream).IsAssignableFrom(typeof(T))) { - converter = (MessagePackConverter?)Activator.CreateInstance(typeof(StreamConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; + converter = (MessagePackConverter?)Activator.CreateInstance(typeof(StreamConverter<>).MakeGenericType(typeof(T)))!; } // TODO: Improve Exception return converter ?? throw new NotSupportedException(); } -#pragma warning disable CA1812 - private class DuplexPipeConverter(NerdbankMessagePackFormatter formatter) : MessagePackConverter + private class DuplexPipeConverter : MessagePackConverter where T : class, IDuplexPipe -#pragma warning restore CA1812 { public override T? Read(ref MessagePackReader reader, SerializationContext context) { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); if (reader.TryReadNil()) @@ -1518,9 +1406,11 @@ private class DuplexPipeConverter(NerdbankMessagePackFormatter formatter) : M public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); - if (formatter.DuplexPipeTracker.GetULongToken(value) is { } token) + if (formatter.DuplexPipeTracker.GetULongToken(value) is ulong token) { writer.Write(token); } @@ -1536,13 +1426,13 @@ public override void Write(ref MessagePackWriter writer, in T? value, Serializat } } -#pragma warning disable CA1812 - private class PipeReaderConverter(NerdbankMessagePackFormatter formatter) : MessagePackConverter + private class PipeReaderConverter : MessagePackConverter where T : PipeReader -#pragma warning restore CA1812 { public override T? Read(ref MessagePackReader reader, SerializationContext context) { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); if (reader.TryReadNil()) { @@ -1554,6 +1444,8 @@ private class PipeReaderConverter(NerdbankMessagePackFormatter formatter) : M public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); if (formatter.DuplexPipeTracker.GetULongToken(value) is { } token) { @@ -1571,13 +1463,13 @@ public override void Write(ref MessagePackWriter writer, in T? value, Serializat } } -#pragma warning disable CA1812 - private class PipeWriterConverter(NerdbankMessagePackFormatter formatter) : MessagePackConverter + private class PipeWriterConverter : MessagePackConverter where T : PipeWriter -#pragma warning restore CA1812 { public override T? Read(ref MessagePackReader reader, SerializationContext context) { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); if (reader.TryReadNil()) { @@ -1589,8 +1481,10 @@ private class PipeWriterConverter(NerdbankMessagePackFormatter formatter) : M public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); - if (formatter.DuplexPipeTracker.GetULongToken(value) is { } token) + if (formatter.DuplexPipeTracker.GetULongToken(value) is ulong token) { writer.Write(token); } @@ -1609,30 +1503,25 @@ public override void Write(ref MessagePackWriter writer, in T? value, Serializat private class StreamConverter : MessagePackConverter where T : Stream { - private readonly NerdbankMessagePackFormatter formatter; - - public StreamConverter(NerdbankMessagePackFormatter formatter) - { - this.formatter = formatter; - } - public override T? Read(ref MessagePackReader reader, SerializationContext context) { - context.DepthStep(); + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); if (reader.TryReadNil()) { return null; } - return (T)this.formatter.DuplexPipeTracker.GetPipe(reader.ReadUInt64()).AsStream(); + return (T)formatter.DuplexPipeTracker.GetPipe(reader.ReadUInt64()).AsStream(); } public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) { - context.DepthStep(); + NerdbankMessagePackFormatter formatter = context.GetFormatter(); - if (this.formatter.DuplexPipeTracker.GetULongToken(value?.UsePipe()) is { } token) + context.DepthStep(); + if (formatter.DuplexPipeTracker.GetULongToken(value?.UsePipe()) is { } token) { writer.Write(token); } @@ -1649,48 +1538,6 @@ public override void Write(ref MessagePackWriter writer, in T? value, Serializat } } - private class RpcMarshalableConverter( - NerdbankMessagePackFormatter formatter, - JsonRpcProxyOptions proxyOptions, - JsonRpcTargetOptions targetOptions, - RpcMarshalableAttribute rpcMarshalableAttribute) : MessagePackConverter - where T : class - { - [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Reader is passed to rpc context")] - public override T? Read(ref MessagePackReader reader, SerializationContext context) - { - context.DepthStep(); - - MessageFormatterRpcMarshaledContextTracker.MarshalToken? token = formatter.rpcProfile - .Deserialize( - ref reader, - context.CancellationToken); - - return token.HasValue ? (T?)formatter.RpcMarshaledContextTracker.GetObject(typeof(T), token, proxyOptions) : null; - } - - [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to rpc context")] - public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) - { - context.DepthStep(); - - if (value is null) - { - writer.WriteNil(); - } - else - { - MessageFormatterRpcMarshaledContextTracker.MarshalToken token = formatter.RpcMarshaledContextTracker.GetToken(value, targetOptions, typeof(T), rpcMarshalableAttribute); - formatter.rpcProfile.Serialize(ref writer, token, context.CancellationToken); - } - } - - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) - { - return CreateUndocumentedSchema(typeof(RpcMarshalableConverter)); - } - } - /// /// Manages serialization of any -derived type that follows standard rules. /// @@ -1700,7 +1547,7 @@ public override void Write(ref MessagePackWriter writer, in T? value, Serializat /// 2. Be attributed with /// 3. Declare a constructor with a signature of (, ). /// - private class MessagePackExceptionConverterResolver + private static class MessagePackExceptionConverterResolver { /// /// Tracks recursion count while serializing or deserializing an exception. @@ -1711,19 +1558,12 @@ private class MessagePackExceptionConverterResolver /// private static ThreadLocal exceptionRecursionCounter = new(); - private readonly object[] formatterActivationArgs; - - internal MessagePackExceptionConverterResolver(NerdbankMessagePackFormatter formatter) - { - this.formatterActivationArgs = new object[] { formatter }; - } - - public MessagePackConverter GetConverter() + public static MessagePackConverter GetConverter() { MessagePackConverter? formatter = null; if (typeof(Exception).IsAssignableFrom(typeof(T)) && typeof(T).GetCustomAttribute() is object) { - formatter = (MessagePackConverter)Activator.CreateInstance(typeof(ExceptionConverter<>).MakeGenericType(typeof(T)), this.formatterActivationArgs)!; + formatter = (MessagePackConverter)Activator.CreateInstance(typeof(ExceptionConverter<>).MakeGenericType(typeof(T)))!; } // TODO: Improve Exception @@ -1731,12 +1571,13 @@ public MessagePackConverter GetConverter() } #pragma warning disable CA1812 - private partial class ExceptionConverter(NerdbankMessagePackFormatter formatter) : MessagePackConverter + private partial class ExceptionConverter : MessagePackConverter where T : Exception #pragma warning restore CA1812 { public override T? Read(ref MessagePackReader reader, SerializationContext context) { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); Assumes.NotNull(formatter.JsonRpc); context.DepthStep(); @@ -1786,8 +1627,9 @@ private partial class ExceptionConverter(NerdbankMessagePackFormatter formatt public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) { - context.DepthStep(); + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); if (value is null) { writer.WriteNil(); @@ -1804,7 +1646,7 @@ public override void Write(ref MessagePackWriter writer, in T? value, Serializat return; } - // TODO: Is this the right context? + // TODO: Is this the right profile? var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(formatter.rpcProfile)); ExceptionSerializationHelpers.Serialize(value, info); writer.WriteMapHeader(info.GetSafeMemberCount()); @@ -1831,6 +1673,51 @@ public override void Write(ref MessagePackWriter writer, in T? value, Serializat } } + private class RpcMarshalableConverter( + JsonRpcProxyOptions proxyOptions, + JsonRpcTargetOptions targetOptions, + RpcMarshalableAttribute rpcMarshalableAttribute) : MessagePackConverter + where T : class + { + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Reader is passed to rpc context")] + public override T? Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + MessageFormatterRpcMarshaledContextTracker.MarshalToken? token = formatter.rpcProfile + .Deserialize( + ref reader, + context.CancellationToken); + + return token.HasValue ? (T?)formatter.RpcMarshaledContextTracker.GetObject(typeof(T), token, proxyOptions) : null; + } + + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to rpc context")] + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + if (value is null) + { + writer.WriteNil(); + } + else + { + MessageFormatterRpcMarshaledContextTracker.MarshalToken token = formatter.RpcMarshaledContextTracker.GetToken(value, targetOptions, typeof(T), rpcMarshalableAttribute); + formatter.rpcProfile.Serialize(ref writer, token, context.CancellationToken); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(RpcMarshalableConverter)); + } + } + /// /// Enables formatting the default/empty class. /// @@ -1864,6 +1751,50 @@ public override EventArgs Read(ref MessagePackReader reader, SerializationContex } } + private class RequestIdConverter : MessagePackConverter + { + internal static readonly RequestIdConverter Instance = new(); + + private RequestIdConverter() + { + } + + public override RequestId Read(ref MessagePackReader reader, SerializationContext context) + { + context.DepthStep(); + + if (reader.NextMessagePackType == MessagePackType.Integer) + { + return new RequestId(reader.ReadInt64()); + } + else + { + // Do *not* read as an interned string here because this ID should be unique. + return new RequestId(reader.ReadString()); + } + } + + public override void Write(ref MessagePackWriter writer, in RequestId value, SerializationContext context) + { + context.DepthStep(); + + if (value.Number.HasValue) + { + writer.Write(value.Number.Value); + } + else + { + writer.Write(value.String); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => JsonNode.Parse(""" + { + "type": ["string", { "type": "integer", "format": "int64" }] + } + """)?.AsObject(); + } + private class TopLevelPropertyBag : TopLevelPropertyBagBase { private readonly FormatterProfile formatterProfile; @@ -1873,12 +1804,12 @@ private class TopLevelPropertyBag : TopLevelPropertyBagBase /// Initializes a new instance of the class /// for an incoming message. /// - /// The serializer options to use for this data. + /// The profile use for this data. /// The map of unrecognized inbound properties. - internal TopLevelPropertyBag(FormatterProfile userDataContext, IReadOnlyDictionary> inboundUnknownProperties) + internal TopLevelPropertyBag(FormatterProfile formatterProfile, IReadOnlyDictionary> inboundUnknownProperties) : base(isOutbound: false) { - this.formatterProfile = userDataContext; + this.formatterProfile = formatterProfile; this.inboundUnknownProperties = inboundUnknownProperties; } @@ -1886,7 +1817,7 @@ internal TopLevelPropertyBag(FormatterProfile userDataContext, IReadOnlyDictiona /// Initializes a new instance of the class /// for an outbound message. /// - /// The serializer options to use for this data. + /// The profile to use for this data. internal TopLevelPropertyBag(FormatterProfile formatterProfile) : base(isOutbound: true) { @@ -1951,7 +1882,7 @@ private class OutboundJsonRpcRequest : JsonRpcRequestBase internal OutboundJsonRpcRequest(NerdbankMessagePackFormatter formatter) { - this.formatter = formatter ?? throw new ArgumentNullException(nameof(formatter)); + this.formatter = formatter; } protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.formatter.userDataProfile); @@ -1964,7 +1895,7 @@ private class JsonRpcRequest : JsonRpcRequestBase, IJsonRpcMessagePackRetention internal JsonRpcRequest(NerdbankMessagePackFormatter formatter) { - this.formatter = formatter ?? throw new ArgumentNullException(nameof(formatter)); + this.formatter = formatter; } public override int ArgumentCount => this.MsgPackNamedArguments?.Count ?? this.MsgPackPositionalArguments?.Count ?? base.ArgumentCount; @@ -2127,16 +2058,16 @@ protected override void ReleaseBuffers() [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] private class JsonRpcError : JsonRpcErrorBase, IJsonRpcMessagePackRetention { - private readonly FormatterProfile formatterContext; + private readonly FormatterProfile formatterProfile; - public JsonRpcError(FormatterProfile serializerOptions) + public JsonRpcError(FormatterProfile formatterProfile) { - this.formatterContext = serializerOptions; + this.formatterProfile = formatterProfile; } public ReadOnlySequence OriginalMessagePack { get; internal set; } - protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.formatterContext); + protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.formatterProfile); protected override void ReleaseBuffers() { @@ -2179,7 +2110,6 @@ internal ErrorDetail(FormatterProfile formatterProfile) try { // return MessagePackSerializer.Deserialize(this.MsgPackData, this.serializerOptions.WithResolver(PrimitiveObjectResolver.Instance)); - // TODO: Which Shape Provider to use? return this.formatterProfile.Deserialize(this.MsgPackData); } catch (MessagePackSerializationException) diff --git a/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs index 50ee951a3..1c1134f7d 100644 --- a/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs +++ b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs @@ -409,10 +409,13 @@ protected override void InitializeFormattersAndHandlers( static NerdbankMessagePackFormatter.FormatterProfile Configure(NerdbankMessagePackFormatter.FormatterProfileBuilder b) { b.RegisterAsyncEnumerableType, UnionBaseClass>(); + b.RegisterAsyncEnumerableType, UnionDerivedClass>(); b.RegisterConverter(new UnserializableTypeConverter()); b.RegisterConverter(new TypeThrowsWhenDeserializedConverter()); b.RegisterConverter(new CustomExtensionConverter()); b.RegisterStreamType(); + b.RegisterProgressType, UnionBaseClass>(); + b.RegisterProgressType, UnionDerivedClass>(); b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); return b.Build(); diff --git a/test/StreamJsonRpc.Tests/TestBase.cs b/test/StreamJsonRpc.Tests/TestBase.cs index b242a7721..9f8a07a74 100644 --- a/test/StreamJsonRpc.Tests/TestBase.cs +++ b/test/StreamJsonRpc.Tests/TestBase.cs @@ -5,8 +5,9 @@ using System.Reflection; using System.Runtime.Serialization; using Microsoft.VisualStudio.Threading; +using PolyType; -public abstract class TestBase : IDisposable +public abstract partial class TestBase : IDisposable { protected static readonly TimeSpan ExpectedTimeout = TimeSpan.FromMilliseconds(200); @@ -198,6 +199,9 @@ protected virtual async Task CheckGCPressureAsync(Func scenario, int maxBy Assert.True(passingAttemptObserved); } + [GenerateShape.CustomType>] + internal partial class Witness; + #pragma warning disable SYSLIB0050 // Type or member is obsolete private class RoundtripFormatter : IFormatterConverter #pragma warning restore SYSLIB0050 // Type or member is obsolete From bb22ee8102b58f15aaaf560653d151359ac6b07e Mon Sep 17 00:00:00 2001 From: Charles Willis Date: Fri, 27 Dec 2024 20:04:17 +0000 Subject: [PATCH 19/25] Refactor and add comments for Profile and Profile.Builder. --- ...gePackFormatter.FormatterProfileBuilder.cs | 188 ----------------- ...Formatter.MessagePackFormatterConverter.cs | 4 +- ...ankMessagePackFormatter.Profile.Builder.cs | 194 ++++++++++++++++++ ...> NerdbankMessagePackFormatter.Profile.cs} | 10 +- .../NerdbankMessagePackFormatter.cs | 64 +++--- ...nkMessagePackFormatterProfileExtensions.cs | 22 +- ...AsyncEnumerableNerdbankMessagePackTests.cs | 9 +- ...DisposableProxyNerdbankMessagePackTests.cs | 1 - ...xPipeMarshalingNerdbankMessagePackTests.cs | 2 - .../JsonRpcNerdbankMessagePackLengthTests.cs | 3 +- ...arshalableProxyNerdbankMessagePackTests.cs | 1 - .../NerdbankMessagePackFormatterTests.cs | 9 +- 12 files changed, 257 insertions(+), 250 deletions(-) delete mode 100644 src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfileBuilder.cs create mode 100644 src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs rename src/StreamJsonRpc/{NerdbankMessagePackFormatter.FormatterProfile.cs => NerdbankMessagePackFormatter.Profile.cs} (91%) diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfileBuilder.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfileBuilder.cs deleted file mode 100644 index 03adb3cc0..000000000 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfileBuilder.cs +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. - -using System.Collections.Immutable; -using System.IO.Pipelines; -using Nerdbank.MessagePack; -using PolyType; -using StreamJsonRpc.Reflection; - -namespace StreamJsonRpc; - -/// -/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). -/// -public sealed partial class NerdbankMessagePackFormatter -{ - /// - /// Provides methods to build a serialization profile for the . - /// - public class FormatterProfileBuilder - { - private readonly FormatterProfile baseProfile; - - private ImmutableArray.Builder? typeShapeProvidersBuilder = null; - - /// - /// Initializes a new instance of the class. - /// - /// The base profile to build upon. - internal FormatterProfileBuilder(FormatterProfile baseProfile) - { - this.baseProfile = baseProfile; - } - - /// - /// Adds a type shape provider to the context. - /// - /// The type shape provider to add. - public void AddTypeShapeProvider(ITypeShapeProvider provider) - { - this.typeShapeProvidersBuilder ??= ImmutableArray.CreateBuilder(); - this.typeShapeProvidersBuilder.Add(provider); - } - - /// - /// Registers an async enumerable type with the context. - /// - /// The type of the async enumerable. - /// The type of the elements in the async enumerable. - public void RegisterAsyncEnumerableType() - where TEnumerable : IAsyncEnumerable - { - MessagePackConverter converter = AsyncEnumerableConverterResolver.GetConverter(); - this.baseProfile.Serializer.RegisterConverter(converter); - } - - /// - /// Registers a converter with the context. - /// - /// The type the converter handles. - /// The converter to register. - public void RegisterConverter(MessagePackConverter converter) - { - this.baseProfile.Serializer.RegisterConverter(converter); - } - - /// - /// Registers known subtypes for a base type with the context. - /// - /// The base type. - /// The mapping of known subtypes. - public void RegisterKnownSubTypes(KnownSubTypeMapping mapping) - { - this.baseProfile.Serializer.RegisterKnownSubTypes(mapping); - } - - /// - /// Registers a progress type with the context. - /// - /// The type of the progress. - /// The type of the report. - public void RegisterProgressType() - where TProgress : IProgress - { - MessagePackConverter converter = ProgressConverterResolver.GetConverter(); - this.baseProfile.Serializer.RegisterConverter(converter); - } - - /// - /// Registers a duplex pipe type with the context. - /// - /// The type of the duplex pipe. - public void RegisterDuplexPipeType() - where TPipe : IDuplexPipe - { - MessagePackConverter converter = PipeConverterResolver.GetConverter(); - this.baseProfile.Serializer.RegisterConverter(converter); - } - - /// - /// Registers a pipe reader type with the context. - /// - /// The type of the pipe reader. - public void RegisterPipeReaderType() - where TReader : PipeReader - { - MessagePackConverter converter = PipeConverterResolver.GetConverter(); - this.baseProfile.Serializer.RegisterConverter(converter); - } - - /// - /// Registers a pipe writer type with the context. - /// - /// The type of the pipe writer. - public void RegisterPipeWriterType() - where TWriter : PipeWriter - { - MessagePackConverter converter = PipeConverterResolver.GetConverter(); - this.baseProfile.Serializer.RegisterConverter(converter); - } - - /// - /// Registers a stream type with the context. - /// - /// The type of the stream. - public void RegisterStreamType() - where TStream : Stream - { - MessagePackConverter converter = PipeConverterResolver.GetConverter(); - this.baseProfile.Serializer.RegisterConverter(converter); - } - - /// - /// Registers an exception type with the context. - /// - /// The type of the exception. - public void RegisterExceptionType() - where TException : Exception - { - MessagePackConverter converter = MessagePackExceptionConverterResolver.GetConverter(); - this.baseProfile.Serializer.RegisterConverter(converter); - } - - /// - /// Registers an RPC marshalable type with the context. - /// - /// The type to register. - public void RegisterRpcMarshalableType() - where T : class - { - if (MessageFormatterRpcMarshaledContextTracker.TryGetMarshalOptionsForType( - typeof(T), - out JsonRpcProxyOptions? proxyOptions, - out JsonRpcTargetOptions? targetOptions, - out RpcMarshalableAttribute? attribute)) - { - var converter = (RpcMarshalableConverter)Activator.CreateInstance( - typeof(RpcMarshalableConverter<>).MakeGenericType(typeof(T)), - proxyOptions, - targetOptions, - attribute)!; - - this.baseProfile.Serializer.RegisterConverter(converter); - return; - } - - // TODO: Throw? - throw new NotSupportedException(); - } - - /// - /// Builds the formatter profile. - /// - /// The built formatter profile. - public FormatterProfile Build() - { - if (this.typeShapeProvidersBuilder is null || this.typeShapeProvidersBuilder.Count < 1) - { - return this.baseProfile; - } - - return new FormatterProfile( - this.baseProfile.Source, - this.baseProfile.Serializer, - this.typeShapeProvidersBuilder.ToImmutable()); - } - } -} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.MessagePackFormatterConverter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.MessagePackFormatterConverter.cs index e6a28a3ae..47ab7aa9b 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.MessagePackFormatterConverter.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.MessagePackFormatterConverter.cs @@ -14,9 +14,9 @@ public partial class NerdbankMessagePackFormatter { private class MessagePackFormatterConverter : IFormatterConverter { - private readonly FormatterProfile formatterContext; + private readonly Profile formatterContext; - internal MessagePackFormatterConverter(FormatterProfile formatterContext) + internal MessagePackFormatterConverter(Profile formatterContext) { this.formatterContext = formatterContext; } diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs new file mode 100644 index 000000000..53c081398 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs @@ -0,0 +1,194 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Collections.Immutable; +using System.IO.Pipelines; +using Nerdbank.MessagePack; +using PolyType; +using StreamJsonRpc.Reflection; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + /// + /// A serialization profile for the . + /// + public partial class Profile + { + /// + /// Provides methods to build a serialization profile for the . + /// + public class Builder + { + private readonly Profile baseProfile; + + private ImmutableArray.Builder? typeShapeProvidersBuilder = null; + + /// + /// Initializes a new instance of the class. + /// + /// The base profile to build upon. + internal Builder(Profile baseProfile) + { + this.baseProfile = baseProfile; + } + + /// + /// Adds a type shape provider to the context. + /// + /// The type shape provider to add. + public void AddTypeShapeProvider(ITypeShapeProvider provider) + { + this.typeShapeProvidersBuilder ??= ImmutableArray.CreateBuilder(); + this.typeShapeProvidersBuilder.Add(provider); + } + + /// + /// Registers an async enumerable type with the context. + /// + /// The type of the async enumerable. + /// The type of the elements in the async enumerable. + public void RegisterAsyncEnumerableType() + where TEnumerable : IAsyncEnumerable + { + MessagePackConverter converter = AsyncEnumerableConverterResolver.GetConverter(); + this.baseProfile.Serializer.RegisterConverter(converter); + } + + /// + /// Registers a converter with the context. + /// + /// The type the converter handles. + /// The converter to register. + public void RegisterConverter(MessagePackConverter converter) + { + this.baseProfile.Serializer.RegisterConverter(converter); + } + + /// + /// Registers known subtypes for a base type with the context. + /// + /// The base type. + /// The mapping of known subtypes. + public void RegisterKnownSubTypes(KnownSubTypeMapping mapping) + { + this.baseProfile.Serializer.RegisterKnownSubTypes(mapping); + } + + /// + /// Registers a progress type with the context. + /// + /// The type of the progress. + /// The type of the report. + public void RegisterProgressType() + where TProgress : IProgress + { + MessagePackConverter converter = ProgressConverterResolver.GetConverter(); + this.baseProfile.Serializer.RegisterConverter(converter); + } + + /// + /// Registers a duplex pipe type with the context. + /// + /// The type of the duplex pipe. + public void RegisterDuplexPipeType() + where TPipe : IDuplexPipe + { + MessagePackConverter converter = PipeConverterResolver.GetConverter(); + this.baseProfile.Serializer.RegisterConverter(converter); + } + + /// + /// Registers a pipe reader type with the context. + /// + /// The type of the pipe reader. + public void RegisterPipeReaderType() + where TReader : PipeReader + { + MessagePackConverter converter = PipeConverterResolver.GetConverter(); + this.baseProfile.Serializer.RegisterConverter(converter); + } + + /// + /// Registers a pipe writer type with the context. + /// + /// The type of the pipe writer. + public void RegisterPipeWriterType() + where TWriter : PipeWriter + { + MessagePackConverter converter = PipeConverterResolver.GetConverter(); + this.baseProfile.Serializer.RegisterConverter(converter); + } + + /// + /// Registers a stream type with the context. + /// + /// The type of the stream. + public void RegisterStreamType() + where TStream : Stream + { + MessagePackConverter converter = PipeConverterResolver.GetConverter(); + this.baseProfile.Serializer.RegisterConverter(converter); + } + + /// + /// Registers an exception type with the context. + /// + /// The type of the exception. + public void RegisterExceptionType() + where TException : Exception + { + MessagePackConverter converter = MessagePackExceptionConverterResolver.GetConverter(); + this.baseProfile.Serializer.RegisterConverter(converter); + } + + /// + /// Registers an RPC marshalable type with the context. + /// + /// The type to register. + public void RegisterRpcMarshalableType() + where T : class + { + if (MessageFormatterRpcMarshaledContextTracker.TryGetMarshalOptionsForType( + typeof(T), + out JsonRpcProxyOptions? proxyOptions, + out JsonRpcTargetOptions? targetOptions, + out RpcMarshalableAttribute? attribute)) + { + var converter = (RpcMarshalableConverter)Activator.CreateInstance( + typeof(RpcMarshalableConverter<>).MakeGenericType(typeof(T)), + proxyOptions, + targetOptions, + attribute)!; + + this.baseProfile.Serializer.RegisterConverter(converter); + return; + } + + // TODO: Throw? + throw new NotSupportedException(); + } + + /// + /// Builds the formatter profile. + /// + /// The built formatter profile. + public Profile Build() + { + if (this.typeShapeProvidersBuilder is null || this.typeShapeProvidersBuilder.Count < 1) + { + return this.baseProfile; + } + + return new Profile( + this.baseProfile.Source, + this.baseProfile.Serializer, + this.typeShapeProvidersBuilder.ToImmutable()); + } + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfile.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.cs similarity index 91% rename from src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfile.cs rename to src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.cs index de1499530..e8d1d4579 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterProfile.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.cs @@ -15,20 +15,20 @@ namespace StreamJsonRpc; public sealed partial class NerdbankMessagePackFormatter { /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// The MessagePack serializer to use. /// The type shape providers to use. [DebuggerDisplay($"{{{nameof(GetDebuggerDisplay)}(),nq}}")] - public class FormatterProfile(MessagePackSerializer serializer, ImmutableArray shapeProviders) + public partial class Profile(MessagePackSerializer serializer, ImmutableArray shapeProviders) { /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// The source of the profile. /// The MessagePack serializer to use. /// The type shape providers to use. - internal FormatterProfile(ProfileSource source, MessagePackSerializer serializer, ImmutableArray shapeProviders) + internal Profile(ProfileSource source, MessagePackSerializer serializer, ImmutableArray shapeProviders) : this(serializer, shapeProviders) { this.Source = source; @@ -61,7 +61,7 @@ internal enum ProfileSource private int ProvidersCount => shapeProviders.Length; - internal FormatterProfile WithFormatterState(NerdbankMessagePackFormatter formatter) + internal Profile WithFormatterState(NerdbankMessagePackFormatter formatter) { SerializationContext nextContext = serializer.StartingContext; nextContext[SerializationContextExtensions.FormatterKey] = formatter; diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs index d8ff95bef..3d5351fc5 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs @@ -43,7 +43,7 @@ public partial class NerdbankMessagePackFormatter : FormatterBase, IJsonRpcMessa /// /// The serializer context to use for top-level RPC messages. /// - private readonly FormatterProfile rpcProfile; + private readonly Profile rpcProfile; private readonly ToStringHelper serializationToStringHelper = new(); @@ -52,7 +52,7 @@ public partial class NerdbankMessagePackFormatter : FormatterBase, IJsonRpcMessa /// /// The serializer to use for user data (e.g. arguments, return values and errors). /// - private FormatterProfile userDataProfile; + private Profile userDataProfile; /// /// Initializes a new instance of the class. @@ -73,8 +73,8 @@ public NerdbankMessagePackFormatter() serializer.RegisterConverter(RequestIdConverter.Instance); serializer.RegisterConverter(new TraceParentConverter()); - this.rpcProfile = new FormatterProfile( - FormatterProfile.ProfileSource.Internal, + this.rpcProfile = new Profile( + Profile.ProfileSource.Internal, serializer, [ShapeProvider_StreamJsonRpc.Default]); @@ -94,15 +94,14 @@ public NerdbankMessagePackFormatter() userSerializer.RegisterConverter(RequestIdConverter.Instance); // We preset this one because for some protocols like IProgress, tokens are passed in that we must relay exactly back to the client as an argument. - userSerializer.RegisterConverter(new TraceParentConverter()); userSerializer.RegisterConverter(EventArgsConverter.Instance); - this.userDataProfile = new FormatterProfile( - FormatterProfile.ProfileSource.External, + this.userDataProfile = new Profile( + Profile.ProfileSource.External, userSerializer, [ReflectionTypeShapeProvider.Default]); - this.ProfileBuilder = new FormatterProfileBuilder(this.userDataProfile); + this.ProfileBuilder = new Profile.Builder(this.userDataProfile); } private interface IJsonRpcMessagePackRetention @@ -119,28 +118,43 @@ private interface IJsonRpcMessagePackRetention /// /// Gets the profile builder for the formatter. /// - public FormatterProfileBuilder ProfileBuilder { get; } + public Profile.Builder ProfileBuilder { get; } /// - /// Sets the formatter profile. + /// Sets the formatter profile for user data. /// + /// + /// + /// For improved startup performance, use + /// to configure a reusable profile and set it here for each instance of this formatter. + /// The profile must be configured before any messages are serialized or deserialized. + /// + /// + /// If not set, a default profile is used which will resolve types using reflection emit. + /// + /// /// The formatter profile to set. - public void SetFormatterProfile(FormatterProfile profile) + public void SetFormatterProfile(Profile profile) { Requires.NotNull(profile, nameof(profile)); this.userDataProfile = profile.WithFormatterState(this); } /// - /// Configures the serialization context for user data with the specified configuration action. + /// Configures the formatter profile for user data with the specified configuration action. /// - /// The action to configure the serialization context. - public void SetFormatterProfile(Func configure) + /// + /// Generally prefer using over this method + /// as it is more efficient to reuse a profile across multiple instances of this formatter. + /// + /// The configuration action. + public void SetFormatterProfile(Action configure) { Requires.NotNull(configure, nameof(configure)); - var builder = new FormatterProfileBuilder(this.userDataProfile); - this.SetFormatterProfile(configure(builder)); + var builder = new Profile.Builder(this.userDataProfile); + configure(builder); + this.SetFormatterProfile(builder.Build()); } /// @@ -1797,7 +1811,7 @@ public override void Write(ref MessagePackWriter writer, in RequestId value, Ser private class TopLevelPropertyBag : TopLevelPropertyBagBase { - private readonly FormatterProfile formatterProfile; + private readonly Profile formatterProfile; private readonly IReadOnlyDictionary>? inboundUnknownProperties; /// @@ -1806,7 +1820,7 @@ private class TopLevelPropertyBag : TopLevelPropertyBagBase /// /// The profile use for this data. /// The map of unrecognized inbound properties. - internal TopLevelPropertyBag(FormatterProfile formatterProfile, IReadOnlyDictionary> inboundUnknownProperties) + internal TopLevelPropertyBag(Profile formatterProfile, IReadOnlyDictionary> inboundUnknownProperties) : base(isOutbound: false) { this.formatterProfile = formatterProfile; @@ -1818,7 +1832,7 @@ internal TopLevelPropertyBag(FormatterProfile formatterProfile, IReadOnlyDiction /// for an outbound message. /// /// The profile to use for this data. - internal TopLevelPropertyBag(FormatterProfile formatterProfile) + internal TopLevelPropertyBag(Profile formatterProfile) : base(isOutbound: true) { this.formatterProfile = formatterProfile; @@ -1998,11 +2012,11 @@ protected override void ReleaseBuffers() private partial class JsonRpcResult : JsonRpcResultBase, IJsonRpcMessagePackRetention { private readonly NerdbankMessagePackFormatter formatter; - private readonly FormatterProfile formatterProfile; + private readonly Profile formatterProfile; private Exception? resultDeserializationException; - internal JsonRpcResult(NerdbankMessagePackFormatter formatter, FormatterProfile formatterProfile) + internal JsonRpcResult(NerdbankMessagePackFormatter formatter, Profile formatterProfile) { this.formatter = formatter; this.formatterProfile = formatterProfile; @@ -2058,9 +2072,9 @@ protected override void ReleaseBuffers() [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] private class JsonRpcError : JsonRpcErrorBase, IJsonRpcMessagePackRetention { - private readonly FormatterProfile formatterProfile; + private readonly Profile formatterProfile; - public JsonRpcError(FormatterProfile formatterProfile) + public JsonRpcError(Profile formatterProfile) { this.formatterProfile = formatterProfile; } @@ -2082,9 +2096,9 @@ protected override void ReleaseBuffers() internal new class ErrorDetail : Protocol.JsonRpcError.ErrorDetail { - private readonly FormatterProfile formatterProfile; + private readonly Profile formatterProfile; - internal ErrorDetail(FormatterProfile formatterProfile) + internal ErrorDetail(Profile formatterProfile) { this.formatterProfile = formatterProfile ?? throw new ArgumentNullException(nameof(formatterProfile)); } diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatterProfileExtensions.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatterProfileExtensions.cs index 32ee75adc..f3550c49d 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatterProfileExtensions.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatterProfileExtensions.cs @@ -8,33 +8,33 @@ namespace StreamJsonRpc; /// -/// Extension methods for that are specific to the . +/// Extension methods for that are specific to the . /// [System.Diagnostics.CodeAnalysis.SuppressMessage("ApiDesign", "RS0016:Add public types and members to the declared API", Justification = "TODO: Temporary for development")] public static class NerdbankMessagePackFormatterProfileExtensions { /// - /// Serializes an object using the specified . + /// Serializes an object using the specified . /// /// The formatter profile to use for serialization. /// The writer to which the object will be serialized. /// The object to serialize. /// A token to monitor for cancellation requests. - public static void SerializeObject(this FormatterProfile profile, ref MessagePackWriter writer, object? value, CancellationToken cancellationToken = default) + public static void SerializeObject(this Profile profile, ref MessagePackWriter writer, object? value, CancellationToken cancellationToken = default) { Requires.NotNull(profile, nameof(profile)); SerializeObject(profile, ref writer, value, value?.GetType() ?? typeof(object), cancellationToken); } /// - /// Serializes an object using the specified . + /// Serializes an object using the specified . /// /// The formatter profile to use for serialization. /// The writer to which the object will be serialized. /// The object to serialize. /// A token to monitor for cancellation requests. /// The type of the object to serialize. - public static void Serialize(this FormatterProfile profile, ref MessagePackWriter writer, T? value, CancellationToken cancellationToken = default) + public static void Serialize(this Profile profile, ref MessagePackWriter writer, T? value, CancellationToken cancellationToken = default) { Requires.NotNull(profile, nameof(profile)); @@ -52,7 +52,7 @@ public static void Serialize(this FormatterProfile profile, ref MessagePackWr } /// - /// Deserializes a sequence of bytes into an object of type using the specified . + /// Deserializes a sequence of bytes into an object of type using the specified . /// /// The type of the object to deserialize. /// The formatter profile to use for deserialization. @@ -60,14 +60,14 @@ public static void Serialize(this FormatterProfile profile, ref MessagePackWr /// A token to monitor for cancellation requests. /// The deserialized object of type . /// Thrown when deserialization fails. - public static T? Deserialize(this FormatterProfile profile, in ReadOnlySequence pack, CancellationToken cancellationToken = default) + public static T? Deserialize(this Profile profile, in ReadOnlySequence pack, CancellationToken cancellationToken = default) { Requires.NotNull(profile, nameof(profile)); MessagePackReader reader = new(pack); return Deserialize(profile, ref reader, cancellationToken); } - internal static T? Deserialize(this FormatterProfile profile, ref MessagePackReader reader, CancellationToken cancellationToken = default) + internal static T? Deserialize(this Profile profile, ref MessagePackReader reader, CancellationToken cancellationToken = default) { return profile.Serializer.Deserialize( ref reader, @@ -75,13 +75,13 @@ public static void Serialize(this FormatterProfile profile, ref MessagePackWr cancellationToken); } - internal static object? DeserializeObject(this FormatterProfile profile, in ReadOnlySequence pack, Type objectType, CancellationToken cancellationToken = default) + internal static object? DeserializeObject(this Profile profile, in ReadOnlySequence pack, Type objectType, CancellationToken cancellationToken = default) { MessagePackReader reader = new(pack); return DeserializeObject(profile, ref reader, objectType, cancellationToken); } - internal static object? DeserializeObject(this FormatterProfile profile, ref MessagePackReader reader, Type objectType, CancellationToken cancellationToken = default) + internal static object? DeserializeObject(this Profile profile, ref MessagePackReader reader, Type objectType, CancellationToken cancellationToken = default) { return profile.Serializer.DeserializeObject( ref reader, @@ -89,7 +89,7 @@ public static void Serialize(this FormatterProfile profile, ref MessagePackWr cancellationToken); } - internal static void SerializeObject(this FormatterProfile profile, ref MessagePackWriter writer, object? value, Type objectType, CancellationToken cancellationToken = default) + internal static void SerializeObject(this Profile profile, ref MessagePackWriter writer, object? value, Type objectType, CancellationToken cancellationToken = default) { if (value is null) { diff --git a/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs index 3dce7a740..e1401e0e8 100644 --- a/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs @@ -11,21 +11,18 @@ public AsyncEnumerableNerdbankMessagePackTests(ITestOutputHelper logger) protected override void InitializeFormattersAndHandlers() { NerdbankMessagePackFormatter serverFormatter = new(); - NerdbankMessagePackFormatter.FormatterProfile serverProfile = ConfigureContext(serverFormatter.ProfileBuilder); - serverFormatter.SetFormatterProfile(serverProfile); + serverFormatter.SetFormatterProfile(ConfigureContext); NerdbankMessagePackFormatter clientFormatter = new(); - NerdbankMessagePackFormatter.FormatterProfile clientProfile = ConfigureContext(clientFormatter.ProfileBuilder); - clientFormatter.SetFormatterProfile(clientProfile); + clientFormatter.SetFormatterProfile(ConfigureContext); this.serverMessageFormatter = serverFormatter; this.clientMessageFormatter = clientFormatter; - static NerdbankMessagePackFormatter.FormatterProfile ConfigureContext(NerdbankMessagePackFormatter.FormatterProfileBuilder profileBuilder) + static void ConfigureContext(NerdbankMessagePackFormatter.Profile.Builder profileBuilder) { profileBuilder.RegisterAsyncEnumerableType, int>(); profileBuilder.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); - return profileBuilder.Build(); } } } diff --git a/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs index 0184a756a..052dc089f 100644 --- a/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs @@ -22,7 +22,6 @@ protected override IJsonRpcMessageFormatter CreateFormatter() b.RegisterDuplexPipeType(); b.AddTypeShapeProvider(PolyType.SourceGenerator.ShapeProvider_StreamJsonRpc_Tests.Default); b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); - return b.Build(); }); return formatter; diff --git a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs index 9b3c56efb..62da3d17d 100644 --- a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs @@ -24,7 +24,6 @@ protected override void InitializeFormattersAndHandlers() b.RegisterPipeWriterType(); b.RegisterDuplexPipeType(); b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); - return b.Build(); }); NerdbankMessagePackFormatter clientFormatter = new() @@ -38,7 +37,6 @@ protected override void InitializeFormattersAndHandlers() b.RegisterPipeWriterType(); b.RegisterDuplexPipeType(); b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); - return b.Build(); }); this.serverMessageFormatter = serverFormatter; diff --git a/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs index 1c1134f7d..ec62ca51f 100644 --- a/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs +++ b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs @@ -406,7 +406,7 @@ protected override void InitializeFormattersAndHandlers( ? new DelayedFlushingHandler(clientStream, clientMessageFormatter) : new LengthHeaderMessageHandler(clientStream, clientStream, clientMessageFormatter); - static NerdbankMessagePackFormatter.FormatterProfile Configure(NerdbankMessagePackFormatter.FormatterProfileBuilder b) + static void Configure(NerdbankMessagePackFormatter.Profile.Builder b) { b.RegisterAsyncEnumerableType, UnionBaseClass>(); b.RegisterAsyncEnumerableType, UnionDerivedClass>(); @@ -418,7 +418,6 @@ static NerdbankMessagePackFormatter.FormatterProfile Configure(NerdbankMessagePa b.RegisterProgressType, UnionDerivedClass>(); b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); - return b.Build(); } } diff --git a/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs index 443185613..2450ef533 100644 --- a/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs @@ -34,7 +34,6 @@ protected override IJsonRpcMessageFormatter CreateFormatter() b.RegisterRpcMarshalableType(); b.AddTypeShapeProvider(PolyType.SourceGenerator.ShapeProvider_StreamJsonRpc_Tests.Default); b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); - return b.Build(); }); return formatter; diff --git a/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs b/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs index e9f364601..605049fea 100644 --- a/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs +++ b/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs @@ -152,7 +152,6 @@ public void Resolver_RequestArgInNamedArgs_AnonymousType() { b.RegisterConverter(new CustomConverter()); b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); - return b.Build(); }); var originalArg = new { Prop1 = 3, Prop2 = 5 }; @@ -176,7 +175,6 @@ public void Resolver_RequestArgInNamedArgs_DataContractObject() { b.RegisterConverter(new CustomConverter()); b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); - return b.Build(); }); var originalArg = new DataContractWithSubsetOfMembersIncluded { ExcludedField = "A", ExcludedProperty = "B", IncludedField = "C", IncludedProperty = "D" }; @@ -202,7 +200,6 @@ public void Resolver_RequestArgInNamedArgs_NonDataContractObject() { b.RegisterConverter(new CustomConverter()); b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); - return b.Build(); }); var originalArg = new NonDataContractWithExcludedMembers { ExcludedField = "A", ExcludedProperty = "B", InternalField = "C", InternalProperty = "D", PublicField = "E", PublicProperty = "F" }; @@ -244,7 +241,6 @@ public void Resolver_Result() { b.RegisterConverter(new CustomConverter()); b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); - return b.Build(); }); var originalResultValue = new TypeRequiringCustomFormatter { Prop1 = 3, Prop2 = 5 }; @@ -366,7 +362,6 @@ protected override NerdbankMessagePackFormatter CreateFormatter() { b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); b.AddTypeShapeProvider(ReflectionTypeShapeProvider.Default); - return b.Build(); }); return formatter; @@ -375,10 +370,10 @@ protected override NerdbankMessagePackFormatter CreateFormatter() private T Read(object anonymousObject) where T : JsonRpcMessage { - NerdbankMessagePackFormatter.FormatterProfileBuilder profileBuilder = this.Formatter.ProfileBuilder; + NerdbankMessagePackFormatter.Profile.Builder profileBuilder = this.Formatter.ProfileBuilder; profileBuilder.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); profileBuilder.AddTypeShapeProvider(ReflectionTypeShapeProvider.Default); - NerdbankMessagePackFormatter.FormatterProfile profile = profileBuilder.Build(); + NerdbankMessagePackFormatter.Profile profile = profileBuilder.Build(); this.Formatter.SetFormatterProfile(profile); From e171cdb9f795b042c1a5a33984710791e18a7359 Mon Sep 17 00:00:00 2001 From: Charles Willis Date: Sat, 28 Dec 2024 21:22:47 +0000 Subject: [PATCH 20/25] Add additional profile configuration to test suite. --- ...ankMessagePackFormatter.Profile.Builder.cs | 20 +------ .../NerdbankMessagePackFormatter.cs | 57 ++++++++++++++----- src/StreamJsonRpc/Protocol/JsonRpcMessage.cs | 1 + ...sageFormatterRpcMarshaledContextTracker.cs | 4 +- ...AsyncEnumerableNerdbankMessagePackTests.cs | 10 ++++ .../AsyncEnumerableTests.cs | 35 ++++++++---- ...DisposableProxyNerdbankMessagePackTests.cs | 1 + .../DisposableProxyTests.cs | 2 +- ...xPipeMarshalingNerdbankMessagePackTests.cs | 2 + .../JsonRpcNerdbankMessagePackLengthTests.cs | 1 + ...erverMarshalingNerdbankMessagePackTests.cs | 22 ++++++- .../ObserverMarshalingTests.cs | 2 +- 12 files changed, 112 insertions(+), 45 deletions(-) diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs index 53c081398..08683991f 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs @@ -153,24 +153,8 @@ public void RegisterExceptionType() public void RegisterRpcMarshalableType() where T : class { - if (MessageFormatterRpcMarshaledContextTracker.TryGetMarshalOptionsForType( - typeof(T), - out JsonRpcProxyOptions? proxyOptions, - out JsonRpcTargetOptions? targetOptions, - out RpcMarshalableAttribute? attribute)) - { - var converter = (RpcMarshalableConverter)Activator.CreateInstance( - typeof(RpcMarshalableConverter<>).MakeGenericType(typeof(T)), - proxyOptions, - targetOptions, - attribute)!; - - this.baseProfile.Serializer.RegisterConverter(converter); - return; - } - - // TODO: Throw? - throw new NotSupportedException(); + MessagePackConverter converter = GetRpcMarshalableConverter(); + this.baseProfile.Serializer.RegisterConverter(converter); } /// diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs index 3d5351fc5..1c776b717 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs @@ -30,6 +30,8 @@ namespace StreamJsonRpc; /// The MessagePack implementation used here comes from https://github.com/AArnott/Nerdbank.MessagePack. /// [SuppressMessage("ApiDesign", "RS0016:Add public types and members to the declared API", Justification = "TODO: Suppressed for Development")] +[GenerateShape] +[GenerateShape] public partial class NerdbankMessagePackFormatter : FormatterBase, IJsonRpcMessageFormatter, IJsonRpcFormatterTracingCallbacks, IJsonRpcMessageFactory { /// @@ -95,6 +97,7 @@ public NerdbankMessagePackFormatter() // We preset this one because for some protocols like IProgress, tokens are passed in that we must relay exactly back to the client as an argument. userSerializer.RegisterConverter(EventArgsConverter.Instance); + userSerializer.RegisterConverter(new TraceParentConverter()); this.userDataProfile = new Profile( Profile.ProfileSource.External, @@ -234,6 +237,26 @@ void IJsonRpcFormatterTracingCallbacks.OnSerializationComplete(JsonRpcMessage me } } + internal static MessagePackConverter GetRpcMarshalableConverter() + where T : class + { + if (MessageFormatterRpcMarshaledContextTracker.TryGetMarshalOptionsForType( + typeof(T), + out JsonRpcProxyOptions? proxyOptions, + out JsonRpcTargetOptions? targetOptions, + out RpcMarshalableAttribute? attribute)) + { + return (RpcMarshalableConverter)Activator.CreateInstance( + typeof(RpcMarshalableConverter<>).MakeGenericType(typeof(T)), + proxyOptions, + targetOptions, + attribute)!; + } + + // TODO: Improve Exception message. + throw new NotSupportedException($"Type '{typeof(T).FullName}' is not supported for RPC Marshaling."); + } + /// /// Extracts a dictionary of property names and values from the specified params object. /// @@ -391,8 +414,7 @@ private static void ReadUnknownProperty(ref MessagePackReader reader, in Seriali /// /// Converts JSON-RPC messages to and from MessagePack format. /// - [GenerateShape] - internal partial class JsonRpcMessageConverter : MessagePackConverter + internal class JsonRpcMessageConverter : MessagePackConverter { /// /// Reads a JSON-RPC message from the specified MessagePack reader. @@ -480,7 +502,7 @@ public override void Write(ref MessagePackWriter writer, in JsonRpcMessage? valu /// /// Converts a JSON-RPC request message to and from MessagePack format. /// - internal partial class JsonRpcRequestConverter : MessagePackConverter + internal class JsonRpcRequestConverter : MessagePackConverter { /// /// Reads a JSON-RPC request message from the specified MessagePack reader. @@ -678,8 +700,7 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcRequ if (value.TraceParent?.Length > 0) { writer.Write(TraceParentPropertyName); - context.GetConverter(context.TypeShapeProvider) - .Write(ref writer, new TraceParent(value.TraceParent), context); + formatter.rpcProfile.Serialize(ref writer, new TraceParent(value.TraceParent)); if (value.TraceState?.Length > 0) { @@ -770,7 +791,7 @@ private static unsafe string ReadTraceState(ref MessagePackReader reader, Serial /// /// Converts a JSON-RPC result message to and from MessagePack format. /// - internal partial class JsonRpcResultConverter : MessagePackConverter + internal class JsonRpcResultConverter : MessagePackConverter { /// /// Reads a JSON-RPC result message from the specified MessagePack reader. @@ -878,7 +899,7 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcResu /// /// Converts a JSON-RPC error message to and from MessagePack format. /// - internal partial class JsonRpcErrorConverter : MessagePackConverter + internal class JsonRpcErrorConverter : MessagePackConverter { /// /// Reads a JSON-RPC error message from the specified MessagePack reader. @@ -971,7 +992,7 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcErro /// /// Converts a JSON-RPC error detail to and from MessagePack format. /// - internal partial class JsonRpcErrorDetailConverter : MessagePackConverter + internal class JsonRpcErrorDetailConverter : MessagePackConverter { private static readonly MessagePackString CodePropertyName = new("code"); private static readonly MessagePackString MessagePropertyName = new("message"); @@ -1700,10 +1721,16 @@ private class RpcMarshalableConverter( context.DepthStep(); - MessageFormatterRpcMarshaledContextTracker.MarshalToken? token = formatter.rpcProfile - .Deserialize( - ref reader, - context.CancellationToken); + // This converter instance is registered with the user data profile, + // however the shape of MarshalToken is defined by the StreamJsonRpc source generator provider. + MessageFormatterRpcMarshaledContextTracker.MarshalToken? token = context + .GetConverter(ShapeProvider_StreamJsonRpc.Default) + .Read(ref reader, context); + + ////MessageFormatterRpcMarshaledContextTracker.MarshalToken? token = formatter.rpcProfile + //// .Deserialize( + //// ref reader, + //// context.CancellationToken); return token.HasValue ? (T?)formatter.RpcMarshaledContextTracker.GetObject(typeof(T), token, proxyOptions) : null; } @@ -1722,7 +1749,11 @@ public override void Write(ref MessagePackWriter writer, in T? value, Serializat else { MessageFormatterRpcMarshaledContextTracker.MarshalToken token = formatter.RpcMarshaledContextTracker.GetToken(value, targetOptions, typeof(T), rpcMarshalableAttribute); - formatter.rpcProfile.Serialize(ref writer, token, context.CancellationToken); + ////formatter.rpcProfile.Serialize(ref writer, token, context.CancellationToken); + // This converter instance is registered with the user data profile, + // however the shape of MarshalToken is defined by the StreamJsonRpc source generator provider. + context.GetConverter(ShapeProvider_StreamJsonRpc.Default) + .Write(ref writer, token, context); } } diff --git a/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs b/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs index c69e34c1f..be652d36d 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs @@ -26,6 +26,7 @@ namespace StreamJsonRpc.Protocol; [KnownSubType(3)] #endif [MessagePackConverter(typeof(NerdbankMessagePackFormatter.JsonRpcMessageConverter))] +[GenerateShape] public abstract partial class JsonRpcMessage { /// diff --git a/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs b/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs index d91b1ea38..d61a0d121 100644 --- a/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs +++ b/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs @@ -9,6 +9,8 @@ using System.Reflection; using System.Runtime.Serialization; using Microsoft.VisualStudio.Threading; +using Nerdbank.MessagePack; +using PolyType; using static System.FormattableString; using STJ = System.Text.Json.Serialization; @@ -17,7 +19,7 @@ namespace StreamJsonRpc.Reflection; /// /// Tracks objects that get marshaled using the general marshaling protocol. /// -internal class MessageFormatterRpcMarshaledContextTracker +internal partial class MessageFormatterRpcMarshaledContextTracker { private static readonly IReadOnlyCollection<(Type ImplicitlyMarshaledType, JsonRpcProxyOptions ProxyOptions, JsonRpcTargetOptions TargetOptions, RpcMarshalableAttribute Attribute)> ImplicitlyMarshaledTypes = new (Type, JsonRpcProxyOptions, JsonRpcTargetOptions, RpcMarshalableAttribute)[] { diff --git a/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs index e1401e0e8..bb15ead4b 100644 --- a/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using PolyType; + public class AsyncEnumerableNerdbankMessagePackTests : AsyncEnumerableTests { public AsyncEnumerableNerdbankMessagePackTests(ITestOutputHelper logger) @@ -22,7 +24,15 @@ protected override void InitializeFormattersAndHandlers() static void ConfigureContext(NerdbankMessagePackFormatter.Profile.Builder profileBuilder) { profileBuilder.RegisterAsyncEnumerableType, int>(); + profileBuilder.RegisterAsyncEnumerableType, string>(); + profileBuilder.RegisterRpcMarshalableType(); + profileBuilder.AddTypeShapeProvider(AsyncEnumerableWitness.ShapeProvider); profileBuilder.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); } } } + +[GenerateShape] +#pragma warning disable SA1402 // File may only contain a single type +public partial class AsyncEnumerableWitness; +#pragma warning restore SA1402 // File may only contain a single type diff --git a/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs b/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs index 7e503ad57..70245a535 100644 --- a/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs +++ b/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs @@ -11,6 +11,7 @@ using Microsoft.VisualStudio.Threading; using Nerdbank.Streams; using Newtonsoft.Json; +using NBMP = Nerdbank.MessagePack; public abstract class AsyncEnumerableTests : TestBase, IAsyncLifetime { @@ -618,6 +619,16 @@ private async Task ReturnEnumerable_AutomaticallyReleasedOnErrorF return weakReferenceToSource; } + [DataContract] + protected internal class CompoundEnumerableResult + { + [DataMember] + public string? Message { get; set; } + + [DataMember] + public IAsyncEnumerable? Enumeration { get; set; } + } + protected class Server : IServer { /// @@ -795,18 +806,9 @@ protected class Client : IClient public Task DoSomethingAsync(CancellationToken cancellationToken) => Task.CompletedTask; } - [DataContract] - protected class CompoundEnumerableResult - { - [DataMember] - public string? Message { get; set; } - - [DataMember] - public IAsyncEnumerable? Enumeration { get; set; } - } - [JsonConverter(typeof(ThrowingJsonConverter))] [MessagePackFormatter(typeof(ThrowingMessagePackFormatter))] + [NBMP.MessagePackConverter(typeof(ThrowingMessagePackNerdbankConverter))] protected class UnserializableType { } @@ -836,4 +838,17 @@ public void Serialize(ref MessagePackWriter writer, T value, MessagePackSerializ throw new Exception(); } } + + protected class ThrowingMessagePackNerdbankConverter : NBMP.MessagePackConverter + { + public override T? Read(ref NBMP.MessagePackReader reader, NBMP.SerializationContext context) + { + throw new Exception(); + } + + public override void Write(ref NBMP.MessagePackWriter writer, in T? value, NBMP.SerializationContext context) + { + throw new Exception(); + } + } } diff --git a/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs index 052dc089f..c5affc7ef 100644 --- a/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs @@ -20,6 +20,7 @@ protected override IJsonRpcMessageFormatter CreateFormatter() { b.RegisterStreamType(); b.RegisterDuplexPipeType(); + b.RegisterRpcMarshalableType(); b.AddTypeShapeProvider(PolyType.SourceGenerator.ShapeProvider_StreamJsonRpc_Tests.Default); b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); }); diff --git a/test/StreamJsonRpc.Tests/DisposableProxyTests.cs b/test/StreamJsonRpc.Tests/DisposableProxyTests.cs index b391b2410..57be6d20f 100644 --- a/test/StreamJsonRpc.Tests/DisposableProxyTests.cs +++ b/test/StreamJsonRpc.Tests/DisposableProxyTests.cs @@ -110,7 +110,7 @@ public async Task DisposableWithinArg_IsMarshaledAndLaterCollected() await this.AssertWeakReferenceGetsCollectedAsync(weakRefs.Target); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development public async Task DisposableReturnValue_Null() { IDisposable? proxyDisposable = await this.client.GetDisposableAsync(returnNull: true); diff --git a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs index 62da3d17d..2b9f0427b 100644 --- a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs @@ -23,6 +23,7 @@ protected override void InitializeFormattersAndHandlers() b.RegisterPipeReaderType(); b.RegisterPipeWriterType(); b.RegisterDuplexPipeType(); + b.RegisterDuplexPipeType(); b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); }); @@ -36,6 +37,7 @@ protected override void InitializeFormattersAndHandlers() b.RegisterPipeReaderType(); b.RegisterPipeWriterType(); b.RegisterDuplexPipeType(); + b.RegisterDuplexPipeType(); b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); }); diff --git a/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs index ec62ca51f..19daa7a43 100644 --- a/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs +++ b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs @@ -416,6 +416,7 @@ static void Configure(NerdbankMessagePackFormatter.Profile.Builder b) b.RegisterStreamType(); b.RegisterProgressType, UnionBaseClass>(); b.RegisterProgressType, UnionDerivedClass>(); + b.RegisterProgressType, int>(); b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); } diff --git a/test/StreamJsonRpc.Tests/ObserverMarshalingNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/ObserverMarshalingNerdbankMessagePackTests.cs index 7f46b1da4..e1d82e4ea 100644 --- a/test/StreamJsonRpc.Tests/ObserverMarshalingNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/ObserverMarshalingNerdbankMessagePackTests.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using PolyType; + public class ObserverMarshalingNerdbankMessagePackTests : ObserverMarshalingTests { public ObserverMarshalingNerdbankMessagePackTests(ITestOutputHelper logger) @@ -8,5 +10,23 @@ public ObserverMarshalingNerdbankMessagePackTests(ITestOutputHelper logger) { } - protected override IJsonRpcMessageFormatter CreateFormatter() => new NerdbankMessagePackFormatter(); + protected override IJsonRpcMessageFormatter CreateFormatter() + { + NerdbankMessagePackFormatter formatter = new(); + formatter.SetFormatterProfile(b => + { + b.RegisterRpcMarshalableType>(); + b.RegisterRpcMarshalableType(); + b.RegisterExceptionType(); + b.AddTypeShapeProvider(PolyType.SourceGenerator.ShapeProvider_StreamJsonRpc_Tests.Default); + b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); + }); + + return formatter; + } } + +[GenerateShape] +#pragma warning disable SA1402 // File may only contain a single type +internal partial class ObserverMarshalingWitness; +#pragma warning restore SA1402 // File may only contain a single type diff --git a/test/StreamJsonRpc.Tests/ObserverMarshalingTests.cs b/test/StreamJsonRpc.Tests/ObserverMarshalingTests.cs index c17c1ed11..3dbb47897 100644 --- a/test/StreamJsonRpc.Tests/ObserverMarshalingTests.cs +++ b/test/StreamJsonRpc.Tests/ObserverMarshalingTests.cs @@ -84,7 +84,7 @@ public async Task ReturnThenPushSequence() Assert.Equal(Enumerable.Range(1, 3), result); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development public async Task FaultImmediately() { var observer = new MockObserver(); From d36d6079adb35795c7899f3e1e87113fee62be6c Mon Sep 17 00:00:00 2001 From: Charles Willis Date: Sat, 28 Dec 2024 22:17:46 +0000 Subject: [PATCH 21/25] Update Benchmarks --- test/Benchmarks/InvokeBenchmarks.cs | 3 ++- test/Benchmarks/NotifyBenchmarks.cs | 3 ++- test/Benchmarks/ShortLivedConnectionBenchmarks.cs | 3 ++- test/Benchmarks/run.cmd | 2 +- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/test/Benchmarks/InvokeBenchmarks.cs b/test/Benchmarks/InvokeBenchmarks.cs index 78fc982cf..7a7a409ac 100644 --- a/test/Benchmarks/InvokeBenchmarks.cs +++ b/test/Benchmarks/InvokeBenchmarks.cs @@ -15,7 +15,7 @@ public class InvokeBenchmarks private JsonRpc clientRpc = null!; private JsonRpc serverRpc = null!; - [Params("JSON", "MessagePack")] + [Params("JSON", "MessagePack", "NerdbankMessagePack")] public string Formatter { get; set; } = null!; [GlobalSetup] @@ -35,6 +35,7 @@ IJsonRpcMessageHandler CreateHandler(IDuplexPipe pipe) { "JSON" => new HeaderDelimitedMessageHandler(pipe, new JsonMessageFormatter()), "MessagePack" => new LengthHeaderMessageHandler(pipe, new MessagePackFormatter()), + "NerdbankMessagePack" => new LengthHeaderMessageHandler(pipe, new NerdbankMessagePackFormatter()), _ => throw Assumes.NotReachable(), }; } diff --git a/test/Benchmarks/NotifyBenchmarks.cs b/test/Benchmarks/NotifyBenchmarks.cs index 92fe6ef47..5924e52f4 100644 --- a/test/Benchmarks/NotifyBenchmarks.cs +++ b/test/Benchmarks/NotifyBenchmarks.cs @@ -12,7 +12,7 @@ public class NotifyBenchmarks { private JsonRpc clientRpc = null!; - [Params("JSON", "MessagePack")] + [Params("JSON", "MessagePack", "NerdbankMessagePack")] public string Formatter { get; set; } = null!; [GlobalSetup] @@ -26,6 +26,7 @@ IJsonRpcMessageHandler CreateHandler(Stream pipe) { "JSON" => new HeaderDelimitedMessageHandler(pipe, new JsonMessageFormatter()), "MessagePack" => new LengthHeaderMessageHandler(pipe, pipe, new MessagePackFormatter()), + "NerdbankMessagePack" => new LengthHeaderMessageHandler(pipe, pipe, new NerdbankMessagePackFormatter()), _ => throw Assumes.NotReachable(), }; } diff --git a/test/Benchmarks/ShortLivedConnectionBenchmarks.cs b/test/Benchmarks/ShortLivedConnectionBenchmarks.cs index fe58eab77..0a59cee53 100644 --- a/test/Benchmarks/ShortLivedConnectionBenchmarks.cs +++ b/test/Benchmarks/ShortLivedConnectionBenchmarks.cs @@ -14,7 +14,7 @@ public class ShortLivedConnectionBenchmarks { private const int Iterations = 1000; - [Params("JSON", "MessagePack")] + [Params("JSON", "MessagePack", "NerdbankMessagePack")] public string Formatter { get; set; } = null!; [Benchmark(OperationsPerInvoke = Iterations)] @@ -39,6 +39,7 @@ IJsonRpcMessageHandler CreateHandler(IDuplexPipe pipe) { "JSON" => new HeaderDelimitedMessageHandler(pipe, new JsonMessageFormatter()), "MessagePack" => new LengthHeaderMessageHandler(pipe, new MessagePackFormatter()), + "NerdbankMessagePack" => new LengthHeaderMessageHandler(pipe, new NerdbankMessagePackFormatter()), _ => throw Assumes.NotReachable(), }; } diff --git a/test/Benchmarks/run.cmd b/test/Benchmarks/run.cmd index 81d737eb3..963044933 100644 --- a/test/Benchmarks/run.cmd +++ b/test/Benchmarks/run.cmd @@ -1,3 +1,3 @@ @pushd "%~dp0\" -dotnet run -f net6.0 -c release -- --runtimes net472 net6.0 %* +dotnet run -f net8.0 -c release -- --runtimes net472 net8.0 %* @popd From 949b3abeb1d9feff5fe7bd2c2ac9130eead5ae86 Mon Sep 17 00:00:00 2001 From: Charles Willis Date: Sun, 29 Dec 2024 23:23:35 +0000 Subject: [PATCH 22/25] Upgrade NB.MB version; Update tests. --- Directory.Packages.props | 2 +- nuget.config | 1 + ...ankMessagePackFormatter.Profile.Builder.cs | 61 +++++++++++++++---- .../NerdbankMessagePackFormatter.cs | 25 ++++---- ...AsyncEnumerableNerdbankMessagePackTests.cs | 6 +- ...DisposableProxyNerdbankMessagePackTests.cs | 13 ++-- ...xPipeMarshalingNerdbankMessagePackTests.cs | 36 +++++------ .../DuplexPipeMarshalingTests.cs | 41 +++++++------ .../JsonRpcNerdbankMessagePackLengthTests.cs | 12 ++-- test/StreamJsonRpc.Tests/JsonRpcTests.cs | 12 ++-- ...erverMarshalingNerdbankMessagePackTests.cs | 3 +- .../ObserverMarshalingTests.cs | 1 + test/StreamJsonRpc.Tests/TestBase.cs | 2 +- 13 files changed, 137 insertions(+), 78 deletions(-) diff --git a/Directory.Packages.props b/Directory.Packages.props index 1e2caeb7b..564e04b1e 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -22,7 +22,7 @@ - + diff --git a/nuget.config b/nuget.config index 8ef7b7880..ee3500d0c 100644 --- a/nuget.config +++ b/nuget.config @@ -21,6 +21,7 @@ + diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs index 08683991f..dd1c98559 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs @@ -38,7 +38,7 @@ internal Builder(Profile baseProfile) } /// - /// Adds a type shape provider to the context. + /// Adds a type shape provider to the profile. /// /// The type shape provider to add. public void AddTypeShapeProvider(ITypeShapeProvider provider) @@ -48,7 +48,7 @@ public void AddTypeShapeProvider(ITypeShapeProvider provider) } /// - /// Registers an async enumerable type with the context. + /// Registers an async enumerable type with the profile. /// /// The type of the async enumerable. /// The type of the elements in the async enumerable. @@ -60,7 +60,17 @@ public void RegisterAsyncEnumerableType() } /// - /// Registers a converter with the context. + /// Registers an async enumerable type with the profile. + /// + /// The type of the elements in the async enumerable. + public void RegisterAsyncEnumerableType() + { + MessagePackConverter> converter = AsyncEnumerableConverterResolver.GetConverter>(); + this.baseProfile.Serializer.RegisterConverter(converter); + } + + /// + /// Registers a converter with the profile. /// /// The type the converter handles. /// The converter to register. @@ -70,7 +80,7 @@ public void RegisterConverter(MessagePackConverter converter) } /// - /// Registers known subtypes for a base type with the context. + /// Registers known subtypes for a base type with the profile. /// /// The base type. /// The mapping of known subtypes. @@ -80,7 +90,7 @@ public void RegisterKnownSubTypes(KnownSubTypeMapping mapping) } /// - /// Registers a progress type with the context. + /// Registers a progress type with the profile. /// /// The type of the progress. /// The type of the report. @@ -92,7 +102,17 @@ public void RegisterProgressType() } /// - /// Registers a duplex pipe type with the context. + /// Registers a progress type with the profile. + /// + /// The type of the report. + public void RegisterProgressType() + { + MessagePackConverter> converter = ProgressConverterResolver.GetConverter>(); + this.baseProfile.Serializer.RegisterConverter(converter); + } + + /// + /// Registers a duplex pipe type with the profile. /// /// The type of the duplex pipe. public void RegisterDuplexPipeType() @@ -103,7 +123,7 @@ public void RegisterDuplexPipeType() } /// - /// Registers a pipe reader type with the context. + /// Registers a pipe reader type with the profile. /// /// The type of the pipe reader. public void RegisterPipeReaderType() @@ -114,7 +134,7 @@ public void RegisterPipeReaderType() } /// - /// Registers a pipe writer type with the context. + /// Registers a pipe writer type with the profile. /// /// The type of the pipe writer. public void RegisterPipeWriterType() @@ -125,7 +145,7 @@ public void RegisterPipeWriterType() } /// - /// Registers a stream type with the context. + /// Registers a stream type with the profile. /// /// The type of the stream. public void RegisterStreamType() @@ -136,7 +156,7 @@ public void RegisterStreamType() } /// - /// Registers an exception type with the context. + /// Registers an exception type with the profile. /// /// The type of the exception. public void RegisterExceptionType() @@ -147,7 +167,7 @@ public void RegisterExceptionType() } /// - /// Registers an RPC marshalable type with the context. + /// Registers an RPC marshalable type with the profile. /// /// The type to register. public void RegisterRpcMarshalableType() @@ -157,6 +177,25 @@ public void RegisterRpcMarshalableType() this.baseProfile.Serializer.RegisterConverter(converter); } + /// + /// Registers an observer type with the profile. + /// + /// + /// + /// To register with the profile, + /// call this method with the type parameter . + /// + /// + /// RegisterObserver<int>(); + /// + /// + /// The type of the observer. + public void RegisterObserver() + { + MessagePackConverter> converter = GetRpcMarshalableConverter>(); + this.baseProfile.Serializer.RegisterConverter(converter); + } + /// /// Builds the formatter profile. /// diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs index 1c776b717..f5dad3537 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs @@ -73,7 +73,6 @@ public NerdbankMessagePackFormatter() }; serializer.RegisterConverter(RequestIdConverter.Instance); - serializer.RegisterConverter(new TraceParentConverter()); this.rpcProfile = new Profile( Profile.ProfileSource.Internal, @@ -97,7 +96,14 @@ public NerdbankMessagePackFormatter() // We preset this one because for some protocols like IProgress, tokens are passed in that we must relay exactly back to the client as an argument. userSerializer.RegisterConverter(EventArgsConverter.Instance); - userSerializer.RegisterConverter(new TraceParentConverter()); + + // Common exotic types that we want to support. + userSerializer.RegisterConverter(GetRpcMarshalableConverter()); + userSerializer.RegisterConverter(PipeConverterResolver.GetConverter()); + userSerializer.RegisterConverter(PipeConverterResolver.GetConverter()); + userSerializer.RegisterConverter(PipeConverterResolver.GetConverter()); + userSerializer.RegisterConverter(PipeConverterResolver.GetConverter()); + userSerializer.RegisterConverter(MessagePackExceptionConverterResolver.GetConverter()); this.userDataProfile = new Profile( Profile.ProfileSource.External, @@ -1596,7 +1602,7 @@ private static class MessagePackExceptionConverterResolver public static MessagePackConverter GetConverter() { MessagePackConverter? formatter = null; - if (typeof(Exception).IsAssignableFrom(typeof(T)) && typeof(T).GetCustomAttribute() is object) + if (typeof(Exception).IsAssignableFrom(typeof(T)) && typeof(T).GetCustomAttribute() is not null) { formatter = (MessagePackConverter)Activator.CreateInstance(typeof(ExceptionConverter<>).MakeGenericType(typeof(T)))!; } @@ -1605,10 +1611,8 @@ public static MessagePackConverter GetConverter() return formatter ?? throw new NotSupportedException(); } -#pragma warning disable CA1812 private partial class ExceptionConverter : MessagePackConverter where T : Exception -#pragma warning restore CA1812 { public override T? Read(ref MessagePackReader reader, SerializationContext context) { @@ -1727,12 +1731,9 @@ private class RpcMarshalableConverter( .GetConverter(ShapeProvider_StreamJsonRpc.Default) .Read(ref reader, context); - ////MessageFormatterRpcMarshaledContextTracker.MarshalToken? token = formatter.rpcProfile - //// .Deserialize( - //// ref reader, - //// context.CancellationToken); - - return token.HasValue ? (T?)formatter.RpcMarshaledContextTracker.GetObject(typeof(T), token, proxyOptions) : null; + return token.HasValue + ? (T?)formatter.RpcMarshaledContextTracker.GetObject(typeof(T), token, proxyOptions) + : null; } [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to rpc context")] @@ -1749,7 +1750,7 @@ public override void Write(ref MessagePackWriter writer, in T? value, Serializat else { MessageFormatterRpcMarshaledContextTracker.MarshalToken token = formatter.RpcMarshaledContextTracker.GetToken(value, targetOptions, typeof(T), rpcMarshalableAttribute); - ////formatter.rpcProfile.Serialize(ref writer, token, context.CancellationToken); + // This converter instance is registered with the user data profile, // however the shape of MarshalToken is defined by the StreamJsonRpc source generator provider. context.GetConverter(ShapeProvider_StreamJsonRpc.Default) diff --git a/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs index bb15ead4b..f6ca692a0 100644 --- a/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs @@ -23,15 +23,15 @@ protected override void InitializeFormattersAndHandlers() static void ConfigureContext(NerdbankMessagePackFormatter.Profile.Builder profileBuilder) { - profileBuilder.RegisterAsyncEnumerableType, int>(); - profileBuilder.RegisterAsyncEnumerableType, string>(); - profileBuilder.RegisterRpcMarshalableType(); + profileBuilder.RegisterAsyncEnumerableType(); + profileBuilder.RegisterAsyncEnumerableType(); profileBuilder.AddTypeShapeProvider(AsyncEnumerableWitness.ShapeProvider); profileBuilder.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); } } } +[GenerateShape>] [GenerateShape] #pragma warning disable SA1402 // File may only contain a single type public partial class AsyncEnumerableWitness; diff --git a/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs index c5affc7ef..7ad1ac7b3 100644 --- a/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs @@ -3,6 +3,7 @@ using System.IO.Pipelines; using Nerdbank.MessagePack; +using PolyType; public class DisposableProxyNerdbankMessagePackTests : DisposableProxyTests { @@ -18,13 +19,17 @@ protected override IJsonRpcMessageFormatter CreateFormatter() NerdbankMessagePackFormatter formatter = new(); formatter.SetFormatterProfile(b => { - b.RegisterStreamType(); - b.RegisterDuplexPipeType(); - b.RegisterRpcMarshalableType(); - b.AddTypeShapeProvider(PolyType.SourceGenerator.ShapeProvider_StreamJsonRpc_Tests.Default); + b.AddTypeShapeProvider(DisposableProxyWitness.ShapeProvider); b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); }); return formatter; } } + +[GenerateShape] +[GenerateShape] +[GenerateShape] +#pragma warning disable SA1402 // File may only contain a single type +public partial class DisposableProxyWitness; +#pragma warning restore SA1402 // File may only contain a single type diff --git a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs index 2b9f0427b..e7726532a 100644 --- a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs @@ -3,6 +3,7 @@ using System.IO.Pipelines; using Nerdbank.Streams; +using PolyType; public class DuplexPipeMarshalingNerdbankMessagePackTests : DuplexPipeMarshalingTests { @@ -18,30 +19,31 @@ protected override void InitializeFormattersAndHandlers() MultiplexingStream = this.serverMx, }; - serverFormatter.SetFormatterProfile(b => - { - b.RegisterPipeReaderType(); - b.RegisterPipeWriterType(); - b.RegisterDuplexPipeType(); - b.RegisterDuplexPipeType(); - b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); - }); - NerdbankMessagePackFormatter clientFormatter = new() { MultiplexingStream = this.clientMx, }; - clientFormatter.SetFormatterProfile(b => - { - b.RegisterPipeReaderType(); - b.RegisterPipeWriterType(); - b.RegisterDuplexPipeType(); - b.RegisterDuplexPipeType(); - b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); - }); + serverFormatter.SetFormatterProfile(Configure); + clientFormatter.SetFormatterProfile(Configure); this.serverMessageFormatter = serverFormatter; this.clientMessageFormatter = clientFormatter; + + static void Configure(NerdbankMessagePackFormatter.Profile.Builder b) + { + b.RegisterDuplexPipeType(); + b.RegisterStreamType(); + b.RegisterStreamType(); + b.RegisterStreamType(); + b.AddTypeShapeProvider(DuplexPipeWitness.ShapeProvider); + b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); + } } } + +[GenerateShape] +[GenerateShape] +#pragma warning disable SA1402 // File may only contain a single type +public partial class DuplexPipeWitness; +#pragma warning restore SA1402 // File may only contain a single type diff --git a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingTests.cs b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingTests.cs index 7cd6fcf7e..32459382d 100644 --- a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingTests.cs +++ b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingTests.cs @@ -8,6 +8,7 @@ using System.Text; using Microsoft.VisualStudio.Threading; using Nerdbank.Streams; +using PolyType; using STJ = System.Text.Json.Serialization; public abstract class DuplexPipeMarshalingTests : TestBase, IAsyncLifetime @@ -761,6 +762,21 @@ private async Task ServerStreamIsDisposedWhenClientDisconnects(string rpcMethodN await serverStreamDisposal.WaitAsync(this.TimeoutToken); } + [DataContract] + public class StreamContainingClass + { + [DataMember] + private Stream innerStream; + + public StreamContainingClass(Stream innerStream) + { + this.innerStream = innerStream; + } + + [STJ.JsonPropertyName("innerStream")] + public Stream InnerStream => this.innerStream; + } + #pragma warning disable CA1801 // Review unused parameters protected class ServerWithOverloads { @@ -779,7 +795,9 @@ public async Task OverloadedMethod(bool writeOnOdd, IDuplexPipe pipe, string mes public void OverloadedMethod(bool foo, int value, string[] values) => Assert.NotNull(values); } - protected class Server +#pragma warning disable SA1202 // Elements should be ordered by access + public class Server +#pragma warning restore SA1202 // Elements should be ordered by access { internal Task? ChatLaterTask { get; private set; } @@ -987,7 +1005,9 @@ protected class ServerWithIDuplexPipeReturningMethod public IDuplexPipe? MethodThatReturnsIDuplexPipe() => null; } - protected class OneWayWrapperStream : Stream +#pragma warning disable SA1202 // Elements should be ordered by access + public class OneWayWrapperStream : Stream +#pragma warning restore SA1202 // Elements should be ordered by access { private readonly Stream innerStream; private readonly bool canRead; @@ -1014,8 +1034,10 @@ internal OneWayWrapperStream(Stream innerStream, bool canRead = false, bool canW public override bool CanWrite => this.canWrite && this.innerStream.CanWrite; + [PropertyShape(Ignore = true)] public override long Length => throw new NotSupportedException(); + [PropertyShape(Ignore = true)] public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); } public override void Flush() @@ -1092,19 +1114,4 @@ protected override void Dispose(bool disposing) base.Dispose(disposing); } } - - [DataContract] - protected class StreamContainingClass - { - [DataMember] - private Stream innerStream; - - public StreamContainingClass(Stream innerStream) - { - this.innerStream = innerStream; - } - - [STJ.JsonPropertyName("innerStream")] - public Stream InnerStream => this.innerStream; - } } diff --git a/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs index 19daa7a43..46d8343cc 100644 --- a/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs +++ b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs @@ -408,14 +408,13 @@ protected override void InitializeFormattersAndHandlers( static void Configure(NerdbankMessagePackFormatter.Profile.Builder b) { - b.RegisterAsyncEnumerableType, UnionBaseClass>(); - b.RegisterAsyncEnumerableType, UnionDerivedClass>(); + b.RegisterAsyncEnumerableType(); + b.RegisterAsyncEnumerableType(); b.RegisterConverter(new UnserializableTypeConverter()); b.RegisterConverter(new TypeThrowsWhenDeserializedConverter()); - b.RegisterConverter(new CustomExtensionConverter()); - b.RegisterStreamType(); - b.RegisterProgressType, UnionBaseClass>(); - b.RegisterProgressType, UnionDerivedClass>(); + b.RegisterProgressType(); + b.RegisterProgressType(); + b.RegisterProgressType(); b.RegisterProgressType, int>(); b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); @@ -440,6 +439,7 @@ public partial class UnionDerivedClass : UnionBaseClass } [GenerateShape] + [MessagePackConverter(typeof(CustomExtensionConverter))] internal partial class CustomExtensionType { } diff --git a/test/StreamJsonRpc.Tests/JsonRpcTests.cs b/test/StreamJsonRpc.Tests/JsonRpcTests.cs index f5117dc7d..f751d4171 100644 --- a/test/StreamJsonRpc.Tests/JsonRpcTests.cs +++ b/test/StreamJsonRpc.Tests/JsonRpcTests.cs @@ -14,7 +14,7 @@ using JsonNET = Newtonsoft.Json; using STJ = System.Text.Json.Serialization; -public abstract class JsonRpcTests : TestBase +public abstract partial class JsonRpcTests : TestBase { #pragma warning disable SA1310 // Field names should not contain underscore protected const int COR_E_UNAUTHORIZEDACCESS = unchecked((int)0x80070005); @@ -3907,10 +3907,12 @@ public ValueTask DisposeAsync() } [DataContract] - public class ParamsObjectWithCustomNames + [GenerateShape] + public partial class ParamsObjectWithCustomNames { [DataMember(Name = "argument")] [STJ.JsonPropertyName("argument")] + [PropertyShape(Name = "argument")] public string? TheArgument { get; set; } } @@ -3967,7 +3969,8 @@ public void IgnoredMethod() } [DataContract] - public class Foo + [GenerateShape] + public partial class Foo { [DataMember(Order = 0, IsRequired = true)] [STJ.JsonRequired, STJ.JsonPropertyOrder(0)] @@ -3980,7 +3983,8 @@ public class Foo public int Bazz { get; set; } } - public class CustomSerializedType + [GenerateShape] + public partial class CustomSerializedType { // Ignore this so default serializers will drop it, proving that custom serializers were used if the value propagates. [JsonNET.JsonIgnore] diff --git a/test/StreamJsonRpc.Tests/ObserverMarshalingNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/ObserverMarshalingNerdbankMessagePackTests.cs index e1d82e4ea..ccf419fa2 100644 --- a/test/StreamJsonRpc.Tests/ObserverMarshalingNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/ObserverMarshalingNerdbankMessagePackTests.cs @@ -16,9 +16,8 @@ protected override IJsonRpcMessageFormatter CreateFormatter() formatter.SetFormatterProfile(b => { b.RegisterRpcMarshalableType>(); - b.RegisterRpcMarshalableType(); b.RegisterExceptionType(); - b.AddTypeShapeProvider(PolyType.SourceGenerator.ShapeProvider_StreamJsonRpc_Tests.Default); + b.AddTypeShapeProvider(ObserverMarshalingWitness.ShapeProvider); b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); }); diff --git a/test/StreamJsonRpc.Tests/ObserverMarshalingTests.cs b/test/StreamJsonRpc.Tests/ObserverMarshalingTests.cs index 3dbb47897..7319d595c 100644 --- a/test/StreamJsonRpc.Tests/ObserverMarshalingTests.cs +++ b/test/StreamJsonRpc.Tests/ObserverMarshalingTests.cs @@ -6,6 +6,7 @@ using System.Runtime.CompilerServices; using Microsoft; using Microsoft.VisualStudio.Threading; +using Nerdbank.MessagePack; using Nerdbank.Streams; using StreamJsonRpc; using Xunit; diff --git a/test/StreamJsonRpc.Tests/TestBase.cs b/test/StreamJsonRpc.Tests/TestBase.cs index 9f8a07a74..8053d8bb7 100644 --- a/test/StreamJsonRpc.Tests/TestBase.cs +++ b/test/StreamJsonRpc.Tests/TestBase.cs @@ -200,7 +200,7 @@ protected virtual async Task CheckGCPressureAsync(Func scenario, int maxBy } [GenerateShape.CustomType>] - internal partial class Witness; + internal partial class TestBaseWitness; #pragma warning disable SYSLIB0050 // Type or member is obsolete private class RoundtripFormatter : IFormatterConverter From f94bf1ddcaf0b89677fb4623a259162ca2432d68 Mon Sep 17 00:00:00 2001 From: Charles Willis Date: Wed, 1 Jan 2025 01:07:41 +0000 Subject: [PATCH 23/25] Refactor serialization to use ReadOnlySequence Updated serialization and deserialization processes in the NerdbankMessagePackFormatter to utilize ReadOnlySequence instead of RawMessagePack. Introduced new classes for handling enumerator results and enhanced comments for clarity. Updated tests to ensure compatibility with the new serialization methods and added data contract attributes for better interoperability. --- ...Formatter.MessagePackFormatterConverter.cs | 35 ++++++----- ...ankMessagePackFormatter.Profile.Builder.cs | 7 +++ .../NerdbankMessagePackFormatter.cs | 63 ++++++++++++++++++- ...nkMessagePackFormatterProfileExtensions.cs | 12 ++-- .../MessageFormatterEnumerableTracker.cs | 27 ++++---- ...sageFormatterRpcMarshaledContextTracker.cs | 2 +- ...AsyncEnumerableNerdbankMessagePackTests.cs | 5 +- .../AsyncEnumerableTests.cs | 18 ++++-- .../DuplexPipeMarshalingTests.cs | 54 ++++++++++++---- .../JsonRpcNerdbankMessagePackLengthTests.cs | 4 +- ...arshalableProxyNerdbankMessagePackTests.cs | 11 ++-- .../MarshalableProxyTests.cs | 36 ++++++----- 12 files changed, 200 insertions(+), 74 deletions(-) diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.MessagePackFormatterConverter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.MessagePackFormatterConverter.cs index 47ab7aa9b..784455db9 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.MessagePackFormatterConverter.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.MessagePackFormatterConverter.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System.Buffers; using System.Runtime.Serialization; using Nerdbank.MessagePack; using StreamJsonRpc.Reflection; @@ -25,46 +26,46 @@ internal MessagePackFormatterConverter(Profile formatterContext) public object? Convert(object value, Type type) #pragma warning restore CS8766 { - return this.formatterContext.DeserializeObject((RawMessagePack)value, type); + return this.formatterContext.DeserializeObject((ReadOnlySequence)value, type); } public object Convert(object value, TypeCode typeCode) { return typeCode switch { - TypeCode.Object => this.formatterContext.Deserialize((RawMessagePack)value)!, + TypeCode.Object => this.formatterContext.Deserialize((ReadOnlySequence)value)!, _ => ExceptionSerializationHelpers.Convert(this, value, typeCode), }; } - public bool ToBoolean(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + public bool ToBoolean(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); - public byte ToByte(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + public byte ToByte(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); - public char ToChar(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + public char ToChar(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); - public DateTime ToDateTime(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + public DateTime ToDateTime(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); - public decimal ToDecimal(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + public decimal ToDecimal(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); - public double ToDouble(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + public double ToDouble(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); - public short ToInt16(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + public short ToInt16(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); - public int ToInt32(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + public int ToInt32(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); - public long ToInt64(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + public long ToInt64(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); - public sbyte ToSByte(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + public sbyte ToSByte(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); - public float ToSingle(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + public float ToSingle(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); - public string? ToString(object value) => value is null ? null : this.formatterContext.Deserialize((RawMessagePack)value); + public string? ToString(object value) => value is null ? null : this.formatterContext.Deserialize((ReadOnlySequence)value); - public ushort ToUInt16(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + public ushort ToUInt16(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); - public uint ToUInt32(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + public uint ToUInt32(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); - public ulong ToUInt64(object value) => this.formatterContext.Deserialize((RawMessagePack)value); + public ulong ToUInt64(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); } } diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs index dd1c98559..713987e08 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs @@ -62,11 +62,18 @@ public void RegisterAsyncEnumerableType() /// /// Registers an async enumerable type with the profile. /// + /// + /// To avoid the cost of reflection, ensure is + /// registered with your type shape provider. + /// /// The type of the elements in the async enumerable. public void RegisterAsyncEnumerableType() { MessagePackConverter> converter = AsyncEnumerableConverterResolver.GetConverter>(); this.baseProfile.Serializer.RegisterConverter(converter); + + MessagePackConverter> resultConverter = EnumeratorResultsConverterResolver.GetConverter(); + this.baseProfile.Serializer.RegisterConverter(resultConverter); } /// diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs index f5dad3537..c9c15fbfb 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs @@ -73,6 +73,7 @@ public NerdbankMessagePackFormatter() }; serializer.RegisterConverter(RequestIdConverter.Instance); + serializer.RegisterConverter(EventArgsConverter.Instance); this.rpcProfile = new Profile( Profile.ProfileSource.Internal, @@ -878,7 +879,9 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcResu writer.WriteNil(); } - if (value.ResultDeclaredType is not null && value.ResultDeclaredType != typeof(void)) + if (value.ResultDeclaredType is not null + && value.ResultDeclaredType != typeof(void) + && value.ResultDeclaredType != typeof(object)) { formatter.userDataProfile.SerializeObject(ref writer, value.Result, value.ResultDeclaredType, context.CancellationToken); } @@ -1712,6 +1715,64 @@ public override void Write(ref MessagePackWriter writer, in T? value, Serializat } } + private static class EnumeratorResultsConverterResolver + { + public static MessagePackConverter> GetConverter() + { + MessagePackConverter>? converter = + (EnumeratorResultsConverter?)Activator + .CreateInstance(typeof(EnumeratorResultsConverter<>) + .MakeGenericType(typeof(T))); + + return converter ?? throw new NotSupportedException($"Could not create {nameof(EnumeratorResultsConverter)}."); + } + + private class EnumeratorResultsConverter : MessagePackConverter> + { + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Reader is passed to user data context")] + public override MessageFormatterEnumerableTracker.EnumeratorResults? Read(ref MessagePackReader reader, SerializationContext context) + { + if (reader.TryReadNil()) + { + return default; + } + + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); + + Verify.Operation(reader.ReadArrayHeader() == 2, "Expected array of length 2."); + return new MessageFormatterEnumerableTracker.EnumeratorResults() + { + Values = formatter.userDataProfile.Deserialize>(ref reader, context.CancellationToken), + Finished = formatter.userDataProfile.Deserialize(ref reader, context.CancellationToken), + }; + } + + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to user data context")] + public override void Write(ref MessagePackWriter writer, in MessageFormatterEnumerableTracker.EnumeratorResults? value, SerializationContext context) + { + if (value is null) + { + writer.WriteNil(); + } + else + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); + + writer.WriteArrayHeader(2); + formatter.userDataProfile.Serialize(ref writer, value.Values, context.CancellationToken); + formatter.userDataProfile.Serialize(ref writer, value.Finished, context.CancellationToken); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(EnumeratorResultsConverter)); + } + } + } + private class RpcMarshalableConverter( JsonRpcProxyOptions proxyOptions, JsonRpcTargetOptions targetOptions, diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatterProfileExtensions.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatterProfileExtensions.cs index f3550c49d..9239f0e0c 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatterProfileExtensions.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatterProfileExtensions.cs @@ -44,10 +44,11 @@ public static void Serialize(this Profile profile, ref MessagePackWriter writ return; } + PolyType.Abstractions.ITypeShape shape = profile.ShapeProviderResolver.ResolveShape(); profile.Serializer.Serialize( ref writer, value, - profile.ShapeProviderResolver.ResolveShape(), + shape, cancellationToken); } @@ -69,9 +70,10 @@ public static void Serialize(this Profile profile, ref MessagePackWriter writ internal static T? Deserialize(this Profile profile, ref MessagePackReader reader, CancellationToken cancellationToken = default) { + PolyType.ITypeShapeProvider provider = profile.ShapeProviderResolver.ResolveShapeProvider(); return profile.Serializer.Deserialize( ref reader, - profile.ShapeProviderResolver.ResolveShapeProvider(), + provider, cancellationToken); } @@ -83,9 +85,10 @@ public static void Serialize(this Profile profile, ref MessagePackWriter writ internal static object? DeserializeObject(this Profile profile, ref MessagePackReader reader, Type objectType, CancellationToken cancellationToken = default) { + PolyType.Abstractions.ITypeShape shape = profile.ShapeProviderResolver.ResolveShape(objectType); return profile.Serializer.DeserializeObject( ref reader, - profile.ShapeProviderResolver.ResolveShape(objectType), + shape, cancellationToken); } @@ -97,10 +100,11 @@ internal static void SerializeObject(this Profile profile, ref MessagePackWriter return; } + PolyType.Abstractions.ITypeShape shape = profile.ShapeProviderResolver.ResolveShape(objectType); profile.Serializer.SerializeObject( ref writer, value, - profile.ShapeProviderResolver.ResolveShape(objectType), + shape, cancellationToken); } } diff --git a/src/StreamJsonRpc/Reflection/MessageFormatterEnumerableTracker.cs b/src/StreamJsonRpc/Reflection/MessageFormatterEnumerableTracker.cs index 927fc92cf..c7b69b0de 100644 --- a/src/StreamJsonRpc/Reflection/MessageFormatterEnumerableTracker.cs +++ b/src/StreamJsonRpc/Reflection/MessageFormatterEnumerableTracker.cs @@ -8,6 +8,7 @@ using System.Threading.Tasks.Dataflow; using Microsoft.VisualStudio.Threading; using Nerdbank.Streams; +using PolyType; using StreamJsonRpc.Protocol; using STJ = System.Text.Json.Serialization; @@ -218,6 +219,20 @@ private void CleanUpResources(RequestId outboundRequestId) } } + [DataContract] + internal class EnumeratorResults + { + [DataMember(Name = ValuesPropertyName, Order = 0)] + [STJ.JsonPropertyName(ValuesPropertyName), STJ.JsonPropertyOrder(0)] + [PropertyShape(Name = ValuesPropertyName, Order = 0)] + public IReadOnlyList? Values { get; set; } + + [DataMember(Name = FinishedPropertyName, Order = 1)] + [STJ.JsonPropertyName(FinishedPropertyName), STJ.JsonPropertyOrder(1)] + [PropertyShape(Name = FinishedPropertyName, Order = 1)] + public bool Finished { get; set; } + } + private class GeneratingEnumeratorTracker : IGeneratingEnumeratorTracker { private readonly IAsyncEnumerator enumerator; @@ -525,16 +540,4 @@ private static void Write(IBufferWriter writer, IReadOnlyList values) } } } - - [DataContract] - private class EnumeratorResults - { - [DataMember(Name = ValuesPropertyName, Order = 0)] - [STJ.JsonPropertyName(ValuesPropertyName), STJ.JsonPropertyOrder(0)] - public IReadOnlyList? Values { get; set; } - - [DataMember(Name = FinishedPropertyName, Order = 1)] - [STJ.JsonPropertyName(FinishedPropertyName), STJ.JsonPropertyOrder(1)] - public bool Finished { get; set; } - } } diff --git a/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs b/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs index d61a0d121..65f2020d1 100644 --- a/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs +++ b/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs @@ -398,7 +398,7 @@ internal MarshalToken GetToken(object marshaledObject, JsonRpcTargetOptions opti private static void ValidateMarshalableInterface(Type type, RpcMarshalableAttribute attribute) { // We only require marshalable interfaces to derive from IDisposable when they are not call-scoped. - if (!attribute.CallScopedLifetime && !typeof(IDisposable).IsAssignableFrom(type)) + if (attribute.CallScopedLifetime && !typeof(IDisposable).IsAssignableFrom(type)) { throw new NotSupportedException(string.Format(CultureInfo.CurrentCulture, Resources.MarshalableInterfaceNotDisposable, type.FullName)); } diff --git a/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs index f6ca692a0..df4cdeded 100644 --- a/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using PolyType; +using static AsyncEnumerableTests; public class AsyncEnumerableNerdbankMessagePackTests : AsyncEnumerableTests { @@ -32,7 +33,9 @@ static void ConfigureContext(NerdbankMessagePackFormatter.Profile.Builder profil } [GenerateShape>] -[GenerateShape] +[GenerateShape>] +[GenerateShape>] +[GenerateShape] #pragma warning disable SA1402 // File may only contain a single type public partial class AsyncEnumerableWitness; #pragma warning restore SA1402 // File may only contain a single type diff --git a/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs b/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs index 70245a535..f993f66b4 100644 --- a/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs +++ b/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs @@ -277,7 +277,11 @@ async IAsyncEnumerable Generator(CancellationToken cancellationToken) } else { - await this.clientRpc.InvokeWithCancellationAsync(nameof(Server.PassInNumbersAsync), new object[] { Generator(this.TimeoutToken) }, this.TimeoutToken); + await this.clientRpc.InvokeWithCancellationAsync( + nameof(Server.PassInNumbersAsync), + [Generator(this.TimeoutToken)], + [typeof(IAsyncEnumerable)], + this.TimeoutToken); } } @@ -448,8 +452,14 @@ public async Task NotifyAsync_ThrowsIfAsyncEnumerableSent() // But for a notification there's no guarantee the server handles the message and no way to get an error back, // so it simply should not be allowed since the risk of memory leak is too high. var numbers = new int[] { 1, 2, 3 }.AsAsyncEnumerable(); - await Assert.ThrowsAnyAsync(() => this.clientRpc.NotifyAsync(nameof(Server.PassInNumbersAsync), new object?[] { numbers })); - await Assert.ThrowsAnyAsync(() => this.clientRpc.NotifyAsync(nameof(Server.PassInNumbersAsync), new object?[] { new { e = numbers } })); + await Assert.ThrowsAnyAsync(() => this.clientRpc.NotifyAsync( + nameof(Server.PassInNumbersAsync), + [numbers], + [typeof(IAsyncEnumerable)])); + await Assert.ThrowsAnyAsync(() => this.clientRpc.NotifyAsync( + nameof(Server.PassInNumbersAsync), + new object?[] { new { e = numbers } }, + [typeof(IAsyncEnumerable)])); } [SkippableFact] @@ -544,7 +554,7 @@ public async Task AsyncIteratorThrows(int minBatchSize, int maxReadAhead, int pr Assert.Equal(Server.FailByDesignExceptionMessage, ex.Message); } - [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development + [Fact] public async Task EnumerableIdDisposal() { // This test is specially arranged to create two RPC calls going opposite directions, with the same request ID. diff --git a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingTests.cs b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingTests.cs index 32459382d..809982869 100644 --- a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingTests.cs +++ b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingTests.cs @@ -117,7 +117,8 @@ public async Task ClientCanSendReadOnlyPipeToServer(bool orderedArguments) { bytesReceived = await this.clientRpc.InvokeWithCancellationAsync( nameof(Server.AcceptReadablePipe), - new object[] { ExpectedFileName, pipes.Item2 }, + [ExpectedFileName, pipes.Item2], + [typeof(string), typeof(IDuplexPipe)], this.TimeoutToken); } else @@ -144,7 +145,8 @@ public async Task ClientCanSendWriteOnlyPipeToServer(bool orderedArguments) { await this.clientRpc.InvokeWithCancellationAsync( nameof(Server.AcceptWritablePipe), - new object[] { pipes.Item2, bytesToReceive }, + [pipes.Item2, bytesToReceive], + [typeof(IDuplexPipe), typeof(int)], this.TimeoutToken); } else @@ -183,7 +185,8 @@ public async Task ClientCanSendPipeReaderToServer() int bytesReceived = await this.clientRpc.InvokeWithCancellationAsync( nameof(Server.AcceptPipeReader), - new object[] { ExpectedFileName, pipe.Reader }, + [ExpectedFileName, pipe.Reader], + [typeof(string), typeof(PipeReader)], this.TimeoutToken); Assert.Equal(MemoryBuffer.Length, bytesReceived); @@ -197,7 +200,8 @@ public async Task ClientCanSendPipeWriterToServer() int bytesToReceive = MemoryBuffer.Length - 1; await this.clientRpc.InvokeWithCancellationAsync( nameof(Server.AcceptPipeWriter), - new object[] { pipe.Writer, bytesToReceive }, + [pipe.Writer, bytesToReceive], + [typeof(PipeWriter), typeof(int)], this.TimeoutToken); // Read all that the server wanted us to know, and verify it. @@ -454,7 +458,11 @@ public async Task ClientCanSendTwoWayPipeToServer(bool serverUsesStream) { (IDuplexPipe, IDuplexPipe) pipePair = FullDuplexStream.CreatePipePair(); Task twoWayCom = TwoWayTalkAsync(pipePair.Item1, writeOnOdd: true, this.TimeoutToken); - await this.clientRpc.InvokeWithCancellationAsync(serverUsesStream ? nameof(Server.TwoWayStreamAsArg) : nameof(Server.TwoWayPipeAsArg), new object[] { false, pipePair.Item2 }, this.TimeoutToken); + await this.clientRpc.InvokeWithCancellationAsync( + serverUsesStream ? nameof(Server.TwoWayStreamAsArg) : nameof(Server.TwoWayPipeAsArg), + [false, pipePair.Item2], + [typeof(bool), typeof(IDuplexPipe)], + this.TimeoutToken); await twoWayCom.WithCancellation(this.TimeoutToken); // rethrow any exceptions. // Confirm that we can see the server is no longer writing. @@ -472,7 +480,11 @@ public async Task ClientCanSendTwoWayStreamToServer(bool serverUsesStream) { (Stream, Stream) streamPair = FullDuplexStream.CreatePair(); Task twoWayCom = TwoWayTalkAsync(streamPair.Item1, writeOnOdd: true, this.TimeoutToken); - await this.clientRpc.InvokeWithCancellationAsync(serverUsesStream ? nameof(Server.TwoWayStreamAsArg) : nameof(Server.TwoWayPipeAsArg), new object[] { false, streamPair.Item2 }, this.TimeoutToken); + await this.clientRpc.InvokeWithCancellationAsync( + serverUsesStream ? nameof(Server.TwoWayStreamAsArg) : nameof(Server.TwoWayPipeAsArg), + [false, streamPair.Item2], + [typeof(bool), typeof(Stream)], + this.TimeoutToken); await twoWayCom.WithCancellation(this.TimeoutToken); // rethrow any exceptions. streamPair.Item1.Dispose(); @@ -482,7 +494,11 @@ public async Task ClientCanSendTwoWayStreamToServer(bool serverUsesStream) public async Task PipeRemainsOpenAfterSuccessfulServerResult() { (IDuplexPipe, IDuplexPipe) pipePair = FullDuplexStream.CreatePipePair(); - await this.clientRpc.InvokeWithCancellationAsync(nameof(Server.AcceptPipeAndChatLater), new object[] { false, pipePair.Item2 }, this.TimeoutToken); + await this.clientRpc.InvokeWithCancellationAsync( + nameof(Server.AcceptPipeAndChatLater), + [false, pipePair.Item2], + [typeof(bool), typeof(IDuplexPipe)], + this.TimeoutToken); await WhenAllSucceedOrAnyFault(TwoWayTalkAsync(pipePair.Item1, writeOnOdd: true, this.TimeoutToken), this.server.ChatLaterTask!); pipePair.Item1.Output.Complete(); @@ -496,7 +512,11 @@ public async Task PipeRemainsOpenAfterSuccessfulServerResult() public async Task ClientClosesChannelsWhenServerErrorsOut() { (IDuplexPipe, IDuplexPipe) pipePair = FullDuplexStream.CreatePipePair(); - await Assert.ThrowsAsync(() => this.clientRpc.InvokeWithCancellationAsync(nameof(Server.RejectCall), new object[] { pipePair.Item2 }, this.TimeoutToken)); + await Assert.ThrowsAsync(() => this.clientRpc.InvokeWithCancellationAsync( + nameof(Server.RejectCall), + [pipePair.Item2], + [typeof(IDuplexPipe)], + this.TimeoutToken)); // Verify that the pipe is closed. ReadResult readResult = await pipePair.Item1.Input.ReadAsync(this.TimeoutToken); @@ -507,7 +527,11 @@ public async Task ClientClosesChannelsWhenServerErrorsOut() public async Task PipesCloseWhenConnectionCloses() { (IDuplexPipe, IDuplexPipe) pipePair = FullDuplexStream.CreatePipePair(); - await this.clientRpc.InvokeWithCancellationAsync(nameof(Server.AcceptPipeAndChatLater), new object[] { false, pipePair.Item2 }, this.TimeoutToken); + await this.clientRpc.InvokeWithCancellationAsync( + nameof(Server.AcceptPipeAndChatLater), + [false, pipePair.Item2], + [typeof(bool), typeof(IDuplexPipe)], + this.TimeoutToken); this.clientRpc.Dispose(); @@ -527,7 +551,11 @@ public async Task ClientSendsMultiplePipes() (IDuplexPipe, IDuplexPipe) pipePair1 = FullDuplexStream.CreatePipePair(); (IDuplexPipe, IDuplexPipe) pipePair2 = FullDuplexStream.CreatePipePair(); - await this.clientRpc.InvokeWithCancellationAsync(nameof(Server.TwoPipes), new object[] { pipePair1.Item2, pipePair2.Item2 }, this.TimeoutToken); + await this.clientRpc.InvokeWithCancellationAsync( + nameof(Server.TwoPipes), + [pipePair1.Item2, pipePair2.Item2], + [typeof(IDuplexPipe), typeof(IDuplexPipe)], + this.TimeoutToken); pipePair1.Item1.Output.Complete(); pipePair2.Item1.Output.Complete(); @@ -604,7 +632,11 @@ public async Task ClientSendsPipeWhereServerHasMultipleOverloads() (IDuplexPipe, IDuplexPipe) pipePair = FullDuplexStream.CreatePipePair(); Task twoWayCom = TwoWayTalkAsync(pipePair.Item1, writeOnOdd: true, this.TimeoutToken); - await this.clientRpc.InvokeWithCancellationAsync(nameof(ServerWithOverloads.OverloadedMethod), new object[] { false, pipePair.Item2, "hi" }, this.TimeoutToken); + await this.clientRpc.InvokeWithCancellationAsync( + nameof(ServerWithOverloads.OverloadedMethod), + [false, pipePair.Item2, "hi"], + [typeof(bool), typeof(IDuplexPipe), typeof(string)], + this.TimeoutToken); await twoWayCom.WithCancellation(this.TimeoutToken); // rethrow any exceptions. pipePair.Item1.Output.Complete(); diff --git a/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs index 46d8343cc..3a3f71776 100644 --- a/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs +++ b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs @@ -408,10 +408,10 @@ protected override void InitializeFormattersAndHandlers( static void Configure(NerdbankMessagePackFormatter.Profile.Builder b) { - b.RegisterAsyncEnumerableType(); - b.RegisterAsyncEnumerableType(); b.RegisterConverter(new UnserializableTypeConverter()); b.RegisterConverter(new TypeThrowsWhenDeserializedConverter()); + b.RegisterAsyncEnumerableType(); + b.RegisterAsyncEnumerableType(); b.RegisterProgressType(); b.RegisterProgressType(); b.RegisterProgressType(); diff --git a/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs index 2450ef533..aa3f0a40b 100644 --- a/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs @@ -2,8 +2,9 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using Nerdbank.MessagePack; +using PolyType; -public class MarshalableProxyNerdbankMessagePackTests : MarshalableProxyTests +public partial class MarshalableProxyNerdbankMessagePackTests : MarshalableProxyTests { public MarshalableProxyNerdbankMessagePackTests(ITestOutputHelper logger) : base(logger) @@ -21,8 +22,6 @@ protected override IJsonRpcMessageFormatter CreateFormatter() b.RegisterRpcMarshalableType(); b.RegisterRpcMarshalableType(); b.RegisterRpcMarshalableType(); - b.RegisterRpcMarshalableType(); - b.RegisterRpcMarshalableType(); b.RegisterRpcMarshalableType(); b.RegisterRpcMarshalableType(); b.RegisterRpcMarshalableType(); @@ -32,10 +31,14 @@ protected override IJsonRpcMessageFormatter CreateFormatter() b.RegisterRpcMarshalableType(); b.RegisterRpcMarshalableType(); b.RegisterRpcMarshalableType(); - b.AddTypeShapeProvider(PolyType.SourceGenerator.ShapeProvider_StreamJsonRpc_Tests.Default); + b.RegisterRpcMarshalableType>(); + b.AddTypeShapeProvider(MarshalableProxyWitness.ShapeProvider); b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); }); return formatter; } + + [GenerateShape] + public partial class MarshalableProxyWitness; } diff --git a/test/StreamJsonRpc.Tests/MarshalableProxyTests.cs b/test/StreamJsonRpc.Tests/MarshalableProxyTests.cs index e762d7cdc..e0ab4ec7a 100644 --- a/test/StreamJsonRpc.Tests/MarshalableProxyTests.cs +++ b/test/StreamJsonRpc.Tests/MarshalableProxyTests.cs @@ -573,7 +573,7 @@ public async Task GenericMarshalableWithinArg_CanCallMethods() Assert.Equal(99, await ((IGenericMarshalable)this.server.ReceivedProxy!).DoSomethingWithParameterAsync(99)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task MarshalableReturnValue_Null() { IMarshalable? proxyMarshalable = await this.client.GetMarshalableAsync(returnNull: true); @@ -689,7 +689,7 @@ public async Task DisposeOnDisconnect() await disposed.WaitAsync(this.TimeoutToken); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface() { this.server.ReturnedMarshalableWithOptionalInterfaces = new MarshalableWithOptionalInterfaces(); @@ -707,7 +707,7 @@ public async Task RpcMarshalableOptionalInterface() AssertIsNot(proxy1, typeof(IMarshalableSubType2Extended)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface_JsonRpcMethodAttribute() { this.server.ReturnedMarshalableWithOptionalInterfaces = new MarshalableWithOptionalInterfaces(); @@ -725,7 +725,7 @@ public async Task RpcMarshalableOptionalInterface_JsonRpcMethodAttribute() Assert.Equal("foo", await this.clientRpc.InvokeAsync("$/invokeProxy/1/1.RemamedAsync", "foo")); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface_MethodNameTransform_Prefix() { var server = new Server(); @@ -753,7 +753,7 @@ public async Task RpcMarshalableOptionalInterface_MethodNameTransform_Prefix() Assert.Equal(1, await localRpc.InvokeAsync("$/invokeProxy/0/1.GetAsync", 1)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface_MethodNameTransform_CamelCase() { var server = new Server(); @@ -781,7 +781,7 @@ public async Task RpcMarshalableOptionalInterface_MethodNameTransform_CamelCase( Assert.Equal(1, await localRpc.InvokeAsync("$/invokeProxy/0/1.GetAsync", 1)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface_Null() { this.server.ReturnedMarshalableWithOptionalInterfaces = null; @@ -789,7 +789,7 @@ public async Task RpcMarshalableOptionalInterface_Null() Assert.Null(proxy); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface_IndirectInterfaceImplementation() { this.server.ReturnedMarshalableWithOptionalInterfaces = new MarshalableSubType1Indirect(); @@ -800,7 +800,7 @@ public async Task RpcMarshalableOptionalInterface_IndirectInterfaceImplementatio AssertIsNot(proxy, typeof(IMarshalableSubType2Extended)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface_WithExplicitImplementation() { this.server.ReturnedMarshalableWithOptionalInterfaces = new MarshalableSubType2(); @@ -811,7 +811,7 @@ public async Task RpcMarshalableOptionalInterface_WithExplicitImplementation() AssertIsNot(proxy, typeof(IMarshalableSubType2Extended)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface_UnknownSubType() { this.server.ReturnedMarshalableWithOptionalInterfaces = new MarshalableUnknownSubType(); @@ -822,7 +822,7 @@ public async Task RpcMarshalableOptionalInterface_UnknownSubType() AssertIsNot(proxy, typeof(IMarshalableSubType2Extended)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface_OnlyAttibutesOnDeclaredTypeAreHonored() { this.server.ReturnedMarshalableWithOptionalInterfaces = new MarshalableSubType2Extended(); @@ -837,7 +837,7 @@ public async Task RpcMarshalableOptionalInterface_OnlyAttibutesOnDeclaredTypeAre Assert.Equal(4, await ((IMarshalableSubType2Extended)proxy1).GetPlusThreeAsync(1)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface_OptionalInterfaceNotExtendingBase() { this.server.ReturnedMarshalableWithOptionalInterfaces = new MarshalableNonExtendingBase(); @@ -847,7 +847,7 @@ public async Task RpcMarshalableOptionalInterface_OptionalInterfaceNotExtendingB Assert.Equal(5, await ((IMarshalableNonExtendingBase)proxy).GetPlusFourAsync(1)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface_IntermediateNonMarshalableInterface() { this.server.ReturnedMarshalableWithOptionalInterfaces = new MarshalableSubTypeWithIntermediateInterface(); @@ -868,7 +868,7 @@ public async Task RpcMarshalableOptionalInterface_IntermediateNonMarshalableInte Assert.Equal(4, await ((IMarshalableSubTypeWithIntermediateInterface)proxy).GetPlusThreeAsync(1)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface_MultipleIntermediateInterfaces() { this.server.ReturnedMarshalableWithOptionalInterfaces = new MarshalableSubTypeWithIntermediateInterface1And2(); @@ -895,7 +895,7 @@ public async Task RpcMarshalableOptionalInterface_MultipleIntermediateInterfaces Assert.Equal(-3, await ((IMarshalableSubTypeIntermediateInterface)proxy2).GetPlusTwoAsync(1)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface_MultipleImplementations() { this.server.ReturnedMarshalableWithOptionalInterfaces = new MarshalableSubTypeMultipleImplementations(); @@ -920,7 +920,7 @@ public async Task RpcMarshalableOptionalInterface_MultipleImplementations() Assert.Equal(-1, await ((IMarshalableSubType2)proxy).GetMinusTwoAsync(1)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface_MultipleImplementationsCombined() { this.server.ReturnedMarshalableWithOptionalInterfaces = new MarshalableSubTypesCombined(); @@ -962,7 +962,7 @@ public async Task RpcMarshalable_CallScopedLifetime() Assert.False(marshaled.IsDisposed); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalable_CallScopedLifetime_AsyncEnumerableReturned() { MarshalableAndSerializable marshaled = new(); @@ -978,7 +978,7 @@ public async Task RpcMarshalable_CallScopedLifetime_AsyncEnumerableReturned() await Assert.ThrowsAsync(() => this.server.ContinuationResult).WithCancellation(this.TimeoutToken); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalable_CallScopedLifetime_AsyncEnumerableThrown() { this.clientRpc.AllowModificationWhileListening = true; @@ -1266,8 +1266,10 @@ public Data(Action? disposeAction) [DataMember] public int Value { get; set; } + [PropertyShape(Ignore = true)] public bool IsDisposed { get; private set; } + [PropertyShape(Ignore = true)] public bool DoSomethingCalled { get; private set; } public void Dispose() From eecd92c7f942d8b663b314b10be07efc72ebc332 Mon Sep 17 00:00:00 2001 From: Charles Willis Date: Wed, 1 Jan 2025 16:38:41 +0000 Subject: [PATCH 24/25] Enhance NerdbankMessagePackFormatter with new types Added new type shape providers and registered exception types in the NerdbankMessagePackFormatter class. Updated AddTypeShapeProvider to initialize with a capacity of 3. Removed methods for duplex pipes and streams, reflecting a shift in supported types. Registered new converters for handling exceptions like RemoteInvocationException. Updated CommonErrorData with shape generation attributes and modified test classes to align with these changes. --- ...ankMessagePackFormatter.Profile.Builder.cs | 53 ++--------- .../NerdbankMessagePackFormatter.cs | 94 ++++++++++++++++++- src/StreamJsonRpc/Protocol/CommonErrorData.cs | 10 +- ...xPipeMarshalingNerdbankMessagePackTests.cs | 5 +- .../DuplexPipeMarshalingTests.cs | 11 --- .../JsonRpcNerdbankMessagePackLengthTests.cs | 7 +- 6 files changed, 112 insertions(+), 68 deletions(-) diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs index 713987e08..c6dbe1ae1 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs @@ -2,7 +2,6 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System.Collections.Immutable; -using System.IO.Pipelines; using Nerdbank.MessagePack; using PolyType; using StreamJsonRpc.Reflection; @@ -43,7 +42,7 @@ internal Builder(Profile baseProfile) /// The type shape provider to add. public void AddTypeShapeProvider(ITypeShapeProvider provider) { - this.typeShapeProvidersBuilder ??= ImmutableArray.CreateBuilder(); + this.typeShapeProvidersBuilder ??= ImmutableArray.CreateBuilder(initialCapacity: 3); this.typeShapeProvidersBuilder.Add(provider); } @@ -57,6 +56,9 @@ public void RegisterAsyncEnumerableType() { MessagePackConverter converter = AsyncEnumerableConverterResolver.GetConverter(); this.baseProfile.Serializer.RegisterConverter(converter); + + MessagePackConverter> resultConverter = EnumeratorResultsConverterResolver.GetConverter(); + this.baseProfile.Serializer.RegisterConverter(resultConverter); } /// @@ -118,50 +120,6 @@ public void RegisterProgressType() this.baseProfile.Serializer.RegisterConverter(converter); } - /// - /// Registers a duplex pipe type with the profile. - /// - /// The type of the duplex pipe. - public void RegisterDuplexPipeType() - where TPipe : IDuplexPipe - { - MessagePackConverter converter = PipeConverterResolver.GetConverter(); - this.baseProfile.Serializer.RegisterConverter(converter); - } - - /// - /// Registers a pipe reader type with the profile. - /// - /// The type of the pipe reader. - public void RegisterPipeReaderType() - where TReader : PipeReader - { - MessagePackConverter converter = PipeConverterResolver.GetConverter(); - this.baseProfile.Serializer.RegisterConverter(converter); - } - - /// - /// Registers a pipe writer type with the profile. - /// - /// The type of the pipe writer. - public void RegisterPipeWriterType() - where TWriter : PipeWriter - { - MessagePackConverter converter = PipeConverterResolver.GetConverter(); - this.baseProfile.Serializer.RegisterConverter(converter); - } - - /// - /// Registers a stream type with the profile. - /// - /// The type of the stream. - public void RegisterStreamType() - where TStream : Stream - { - MessagePackConverter converter = PipeConverterResolver.GetConverter(); - this.baseProfile.Serializer.RegisterConverter(converter); - } - /// /// Registers an exception type with the profile. /// @@ -214,10 +172,11 @@ public Profile Build() return this.baseProfile; } + // ExoticTypeShapeProvider is always first and cannot be overridden. return new Profile( this.baseProfile.Source, this.baseProfile.Serializer, - this.typeShapeProvidersBuilder.ToImmutable()); + [ExoticTypeShapeProvider.Instance, .. this.typeShapeProvidersBuilder]); } } } diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs index c9c15fbfb..9625ec4d6 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System.Buffers; +using System.Collections; using System.Collections.Immutable; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; @@ -18,6 +19,7 @@ using PolyType.Abstractions; using PolyType.ReflectionProvider; using PolyType.SourceGenerator; +using PolyType.SourceGenModel; using StreamJsonRpc.Protocol; using StreamJsonRpc.Reflection; @@ -32,6 +34,12 @@ namespace StreamJsonRpc; [SuppressMessage("ApiDesign", "RS0016:Add public types and members to the declared API", Justification = "TODO: Suppressed for Development")] [GenerateShape] [GenerateShape] +[GenerateShape] +[GenerateShape] +[GenerateShape] +[GenerateShape] +[GenerateShape] +[GenerateShape] public partial class NerdbankMessagePackFormatter : FormatterBase, IJsonRpcMessageFormatter, IJsonRpcFormatterTracingCallbacks, IJsonRpcMessageFactory { /// @@ -61,6 +69,12 @@ public partial class NerdbankMessagePackFormatter : FormatterBase, IJsonRpcMessa /// public NerdbankMessagePackFormatter() { + KnownSubTypeMapping exceptionSubtypeMap = new(); + exceptionSubtypeMap.Add(alias: 1, ShapeProvider); + exceptionSubtypeMap.Add(alias: 2, ShapeProvider); + exceptionSubtypeMap.Add(alias: 3, ShapeProvider); + exceptionSubtypeMap.Add(alias: 4, ShapeProvider); + // Set up initial options for our own message types. MessagePackSerializer serializer = new() { @@ -72,13 +86,22 @@ public NerdbankMessagePackFormatter() }, }; + serializer.RegisterKnownSubTypes(exceptionSubtypeMap); serializer.RegisterConverter(RequestIdConverter.Instance); serializer.RegisterConverter(EventArgsConverter.Instance); + serializer.RegisterConverter(MessagePackExceptionConverterResolver.GetConverter()); + serializer.RegisterConverter(MessagePackExceptionConverterResolver.GetConverter()); + serializer.RegisterConverter(MessagePackExceptionConverterResolver.GetConverter()); + serializer.RegisterConverter(MessagePackExceptionConverterResolver.GetConverter()); + serializer.RegisterConverter(MessagePackExceptionConverterResolver.GetConverter()); this.rpcProfile = new Profile( Profile.ProfileSource.Internal, serializer, - [ShapeProvider_StreamJsonRpc.Default]); + [ + ExoticTypeShapeProvider.Instance, + ShapeProvider_StreamJsonRpc.Default + ]); // Create a serializer for user data. MessagePackSerializer userSerializer = new() @@ -91,6 +114,8 @@ public NerdbankMessagePackFormatter() }, }; + userSerializer.RegisterKnownSubTypes(exceptionSubtypeMap); + // Add our own resolvers to fill in specialized behavior if the user doesn't provide/override it by their own resolver. // We preset this one in user data because $/cancellation methods can carry RequestId values as arguments. userSerializer.RegisterConverter(RequestIdConverter.Instance); @@ -105,11 +130,18 @@ public NerdbankMessagePackFormatter() userSerializer.RegisterConverter(PipeConverterResolver.GetConverter()); userSerializer.RegisterConverter(PipeConverterResolver.GetConverter()); userSerializer.RegisterConverter(MessagePackExceptionConverterResolver.GetConverter()); + userSerializer.RegisterConverter(MessagePackExceptionConverterResolver.GetConverter()); + userSerializer.RegisterConverter(MessagePackExceptionConverterResolver.GetConverter()); + userSerializer.RegisterConverter(MessagePackExceptionConverterResolver.GetConverter()); + userSerializer.RegisterConverter(MessagePackExceptionConverterResolver.GetConverter()); this.userDataProfile = new Profile( Profile.ProfileSource.External, userSerializer, - [ReflectionTypeShapeProvider.Default]); + [ + ExoticTypeShapeProvider.Instance, + ReflectionTypeShapeProvider.Default + ]); this.ProfileBuilder = new Profile.Builder(this.userDataProfile); } @@ -1643,7 +1675,7 @@ private partial class ExceptionConverter : MessagePackConverter } // TODO: Is this the right context? - var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(formatter.rpcProfile)); + var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(formatter.userDataProfile)); int memberCount = reader.ReadMapHeader(); for (int i = 0; i < memberCount; i++) { @@ -1689,7 +1721,7 @@ public override void Write(ref MessagePackWriter writer, in T? value, Serializat } // TODO: Is this the right profile? - var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(formatter.rpcProfile)); + var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(formatter.userDataProfile)); ExceptionSerializationHelpers.Serialize(value, info); writer.WriteMapHeader(info.GetSafeMemberCount()); foreach (SerializationEntry element in info.GetSafeMembers()) @@ -2238,4 +2270,58 @@ protected internal override void SetExpectedDataType(Type dataType) } } } + + /// + /// Ensures certain exotic types are matched to the correct MessagePackConverter. + /// We rely on the caching in NerbdBank.MessagePackSerializer to ensure we don't create multiple instances of these shapes. + /// + private class ExoticTypeShapeProvider : ITypeShapeProvider + { + internal static readonly ExoticTypeShapeProvider Instance = new(); + + public ITypeShape? GetShape(Type type) + { + if (typeof(PipeReader).IsAssignableFrom(type)) + { + return new SourceGenObjectTypeShape() + { + IsRecordType = false, + IsTupleType = false, + Provider = this, + }; + } + + if (typeof(PipeWriter).IsAssignableFrom(type)) + { + return new SourceGenObjectTypeShape() + { + IsRecordType = false, + IsTupleType = false, + Provider = this, + }; + } + + if (typeof(Stream).IsAssignableFrom(type)) + { + return new SourceGenObjectTypeShape() + { + IsRecordType = false, + IsTupleType = false, + Provider = this, + }; + } + + if (typeof(IDuplexPipe).IsAssignableFrom(type)) + { + return new SourceGenObjectTypeShape() + { + IsRecordType = false, + IsTupleType = false, + Provider = this, + }; + } + + return null; + } + } } diff --git a/src/StreamJsonRpc/Protocol/CommonErrorData.cs b/src/StreamJsonRpc/Protocol/CommonErrorData.cs index 823f4d5a5..f9e792dac 100644 --- a/src/StreamJsonRpc/Protocol/CommonErrorData.cs +++ b/src/StreamJsonRpc/Protocol/CommonErrorData.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System.Runtime.Serialization; +using PolyType; using STJ = System.Text.Json.Serialization; namespace StreamJsonRpc.Protocol; @@ -10,11 +11,13 @@ namespace StreamJsonRpc.Protocol; /// A class that describes useful data that may be found in the JSON-RPC error message's error.data property. /// [DataContract] -public class CommonErrorData +[GenerateShape] +public partial class CommonErrorData { /// /// Initializes a new instance of the class. /// + [ConstructorShape] public CommonErrorData() { } @@ -39,6 +42,7 @@ public CommonErrorData(Exception copyFrom) /// [DataMember(Order = 0, Name = "type")] [STJ.JsonPropertyName("type"), STJ.JsonPropertyOrder(0)] + [PropertyShape(Name = "type", Order = 0)] public string? TypeName { get; set; } /// @@ -46,6 +50,7 @@ public CommonErrorData(Exception copyFrom) /// [DataMember(Order = 1, Name = "message")] [STJ.JsonPropertyName("message"), STJ.JsonPropertyOrder(1)] + [PropertyShape(Name = "message", Order = 1)] public string? Message { get; set; } /// @@ -53,6 +58,7 @@ public CommonErrorData(Exception copyFrom) /// [DataMember(Order = 2, Name = "stack")] [STJ.JsonPropertyName("stack"), STJ.JsonPropertyOrder(2)] + [PropertyShape(Name = "stack", Order = 2)] public string? StackTrace { get; set; } /// @@ -60,6 +66,7 @@ public CommonErrorData(Exception copyFrom) /// [DataMember(Order = 3, Name = "code")] [STJ.JsonPropertyName("code"), STJ.JsonPropertyOrder(3)] + [PropertyShape(Name = "code", Order = 3)] public int HResult { get; set; } /// @@ -67,5 +74,6 @@ public CommonErrorData(Exception copyFrom) /// [DataMember(Order = 4, Name = "inner")] [STJ.JsonPropertyName("inner"), STJ.JsonPropertyOrder(4)] + [PropertyShape(Name = "inner", Order = 4)] public CommonErrorData? Inner { get; set; } } diff --git a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs index e7726532a..2dd30365b 100644 --- a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System.IO.Pipelines; +using Nerdbank.MessagePack; using Nerdbank.Streams; using PolyType; @@ -32,10 +33,6 @@ protected override void InitializeFormattersAndHandlers() static void Configure(NerdbankMessagePackFormatter.Profile.Builder b) { - b.RegisterDuplexPipeType(); - b.RegisterStreamType(); - b.RegisterStreamType(); - b.RegisterStreamType(); b.AddTypeShapeProvider(DuplexPipeWitness.ShapeProvider); b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); } diff --git a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingTests.cs b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingTests.cs index 809982869..b3a9587ff 100644 --- a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingTests.cs +++ b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingTests.cs @@ -118,7 +118,6 @@ public async Task ClientCanSendReadOnlyPipeToServer(bool orderedArguments) bytesReceived = await this.clientRpc.InvokeWithCancellationAsync( nameof(Server.AcceptReadablePipe), [ExpectedFileName, pipes.Item2], - [typeof(string), typeof(IDuplexPipe)], this.TimeoutToken); } else @@ -146,7 +145,6 @@ public async Task ClientCanSendWriteOnlyPipeToServer(bool orderedArguments) await this.clientRpc.InvokeWithCancellationAsync( nameof(Server.AcceptWritablePipe), [pipes.Item2, bytesToReceive], - [typeof(IDuplexPipe), typeof(int)], this.TimeoutToken); } else @@ -186,7 +184,6 @@ public async Task ClientCanSendPipeReaderToServer() int bytesReceived = await this.clientRpc.InvokeWithCancellationAsync( nameof(Server.AcceptPipeReader), [ExpectedFileName, pipe.Reader], - [typeof(string), typeof(PipeReader)], this.TimeoutToken); Assert.Equal(MemoryBuffer.Length, bytesReceived); @@ -201,7 +198,6 @@ public async Task ClientCanSendPipeWriterToServer() await this.clientRpc.InvokeWithCancellationAsync( nameof(Server.AcceptPipeWriter), [pipe.Writer, bytesToReceive], - [typeof(PipeWriter), typeof(int)], this.TimeoutToken); // Read all that the server wanted us to know, and verify it. @@ -461,7 +457,6 @@ public async Task ClientCanSendTwoWayPipeToServer(bool serverUsesStream) await this.clientRpc.InvokeWithCancellationAsync( serverUsesStream ? nameof(Server.TwoWayStreamAsArg) : nameof(Server.TwoWayPipeAsArg), [false, pipePair.Item2], - [typeof(bool), typeof(IDuplexPipe)], this.TimeoutToken); await twoWayCom.WithCancellation(this.TimeoutToken); // rethrow any exceptions. @@ -483,7 +478,6 @@ public async Task ClientCanSendTwoWayStreamToServer(bool serverUsesStream) await this.clientRpc.InvokeWithCancellationAsync( serverUsesStream ? nameof(Server.TwoWayStreamAsArg) : nameof(Server.TwoWayPipeAsArg), [false, streamPair.Item2], - [typeof(bool), typeof(Stream)], this.TimeoutToken); await twoWayCom.WithCancellation(this.TimeoutToken); // rethrow any exceptions. @@ -497,7 +491,6 @@ public async Task PipeRemainsOpenAfterSuccessfulServerResult() await this.clientRpc.InvokeWithCancellationAsync( nameof(Server.AcceptPipeAndChatLater), [false, pipePair.Item2], - [typeof(bool), typeof(IDuplexPipe)], this.TimeoutToken); await WhenAllSucceedOrAnyFault(TwoWayTalkAsync(pipePair.Item1, writeOnOdd: true, this.TimeoutToken), this.server.ChatLaterTask!); @@ -515,7 +508,6 @@ public async Task ClientClosesChannelsWhenServerErrorsOut() await Assert.ThrowsAsync(() => this.clientRpc.InvokeWithCancellationAsync( nameof(Server.RejectCall), [pipePair.Item2], - [typeof(IDuplexPipe)], this.TimeoutToken)); // Verify that the pipe is closed. @@ -530,7 +522,6 @@ public async Task PipesCloseWhenConnectionCloses() await this.clientRpc.InvokeWithCancellationAsync( nameof(Server.AcceptPipeAndChatLater), [false, pipePair.Item2], - [typeof(bool), typeof(IDuplexPipe)], this.TimeoutToken); this.clientRpc.Dispose(); @@ -554,7 +545,6 @@ public async Task ClientSendsMultiplePipes() await this.clientRpc.InvokeWithCancellationAsync( nameof(Server.TwoPipes), [pipePair1.Item2, pipePair2.Item2], - [typeof(IDuplexPipe), typeof(IDuplexPipe)], this.TimeoutToken); pipePair1.Item1.Output.Complete(); pipePair2.Item1.Output.Complete(); @@ -635,7 +625,6 @@ public async Task ClientSendsPipeWhereServerHasMultipleOverloads() await this.clientRpc.InvokeWithCancellationAsync( nameof(ServerWithOverloads.OverloadedMethod), [false, pipePair.Item2, "hi"], - [typeof(bool), typeof(IDuplexPipe), typeof(string)], this.TimeoutToken); await twoWayCom.WithCancellation(this.TimeoutToken); // rethrow any exceptions. diff --git a/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs index 3a3f71776..9cb94f205 100644 --- a/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs +++ b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs @@ -416,13 +416,18 @@ static void Configure(NerdbankMessagePackFormatter.Profile.Builder b) b.RegisterProgressType(); b.RegisterProgressType(); b.RegisterProgressType, int>(); - b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + b.RegisterProgressType, CustomSerializedType>(); + b.AddTypeShapeProvider(JsonRpcWitness.ShapeProvider); b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); } } protected override object[] CreateFormatterIntrinsicParamsObject(string arg) => []; + [GenerateShape] + [GenerateShape>] + public partial class JsonRpcWitness; + [GenerateShape] #if NET [KnownSubType] From edad2f825943831d1b4030786e51ecc70890de2c Mon Sep 17 00:00:00 2001 From: Charles Willis Date: Sat, 11 Jan 2025 06:03:52 +0000 Subject: [PATCH 25/25] Update Nerdbank.MessagePack and refactor converters - Bump `Nerdbank.MessagePack` package version to `0.3.98-beta`. - Refactor `Profile.Builder` to register async enumerable converters and known subtypes. - Update `NerdbankMessagePackFormatter` to include new converter registration methods. - Update multiple test files to reflect new converter registrations and functionality. --- Directory.Packages.props | 2 +- ...PackFormatter.AsyncEnumerableConverters.cs | 154 +++ .../NerdbankMessagePackFormatter.Constants.cs | 2 - ...ackFormatter.EnumeratorResultsConverter.cs | 61 ++ ...MessagePackFormatter.EventArgsConverter.cs | 47 + ...MessagePackFormatter.ExceptionConverter.cs | 127 +++ ...kFormatter.IJsonRpcMessagePackRetention.cs | 23 + ...bankMessagePackFormatter.PipeConverters.cs | 177 ++++ ...ankMessagePackFormatter.Profile.Builder.cs | 134 ++- ...PackFormatter.ProgressConverterResolver.cs | 116 +++ ...MessagePackFormatter.RequestIdConverter.cs | 58 ++ ...gePackFormatter.RpcMarshalableConverter.cs | 61 ++ ...ssagePackFormatter.TraceParentConverter.cs | 84 ++ .../NerdbankMessagePackFormatter.cs | 892 +----------------- ...sageFormatterRpcMarshaledContextTracker.cs | 12 +- ...AsyncEnumerableNerdbankMessagePackTests.cs | 4 +- ...DisposableProxyNerdbankMessagePackTests.cs | 6 + ...xPipeMarshalingNerdbankMessagePackTests.cs | 7 - .../DuplexPipeMarshalingTests.cs | 35 +- .../JsonRpcNerdbankMessagePackLengthTests.cs | 14 +- ...arshalableProxyNerdbankMessagePackTests.cs | 30 +- .../NerdbankMessagePackFormatterTests.cs | 7 +- ...erverMarshalingNerdbankMessagePackTests.cs | 2 +- ...getObjectEventsNerdbankMessagePackTests.cs | 9 + .../TargetObjectEventsTests.cs | 13 +- 25 files changed, 1118 insertions(+), 959 deletions(-) create mode 100644 src/StreamJsonRpc/NerdbankMessagePackFormatter.AsyncEnumerableConverters.cs create mode 100644 src/StreamJsonRpc/NerdbankMessagePackFormatter.EnumeratorResultsConverter.cs create mode 100644 src/StreamJsonRpc/NerdbankMessagePackFormatter.EventArgsConverter.cs create mode 100644 src/StreamJsonRpc/NerdbankMessagePackFormatter.ExceptionConverter.cs create mode 100644 src/StreamJsonRpc/NerdbankMessagePackFormatter.IJsonRpcMessagePackRetention.cs create mode 100644 src/StreamJsonRpc/NerdbankMessagePackFormatter.PipeConverters.cs create mode 100644 src/StreamJsonRpc/NerdbankMessagePackFormatter.ProgressConverterResolver.cs create mode 100644 src/StreamJsonRpc/NerdbankMessagePackFormatter.RequestIdConverter.cs create mode 100644 src/StreamJsonRpc/NerdbankMessagePackFormatter.RpcMarshalableConverter.cs create mode 100644 src/StreamJsonRpc/NerdbankMessagePackFormatter.TraceParentConverter.cs diff --git a/Directory.Packages.props b/Directory.Packages.props index 564e04b1e..b16855566 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -22,7 +22,7 @@ - + diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.AsyncEnumerableConverters.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.AsyncEnumerableConverters.cs new file mode 100644 index 000000000..1938eb38f --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.AsyncEnumerableConverters.cs @@ -0,0 +1,154 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Buffers; +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using Nerdbank.Streams; +using PolyType.Abstractions; +using StreamJsonRpc.Reflection; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + private static class AsyncEnumerableConverters + { + /// + /// Converts an enumeration token to an + /// or an into an enumeration token. + /// +#pragma warning disable CA1812 + internal class PreciseTypeConverter : MessagePackConverter> +#pragma warning restore CA1812 + { + /// + /// The constant "token", in its various forms. + /// + private static readonly MessagePackString TokenPropertyName = new(MessageFormatterEnumerableTracker.TokenPropertyName); + + /// + /// The constant "values", in its various forms. + /// + private static readonly MessagePackString ValuesPropertyName = new(MessageFormatterEnumerableTracker.ValuesPropertyName); + + public override IAsyncEnumerable? Read(ref MessagePackReader reader, SerializationContext context) + { + if (reader.TryReadNil()) + { + return default; + } + + NerdbankMessagePackFormatter mainFormatter = context.GetFormatter(); + + context.DepthStep(); + + RawMessagePack? token = default; + IReadOnlyList? initialElements = null; + int propertyCount = reader.ReadMapHeader(); + for (int i = 0; i < propertyCount; i++) + { + if (TokenPropertyName.TryRead(ref reader)) + { + // The value needs to outlive the reader, so we clone it. + token = new RawMessagePack(reader.ReadRaw(context)).ToOwned(); + } + else if (ValuesPropertyName.TryRead(ref reader)) + { + initialElements = context.GetConverter>(context.TypeShapeProvider).Read(ref reader, context); + } + else + { + reader.Skip(context); + } + } + + return mainFormatter.EnumerableTracker.CreateEnumerableProxy(token.HasValue ? token.Value : null, initialElements); + } + + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to helper method")] + public override void Write(ref MessagePackWriter writer, in IAsyncEnumerable? value, SerializationContext context) + { + context.DepthStep(); + + NerdbankMessagePackFormatter mainFormatter = context.GetFormatter(); + Serialize_Shared(mainFormatter, ref writer, value, context); + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(PreciseTypeConverter)); + } + + internal static void Serialize_Shared(NerdbankMessagePackFormatter mainFormatter, ref MessagePackWriter writer, IAsyncEnumerable? value, SerializationContext context) + { + if (value is null) + { + writer.WriteNil(); + } + else + { + (IReadOnlyList elements, bool finished) = value.TearOffPrefetchedElements(); + long token = mainFormatter.EnumerableTracker.GetToken(value); + + int propertyCount = 0; + if (elements.Count > 0) + { + propertyCount++; + } + + if (!finished) + { + propertyCount++; + } + + writer.WriteMapHeader(propertyCount); + + if (!finished) + { + writer.Write(TokenPropertyName); + writer.Write(token); + } + + if (elements.Count > 0) + { + writer.Write(ValuesPropertyName); + context.GetConverter>(context.TypeShapeProvider).Write(ref writer, elements, context); + } + } + } + } + + /// + /// Converts an instance of to an enumeration token. + /// +#pragma warning disable CA1812 + internal class GeneratorConverter : MessagePackConverter + where TClass : IAsyncEnumerable +#pragma warning restore CA1812 + { + public override TClass Read(ref MessagePackReader reader, SerializationContext context) + { + throw new NotSupportedException(); + } + + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to helper method")] + public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) + { + NerdbankMessagePackFormatter mainFormatter = context.GetFormatter(); + + context.DepthStep(); + PreciseTypeConverter.Serialize_Shared(mainFormatter, ref writer, value, context); + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(GeneratorConverter)); + } + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.Constants.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.Constants.cs index 52699eff1..8990d0a14 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.Constants.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.Constants.cs @@ -1,10 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. -using System.Diagnostics; using Nerdbank.MessagePack; using StreamJsonRpc.Protocol; -using NBMP = Nerdbank.MessagePack; namespace StreamJsonRpc; diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.EnumeratorResultsConverter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.EnumeratorResultsConverter.cs new file mode 100644 index 000000000..379623815 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.EnumeratorResultsConverter.cs @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using PolyType.Abstractions; +using StreamJsonRpc.Reflection; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + private class EnumeratorResultsConverter : MessagePackConverter> + { + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Reader is passed to user data context")] + public override MessageFormatterEnumerableTracker.EnumeratorResults? Read(ref MessagePackReader reader, SerializationContext context) + { + if (reader.TryReadNil()) + { + return default; + } + + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); + + Verify.Operation(reader.ReadArrayHeader() == 2, "Expected array of length 2."); + return new MessageFormatterEnumerableTracker.EnumeratorResults() + { + Values = formatter.userDataProfile.Deserialize>(ref reader, context.CancellationToken), + Finished = formatter.userDataProfile.Deserialize(ref reader, context.CancellationToken), + }; + } + + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to user data context")] + public override void Write(ref MessagePackWriter writer, in MessageFormatterEnumerableTracker.EnumeratorResults? value, SerializationContext context) + { + if (value is null) + { + writer.WriteNil(); + } + else + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); + + writer.WriteArrayHeader(2); + formatter.userDataProfile.Serialize(ref writer, value.Values, context.CancellationToken); + formatter.userDataProfile.Serialize(ref writer, value.Finished, context.CancellationToken); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(EnumeratorResultsConverter)); + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.EventArgsConverter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.EventArgsConverter.cs new file mode 100644 index 000000000..24e0bf76d --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.EventArgsConverter.cs @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using PolyType.Abstractions; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + /// + /// Enables formatting the default/empty class. + /// + private class EventArgsConverter : MessagePackConverter + { + internal static readonly EventArgsConverter Instance = new(); + + private EventArgsConverter() + { + } + + /// + public override void Write(ref MessagePackWriter writer, in EventArgs? value, SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + context.DepthStep(); + writer.WriteMapHeader(0); + } + + /// + public override EventArgs Read(ref MessagePackReader reader, SerializationContext context) + { + context.DepthStep(); + reader.Skip(context); + return EventArgs.Empty; + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(EventArgsConverter)); + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.ExceptionConverter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ExceptionConverter.cs new file mode 100644 index 000000000..6831f6245 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ExceptionConverter.cs @@ -0,0 +1,127 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Runtime.Serialization; +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using PolyType.Abstractions; +using StreamJsonRpc.Reflection; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + /// + /// Manages serialization of any -derived type that follows standard rules. + /// + /// + /// A serializable class will: + /// 1. Derive from + /// 2. Be attributed with + /// 3. Declare a constructor with a signature of (, ). + /// + private class ExceptionConverter : MessagePackConverter + where T : Exception + { + public static readonly ExceptionConverter Instance = new(); + + public override T? Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + Assumes.NotNull(formatter.JsonRpc); + + context.DepthStep(); + + if (reader.TryReadNil()) + { + return null; + } + + // We have to guard our own recursion because the serializer has no visibility into inner exceptions. + // Each exception in the russian doll is a new serialization job from its perspective. + formatter.exceptionRecursionCounter.Value++; + try + { + if (formatter.exceptionRecursionCounter.Value > formatter.JsonRpc.ExceptionOptions.RecursionLimit) + { + // Exception recursion has gone too deep. Skip this value and return null as if there were no inner exception. + // Note that in skipping, the parser may use recursion internally and may still throw if its own limits are exceeded. + reader.Skip(context); + return null; + } + + // TODO: Is this the right context? + var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(formatter.userDataProfile)); + int memberCount = reader.ReadMapHeader(); + for (int i = 0; i < memberCount; i++) + { + string? name = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context) + ?? throw new MessagePackSerializationException(Resources.UnexpectedNullValueInMap); + + // SerializationInfo.GetValue(string, typeof(object)) does not call our formatter, + // so the caller will get a boxed RawMessagePack struct in that case. + // Although we can't do much about *that* in general, we can at least ensure that null values + // are represented as null instead of this boxed struct. + var value = reader.TryReadNil() ? null : (object)reader.ReadRaw(context); + + info.AddSafeValue(name, value); + } + + return ExceptionSerializationHelpers.Deserialize(formatter.JsonRpc, info, formatter.JsonRpc.TraceSource); + } + finally + { + formatter.exceptionRecursionCounter.Value--; + } + } + + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + if (value is null) + { + writer.WriteNil(); + return; + } + + formatter.exceptionRecursionCounter.Value++; + try + { + if (formatter.exceptionRecursionCounter.Value > formatter.JsonRpc?.ExceptionOptions.RecursionLimit) + { + // Exception recursion has gone too deep. Skip this value and write null as if there were no inner exception. + writer.WriteNil(); + return; + } + + // TODO: Is this the right profile? + var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(formatter.userDataProfile)); + ExceptionSerializationHelpers.Serialize(value, info); + writer.WriteMapHeader(info.GetSafeMemberCount()); + foreach (SerializationEntry element in info.GetSafeMembers()) + { + writer.Write(element.Name); + formatter.rpcProfile.SerializeObject( + ref writer, + element.Value, + element.ObjectType, + context.CancellationToken); + } + } + finally + { + formatter.exceptionRecursionCounter.Value--; + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(ExceptionConverter)); + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.IJsonRpcMessagePackRetention.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.IJsonRpcMessagePackRetention.cs new file mode 100644 index 000000000..6ceb33db0 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.IJsonRpcMessagePackRetention.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Buffers; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + private interface IJsonRpcMessagePackRetention + { + /// + /// Gets the original msgpack sequence that was deserialized into this message. + /// + /// + /// The buffer is only retained for a short time. If it has already been cleared, the result of this property is an empty sequence. + /// + ReadOnlySequence OriginalMessagePack { get; } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.PipeConverters.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.PipeConverters.cs new file mode 100644 index 000000000..5e77df4d1 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.PipeConverters.cs @@ -0,0 +1,177 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.IO.Pipelines; +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using Nerdbank.Streams; +using PolyType.Abstractions; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + private static class PipeConverters + { + internal class DuplexPipeConverter : MessagePackConverter + where T : class, IDuplexPipe + { + public static readonly DuplexPipeConverter DefaultInstance = new(); + + public override T? Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + if (reader.TryReadNil()) + { + return null; + } + + return (T)formatter.DuplexPipeTracker.GetPipe(reader.ReadUInt64()); + } + + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + if (formatter.DuplexPipeTracker.GetULongToken(value) is ulong token) + { + writer.Write(token); + } + else + { + writer.WriteNil(); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(DuplexPipeConverter)); + } + } + + internal class PipeReaderConverter : MessagePackConverter + where T : PipeReader + { + public static readonly PipeReaderConverter DefaultInstance = new(); + + public override T? Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + if (reader.TryReadNil()) + { + return null; + } + + return (T)formatter.DuplexPipeTracker.GetPipeReader(reader.ReadUInt64()); + } + + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + if (formatter.DuplexPipeTracker.GetULongToken(value) is { } token) + { + writer.Write(token); + } + else + { + writer.WriteNil(); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(PipeReaderConverter)); + } + } + + internal class PipeWriterConverter : MessagePackConverter + where T : PipeWriter + { + public static readonly PipeWriterConverter DefaultInstance = new(); + + public override T? Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + if (reader.TryReadNil()) + { + return null; + } + + return (T)formatter.DuplexPipeTracker.GetPipeWriter(reader.ReadUInt64()); + } + + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + if (formatter.DuplexPipeTracker.GetULongToken(value) is ulong token) + { + writer.Write(token); + } + else + { + writer.WriteNil(); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(PipeWriterConverter)); + } + } + + internal class StreamConverter : MessagePackConverter + where T : Stream + { + public static readonly StreamConverter DefaultInstance = new(); + + public override T? Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + if (reader.TryReadNil()) + { + return null; + } + + return (T)formatter.DuplexPipeTracker.GetPipe(reader.ReadUInt64()).AsStream(); + } + + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + if (formatter.DuplexPipeTracker.GetULongToken(value?.UsePipe()) is { } token) + { + writer.Write(token); + } + else + { + writer.WriteNil(); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(StreamConverter)); + } + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs index c6dbe1ae1..34a5ea1e3 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System.Collections.Immutable; +using System.IO.Pipelines; using Nerdbank.MessagePack; using PolyType; using StreamJsonRpc.Reflection; @@ -47,76 +48,70 @@ public void AddTypeShapeProvider(ITypeShapeProvider provider) } /// - /// Registers an async enumerable type with the profile. + /// Registers known subtypes for a base type with the profile. /// - /// The type of the async enumerable. - /// The type of the elements in the async enumerable. - public void RegisterAsyncEnumerableType() - where TEnumerable : IAsyncEnumerable + /// The base type. + /// The mapping of known subtypes. + public void RegisterKnownSubTypes(KnownSubTypeMapping mapping) { - MessagePackConverter converter = AsyncEnumerableConverterResolver.GetConverter(); - this.baseProfile.Serializer.RegisterConverter(converter); + this.baseProfile.Serializer.RegisterKnownSubTypes(mapping); + } - MessagePackConverter> resultConverter = EnumeratorResultsConverterResolver.GetConverter(); - this.baseProfile.Serializer.RegisterConverter(resultConverter); + /// + /// Registers a converter with the profile. + /// + /// The type the converter handles. + /// The converter to register. + public void RegisterConverter(MessagePackConverter converter) + { + this.baseProfile.Serializer.RegisterConverter(converter); } /// - /// Registers an async enumerable type with the profile. + /// Registers an async enumerable converter with the profile. /// /// - /// To avoid the cost of reflection, ensure is - /// registered with your type shape provider. + /// Register an on the type shape provider to avoid reflection costs. /// /// The type of the elements in the async enumerable. - public void RegisterAsyncEnumerableType() + public void RegisterAsyncEnumerableConverter() { - MessagePackConverter> converter = AsyncEnumerableConverterResolver.GetConverter>(); + MessagePackConverter> converter = new AsyncEnumerableConverters.PreciseTypeConverter(); this.baseProfile.Serializer.RegisterConverter(converter); - MessagePackConverter> resultConverter = EnumeratorResultsConverterResolver.GetConverter(); + MessagePackConverter> resultConverter = new EnumeratorResultsConverter(); this.baseProfile.Serializer.RegisterConverter(resultConverter); } - /// - /// Registers a converter with the profile. - /// - /// The type the converter handles. - /// The converter to register. - public void RegisterConverter(MessagePackConverter converter) + /// + /// The type of the async enumerable generator. + /// The type of the elements in the async enumerable. + public void RegisterAsyncEnumerableConverter() + where TGenerator : IAsyncEnumerable { + MessagePackConverter converter = new AsyncEnumerableConverters.GeneratorConverter(); this.baseProfile.Serializer.RegisterConverter(converter); - } - - /// - /// Registers known subtypes for a base type with the profile. - /// - /// The base type. - /// The mapping of known subtypes. - public void RegisterKnownSubTypes(KnownSubTypeMapping mapping) - { - this.baseProfile.Serializer.RegisterKnownSubTypes(mapping); + MessagePackConverter> resultConverter = new EnumeratorResultsConverter(); + this.baseProfile.Serializer.RegisterConverter(resultConverter); } /// /// Registers a progress type with the profile. /// - /// The type of the progress. /// The type of the report. - public void RegisterProgressType() - where TProgress : IProgress + public void RegisterProgressConverter() { - MessagePackConverter converter = ProgressConverterResolver.GetConverter(); + MessagePackConverter> converter = ProgressConverterResolver.GetConverter>(); this.baseProfile.Serializer.RegisterConverter(converter); } - /// - /// Registers a progress type with the profile. - /// + /// + /// The type of the progress. /// The type of the report. - public void RegisterProgressType() + public void RegisterProgressConverter() + where TProgress : IProgress { - MessagePackConverter> converter = ProgressConverterResolver.GetConverter>(); + MessagePackConverter converter = ProgressConverterResolver.GetConverter(); this.baseProfile.Serializer.RegisterConverter(converter); } @@ -127,7 +122,7 @@ public void RegisterProgressType() public void RegisterExceptionType() where TException : Exception { - MessagePackConverter converter = MessagePackExceptionConverterResolver.GetConverter(); + MessagePackConverter converter = ExceptionConverter.Instance; this.baseProfile.Serializer.RegisterConverter(converter); } @@ -135,27 +130,62 @@ public void RegisterExceptionType() /// Registers an RPC marshalable type with the profile. /// /// The type to register. - public void RegisterRpcMarshalableType() + public void RegisterRpcMarshalableConverter() where T : class { MessagePackConverter converter = GetRpcMarshalableConverter(); this.baseProfile.Serializer.RegisterConverter(converter); } + /// + /// Registers a converter for the type with the profile. + /// + /// The type of the duplex pipe. + public void RegisterDuplexPipeConverter() + where T : class, IDuplexPipe + { + var converter = new PipeConverters.DuplexPipeConverter(); + this.baseProfile.Serializer.RegisterConverter(converter); + } + + /// + /// Registers a converter for the type with the profile. + /// + /// The type of the pipe reader. + public void RegisterPipeReaderConverter() + where T : PipeReader + { + var converter = new PipeConverters.PipeReaderConverter(); + this.baseProfile.Serializer.RegisterConverter(converter); + } + + /// + /// Registers a converter for the type with the profile. + /// + /// The type of the pipe writer. + public void RegisterPipeWriterConverter() + where T : PipeWriter + { + var converter = new PipeConverters.PipeWriterConverter(); + this.baseProfile.Serializer.RegisterConverter(converter); + } + + /// + /// Registers a converter for the type with the profile. + /// + /// The type of the stream. + public void RegisterStreamConverter() + where T : Stream + { + var converter = new PipeConverters.StreamConverter(); + this.baseProfile.Serializer.RegisterConverter(converter); + } + /// /// Registers an observer type with the profile. /// - /// - /// - /// To register with the profile, - /// call this method with the type parameter . - /// - /// - /// RegisterObserver<int>(); - /// - /// /// The type of the observer. - public void RegisterObserver() + public void RegisterObserverConverter() { MessagePackConverter> converter = GetRpcMarshalableConverter>(); this.baseProfile.Serializer.RegisterConverter(converter); diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.ProgressConverterResolver.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ProgressConverterResolver.cs new file mode 100644 index 000000000..04ff23d83 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ProgressConverterResolver.cs @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Buffers; +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using PolyType.Abstractions; +using StreamJsonRpc.Reflection; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + private static class ProgressConverterResolver + { + public static MessagePackConverter GetConverter() + { + MessagePackConverter? converter = default; + + if (MessageFormatterProgressTracker.CanDeserialize(typeof(T))) + { + converter = new FullProgressConverter(); + } + else if (MessageFormatterProgressTracker.CanSerialize(typeof(T))) + { + converter = new ProgressClientConverter(); + } + + // TODO: Improve Exception + return converter ?? throw new NotSupportedException(); + } + + /// + /// Converts an instance of to a progress token. + /// + private class ProgressClientConverter : MessagePackConverter + { + public override TClass Read(ref MessagePackReader reader, SerializationContext context) + { + throw new NotSupportedException("This formatter only serializes IProgress instances."); + } + + public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + if (value is null) + { + writer.WriteNil(); + } + else + { + long progressId = formatter.FormatterProgressTracker.GetTokenForProgress(value); + writer.Write(progressId); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(ProgressClientConverter)); + } + } + + /// + /// Converts a progress token to an or an into a token. + /// + private class FullProgressConverter : MessagePackConverter + { + [return: MaybeNull] + public override TClass? Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + if (reader.TryReadNil()) + { + return default!; + } + + Assumes.NotNull(formatter.JsonRpc); + ReadOnlySequence token = reader.ReadRaw(context); + bool clientRequiresNamedArgs = formatter.ApplicableMethodAttributeOnDeserializingMethod?.ClientRequiresNamedArguments is true; + return (TClass)formatter.FormatterProgressTracker.CreateProgress(formatter.JsonRpc, token, typeof(TClass), clientRequiresNamedArgs); + } + + public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + if (value is null) + { + writer.WriteNil(); + } + else + { + long progressId = formatter.FormatterProgressTracker.GetTokenForProgress(value); + writer.Write(progressId); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(FullProgressConverter)); + } + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.RequestIdConverter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.RequestIdConverter.cs new file mode 100644 index 000000000..0f31df54c --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.RequestIdConverter.cs @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using PolyType.Abstractions; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + private class RequestIdConverter : MessagePackConverter + { + internal static readonly RequestIdConverter Instance = new(); + + private RequestIdConverter() + { + } + + public override RequestId Read(ref MessagePackReader reader, SerializationContext context) + { + context.DepthStep(); + + if (reader.NextMessagePackType == MessagePackType.Integer) + { + return new RequestId(reader.ReadInt64()); + } + else + { + // Do *not* read as an interned string here because this ID should be unique. + return new RequestId(reader.ReadString()); + } + } + + public override void Write(ref MessagePackWriter writer, in RequestId value, SerializationContext context) + { + context.DepthStep(); + + if (value.Number.HasValue) + { + writer.Write(value.Number.Value); + } + else + { + writer.Write(value.String); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => JsonNode.Parse(""" + { + "type": ["string", { "type": "integer", "format": "int64" }] + } + """)?.AsObject(); + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.RpcMarshalableConverter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.RpcMarshalableConverter.cs new file mode 100644 index 000000000..8c4e3c251 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.RpcMarshalableConverter.cs @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using PolyType.Abstractions; +using StreamJsonRpc.Reflection; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + private class RpcMarshalableConverter( + JsonRpcProxyOptions proxyOptions, + JsonRpcTargetOptions targetOptions, + RpcMarshalableAttribute rpcMarshalableAttribute) : MessagePackConverter + where T : class + { + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Reader is passed to rpc context")] + public override T? Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + MessageFormatterRpcMarshaledContextTracker.MarshalToken? token = formatter + .rpcProfile.Deserialize(ref reader, context.CancellationToken); + + return token.HasValue + ? (T?)formatter.RpcMarshaledContextTracker.GetObject(typeof(T), token, proxyOptions) + : null; + } + + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to rpc context")] + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + if (value is null) + { + writer.WriteNil(); + } + else + { + MessageFormatterRpcMarshaledContextTracker.MarshalToken token = formatter.RpcMarshaledContextTracker.GetToken(value, targetOptions, typeof(T), rpcMarshalableAttribute); + formatter.rpcProfile.Serialize(ref writer, token, context.CancellationToken); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(RpcMarshalableConverter)); + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.TraceParentConverter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.TraceParentConverter.cs new file mode 100644 index 000000000..a46d91c2d --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.TraceParentConverter.cs @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Buffers; +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using PolyType.Abstractions; +using StreamJsonRpc.Protocol; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + internal class TraceParentConverter : MessagePackConverter + { + public unsafe override TraceParent Read(ref MessagePackReader reader, SerializationContext context) + { + context.DepthStep(); + + if (reader.ReadArrayHeader() != 2) + { + throw new NotSupportedException("Unexpected array length."); + } + + var result = default(TraceParent); + result.Version = reader.ReadByte(); + if (result.Version != 0) + { + throw new NotSupportedException("traceparent version " + result.Version + " is not supported."); + } + + if (reader.ReadArrayHeader() != 3) + { + throw new NotSupportedException("Unexpected array length in version-format."); + } + + ReadOnlySequence bytes = reader.ReadBytes() ?? throw new NotSupportedException("Expected traceid not found."); + bytes.CopyTo(new Span(result.TraceId, TraceParent.TraceIdByteCount)); + + bytes = reader.ReadBytes() ?? throw new NotSupportedException("Expected parentid not found."); + bytes.CopyTo(new Span(result.ParentId, TraceParent.ParentIdByteCount)); + + result.Flags = (TraceParent.TraceFlags)reader.ReadByte(); + + return result; + } + + public unsafe override void Write(ref MessagePackWriter writer, in TraceParent value, SerializationContext context) + { + if (value.Version != 0) + { + throw new NotSupportedException("traceparent version " + value.Version + " is not supported."); + } + + context.DepthStep(); + + writer.WriteArrayHeader(2); + + writer.Write(value.Version); + + writer.WriteArrayHeader(3); + + fixed (byte* traceId = value.TraceId) + { + writer.Write(new ReadOnlySpan(traceId, TraceParent.TraceIdByteCount)); + } + + fixed (byte* parentId = value.ParentId) + { + writer.Write(new ReadOnlySpan(parentId, TraceParent.ParentIdByteCount)); + } + + writer.Write((byte)value.Flags); + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(TraceParentConverter)); + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs index 9625ec4d6..c25ea263f 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs @@ -14,7 +14,6 @@ using System.Text; using System.Text.Json.Nodes; using Nerdbank.MessagePack; -using Nerdbank.Streams; using PolyType; using PolyType.Abstractions; using PolyType.ReflectionProvider; @@ -31,8 +30,6 @@ namespace StreamJsonRpc; /// /// The MessagePack implementation used here comes from https://github.com/AArnott/Nerdbank.MessagePack. /// -[SuppressMessage("ApiDesign", "RS0016:Add public types and members to the declared API", Justification = "TODO: Suppressed for Development")] -[GenerateShape] [GenerateShape] [GenerateShape] [GenerateShape] @@ -40,6 +37,8 @@ namespace StreamJsonRpc; [GenerateShape] [GenerateShape] [GenerateShape] +[GenerateShape] +[SuppressMessage("ApiDesign", "RS0016:Add public types and members to the declared API", Justification = "TODO: Suppressed for Development")] public partial class NerdbankMessagePackFormatter : FormatterBase, IJsonRpcMessageFormatter, IJsonRpcFormatterTracingCallbacks, IJsonRpcMessageFactory { /// @@ -59,6 +58,8 @@ public partial class NerdbankMessagePackFormatter : FormatterBase, IJsonRpcMessa private readonly ToStringHelper deserializationToStringHelper = new(); + private readonly ThreadLocal exceptionRecursionCounter = new(); + /// /// The serializer to use for user data (e.g. arguments, return values and errors). /// @@ -87,13 +88,7 @@ public NerdbankMessagePackFormatter() }; serializer.RegisterKnownSubTypes(exceptionSubtypeMap); - serializer.RegisterConverter(RequestIdConverter.Instance); - serializer.RegisterConverter(EventArgsConverter.Instance); - serializer.RegisterConverter(MessagePackExceptionConverterResolver.GetConverter()); - serializer.RegisterConverter(MessagePackExceptionConverterResolver.GetConverter()); - serializer.RegisterConverter(MessagePackExceptionConverterResolver.GetConverter()); - serializer.RegisterConverter(MessagePackExceptionConverterResolver.GetConverter()); - serializer.RegisterConverter(MessagePackExceptionConverterResolver.GetConverter()); + RegisterCommonConverters(serializer); this.rpcProfile = new Profile( Profile.ProfileSource.Internal, @@ -115,25 +110,7 @@ public NerdbankMessagePackFormatter() }; userSerializer.RegisterKnownSubTypes(exceptionSubtypeMap); - - // Add our own resolvers to fill in specialized behavior if the user doesn't provide/override it by their own resolver. - // We preset this one in user data because $/cancellation methods can carry RequestId values as arguments. - userSerializer.RegisterConverter(RequestIdConverter.Instance); - - // We preset this one because for some protocols like IProgress, tokens are passed in that we must relay exactly back to the client as an argument. - userSerializer.RegisterConverter(EventArgsConverter.Instance); - - // Common exotic types that we want to support. - userSerializer.RegisterConverter(GetRpcMarshalableConverter()); - userSerializer.RegisterConverter(PipeConverterResolver.GetConverter()); - userSerializer.RegisterConverter(PipeConverterResolver.GetConverter()); - userSerializer.RegisterConverter(PipeConverterResolver.GetConverter()); - userSerializer.RegisterConverter(PipeConverterResolver.GetConverter()); - userSerializer.RegisterConverter(MessagePackExceptionConverterResolver.GetConverter()); - userSerializer.RegisterConverter(MessagePackExceptionConverterResolver.GetConverter()); - userSerializer.RegisterConverter(MessagePackExceptionConverterResolver.GetConverter()); - userSerializer.RegisterConverter(MessagePackExceptionConverterResolver.GetConverter()); - userSerializer.RegisterConverter(MessagePackExceptionConverterResolver.GetConverter()); + RegisterCommonConverters(userSerializer); this.userDataProfile = new Profile( Profile.ProfileSource.External, @@ -144,17 +121,27 @@ public NerdbankMessagePackFormatter() ]); this.ProfileBuilder = new Profile.Builder(this.userDataProfile); - } - private interface IJsonRpcMessagePackRetention - { - /// - /// Gets the original msgpack sequence that was deserialized into this message. - /// - /// - /// The buffer is only retained for a short time. If it has already been cleared, the result of this property is an empty sequence. - /// - ReadOnlySequence OriginalMessagePack { get; } + // Add our own resolvers to fill in specialized behavior if the user doesn't provide/override it by their own resolver. + static void RegisterCommonConverters(MessagePackSerializer serializer) + { + serializer.RegisterConverter(GetRpcMarshalableConverter()); + serializer.RegisterConverter(PipeConverters.PipeReaderConverter.DefaultInstance); + serializer.RegisterConverter(PipeConverters.PipeWriterConverter.DefaultInstance); + serializer.RegisterConverter(PipeConverters.StreamConverter.DefaultInstance); + serializer.RegisterConverter(PipeConverters.DuplexPipeConverter.DefaultInstance); + + // We preset this one in user data because $/cancellation methods can carry RequestId values as arguments. + serializer.RegisterConverter(RequestIdConverter.Instance); + + // We preset this one because for some protocols like IProgress, tokens are passed in that we must relay exactly back to the client as an argument. + serializer.RegisterConverter(EventArgsConverter.Instance); + serializer.RegisterConverter(ExceptionConverter.Instance); + serializer.RegisterConverter(ExceptionConverter.Instance); + serializer.RegisterConverter(ExceptionConverter.Instance); + serializer.RegisterConverter(ExceptionConverter.Instance); + serializer.RegisterConverter(ExceptionConverter.Instance); + } } /// @@ -408,7 +395,10 @@ bool TryGetSerializationInfo(MemberInfo memberInfo, out string key) private static ReadOnlySequence GetSliceForNextToken(ref MessagePackReader reader, in SerializationContext context) { - return reader.ReadRaw(context); + SequencePosition startPosition = reader.Position; + reader.Skip(context); + SequencePosition endPosition = reader.Position; + return reader.Sequence.Slice(startPosition, endPosition); } /// @@ -910,8 +900,7 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcResu { writer.WriteNil(); } - - if (value.ResultDeclaredType is not null + else if (value.ResultDeclaredType is not null && value.ResultDeclaredType != typeof(void) && value.ResultDeclaredType != typeof(object)) { @@ -1115,825 +1104,6 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcErro } } - internal class TraceParentConverter : MessagePackConverter - { - public unsafe override TraceParent Read(ref MessagePackReader reader, SerializationContext context) - { - context.DepthStep(); - - if (reader.ReadArrayHeader() != 2) - { - throw new NotSupportedException("Unexpected array length."); - } - - var result = default(TraceParent); - result.Version = reader.ReadByte(); - if (result.Version != 0) - { - throw new NotSupportedException("traceparent version " + result.Version + " is not supported."); - } - - if (reader.ReadArrayHeader() != 3) - { - throw new NotSupportedException("Unexpected array length in version-format."); - } - - ReadOnlySequence bytes = reader.ReadBytes() ?? throw new NotSupportedException("Expected traceid not found."); - bytes.CopyTo(new Span(result.TraceId, TraceParent.TraceIdByteCount)); - - bytes = reader.ReadBytes() ?? throw new NotSupportedException("Expected parentid not found."); - bytes.CopyTo(new Span(result.ParentId, TraceParent.ParentIdByteCount)); - - result.Flags = (TraceParent.TraceFlags)reader.ReadByte(); - - return result; - } - - public unsafe override void Write(ref MessagePackWriter writer, in TraceParent value, SerializationContext context) - { - if (value.Version != 0) - { - throw new NotSupportedException("traceparent version " + value.Version + " is not supported."); - } - - context.DepthStep(); - - writer.WriteArrayHeader(2); - - writer.Write(value.Version); - - writer.WriteArrayHeader(3); - - fixed (byte* traceId = value.TraceId) - { - writer.Write(new ReadOnlySpan(traceId, TraceParent.TraceIdByteCount)); - } - - fixed (byte* parentId = value.ParentId) - { - writer.Write(new ReadOnlySpan(parentId, TraceParent.ParentIdByteCount)); - } - - writer.Write((byte)value.Flags); - } - - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) - { - return CreateUndocumentedSchema(typeof(TraceParentConverter)); - } - } - - private static class ProgressConverterResolver - { - public static MessagePackConverter GetConverter() - { - MessagePackConverter? converter = default; - - if (MessageFormatterProgressTracker.CanDeserialize(typeof(T))) - { - converter = new FullProgressConverter(); - } - else if (MessageFormatterProgressTracker.CanSerialize(typeof(T))) - { - converter = new ProgressClientConverter(); - } - - // TODO: Improve Exception - return converter ?? throw new NotSupportedException(); - } - - /// - /// Converts an instance of to a progress token. - /// - private class ProgressClientConverter : MessagePackConverter - { - public override TClass Read(ref MessagePackReader reader, SerializationContext context) - { - throw new NotSupportedException("This formatter only serializes IProgress instances."); - } - - public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) - { - NerdbankMessagePackFormatter formatter = context.GetFormatter(); - - context.DepthStep(); - - if (value is null) - { - writer.WriteNil(); - } - else - { - long progressId = formatter.FormatterProgressTracker.GetTokenForProgress(value); - writer.Write(progressId); - } - } - - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) - { - return CreateUndocumentedSchema(typeof(ProgressClientConverter)); - } - } - - /// - /// Converts a progress token to an or an into a token. - /// - private class FullProgressConverter : MessagePackConverter - { - [return: MaybeNull] - public override TClass? Read(ref MessagePackReader reader, SerializationContext context) - { - NerdbankMessagePackFormatter formatter = context.GetFormatter(); - - context.DepthStep(); - - if (reader.TryReadNil()) - { - return default!; - } - - Assumes.NotNull(formatter.JsonRpc); - ReadOnlySequence token = reader.ReadRaw(context); - bool clientRequiresNamedArgs = formatter.ApplicableMethodAttributeOnDeserializingMethod?.ClientRequiresNamedArguments is true; - return (TClass)formatter.FormatterProgressTracker.CreateProgress(formatter.JsonRpc, token, typeof(TClass), clientRequiresNamedArgs); - } - - public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) - { - NerdbankMessagePackFormatter formatter = context.GetFormatter(); - - context.DepthStep(); - - if (value is null) - { - writer.WriteNil(); - } - else - { - long progressId = formatter.FormatterProgressTracker.GetTokenForProgress(value); - writer.Write(progressId); - } - } - - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) - { - return CreateUndocumentedSchema(typeof(FullProgressConverter)); - } - } - } - - private static class AsyncEnumerableConverterResolver - { - public static MessagePackConverter GetConverter() - { - MessagePackConverter? converter = default; - - if (TrackerHelpers>.IsActualInterfaceMatch(typeof(T))) - { - converter = (MessagePackConverter?)Activator.CreateInstance( - typeof(PreciseTypeConverter<>).MakeGenericType(typeof(T).GenericTypeArguments[0])); - } - else if (TrackerHelpers>.FindInterfaceImplementedBy(typeof(T)) is { } iface) - { - converter = (MessagePackConverter?)Activator.CreateInstance( - typeof(GeneratorConverter<,>).MakeGenericType(typeof(T), iface.GenericTypeArguments[0])); - } - - // TODO: Improve Exception - return converter ?? throw new NotSupportedException(); - } - - /// - /// Converts an enumeration token to an - /// or an into an enumeration token. - /// -#pragma warning disable CA1812 - private partial class PreciseTypeConverter : MessagePackConverter> -#pragma warning restore CA1812 - { - /// - /// The constant "token", in its various forms. - /// - private static readonly MessagePackString TokenPropertyName = new(MessageFormatterEnumerableTracker.TokenPropertyName); - - /// - /// The constant "values", in its various forms. - /// - private static readonly MessagePackString ValuesPropertyName = new(MessageFormatterEnumerableTracker.ValuesPropertyName); - - public override IAsyncEnumerable? Read(ref MessagePackReader reader, SerializationContext context) - { - if (reader.TryReadNil()) - { - return default; - } - - NerdbankMessagePackFormatter mainFormatter = context.GetFormatter(); - - context.DepthStep(); - - RawMessagePack? token = default; - IReadOnlyList? initialElements = null; - int propertyCount = reader.ReadMapHeader(); - for (int i = 0; i < propertyCount; i++) - { - if (TokenPropertyName.TryRead(ref reader)) - { - token = (RawMessagePack)reader.ReadRaw(context); - } - else if (ValuesPropertyName.TryRead(ref reader)) - { - initialElements = context.GetConverter>(context.TypeShapeProvider).Read(ref reader, context); - } - else - { - reader.Skip(context); - } - } - - return mainFormatter.EnumerableTracker.CreateEnumerableProxy(token.HasValue ? (object)token : null, initialElements); - } - - [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to helper method")] - public override void Write(ref MessagePackWriter writer, in IAsyncEnumerable? value, SerializationContext context) - { - context.DepthStep(); - - NerdbankMessagePackFormatter mainFormatter = context.GetFormatter(); - Serialize_Shared(mainFormatter, ref writer, value, context); - } - - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) - { - return CreateUndocumentedSchema(typeof(PreciseTypeConverter)); - } - - internal static void Serialize_Shared(NerdbankMessagePackFormatter mainFormatter, ref MessagePackWriter writer, IAsyncEnumerable? value, SerializationContext context) - { - if (value is null) - { - writer.WriteNil(); - } - else - { - (IReadOnlyList elements, bool finished) = value.TearOffPrefetchedElements(); - long token = mainFormatter.EnumerableTracker.GetToken(value); - - int propertyCount = 0; - if (elements.Count > 0) - { - propertyCount++; - } - - if (!finished) - { - propertyCount++; - } - - writer.WriteMapHeader(propertyCount); - - if (!finished) - { - writer.Write(TokenPropertyName); - writer.Write(token); - } - - if (elements.Count > 0) - { - writer.Write(ValuesPropertyName); - context.GetConverter>(context.TypeShapeProvider).Write(ref writer, elements, context); - } - } - } - } - - /// - /// Converts an instance of to an enumeration token. - /// -#pragma warning disable CA1812 - private class GeneratorConverter : MessagePackConverter - where TClass : IAsyncEnumerable -#pragma warning restore CA1812 - { - public override TClass Read(ref MessagePackReader reader, SerializationContext context) - { - throw new NotSupportedException(); - } - - [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to helper method")] - public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) - { - NerdbankMessagePackFormatter mainFormatter = context.GetFormatter(); - - context.DepthStep(); - PreciseTypeConverter.Serialize_Shared(mainFormatter, ref writer, value, context); - } - - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) - { - return CreateUndocumentedSchema(typeof(GeneratorConverter)); - } - } - } - - private static class PipeConverterResolver - { - public static MessagePackConverter GetConverter() - { - MessagePackConverter? converter = default; - - if (typeof(IDuplexPipe).IsAssignableFrom(typeof(T))) - { - converter = (MessagePackConverter?)Activator.CreateInstance(typeof(DuplexPipeConverter<>).MakeGenericType(typeof(T)))!; - } - else if (typeof(PipeReader).IsAssignableFrom(typeof(T))) - { - converter = (MessagePackConverter?)Activator.CreateInstance(typeof(PipeReaderConverter<>).MakeGenericType(typeof(T)))!; - } - else if (typeof(PipeWriter).IsAssignableFrom(typeof(T))) - { - converter = (MessagePackConverter?)Activator.CreateInstance(typeof(PipeWriterConverter<>).MakeGenericType(typeof(T)))!; - } - else if (typeof(Stream).IsAssignableFrom(typeof(T))) - { - converter = (MessagePackConverter?)Activator.CreateInstance(typeof(StreamConverter<>).MakeGenericType(typeof(T)))!; - } - - // TODO: Improve Exception - return converter ?? throw new NotSupportedException(); - } - - private class DuplexPipeConverter : MessagePackConverter - where T : class, IDuplexPipe - { - public override T? Read(ref MessagePackReader reader, SerializationContext context) - { - NerdbankMessagePackFormatter formatter = context.GetFormatter(); - - context.DepthStep(); - - if (reader.TryReadNil()) - { - return null; - } - - return (T)formatter.DuplexPipeTracker.GetPipe(reader.ReadUInt64()); - } - - public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) - { - NerdbankMessagePackFormatter formatter = context.GetFormatter(); - - context.DepthStep(); - - if (formatter.DuplexPipeTracker.GetULongToken(value) is ulong token) - { - writer.Write(token); - } - else - { - writer.WriteNil(); - } - } - - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) - { - return CreateUndocumentedSchema(typeof(DuplexPipeConverter)); - } - } - - private class PipeReaderConverter : MessagePackConverter - where T : PipeReader - { - public override T? Read(ref MessagePackReader reader, SerializationContext context) - { - NerdbankMessagePackFormatter formatter = context.GetFormatter(); - - context.DepthStep(); - if (reader.TryReadNil()) - { - return null; - } - - return (T)formatter.DuplexPipeTracker.GetPipeReader(reader.ReadUInt64()); - } - - public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) - { - NerdbankMessagePackFormatter formatter = context.GetFormatter(); - - context.DepthStep(); - if (formatter.DuplexPipeTracker.GetULongToken(value) is { } token) - { - writer.Write(token); - } - else - { - writer.WriteNil(); - } - } - - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) - { - return CreateUndocumentedSchema(typeof(PipeReaderConverter)); - } - } - - private class PipeWriterConverter : MessagePackConverter - where T : PipeWriter - { - public override T? Read(ref MessagePackReader reader, SerializationContext context) - { - NerdbankMessagePackFormatter formatter = context.GetFormatter(); - - context.DepthStep(); - if (reader.TryReadNil()) - { - return null; - } - - return (T)formatter.DuplexPipeTracker.GetPipeWriter(reader.ReadUInt64()); - } - - public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) - { - NerdbankMessagePackFormatter formatter = context.GetFormatter(); - - context.DepthStep(); - if (formatter.DuplexPipeTracker.GetULongToken(value) is ulong token) - { - writer.Write(token); - } - else - { - writer.WriteNil(); - } - } - - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) - { - return CreateUndocumentedSchema(typeof(PipeWriterConverter)); - } - } - - private class StreamConverter : MessagePackConverter - where T : Stream - { - public override T? Read(ref MessagePackReader reader, SerializationContext context) - { - NerdbankMessagePackFormatter formatter = context.GetFormatter(); - - context.DepthStep(); - if (reader.TryReadNil()) - { - return null; - } - - return (T)formatter.DuplexPipeTracker.GetPipe(reader.ReadUInt64()).AsStream(); - } - - public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) - { - NerdbankMessagePackFormatter formatter = context.GetFormatter(); - - context.DepthStep(); - if (formatter.DuplexPipeTracker.GetULongToken(value?.UsePipe()) is { } token) - { - writer.Write(token); - } - else - { - writer.WriteNil(); - } - } - - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) - { - return CreateUndocumentedSchema(typeof(StreamConverter)); - } - } - } - - /// - /// Manages serialization of any -derived type that follows standard rules. - /// - /// - /// A serializable class will: - /// 1. Derive from - /// 2. Be attributed with - /// 3. Declare a constructor with a signature of (, ). - /// - private static class MessagePackExceptionConverterResolver - { - /// - /// Tracks recursion count while serializing or deserializing an exception. - /// - /// - /// This is placed here (outside the generic class) - /// so that it's one counter shared across all exception types that may be serialized or deserialized. - /// - private static ThreadLocal exceptionRecursionCounter = new(); - - public static MessagePackConverter GetConverter() - { - MessagePackConverter? formatter = null; - if (typeof(Exception).IsAssignableFrom(typeof(T)) && typeof(T).GetCustomAttribute() is not null) - { - formatter = (MessagePackConverter)Activator.CreateInstance(typeof(ExceptionConverter<>).MakeGenericType(typeof(T)))!; - } - - // TODO: Improve Exception - return formatter ?? throw new NotSupportedException(); - } - - private partial class ExceptionConverter : MessagePackConverter - where T : Exception - { - public override T? Read(ref MessagePackReader reader, SerializationContext context) - { - NerdbankMessagePackFormatter formatter = context.GetFormatter(); - Assumes.NotNull(formatter.JsonRpc); - - context.DepthStep(); - - if (reader.TryReadNil()) - { - return null; - } - - // We have to guard our own recursion because the serializer has no visibility into inner exceptions. - // Each exception in the russian doll is a new serialization job from its perspective. - exceptionRecursionCounter.Value++; - try - { - if (exceptionRecursionCounter.Value > formatter.JsonRpc.ExceptionOptions.RecursionLimit) - { - // Exception recursion has gone too deep. Skip this value and return null as if there were no inner exception. - // Note that in skipping, the parser may use recursion internally and may still throw if its own limits are exceeded. - reader.Skip(context); - return null; - } - - // TODO: Is this the right context? - var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(formatter.userDataProfile)); - int memberCount = reader.ReadMapHeader(); - for (int i = 0; i < memberCount; i++) - { - string? name = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context) - ?? throw new MessagePackSerializationException(Resources.UnexpectedNullValueInMap); - - // SerializationInfo.GetValue(string, typeof(object)) does not call our formatter, - // so the caller will get a boxed RawMessagePack struct in that case. - // Although we can't do much about *that* in general, we can at least ensure that null values - // are represented as null instead of this boxed struct. - var value = reader.TryReadNil() ? null : (object)reader.ReadRaw(context); - - info.AddSafeValue(name, value); - } - - return ExceptionSerializationHelpers.Deserialize(formatter.JsonRpc, info, formatter.JsonRpc.TraceSource); - } - finally - { - exceptionRecursionCounter.Value--; - } - } - - public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) - { - NerdbankMessagePackFormatter formatter = context.GetFormatter(); - - context.DepthStep(); - if (value is null) - { - writer.WriteNil(); - return; - } - - exceptionRecursionCounter.Value++; - try - { - if (exceptionRecursionCounter.Value > formatter.JsonRpc?.ExceptionOptions.RecursionLimit) - { - // Exception recursion has gone too deep. Skip this value and write null as if there were no inner exception. - writer.WriteNil(); - return; - } - - // TODO: Is this the right profile? - var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(formatter.userDataProfile)); - ExceptionSerializationHelpers.Serialize(value, info); - writer.WriteMapHeader(info.GetSafeMemberCount()); - foreach (SerializationEntry element in info.GetSafeMembers()) - { - writer.Write(element.Name); - formatter.rpcProfile.SerializeObject( - ref writer, - element.Value, - element.ObjectType, - context.CancellationToken); - } - } - finally - { - exceptionRecursionCounter.Value--; - } - } - - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) - { - return CreateUndocumentedSchema(typeof(ExceptionConverter)); - } - } - } - - private static class EnumeratorResultsConverterResolver - { - public static MessagePackConverter> GetConverter() - { - MessagePackConverter>? converter = - (EnumeratorResultsConverter?)Activator - .CreateInstance(typeof(EnumeratorResultsConverter<>) - .MakeGenericType(typeof(T))); - - return converter ?? throw new NotSupportedException($"Could not create {nameof(EnumeratorResultsConverter)}."); - } - - private class EnumeratorResultsConverter : MessagePackConverter> - { - [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Reader is passed to user data context")] - public override MessageFormatterEnumerableTracker.EnumeratorResults? Read(ref MessagePackReader reader, SerializationContext context) - { - if (reader.TryReadNil()) - { - return default; - } - - NerdbankMessagePackFormatter formatter = context.GetFormatter(); - context.DepthStep(); - - Verify.Operation(reader.ReadArrayHeader() == 2, "Expected array of length 2."); - return new MessageFormatterEnumerableTracker.EnumeratorResults() - { - Values = formatter.userDataProfile.Deserialize>(ref reader, context.CancellationToken), - Finished = formatter.userDataProfile.Deserialize(ref reader, context.CancellationToken), - }; - } - - [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to user data context")] - public override void Write(ref MessagePackWriter writer, in MessageFormatterEnumerableTracker.EnumeratorResults? value, SerializationContext context) - { - if (value is null) - { - writer.WriteNil(); - } - else - { - NerdbankMessagePackFormatter formatter = context.GetFormatter(); - context.DepthStep(); - - writer.WriteArrayHeader(2); - formatter.userDataProfile.Serialize(ref writer, value.Values, context.CancellationToken); - formatter.userDataProfile.Serialize(ref writer, value.Finished, context.CancellationToken); - } - } - - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) - { - return CreateUndocumentedSchema(typeof(EnumeratorResultsConverter)); - } - } - } - - private class RpcMarshalableConverter( - JsonRpcProxyOptions proxyOptions, - JsonRpcTargetOptions targetOptions, - RpcMarshalableAttribute rpcMarshalableAttribute) : MessagePackConverter - where T : class - { - [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Reader is passed to rpc context")] - public override T? Read(ref MessagePackReader reader, SerializationContext context) - { - NerdbankMessagePackFormatter formatter = context.GetFormatter(); - - context.DepthStep(); - - // This converter instance is registered with the user data profile, - // however the shape of MarshalToken is defined by the StreamJsonRpc source generator provider. - MessageFormatterRpcMarshaledContextTracker.MarshalToken? token = context - .GetConverter(ShapeProvider_StreamJsonRpc.Default) - .Read(ref reader, context); - - return token.HasValue - ? (T?)formatter.RpcMarshaledContextTracker.GetObject(typeof(T), token, proxyOptions) - : null; - } - - [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to rpc context")] - public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) - { - NerdbankMessagePackFormatter formatter = context.GetFormatter(); - - context.DepthStep(); - - if (value is null) - { - writer.WriteNil(); - } - else - { - MessageFormatterRpcMarshaledContextTracker.MarshalToken token = formatter.RpcMarshaledContextTracker.GetToken(value, targetOptions, typeof(T), rpcMarshalableAttribute); - - // This converter instance is registered with the user data profile, - // however the shape of MarshalToken is defined by the StreamJsonRpc source generator provider. - context.GetConverter(ShapeProvider_StreamJsonRpc.Default) - .Write(ref writer, token, context); - } - } - - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) - { - return CreateUndocumentedSchema(typeof(RpcMarshalableConverter)); - } - } - - /// - /// Enables formatting the default/empty class. - /// - private class EventArgsConverter : MessagePackConverter - { - internal static readonly EventArgsConverter Instance = new(); - - private EventArgsConverter() - { - } - - /// - public override void Write(ref MessagePackWriter writer, in EventArgs? value, SerializationContext context) - { - Requires.NotNull(value!, nameof(value)); - context.DepthStep(); - writer.WriteMapHeader(0); - } - - /// - public override EventArgs Read(ref MessagePackReader reader, SerializationContext context) - { - context.DepthStep(); - reader.Skip(context); - return EventArgs.Empty; - } - - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) - { - return CreateUndocumentedSchema(typeof(EventArgsConverter)); - } - } - - private class RequestIdConverter : MessagePackConverter - { - internal static readonly RequestIdConverter Instance = new(); - - private RequestIdConverter() - { - } - - public override RequestId Read(ref MessagePackReader reader, SerializationContext context) - { - context.DepthStep(); - - if (reader.NextMessagePackType == MessagePackType.Integer) - { - return new RequestId(reader.ReadInt64()); - } - else - { - // Do *not* read as an interned string here because this ID should be unique. - return new RequestId(reader.ReadString()); - } - } - - public override void Write(ref MessagePackWriter writer, in RequestId value, SerializationContext context) - { - context.DepthStep(); - - if (value.Number.HasValue) - { - writer.Write(value.Number.Value); - } - else - { - writer.Write(value.String); - } - } - - public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => JsonNode.Parse(""" - { - "type": ["string", { "type": "integer", "format": "int64" }] - } - """)?.AsObject(); - } - private class TopLevelPropertyBag : TopLevelPropertyBagBase { private readonly Profile formatterProfile; diff --git a/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs b/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs index 65f2020d1..712c7c09b 100644 --- a/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs +++ b/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs @@ -9,7 +9,6 @@ using System.Reflection; using System.Runtime.Serialization; using Microsoft.VisualStudio.Threading; -using Nerdbank.MessagePack; using PolyType; using static System.FormattableString; using STJ = System.Text.Json.Serialization; @@ -452,10 +451,15 @@ private void CleanUpOutboundResources(RequestId requestId, bool successful) } } + /// + /// A token that represents a marshaled object. + /// [DataContract] - internal struct MarshalToken + [GenerateShape] + internal partial struct MarshalToken { [MessagePack.SerializationConstructor] + [ConstructorShape] #pragma warning disable SA1313 // Parameter names should begin with lower-case letter public MarshalToken(int __jsonrpc_marshaled, long handle, string? lifetime, int[]? optionalInterfaces) #pragma warning restore SA1313 // Parameter names should begin with lower-case letter @@ -468,18 +472,22 @@ public MarshalToken(int __jsonrpc_marshaled, long handle, string? lifetime, int[ [DataMember(Name = "__jsonrpc_marshaled", IsRequired = true)] [STJ.JsonPropertyName("__jsonrpc_marshaled"), STJ.JsonRequired] + [PropertyShape(Name = "__jsonrpc_marshaled")] public int Marshaled { get; set; } [DataMember(Name = "handle", IsRequired = true)] [STJ.JsonPropertyName("handle"), STJ.JsonRequired] + [PropertyShape(Name = "handle")] public long Handle { get; set; } [DataMember(Name = "lifetime", EmitDefaultValue = false)] [STJ.JsonPropertyName("lifetime"), STJ.JsonIgnore(Condition = STJ.JsonIgnoreCondition.WhenWritingNull)] + [PropertyShape(Name = "lifetime")] public string? Lifetime { get; set; } [DataMember(Name = "optionalInterfaces", EmitDefaultValue = false)] [STJ.JsonPropertyName("optionalInterfaces"), STJ.JsonIgnore(Condition = STJ.JsonIgnoreCondition.WhenWritingNull)] + [PropertyShape(Name = "optionalInterfaces")] public int[]? OptionalInterfacesCodes { get; set; } } diff --git a/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs index df4cdeded..d3dd9ed7b 100644 --- a/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs @@ -24,8 +24,8 @@ protected override void InitializeFormattersAndHandlers() static void ConfigureContext(NerdbankMessagePackFormatter.Profile.Builder profileBuilder) { - profileBuilder.RegisterAsyncEnumerableType(); - profileBuilder.RegisterAsyncEnumerableType(); + profileBuilder.RegisterAsyncEnumerableConverter(); + profileBuilder.RegisterAsyncEnumerableConverter(); profileBuilder.AddTypeShapeProvider(AsyncEnumerableWitness.ShapeProvider); profileBuilder.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); } diff --git a/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs index 7ad1ac7b3..de8990007 100644 --- a/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System.Diagnostics; using System.IO.Pipelines; using Nerdbank.MessagePack; using PolyType; @@ -19,6 +20,10 @@ protected override IJsonRpcMessageFormatter CreateFormatter() NerdbankMessagePackFormatter formatter = new(); formatter.SetFormatterProfile(b => { + KnownSubTypeMapping disposableMapping = new(); + disposableMapping.Add(alias: 1, DisposableProxyWitness.ShapeProvider); + + b.RegisterKnownSubTypes(disposableMapping); b.AddTypeShapeProvider(DisposableProxyWitness.ShapeProvider); b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); }); @@ -30,6 +35,7 @@ protected override IJsonRpcMessageFormatter CreateFormatter() [GenerateShape] [GenerateShape] [GenerateShape] +[GenerateShape] #pragma warning disable SA1402 // File may only contain a single type public partial class DisposableProxyWitness; #pragma warning restore SA1402 // File may only contain a single type diff --git a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs index 2dd30365b..8b388d722 100644 --- a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs @@ -33,14 +33,7 @@ protected override void InitializeFormattersAndHandlers() static void Configure(NerdbankMessagePackFormatter.Profile.Builder b) { - b.AddTypeShapeProvider(DuplexPipeWitness.ShapeProvider); b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); } } } - -[GenerateShape] -[GenerateShape] -#pragma warning disable SA1402 // File may only contain a single type -public partial class DuplexPipeWitness; -#pragma warning restore SA1402 // File may only contain a single type diff --git a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingTests.cs b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingTests.cs index b3a9587ff..0c367c83e 100644 --- a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingTests.cs +++ b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingTests.cs @@ -11,7 +11,7 @@ using PolyType; using STJ = System.Text.Json.Serialization; -public abstract class DuplexPipeMarshalingTests : TestBase, IAsyncLifetime +public abstract partial class DuplexPipeMarshalingTests : TestBase, IAsyncLifetime { protected readonly Server server = new Server(); protected JsonRpc serverRpc; @@ -333,13 +333,17 @@ public async Task ServerMethodThatReturnsCustomTypeWithInnerStream() result.InnerStream.Dispose(); } - [Fact] + [SkippableFact] public async Task PassStreamWithArgsAsSingleObject() { + Skip.If(this.GetType() == typeof(DuplexPipeMarshalingNerdbankMessagePackTests), "Dynamic types are not supported with NerdBankMessagePack."); MemoryStream ms = new(); ms.Write(new byte[] { 1, 2, 3 }, 0, 3); ms.Position = 0; - int bytesRead = await this.clientRpc.InvokeWithParameterObjectAsync(nameof(Server.AcceptStreamArgInFirstParam), new { innerStream = ms }, this.TimeoutToken); + int bytesRead = await this.clientRpc.InvokeWithParameterObjectAsync( + nameof(Server.AcceptStreamArgInFirstParam), + new { innerStream = ms }, + this.TimeoutToken); Assert.Equal(ms.Length, bytesRead); } @@ -469,15 +473,32 @@ await this.clientRpc.InvokeWithCancellationAsync( pipePair.Item1.Input.Complete(); } - [Theory] + [SkippableTheory] [CombinatorialData] public async Task ClientCanSendTwoWayStreamToServer(bool serverUsesStream) + { + Skip.If(this.GetType() == typeof(DuplexPipeMarshalingNerdbankMessagePackTests), "This test is not supported with NerdBankMessagePack."); + (Stream, Stream) streamPair = FullDuplexStream.CreatePair(); + Task twoWayCom = TwoWayTalkAsync(streamPair.Item1, writeOnOdd: true, this.TimeoutToken); + await this.clientRpc.InvokeWithCancellationAsync( + serverUsesStream ? nameof(Server.TwoWayStreamAsArg) : nameof(Server.TwoWayPipeAsArg), + [false, streamPair.Item2], + this.TimeoutToken); + await twoWayCom.WithCancellation(this.TimeoutToken); // rethrow any exceptions. + + streamPair.Item1.Dispose(); + } + + [Theory] + [CombinatorialData] + public async Task ClientCanSendTwoWayStreamToServer_WithExplicitTypes(bool serverUsesStream) { (Stream, Stream) streamPair = FullDuplexStream.CreatePair(); Task twoWayCom = TwoWayTalkAsync(streamPair.Item1, writeOnOdd: true, this.TimeoutToken); await this.clientRpc.InvokeWithCancellationAsync( serverUsesStream ? nameof(Server.TwoWayStreamAsArg) : nameof(Server.TwoWayPipeAsArg), [false, streamPair.Item2], + [typeof(bool), typeof(Stream)], this.TimeoutToken); await twoWayCom.WithCancellation(this.TimeoutToken); // rethrow any exceptions. @@ -784,17 +805,21 @@ private async Task ServerStreamIsDisposedWhenClientDisconnects(string rpcMethodN } [DataContract] - public class StreamContainingClass + [GenerateShape] + public partial class StreamContainingClass { [DataMember] + [PropertyShape(Ignore = false)] private Stream innerStream; + [ConstructorShape] public StreamContainingClass(Stream innerStream) { this.innerStream = innerStream; } [STJ.JsonPropertyName("innerStream")] + [PropertyShape(Name = "innerStream")] public Stream InnerStream => this.innerStream; } diff --git a/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs index 9cb94f205..0f4192349 100644 --- a/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs +++ b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs @@ -410,13 +410,13 @@ static void Configure(NerdbankMessagePackFormatter.Profile.Builder b) { b.RegisterConverter(new UnserializableTypeConverter()); b.RegisterConverter(new TypeThrowsWhenDeserializedConverter()); - b.RegisterAsyncEnumerableType(); - b.RegisterAsyncEnumerableType(); - b.RegisterProgressType(); - b.RegisterProgressType(); - b.RegisterProgressType(); - b.RegisterProgressType, int>(); - b.RegisterProgressType, CustomSerializedType>(); + b.RegisterAsyncEnumerableConverter(); + b.RegisterAsyncEnumerableConverter(); + b.RegisterProgressConverter(); + b.RegisterProgressConverter(); + b.RegisterProgressConverter(); + b.RegisterProgressConverter, int>(); + b.RegisterProgressConverter, CustomSerializedType>(); b.AddTypeShapeProvider(JsonRpcWitness.ShapeProvider); b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); } diff --git a/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs index aa3f0a40b..22c7a0eef 100644 --- a/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs @@ -18,20 +18,20 @@ protected override IJsonRpcMessageFormatter CreateFormatter() NerdbankMessagePackFormatter formatter = new(); formatter.SetFormatterProfile(b => { - b.RegisterRpcMarshalableType(); - b.RegisterRpcMarshalableType(); - b.RegisterRpcMarshalableType(); - b.RegisterRpcMarshalableType(); - b.RegisterRpcMarshalableType(); - b.RegisterRpcMarshalableType(); - b.RegisterRpcMarshalableType(); - b.RegisterRpcMarshalableType(); - b.RegisterRpcMarshalableType(); - b.RegisterRpcMarshalableType(); - b.RegisterRpcMarshalableType(); - b.RegisterRpcMarshalableType(); - b.RegisterRpcMarshalableType(); - b.RegisterRpcMarshalableType>(); + b.RegisterRpcMarshalableConverter(); + b.RegisterRpcMarshalableConverter(); + b.RegisterRpcMarshalableConverter(); + b.RegisterRpcMarshalableConverter(); + b.RegisterRpcMarshalableConverter(); + b.RegisterRpcMarshalableConverter(); + b.RegisterRpcMarshalableConverter(); + b.RegisterRpcMarshalableConverter(); + b.RegisterRpcMarshalableConverter(); + b.RegisterRpcMarshalableConverter(); + b.RegisterRpcMarshalableConverter(); + b.RegisterRpcMarshalableConverter(); + b.RegisterRpcMarshalableConverter(); + b.RegisterRpcMarshalableConverter>(); b.AddTypeShapeProvider(MarshalableProxyWitness.ShapeProvider); b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); }); @@ -40,5 +40,7 @@ protected override IJsonRpcMessageFormatter CreateFormatter() } [GenerateShape] + [GenerateShape] + [GenerateShape] public partial class MarshalableProxyWitness; } diff --git a/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs b/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs index 605049fea..3f60df26a 100644 --- a/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs +++ b/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs @@ -273,9 +273,10 @@ public void Resolver_ErrorData() Assert.Equal(originalErrorData.Prop2, roundtripErrorData.Prop2); } - [Fact] + [SkippableFact] public void CanDeserializeWithExtraProperty_JsonRpcRequest() { + Skip.If(this.Formatter is NerdbankMessagePackFormatter, "Dynamic types are not supported for NerdbankMessagePack."); var dynamic = new { jsonrpc = "2.0", @@ -307,9 +308,10 @@ public void CanDeserializeWithExtraProperty_JsonRpcResult() Assert.Equal(dynamic.result, request.GetResult()); } - [Fact] + [SkippableFact] public void CanDeserializeWithExtraProperty_JsonRpcError() { + Skip.If(this.Formatter is NerdbankMessagePackFormatter, "Dynamic types are not supported for NerdbankMessagePack."); var dynamic = new { jsonrpc = "2.0", @@ -326,6 +328,7 @@ public void CanDeserializeWithExtraProperty_JsonRpcError() [Fact] public void StringsInUserDataAreInterned() { + Skip.If(this.Formatter is NerdbankMessagePackFormatter, "Dynamic types are not supported for NerdbankMessagePack."); var dynamic = new { jsonrpc = "2.0", diff --git a/test/StreamJsonRpc.Tests/ObserverMarshalingNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/ObserverMarshalingNerdbankMessagePackTests.cs index ccf419fa2..d49d208e3 100644 --- a/test/StreamJsonRpc.Tests/ObserverMarshalingNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/ObserverMarshalingNerdbankMessagePackTests.cs @@ -15,7 +15,7 @@ protected override IJsonRpcMessageFormatter CreateFormatter() NerdbankMessagePackFormatter formatter = new(); formatter.SetFormatterProfile(b => { - b.RegisterRpcMarshalableType>(); + b.RegisterRpcMarshalableConverter>(); b.RegisterExceptionType(); b.AddTypeShapeProvider(ObserverMarshalingWitness.ShapeProvider); b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); diff --git a/test/StreamJsonRpc.Tests/TargetObjectEventsNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/TargetObjectEventsNerdbankMessagePackTests.cs index 12152f55d..c1dff17b9 100644 --- a/test/StreamJsonRpc.Tests/TargetObjectEventsNerdbankMessagePackTests.cs +++ b/test/StreamJsonRpc.Tests/TargetObjectEventsNerdbankMessagePackTests.cs @@ -10,7 +10,16 @@ protected override void InitializeFormattersAndHandlers() var serverMessageFormatter = new NerdbankMessagePackFormatter(); var clientMessageFormatter = new NerdbankMessagePackFormatter(); + serverMessageFormatter.SetFormatterProfile(ConfigureContext); + clientMessageFormatter.SetFormatterProfile(ConfigureContext); + this.serverMessageHandler = new LengthHeaderMessageHandler(this.serverStream, this.serverStream, serverMessageFormatter); this.clientMessageHandler = new LengthHeaderMessageHandler(this.clientStream, this.clientStream, clientMessageFormatter); + + void ConfigureContext(NerdbankMessagePackFormatter.Profile.Builder profileBuilder) + { + profileBuilder.AddTypeShapeProvider(PolyType.SourceGenerator.ShapeProvider_StreamJsonRpc_Tests.Default); + profileBuilder.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); + } } } diff --git a/test/StreamJsonRpc.Tests/TargetObjectEventsTests.cs b/test/StreamJsonRpc.Tests/TargetObjectEventsTests.cs index 9acc55756..3fe7f6d8b 100644 --- a/test/StreamJsonRpc.Tests/TargetObjectEventsTests.cs +++ b/test/StreamJsonRpc.Tests/TargetObjectEventsTests.cs @@ -3,11 +3,13 @@ using Microsoft; using Microsoft.VisualStudio.Threading; using Nerdbank; +using Nerdbank.MessagePack; +using PolyType; using StreamJsonRpc; using Xunit; using Xunit.Abstractions; -public abstract class TargetObjectEventsTests : TestBase +public abstract partial class TargetObjectEventsTests : TestBase { protected IJsonRpcMessageHandler serverMessageHandler = null!; protected IJsonRpcMessageHandler clientMessageHandler = null!; @@ -35,7 +37,11 @@ public TargetObjectEventsTests(ITestOutputHelper logger) } [MessagePack.Union(key: 0, typeof(Fruit))] - public interface IFruit + [GenerateShape] +#pragma warning disable CS0618 // Type or member is obsolete + [KnownSubType(typeof(Fruit), 1)] +#pragma warning restore CS0618 // Type or member is obsolete + public partial interface IFruit { string Name { get; } } @@ -360,7 +366,8 @@ private void ReinitializeRpcWithoutListening() } [DataContract] - public class Fruit : IFruit + [GenerateShape] + public partial class Fruit : IFruit { internal Fruit(string name) {