diff --git a/src/csharp/ChatClient.cs b/src/csharp/ChatClient.cs new file mode 100644 index 000000000..f4033d1aa --- /dev/null +++ b/src/csharp/ChatClient.cs @@ -0,0 +1,232 @@ +using Microsoft.Extensions.AI; +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.ML.OnnxRuntimeGenAI; + +/// Provides an implementation based on ONNX Runtime GenAI. +public sealed partial class ChatClient : IChatClient +{ + /// The options used to configure the instance. + private readonly ChatClientConfiguration _config; + /// The wrapped . + private readonly Model _model; + /// The wrapped . + private readonly Tokenizer _tokenizer; + /// Whether to dispose of when this instance is disposed. + private readonly bool _ownsModel; + + /// Initializes an instance of the class. + /// The file path to the model to load. + /// Options used to configure the client instance. + /// is null. + public ChatClient(string modelPath, ChatClientConfiguration configuration) + { + if (modelPath is null) + { + throw new ArgumentNullException(nameof(modelPath)); + } + + _ownsModel = true; + _model = new Model(modelPath); + _tokenizer = new Tokenizer(_model); + + Metadata = new(typeof(ChatClient).Namespace, new Uri($"file://{modelPath}"), modelPath); + } + + /// Initializes an instance of the class. + /// The model to employ. + /// + /// if this owns the and should + /// dispose of it when this is disposed; otherwise, . + /// The default is . + /// + /// is null. + public ChatClient(Model model, bool ownsModel = true) + { + if (model is null) + { + throw new ArgumentNullException(nameof(model)); + } + + _ownsModel = ownsModel; + _model = model; + _tokenizer = new Tokenizer(_model); + + Metadata = new("onnxruntime-genai"); + } + + /// + public ChatClientMetadata Metadata { get; } + + /// + public void Dispose() + { + _tokenizer.Dispose(); + + if (_ownsModel) + { + _model.Dispose(); + } + } + + /// + public async Task CompleteAsync(IList chatMessages, ChatOptions options = null, CancellationToken cancellationToken = default) + { + if (chatMessages is null) + { + throw new ArgumentNullException(nameof(chatMessages)); + } + + StringBuilder text = new(); + await Task.Run(() => + { + using Sequences tokens = _tokenizer.Encode(_config.PromptFormatter(chatMessages)); + using GeneratorParams generatorParams = new(_model); + UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options); + + using Generator generator = new(_model, generatorParams); + generator.AppendTokenSequences(tokens); + + using var tokenizerStream = _tokenizer.CreateStream(); + + var completionId = Guid.NewGuid().ToString(); + while (!generator.IsDone()) + { + cancellationToken.ThrowIfCancellationRequested(); + + generator.GenerateNextToken(); + + ReadOnlySpan outputSequence = generator.GetSequence(0); + string next = tokenizerStream.Decode(outputSequence[outputSequence.Length - 1]); + + if (IsStop(next, options)) + { + break; + } + + text.Append(next); + } + }, cancellationToken); + + return new ChatCompletion(new ChatMessage(ChatRole.Assistant, text.ToString())) + { + CompletionId = Guid.NewGuid().ToString(), + CreatedAt = DateTimeOffset.UtcNow, + ModelId = Metadata.ModelId, + }; + } + + /// + public async IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, ChatOptions options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (chatMessages is null) + { + throw new ArgumentNullException(nameof(chatMessages)); + } + + using Sequences tokens = _tokenizer.Encode(_config.PromptFormatter(chatMessages)); + using GeneratorParams generatorParams = new(_model); + UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options); + + using Generator generator = new(_model, generatorParams); + generator.AppendTokenSequences(tokens); + + using var tokenizerStream = _tokenizer.CreateStream(); + + var completionId = Guid.NewGuid().ToString(); + while (!generator.IsDone()) + { + string next = await Task.Run(() => + { + generator.GenerateNextToken(); + + ReadOnlySpan outputSequence = generator.GetSequence(0); + return tokenizerStream.Decode(outputSequence[outputSequence.Length - 1]); + }, cancellationToken); + + if (IsStop(next, options)) + { + break; + } + + yield return new StreamingChatCompletionUpdate + { + CompletionId = completionId, + CreatedAt = DateTimeOffset.UtcNow, + Role = ChatRole.Assistant, + Text = next, + }; + } + } + + /// + public object GetService(Type serviceType, object key = null) => + key is not null ? null : + serviceType == typeof(Model) ? _model : + serviceType == typeof(Tokenizer) ? _tokenizer : + serviceType?.IsInstanceOfType(this) is true ? this : + null; + + /// Gets whether the specified token is a stop sequence. + private bool IsStop(string token, ChatOptions options) => + options?.StopSequences?.Contains(token) is true || + Array.IndexOf(_config.StopSequences, token) >= 0; + + /// Updates the based on the supplied . + private static void UpdateGeneratorParamsFromOptions(int numInputTokens, GeneratorParams generatorParams, ChatOptions options) + { + if (options is null) + { + return; + } + + if (options.MaxOutputTokens.HasValue) + { + generatorParams.SetSearchOption("max_length", numInputTokens + options.MaxOutputTokens.Value); + } + + if (options.Temperature.HasValue) + { + generatorParams.SetSearchOption("temperature", options.Temperature.Value); + } + + if (options.TopP.HasValue || options.TopK.HasValue) + { + if (options.TopP.HasValue) + { + generatorParams.SetSearchOption("top_p", options.TopP.Value); + } + + if (options.TopK.HasValue) + { + generatorParams.SetSearchOption("top_k", options.TopK.Value); + } + } + + if (options.Seed.HasValue) + { + generatorParams.SetSearchOption("random_seed", options.Seed.Value); + } + + if (options.AdditionalProperties is { } props) + { + foreach (var entry in props) + { + switch (entry.Value) + { + case int i: generatorParams.SetSearchOption(entry.Key, i); break; + case long l: generatorParams.SetSearchOption(entry.Key, l); break; + case float f: generatorParams.SetSearchOption(entry.Key, f); break; + case double d: generatorParams.SetSearchOption(entry.Key, d); break; + case bool b: generatorParams.SetSearchOption(entry.Key, b); break; + } + } + } + } +} \ No newline at end of file diff --git a/src/csharp/ChatClientConfiguration.cs b/src/csharp/ChatClientConfiguration.cs new file mode 100644 index 000000000..282ae9362 --- /dev/null +++ b/src/csharp/ChatClientConfiguration.cs @@ -0,0 +1,73 @@ +using Microsoft.Extensions.AI; +using System; +using System.Collections.Generic; + +namespace Microsoft.ML.OnnxRuntimeGenAI; + +/// Provides configuration options used when constructing a . +/// +/// Every model has different requirements for stop sequences and prompt formatting. For best results, +/// the configuration should be tailored to the exact nature of the model being used. For example, +/// when using a Phi3 model, a configuration like the following may be used: +/// +/// static ChatClientConfiguration CreateForPhi3() => +/// new(["<|system|>", "<|user|>", "<|assistant|>", "<|end|>"], +/// (IEnumerable<ChatMessage> messages) => +/// { +/// StringBuilder prompt = new(); +/// +/// foreach (var message in messages) +/// foreach (var content in message.Contents.OfType<TextContent>()) +/// prompt.Append("<|").Append(message.Role.Value).Append("|>\n").Append(tc.Text).Append("<|end|>\n"); +/// +/// return prompt.Append("<|assistant|>\n").ToString(); +/// }); +/// +/// +public sealed class ChatClientConfiguration +{ + private string[] _stopSequences; + private Func, string> _promptFormatter; + + /// Initializes a new instance of the class. + /// The stop sequences used by the model. + /// The function to use to format a list of messages for input into the model. + /// is null. + /// is null. + public ChatClientConfiguration( + string[] stopSequences, + Func, string> promptFormatter) + { + if (stopSequences is null) + { + throw new ArgumentNullException(nameof(stopSequences)); + } + + if (promptFormatter is null) + { + throw new ArgumentNullException(nameof(promptFormatter)); + } + + StopSequences = stopSequences; + PromptFormatter = promptFormatter; + } + + /// + /// Gets or sets stop sequences to use during generation. + /// + /// + /// These will apply in addition to any stop sequences that are a part of the . + /// + public string[] StopSequences + { + get => _stopSequences; + set => _stopSequences = value ?? throw new ArgumentNullException(nameof(value)); + } + + /// Gets the function that creates a prompt string from the chat history. + public Func, string> PromptFormatter + { + get => _promptFormatter; + set => _promptFormatter = value ?? throw new ArgumentNullException(nameof(value)); + } +} \ No newline at end of file diff --git a/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj b/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj index ee53c83fb..83f729a7e 100644 --- a/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj +++ b/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj @@ -121,4 +121,8 @@ + + + + diff --git a/test/csharp/TestOnnxRuntimeGenAIAPI.cs b/test/csharp/TestOnnxRuntimeGenAIAPI.cs index ddced9e42..9480d0d84 100644 --- a/test/csharp/TestOnnxRuntimeGenAIAPI.cs +++ b/test/csharp/TestOnnxRuntimeGenAIAPI.cs @@ -5,9 +5,11 @@ using System.IO; using System.Linq; using System.Runtime.InteropServices; -using System.Runtime.CompilerServices; using Xunit; using Xunit.Abstractions; +using System.Collections.Generic; +using Microsoft.Extensions.AI; +using System.Text; namespace Microsoft.ML.OnnxRuntimeGenAI.Tests { @@ -349,6 +351,32 @@ public void TestTopKTopPSearch() } } + [IgnoreOnModelAbsenceFact(DisplayName = "TestChatClient")] + public async void TestChatClient() + { + using var client = new ChatClient( + _phi2Path, + new(["<|system|>", "<|user|>", "<|assistant|>", "<|end|>"], + (IEnumerable messages) => + { + StringBuilder prompt = new(); + + foreach (var message in messages) + foreach (var content in message.Contents.OfType()) + prompt.Append("<|").Append(message.Role.Value).Append("|>\n").Append(content.Text).Append("<|end|>\n"); + + return prompt.Append("<|assistant|>\n").ToString(); + })); + + var completion = await client.CompleteAsync("What is 2 + 3?", new() + { + MaxOutputTokens = 20, + Temperature = 0f, + }); + + Assert.Contains("5", completion.ToString()); + } + [IgnoreOnModelAbsenceFact(DisplayName = "TestTokenizerBatchEncodeDecode")] public void TestTokenizerBatchEncodeDecode() {