diff --git a/src/csharp/ChatClient.cs b/src/csharp/ChatClient.cs
new file mode 100644
index 000000000..f4033d1aa
--- /dev/null
+++ b/src/csharp/ChatClient.cs
@@ -0,0 +1,232 @@
+using Microsoft.Extensions.AI;
+using System;
+using System.Collections.Generic;
+using System.Runtime.CompilerServices;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Microsoft.ML.OnnxRuntimeGenAI;
+
+/// Provides an implementation based on ONNX Runtime GenAI.
+public sealed partial class ChatClient : IChatClient
+{
+ /// The options used to configure the instance.
+ private readonly ChatClientConfiguration _config;
+ /// The wrapped .
+ private readonly Model _model;
+ /// The wrapped .
+ private readonly Tokenizer _tokenizer;
+ /// Whether to dispose of when this instance is disposed.
+ private readonly bool _ownsModel;
+
+ /// Initializes an instance of the class.
+ /// The file path to the model to load.
+ /// Options used to configure the client instance.
+ /// is null.
+ public ChatClient(string modelPath, ChatClientConfiguration configuration)
+ {
+ if (modelPath is null)
+ {
+ throw new ArgumentNullException(nameof(modelPath));
+ }
+
+ _ownsModel = true;
+ _model = new Model(modelPath);
+ _tokenizer = new Tokenizer(_model);
+
+ Metadata = new(typeof(ChatClient).Namespace, new Uri($"file://{modelPath}"), modelPath);
+ }
+
+ /// Initializes an instance of the class.
+ /// The model to employ.
+ ///
+ /// if this owns the and should
+ /// dispose of it when this is disposed; otherwise, .
+ /// The default is .
+ ///
+ /// is null.
+ public ChatClient(Model model, bool ownsModel = true)
+ {
+ if (model is null)
+ {
+ throw new ArgumentNullException(nameof(model));
+ }
+
+ _ownsModel = ownsModel;
+ _model = model;
+ _tokenizer = new Tokenizer(_model);
+
+ Metadata = new("onnxruntime-genai");
+ }
+
+ ///
+ public ChatClientMetadata Metadata { get; }
+
+ ///
+ public void Dispose()
+ {
+ _tokenizer.Dispose();
+
+ if (_ownsModel)
+ {
+ _model.Dispose();
+ }
+ }
+
+ ///
+ public async Task CompleteAsync(IList chatMessages, ChatOptions options = null, CancellationToken cancellationToken = default)
+ {
+ if (chatMessages is null)
+ {
+ throw new ArgumentNullException(nameof(chatMessages));
+ }
+
+ StringBuilder text = new();
+ await Task.Run(() =>
+ {
+ using Sequences tokens = _tokenizer.Encode(_config.PromptFormatter(chatMessages));
+ using GeneratorParams generatorParams = new(_model);
+ UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options);
+
+ using Generator generator = new(_model, generatorParams);
+ generator.AppendTokenSequences(tokens);
+
+ using var tokenizerStream = _tokenizer.CreateStream();
+
+ var completionId = Guid.NewGuid().ToString();
+ while (!generator.IsDone())
+ {
+ cancellationToken.ThrowIfCancellationRequested();
+
+ generator.GenerateNextToken();
+
+ ReadOnlySpan outputSequence = generator.GetSequence(0);
+ string next = tokenizerStream.Decode(outputSequence[outputSequence.Length - 1]);
+
+ if (IsStop(next, options))
+ {
+ break;
+ }
+
+ text.Append(next);
+ }
+ }, cancellationToken);
+
+ return new ChatCompletion(new ChatMessage(ChatRole.Assistant, text.ToString()))
+ {
+ CompletionId = Guid.NewGuid().ToString(),
+ CreatedAt = DateTimeOffset.UtcNow,
+ ModelId = Metadata.ModelId,
+ };
+ }
+
+ ///
+ public async IAsyncEnumerable CompleteStreamingAsync(
+ IList chatMessages, ChatOptions options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ if (chatMessages is null)
+ {
+ throw new ArgumentNullException(nameof(chatMessages));
+ }
+
+ using Sequences tokens = _tokenizer.Encode(_config.PromptFormatter(chatMessages));
+ using GeneratorParams generatorParams = new(_model);
+ UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options);
+
+ using Generator generator = new(_model, generatorParams);
+ generator.AppendTokenSequences(tokens);
+
+ using var tokenizerStream = _tokenizer.CreateStream();
+
+ var completionId = Guid.NewGuid().ToString();
+ while (!generator.IsDone())
+ {
+ string next = await Task.Run(() =>
+ {
+ generator.GenerateNextToken();
+
+ ReadOnlySpan outputSequence = generator.GetSequence(0);
+ return tokenizerStream.Decode(outputSequence[outputSequence.Length - 1]);
+ }, cancellationToken);
+
+ if (IsStop(next, options))
+ {
+ break;
+ }
+
+ yield return new StreamingChatCompletionUpdate
+ {
+ CompletionId = completionId,
+ CreatedAt = DateTimeOffset.UtcNow,
+ Role = ChatRole.Assistant,
+ Text = next,
+ };
+ }
+ }
+
+ ///
+ public object GetService(Type serviceType, object key = null) =>
+ key is not null ? null :
+ serviceType == typeof(Model) ? _model :
+ serviceType == typeof(Tokenizer) ? _tokenizer :
+ serviceType?.IsInstanceOfType(this) is true ? this :
+ null;
+
+ /// Gets whether the specified token is a stop sequence.
+ private bool IsStop(string token, ChatOptions options) =>
+ options?.StopSequences?.Contains(token) is true ||
+ Array.IndexOf(_config.StopSequences, token) >= 0;
+
+ /// Updates the based on the supplied .
+ private static void UpdateGeneratorParamsFromOptions(int numInputTokens, GeneratorParams generatorParams, ChatOptions options)
+ {
+ if (options is null)
+ {
+ return;
+ }
+
+ if (options.MaxOutputTokens.HasValue)
+ {
+ generatorParams.SetSearchOption("max_length", numInputTokens + options.MaxOutputTokens.Value);
+ }
+
+ if (options.Temperature.HasValue)
+ {
+ generatorParams.SetSearchOption("temperature", options.Temperature.Value);
+ }
+
+ if (options.TopP.HasValue || options.TopK.HasValue)
+ {
+ if (options.TopP.HasValue)
+ {
+ generatorParams.SetSearchOption("top_p", options.TopP.Value);
+ }
+
+ if (options.TopK.HasValue)
+ {
+ generatorParams.SetSearchOption("top_k", options.TopK.Value);
+ }
+ }
+
+ if (options.Seed.HasValue)
+ {
+ generatorParams.SetSearchOption("random_seed", options.Seed.Value);
+ }
+
+ if (options.AdditionalProperties is { } props)
+ {
+ foreach (var entry in props)
+ {
+ switch (entry.Value)
+ {
+ case int i: generatorParams.SetSearchOption(entry.Key, i); break;
+ case long l: generatorParams.SetSearchOption(entry.Key, l); break;
+ case float f: generatorParams.SetSearchOption(entry.Key, f); break;
+ case double d: generatorParams.SetSearchOption(entry.Key, d); break;
+ case bool b: generatorParams.SetSearchOption(entry.Key, b); break;
+ }
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/csharp/ChatClientConfiguration.cs b/src/csharp/ChatClientConfiguration.cs
new file mode 100644
index 000000000..282ae9362
--- /dev/null
+++ b/src/csharp/ChatClientConfiguration.cs
@@ -0,0 +1,73 @@
+using Microsoft.Extensions.AI;
+using System;
+using System.Collections.Generic;
+
+namespace Microsoft.ML.OnnxRuntimeGenAI;
+
+/// Provides configuration options used when constructing a .
+///
+/// Every model has different requirements for stop sequences and prompt formatting. For best results,
+/// the configuration should be tailored to the exact nature of the model being used. For example,
+/// when using a Phi3 model, a configuration like the following may be used:
+///
+/// static ChatClientConfiguration CreateForPhi3() =>
+/// new(["<|system|>", "<|user|>", "<|assistant|>", "<|end|>"],
+/// (IEnumerable<ChatMessage> messages) =>
+/// {
+/// StringBuilder prompt = new();
+///
+/// foreach (var message in messages)
+/// foreach (var content in message.Contents.OfType<TextContent>())
+/// prompt.Append("<|").Append(message.Role.Value).Append("|>\n").Append(tc.Text).Append("<|end|>\n");
+///
+/// return prompt.Append("<|assistant|>\n").ToString();
+/// });
+///
+///
+public sealed class ChatClientConfiguration
+{
+ private string[] _stopSequences;
+ private Func, string> _promptFormatter;
+
+ /// Initializes a new instance of the class.
+ /// The stop sequences used by the model.
+ /// The function to use to format a list of messages for input into the model.
+ /// is null.
+ /// is null.
+ public ChatClientConfiguration(
+ string[] stopSequences,
+ Func, string> promptFormatter)
+ {
+ if (stopSequences is null)
+ {
+ throw new ArgumentNullException(nameof(stopSequences));
+ }
+
+ if (promptFormatter is null)
+ {
+ throw new ArgumentNullException(nameof(promptFormatter));
+ }
+
+ StopSequences = stopSequences;
+ PromptFormatter = promptFormatter;
+ }
+
+ ///
+ /// Gets or sets stop sequences to use during generation.
+ ///
+ ///
+ /// These will apply in addition to any stop sequences that are a part of the .
+ ///
+ public string[] StopSequences
+ {
+ get => _stopSequences;
+ set => _stopSequences = value ?? throw new ArgumentNullException(nameof(value));
+ }
+
+ /// Gets the function that creates a prompt string from the chat history.
+ public Func, string> PromptFormatter
+ {
+ get => _promptFormatter;
+ set => _promptFormatter = value ?? throw new ArgumentNullException(nameof(value));
+ }
+}
\ No newline at end of file
diff --git a/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj b/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj
index ee53c83fb..83f729a7e 100644
--- a/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj
+++ b/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj
@@ -121,4 +121,8 @@
+
+
+
+
diff --git a/test/csharp/TestOnnxRuntimeGenAIAPI.cs b/test/csharp/TestOnnxRuntimeGenAIAPI.cs
index ddced9e42..9480d0d84 100644
--- a/test/csharp/TestOnnxRuntimeGenAIAPI.cs
+++ b/test/csharp/TestOnnxRuntimeGenAIAPI.cs
@@ -5,9 +5,11 @@
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
-using System.Runtime.CompilerServices;
using Xunit;
using Xunit.Abstractions;
+using System.Collections.Generic;
+using Microsoft.Extensions.AI;
+using System.Text;
namespace Microsoft.ML.OnnxRuntimeGenAI.Tests
{
@@ -349,6 +351,32 @@ public void TestTopKTopPSearch()
}
}
+ [IgnoreOnModelAbsenceFact(DisplayName = "TestChatClient")]
+ public async void TestChatClient()
+ {
+ using var client = new ChatClient(
+ _phi2Path,
+ new(["<|system|>", "<|user|>", "<|assistant|>", "<|end|>"],
+ (IEnumerable messages) =>
+ {
+ StringBuilder prompt = new();
+
+ foreach (var message in messages)
+ foreach (var content in message.Contents.OfType())
+ prompt.Append("<|").Append(message.Role.Value).Append("|>\n").Append(content.Text).Append("<|end|>\n");
+
+ return prompt.Append("<|assistant|>\n").ToString();
+ }));
+
+ var completion = await client.CompleteAsync("What is 2 + 3?", new()
+ {
+ MaxOutputTokens = 20,
+ Temperature = 0f,
+ });
+
+ Assert.Contains("5", completion.ToString());
+ }
+
[IgnoreOnModelAbsenceFact(DisplayName = "TestTokenizerBatchEncodeDecode")]
public void TestTokenizerBatchEncodeDecode()
{