diff --git a/Directory.Packages.props b/Directory.Packages.props index 462cbe578..b16855566 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -22,11 +22,12 @@ + - + - + diff --git a/nuget.config b/nuget.config index 0a73357e8..ee3500d0c 100644 --- a/nuget.config +++ b/nuget.config @@ -5,10 +5,26 @@ + + + + + + + + + + + + + + + + diff --git a/src/StreamJsonRpc/FormatterBase.cs b/src/StreamJsonRpc/FormatterBase.cs index 7dd4c6479..42b283359 100644 --- a/src/StreamJsonRpc/FormatterBase.cs +++ b/src/StreamJsonRpc/FormatterBase.cs @@ -6,7 +6,9 @@ using System.IO.Pipelines; using System.Reflection; using System.Runtime.Serialization; +using Nerdbank.MessagePack; using Nerdbank.Streams; +using PolyType; using StreamJsonRpc.Protocol; using StreamJsonRpc.Reflection; @@ -320,7 +322,7 @@ public void Dispose() /// /// A base class for top-level property bags that should be declared in the derived formatter class. /// - protected abstract class TopLevelPropertyBagBase + protected internal abstract class TopLevelPropertyBagBase { private readonly bool isOutbound; private Dictionary? outboundProperties; @@ -436,6 +438,7 @@ protected abstract class JsonRpcErrorBase : JsonRpcError, IJsonRpcMessageBufferM /// [Newtonsoft.Json.JsonIgnore] [IgnoreDataMember] + [PropertyShape(Ignore = true)] public TopLevelPropertyBagBase? TopLevelPropertyBag { get; set; } void IJsonRpcMessageBufferManager.DeserializationComplete(JsonRpcMessage message) @@ -480,6 +483,7 @@ protected abstract class JsonRpcResultBase : JsonRpcResult, IJsonRpcMessageBuffe /// [Newtonsoft.Json.JsonIgnore] [IgnoreDataMember] + [PropertyShape(Ignore = true)] public TopLevelPropertyBagBase? TopLevelPropertyBag { get; set; } void IJsonRpcMessageBufferManager.DeserializationComplete(JsonRpcMessage message) diff --git a/src/StreamJsonRpc/JsonRpc.cs b/src/StreamJsonRpc/JsonRpc.cs index e8ab2712a..e5f1e596a 100644 --- a/src/StreamJsonRpc/JsonRpc.cs +++ b/src/StreamJsonRpc/JsonRpc.cs @@ -1697,7 +1697,7 @@ protected virtual async ValueTask DispatchRequestAsync(JsonRpcRe } /// - /// Sends the JSON-RPC message to intance to be transmitted. + /// Sends the JSON-RPC message to instance to be transmitted. /// /// The message to send. /// A token to cancel the send request. diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.AsyncEnumerableConverters.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.AsyncEnumerableConverters.cs new file mode 100644 index 000000000..1938eb38f --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.AsyncEnumerableConverters.cs @@ -0,0 +1,154 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Buffers; +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using Nerdbank.Streams; +using PolyType.Abstractions; +using StreamJsonRpc.Reflection; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + private static class AsyncEnumerableConverters + { + /// + /// Converts an enumeration token to an + /// or an into an enumeration token. + /// +#pragma warning disable CA1812 + internal class PreciseTypeConverter : MessagePackConverter> +#pragma warning restore CA1812 + { + /// + /// The constant "token", in its various forms. + /// + private static readonly MessagePackString TokenPropertyName = new(MessageFormatterEnumerableTracker.TokenPropertyName); + + /// + /// The constant "values", in its various forms. + /// + private static readonly MessagePackString ValuesPropertyName = new(MessageFormatterEnumerableTracker.ValuesPropertyName); + + public override IAsyncEnumerable? Read(ref MessagePackReader reader, SerializationContext context) + { + if (reader.TryReadNil()) + { + return default; + } + + NerdbankMessagePackFormatter mainFormatter = context.GetFormatter(); + + context.DepthStep(); + + RawMessagePack? token = default; + IReadOnlyList? initialElements = null; + int propertyCount = reader.ReadMapHeader(); + for (int i = 0; i < propertyCount; i++) + { + if (TokenPropertyName.TryRead(ref reader)) + { + // The value needs to outlive the reader, so we clone it. + token = new RawMessagePack(reader.ReadRaw(context)).ToOwned(); + } + else if (ValuesPropertyName.TryRead(ref reader)) + { + initialElements = context.GetConverter>(context.TypeShapeProvider).Read(ref reader, context); + } + else + { + reader.Skip(context); + } + } + + return mainFormatter.EnumerableTracker.CreateEnumerableProxy(token.HasValue ? token.Value : null, initialElements); + } + + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to helper method")] + public override void Write(ref MessagePackWriter writer, in IAsyncEnumerable? value, SerializationContext context) + { + context.DepthStep(); + + NerdbankMessagePackFormatter mainFormatter = context.GetFormatter(); + Serialize_Shared(mainFormatter, ref writer, value, context); + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(PreciseTypeConverter)); + } + + internal static void Serialize_Shared(NerdbankMessagePackFormatter mainFormatter, ref MessagePackWriter writer, IAsyncEnumerable? value, SerializationContext context) + { + if (value is null) + { + writer.WriteNil(); + } + else + { + (IReadOnlyList elements, bool finished) = value.TearOffPrefetchedElements(); + long token = mainFormatter.EnumerableTracker.GetToken(value); + + int propertyCount = 0; + if (elements.Count > 0) + { + propertyCount++; + } + + if (!finished) + { + propertyCount++; + } + + writer.WriteMapHeader(propertyCount); + + if (!finished) + { + writer.Write(TokenPropertyName); + writer.Write(token); + } + + if (elements.Count > 0) + { + writer.Write(ValuesPropertyName); + context.GetConverter>(context.TypeShapeProvider).Write(ref writer, elements, context); + } + } + } + } + + /// + /// Converts an instance of to an enumeration token. + /// +#pragma warning disable CA1812 + internal class GeneratorConverter : MessagePackConverter + where TClass : IAsyncEnumerable +#pragma warning restore CA1812 + { + public override TClass Read(ref MessagePackReader reader, SerializationContext context) + { + throw new NotSupportedException(); + } + + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to helper method")] + public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) + { + NerdbankMessagePackFormatter mainFormatter = context.GetFormatter(); + + context.DepthStep(); + PreciseTypeConverter.Serialize_Shared(mainFormatter, ref writer, value, context); + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(GeneratorConverter)); + } + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.Constants.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.Constants.cs new file mode 100644 index 000000000..8990d0a14 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.Constants.cs @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Nerdbank.MessagePack; +using StreamJsonRpc.Protocol; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + /// + /// The constant "jsonrpc", in its various forms. + /// + private static readonly MessagePackString VersionPropertyName = new(Constants.jsonrpc); + + /// + /// The constant "id", in its various forms. + /// + private static readonly MessagePackString IdPropertyName = new(Constants.id); + + /// + /// The constant "method", in its various forms. + /// + private static readonly MessagePackString MethodPropertyName = new(Constants.Request.method); + + /// + /// The constant "result", in its various forms. + /// + private static readonly MessagePackString ResultPropertyName = new(Constants.Result.result); + + /// + /// The constant "error", in its various forms. + /// + private static readonly MessagePackString ErrorPropertyName = new(Constants.Error.error); + + /// + /// The constant "params", in its various forms. + /// + private static readonly MessagePackString ParamsPropertyName = new(Constants.Request.@params); + + /// + /// The constant "traceparent", in its various forms. + /// + private static readonly MessagePackString TraceParentPropertyName = new(Constants.Request.traceparent); + + /// + /// The constant "tracestate", in its various forms. + /// + private static readonly MessagePackString TraceStatePropertyName = new(Constants.Request.tracestate); + + /// + /// The constant "2.0", in its various forms. + /// + private static readonly MessagePackString Version2 = new("2.0"); +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.EnumeratorResultsConverter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.EnumeratorResultsConverter.cs new file mode 100644 index 000000000..379623815 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.EnumeratorResultsConverter.cs @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using PolyType.Abstractions; +using StreamJsonRpc.Reflection; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + private class EnumeratorResultsConverter : MessagePackConverter> + { + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Reader is passed to user data context")] + public override MessageFormatterEnumerableTracker.EnumeratorResults? Read(ref MessagePackReader reader, SerializationContext context) + { + if (reader.TryReadNil()) + { + return default; + } + + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); + + Verify.Operation(reader.ReadArrayHeader() == 2, "Expected array of length 2."); + return new MessageFormatterEnumerableTracker.EnumeratorResults() + { + Values = formatter.userDataProfile.Deserialize>(ref reader, context.CancellationToken), + Finished = formatter.userDataProfile.Deserialize(ref reader, context.CancellationToken), + }; + } + + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to user data context")] + public override void Write(ref MessagePackWriter writer, in MessageFormatterEnumerableTracker.EnumeratorResults? value, SerializationContext context) + { + if (value is null) + { + writer.WriteNil(); + } + else + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); + + writer.WriteArrayHeader(2); + formatter.userDataProfile.Serialize(ref writer, value.Values, context.CancellationToken); + formatter.userDataProfile.Serialize(ref writer, value.Finished, context.CancellationToken); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(EnumeratorResultsConverter)); + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.EventArgsConverter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.EventArgsConverter.cs new file mode 100644 index 000000000..24e0bf76d --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.EventArgsConverter.cs @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using PolyType.Abstractions; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + /// + /// Enables formatting the default/empty class. + /// + private class EventArgsConverter : MessagePackConverter + { + internal static readonly EventArgsConverter Instance = new(); + + private EventArgsConverter() + { + } + + /// + public override void Write(ref MessagePackWriter writer, in EventArgs? value, SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + context.DepthStep(); + writer.WriteMapHeader(0); + } + + /// + public override EventArgs Read(ref MessagePackReader reader, SerializationContext context) + { + context.DepthStep(); + reader.Skip(context); + return EventArgs.Empty; + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(EventArgsConverter)); + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.ExceptionConverter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ExceptionConverter.cs new file mode 100644 index 000000000..6831f6245 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ExceptionConverter.cs @@ -0,0 +1,127 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Runtime.Serialization; +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using PolyType.Abstractions; +using StreamJsonRpc.Reflection; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + /// + /// Manages serialization of any -derived type that follows standard rules. + /// + /// + /// A serializable class will: + /// 1. Derive from + /// 2. Be attributed with + /// 3. Declare a constructor with a signature of (, ). + /// + private class ExceptionConverter : MessagePackConverter + where T : Exception + { + public static readonly ExceptionConverter Instance = new(); + + public override T? Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + Assumes.NotNull(formatter.JsonRpc); + + context.DepthStep(); + + if (reader.TryReadNil()) + { + return null; + } + + // We have to guard our own recursion because the serializer has no visibility into inner exceptions. + // Each exception in the russian doll is a new serialization job from its perspective. + formatter.exceptionRecursionCounter.Value++; + try + { + if (formatter.exceptionRecursionCounter.Value > formatter.JsonRpc.ExceptionOptions.RecursionLimit) + { + // Exception recursion has gone too deep. Skip this value and return null as if there were no inner exception. + // Note that in skipping, the parser may use recursion internally and may still throw if its own limits are exceeded. + reader.Skip(context); + return null; + } + + // TODO: Is this the right context? + var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(formatter.userDataProfile)); + int memberCount = reader.ReadMapHeader(); + for (int i = 0; i < memberCount; i++) + { + string? name = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context) + ?? throw new MessagePackSerializationException(Resources.UnexpectedNullValueInMap); + + // SerializationInfo.GetValue(string, typeof(object)) does not call our formatter, + // so the caller will get a boxed RawMessagePack struct in that case. + // Although we can't do much about *that* in general, we can at least ensure that null values + // are represented as null instead of this boxed struct. + var value = reader.TryReadNil() ? null : (object)reader.ReadRaw(context); + + info.AddSafeValue(name, value); + } + + return ExceptionSerializationHelpers.Deserialize(formatter.JsonRpc, info, formatter.JsonRpc.TraceSource); + } + finally + { + formatter.exceptionRecursionCounter.Value--; + } + } + + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + if (value is null) + { + writer.WriteNil(); + return; + } + + formatter.exceptionRecursionCounter.Value++; + try + { + if (formatter.exceptionRecursionCounter.Value > formatter.JsonRpc?.ExceptionOptions.RecursionLimit) + { + // Exception recursion has gone too deep. Skip this value and write null as if there were no inner exception. + writer.WriteNil(); + return; + } + + // TODO: Is this the right profile? + var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(formatter.userDataProfile)); + ExceptionSerializationHelpers.Serialize(value, info); + writer.WriteMapHeader(info.GetSafeMemberCount()); + foreach (SerializationEntry element in info.GetSafeMembers()) + { + writer.Write(element.Name); + formatter.rpcProfile.SerializeObject( + ref writer, + element.Value, + element.ObjectType, + context.CancellationToken); + } + } + finally + { + formatter.exceptionRecursionCounter.Value--; + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(ExceptionConverter)); + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.IJsonRpcMessagePackRetention.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.IJsonRpcMessagePackRetention.cs new file mode 100644 index 000000000..6ceb33db0 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.IJsonRpcMessagePackRetention.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Buffers; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + private interface IJsonRpcMessagePackRetention + { + /// + /// Gets the original msgpack sequence that was deserialized into this message. + /// + /// + /// The buffer is only retained for a short time. If it has already been cleared, the result of this property is an empty sequence. + /// + ReadOnlySequence OriginalMessagePack { get; } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.MessagePackFormatterConverter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.MessagePackFormatterConverter.cs new file mode 100644 index 000000000..784455db9 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.MessagePackFormatterConverter.cs @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Buffers; +using System.Runtime.Serialization; +using Nerdbank.MessagePack; +using StreamJsonRpc.Reflection; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + private class MessagePackFormatterConverter : IFormatterConverter + { + private readonly Profile formatterContext; + + internal MessagePackFormatterConverter(Profile formatterContext) + { + this.formatterContext = formatterContext; + } + +#pragma warning disable CS8766 // This method may in fact return null, and no one cares. + public object? Convert(object value, Type type) +#pragma warning restore CS8766 + { + return this.formatterContext.DeserializeObject((ReadOnlySequence)value, type); + } + + public object Convert(object value, TypeCode typeCode) + { + return typeCode switch + { + TypeCode.Object => this.formatterContext.Deserialize((ReadOnlySequence)value)!, + _ => ExceptionSerializationHelpers.Convert(this, value, typeCode), + }; + } + + public bool ToBoolean(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); + + public byte ToByte(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); + + public char ToChar(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); + + public DateTime ToDateTime(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); + + public decimal ToDecimal(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); + + public double ToDouble(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); + + public short ToInt16(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); + + public int ToInt32(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); + + public long ToInt64(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); + + public sbyte ToSByte(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); + + public float ToSingle(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); + + public string? ToString(object value) => value is null ? null : this.formatterContext.Deserialize((ReadOnlySequence)value); + + public ushort ToUInt16(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); + + public uint ToUInt32(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); + + public ulong ToUInt64(object value) => this.formatterContext.Deserialize((ReadOnlySequence)value); + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.PipeConverters.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.PipeConverters.cs new file mode 100644 index 000000000..5e77df4d1 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.PipeConverters.cs @@ -0,0 +1,177 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.IO.Pipelines; +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using Nerdbank.Streams; +using PolyType.Abstractions; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + private static class PipeConverters + { + internal class DuplexPipeConverter : MessagePackConverter + where T : class, IDuplexPipe + { + public static readonly DuplexPipeConverter DefaultInstance = new(); + + public override T? Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + if (reader.TryReadNil()) + { + return null; + } + + return (T)formatter.DuplexPipeTracker.GetPipe(reader.ReadUInt64()); + } + + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + if (formatter.DuplexPipeTracker.GetULongToken(value) is ulong token) + { + writer.Write(token); + } + else + { + writer.WriteNil(); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(DuplexPipeConverter)); + } + } + + internal class PipeReaderConverter : MessagePackConverter + where T : PipeReader + { + public static readonly PipeReaderConverter DefaultInstance = new(); + + public override T? Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + if (reader.TryReadNil()) + { + return null; + } + + return (T)formatter.DuplexPipeTracker.GetPipeReader(reader.ReadUInt64()); + } + + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + if (formatter.DuplexPipeTracker.GetULongToken(value) is { } token) + { + writer.Write(token); + } + else + { + writer.WriteNil(); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(PipeReaderConverter)); + } + } + + internal class PipeWriterConverter : MessagePackConverter + where T : PipeWriter + { + public static readonly PipeWriterConverter DefaultInstance = new(); + + public override T? Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + if (reader.TryReadNil()) + { + return null; + } + + return (T)formatter.DuplexPipeTracker.GetPipeWriter(reader.ReadUInt64()); + } + + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + if (formatter.DuplexPipeTracker.GetULongToken(value) is ulong token) + { + writer.Write(token); + } + else + { + writer.WriteNil(); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(PipeWriterConverter)); + } + } + + internal class StreamConverter : MessagePackConverter + where T : Stream + { + public static readonly StreamConverter DefaultInstance = new(); + + public override T? Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + if (reader.TryReadNil()) + { + return null; + } + + return (T)formatter.DuplexPipeTracker.GetPipe(reader.ReadUInt64()).AsStream(); + } + + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + if (formatter.DuplexPipeTracker.GetULongToken(value?.UsePipe()) is { } token) + { + writer.Write(token); + } + else + { + writer.WriteNil(); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(StreamConverter)); + } + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs new file mode 100644 index 000000000..34a5ea1e3 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.Builder.cs @@ -0,0 +1,213 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Collections.Immutable; +using System.IO.Pipelines; +using Nerdbank.MessagePack; +using PolyType; +using StreamJsonRpc.Reflection; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + /// + /// A serialization profile for the . + /// + public partial class Profile + { + /// + /// Provides methods to build a serialization profile for the . + /// + public class Builder + { + private readonly Profile baseProfile; + + private ImmutableArray.Builder? typeShapeProvidersBuilder = null; + + /// + /// Initializes a new instance of the class. + /// + /// The base profile to build upon. + internal Builder(Profile baseProfile) + { + this.baseProfile = baseProfile; + } + + /// + /// Adds a type shape provider to the profile. + /// + /// The type shape provider to add. + public void AddTypeShapeProvider(ITypeShapeProvider provider) + { + this.typeShapeProvidersBuilder ??= ImmutableArray.CreateBuilder(initialCapacity: 3); + this.typeShapeProvidersBuilder.Add(provider); + } + + /// + /// Registers known subtypes for a base type with the profile. + /// + /// The base type. + /// The mapping of known subtypes. + public void RegisterKnownSubTypes(KnownSubTypeMapping mapping) + { + this.baseProfile.Serializer.RegisterKnownSubTypes(mapping); + } + + /// + /// Registers a converter with the profile. + /// + /// The type the converter handles. + /// The converter to register. + public void RegisterConverter(MessagePackConverter converter) + { + this.baseProfile.Serializer.RegisterConverter(converter); + } + + /// + /// Registers an async enumerable converter with the profile. + /// + /// + /// Register an on the type shape provider to avoid reflection costs. + /// + /// The type of the elements in the async enumerable. + public void RegisterAsyncEnumerableConverter() + { + MessagePackConverter> converter = new AsyncEnumerableConverters.PreciseTypeConverter(); + this.baseProfile.Serializer.RegisterConverter(converter); + + MessagePackConverter> resultConverter = new EnumeratorResultsConverter(); + this.baseProfile.Serializer.RegisterConverter(resultConverter); + } + + /// + /// The type of the async enumerable generator. + /// The type of the elements in the async enumerable. + public void RegisterAsyncEnumerableConverter() + where TGenerator : IAsyncEnumerable + { + MessagePackConverter converter = new AsyncEnumerableConverters.GeneratorConverter(); + this.baseProfile.Serializer.RegisterConverter(converter); + MessagePackConverter> resultConverter = new EnumeratorResultsConverter(); + this.baseProfile.Serializer.RegisterConverter(resultConverter); + } + + /// + /// Registers a progress type with the profile. + /// + /// The type of the report. + public void RegisterProgressConverter() + { + MessagePackConverter> converter = ProgressConverterResolver.GetConverter>(); + this.baseProfile.Serializer.RegisterConverter(converter); + } + + /// + /// The type of the progress. + /// The type of the report. + public void RegisterProgressConverter() + where TProgress : IProgress + { + MessagePackConverter converter = ProgressConverterResolver.GetConverter(); + this.baseProfile.Serializer.RegisterConverter(converter); + } + + /// + /// Registers an exception type with the profile. + /// + /// The type of the exception. + public void RegisterExceptionType() + where TException : Exception + { + MessagePackConverter converter = ExceptionConverter.Instance; + this.baseProfile.Serializer.RegisterConverter(converter); + } + + /// + /// Registers an RPC marshalable type with the profile. + /// + /// The type to register. + public void RegisterRpcMarshalableConverter() + where T : class + { + MessagePackConverter converter = GetRpcMarshalableConverter(); + this.baseProfile.Serializer.RegisterConverter(converter); + } + + /// + /// Registers a converter for the type with the profile. + /// + /// The type of the duplex pipe. + public void RegisterDuplexPipeConverter() + where T : class, IDuplexPipe + { + var converter = new PipeConverters.DuplexPipeConverter(); + this.baseProfile.Serializer.RegisterConverter(converter); + } + + /// + /// Registers a converter for the type with the profile. + /// + /// The type of the pipe reader. + public void RegisterPipeReaderConverter() + where T : PipeReader + { + var converter = new PipeConverters.PipeReaderConverter(); + this.baseProfile.Serializer.RegisterConverter(converter); + } + + /// + /// Registers a converter for the type with the profile. + /// + /// The type of the pipe writer. + public void RegisterPipeWriterConverter() + where T : PipeWriter + { + var converter = new PipeConverters.PipeWriterConverter(); + this.baseProfile.Serializer.RegisterConverter(converter); + } + + /// + /// Registers a converter for the type with the profile. + /// + /// The type of the stream. + public void RegisterStreamConverter() + where T : Stream + { + var converter = new PipeConverters.StreamConverter(); + this.baseProfile.Serializer.RegisterConverter(converter); + } + + /// + /// Registers an observer type with the profile. + /// + /// The type of the observer. + public void RegisterObserverConverter() + { + MessagePackConverter> converter = GetRpcMarshalableConverter>(); + this.baseProfile.Serializer.RegisterConverter(converter); + } + + /// + /// Builds the formatter profile. + /// + /// The built formatter profile. + public Profile Build() + { + if (this.typeShapeProvidersBuilder is null || this.typeShapeProvidersBuilder.Count < 1) + { + return this.baseProfile; + } + + // ExoticTypeShapeProvider is always first and cannot be overridden. + return new Profile( + this.baseProfile.Source, + this.baseProfile.Serializer, + [ExoticTypeShapeProvider.Instance, .. this.typeShapeProvidersBuilder]); + } + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.cs new file mode 100644 index 000000000..e8d1d4579 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.Profile.cs @@ -0,0 +1,135 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Collections.Immutable; +using System.Diagnostics; +using Nerdbank.MessagePack; +using PolyType; +using PolyType.Abstractions; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public sealed partial class NerdbankMessagePackFormatter +{ + /// + /// Initializes a new instance of the class. + /// + /// The MessagePack serializer to use. + /// The type shape providers to use. + [DebuggerDisplay($"{{{nameof(GetDebuggerDisplay)}(),nq}}")] + public partial class Profile(MessagePackSerializer serializer, ImmutableArray shapeProviders) + { + /// + /// Initializes a new instance of the class. + /// + /// The source of the profile. + /// The MessagePack serializer to use. + /// The type shape providers to use. + internal Profile(ProfileSource source, MessagePackSerializer serializer, ImmutableArray shapeProviders) + : this(serializer, shapeProviders) + { + this.Source = source; + } + + internal enum ProfileSource + { + /// + /// The profile is internal to the formatter. + /// + Internal, + + /// + /// The profile is external to the formatter. + /// + External, + } + + internal ProfileSource Source { get; } = ProfileSource.Internal; + + /// + /// Gets the MessagePack serializer. + /// + internal MessagePackSerializer Serializer => serializer; + + /// + /// Gets the shape provider resolver. + /// + internal TypeShapeProviderResolver ShapeProviderResolver { get; } = new TypeShapeProviderResolver(shapeProviders); + + private int ProvidersCount => shapeProviders.Length; + + internal Profile WithFormatterState(NerdbankMessagePackFormatter formatter) + { + SerializationContext nextContext = serializer.StartingContext; + nextContext[SerializationContextExtensions.FormatterKey] = formatter; + MessagePackSerializer nextSerializer = serializer with { StartingContext = nextContext }; + + return new(this.Source, nextSerializer, shapeProviders); + } + + private string GetDebuggerDisplay() => $"{this.Source} [{this.ProvidersCount}]"; + + /// + /// When passing a type shape provider to MessagePackSerializer, it will resolve the type shape + /// from that provider and try to cache it. If the resolved type shape is not sourced from the + /// passed provider, it will throw an ArgumentException with the message: + /// System.ArgumentException : The specified shape provider is not valid for this cache. + /// To avoid this, this class does not implement ITypeShapeProvider directly so that it cannot + /// be passed to the serializer. Instead, use to get the + /// provider that will resolve the shape for the specified type, or if the serialization method supports + /// if use to get the shape directly. + /// Related issue: https://github.com/eiriktsarpalis/PolyType/issues/92. + /// + internal class TypeShapeProviderResolver + { + private readonly ImmutableArray providers; + + internal TypeShapeProviderResolver(ImmutableArray providers) + { + this.providers = providers; + } + + public ITypeShape ResolveShape() => (ITypeShape)this.ResolveShape(typeof(T)); + + public ITypeShape ResolveShape(Type type) + { + foreach (ITypeShapeProvider provider in this.providers) + { + ITypeShape? shape = provider.GetShape(type); + if (shape is not null) + { + return shape; + } + } + + // TODO: Loc the exception message. + throw new MessagePackSerializationException($"No shape provider found for type '{type}'."); + } + + /// + /// Find the first provider that can provide a shape for the specified type. + /// + /// The type to resolve. + /// The relevant shape provider. + public ITypeShapeProvider ResolveShapeProvider() => this.ResolveShapeProvider(typeof(T)); + + /// + public ITypeShapeProvider ResolveShapeProvider(Type type) + { + foreach (ITypeShapeProvider provider in this.providers) + { + if (provider.GetShape(type) is not null) + { + return provider; + } + } + + // TODO: Loc the exception message. + throw new MessagePackSerializationException($"No shape provider found for type '{type}'."); + } + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.ProgressConverterResolver.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ProgressConverterResolver.cs new file mode 100644 index 000000000..04ff23d83 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ProgressConverterResolver.cs @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Buffers; +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using PolyType.Abstractions; +using StreamJsonRpc.Reflection; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + private static class ProgressConverterResolver + { + public static MessagePackConverter GetConverter() + { + MessagePackConverter? converter = default; + + if (MessageFormatterProgressTracker.CanDeserialize(typeof(T))) + { + converter = new FullProgressConverter(); + } + else if (MessageFormatterProgressTracker.CanSerialize(typeof(T))) + { + converter = new ProgressClientConverter(); + } + + // TODO: Improve Exception + return converter ?? throw new NotSupportedException(); + } + + /// + /// Converts an instance of to a progress token. + /// + private class ProgressClientConverter : MessagePackConverter + { + public override TClass Read(ref MessagePackReader reader, SerializationContext context) + { + throw new NotSupportedException("This formatter only serializes IProgress instances."); + } + + public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + if (value is null) + { + writer.WriteNil(); + } + else + { + long progressId = formatter.FormatterProgressTracker.GetTokenForProgress(value); + writer.Write(progressId); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(ProgressClientConverter)); + } + } + + /// + /// Converts a progress token to an or an into a token. + /// + private class FullProgressConverter : MessagePackConverter + { + [return: MaybeNull] + public override TClass? Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + if (reader.TryReadNil()) + { + return default!; + } + + Assumes.NotNull(formatter.JsonRpc); + ReadOnlySequence token = reader.ReadRaw(context); + bool clientRequiresNamedArgs = formatter.ApplicableMethodAttributeOnDeserializingMethod?.ClientRequiresNamedArguments is true; + return (TClass)formatter.FormatterProgressTracker.CreateProgress(formatter.JsonRpc, token, typeof(TClass), clientRequiresNamedArgs); + } + + public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + if (value is null) + { + writer.WriteNil(); + } + else + { + long progressId = formatter.FormatterProgressTracker.GetTokenForProgress(value); + writer.Write(progressId); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(FullProgressConverter)); + } + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.RequestIdConverter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.RequestIdConverter.cs new file mode 100644 index 000000000..0f31df54c --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.RequestIdConverter.cs @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using PolyType.Abstractions; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + private class RequestIdConverter : MessagePackConverter + { + internal static readonly RequestIdConverter Instance = new(); + + private RequestIdConverter() + { + } + + public override RequestId Read(ref MessagePackReader reader, SerializationContext context) + { + context.DepthStep(); + + if (reader.NextMessagePackType == MessagePackType.Integer) + { + return new RequestId(reader.ReadInt64()); + } + else + { + // Do *not* read as an interned string here because this ID should be unique. + return new RequestId(reader.ReadString()); + } + } + + public override void Write(ref MessagePackWriter writer, in RequestId value, SerializationContext context) + { + context.DepthStep(); + + if (value.Number.HasValue) + { + writer.Write(value.Number.Value); + } + else + { + writer.Write(value.String); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => JsonNode.Parse(""" + { + "type": ["string", { "type": "integer", "format": "int64" }] + } + """)?.AsObject(); + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.RpcMarshalableConverter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.RpcMarshalableConverter.cs new file mode 100644 index 000000000..8c4e3c251 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.RpcMarshalableConverter.cs @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using PolyType.Abstractions; +using StreamJsonRpc.Reflection; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + private class RpcMarshalableConverter( + JsonRpcProxyOptions proxyOptions, + JsonRpcTargetOptions targetOptions, + RpcMarshalableAttribute rpcMarshalableAttribute) : MessagePackConverter + where T : class + { + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Reader is passed to rpc context")] + public override T? Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + MessageFormatterRpcMarshaledContextTracker.MarshalToken? token = formatter + .rpcProfile.Deserialize(ref reader, context.CancellationToken); + + return token.HasValue + ? (T?)formatter.RpcMarshaledContextTracker.GetObject(typeof(T), token, proxyOptions) + : null; + } + + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to rpc context")] + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + if (value is null) + { + writer.WriteNil(); + } + else + { + MessageFormatterRpcMarshaledContextTracker.MarshalToken token = formatter.RpcMarshaledContextTracker.GetToken(value, targetOptions, typeof(T), rpcMarshalableAttribute); + formatter.rpcProfile.Serialize(ref writer, token, context.CancellationToken); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(RpcMarshalableConverter)); + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.ToStringHelper.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ToStringHelper.cs new file mode 100644 index 000000000..ba740d2b5 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ToStringHelper.cs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Buffers; +using Nerdbank.MessagePack; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + /// + /// A recyclable object that can serialize a message to JSON on demand. + /// + /// + /// In perf traces, creation of this object used to show up as one of the most allocated objects. + /// It is used even when tracing isn't active. So we changed its design to be reused, + /// since its lifetime is only required during a synchronous call to a trace API. + /// + private class ToStringHelper + { + private ReadOnlySequence? encodedMessage; + private string? jsonString; + + public override string ToString() + { + Verify.Operation(this.encodedMessage.HasValue, "This object has not been activated. It may have already been recycled."); + + return this.jsonString ??= MessagePackSerializer.ConvertToJson(this.encodedMessage.Value); + } + + /// + /// Initializes this object to represent a message. + /// + internal void Activate(ReadOnlySequence encodedMessage) + { + this.encodedMessage = encodedMessage; + } + + /// + /// Cleans out this object to release memory and ensure throws if someone uses it after deactivation. + /// + internal void Deactivate() + { + this.encodedMessage = null; + this.jsonString = null; + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.TraceParentConverter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.TraceParentConverter.cs new file mode 100644 index 000000000..a46d91c2d --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.TraceParentConverter.cs @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Buffers; +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using PolyType.Abstractions; +using StreamJsonRpc.Protocol; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + internal class TraceParentConverter : MessagePackConverter + { + public unsafe override TraceParent Read(ref MessagePackReader reader, SerializationContext context) + { + context.DepthStep(); + + if (reader.ReadArrayHeader() != 2) + { + throw new NotSupportedException("Unexpected array length."); + } + + var result = default(TraceParent); + result.Version = reader.ReadByte(); + if (result.Version != 0) + { + throw new NotSupportedException("traceparent version " + result.Version + " is not supported."); + } + + if (reader.ReadArrayHeader() != 3) + { + throw new NotSupportedException("Unexpected array length in version-format."); + } + + ReadOnlySequence bytes = reader.ReadBytes() ?? throw new NotSupportedException("Expected traceid not found."); + bytes.CopyTo(new Span(result.TraceId, TraceParent.TraceIdByteCount)); + + bytes = reader.ReadBytes() ?? throw new NotSupportedException("Expected parentid not found."); + bytes.CopyTo(new Span(result.ParentId, TraceParent.ParentIdByteCount)); + + result.Flags = (TraceParent.TraceFlags)reader.ReadByte(); + + return result; + } + + public unsafe override void Write(ref MessagePackWriter writer, in TraceParent value, SerializationContext context) + { + if (value.Version != 0) + { + throw new NotSupportedException("traceparent version " + value.Version + " is not supported."); + } + + context.DepthStep(); + + writer.WriteArrayHeader(2); + + writer.Write(value.Version); + + writer.WriteArrayHeader(3); + + fixed (byte* traceId = value.TraceId) + { + writer.Write(new ReadOnlySpan(traceId, TraceParent.TraceIdByteCount)); + } + + fixed (byte* parentId = value.ParentId) + { + writer.Write(new ReadOnlySpan(parentId, TraceParent.ParentIdByteCount)); + } + + writer.Write((byte)value.Flags); + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(TraceParentConverter)); + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs new file mode 100644 index 000000000..c25ea263f --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs @@ -0,0 +1,1497 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Buffers; +using System.Collections; +using System.Collections.Immutable; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.IO.Pipelines; +using System.Reflection; +using System.Runtime.ExceptionServices; +using System.Runtime.Serialization; +using System.Text; +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using PolyType; +using PolyType.Abstractions; +using PolyType.ReflectionProvider; +using PolyType.SourceGenerator; +using PolyType.SourceGenModel; +using StreamJsonRpc.Protocol; +using StreamJsonRpc.Reflection; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +/// +/// The MessagePack implementation used here comes from https://github.com/AArnott/Nerdbank.MessagePack. +/// +[GenerateShape] +[GenerateShape] +[GenerateShape] +[GenerateShape] +[GenerateShape] +[GenerateShape] +[GenerateShape] +[GenerateShape] +[SuppressMessage("ApiDesign", "RS0016:Add public types and members to the declared API", Justification = "TODO: Suppressed for Development")] +public partial class NerdbankMessagePackFormatter : FormatterBase, IJsonRpcMessageFormatter, IJsonRpcFormatterTracingCallbacks, IJsonRpcMessageFactory +{ + /// + /// A cache of property names to declared property types, indexed by their containing parameter object type. + /// + /// + /// All access to this field should be while holding a lock on this member's value. + /// + private static readonly Dictionary> ParameterObjectPropertyTypes = []; + + /// + /// The serializer context to use for top-level RPC messages. + /// + private readonly Profile rpcProfile; + + private readonly ToStringHelper serializationToStringHelper = new(); + + private readonly ToStringHelper deserializationToStringHelper = new(); + + private readonly ThreadLocal exceptionRecursionCounter = new(); + + /// + /// The serializer to use for user data (e.g. arguments, return values and errors). + /// + private Profile userDataProfile; + + /// + /// Initializes a new instance of the class. + /// + public NerdbankMessagePackFormatter() + { + KnownSubTypeMapping exceptionSubtypeMap = new(); + exceptionSubtypeMap.Add(alias: 1, ShapeProvider); + exceptionSubtypeMap.Add(alias: 2, ShapeProvider); + exceptionSubtypeMap.Add(alias: 3, ShapeProvider); + exceptionSubtypeMap.Add(alias: 4, ShapeProvider); + + // Set up initial options for our own message types. + MessagePackSerializer serializer = new() + { + InternStrings = true, + SerializeDefaultValues = false, + StartingContext = new SerializationContext() + { + [SerializationContextExtensions.FormatterKey] = this, + }, + }; + + serializer.RegisterKnownSubTypes(exceptionSubtypeMap); + RegisterCommonConverters(serializer); + + this.rpcProfile = new Profile( + Profile.ProfileSource.Internal, + serializer, + [ + ExoticTypeShapeProvider.Instance, + ShapeProvider_StreamJsonRpc.Default + ]); + + // Create a serializer for user data. + MessagePackSerializer userSerializer = new() + { + InternStrings = true, + SerializeDefaultValues = false, + StartingContext = new SerializationContext() + { + [SerializationContextExtensions.FormatterKey] = this, + }, + }; + + userSerializer.RegisterKnownSubTypes(exceptionSubtypeMap); + RegisterCommonConverters(userSerializer); + + this.userDataProfile = new Profile( + Profile.ProfileSource.External, + userSerializer, + [ + ExoticTypeShapeProvider.Instance, + ReflectionTypeShapeProvider.Default + ]); + + this.ProfileBuilder = new Profile.Builder(this.userDataProfile); + + // Add our own resolvers to fill in specialized behavior if the user doesn't provide/override it by their own resolver. + static void RegisterCommonConverters(MessagePackSerializer serializer) + { + serializer.RegisterConverter(GetRpcMarshalableConverter()); + serializer.RegisterConverter(PipeConverters.PipeReaderConverter.DefaultInstance); + serializer.RegisterConverter(PipeConverters.PipeWriterConverter.DefaultInstance); + serializer.RegisterConverter(PipeConverters.StreamConverter.DefaultInstance); + serializer.RegisterConverter(PipeConverters.DuplexPipeConverter.DefaultInstance); + + // We preset this one in user data because $/cancellation methods can carry RequestId values as arguments. + serializer.RegisterConverter(RequestIdConverter.Instance); + + // We preset this one because for some protocols like IProgress, tokens are passed in that we must relay exactly back to the client as an argument. + serializer.RegisterConverter(EventArgsConverter.Instance); + serializer.RegisterConverter(ExceptionConverter.Instance); + serializer.RegisterConverter(ExceptionConverter.Instance); + serializer.RegisterConverter(ExceptionConverter.Instance); + serializer.RegisterConverter(ExceptionConverter.Instance); + serializer.RegisterConverter(ExceptionConverter.Instance); + } + } + + /// + /// Gets the profile builder for the formatter. + /// + public Profile.Builder ProfileBuilder { get; } + + /// + /// Sets the formatter profile for user data. + /// + /// + /// + /// For improved startup performance, use + /// to configure a reusable profile and set it here for each instance of this formatter. + /// The profile must be configured before any messages are serialized or deserialized. + /// + /// + /// If not set, a default profile is used which will resolve types using reflection emit. + /// + /// + /// The formatter profile to set. + public void SetFormatterProfile(Profile profile) + { + Requires.NotNull(profile, nameof(profile)); + this.userDataProfile = profile.WithFormatterState(this); + } + + /// + /// Configures the formatter profile for user data with the specified configuration action. + /// + /// + /// Generally prefer using over this method + /// as it is more efficient to reuse a profile across multiple instances of this formatter. + /// + /// The configuration action. + public void SetFormatterProfile(Action configure) + { + Requires.NotNull(configure, nameof(configure)); + + var builder = new Profile.Builder(this.userDataProfile); + configure(builder); + this.SetFormatterProfile(builder.Build()); + } + + /// + public JsonRpcMessage Deserialize(ReadOnlySequence contentBuffer) + { + JsonRpcMessage message = this.rpcProfile.Deserialize(contentBuffer) + ?? throw new MessagePackSerializationException("Failed to deserialize JSON-RPC message."); + + IJsonRpcTracingCallbacks? tracingCallbacks = this.JsonRpc; + this.deserializationToStringHelper.Activate(contentBuffer); + try + { + tracingCallbacks?.OnMessageDeserialized(message, this.deserializationToStringHelper); + } + finally + { + this.deserializationToStringHelper.Deactivate(); + } + + return message; + } + + /// + public void Serialize(IBufferWriter contentBuffer, JsonRpcMessage message) + { + if (message is Protocol.JsonRpcRequest request + && request.Arguments is not null + && request.ArgumentsList is null + && request.Arguments is not IReadOnlyDictionary) + { + // This request contains named arguments, but not using a standard dictionary. Convert it to a dictionary so that + // the parameters can be matched to the method we're invoking. + if (GetParamsObjectDictionary(request.Arguments) is { } namedArgs) + { + request.Arguments = namedArgs.ArgumentValues; + request.NamedArgumentDeclaredTypes = namedArgs.ArgumentTypes; + } + } + + var writer = new MessagePackWriter(contentBuffer); + try + { + this.rpcProfile.Serialize(ref writer, message); + writer.Flush(); + } + catch (Exception ex) + { + throw new MessagePackSerializationException(string.Format(CultureInfo.CurrentCulture, Resources.ErrorWritingJsonRpcMessage, ex.GetType().Name, ex.Message), ex); + } + } + + /// + public object GetJsonText(JsonRpcMessage message) => message is IJsonRpcMessagePackRetention retainedMsgPack + ? MessagePackSerializer.ConvertToJson(retainedMsgPack.OriginalMessagePack) + : throw new NotSupportedException(); + + /// + Protocol.JsonRpcRequest IJsonRpcMessageFactory.CreateRequestMessage() => new OutboundJsonRpcRequest(this); + + /// + Protocol.JsonRpcError IJsonRpcMessageFactory.CreateErrorMessage() => new JsonRpcError(this.rpcProfile); + + /// + Protocol.JsonRpcResult IJsonRpcMessageFactory.CreateResultMessage() => new JsonRpcResult(this, this.rpcProfile); + + void IJsonRpcFormatterTracingCallbacks.OnSerializationComplete(JsonRpcMessage message, ReadOnlySequence encodedMessage) + { + IJsonRpcTracingCallbacks? tracingCallbacks = this.JsonRpc; + this.serializationToStringHelper.Activate(encodedMessage); + try + { + tracingCallbacks?.OnMessageSerialized(message, this.serializationToStringHelper); + } + finally + { + this.serializationToStringHelper.Deactivate(); + } + } + + internal static MessagePackConverter GetRpcMarshalableConverter() + where T : class + { + if (MessageFormatterRpcMarshaledContextTracker.TryGetMarshalOptionsForType( + typeof(T), + out JsonRpcProxyOptions? proxyOptions, + out JsonRpcTargetOptions? targetOptions, + out RpcMarshalableAttribute? attribute)) + { + return (RpcMarshalableConverter)Activator.CreateInstance( + typeof(RpcMarshalableConverter<>).MakeGenericType(typeof(T)), + proxyOptions, + targetOptions, + attribute)!; + } + + // TODO: Improve Exception message. + throw new NotSupportedException($"Type '{typeof(T).FullName}' is not supported for RPC Marshaling."); + } + + /// + /// Extracts a dictionary of property names and values from the specified params object. + /// + /// The params object. + /// A dictionary of argument values and another of declared argument types, or if is null. + /// + /// This method supports DataContractSerializer-compliant types. This includes C# anonymous types. + /// + [return: NotNullIfNotNull(nameof(paramsObject))] + private static (IReadOnlyDictionary ArgumentValues, IReadOnlyDictionary ArgumentTypes)? GetParamsObjectDictionary(object? paramsObject) + { + if (paramsObject is null) + { + return default; + } + + // Look up the argument types dictionary if we saved it before. + Type paramsObjectType = paramsObject.GetType(); + IReadOnlyDictionary? argumentTypes; + lock (ParameterObjectPropertyTypes) + { + ParameterObjectPropertyTypes.TryGetValue(paramsObjectType, out argumentTypes); + } + + // If we couldn't find a previously created argument types dictionary, create a mutable one that we'll build this time. + Dictionary? mutableArgumentTypes = argumentTypes is null ? [] : null; + + var result = new Dictionary(StringComparer.Ordinal); + + TypeInfo paramsTypeInfo = paramsObject.GetType().GetTypeInfo(); + bool isDataContract = paramsTypeInfo.GetCustomAttribute() is not null; + + BindingFlags bindingFlags = BindingFlags.FlattenHierarchy | BindingFlags.Public | BindingFlags.Instance; + if (isDataContract) + { + bindingFlags |= BindingFlags.NonPublic; + } + + bool TryGetSerializationInfo(MemberInfo memberInfo, out string key) + { + key = memberInfo.Name; + if (isDataContract) + { + DataMemberAttribute? dataMemberAttribute = memberInfo.GetCustomAttribute(); + if (dataMemberAttribute is null) + { + return false; + } + + if (!dataMemberAttribute.EmitDefaultValue) + { + throw new NotSupportedException($"(DataMemberAttribute.EmitDefaultValue == false) is not supported but was found on: {memberInfo.DeclaringType!.FullName}.{memberInfo.Name}."); + } + + key = dataMemberAttribute.Name ?? memberInfo.Name; + return true; + } + else + { + return memberInfo.GetCustomAttribute() is null; + } + } + + foreach (PropertyInfo property in paramsTypeInfo.GetProperties(bindingFlags)) + { + if (property.GetMethod is not null) + { + if (TryGetSerializationInfo(property, out string key)) + { + result[key] = property.GetValue(paramsObject); + if (mutableArgumentTypes is not null) + { + mutableArgumentTypes[key] = property.PropertyType; + } + } + } + } + + foreach (FieldInfo field in paramsTypeInfo.GetFields(bindingFlags)) + { + if (TryGetSerializationInfo(field, out string key)) + { + result[key] = field.GetValue(paramsObject); + if (mutableArgumentTypes is not null) + { + mutableArgumentTypes[key] = field.FieldType; + } + } + } + + // If we assembled the argument types dictionary this time, save it for next time. + if (mutableArgumentTypes is not null) + { + lock (ParameterObjectPropertyTypes) + { + if (ParameterObjectPropertyTypes.TryGetValue(paramsObjectType, out IReadOnlyDictionary? lostRace)) + { + // Of the two, pick the winner to use ourselves so we consolidate on one and allow the GC to collect the loser sooner. + argumentTypes = lostRace; + } + else + { + ParameterObjectPropertyTypes.Add(paramsObjectType, argumentTypes = mutableArgumentTypes); + } + } + } + + return (result, argumentTypes!); + } + + private static ReadOnlySequence GetSliceForNextToken(ref MessagePackReader reader, in SerializationContext context) + { + SequencePosition startPosition = reader.Position; + reader.Skip(context); + SequencePosition endPosition = reader.Position; + return reader.Sequence.Slice(startPosition, endPosition); + } + + /// + /// Reads a string with an optimized path for the value "2.0". + /// + /// The reader to use. + /// The decoded string. + private static string ReadProtocolVersion(ref MessagePackReader reader) + { + // Recognize "2.0" since we expect it and can avoid decoding and allocating a new string for it. + if (Version2.TryRead(ref reader)) + { + return Version2.Value; + } + else + { + // TODO: Should throw? + return reader.ReadString() ?? throw new MessagePackSerializationException(Resources.RequiredArgumentMissing); + } + } + + /// + /// Writes the JSON-RPC version property name and value in a highly optimized way. + /// + private static void WriteProtocolVersionPropertyAndValue(ref MessagePackWriter writer, string version) + { + writer.Write(VersionPropertyName); + writer.Write(version); + } + + private static void ReadUnknownProperty(ref MessagePackReader reader, in SerializationContext context, ref Dictionary>? topLevelProperties, ReadOnlySpan stringKey) + { + topLevelProperties ??= new Dictionary>(StringComparer.Ordinal); +#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER + string name = Encoding.UTF8.GetString(stringKey); +#else + string name = Encoding.UTF8.GetString(stringKey.ToArray()); +#endif + topLevelProperties.Add(name, GetSliceForNextToken(ref reader, context)); + } + + /// + /// Converts JSON-RPC messages to and from MessagePack format. + /// + internal class JsonRpcMessageConverter : MessagePackConverter + { + /// + /// Reads a JSON-RPC message from the specified MessagePack reader. + /// + /// The MessagePack reader to read from. + /// The serialization context. + /// The deserialized JSON-RPC message. + public override JsonRpcMessage? Read(ref MessagePackReader reader, SerializationContext context) + { + context.DepthStep(); + + MessagePackReader readAhead = reader.CreatePeekReader(); + int propertyCount = readAhead.ReadMapHeader(); + + for (int i = 0; i < propertyCount; i++) + { + if (MethodPropertyName.TryRead(ref readAhead)) + { + return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (ResultPropertyName.TryRead(ref readAhead)) + { + return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (ErrorPropertyName.TryRead(ref readAhead)) + { + return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else + { + readAhead.Skip(context); + } + + // Skip the value of the property. + readAhead.Skip(context); + } + + throw new UnrecognizedJsonRpcMessageException(); + } + + /// + /// Writes a JSON-RPC message to the specified MessagePack writer. + /// + /// The MessagePack writer to write to. + /// The JSON-RPC message to write. + /// The serialization context. + public override void Write(ref MessagePackWriter writer, in JsonRpcMessage? value, SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); + + using (formatter.TrackSerialization(value)) + { + switch (value) + { + case Protocol.JsonRpcRequest request: + context.GetConverter(context.TypeShapeProvider).Write(ref writer, request, context); + break; + case Protocol.JsonRpcResult result: + context.GetConverter(context.TypeShapeProvider).Write(ref writer, result, context); + break; + case Protocol.JsonRpcError error: + context.GetConverter(context.TypeShapeProvider).Write(ref writer, error, context); + break; + default: + throw new NotSupportedException("Unexpected JsonRpcMessage-derived type: " + value.GetType().Name); + } + } + } + + /// + /// Gets the JSON schema for the specified type. + /// + /// The JSON schema context. + /// The type shape. + /// The JSON schema for the specified type. + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return base.GetJsonSchema(context, typeShape); + } + } + + /// + /// Converts a JSON-RPC request message to and from MessagePack format. + /// + internal class JsonRpcRequestConverter : MessagePackConverter + { + /// + /// Reads a JSON-RPC request message from the specified MessagePack reader. + /// + /// The MessagePack reader to read from. + /// The serialization context. + /// The deserialized JSON-RPC request message. + public override Protocol.JsonRpcRequest? Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + var result = new JsonRpcRequest(formatter) + { + OriginalMessagePack = reader.Sequence, + }; + + Dictionary>? topLevelProperties = null; + + int propertyCount = reader.ReadMapHeader(); + for (int propertyIndex = 0; propertyIndex < propertyCount; propertyIndex++) + { + if (VersionPropertyName.TryRead(ref reader)) + { + result.Version = ReadProtocolVersion(ref reader); + } + else if (IdPropertyName.TryRead(ref reader)) + { + result.RequestId = context.GetConverter(null).Read(ref reader, context); + } + else if (MethodPropertyName.TryRead(ref reader)) + { + result.Method = context.GetConverter(null).Read(ref reader, context); + } + else if (ParamsPropertyName.TryRead(ref reader)) + { + SequencePosition paramsTokenStartPosition = reader.Position; + + // Parse out the arguments into a dictionary or array, but don't deserialize them because we don't yet know what types to deserialize them to. + switch (reader.NextMessagePackType) + { + case MessagePackType.Array: + var positionalArgs = new ReadOnlySequence[reader.ReadArrayHeader()]; + for (int i = 0; i < positionalArgs.Length; i++) + { + positionalArgs[i] = GetSliceForNextToken(ref reader, context); + } + + result.MsgPackPositionalArguments = positionalArgs; + break; + case MessagePackType.Map: + int namedArgsCount = reader.ReadMapHeader(); + var namedArgs = new Dictionary>(namedArgsCount); + for (int i = 0; i < namedArgsCount; i++) + { + string? propertyName = context.GetConverter(null).Read(ref reader, context) ?? throw new MessagePackSerializationException(Resources.UnexpectedNullValueInMap); + namedArgs.Add(propertyName, GetSliceForNextToken(ref reader, context)); + } + + result.MsgPackNamedArguments = namedArgs; + break; + case MessagePackType.Nil: + result.MsgPackPositionalArguments = []; + reader.ReadNil(); + break; + case MessagePackType type: + throw new MessagePackSerializationException("Expected a map or array of arguments but got " + type); + } + + result.MsgPackArguments = reader.Sequence.Slice(paramsTokenStartPosition, reader.Position); + } + else if (TraceParentPropertyName.TryRead(ref reader)) + { + TraceParent traceParent = context.GetConverter(null).Read(ref reader, context); + result.TraceParent = traceParent.ToString(); + } + else if (TraceStatePropertyName.TryRead(ref reader)) + { + result.TraceState = ReadTraceState(ref reader, context); + } + else + { + ReadUnknownProperty(ref reader, context, ref topLevelProperties, reader.ReadStringSpan()); + } + } + + if (topLevelProperties is not null) + { + result.TopLevelPropertyBag = new TopLevelPropertyBag(formatter.userDataProfile, topLevelProperties); + } + + formatter.TryHandleSpecialIncomingMessage(result); + + return result; + } + + /// + /// Writes a JSON-RPC request message to the specified MessagePack writer. + /// + /// The MessagePack writer to write to. + /// The JSON-RPC request message to write. + /// The serialization context. + public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcRequest? value, SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + var topLevelPropertyBag = (TopLevelPropertyBag?)(value as IMessageWithTopLevelPropertyBag)?.TopLevelPropertyBag; + + int mapElementCount = value.RequestId.IsEmpty ? 3 : 4; + if (value.TraceParent?.Length > 0) + { + mapElementCount++; + if (value.TraceState?.Length > 0) + { + mapElementCount++; + } + } + + mapElementCount += topLevelPropertyBag?.PropertyCount ?? 0; + writer.WriteMapHeader(mapElementCount); + + WriteProtocolVersionPropertyAndValue(ref writer, value.Version); + + if (!value.RequestId.IsEmpty) + { + writer.Write(IdPropertyName); + context.GetConverter(context.TypeShapeProvider) + .Write(ref writer, value.RequestId, context); + } + + writer.Write(MethodPropertyName); + writer.Write(value.Method); + + writer.Write(ParamsPropertyName); + if (value.ArgumentsList is not null) + { + writer.WriteArrayHeader(value.ArgumentsList.Count); + + for (int i = 0; i < value.ArgumentsList.Count; i++) + { + object? arg = value.ArgumentsList[i]; + + if (value.ArgumentListDeclaredTypes is null) + { + formatter.userDataProfile.SerializeObject( + ref writer, + arg, + context.CancellationToken); + } + else + { + formatter.userDataProfile.SerializeObject( + ref writer, + arg, + value.ArgumentListDeclaredTypes[i], + context.CancellationToken); + } + } + } + else if (value.NamedArguments is not null) + { + writer.WriteMapHeader(value.NamedArguments.Count); + foreach (KeyValuePair entry in value.NamedArguments) + { + writer.Write(entry.Key); + + if (value.NamedArgumentDeclaredTypes is null) + { + formatter.userDataProfile.SerializeObject( + ref writer, + entry.Value, + context.CancellationToken); + } + else + { + Type argType = value.NamedArgumentDeclaredTypes[entry.Key]; + formatter.userDataProfile.SerializeObject( + ref writer, + entry.Value, + argType, + context.CancellationToken); + } + } + } + else + { + writer.WriteNil(); + } + + if (value.TraceParent?.Length > 0) + { + writer.Write(TraceParentPropertyName); + formatter.rpcProfile.Serialize(ref writer, new TraceParent(value.TraceParent)); + + if (value.TraceState?.Length > 0) + { + writer.Write(TraceStatePropertyName); + WriteTraceState(ref writer, value.TraceState); + } + } + + topLevelPropertyBag?.WriteProperties(ref writer); + } + + /// + /// Gets the JSON schema for the specified type. + /// + /// The JSON schema context. + /// The type shape. + /// The JSON schema for the specified type. + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(JsonRpcRequestConverter)); + } + + private static void WriteTraceState(ref MessagePackWriter writer, string traceState) + { + ReadOnlySpan traceStateChars = traceState.AsSpan(); + + // Count elements first so we can write the header. + int elementCount = 1; + int commaIndex; + while ((commaIndex = traceStateChars.IndexOf(',')) >= 0) + { + elementCount++; + traceStateChars = traceStateChars.Slice(commaIndex + 1); + } + + // For every element, we have a key and value to record. + writer.WriteArrayHeader(elementCount * 2); + + traceStateChars = traceState.AsSpan(); + while ((commaIndex = traceStateChars.IndexOf(',')) >= 0) + { + ReadOnlySpan element = traceStateChars.Slice(0, commaIndex); + WritePair(ref writer, element); + traceStateChars = traceStateChars.Slice(commaIndex + 1); + } + + // Write out the last one. + WritePair(ref writer, traceStateChars); + + static void WritePair(ref MessagePackWriter writer, ReadOnlySpan pair) + { + int equalsIndex = pair.IndexOf('='); + ReadOnlySpan key = pair.Slice(0, equalsIndex); + ReadOnlySpan value = pair.Slice(equalsIndex + 1); + writer.Write(key); + writer.Write(value); + } + } + + private static unsafe string ReadTraceState(ref MessagePackReader reader, SerializationContext context) + { + int elements = reader.ReadArrayHeader(); + if (elements % 2 != 0) + { + throw new NotSupportedException("Odd number of elements not expected."); + } + + // With care, we could probably assemble this string with just two allocations (the string + a char[]). + var resultBuilder = new StringBuilder(); + for (int i = 0; i < elements; i += 2) + { + if (resultBuilder.Length > 0) + { + resultBuilder.Append(','); + } + + // We assume the key is a frequent string, and the value is unique, + // so we optimize whether to use string interning or not on that basis. + resultBuilder.Append(context.GetConverter(null).Read(ref reader, context)); + resultBuilder.Append('='); + resultBuilder.Append(reader.ReadString()); + } + + return resultBuilder.ToString(); + } + } + + /// + /// Converts a JSON-RPC result message to and from MessagePack format. + /// + internal class JsonRpcResultConverter : MessagePackConverter + { + /// + /// Reads a JSON-RPC result message from the specified MessagePack reader. + /// + /// The MessagePack reader to read from. + /// The serialization context. + /// The deserialized JSON-RPC result message. + public override Protocol.JsonRpcResult Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); + + var result = new JsonRpcResult(formatter, formatter.userDataProfile) + { + OriginalMessagePack = reader.Sequence, + }; + + Dictionary>? topLevelProperties = null; + + int propertyCount = reader.ReadMapHeader(); + for (int propertyIndex = 0; propertyIndex < propertyCount; propertyIndex++) + { + if (VersionPropertyName.TryRead(ref reader)) + { + result.Version = ReadProtocolVersion(ref reader); + } + else if (IdPropertyName.TryRead(ref reader)) + { + result.RequestId = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (ResultPropertyName.TryRead(ref reader)) + { + result.MsgPackResult = GetSliceForNextToken(ref reader, context); + } + else + { + ReadUnknownProperty(ref reader, context, ref topLevelProperties, reader.ReadStringSpan()); + } + } + + if (topLevelProperties is not null) + { + result.TopLevelPropertyBag = new TopLevelPropertyBag(formatter.userDataProfile, topLevelProperties); + } + + return result; + } + + /// + /// Writes a JSON-RPC result message to the specified MessagePack writer. + /// + /// The MessagePack writer to write to. + /// The JSON-RPC result message to write. + /// The serialization context. + public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcResult? value, SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + var topLevelPropertyBagMessage = value as IMessageWithTopLevelPropertyBag; + + int mapElementCount = 3; + mapElementCount += (topLevelPropertyBagMessage?.TopLevelPropertyBag as TopLevelPropertyBag)?.PropertyCount ?? 0; + writer.WriteMapHeader(mapElementCount); + + WriteProtocolVersionPropertyAndValue(ref writer, value.Version); + + writer.Write(IdPropertyName); + context.GetConverter(context.TypeShapeProvider).Write(ref writer, value.RequestId, context); + + writer.Write(ResultPropertyName); + + if (value.Result is null) + { + writer.WriteNil(); + } + else if (value.ResultDeclaredType is not null + && value.ResultDeclaredType != typeof(void) + && value.ResultDeclaredType != typeof(object)) + { + formatter.userDataProfile.SerializeObject(ref writer, value.Result, value.ResultDeclaredType, context.CancellationToken); + } + else + { + formatter.userDataProfile.SerializeObject(ref writer, value.Result, context.CancellationToken); + } + + (topLevelPropertyBagMessage?.TopLevelPropertyBag as TopLevelPropertyBag)?.WriteProperties(ref writer); + } + + /// + /// Gets the JSON schema for the specified type. + /// + /// The JSON schema context. + /// The type shape. + /// The JSON schema for the specified type. + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(JsonRpcResultConverter)); + } + } + + /// + /// Converts a JSON-RPC error message to and from MessagePack format. + /// + internal class JsonRpcErrorConverter : MessagePackConverter + { + /// + /// Reads a JSON-RPC error message from the specified MessagePack reader. + /// + /// The MessagePack reader to read from. + /// The serialization context. + /// The deserialized JSON-RPC error message. + public override Protocol.JsonRpcError Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + var error = new JsonRpcError(formatter.userDataProfile) + { + OriginalMessagePack = reader.Sequence, + }; + + Dictionary>? topLevelProperties = null; + + context.DepthStep(); + int propertyCount = reader.ReadMapHeader(); + for (int propertyIdx = 0; propertyIdx < propertyCount; propertyIdx++) + { + if (VersionPropertyName.TryRead(ref reader)) + { + error.Version = ReadProtocolVersion(ref reader); + } + else if (IdPropertyName.TryRead(ref reader)) + { + error.RequestId = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (ErrorPropertyName.TryRead(ref reader)) + { + error.Error = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else + { + ReadUnknownProperty(ref reader, context, ref topLevelProperties, reader.ReadStringSpan()); + } + } + + if (topLevelProperties is not null) + { + error.TopLevelPropertyBag = new TopLevelPropertyBag(formatter.userDataProfile, topLevelProperties); + } + + return error; + } + + /// + /// Writes a JSON-RPC error message to the specified MessagePack writer. + /// + /// The MessagePack writer to write to. + /// The JSON-RPC error message to write. + /// The serialization context. + public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcError? value, SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + + var topLevelPropertyBag = (TopLevelPropertyBag?)(value as IMessageWithTopLevelPropertyBag)?.TopLevelPropertyBag; + + context.DepthStep(); + int mapElementCount = 3; + mapElementCount += topLevelPropertyBag?.PropertyCount ?? 0; + writer.WriteMapHeader(mapElementCount); + + WriteProtocolVersionPropertyAndValue(ref writer, value.Version); + + writer.Write(IdPropertyName); + context.GetConverter(context.TypeShapeProvider) + .Write(ref writer, value.RequestId, context); + + writer.Write(ErrorPropertyName); + context.GetConverter(context.TypeShapeProvider) + .Write(ref writer, value.Error, context); + + topLevelPropertyBag?.WriteProperties(ref writer); + } + + /// + /// Gets the JSON schema for the specified type. + /// + /// The JSON schema context. + /// The type shape. + /// The JSON schema for the specified type. + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(JsonRpcErrorConverter)); + } + } + + /// + /// Converts a JSON-RPC error detail to and from MessagePack format. + /// + internal class JsonRpcErrorDetailConverter : MessagePackConverter + { + private static readonly MessagePackString CodePropertyName = new("code"); + private static readonly MessagePackString MessagePropertyName = new("message"); + private static readonly MessagePackString DataPropertyName = new("data"); + + /// + /// Reads a JSON-RPC error detail from the specified MessagePack reader. + /// + /// The MessagePack reader to read from. + /// The serialization context. + /// The deserialized JSON-RPC error detail. + public override Protocol.JsonRpcError.ErrorDetail Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); + + var result = new JsonRpcError.ErrorDetail(formatter.userDataProfile); + + int propertyCount = reader.ReadMapHeader(); + for (int propertyIdx = 0; propertyIdx < propertyCount; propertyIdx++) + { + if (CodePropertyName.TryRead(ref reader)) + { + result.Code = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (MessagePropertyName.TryRead(ref reader)) + { + result.Message = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (DataPropertyName.TryRead(ref reader)) + { + result.MsgPackData = GetSliceForNextToken(ref reader, context); + } + else + { + reader.Skip(context); + } + } + + return result; + } + + /// + /// Writes a JSON-RPC error detail to the specified MessagePack writer. + /// + /// The MessagePack writer to write to. + /// The JSON-RPC error detail to write. + /// The serialization context. + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to user data context")] + public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcError.ErrorDetail? value, SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); + + writer.WriteMapHeader(3); + + writer.Write(CodePropertyName); + context.GetConverter(context.TypeShapeProvider) + .Write(ref writer, value.Code, context); + + writer.Write(MessagePropertyName); + writer.Write(value.Message); + + writer.Write(DataPropertyName); + formatter.userDataProfile.Serialize(ref writer, value.Data, context.CancellationToken); + } + + /// + /// Gets the JSON schema for the specified type. + /// + /// The JSON schema context. + /// The type shape. + /// The JSON schema for the specified type. + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(JsonRpcErrorDetailConverter)); + } + } + + private class TopLevelPropertyBag : TopLevelPropertyBagBase + { + private readonly Profile formatterProfile; + private readonly IReadOnlyDictionary>? inboundUnknownProperties; + + /// + /// Initializes a new instance of the class + /// for an incoming message. + /// + /// The profile use for this data. + /// The map of unrecognized inbound properties. + internal TopLevelPropertyBag(Profile formatterProfile, IReadOnlyDictionary> inboundUnknownProperties) + : base(isOutbound: false) + { + this.formatterProfile = formatterProfile; + this.inboundUnknownProperties = inboundUnknownProperties; + } + + /// + /// Initializes a new instance of the class + /// for an outbound message. + /// + /// The profile to use for this data. + internal TopLevelPropertyBag(Profile formatterProfile) + : base(isOutbound: true) + { + this.formatterProfile = formatterProfile; + } + + internal int PropertyCount => this.inboundUnknownProperties?.Count ?? this.OutboundProperties?.Count ?? 0; + + /// + /// Writes the properties tracked by this collection to a messagepack writer. + /// + /// The writer to use. + internal void WriteProperties(ref MessagePackWriter writer) + { + if (this.inboundUnknownProperties is not null) + { + // We're actually re-transmitting an incoming message (remote target feature). + // We need to copy all the properties that were in the original message. + // Don't implement this without enabling the tests for the scenario found in JsonRpcRemoteTargetMessagePackFormatterTests.cs. + // The tests fail for reasons even without this support, so there's work to do beyond just implementing this. + throw new NotImplementedException(); + + ////foreach (KeyValuePair> entry in this.inboundUnknownProperties) + ////{ + //// writer.Write(entry.Key); + //// writer.Write(entry.Value); + ////} + } + else + { + foreach (KeyValuePair entry in this.OutboundProperties) + { + writer.Write(entry.Key); + this.formatterProfile.SerializeObject(ref writer, entry.Value.Value, entry.Value.DeclaredType); + } + } + } + + protected internal override bool TryGetTopLevelProperty(string name, [MaybeNull] out T value) + { + if (this.inboundUnknownProperties is null) + { + throw new InvalidOperationException(Resources.InboundMessageOnly); + } + + value = default; + + if (this.inboundUnknownProperties.TryGetValue(name, out ReadOnlySequence serializedValue) is true) + { + value = this.formatterProfile.Deserialize(serializedValue); + return true; + } + + return false; + } + } + + [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] + private class OutboundJsonRpcRequest : JsonRpcRequestBase + { + private readonly NerdbankMessagePackFormatter formatter; + + internal OutboundJsonRpcRequest(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter; + } + + protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.formatter.userDataProfile); + } + + [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] + private class JsonRpcRequest : JsonRpcRequestBase, IJsonRpcMessagePackRetention + { + private readonly NerdbankMessagePackFormatter formatter; + + internal JsonRpcRequest(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter; + } + + public override int ArgumentCount => this.MsgPackNamedArguments?.Count ?? this.MsgPackPositionalArguments?.Count ?? base.ArgumentCount; + + public override IEnumerable? ArgumentNames => this.MsgPackNamedArguments?.Keys; + + public ReadOnlySequence OriginalMessagePack { get; internal set; } + + internal ReadOnlySequence MsgPackArguments { get; set; } + + internal IReadOnlyDictionary>? MsgPackNamedArguments { get; set; } + + internal IReadOnlyList>? MsgPackPositionalArguments { get; set; } + + public override ArgumentMatchResult TryGetTypedArguments(ReadOnlySpan parameters, Span typedArguments) + { + using (this.formatter.TrackDeserialization(this, parameters)) + { + if (parameters.Length == 1 && this.MsgPackNamedArguments is not null) + { + if (this.formatter.ApplicableMethodAttributeOnDeserializingMethod?.UseSingleObjectParameterDeserialization ?? false) + { + var reader = new MessagePackReader(this.MsgPackArguments); + try + { + typedArguments[0] = this.formatter.userDataProfile.DeserializeObject( + ref reader, + parameters[0].ParameterType); + + return ArgumentMatchResult.Success; + } + catch (MessagePackSerializationException) + { + return ArgumentMatchResult.ParameterArgumentTypeMismatch; + } + } + } + + return base.TryGetTypedArguments(parameters, typedArguments); + } + } + + public override bool TryGetArgumentByNameOrIndex(string? name, int position, Type? typeHint, out object? value) + { + // If anyone asks us for an argument *after* we've been told deserialization is done, there's something very wrong. + Assumes.True(this.MsgPackNamedArguments is not null || this.MsgPackPositionalArguments is not null); + + ReadOnlySequence msgpackArgument = default; + if (position >= 0 && this.MsgPackPositionalArguments?.Count > position) + { + msgpackArgument = this.MsgPackPositionalArguments[position]; + } + else if (name is not null && this.MsgPackNamedArguments is not null) + { + this.MsgPackNamedArguments.TryGetValue(name, out msgpackArgument); + } + + if (msgpackArgument.IsEmpty) + { + value = null; + return false; + } + + using (this.formatter.TrackDeserialization(this)) + { + try + { + value = this.formatter.userDataProfile.DeserializeObject( + msgpackArgument, + typeHint ?? typeof(object)); + + return true; + } + catch (MessagePackSerializationException ex) + { + if (this.formatter.JsonRpc?.TraceSource.Switch.ShouldTrace(TraceEventType.Warning) ?? false) + { + this.formatter.JsonRpc.TraceSource.TraceEvent(TraceEventType.Warning, (int)JsonRpc.TraceEvents.MethodArgumentDeserializationFailure, Resources.FailureDeserializingRpcArgument, name, position, typeHint, ex); + } + + throw new RpcArgumentDeserializationException(name, position, typeHint, ex); + } + } + } + + protected override void ReleaseBuffers() + { + base.ReleaseBuffers(); + this.MsgPackNamedArguments = null; + this.MsgPackPositionalArguments = null; + this.TopLevelPropertyBag = null; + this.MsgPackArguments = default; + this.OriginalMessagePack = default; + } + + protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.formatter.userDataProfile); + } + + [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] + private partial class JsonRpcResult : JsonRpcResultBase, IJsonRpcMessagePackRetention + { + private readonly NerdbankMessagePackFormatter formatter; + private readonly Profile formatterProfile; + + private Exception? resultDeserializationException; + + internal JsonRpcResult(NerdbankMessagePackFormatter formatter, Profile formatterProfile) + { + this.formatter = formatter; + this.formatterProfile = formatterProfile; + } + + public ReadOnlySequence OriginalMessagePack { get; internal set; } + + internal ReadOnlySequence MsgPackResult { get; set; } + + public override T GetResult() + { + if (this.resultDeserializationException is not null) + { + ExceptionDispatchInfo.Capture(this.resultDeserializationException).Throw(); + } + + return this.MsgPackResult.IsEmpty + ? (T)this.Result! + : this.formatterProfile.Deserialize(this.MsgPackResult) + ?? throw new MessagePackSerializationException("Failed to deserialize result."); + } + + protected internal override void SetExpectedResultType(Type resultType) + { + Verify.Operation(!this.MsgPackResult.IsEmpty, "Result is no longer available or has already been deserialized."); + + try + { + using (this.formatter.TrackDeserialization(this)) + { + this.Result = this.formatterProfile.DeserializeObject(this.MsgPackResult, resultType); + } + + this.MsgPackResult = default; + } + catch (MessagePackSerializationException ex) + { + // This was a best effort anyway. We'll throw again later at a more convenient time for JsonRpc. + this.resultDeserializationException = ex; + } + } + + protected override void ReleaseBuffers() + { + base.ReleaseBuffers(); + this.MsgPackResult = default; + this.OriginalMessagePack = default; + } + + protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.formatterProfile); + } + + [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] + private class JsonRpcError : JsonRpcErrorBase, IJsonRpcMessagePackRetention + { + private readonly Profile formatterProfile; + + public JsonRpcError(Profile formatterProfile) + { + this.formatterProfile = formatterProfile; + } + + public ReadOnlySequence OriginalMessagePack { get; internal set; } + + protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.formatterProfile); + + protected override void ReleaseBuffers() + { + base.ReleaseBuffers(); + if (this.Error is ErrorDetail privateDetail) + { + privateDetail.MsgPackData = default; + } + + this.OriginalMessagePack = default; + } + + internal new class ErrorDetail : Protocol.JsonRpcError.ErrorDetail + { + private readonly Profile formatterProfile; + + internal ErrorDetail(Profile formatterProfile) + { + this.formatterProfile = formatterProfile ?? throw new ArgumentNullException(nameof(formatterProfile)); + } + + internal ReadOnlySequence MsgPackData { get; set; } + + public override object? GetData(Type dataType) + { + Requires.NotNull(dataType, nameof(dataType)); + if (this.MsgPackData.IsEmpty) + { + return this.Data; + } + + try + { + return this.formatterProfile.DeserializeObject(this.MsgPackData, dataType) + ?? throw new MessagePackSerializationException(Resources.FailureDeserializingJsonRpc); + } + catch (MessagePackSerializationException) + { + // Deserialization failed. Try returning array/dictionary based primitive objects. + try + { + // return MessagePackSerializer.Deserialize(this.MsgPackData, this.serializerOptions.WithResolver(PrimitiveObjectResolver.Instance)); + return this.formatterProfile.Deserialize(this.MsgPackData); + } + catch (MessagePackSerializationException) + { + return null; + } + } + } + + protected internal override void SetExpectedDataType(Type dataType) + { + Verify.Operation(!this.MsgPackData.IsEmpty, "Data is no longer available or has already been deserialized."); + + this.Data = this.GetData(dataType); + + // Clear the source now that we've deserialized to prevent GetData from attempting + // deserialization later when the buffer may be recycled on another thread. + this.MsgPackData = default; + } + } + } + + /// + /// Ensures certain exotic types are matched to the correct MessagePackConverter. + /// We rely on the caching in NerbdBank.MessagePackSerializer to ensure we don't create multiple instances of these shapes. + /// + private class ExoticTypeShapeProvider : ITypeShapeProvider + { + internal static readonly ExoticTypeShapeProvider Instance = new(); + + public ITypeShape? GetShape(Type type) + { + if (typeof(PipeReader).IsAssignableFrom(type)) + { + return new SourceGenObjectTypeShape() + { + IsRecordType = false, + IsTupleType = false, + Provider = this, + }; + } + + if (typeof(PipeWriter).IsAssignableFrom(type)) + { + return new SourceGenObjectTypeShape() + { + IsRecordType = false, + IsTupleType = false, + Provider = this, + }; + } + + if (typeof(Stream).IsAssignableFrom(type)) + { + return new SourceGenObjectTypeShape() + { + IsRecordType = false, + IsTupleType = false, + Provider = this, + }; + } + + if (typeof(IDuplexPipe).IsAssignableFrom(type)) + { + return new SourceGenObjectTypeShape() + { + IsRecordType = false, + IsTupleType = false, + Provider = this, + }; + } + + return null; + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatterProfileExtensions.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatterProfileExtensions.cs new file mode 100644 index 000000000..9239f0e0c --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatterProfileExtensions.cs @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Buffers; +using Nerdbank.MessagePack; +using static StreamJsonRpc.NerdbankMessagePackFormatter; + +namespace StreamJsonRpc; + +/// +/// Extension methods for that are specific to the . +/// +[System.Diagnostics.CodeAnalysis.SuppressMessage("ApiDesign", "RS0016:Add public types and members to the declared API", Justification = "TODO: Temporary for development")] +public static class NerdbankMessagePackFormatterProfileExtensions +{ + /// + /// Serializes an object using the specified . + /// + /// The formatter profile to use for serialization. + /// The writer to which the object will be serialized. + /// The object to serialize. + /// A token to monitor for cancellation requests. + public static void SerializeObject(this Profile profile, ref MessagePackWriter writer, object? value, CancellationToken cancellationToken = default) + { + Requires.NotNull(profile, nameof(profile)); + SerializeObject(profile, ref writer, value, value?.GetType() ?? typeof(object), cancellationToken); + } + + /// + /// Serializes an object using the specified . + /// + /// The formatter profile to use for serialization. + /// The writer to which the object will be serialized. + /// The object to serialize. + /// A token to monitor for cancellation requests. + /// The type of the object to serialize. + public static void Serialize(this Profile profile, ref MessagePackWriter writer, T? value, CancellationToken cancellationToken = default) + { + Requires.NotNull(profile, nameof(profile)); + + if (value is null) + { + writer.WriteNil(); + return; + } + + PolyType.Abstractions.ITypeShape shape = profile.ShapeProviderResolver.ResolveShape(); + profile.Serializer.Serialize( + ref writer, + value, + shape, + cancellationToken); + } + + /// + /// Deserializes a sequence of bytes into an object of type using the specified . + /// + /// The type of the object to deserialize. + /// The formatter profile to use for deserialization. + /// The sequence of bytes to deserialize. + /// A token to monitor for cancellation requests. + /// The deserialized object of type . + /// Thrown when deserialization fails. + public static T? Deserialize(this Profile profile, in ReadOnlySequence pack, CancellationToken cancellationToken = default) + { + Requires.NotNull(profile, nameof(profile)); + MessagePackReader reader = new(pack); + return Deserialize(profile, ref reader, cancellationToken); + } + + internal static T? Deserialize(this Profile profile, ref MessagePackReader reader, CancellationToken cancellationToken = default) + { + PolyType.ITypeShapeProvider provider = profile.ShapeProviderResolver.ResolveShapeProvider(); + return profile.Serializer.Deserialize( + ref reader, + provider, + cancellationToken); + } + + internal static object? DeserializeObject(this Profile profile, in ReadOnlySequence pack, Type objectType, CancellationToken cancellationToken = default) + { + MessagePackReader reader = new(pack); + return DeserializeObject(profile, ref reader, objectType, cancellationToken); + } + + internal static object? DeserializeObject(this Profile profile, ref MessagePackReader reader, Type objectType, CancellationToken cancellationToken = default) + { + PolyType.Abstractions.ITypeShape shape = profile.ShapeProviderResolver.ResolveShape(objectType); + return profile.Serializer.DeserializeObject( + ref reader, + shape, + cancellationToken); + } + + internal static void SerializeObject(this Profile profile, ref MessagePackWriter writer, object? value, Type objectType, CancellationToken cancellationToken = default) + { + if (value is null) + { + writer.WriteNil(); + return; + } + + PolyType.Abstractions.ITypeShape shape = profile.ShapeProviderResolver.ResolveShape(objectType); + profile.Serializer.SerializeObject( + ref writer, + value, + shape, + cancellationToken); + } +} diff --git a/src/StreamJsonRpc/Protocol/CommonErrorData.cs b/src/StreamJsonRpc/Protocol/CommonErrorData.cs index 823f4d5a5..f9e792dac 100644 --- a/src/StreamJsonRpc/Protocol/CommonErrorData.cs +++ b/src/StreamJsonRpc/Protocol/CommonErrorData.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System.Runtime.Serialization; +using PolyType; using STJ = System.Text.Json.Serialization; namespace StreamJsonRpc.Protocol; @@ -10,11 +11,13 @@ namespace StreamJsonRpc.Protocol; /// A class that describes useful data that may be found in the JSON-RPC error message's error.data property. /// [DataContract] -public class CommonErrorData +[GenerateShape] +public partial class CommonErrorData { /// /// Initializes a new instance of the class. /// + [ConstructorShape] public CommonErrorData() { } @@ -39,6 +42,7 @@ public CommonErrorData(Exception copyFrom) /// [DataMember(Order = 0, Name = "type")] [STJ.JsonPropertyName("type"), STJ.JsonPropertyOrder(0)] + [PropertyShape(Name = "type", Order = 0)] public string? TypeName { get; set; } /// @@ -46,6 +50,7 @@ public CommonErrorData(Exception copyFrom) /// [DataMember(Order = 1, Name = "message")] [STJ.JsonPropertyName("message"), STJ.JsonPropertyOrder(1)] + [PropertyShape(Name = "message", Order = 1)] public string? Message { get; set; } /// @@ -53,6 +58,7 @@ public CommonErrorData(Exception copyFrom) /// [DataMember(Order = 2, Name = "stack")] [STJ.JsonPropertyName("stack"), STJ.JsonPropertyOrder(2)] + [PropertyShape(Name = "stack", Order = 2)] public string? StackTrace { get; set; } /// @@ -60,6 +66,7 @@ public CommonErrorData(Exception copyFrom) /// [DataMember(Order = 3, Name = "code")] [STJ.JsonPropertyName("code"), STJ.JsonPropertyOrder(3)] + [PropertyShape(Name = "code", Order = 3)] public int HResult { get; set; } /// @@ -67,5 +74,6 @@ public CommonErrorData(Exception copyFrom) /// [DataMember(Order = 4, Name = "inner")] [STJ.JsonPropertyName("inner"), STJ.JsonPropertyOrder(4)] + [PropertyShape(Name = "inner", Order = 4)] public CommonErrorData? Inner { get; set; } } diff --git a/src/StreamJsonRpc/Protocol/JsonRpcError.cs b/src/StreamJsonRpc/Protocol/JsonRpcError.cs index 7905eb811..a088772f3 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcError.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcError.cs @@ -3,6 +3,8 @@ using System.Diagnostics; using System.Runtime.Serialization; +using Nerdbank.MessagePack; +using PolyType; using StreamJsonRpc.Reflection; using JsonNET = Newtonsoft.Json.Linq; using STJ = System.Text.Json.Serialization; @@ -13,14 +15,17 @@ namespace StreamJsonRpc.Protocol; /// Describes the error resulting from a that failed on the server. /// [DataContract] +[GenerateShape] +[MessagePackConverter(typeof(NerdbankMessagePackFormatter.JsonRpcErrorConverter))] [DebuggerDisplay("{" + nameof(DebuggerDisplay) + "}")] -public class JsonRpcError : JsonRpcMessage, IJsonRpcMessageWithId +public partial class JsonRpcError : JsonRpcMessage, IJsonRpcMessageWithId { /// /// Gets or sets the detail about the error. /// [DataMember(Name = "error", Order = 2, IsRequired = true)] [STJ.JsonPropertyName("error"), STJ.JsonPropertyOrder(2), STJ.JsonRequired] + [PropertyShape(Name = "error", Order = 2)] public ErrorDetail? Error { get; set; } /// @@ -30,6 +35,7 @@ public class JsonRpcError : JsonRpcMessage, IJsonRpcMessageWithId [Obsolete("Use " + nameof(RequestId) + " instead.")] [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public object? Id { get => this.RequestId.ObjectValue; @@ -41,6 +47,7 @@ public object? Id /// [DataMember(Name = "id", Order = 1, IsRequired = true, EmitDefaultValue = true)] [STJ.JsonPropertyName("id"), STJ.JsonPropertyOrder(1), STJ.JsonRequired] + [PropertyShape(Name = "id", Order = 1)] public RequestId RequestId { get; set; } /// @@ -66,7 +73,9 @@ public override string ToString() /// Describes the error. /// [DataContract] - public class ErrorDetail + [GenerateShape] + [MessagePackConverter(typeof(NerdbankMessagePackFormatter.JsonRpcErrorDetailConverter))] + public partial class ErrorDetail { /// /// Gets or sets a number that indicates the error type that occurred. @@ -77,6 +86,7 @@ public class ErrorDetail /// [DataMember(Name = "code", Order = 0, IsRequired = true)] [STJ.JsonPropertyName("code"), STJ.JsonPropertyOrder(0), STJ.JsonRequired] + [PropertyShape(Name = "code", Order = 0)] public JsonRpcErrorCode Code { get; set; } /// @@ -87,6 +97,7 @@ public class ErrorDetail /// [DataMember(Name = "message", Order = 1, IsRequired = true)] [STJ.JsonPropertyName("message"), STJ.JsonPropertyOrder(1), STJ.JsonRequired] + [PropertyShape(Name = "message", Order = 1)] public string? Message { get; set; } /// @@ -95,6 +106,7 @@ public class ErrorDetail [DataMember(Name = "data", Order = 2, IsRequired = false)] [Newtonsoft.Json.JsonProperty(DefaultValueHandling = Newtonsoft.Json.DefaultValueHandling.Ignore)] [STJ.JsonPropertyName("data"), STJ.JsonPropertyOrder(2)] + [PropertyShape(Name = "data", Order = 2)] public object? Data { get; set; } /// @@ -129,7 +141,7 @@ public class ErrorDetail /// /// The type that will be used as the generic type argument to . /// - /// Overridding methods in types that retain buffers used to deserialize should deserialize within this method and clear those buffers + /// Overriding methods in types that retain buffers used to deserialize should deserialize within this method and clear those buffers /// to prevent further access to these buffers which may otherwise happen concurrently with a call to /// that would recycle the same buffer being deserialized from. /// diff --git a/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs b/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs index 84acc9373..be652d36d 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs @@ -3,6 +3,8 @@ using System.Diagnostics.CodeAnalysis; using System.Runtime.Serialization; +using Nerdbank.MessagePack; +using PolyType; using STJ = System.Text.Json.Serialization; namespace StreamJsonRpc.Protocol; @@ -14,7 +16,18 @@ namespace StreamJsonRpc.Protocol; [KnownType(typeof(JsonRpcRequest))] [KnownType(typeof(JsonRpcResult))] [KnownType(typeof(JsonRpcError))] -public abstract class JsonRpcMessage +#if NETSTANDARD2_0_OR_GREATER +[KnownSubType(typeof(JsonRpcRequest), 1)] +[KnownSubType(typeof(JsonRpcResult), 2)] +[KnownSubType(typeof(JsonRpcError), 3)] +#elif NET +[KnownSubType(1)] +[KnownSubType(2)] +[KnownSubType(3)] +#endif +[MessagePackConverter(typeof(NerdbankMessagePackFormatter.JsonRpcMessageConverter))] +[GenerateShape] +public abstract partial class JsonRpcMessage { /// /// Gets or sets the version of the JSON-RPC protocol that this message conforms to. @@ -22,6 +35,7 @@ public abstract class JsonRpcMessage /// Defaults to "2.0". [DataMember(Name = "jsonrpc", Order = 0, IsRequired = true)] [STJ.JsonPropertyName("jsonrpc"), STJ.JsonPropertyOrder(0), STJ.JsonRequired] + [PropertyShape(Name = "jsonrpc", Order = 0)] public string Version { get; set; } = "2.0"; /// diff --git a/src/StreamJsonRpc/Protocol/JsonRpcRequest.cs b/src/StreamJsonRpc/Protocol/JsonRpcRequest.cs index c41239ac6..37fe79216 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcRequest.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcRequest.cs @@ -4,6 +4,8 @@ using System.Diagnostics; using System.Reflection; using System.Runtime.Serialization; +using Nerdbank.MessagePack; +using PolyType; using JsonNET = Newtonsoft.Json.Linq; using STJ = System.Text.Json.Serialization; @@ -13,8 +15,10 @@ namespace StreamJsonRpc.Protocol; /// Describes a method to be invoked on the server. /// [DataContract] +[GenerateShape] +[MessagePackConverter(typeof(NerdbankMessagePackFormatter.JsonRpcRequestConverter))] [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] -public class JsonRpcRequest : JsonRpcMessage, IJsonRpcMessageWithId +public partial class JsonRpcRequest : JsonRpcMessage, IJsonRpcMessageWithId { /// /// The result of an attempt to match request arguments with a candidate method's parameters. @@ -47,6 +51,7 @@ public enum ArgumentMatchResult /// [DataMember(Name = "method", Order = 2, IsRequired = true)] [STJ.JsonPropertyName("method"), STJ.JsonPropertyOrder(2), STJ.JsonRequired] + [PropertyShape(Name = "method", Order = 2)] public string? Method { get; set; } /// @@ -61,6 +66,7 @@ public enum ArgumentMatchResult /// [DataMember(Name = "params", Order = 3, IsRequired = false, EmitDefaultValue = false)] [STJ.JsonPropertyName("params"), STJ.JsonPropertyOrder(3), STJ.JsonIgnore(Condition = STJ.JsonIgnoreCondition.WhenWritingNull)] + [PropertyShape(Name = "params", Order = 3)] public object? Arguments { get; set; } /// @@ -70,6 +76,7 @@ public enum ArgumentMatchResult [Obsolete("Use " + nameof(RequestId) + " instead.")] [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public object? Id { get => this.RequestId.ObjectValue; @@ -81,6 +88,7 @@ public object? Id /// [DataMember(Name = "id", Order = 1, IsRequired = false, EmitDefaultValue = false)] [STJ.JsonPropertyName("id"), STJ.JsonPropertyOrder(1), STJ.JsonIgnore(Condition = STJ.JsonIgnoreCondition.WhenWritingDefault)] + [PropertyShape(Name = "id", Order = 1)] public RequestId RequestId { get; set; } /// @@ -88,6 +96,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public bool IsResponseExpected => !this.RequestId.IsEmpty; /// @@ -95,6 +104,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public bool IsNotification => this.RequestId.IsEmpty; /// @@ -102,6 +112,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public virtual int ArgumentCount => this.NamedArguments?.Count ?? this.ArgumentsList?.Count ?? 0; /// @@ -109,6 +120,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public IReadOnlyDictionary? NamedArguments { get => this.Arguments as IReadOnlyDictionary; @@ -127,6 +139,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public IReadOnlyDictionary? NamedArgumentDeclaredTypes { get; set; } /// @@ -134,6 +147,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] [Obsolete("Use " + nameof(ArgumentsList) + " instead.")] public object?[]? ArgumentsArray { @@ -146,6 +160,7 @@ public object?[]? ArgumentsArray /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public IReadOnlyList? ArgumentsList { get => this.Arguments as IReadOnlyList; @@ -166,6 +181,7 @@ public IReadOnlyList? ArgumentsList /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public IReadOnlyList? ArgumentListDeclaredTypes { get; set; } /// @@ -173,6 +189,7 @@ public IReadOnlyList? ArgumentsList /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public virtual IEnumerable? ArgumentNames => this.NamedArguments?.Keys; /// @@ -180,6 +197,7 @@ public IReadOnlyList? ArgumentsList /// [DataMember(Name = "traceparent", EmitDefaultValue = false)] [STJ.JsonPropertyName("traceparent"), STJ.JsonIgnore(Condition = STJ.JsonIgnoreCondition.WhenWritingNull)] + [PropertyShape(Name = "traceparent")] public string? TraceParent { get; set; } /// @@ -187,6 +205,7 @@ public IReadOnlyList? ArgumentsList /// [DataMember(Name = "tracestate", EmitDefaultValue = false)] [STJ.JsonPropertyName("tracestate"), STJ.JsonIgnore(Condition = STJ.JsonIgnoreCondition.WhenWritingNull)] + [PropertyShape(Name = "tracestate")] public string? TraceState { get; set; } /// diff --git a/src/StreamJsonRpc/Protocol/JsonRpcResult.cs b/src/StreamJsonRpc/Protocol/JsonRpcResult.cs index 6bd3157e6..d9860a99f 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcResult.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcResult.cs @@ -3,6 +3,8 @@ using System.Diagnostics; using System.Runtime.Serialization; +using Nerdbank.MessagePack; +using PolyType; using JsonNET = Newtonsoft.Json.Linq; using STJ = System.Text.Json.Serialization; @@ -12,14 +14,17 @@ namespace StreamJsonRpc.Protocol; /// Describes the result of a successful method invocation. /// [DataContract] +[GenerateShape] +[MessagePackConverter(typeof(NerdbankMessagePackFormatter.JsonRpcResultConverter))] [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] -public class JsonRpcResult : JsonRpcMessage, IJsonRpcMessageWithId +public partial class JsonRpcResult : JsonRpcMessage, IJsonRpcMessageWithId { /// /// Gets or sets the value of the result of an invocation, if any. /// [DataMember(Name = "result", Order = 2, IsRequired = true, EmitDefaultValue = true)] [STJ.JsonPropertyName("result"), STJ.JsonPropertyOrder(2), STJ.JsonRequired] + [PropertyShape(Name = "result", Order = 2)] public object? Result { get; set; } /// @@ -30,6 +35,7 @@ public class JsonRpcResult : JsonRpcMessage, IJsonRpcMessageWithId /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public Type? ResultDeclaredType { get; set; } /// @@ -39,6 +45,7 @@ public class JsonRpcResult : JsonRpcMessage, IJsonRpcMessageWithId [Obsolete("Use " + nameof(RequestId) + " instead.")] [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public object? Id { get => this.RequestId.ObjectValue; @@ -50,6 +57,7 @@ public object? Id /// [DataMember(Name = "id", Order = 1, IsRequired = true)] [STJ.JsonPropertyName("id"), STJ.JsonPropertyOrder(1), STJ.JsonRequired] + [PropertyShape(Name = "id", Order = 1)] public RequestId RequestId { get; set; } /// diff --git a/src/StreamJsonRpc/Protocol/TraceParent.cs b/src/StreamJsonRpc/Protocol/TraceParent.cs index b4fe393e2..de125b958 100644 --- a/src/StreamJsonRpc/Protocol/TraceParent.cs +++ b/src/StreamJsonRpc/Protocol/TraceParent.cs @@ -2,9 +2,11 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System.Diagnostics; +using Nerdbank.MessagePack; namespace StreamJsonRpc.Protocol; +[MessagePackConverter(typeof(NerdbankMessagePackFormatter.TraceParentConverter))] internal unsafe struct TraceParent { internal const int VersionByteCount = 1; diff --git a/src/StreamJsonRpc/Reflection/MessageFormatterEnumerableTracker.cs b/src/StreamJsonRpc/Reflection/MessageFormatterEnumerableTracker.cs index 927fc92cf..c7b69b0de 100644 --- a/src/StreamJsonRpc/Reflection/MessageFormatterEnumerableTracker.cs +++ b/src/StreamJsonRpc/Reflection/MessageFormatterEnumerableTracker.cs @@ -8,6 +8,7 @@ using System.Threading.Tasks.Dataflow; using Microsoft.VisualStudio.Threading; using Nerdbank.Streams; +using PolyType; using StreamJsonRpc.Protocol; using STJ = System.Text.Json.Serialization; @@ -218,6 +219,20 @@ private void CleanUpResources(RequestId outboundRequestId) } } + [DataContract] + internal class EnumeratorResults + { + [DataMember(Name = ValuesPropertyName, Order = 0)] + [STJ.JsonPropertyName(ValuesPropertyName), STJ.JsonPropertyOrder(0)] + [PropertyShape(Name = ValuesPropertyName, Order = 0)] + public IReadOnlyList? Values { get; set; } + + [DataMember(Name = FinishedPropertyName, Order = 1)] + [STJ.JsonPropertyName(FinishedPropertyName), STJ.JsonPropertyOrder(1)] + [PropertyShape(Name = FinishedPropertyName, Order = 1)] + public bool Finished { get; set; } + } + private class GeneratingEnumeratorTracker : IGeneratingEnumeratorTracker { private readonly IAsyncEnumerator enumerator; @@ -525,16 +540,4 @@ private static void Write(IBufferWriter writer, IReadOnlyList values) } } } - - [DataContract] - private class EnumeratorResults - { - [DataMember(Name = ValuesPropertyName, Order = 0)] - [STJ.JsonPropertyName(ValuesPropertyName), STJ.JsonPropertyOrder(0)] - public IReadOnlyList? Values { get; set; } - - [DataMember(Name = FinishedPropertyName, Order = 1)] - [STJ.JsonPropertyName(FinishedPropertyName), STJ.JsonPropertyOrder(1)] - public bool Finished { get; set; } - } } diff --git a/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs b/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs index d91b1ea38..712c7c09b 100644 --- a/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs +++ b/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs @@ -9,6 +9,7 @@ using System.Reflection; using System.Runtime.Serialization; using Microsoft.VisualStudio.Threading; +using PolyType; using static System.FormattableString; using STJ = System.Text.Json.Serialization; @@ -17,7 +18,7 @@ namespace StreamJsonRpc.Reflection; /// /// Tracks objects that get marshaled using the general marshaling protocol. /// -internal class MessageFormatterRpcMarshaledContextTracker +internal partial class MessageFormatterRpcMarshaledContextTracker { private static readonly IReadOnlyCollection<(Type ImplicitlyMarshaledType, JsonRpcProxyOptions ProxyOptions, JsonRpcTargetOptions TargetOptions, RpcMarshalableAttribute Attribute)> ImplicitlyMarshaledTypes = new (Type, JsonRpcProxyOptions, JsonRpcTargetOptions, RpcMarshalableAttribute)[] { @@ -396,7 +397,7 @@ internal MarshalToken GetToken(object marshaledObject, JsonRpcTargetOptions opti private static void ValidateMarshalableInterface(Type type, RpcMarshalableAttribute attribute) { // We only require marshalable interfaces to derive from IDisposable when they are not call-scoped. - if (!attribute.CallScopedLifetime && !typeof(IDisposable).IsAssignableFrom(type)) + if (attribute.CallScopedLifetime && !typeof(IDisposable).IsAssignableFrom(type)) { throw new NotSupportedException(string.Format(CultureInfo.CurrentCulture, Resources.MarshalableInterfaceNotDisposable, type.FullName)); } @@ -450,10 +451,15 @@ private void CleanUpOutboundResources(RequestId requestId, bool successful) } } + /// + /// A token that represents a marshaled object. + /// [DataContract] - internal struct MarshalToken + [GenerateShape] + internal partial struct MarshalToken { [MessagePack.SerializationConstructor] + [ConstructorShape] #pragma warning disable SA1313 // Parameter names should begin with lower-case letter public MarshalToken(int __jsonrpc_marshaled, long handle, string? lifetime, int[]? optionalInterfaces) #pragma warning restore SA1313 // Parameter names should begin with lower-case letter @@ -466,18 +472,22 @@ public MarshalToken(int __jsonrpc_marshaled, long handle, string? lifetime, int[ [DataMember(Name = "__jsonrpc_marshaled", IsRequired = true)] [STJ.JsonPropertyName("__jsonrpc_marshaled"), STJ.JsonRequired] + [PropertyShape(Name = "__jsonrpc_marshaled")] public int Marshaled { get; set; } [DataMember(Name = "handle", IsRequired = true)] [STJ.JsonPropertyName("handle"), STJ.JsonRequired] + [PropertyShape(Name = "handle")] public long Handle { get; set; } [DataMember(Name = "lifetime", EmitDefaultValue = false)] [STJ.JsonPropertyName("lifetime"), STJ.JsonIgnore(Condition = STJ.JsonIgnoreCondition.WhenWritingNull)] + [PropertyShape(Name = "lifetime")] public string? Lifetime { get; set; } [DataMember(Name = "optionalInterfaces", EmitDefaultValue = false)] [STJ.JsonPropertyName("optionalInterfaces"), STJ.JsonIgnore(Condition = STJ.JsonIgnoreCondition.WhenWritingNull)] + [PropertyShape(Name = "optionalInterfaces")] public int[]? OptionalInterfacesCodes { get; set; } } diff --git a/src/StreamJsonRpc/Reflection/RpcMarshalableAttribute.cs b/src/StreamJsonRpc/Reflection/RpcMarshalableAttribute.cs index 07e5f29d6..e2158249e 100644 --- a/src/StreamJsonRpc/Reflection/RpcMarshalableAttribute.cs +++ b/src/StreamJsonRpc/Reflection/RpcMarshalableAttribute.cs @@ -7,7 +7,7 @@ namespace StreamJsonRpc; /// Designates an interface that is used in an RPC contract to marshal the object so the receiver can invoke remote methods on it instead of serializing the object to send its data to the remote end. /// /// -/// Learn more about marshable interfaces. +/// Learn more about marshalable interfaces. /// [AttributeUsage(AttributeTargets.Interface, AllowMultiple = false, Inherited = false)] public class RpcMarshalableAttribute : Attribute diff --git a/src/StreamJsonRpc/SerializationContextExtensions.cs b/src/StreamJsonRpc/SerializationContextExtensions.cs new file mode 100644 index 000000000..b4f5ab84a --- /dev/null +++ b/src/StreamJsonRpc/SerializationContextExtensions.cs @@ -0,0 +1,16 @@ +using Nerdbank.MessagePack; + +namespace StreamJsonRpc; + +internal static class SerializationContextExtensions +{ + private static readonly object FormatterStateKey = new(); + + internal static object FormatterKey => FormatterStateKey; + + internal static NerdbankMessagePackFormatter GetFormatter(this ref SerializationContext context) + { + return (context[FormatterKey] as NerdbankMessagePackFormatter) + ?? throw new InvalidOperationException(); + } +} diff --git a/src/StreamJsonRpc/StreamJsonRpc.csproj b/src/StreamJsonRpc/StreamJsonRpc.csproj index 70fa95f12..b51740ebc 100644 --- a/src/StreamJsonRpc/StreamJsonRpc.csproj +++ b/src/StreamJsonRpc/StreamJsonRpc.csproj @@ -1,6 +1,6 @@  - netstandard2.0;netstandard2.1;net6.0;net8.0 + netstandard2.0;netstandard2.1;net8.0 prompt 4 true @@ -17,6 +17,7 @@ + diff --git a/test/Benchmarks/Benchmarks.csproj b/test/Benchmarks/Benchmarks.csproj index b3a91e41d..bbcc37f3f 100644 --- a/test/Benchmarks/Benchmarks.csproj +++ b/test/Benchmarks/Benchmarks.csproj @@ -2,7 +2,7 @@ Exe - net6.0;net472 + net8.0;net472 diff --git a/test/Benchmarks/InvokeBenchmarks.cs b/test/Benchmarks/InvokeBenchmarks.cs index 78fc982cf..7a7a409ac 100644 --- a/test/Benchmarks/InvokeBenchmarks.cs +++ b/test/Benchmarks/InvokeBenchmarks.cs @@ -15,7 +15,7 @@ public class InvokeBenchmarks private JsonRpc clientRpc = null!; private JsonRpc serverRpc = null!; - [Params("JSON", "MessagePack")] + [Params("JSON", "MessagePack", "NerdbankMessagePack")] public string Formatter { get; set; } = null!; [GlobalSetup] @@ -35,6 +35,7 @@ IJsonRpcMessageHandler CreateHandler(IDuplexPipe pipe) { "JSON" => new HeaderDelimitedMessageHandler(pipe, new JsonMessageFormatter()), "MessagePack" => new LengthHeaderMessageHandler(pipe, new MessagePackFormatter()), + "NerdbankMessagePack" => new LengthHeaderMessageHandler(pipe, new NerdbankMessagePackFormatter()), _ => throw Assumes.NotReachable(), }; } diff --git a/test/Benchmarks/NotifyBenchmarks.cs b/test/Benchmarks/NotifyBenchmarks.cs index 92fe6ef47..5924e52f4 100644 --- a/test/Benchmarks/NotifyBenchmarks.cs +++ b/test/Benchmarks/NotifyBenchmarks.cs @@ -12,7 +12,7 @@ public class NotifyBenchmarks { private JsonRpc clientRpc = null!; - [Params("JSON", "MessagePack")] + [Params("JSON", "MessagePack", "NerdbankMessagePack")] public string Formatter { get; set; } = null!; [GlobalSetup] @@ -26,6 +26,7 @@ IJsonRpcMessageHandler CreateHandler(Stream pipe) { "JSON" => new HeaderDelimitedMessageHandler(pipe, new JsonMessageFormatter()), "MessagePack" => new LengthHeaderMessageHandler(pipe, pipe, new MessagePackFormatter()), + "NerdbankMessagePack" => new LengthHeaderMessageHandler(pipe, pipe, new NerdbankMessagePackFormatter()), _ => throw Assumes.NotReachable(), }; } diff --git a/test/Benchmarks/ShortLivedConnectionBenchmarks.cs b/test/Benchmarks/ShortLivedConnectionBenchmarks.cs index fe58eab77..0a59cee53 100644 --- a/test/Benchmarks/ShortLivedConnectionBenchmarks.cs +++ b/test/Benchmarks/ShortLivedConnectionBenchmarks.cs @@ -14,7 +14,7 @@ public class ShortLivedConnectionBenchmarks { private const int Iterations = 1000; - [Params("JSON", "MessagePack")] + [Params("JSON", "MessagePack", "NerdbankMessagePack")] public string Formatter { get; set; } = null!; [Benchmark(OperationsPerInvoke = Iterations)] @@ -39,6 +39,7 @@ IJsonRpcMessageHandler CreateHandler(IDuplexPipe pipe) { "JSON" => new HeaderDelimitedMessageHandler(pipe, new JsonMessageFormatter()), "MessagePack" => new LengthHeaderMessageHandler(pipe, new MessagePackFormatter()), + "NerdbankMessagePack" => new LengthHeaderMessageHandler(pipe, new NerdbankMessagePackFormatter()), _ => throw Assumes.NotReachable(), }; } diff --git a/test/Benchmarks/run.cmd b/test/Benchmarks/run.cmd index 81d737eb3..963044933 100644 --- a/test/Benchmarks/run.cmd +++ b/test/Benchmarks/run.cmd @@ -1,3 +1,3 @@ @pushd "%~dp0\" -dotnet run -f net6.0 -c release -- --runtimes net472 net6.0 %* +dotnet run -f net8.0 -c release -- --runtimes net472 net8.0 %* @popd diff --git a/test/StreamJsonRpc.Tests/AssemblyLoadTests.cs b/test/StreamJsonRpc.Tests/AssemblyLoadTests.cs index ebf80d0d1..441ab0215 100644 --- a/test/StreamJsonRpc.Tests/AssemblyLoadTests.cs +++ b/test/StreamJsonRpc.Tests/AssemblyLoadTests.cs @@ -63,6 +63,27 @@ public void MessagePackDoesNotLoadNewtonsoftJsonUnnecessarily() } } + [Fact] + public void NerdbankMessagePackDoesNotLoadNewtonsoftJsonUnnecessarily() + { + AppDomain testDomain = CreateTestAppDomain(); + try + { + var driver = (AppDomainTestDriver)testDomain.CreateInstanceAndUnwrap(typeof(AppDomainTestDriver).Assembly.FullName, typeof(AppDomainTestDriver).FullName); + + this.PrintLoadedAssemblies(driver); + + driver.CreateNerdbankMessagePackConnection(); + + this.PrintLoadedAssemblies(driver); + driver.ThrowIfAssembliesLoaded("Newtonsoft.Json"); + } + finally + { + AppDomain.Unload(testDomain); + } + } + [Fact] public void MockFormatterDoesNotLoadJsonOrMessagePackUnnecessarily() { @@ -142,6 +163,11 @@ internal void CreateMessagePackConnection() var jsonRpc = new JsonRpc(new LengthHeaderMessageHandler(FullDuplexStream.CreatePipePair().Item1, new MessagePackFormatter())); } + internal void CreateNerdbankMessagePackConnection() + { + var jsonRpc = new JsonRpc(new LengthHeaderMessageHandler(FullDuplexStream.CreatePipePair().Item1, new NerdbankMessagePackFormatter())); + } + #pragma warning restore CA1822 // Mark members as static private class MockFormatter : IJsonRpcMessageFormatter diff --git a/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs new file mode 100644 index 000000000..d3dd9ed7b --- /dev/null +++ b/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using PolyType; +using static AsyncEnumerableTests; + +public class AsyncEnumerableNerdbankMessagePackTests : AsyncEnumerableTests +{ + public AsyncEnumerableNerdbankMessagePackTests(ITestOutputHelper logger) + : base(logger) + { + } + + protected override void InitializeFormattersAndHandlers() + { + NerdbankMessagePackFormatter serverFormatter = new(); + serverFormatter.SetFormatterProfile(ConfigureContext); + + NerdbankMessagePackFormatter clientFormatter = new(); + clientFormatter.SetFormatterProfile(ConfigureContext); + + this.serverMessageFormatter = serverFormatter; + this.clientMessageFormatter = clientFormatter; + + static void ConfigureContext(NerdbankMessagePackFormatter.Profile.Builder profileBuilder) + { + profileBuilder.RegisterAsyncEnumerableConverter(); + profileBuilder.RegisterAsyncEnumerableConverter(); + profileBuilder.AddTypeShapeProvider(AsyncEnumerableWitness.ShapeProvider); + profileBuilder.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); + } + } +} + +[GenerateShape>] +[GenerateShape>] +[GenerateShape>] +[GenerateShape] +#pragma warning disable SA1402 // File may only contain a single type +public partial class AsyncEnumerableWitness; +#pragma warning restore SA1402 // File may only contain a single type diff --git a/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs b/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs index 7b1aed635..f993f66b4 100644 --- a/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs +++ b/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs @@ -11,6 +11,7 @@ using Microsoft.VisualStudio.Threading; using Nerdbank.Streams; using Newtonsoft.Json; +using NBMP = Nerdbank.MessagePack; public abstract class AsyncEnumerableTests : TestBase, IAsyncLifetime { @@ -276,7 +277,11 @@ async IAsyncEnumerable Generator(CancellationToken cancellationToken) } else { - await this.clientRpc.InvokeWithCancellationAsync(nameof(Server.PassInNumbersAsync), new object[] { Generator(this.TimeoutToken) }, this.TimeoutToken); + await this.clientRpc.InvokeWithCancellationAsync( + nameof(Server.PassInNumbersAsync), + [Generator(this.TimeoutToken)], + [typeof(IAsyncEnumerable)], + this.TimeoutToken); } } @@ -371,7 +376,7 @@ public async Task Cancellation_DuringLongRunningServerMoveNext(bool useProxy) await Assert.ThrowsAnyAsync(async () => await moveNextTask).WithCancellation(this.TimeoutToken); } - [Theory] + [Theory(Timeout = 2 * 1000)] // TODO: Temporary for development [PairwiseData] public async Task Cancellation_DuringLongRunningServerBeforeReturning(bool useProxy, [CombinatorialValues(0, 1, 2, 3)] int prefetchStrategy) { @@ -447,8 +452,14 @@ public async Task NotifyAsync_ThrowsIfAsyncEnumerableSent() // But for a notification there's no guarantee the server handles the message and no way to get an error back, // so it simply should not be allowed since the risk of memory leak is too high. var numbers = new int[] { 1, 2, 3 }.AsAsyncEnumerable(); - await Assert.ThrowsAnyAsync(() => this.clientRpc.NotifyAsync(nameof(Server.PassInNumbersAsync), new object?[] { numbers })); - await Assert.ThrowsAnyAsync(() => this.clientRpc.NotifyAsync(nameof(Server.PassInNumbersAsync), new object?[] { new { e = numbers } })); + await Assert.ThrowsAnyAsync(() => this.clientRpc.NotifyAsync( + nameof(Server.PassInNumbersAsync), + [numbers], + [typeof(IAsyncEnumerable)])); + await Assert.ThrowsAnyAsync(() => this.clientRpc.NotifyAsync( + nameof(Server.PassInNumbersAsync), + new object?[] { new { e = numbers } }, + [typeof(IAsyncEnumerable)])); } [SkippableFact] @@ -618,6 +629,16 @@ private async Task ReturnEnumerable_AutomaticallyReleasedOnErrorF return weakReferenceToSource; } + [DataContract] + protected internal class CompoundEnumerableResult + { + [DataMember] + public string? Message { get; set; } + + [DataMember] + public IAsyncEnumerable? Enumeration { get; set; } + } + protected class Server : IServer { /// @@ -795,18 +816,9 @@ protected class Client : IClient public Task DoSomethingAsync(CancellationToken cancellationToken) => Task.CompletedTask; } - [DataContract] - protected class CompoundEnumerableResult - { - [DataMember] - public string? Message { get; set; } - - [DataMember] - public IAsyncEnumerable? Enumeration { get; set; } - } - [JsonConverter(typeof(ThrowingJsonConverter))] [MessagePackFormatter(typeof(ThrowingMessagePackFormatter))] + [NBMP.MessagePackConverter(typeof(ThrowingMessagePackNerdbankConverter))] protected class UnserializableType { } @@ -836,4 +848,17 @@ public void Serialize(ref MessagePackWriter writer, T value, MessagePackSerializ throw new Exception(); } } + + protected class ThrowingMessagePackNerdbankConverter : NBMP.MessagePackConverter + { + public override T? Read(ref NBMP.MessagePackReader reader, NBMP.SerializationContext context) + { + throw new Exception(); + } + + public override void Write(ref NBMP.MessagePackWriter writer, in T? value, NBMP.SerializationContext context) + { + throw new Exception(); + } + } } diff --git a/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs new file mode 100644 index 000000000..de8990007 --- /dev/null +++ b/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Diagnostics; +using System.IO.Pipelines; +using Nerdbank.MessagePack; +using PolyType; + +public class DisposableProxyNerdbankMessagePackTests : DisposableProxyTests +{ + public DisposableProxyNerdbankMessagePackTests(ITestOutputHelper logger) + : base(logger) + { + } + + protected override Type FormatterExceptionType => typeof(MessagePackSerializationException); + + protected override IJsonRpcMessageFormatter CreateFormatter() + { + NerdbankMessagePackFormatter formatter = new(); + formatter.SetFormatterProfile(b => + { + KnownSubTypeMapping disposableMapping = new(); + disposableMapping.Add(alias: 1, DisposableProxyWitness.ShapeProvider); + + b.RegisterKnownSubTypes(disposableMapping); + b.AddTypeShapeProvider(DisposableProxyWitness.ShapeProvider); + b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); + }); + + return formatter; + } +} + +[GenerateShape] +[GenerateShape] +[GenerateShape] +[GenerateShape] +#pragma warning disable SA1402 // File may only contain a single type +public partial class DisposableProxyWitness; +#pragma warning restore SA1402 // File may only contain a single type diff --git a/test/StreamJsonRpc.Tests/DisposableProxyTests.cs b/test/StreamJsonRpc.Tests/DisposableProxyTests.cs index 8be43c55b..57be6d20f 100644 --- a/test/StreamJsonRpc.Tests/DisposableProxyTests.cs +++ b/test/StreamJsonRpc.Tests/DisposableProxyTests.cs @@ -77,7 +77,7 @@ public async Task IDisposableInNotificationArgumentIsRejected() Assert.True(IsExceptionOrInnerOfType(ex)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development public async Task DisposableReturnValue_DisposeSwallowsSecondCall() { IDisposable? proxyDisposable = await this.client.GetDisposableAsync(); @@ -86,7 +86,7 @@ public async Task DisposableReturnValue_DisposeSwallowsSecondCall() proxyDisposable.Dispose(); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development public async Task DisposableReturnValue_IsMarshaledAndLaterCollected() { var weakRefs = await this.DisposableReturnValue_Helper(); @@ -110,7 +110,7 @@ public async Task DisposableWithinArg_IsMarshaledAndLaterCollected() await this.AssertWeakReferenceGetsCollectedAsync(weakRefs.Target); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development public async Task DisposableReturnValue_Null() { IDisposable? proxyDisposable = await this.client.GetDisposableAsync(returnNull: true); diff --git a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs new file mode 100644 index 000000000..8b388d722 --- /dev/null +++ b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.IO.Pipelines; +using Nerdbank.MessagePack; +using Nerdbank.Streams; +using PolyType; + +public class DuplexPipeMarshalingNerdbankMessagePackTests : DuplexPipeMarshalingTests +{ + public DuplexPipeMarshalingNerdbankMessagePackTests(ITestOutputHelper logger) + : base(logger) + { + } + + protected override void InitializeFormattersAndHandlers() + { + NerdbankMessagePackFormatter serverFormatter = new() + { + MultiplexingStream = this.serverMx, + }; + + NerdbankMessagePackFormatter clientFormatter = new() + { + MultiplexingStream = this.clientMx, + }; + + serverFormatter.SetFormatterProfile(Configure); + clientFormatter.SetFormatterProfile(Configure); + + this.serverMessageFormatter = serverFormatter; + this.clientMessageFormatter = clientFormatter; + + static void Configure(NerdbankMessagePackFormatter.Profile.Builder b) + { + b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); + } + } +} diff --git a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingTests.cs b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingTests.cs index 7cd6fcf7e..0c367c83e 100644 --- a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingTests.cs +++ b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingTests.cs @@ -8,9 +8,10 @@ using System.Text; using Microsoft.VisualStudio.Threading; using Nerdbank.Streams; +using PolyType; using STJ = System.Text.Json.Serialization; -public abstract class DuplexPipeMarshalingTests : TestBase, IAsyncLifetime +public abstract partial class DuplexPipeMarshalingTests : TestBase, IAsyncLifetime { protected readonly Server server = new Server(); protected JsonRpc serverRpc; @@ -116,7 +117,7 @@ public async Task ClientCanSendReadOnlyPipeToServer(bool orderedArguments) { bytesReceived = await this.clientRpc.InvokeWithCancellationAsync( nameof(Server.AcceptReadablePipe), - new object[] { ExpectedFileName, pipes.Item2 }, + [ExpectedFileName, pipes.Item2], this.TimeoutToken); } else @@ -143,7 +144,7 @@ public async Task ClientCanSendWriteOnlyPipeToServer(bool orderedArguments) { await this.clientRpc.InvokeWithCancellationAsync( nameof(Server.AcceptWritablePipe), - new object[] { pipes.Item2, bytesToReceive }, + [pipes.Item2, bytesToReceive], this.TimeoutToken); } else @@ -182,7 +183,7 @@ public async Task ClientCanSendPipeReaderToServer() int bytesReceived = await this.clientRpc.InvokeWithCancellationAsync( nameof(Server.AcceptPipeReader), - new object[] { ExpectedFileName, pipe.Reader }, + [ExpectedFileName, pipe.Reader], this.TimeoutToken); Assert.Equal(MemoryBuffer.Length, bytesReceived); @@ -196,7 +197,7 @@ public async Task ClientCanSendPipeWriterToServer() int bytesToReceive = MemoryBuffer.Length - 1; await this.clientRpc.InvokeWithCancellationAsync( nameof(Server.AcceptPipeWriter), - new object[] { pipe.Writer, bytesToReceive }, + [pipe.Writer, bytesToReceive], this.TimeoutToken); // Read all that the server wanted us to know, and verify it. @@ -332,13 +333,17 @@ public async Task ServerMethodThatReturnsCustomTypeWithInnerStream() result.InnerStream.Dispose(); } - [Fact] + [SkippableFact] public async Task PassStreamWithArgsAsSingleObject() { + Skip.If(this.GetType() == typeof(DuplexPipeMarshalingNerdbankMessagePackTests), "Dynamic types are not supported with NerdBankMessagePack."); MemoryStream ms = new(); ms.Write(new byte[] { 1, 2, 3 }, 0, 3); ms.Position = 0; - int bytesRead = await this.clientRpc.InvokeWithParameterObjectAsync(nameof(Server.AcceptStreamArgInFirstParam), new { innerStream = ms }, this.TimeoutToken); + int bytesRead = await this.clientRpc.InvokeWithParameterObjectAsync( + nameof(Server.AcceptStreamArgInFirstParam), + new { innerStream = ms }, + this.TimeoutToken); Assert.Equal(ms.Length, bytesRead); } @@ -453,7 +458,10 @@ public async Task ClientCanSendTwoWayPipeToServer(bool serverUsesStream) { (IDuplexPipe, IDuplexPipe) pipePair = FullDuplexStream.CreatePipePair(); Task twoWayCom = TwoWayTalkAsync(pipePair.Item1, writeOnOdd: true, this.TimeoutToken); - await this.clientRpc.InvokeWithCancellationAsync(serverUsesStream ? nameof(Server.TwoWayStreamAsArg) : nameof(Server.TwoWayPipeAsArg), new object[] { false, pipePair.Item2 }, this.TimeoutToken); + await this.clientRpc.InvokeWithCancellationAsync( + serverUsesStream ? nameof(Server.TwoWayStreamAsArg) : nameof(Server.TwoWayPipeAsArg), + [false, pipePair.Item2], + this.TimeoutToken); await twoWayCom.WithCancellation(this.TimeoutToken); // rethrow any exceptions. // Confirm that we can see the server is no longer writing. @@ -465,13 +473,33 @@ public async Task ClientCanSendTwoWayPipeToServer(bool serverUsesStream) pipePair.Item1.Input.Complete(); } - [Theory] + [SkippableTheory] [CombinatorialData] public async Task ClientCanSendTwoWayStreamToServer(bool serverUsesStream) + { + Skip.If(this.GetType() == typeof(DuplexPipeMarshalingNerdbankMessagePackTests), "This test is not supported with NerdBankMessagePack."); + (Stream, Stream) streamPair = FullDuplexStream.CreatePair(); + Task twoWayCom = TwoWayTalkAsync(streamPair.Item1, writeOnOdd: true, this.TimeoutToken); + await this.clientRpc.InvokeWithCancellationAsync( + serverUsesStream ? nameof(Server.TwoWayStreamAsArg) : nameof(Server.TwoWayPipeAsArg), + [false, streamPair.Item2], + this.TimeoutToken); + await twoWayCom.WithCancellation(this.TimeoutToken); // rethrow any exceptions. + + streamPair.Item1.Dispose(); + } + + [Theory] + [CombinatorialData] + public async Task ClientCanSendTwoWayStreamToServer_WithExplicitTypes(bool serverUsesStream) { (Stream, Stream) streamPair = FullDuplexStream.CreatePair(); Task twoWayCom = TwoWayTalkAsync(streamPair.Item1, writeOnOdd: true, this.TimeoutToken); - await this.clientRpc.InvokeWithCancellationAsync(serverUsesStream ? nameof(Server.TwoWayStreamAsArg) : nameof(Server.TwoWayPipeAsArg), new object[] { false, streamPair.Item2 }, this.TimeoutToken); + await this.clientRpc.InvokeWithCancellationAsync( + serverUsesStream ? nameof(Server.TwoWayStreamAsArg) : nameof(Server.TwoWayPipeAsArg), + [false, streamPair.Item2], + [typeof(bool), typeof(Stream)], + this.TimeoutToken); await twoWayCom.WithCancellation(this.TimeoutToken); // rethrow any exceptions. streamPair.Item1.Dispose(); @@ -481,7 +509,10 @@ public async Task ClientCanSendTwoWayStreamToServer(bool serverUsesStream) public async Task PipeRemainsOpenAfterSuccessfulServerResult() { (IDuplexPipe, IDuplexPipe) pipePair = FullDuplexStream.CreatePipePair(); - await this.clientRpc.InvokeWithCancellationAsync(nameof(Server.AcceptPipeAndChatLater), new object[] { false, pipePair.Item2 }, this.TimeoutToken); + await this.clientRpc.InvokeWithCancellationAsync( + nameof(Server.AcceptPipeAndChatLater), + [false, pipePair.Item2], + this.TimeoutToken); await WhenAllSucceedOrAnyFault(TwoWayTalkAsync(pipePair.Item1, writeOnOdd: true, this.TimeoutToken), this.server.ChatLaterTask!); pipePair.Item1.Output.Complete(); @@ -495,7 +526,10 @@ public async Task PipeRemainsOpenAfterSuccessfulServerResult() public async Task ClientClosesChannelsWhenServerErrorsOut() { (IDuplexPipe, IDuplexPipe) pipePair = FullDuplexStream.CreatePipePair(); - await Assert.ThrowsAsync(() => this.clientRpc.InvokeWithCancellationAsync(nameof(Server.RejectCall), new object[] { pipePair.Item2 }, this.TimeoutToken)); + await Assert.ThrowsAsync(() => this.clientRpc.InvokeWithCancellationAsync( + nameof(Server.RejectCall), + [pipePair.Item2], + this.TimeoutToken)); // Verify that the pipe is closed. ReadResult readResult = await pipePair.Item1.Input.ReadAsync(this.TimeoutToken); @@ -506,7 +540,10 @@ public async Task ClientClosesChannelsWhenServerErrorsOut() public async Task PipesCloseWhenConnectionCloses() { (IDuplexPipe, IDuplexPipe) pipePair = FullDuplexStream.CreatePipePair(); - await this.clientRpc.InvokeWithCancellationAsync(nameof(Server.AcceptPipeAndChatLater), new object[] { false, pipePair.Item2 }, this.TimeoutToken); + await this.clientRpc.InvokeWithCancellationAsync( + nameof(Server.AcceptPipeAndChatLater), + [false, pipePair.Item2], + this.TimeoutToken); this.clientRpc.Dispose(); @@ -526,7 +563,10 @@ public async Task ClientSendsMultiplePipes() (IDuplexPipe, IDuplexPipe) pipePair1 = FullDuplexStream.CreatePipePair(); (IDuplexPipe, IDuplexPipe) pipePair2 = FullDuplexStream.CreatePipePair(); - await this.clientRpc.InvokeWithCancellationAsync(nameof(Server.TwoPipes), new object[] { pipePair1.Item2, pipePair2.Item2 }, this.TimeoutToken); + await this.clientRpc.InvokeWithCancellationAsync( + nameof(Server.TwoPipes), + [pipePair1.Item2, pipePair2.Item2], + this.TimeoutToken); pipePair1.Item1.Output.Complete(); pipePair2.Item1.Output.Complete(); @@ -603,7 +643,10 @@ public async Task ClientSendsPipeWhereServerHasMultipleOverloads() (IDuplexPipe, IDuplexPipe) pipePair = FullDuplexStream.CreatePipePair(); Task twoWayCom = TwoWayTalkAsync(pipePair.Item1, writeOnOdd: true, this.TimeoutToken); - await this.clientRpc.InvokeWithCancellationAsync(nameof(ServerWithOverloads.OverloadedMethod), new object[] { false, pipePair.Item2, "hi" }, this.TimeoutToken); + await this.clientRpc.InvokeWithCancellationAsync( + nameof(ServerWithOverloads.OverloadedMethod), + [false, pipePair.Item2, "hi"], + this.TimeoutToken); await twoWayCom.WithCancellation(this.TimeoutToken); // rethrow any exceptions. pipePair.Item1.Output.Complete(); @@ -761,6 +804,25 @@ private async Task ServerStreamIsDisposedWhenClientDisconnects(string rpcMethodN await serverStreamDisposal.WaitAsync(this.TimeoutToken); } + [DataContract] + [GenerateShape] + public partial class StreamContainingClass + { + [DataMember] + [PropertyShape(Ignore = false)] + private Stream innerStream; + + [ConstructorShape] + public StreamContainingClass(Stream innerStream) + { + this.innerStream = innerStream; + } + + [STJ.JsonPropertyName("innerStream")] + [PropertyShape(Name = "innerStream")] + public Stream InnerStream => this.innerStream; + } + #pragma warning disable CA1801 // Review unused parameters protected class ServerWithOverloads { @@ -779,7 +841,9 @@ public async Task OverloadedMethod(bool writeOnOdd, IDuplexPipe pipe, string mes public void OverloadedMethod(bool foo, int value, string[] values) => Assert.NotNull(values); } - protected class Server +#pragma warning disable SA1202 // Elements should be ordered by access + public class Server +#pragma warning restore SA1202 // Elements should be ordered by access { internal Task? ChatLaterTask { get; private set; } @@ -987,7 +1051,9 @@ protected class ServerWithIDuplexPipeReturningMethod public IDuplexPipe? MethodThatReturnsIDuplexPipe() => null; } - protected class OneWayWrapperStream : Stream +#pragma warning disable SA1202 // Elements should be ordered by access + public class OneWayWrapperStream : Stream +#pragma warning restore SA1202 // Elements should be ordered by access { private readonly Stream innerStream; private readonly bool canRead; @@ -1014,8 +1080,10 @@ internal OneWayWrapperStream(Stream innerStream, bool canRead = false, bool canW public override bool CanWrite => this.canWrite && this.innerStream.CanWrite; + [PropertyShape(Ignore = true)] public override long Length => throw new NotSupportedException(); + [PropertyShape(Ignore = true)] public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); } public override void Flush() @@ -1092,19 +1160,4 @@ protected override void Dispose(bool disposing) base.Dispose(disposing); } } - - [DataContract] - protected class StreamContainingClass - { - [DataMember] - private Stream innerStream; - - public StreamContainingClass(Stream innerStream) - { - this.innerStream = innerStream; - } - - [STJ.JsonPropertyName("innerStream")] - public Stream InnerStream => this.innerStream; - } } diff --git a/test/StreamJsonRpc.Tests/FormatterTestBase.cs b/test/StreamJsonRpc.Tests/FormatterTestBase.cs index 92803a257..9da28e887 100644 --- a/test/StreamJsonRpc.Tests/FormatterTestBase.cs +++ b/test/StreamJsonRpc.Tests/FormatterTestBase.cs @@ -1,10 +1,5 @@ using System.Runtime.Serialization; -using MessagePack.Formatters; using Nerdbank.Streams; -using StreamJsonRpc; -using StreamJsonRpc.Protocol; -using Xunit; -using Xunit.Abstractions; public abstract class FormatterTestBase : TestBase where TFormatter : IJsonRpcMessageFormatter diff --git a/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs new file mode 100644 index 000000000..0f4192349 --- /dev/null +++ b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs @@ -0,0 +1,560 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Runtime.CompilerServices; +using Microsoft.VisualStudio.Threading; +using Nerdbank.MessagePack; +using PolyType; +using PolyType.SourceGenerator; + +public partial class JsonRpcNerdbankMessagePackLengthTests : JsonRpcTests +{ + public JsonRpcNerdbankMessagePackLengthTests(ITestOutputHelper logger) + : base(logger) + { + } + + internal interface IMessagePackServer + { + Task ReturnUnionTypeAsync(CancellationToken cancellationToken); + + Task AcceptUnionTypeAndReturnStringAsync(UnionBaseClass value, CancellationToken cancellationToken); + + Task AcceptUnionTypeAsync(UnionBaseClass value, CancellationToken cancellationToken); + + Task ProgressUnionType(IProgress progress, CancellationToken cancellationToken); + + IAsyncEnumerable GetAsyncEnumerableOfUnionType(CancellationToken cancellationToken); + + Task IsExtensionArgNonNull(CustomExtensionType extensionValue); + } + + protected override Type FormatterExceptionType => typeof(MessagePackSerializationException); + + [Fact] + public override async Task CanPassAndCallPrivateMethodsObjects() + { + var result = await this.clientRpc.InvokeAsync(nameof(Server.MethodThatAcceptsFoo), new Foo { Bar = "bar", Bazz = 1000 }); + Assert.NotNull(result); + Assert.Equal("bar!", result.Bar); + Assert.Equal(1001, result.Bazz); + } + + [Fact] + public async Task ExceptionControllingErrorData() + { + var exception = await Assert.ThrowsAsync(() => this.clientRpc.InvokeAsync(nameof(Server.ThrowLocalRpcException))).WithCancellation(this.TimeoutToken); + + IDictionary? data = (IDictionary?)exception.ErrorData; + Assert.NotNull(data); + object myCustomData = data["myCustomData"]; + string actual = (string)myCustomData; + Assert.Equal("hi", actual); + } + + [Fact] + public override async Task CanPassExceptionFromServer_ErrorData() + { + RemoteInvocationException exception = await Assert.ThrowsAnyAsync(() => this.clientRpc.InvokeAsync(nameof(Server.MethodThatThrowsUnauthorizedAccessException))); + Assert.Equal((int)JsonRpcErrorCode.InvocationError, exception.ErrorCode); + + var errorData = Assert.IsType(exception.ErrorData); + Assert.NotNull(errorData.StackTrace); + Assert.StrictEqual(COR_E_UNAUTHORIZEDACCESS, errorData.HResult); + } + + /// + /// Verifies that return values can support union types by considering the return type as declared in the server method signature. + /// + [Fact] + public async Task UnionType_ReturnValue() + { + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(new MessagePackServer()); + UnionBaseClass result = await this.clientRpc.InvokeWithCancellationAsync(nameof(MessagePackServer.ReturnUnionTypeAsync), null, this.TimeoutToken); + Assert.IsType(result); + } + + /// + /// Verifies that return values can support union types by considering the return type as declared in the server method signature. + /// + [Fact] + public async Task UnionType_ReturnValue_NonAsync() + { + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(new MessagePackServer()); + UnionBaseClass result = await this.clientRpc.InvokeWithCancellationAsync(nameof(MessagePackServer.ReturnUnionType), null, this.TimeoutToken); + Assert.IsType(result); + } + + /// + /// Verifies that positional parameters can support union types by providing extra type information for each argument. + /// + [Theory] + [CombinatorialData] + public async Task UnionType_PositionalParameter_NoReturnValue(bool notify) + { + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + UnionBaseClass? receivedValue; + if (notify) + { + await this.clientRpc.NotifyAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), new object?[] { new UnionDerivedClass() }, new[] { typeof(UnionBaseClass) }).WithCancellation(this.TimeoutToken); + receivedValue = await server.ReceivedValueSource.Task.WithCancellation(this.TimeoutToken); + } + else + { + await this.clientRpc.InvokeWithCancellationAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), new object?[] { new UnionDerivedClass() }, new[] { typeof(UnionBaseClass) }, this.TimeoutToken); + receivedValue = server.ReceivedValue; + } + + Assert.IsType(receivedValue); + } + + /// + /// Verifies that positional parameters can support union types by providing extra type information for each argument. + /// + [Fact] + public async Task UnionType_PositionalParameter_AndReturnValue() + { + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(new MessagePackServer()); + string? result = await this.clientRpc.InvokeWithCancellationAsync(nameof(MessagePackServer.AcceptUnionTypeAndReturnStringAsync), new object?[] { new UnionDerivedClass() }, new[] { typeof(UnionBaseClass) }, this.TimeoutToken); + Assert.Equal(typeof(UnionDerivedClass).Name, result); + } + + /// + /// Verifies that the type information associated with named parameters is used for proper serialization of union types. + /// + [Theory] + [CombinatorialData] + public async Task UnionType_NamedParameter_NoReturnValue_UntypedDictionary(bool notify) + { + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + var argument = new Dictionary { { "value", new UnionDerivedClass() } }; + var argumentDeclaredTypes = new Dictionary { { "value", typeof(UnionBaseClass) } }; + + UnionBaseClass? receivedValue; + if (notify) + { + await this.clientRpc.NotifyWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), argument, argumentDeclaredTypes).WithCancellation(this.TimeoutToken); + receivedValue = await server.ReceivedValueSource.Task.WithCancellation(this.TimeoutToken); + } + else + { + await this.clientRpc.InvokeWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), argument, argumentDeclaredTypes, this.TimeoutToken); + receivedValue = server.ReceivedValue; + } + + Assert.IsType(receivedValue); + + // Exercise the non-init path by repeating + server.ReceivedValueSource = new TaskCompletionSource(); + if (notify) + { + await this.clientRpc.NotifyWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), argument, argumentDeclaredTypes).WithCancellation(this.TimeoutToken); + receivedValue = await server.ReceivedValueSource.Task.WithCancellation(this.TimeoutToken); + } + else + { + await this.clientRpc.InvokeWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), argument, argumentDeclaredTypes, this.TimeoutToken); + receivedValue = server.ReceivedValue; + } + + Assert.IsType(receivedValue); + } + + /// + /// Verifies that the type information associated with named parameters is used for proper serialization of union types. + /// + [Fact] + public async Task UnionType_NamedParameter_AndReturnValue_UntypedDictionary() + { + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + string? result = await this.clientRpc.InvokeWithParameterObjectAsync( + nameof(MessagePackServer.AcceptUnionTypeAndReturnStringAsync), + new Dictionary { { "value", new UnionDerivedClass() } }, + new Dictionary { { "value", typeof(UnionBaseClass) } }, + this.TimeoutToken); + Assert.Equal(typeof(UnionDerivedClass).Name, result); + Assert.IsType(server.ReceivedValue); + + // Exercise the non-init path by repeating + result = await this.clientRpc.InvokeWithParameterObjectAsync( + nameof(MessagePackServer.AcceptUnionTypeAndReturnStringAsync), + new Dictionary { { "value", new UnionDerivedClass() } }, + new Dictionary { { "value", typeof(UnionBaseClass) } }, + this.TimeoutToken); + Assert.Equal(typeof(UnionDerivedClass).Name, result); + Assert.IsType(server.ReceivedValue); + } + + /// + /// Verifies that the type information associated with named parameters is used for proper serialization of union types. + /// + [Theory] + [CombinatorialData] + public async Task UnionType_NamedParameter_NoReturnValue(bool notify) + { + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + var namedArgs = new { value = (UnionBaseClass)new UnionDerivedClass() }; + + UnionBaseClass? receivedValue; + if (notify) + { + await this.clientRpc.NotifyWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), namedArgs).WithCancellation(this.TimeoutToken); + receivedValue = await server.ReceivedValueSource.Task.WithCancellation(this.TimeoutToken); + } + else + { + await this.clientRpc.InvokeWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), namedArgs, this.TimeoutToken); + receivedValue = server.ReceivedValue; + } + + Assert.IsType(receivedValue); + + // Exercise the non-init path by repeating + server.ReceivedValueSource = new TaskCompletionSource(); + if (notify) + { + await this.clientRpc.NotifyWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), namedArgs).WithCancellation(this.TimeoutToken); + receivedValue = await server.ReceivedValueSource.Task.WithCancellation(this.TimeoutToken); + } + else + { + await this.clientRpc.InvokeWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), namedArgs, this.TimeoutToken); + receivedValue = server.ReceivedValue; + } + + Assert.IsType(receivedValue); + } + + /// + /// Verifies that the type information associated with named parameters is used for proper serialization of union types. + /// + [Fact] + public async Task UnionType_NamedParameter_AndReturnValue() + { + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(new MessagePackServer()); + string? result = await this.clientRpc.InvokeWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAndReturnStringAsync), new { value = (UnionBaseClass)new UnionDerivedClass() }, this.TimeoutToken); + Assert.Equal(typeof(UnionDerivedClass).Name, result); + } + + /// + /// Verifies that return values can support union types by considering the return type as declared in the server method signature. + /// + [Fact] + public async Task UnionType_ReturnValue_Proxy() + { + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(new MessagePackServer()); + var clientProxy = this.clientRpc.Attach(); + UnionBaseClass result = await clientProxy.ReturnUnionTypeAsync(this.TimeoutToken); + Assert.IsType(result); + } + + /// + /// Verifies that positional parameters can support union types by providing extra type information for each argument. + /// + [Fact] + public async Task UnionType_PositionalParameter_AndReturnValue_Proxy() + { + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(new MessagePackServer()); + var clientProxy = this.clientRpc.Attach(); + string? result = await clientProxy.AcceptUnionTypeAndReturnStringAsync(new UnionDerivedClass(), this.TimeoutToken); + Assert.Equal(typeof(UnionDerivedClass).Name, result); + + // Repeat the proxy call to exercise the non-init path of the dynamically generated proxy. + result = await clientProxy.AcceptUnionTypeAndReturnStringAsync(new UnionDerivedClass(), this.TimeoutToken); + Assert.Equal(typeof(UnionDerivedClass).Name, result); + } + + /// + /// Verifies that the type information associated with named parameters is used for proper serialization of union types. + /// + [Fact] + public async Task UnionType_NamedParameter_AndReturnValue_Proxy() + { + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(new MessagePackServer()); + var clientProxy = this.clientRpc.Attach(new JsonRpcProxyOptions { ServerRequiresNamedArguments = true }); + string? result = await clientProxy.AcceptUnionTypeAndReturnStringAsync(new UnionDerivedClass(), this.TimeoutToken); + Assert.Equal(typeof(UnionDerivedClass).Name, result); + + // Repeat the proxy call to exercise the non-init path of the dynamically generated proxy. + result = await clientProxy.AcceptUnionTypeAndReturnStringAsync(new UnionDerivedClass(), this.TimeoutToken); + Assert.Equal(typeof(UnionDerivedClass).Name, result); + } + + /// + /// Verifies that positional parameters can support union types by providing extra type information for each argument. + /// + [Fact] + public async Task UnionType_PositionalParameter_NoReturnValue_Proxy() + { + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + var clientProxy = this.clientRpc.Attach(); + await clientProxy.AcceptUnionTypeAsync(new UnionDerivedClass(), this.TimeoutToken); + Assert.IsType(server.ReceivedValue); + + // Repeat the proxy call to exercise the non-init path of the dynamically generated proxy. + server.ReceivedValueSource = new TaskCompletionSource(); + await clientProxy.AcceptUnionTypeAsync(new UnionDerivedClass(), this.TimeoutToken); + Assert.IsType(server.ReceivedValue); + } + + /// + /// Verifies that the type information associated with named parameters is used for proper serialization of union types. + /// + [Fact] + public async Task UnionType_NamedParameter_NoReturnValue_Proxy() + { + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + var clientProxy = this.clientRpc.Attach(new JsonRpcProxyOptions { ServerRequiresNamedArguments = true }); + await clientProxy.AcceptUnionTypeAsync(new UnionDerivedClass(), this.TimeoutToken); + Assert.IsType(server.ReceivedValue); + + // Repeat the proxy call to exercise the non-init path of the dynamically generated proxy. + server.ReceivedValueSource = new TaskCompletionSource(); + await clientProxy.AcceptUnionTypeAsync(new UnionDerivedClass(), this.TimeoutToken); + Assert.IsType(server.ReceivedValue); + } + + [Fact] + public async Task UnionType_AsIProgressTypeArgument() + { + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + var clientProxy = this.clientRpc.Attach(); + + var reportSource = new TaskCompletionSource(); + var progress = new Progress(v => reportSource.SetResult(v)); + await clientProxy.ProgressUnionType(progress, this.TimeoutToken); + Assert.IsType(await reportSource.Task.WithCancellation(this.TimeoutToken)); + } + + [Fact] + public async Task UnionType_AsAsyncEnumerableTypeArgument() + { + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + var clientProxy = this.clientRpc.Attach(); + + UnionBaseClass? actualItem = null; + await foreach (UnionBaseClass item in clientProxy.GetAsyncEnumerableOfUnionType(this.TimeoutToken)) + { + actualItem = item; + } + + Assert.IsType(actualItem); + } + + /// + /// Verifies that an argument that cannot be deserialized by the msgpack primitive formatter will not cause a failure. + /// + /// + /// This is a regression test for a bug where + /// verbose ETW tracing would fail to deserialize arguments with the primitive formatter that deserialize just fine for the actual method dispatch. + /// + [SkippableTheory, PairwiseData] + public async Task VerboseLoggingDoesNotFailWhenArgsDoNotDeserializePrimitively(bool namedArguments) + { + Skip.IfNot(SharedUtilities.GetEventSourceTestMode() == SharedUtilities.EventSourceTestMode.EmulateProduction, $"This test specifically verifies behavior when the EventSource should swallow exceptions. Current mode: {SharedUtilities.GetEventSourceTestMode()}."); + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + var clientProxy = this.clientRpc.Attach(new JsonRpcProxyOptions { ServerRequiresNamedArguments = namedArguments }); + + Assert.True(await clientProxy.IsExtensionArgNonNull(new CustomExtensionType())); + } + + protected override void InitializeFormattersAndHandlers( + Stream serverStream, + Stream clientStream, + out IJsonRpcMessageFormatter serverMessageFormatter, + out IJsonRpcMessageFormatter clientMessageFormatter, + out IJsonRpcMessageHandler serverMessageHandler, + out IJsonRpcMessageHandler clientMessageHandler, + bool controlledFlushingClient) + { + NerdbankMessagePackFormatter serverFormatter = new(); + serverFormatter.SetFormatterProfile(Configure); + + NerdbankMessagePackFormatter clientFormatter = new(); + clientFormatter.SetFormatterProfile(Configure); + + serverMessageFormatter = serverFormatter; + clientMessageFormatter = clientFormatter; + + serverMessageHandler = new LengthHeaderMessageHandler(serverStream, serverStream, serverMessageFormatter); + clientMessageHandler = controlledFlushingClient + ? new DelayedFlushingHandler(clientStream, clientMessageFormatter) + : new LengthHeaderMessageHandler(clientStream, clientStream, clientMessageFormatter); + + static void Configure(NerdbankMessagePackFormatter.Profile.Builder b) + { + b.RegisterConverter(new UnserializableTypeConverter()); + b.RegisterConverter(new TypeThrowsWhenDeserializedConverter()); + b.RegisterAsyncEnumerableConverter(); + b.RegisterAsyncEnumerableConverter(); + b.RegisterProgressConverter(); + b.RegisterProgressConverter(); + b.RegisterProgressConverter(); + b.RegisterProgressConverter, int>(); + b.RegisterProgressConverter, CustomSerializedType>(); + b.AddTypeShapeProvider(JsonRpcWitness.ShapeProvider); + b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); + } + } + + protected override object[] CreateFormatterIntrinsicParamsObject(string arg) => []; + + [GenerateShape] + [GenerateShape>] + public partial class JsonRpcWitness; + + [GenerateShape] +#if NET + [KnownSubType] +#else + [KnownSubType(typeof(UnionDerivedClass))] +#endif + public abstract partial class UnionBaseClass + { + } + + [GenerateShape] + public partial class UnionDerivedClass : UnionBaseClass + { + } + + [GenerateShape] + [MessagePackConverter(typeof(CustomExtensionConverter))] + internal partial class CustomExtensionType + { + } + + private class CustomExtensionConverter : MessagePackConverter + { + public override CustomExtensionType? Read(ref MessagePackReader reader, SerializationContext context) + { + if (reader.TryReadNil()) + { + return null; + } + + if (reader.ReadExtensionHeader() is { TypeCode: 1, Length: 0 }) + { + return new(); + } + + throw new Exception("Unexpected extension header."); + } + + public override void Write(ref MessagePackWriter writer, in CustomExtensionType? value, SerializationContext context) + { + if (value is null) + { + writer.WriteNil(); + } + else + { + writer.Write(new Extension(1, default(Memory))); + } + } + } + + private class UnserializableTypeConverter : MessagePackConverter + { + public override CustomSerializedType Read(ref MessagePackReader reader, SerializationContext context) + { + return new CustomSerializedType { Value = reader.ReadString() }; + } + + public override void Write(ref MessagePackWriter writer, in CustomSerializedType? value, SerializationContext context) + { + writer.Write(value?.Value); + } + } + + private class TypeThrowsWhenDeserializedConverter : MessagePackConverter + { + public override TypeThrowsWhenDeserialized Read(ref MessagePackReader reader, SerializationContext context) + { + throw CreateExceptionToBeThrownByDeserializer(); + } + + public override void Write(ref MessagePackWriter writer, in TypeThrowsWhenDeserialized? value, SerializationContext context) + { + writer.WriteArrayHeader(0); + } + } + + private class MessagePackServer : IMessagePackServer + { + internal UnionBaseClass? ReceivedValue { get; private set; } + + internal TaskCompletionSource ReceivedValueSource { get; set; } = new TaskCompletionSource(); + + public Task ReturnUnionTypeAsync(CancellationToken cancellationToken) => Task.FromResult(new UnionDerivedClass()); + + public Task AcceptUnionTypeAndReturnStringAsync(UnionBaseClass value, CancellationToken cancellationToken) => Task.FromResult((this.ReceivedValue = value)?.GetType().Name); + + public Task AcceptUnionTypeAsync(UnionBaseClass value, CancellationToken cancellationToken) + { + this.ReceivedValue = value; + this.ReceivedValueSource.SetResult(value); + return Task.CompletedTask; + } + + public UnionBaseClass ReturnUnionType() => new UnionDerivedClass(); + + public Task ProgressUnionType(IProgress progress, CancellationToken cancellationToken) + { + progress.Report(new UnionDerivedClass()); + return Task.CompletedTask; + } + + public async IAsyncEnumerable GetAsyncEnumerableOfUnionType([EnumeratorCancellation] CancellationToken cancellationToken) + { + await Task.Yield(); + yield return new UnionDerivedClass(); + } + + public Task IsExtensionArgNonNull(CustomExtensionType extensionValue) => Task.FromResult(extensionValue is not null); + } + + private class DelayedFlushingHandler : LengthHeaderMessageHandler, IControlledFlushHandler + { + public DelayedFlushingHandler(Stream stream, IJsonRpcMessageFormatter formatter) + : base(stream, stream, formatter) + { + } + + public AsyncAutoResetEvent FlushEntered { get; } = new AsyncAutoResetEvent(); + + public AsyncManualResetEvent AllowFlushAsyncExit { get; } = new AsyncManualResetEvent(); + + protected override async ValueTask FlushAsync(CancellationToken cancellationToken) + { + this.FlushEntered.Set(); + await this.AllowFlushAsyncExit.WaitAsync(CancellationToken.None); + await base.FlushAsync(cancellationToken); + } + } +} diff --git a/test/StreamJsonRpc.Tests/JsonRpcTests.cs b/test/StreamJsonRpc.Tests/JsonRpcTests.cs index 3ee01caf1..f751d4171 100644 --- a/test/StreamJsonRpc.Tests/JsonRpcTests.cs +++ b/test/StreamJsonRpc.Tests/JsonRpcTests.cs @@ -10,10 +10,11 @@ using Microsoft.VisualStudio.Threading; using Nerdbank.Streams; using Newtonsoft.Json.Linq; +using PolyType; using JsonNET = Newtonsoft.Json; using STJ = System.Text.Json.Serialization; -public abstract class JsonRpcTests : TestBase +public abstract partial class JsonRpcTests : TestBase { #pragma warning disable SA1310 // Field names should not contain underscore protected const int COR_E_UNAUTHORIZEDACCESS = unchecked((int)0x80070005); @@ -392,7 +393,7 @@ public async Task CanCallAsyncMethod() Assert.Equal("test!", result); } - [Theory, PairwiseData] + [Theory(Timeout = 2 * 1000), PairwiseData] // TODO: Temporary for development public async Task CanCallAsyncMethodThatThrows(ExceptionProcessing exceptionStrategy) { this.clientRpc.AllowModificationWhileListening = true; @@ -432,7 +433,7 @@ public async Task CanCallAsyncMethodThatThrowsNonSerializableException(Exception } } - [Theory, PairwiseData] + [Theory(Timeout = 2 * 1000), PairwiseData] // TODO: Temporary for development public async Task CanCallAsyncMethodThatThrowsExceptionWithoutDeserializingConstructor(ExceptionProcessing exceptionStrategy) { this.clientRpc.AllowModificationWhileListening = true; @@ -457,7 +458,7 @@ public async Task CanCallAsyncMethodThatThrowsExceptionWithoutDeserializingConst } } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development public async Task CanCallAsyncMethodThatThrowsExceptionWhileSerializingException() { this.clientRpc.AllowModificationWhileListening = true; @@ -471,7 +472,7 @@ public async Task CanCallAsyncMethodThatThrowsExceptionWhileSerializingException Assert.Null(exception.InnerException); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development public async Task ThrowCustomExceptionThatImplementsISerializableProperly() { this.clientRpc.AllowModificationWhileListening = true; @@ -1118,7 +1119,7 @@ public async Task CancelMayStillReturnResultFromServer() } } - [Theory, PairwiseData] + [Theory(Timeout = 2 * 1000), PairwiseData] // TODO: temporary for development public async Task CancelMayStillReturnErrorFromServer(ExceptionProcessing exceptionStrategy) { this.clientRpc.AllowModificationWhileListening = true; @@ -1946,7 +1947,7 @@ public void AddLocalRpcMethod_String_MethodInfo_Object_NullTargetForInstanceMeth Assert.Throws(() => this.serverRpc.AddLocalRpcMethod("biz.bar", methodInfo, null)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development public async Task ServerMethodIsCanceledWhenConnectionDrops() { this.ReinitializeRpcWithoutListening(); @@ -1962,7 +1963,7 @@ public async Task ServerMethodIsCanceledWhenConnectionDrops() Assert.Null(oce.InnerException); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development public async Task ServerMethodIsNotCanceledWhenConnectionDrops() { Assert.False(this.serverRpc.CancelLocallyInvokedMethodsWhenConnectionIsClosed); @@ -2248,7 +2249,7 @@ public async Task CanPassExceptionFromServer_DeserializedErrorData() Assert.StrictEqual(COR_E_UNAUTHORIZEDACCESS, errorData.HResult); } - [Theory, PairwiseData] + [Theory(Timeout = 2 * 1000), PairwiseData] // TODO: Temporary for development public async Task ExceptionTreeThrownFromServerIsDeserializedAtClient(ExceptionProcessing exceptionStrategy) { this.clientRpc.AllowModificationWhileListening = true; @@ -2477,7 +2478,7 @@ public async Task ExceptionRecursionLimit_ArgumentDeserialization() Assert.Equal(this.serverRpc.ExceptionOptions.RecursionLimit, CountRecursionLevel(this.server.ReceivedException)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development public async Task ExceptionRecursionLimit_ThrownSerialization() { this.serverRpc.AllowModificationWhileListening = true; @@ -2493,7 +2494,7 @@ public async Task ExceptionRecursionLimit_ThrownSerialization() Assert.Equal(this.clientRpc.ExceptionOptions.RecursionLimit, actualRecursionLevel); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development public async Task ExceptionRecursionLimit_ThrownDeserialization() { this.clientRpc.AllowModificationWhileListening = true; @@ -3315,7 +3316,7 @@ public class BaseClass } #pragma warning disable CA1801 // use all parameters - public class Server : BaseClass, IServerDerived + public partial class Server : BaseClass, IServerDerived { internal const string ExceptionMessage = "some message"; internal const string ThrowAfterCancellationMessage = "Throw after cancellation"; @@ -3906,10 +3907,12 @@ public ValueTask DisposeAsync() } [DataContract] - public class ParamsObjectWithCustomNames + [GenerateShape] + public partial class ParamsObjectWithCustomNames { [DataMember(Name = "argument")] [STJ.JsonPropertyName("argument")] + [PropertyShape(Name = "argument")] public string? TheArgument { get; set; } } @@ -3966,22 +3969,27 @@ public void IgnoredMethod() } [DataContract] - public class Foo + [GenerateShape] + public partial class Foo { [DataMember(Order = 0, IsRequired = true)] [STJ.JsonRequired, STJ.JsonPropertyOrder(0)] + [PropertyShape(Order = 0)] public string? Bar { get; set; } [DataMember(Order = 1)] [STJ.JsonPropertyOrder(1)] + [PropertyShape(Order = 1)] public int Bazz { get; set; } } - public class CustomSerializedType + [GenerateShape] + public partial class CustomSerializedType { // Ignore this so default serializers will drop it, proving that custom serializers were used if the value propagates. [JsonNET.JsonIgnore] [IgnoreDataMember] + [PropertyShape(Ignore = true)] public string? Value { get; set; } } @@ -3989,6 +3997,7 @@ public class CustomSerializedType public class CustomISerializableData : ISerializable { [MessagePack.SerializationConstructor] + [ConstructorShape] public CustomISerializableData(int major) { this.Major = major; diff --git a/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs new file mode 100644 index 000000000..22c7a0eef --- /dev/null +++ b/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Nerdbank.MessagePack; +using PolyType; + +public partial class MarshalableProxyNerdbankMessagePackTests : MarshalableProxyTests +{ + public MarshalableProxyNerdbankMessagePackTests(ITestOutputHelper logger) + : base(logger) + { + } + + protected override Type FormatterExceptionType => typeof(MessagePackSerializationException); + + protected override IJsonRpcMessageFormatter CreateFormatter() + { + NerdbankMessagePackFormatter formatter = new(); + formatter.SetFormatterProfile(b => + { + b.RegisterRpcMarshalableConverter(); + b.RegisterRpcMarshalableConverter(); + b.RegisterRpcMarshalableConverter(); + b.RegisterRpcMarshalableConverter(); + b.RegisterRpcMarshalableConverter(); + b.RegisterRpcMarshalableConverter(); + b.RegisterRpcMarshalableConverter(); + b.RegisterRpcMarshalableConverter(); + b.RegisterRpcMarshalableConverter(); + b.RegisterRpcMarshalableConverter(); + b.RegisterRpcMarshalableConverter(); + b.RegisterRpcMarshalableConverter(); + b.RegisterRpcMarshalableConverter(); + b.RegisterRpcMarshalableConverter>(); + b.AddTypeShapeProvider(MarshalableProxyWitness.ShapeProvider); + b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); + }); + + return formatter; + } + + [GenerateShape] + [GenerateShape] + [GenerateShape] + public partial class MarshalableProxyWitness; +} diff --git a/test/StreamJsonRpc.Tests/MarshalableProxyTests.cs b/test/StreamJsonRpc.Tests/MarshalableProxyTests.cs index 8d51a31b2..e0ab4ec7a 100644 --- a/test/StreamJsonRpc.Tests/MarshalableProxyTests.cs +++ b/test/StreamJsonRpc.Tests/MarshalableProxyTests.cs @@ -9,8 +9,10 @@ using MessagePack; using Microsoft; using Microsoft.VisualStudio.Threading; +using Nerdbank.MessagePack; using Nerdbank.Streams; using Newtonsoft.Json; +using PolyType; using StreamJsonRpc; using Xunit; using Xunit.Abstractions; @@ -49,6 +51,7 @@ protected MarshalableProxyTests(ITestOutputHelper logger) [RpcMarshalable] [JsonConverter(typeof(MarshalableConverter))] [MessagePackFormatter(typeof(MarshalableFormatter))] + [MessagePackConverter(typeof(MarshalableNerdbankConverter))] public interface IMarshalableAndSerializable : IMarshalable { private class MarshalableConverter : JsonConverter @@ -71,12 +74,25 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer private class MarshalableFormatter : MessagePack.Formatters.IMessagePackFormatter { - public IMarshalableAndSerializable Deserialize(ref MessagePackReader reader, MessagePackSerializerOptions options) + public IMarshalableAndSerializable Deserialize(ref MessagePack.MessagePackReader reader, MessagePackSerializerOptions options) { throw new NotImplementedException(); } - public void Serialize(ref MessagePackWriter writer, IMarshalableAndSerializable value, MessagePackSerializerOptions options) + public void Serialize(ref MessagePack.MessagePackWriter writer, IMarshalableAndSerializable value, MessagePackSerializerOptions options) + { + throw new NotImplementedException(); + } + } + + private class MarshalableNerdbankConverter : Nerdbank.MessagePack.MessagePackConverter + { + public override IMarshalableAndSerializable? Read(ref Nerdbank.MessagePack.MessagePackReader reader, Nerdbank.MessagePack.SerializationContext context) + { + throw new NotImplementedException(); + } + + public override void Write(ref Nerdbank.MessagePack.MessagePackWriter writer, in IMarshalableAndSerializable? value, Nerdbank.MessagePack.SerializationContext context) { throw new NotImplementedException(); } @@ -557,7 +573,7 @@ public async Task GenericMarshalableWithinArg_CanCallMethods() Assert.Equal(99, await ((IGenericMarshalable)this.server.ReceivedProxy!).DoSomethingWithParameterAsync(99)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task MarshalableReturnValue_Null() { IMarshalable? proxyMarshalable = await this.client.GetMarshalableAsync(returnNull: true); @@ -673,7 +689,7 @@ public async Task DisposeOnDisconnect() await disposed.WaitAsync(this.TimeoutToken); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface() { this.server.ReturnedMarshalableWithOptionalInterfaces = new MarshalableWithOptionalInterfaces(); @@ -691,7 +707,7 @@ public async Task RpcMarshalableOptionalInterface() AssertIsNot(proxy1, typeof(IMarshalableSubType2Extended)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface_JsonRpcMethodAttribute() { this.server.ReturnedMarshalableWithOptionalInterfaces = new MarshalableWithOptionalInterfaces(); @@ -709,7 +725,7 @@ public async Task RpcMarshalableOptionalInterface_JsonRpcMethodAttribute() Assert.Equal("foo", await this.clientRpc.InvokeAsync("$/invokeProxy/1/1.RemamedAsync", "foo")); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface_MethodNameTransform_Prefix() { var server = new Server(); @@ -737,7 +753,7 @@ public async Task RpcMarshalableOptionalInterface_MethodNameTransform_Prefix() Assert.Equal(1, await localRpc.InvokeAsync("$/invokeProxy/0/1.GetAsync", 1)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface_MethodNameTransform_CamelCase() { var server = new Server(); @@ -765,7 +781,7 @@ public async Task RpcMarshalableOptionalInterface_MethodNameTransform_CamelCase( Assert.Equal(1, await localRpc.InvokeAsync("$/invokeProxy/0/1.GetAsync", 1)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface_Null() { this.server.ReturnedMarshalableWithOptionalInterfaces = null; @@ -773,7 +789,7 @@ public async Task RpcMarshalableOptionalInterface_Null() Assert.Null(proxy); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface_IndirectInterfaceImplementation() { this.server.ReturnedMarshalableWithOptionalInterfaces = new MarshalableSubType1Indirect(); @@ -784,7 +800,7 @@ public async Task RpcMarshalableOptionalInterface_IndirectInterfaceImplementatio AssertIsNot(proxy, typeof(IMarshalableSubType2Extended)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface_WithExplicitImplementation() { this.server.ReturnedMarshalableWithOptionalInterfaces = new MarshalableSubType2(); @@ -795,7 +811,7 @@ public async Task RpcMarshalableOptionalInterface_WithExplicitImplementation() AssertIsNot(proxy, typeof(IMarshalableSubType2Extended)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface_UnknownSubType() { this.server.ReturnedMarshalableWithOptionalInterfaces = new MarshalableUnknownSubType(); @@ -806,7 +822,7 @@ public async Task RpcMarshalableOptionalInterface_UnknownSubType() AssertIsNot(proxy, typeof(IMarshalableSubType2Extended)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface_OnlyAttibutesOnDeclaredTypeAreHonored() { this.server.ReturnedMarshalableWithOptionalInterfaces = new MarshalableSubType2Extended(); @@ -821,7 +837,7 @@ public async Task RpcMarshalableOptionalInterface_OnlyAttibutesOnDeclaredTypeAre Assert.Equal(4, await ((IMarshalableSubType2Extended)proxy1).GetPlusThreeAsync(1)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface_OptionalInterfaceNotExtendingBase() { this.server.ReturnedMarshalableWithOptionalInterfaces = new MarshalableNonExtendingBase(); @@ -831,7 +847,7 @@ public async Task RpcMarshalableOptionalInterface_OptionalInterfaceNotExtendingB Assert.Equal(5, await ((IMarshalableNonExtendingBase)proxy).GetPlusFourAsync(1)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface_IntermediateNonMarshalableInterface() { this.server.ReturnedMarshalableWithOptionalInterfaces = new MarshalableSubTypeWithIntermediateInterface(); @@ -852,7 +868,7 @@ public async Task RpcMarshalableOptionalInterface_IntermediateNonMarshalableInte Assert.Equal(4, await ((IMarshalableSubTypeWithIntermediateInterface)proxy).GetPlusThreeAsync(1)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface_MultipleIntermediateInterfaces() { this.server.ReturnedMarshalableWithOptionalInterfaces = new MarshalableSubTypeWithIntermediateInterface1And2(); @@ -879,7 +895,7 @@ public async Task RpcMarshalableOptionalInterface_MultipleIntermediateInterfaces Assert.Equal(-3, await ((IMarshalableSubTypeIntermediateInterface)proxy2).GetPlusTwoAsync(1)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface_MultipleImplementations() { this.server.ReturnedMarshalableWithOptionalInterfaces = new MarshalableSubTypeMultipleImplementations(); @@ -904,7 +920,7 @@ public async Task RpcMarshalableOptionalInterface_MultipleImplementations() Assert.Equal(-1, await ((IMarshalableSubType2)proxy).GetMinusTwoAsync(1)); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalableOptionalInterface_MultipleImplementationsCombined() { this.server.ReturnedMarshalableWithOptionalInterfaces = new MarshalableSubTypesCombined(); @@ -946,7 +962,7 @@ public async Task RpcMarshalable_CallScopedLifetime() Assert.False(marshaled.IsDisposed); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalable_CallScopedLifetime_AsyncEnumerableReturned() { MarshalableAndSerializable marshaled = new(); @@ -962,7 +978,7 @@ public async Task RpcMarshalable_CallScopedLifetime_AsyncEnumerableReturned() await Assert.ThrowsAsync(() => this.server.ContinuationResult).WithCancellation(this.TimeoutToken); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development. public async Task RpcMarshalable_CallScopedLifetime_AsyncEnumerableThrown() { this.clientRpc.AllowModificationWhileListening = true; @@ -1250,8 +1266,10 @@ public Data(Action? disposeAction) [DataMember] public int Value { get; set; } + [PropertyShape(Ignore = true)] public bool IsDisposed { get; private set; } + [PropertyShape(Ignore = true)] public bool DoSomethingCalled { get; private set; } public void Dispose() diff --git a/test/StreamJsonRpc.Tests/MessagePackFormatterTests.cs b/test/StreamJsonRpc.Tests/MessagePackFormatterTests.cs index 39f633486..fe15c18b5 100644 --- a/test/StreamJsonRpc.Tests/MessagePackFormatterTests.cs +++ b/test/StreamJsonRpc.Tests/MessagePackFormatterTests.cs @@ -5,10 +5,6 @@ using MessagePack.Resolvers; using Microsoft.VisualStudio.Threading; using Nerdbank.Streams; -using StreamJsonRpc; -using StreamJsonRpc.Protocol; -using Xunit; -using Xunit.Abstractions; public class MessagePackFormatterTests : FormatterTestBase { diff --git a/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs b/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs new file mode 100644 index 000000000..3f60df26a --- /dev/null +++ b/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs @@ -0,0 +1,462 @@ +using System.Diagnostics; +using System.Runtime.Serialization; +using Microsoft.VisualStudio.Threading; +using Nerdbank.MessagePack; +using Nerdbank.Streams; +using PolyType; +using PolyType.ReflectionProvider; +using PolyType.SourceGenerator; + +public partial class NerdbankMessagePackFormatterTests : FormatterTestBase +{ + public NerdbankMessagePackFormatterTests(ITestOutputHelper logger) + : base(logger) + { + } + + [Fact] + public void JsonRpcRequest_PositionalArgs() + { + var original = new JsonRpcRequest + { + RequestId = new RequestId(5), + Method = "test", + ArgumentsList = new object[] { 5, "hi", new CustomType { Age = 8 } }, + }; + + var actual = this.Roundtrip(original); + Assert.Equal(original.RequestId, actual.RequestId); + Assert.Equal(original.Method, actual.Method); + + Assert.True(actual.TryGetArgumentByNameOrIndex(null, 0, typeof(int), out object? actualArg0)); + Assert.Equal(original.ArgumentsList[0], actualArg0); + + Assert.True(actual.TryGetArgumentByNameOrIndex(null, 1, typeof(string), out object? actualArg1)); + Assert.Equal(original.ArgumentsList[1], actualArg1); + + Assert.True(actual.TryGetArgumentByNameOrIndex(null, 2, typeof(CustomType), out object? actualArg2)); + Assert.Equal(((CustomType?)original.ArgumentsList[2])!.Age, ((CustomType)actualArg2!).Age); + } + + [Fact] + public void JsonRpcRequest_NamedArgs() + { + var original = new JsonRpcRequest + { + RequestId = new RequestId(5), + Method = "test", + NamedArguments = new Dictionary + { + { "Number", 5 }, + { "Message", "hi" }, + { "Custom", new CustomType { Age = 8 } }, + }, + }; + + var actual = this.Roundtrip(original); + Assert.Equal(original.RequestId, actual.RequestId); + Assert.Equal(original.Method, actual.Method); + + Assert.True(actual.TryGetArgumentByNameOrIndex("Number", -1, typeof(int), out object? actualArg0)); + Assert.Equal(original.NamedArguments["Number"], actualArg0); + + Assert.True(actual.TryGetArgumentByNameOrIndex("Message", -1, typeof(string), out object? actualArg1)); + Assert.Equal(original.NamedArguments["Message"], actualArg1); + + Assert.True(actual.TryGetArgumentByNameOrIndex("Custom", -1, typeof(CustomType), out object? actualArg2)); + Assert.Equal(((CustomType?)original.NamedArguments["Custom"])!.Age, ((CustomType)actualArg2!).Age); + } + + [Fact] + public void JsonRpcResult() + { + var original = new JsonRpcResult + { + RequestId = new RequestId(5), + Result = new CustomType { Age = 7 }, + }; + + var actual = this.Roundtrip(original); + Assert.Equal(original.RequestId, actual.RequestId); + Assert.Equal(((CustomType?)original.Result)!.Age, actual.GetResult().Age); + } + + [Fact] + public void JsonRpcError() + { + var original = new JsonRpcError + { + RequestId = new RequestId(5), + Error = new JsonRpcError.ErrorDetail + { + Code = JsonRpcErrorCode.InvocationError, + Message = "Oops", + Data = new CustomType { Age = 15 }, + }, + }; + + var actual = this.Roundtrip(original); + Assert.Equal(original.RequestId, actual.RequestId); + Assert.Equal(original.Error.Code, actual.Error!.Code); + Assert.Equal(original.Error.Message, actual.Error.Message); + Assert.Equal(((CustomType)original.Error.Data).Age, actual.Error.GetData().Age); + } + + [Fact] + public async Task BasicJsonRpc() + { + var (clientStream, serverStream) = FullDuplexStream.CreatePair(); + var clientFormatter = new NerdbankMessagePackFormatter(); + var serverFormatter = new NerdbankMessagePackFormatter(); + + var clientHandler = new LengthHeaderMessageHandler(clientStream.UsePipe(), clientFormatter); + var serverHandler = new LengthHeaderMessageHandler(serverStream.UsePipe(), serverFormatter); + + var clientRpc = new JsonRpc(clientHandler); + var serverRpc = new JsonRpc(serverHandler, new Server()); + + serverRpc.TraceSource = new TraceSource("Server", SourceLevels.Verbose); + clientRpc.TraceSource = new TraceSource("Client", SourceLevels.Verbose); + + serverRpc.TraceSource.Listeners.Add(new XunitTraceListener(this.Logger)); + clientRpc.TraceSource.Listeners.Add(new XunitTraceListener(this.Logger)); + + clientRpc.StartListening(); + serverRpc.StartListening(); + + int result = await clientRpc.InvokeAsync(nameof(Server.Add), 3, 5).WithCancellation(this.TimeoutToken); + Assert.Equal(8, result); + } + + [Fact] + public void Resolver_RequestArgInArray() + { + var originalArg = new TypeRequiringCustomFormatter { Prop1 = 3, Prop2 = 5 }; + var originalRequest = new JsonRpcRequest + { + RequestId = new RequestId(1), + Method = "Eat", + ArgumentsList = new object[] { originalArg }, + }; + var roundtripRequest = this.Roundtrip(originalRequest); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(null, 0, typeof(TypeRequiringCustomFormatter), out object? roundtripArgObj)); + var roundtripArg = (TypeRequiringCustomFormatter)roundtripArgObj!; + Assert.Equal(originalArg.Prop1, roundtripArg.Prop1); + Assert.Equal(originalArg.Prop2, roundtripArg.Prop2); + } + + [Fact] + public void Resolver_RequestArgInNamedArgs_AnonymousType() + { + this.Formatter.SetFormatterProfile(b => + { + b.RegisterConverter(new CustomConverter()); + b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + }); + + var originalArg = new { Prop1 = 3, Prop2 = 5 }; + var originalRequest = new JsonRpcRequest + { + RequestId = new RequestId(1), + Method = "Eat", + Arguments = originalArg, + }; + var roundtripRequest = this.Roundtrip(originalRequest); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.Prop1), -1, typeof(int), out object? prop1)); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.Prop2), -1, typeof(int), out object? prop2)); + Assert.Equal(originalArg.Prop1, prop1); + Assert.Equal(originalArg.Prop2, prop2); + } + + [Fact] + public void Resolver_RequestArgInNamedArgs_DataContractObject() + { + this.Formatter.SetFormatterProfile(b => + { + b.RegisterConverter(new CustomConverter()); + b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + }); + + var originalArg = new DataContractWithSubsetOfMembersIncluded { ExcludedField = "A", ExcludedProperty = "B", IncludedField = "C", IncludedProperty = "D" }; + var originalRequest = new JsonRpcRequest + { + RequestId = new RequestId(1), + Method = "Eat", + Arguments = originalArg, + }; + var roundtripRequest = this.Roundtrip(originalRequest); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.ExcludedField), -1, typeof(string), out object? _)); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.ExcludedProperty), -1, typeof(string), out object? _)); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.IncludedField), -1, typeof(string), out object? includedField)); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.IncludedProperty), -1, typeof(string), out object? includedProperty)); + Assert.Equal(originalArg.IncludedProperty, includedProperty); + Assert.Equal(originalArg.IncludedField, includedField); + } + + [Fact] + public void Resolver_RequestArgInNamedArgs_NonDataContractObject() + { + this.Formatter.SetFormatterProfile(b => + { + b.RegisterConverter(new CustomConverter()); + b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + }); + + var originalArg = new NonDataContractWithExcludedMembers { ExcludedField = "A", ExcludedProperty = "B", InternalField = "C", InternalProperty = "D", PublicField = "E", PublicProperty = "F" }; + var originalRequest = new JsonRpcRequest + { + RequestId = new RequestId(1), + Method = "Eat", + Arguments = originalArg, + }; + var roundtripRequest = this.Roundtrip(originalRequest); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.ExcludedField), -1, typeof(string), out object? _)); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.ExcludedProperty), -1, typeof(string), out object? _)); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.InternalField), -1, typeof(string), out object? _)); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.InternalProperty), -1, typeof(string), out object? _)); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.PublicField), -1, typeof(string), out object? publicField)); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.PublicProperty), -1, typeof(string), out object? publicProperty)); + Assert.Equal(originalArg.PublicProperty, publicProperty); + Assert.Equal(originalArg.PublicField, publicField); + } + + [Fact] + public void Resolver_RequestArgInNamedArgs_NullObject() + { + var originalRequest = new JsonRpcRequest + { + RequestId = new RequestId(1), + Method = "Eat", + Arguments = null, + }; + var roundtripRequest = this.Roundtrip(originalRequest); + Assert.Null(roundtripRequest.Arguments); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex("AnythingReally", -1, typeof(string), out object? _)); + } + + [Fact] + public void Resolver_Result() + { + this.Formatter.SetFormatterProfile(b => + { + b.RegisterConverter(new CustomConverter()); + b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + }); + + var originalResultValue = new TypeRequiringCustomFormatter { Prop1 = 3, Prop2 = 5 }; + var originalResult = new JsonRpcResult + { + RequestId = new RequestId(1), + Result = originalResultValue, + }; + var roundtripResult = this.Roundtrip(originalResult); + var roundtripResultValue = roundtripResult.GetResult(); + Assert.Equal(originalResultValue.Prop1, roundtripResultValue.Prop1); + Assert.Equal(originalResultValue.Prop2, roundtripResultValue.Prop2); + } + + [Fact] + public void Resolver_ErrorData() + { + var originalErrorData = new TypeRequiringCustomFormatter { Prop1 = 3, Prop2 = 5 }; + var originalError = new JsonRpcError + { + RequestId = new RequestId(1), + Error = new JsonRpcError.ErrorDetail + { + Data = originalErrorData, + }, + }; + var roundtripError = this.Roundtrip(originalError); + var roundtripErrorData = roundtripError.Error!.GetData(); + Assert.Equal(originalErrorData.Prop1, roundtripErrorData.Prop1); + Assert.Equal(originalErrorData.Prop2, roundtripErrorData.Prop2); + } + + [SkippableFact] + public void CanDeserializeWithExtraProperty_JsonRpcRequest() + { + Skip.If(this.Formatter is NerdbankMessagePackFormatter, "Dynamic types are not supported for NerdbankMessagePack."); + var dynamic = new + { + jsonrpc = "2.0", + method = "something", + extra = (object?)null, + @params = new object[] { "hi" }, + }; + var request = this.Read(dynamic); + Assert.Equal(dynamic.jsonrpc, request.Version); + Assert.Equal(dynamic.method, request.Method); + Assert.Equal(dynamic.@params.Length, request.ArgumentCount); + Assert.True(request.TryGetArgumentByNameOrIndex(null, 0, typeof(string), out object? arg)); + Assert.Equal(dynamic.@params[0], arg); + } + + [Fact] + public void CanDeserializeWithExtraProperty_JsonRpcResult() + { + var dynamic = new + { + jsonrpc = "2.0", + id = 2, + extra = (object?)null, + result = "hi", + }; + var request = this.Read(dynamic); + Assert.Equal(dynamic.jsonrpc, request.Version); + Assert.Equal(dynamic.id, request.RequestId.Number); + Assert.Equal(dynamic.result, request.GetResult()); + } + + [SkippableFact] + public void CanDeserializeWithExtraProperty_JsonRpcError() + { + Skip.If(this.Formatter is NerdbankMessagePackFormatter, "Dynamic types are not supported for NerdbankMessagePack."); + var dynamic = new + { + jsonrpc = "2.0", + id = 2, + extra = (object?)null, + error = new { extra = 2, code = 5 }, + }; + var request = this.Read(dynamic); + Assert.Equal(dynamic.jsonrpc, request.Version); + Assert.Equal(dynamic.id, request.RequestId.Number); + Assert.Equal(dynamic.error.code, (int?)request.Error?.Code); + } + + [Fact] + public void StringsInUserDataAreInterned() + { + Skip.If(this.Formatter is NerdbankMessagePackFormatter, "Dynamic types are not supported for NerdbankMessagePack."); + var dynamic = new + { + jsonrpc = "2.0", + method = "something", + extra = (object?)null, + @params = new object[] { "hi" }, + }; + var request1 = this.Read(dynamic); + var request2 = this.Read(dynamic); + Assert.True(request1.TryGetArgumentByNameOrIndex(null, 0, typeof(string), out object? arg1)); + Assert.True(request2.TryGetArgumentByNameOrIndex(null, 0, typeof(string), out object? arg2)); + Assert.Same(arg2, arg1); // reference equality to ensure it was interned. + } + + [Fact] + public void StringValuesOfStandardPropertiesAreInterned() + { + var dynamic = new + { + jsonrpc = "2.0", + method = "something", + extra = (object?)null, + @params = Array.Empty(), + }; + var request1 = this.Read(dynamic); + var request2 = this.Read(dynamic); + Assert.Same(request1.Method, request2.Method); // reference equality to ensure it was interned. + } + + protected override NerdbankMessagePackFormatter CreateFormatter() + { + NerdbankMessagePackFormatter formatter = new(); + formatter.SetFormatterProfile(b => + { + b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + b.AddTypeShapeProvider(ReflectionTypeShapeProvider.Default); + }); + + return formatter; + } + + private T Read(object anonymousObject) + where T : JsonRpcMessage + { + NerdbankMessagePackFormatter.Profile.Builder profileBuilder = this.Formatter.ProfileBuilder; + profileBuilder.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + profileBuilder.AddTypeShapeProvider(ReflectionTypeShapeProvider.Default); + NerdbankMessagePackFormatter.Profile profile = profileBuilder.Build(); + + this.Formatter.SetFormatterProfile(profile); + + var sequence = new Sequence(); + var writer = new MessagePackWriter(sequence); + profile.SerializeObject(ref writer, anonymousObject); + writer.Flush(); + return (T)this.Formatter.Deserialize(sequence.AsReadOnlySequence); + } + + [DataContract] + [GenerateShape] + public partial class DataContractWithSubsetOfMembersIncluded + { + [PropertyShape(Ignore = true)] + public string? ExcludedField; + + [DataMember] + internal string? IncludedField; + + [PropertyShape(Ignore = true)] + public string? ExcludedProperty { get; set; } + + [DataMember] + internal string? IncludedProperty { get; set; } + } + + [GenerateShape] + public partial class NonDataContractWithExcludedMembers + { + [IgnoreDataMember] + [PropertyShape(Ignore = true)] + public string? ExcludedField; + + public string? PublicField; + + internal string? InternalField; + + [IgnoreDataMember] + [PropertyShape(Ignore = true)] + public string? ExcludedProperty { get; set; } + + public string? PublicProperty { get; set; } + + internal string? InternalProperty { get; set; } + } + + [MessagePackConverter(typeof(CustomConverter))] + [GenerateShape] + public partial class TypeRequiringCustomFormatter + { + internal int Prop1 { get; set; } + + internal int Prop2 { get; set; } + } + + private class CustomConverter : MessagePackConverter + { + public override TypeRequiringCustomFormatter Read(ref MessagePackReader reader, SerializationContext context) + { + Assert.Equal(2, reader.ReadArrayHeader()); + return new TypeRequiringCustomFormatter + { + Prop1 = reader.ReadInt32(), + Prop2 = reader.ReadInt32(), + }; + } + + public override void Write(ref MessagePackWriter writer, in TypeRequiringCustomFormatter? value, SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + + writer.WriteArrayHeader(2); + writer.Write(value.Prop1); + writer.Write(value.Prop2); + } + } + + private class Server + { + public int Add(int a, int b) => a + b; + } +} diff --git a/test/StreamJsonRpc.Tests/ObserverMarshalingNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/ObserverMarshalingNerdbankMessagePackTests.cs new file mode 100644 index 000000000..d49d208e3 --- /dev/null +++ b/test/StreamJsonRpc.Tests/ObserverMarshalingNerdbankMessagePackTests.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using PolyType; + +public class ObserverMarshalingNerdbankMessagePackTests : ObserverMarshalingTests +{ + public ObserverMarshalingNerdbankMessagePackTests(ITestOutputHelper logger) + : base(logger) + { + } + + protected override IJsonRpcMessageFormatter CreateFormatter() + { + NerdbankMessagePackFormatter formatter = new(); + formatter.SetFormatterProfile(b => + { + b.RegisterRpcMarshalableConverter>(); + b.RegisterExceptionType(); + b.AddTypeShapeProvider(ObserverMarshalingWitness.ShapeProvider); + b.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); + }); + + return formatter; + } +} + +[GenerateShape] +#pragma warning disable SA1402 // File may only contain a single type +internal partial class ObserverMarshalingWitness; +#pragma warning restore SA1402 // File may only contain a single type diff --git a/test/StreamJsonRpc.Tests/ObserverMarshalingTests.cs b/test/StreamJsonRpc.Tests/ObserverMarshalingTests.cs index c17c1ed11..7319d595c 100644 --- a/test/StreamJsonRpc.Tests/ObserverMarshalingTests.cs +++ b/test/StreamJsonRpc.Tests/ObserverMarshalingTests.cs @@ -6,6 +6,7 @@ using System.Runtime.CompilerServices; using Microsoft; using Microsoft.VisualStudio.Threading; +using Nerdbank.MessagePack; using Nerdbank.Streams; using StreamJsonRpc; using Xunit; @@ -84,7 +85,7 @@ public async Task ReturnThenPushSequence() Assert.Equal(Enumerable.Range(1, 3), result); } - [Fact] + [Fact(Timeout = 2 * 1000)] // TODO: Temporary for development public async Task FaultImmediately() { var observer = new MockObserver(); diff --git a/test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj b/test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj index 82fb00a07..f7bf2232d 100644 --- a/test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj +++ b/test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj @@ -1,7 +1,7 @@  - net6.0;net8.0 + net8.0 $(TargetFrameworks);net472 @@ -11,30 +11,39 @@ + + + + + + + + + diff --git a/test/StreamJsonRpc.Tests/TargetObjectEventsNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/TargetObjectEventsNerdbankMessagePackTests.cs new file mode 100644 index 000000000..c1dff17b9 --- /dev/null +++ b/test/StreamJsonRpc.Tests/TargetObjectEventsNerdbankMessagePackTests.cs @@ -0,0 +1,25 @@ +public class TargetObjectEventsNerdbankMessagePackTests : TargetObjectEventsTests +{ + public TargetObjectEventsNerdbankMessagePackTests(ITestOutputHelper logger) + : base(logger) + { + } + + protected override void InitializeFormattersAndHandlers() + { + var serverMessageFormatter = new NerdbankMessagePackFormatter(); + var clientMessageFormatter = new NerdbankMessagePackFormatter(); + + serverMessageFormatter.SetFormatterProfile(ConfigureContext); + clientMessageFormatter.SetFormatterProfile(ConfigureContext); + + this.serverMessageHandler = new LengthHeaderMessageHandler(this.serverStream, this.serverStream, serverMessageFormatter); + this.clientMessageHandler = new LengthHeaderMessageHandler(this.clientStream, this.clientStream, clientMessageFormatter); + + void ConfigureContext(NerdbankMessagePackFormatter.Profile.Builder profileBuilder) + { + profileBuilder.AddTypeShapeProvider(PolyType.SourceGenerator.ShapeProvider_StreamJsonRpc_Tests.Default); + profileBuilder.AddTypeShapeProvider(PolyType.ReflectionProvider.ReflectionTypeShapeProvider.Default); + } + } +} diff --git a/test/StreamJsonRpc.Tests/TargetObjectEventsTests.cs b/test/StreamJsonRpc.Tests/TargetObjectEventsTests.cs index 9acc55756..3fe7f6d8b 100644 --- a/test/StreamJsonRpc.Tests/TargetObjectEventsTests.cs +++ b/test/StreamJsonRpc.Tests/TargetObjectEventsTests.cs @@ -3,11 +3,13 @@ using Microsoft; using Microsoft.VisualStudio.Threading; using Nerdbank; +using Nerdbank.MessagePack; +using PolyType; using StreamJsonRpc; using Xunit; using Xunit.Abstractions; -public abstract class TargetObjectEventsTests : TestBase +public abstract partial class TargetObjectEventsTests : TestBase { protected IJsonRpcMessageHandler serverMessageHandler = null!; protected IJsonRpcMessageHandler clientMessageHandler = null!; @@ -35,7 +37,11 @@ public TargetObjectEventsTests(ITestOutputHelper logger) } [MessagePack.Union(key: 0, typeof(Fruit))] - public interface IFruit + [GenerateShape] +#pragma warning disable CS0618 // Type or member is obsolete + [KnownSubType(typeof(Fruit), 1)] +#pragma warning restore CS0618 // Type or member is obsolete + public partial interface IFruit { string Name { get; } } @@ -360,7 +366,8 @@ private void ReinitializeRpcWithoutListening() } [DataContract] - public class Fruit : IFruit + [GenerateShape] + public partial class Fruit : IFruit { internal Fruit(string name) { diff --git a/test/StreamJsonRpc.Tests/TestBase.cs b/test/StreamJsonRpc.Tests/TestBase.cs index b242a7721..8053d8bb7 100644 --- a/test/StreamJsonRpc.Tests/TestBase.cs +++ b/test/StreamJsonRpc.Tests/TestBase.cs @@ -5,8 +5,9 @@ using System.Reflection; using System.Runtime.Serialization; using Microsoft.VisualStudio.Threading; +using PolyType; -public abstract class TestBase : IDisposable +public abstract partial class TestBase : IDisposable { protected static readonly TimeSpan ExpectedTimeout = TimeSpan.FromMilliseconds(200); @@ -198,6 +199,9 @@ protected virtual async Task CheckGCPressureAsync(Func scenario, int maxBy Assert.True(passingAttemptObserved); } + [GenerateShape.CustomType>] + internal partial class TestBaseWitness; + #pragma warning disable SYSLIB0050 // Type or member is obsolete private class RoundtripFormatter : IFormatterConverter #pragma warning restore SYSLIB0050 // Type or member is obsolete diff --git a/test/StreamJsonRpc.Tests/WebSocketMessageHandlerNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/WebSocketMessageHandlerNerdbankMessagePackTests.cs new file mode 100644 index 000000000..b14c7b25b --- /dev/null +++ b/test/StreamJsonRpc.Tests/WebSocketMessageHandlerNerdbankMessagePackTests.cs @@ -0,0 +1,7 @@ +public class WebSocketMessageHandlerNerdbankMessagePackTests : WebSocketMessageHandlerTests +{ + public WebSocketMessageHandlerNerdbankMessagePackTests(ITestOutputHelper logger) + : base(new NerdbankMessagePackFormatter(), logger) + { + } +}