-
Notifications
You must be signed in to change notification settings - Fork 149
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add an IChatClient implementation to OnnxRuntimeGenAI
- Loading branch information
1 parent
c9ffcb9
commit e3beb9c
Showing
2 changed files
with
282 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,278 @@ | ||
using Microsoft.Extensions.AI; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Runtime.CompilerServices; | ||
using System.Text; | ||
using System.Threading; | ||
using System.Threading.Tasks; | ||
|
||
namespace Microsoft.ML.OnnxRuntimeGenAI; | ||
|
||
/// <summary>An <see cref="IChatClient"/> implementation based on ONNX Runtime GenAI.</summary> | ||
public sealed class ChatClient : IChatClient, IDisposable | ||
{ | ||
/// <summary>The wrapped <see cref="Model"/>.</summary> | ||
private readonly Model _model; | ||
/// <summary>The wrapped <see cref="Tokenizer"/>.</summary> | ||
private readonly Tokenizer _tokenizer; | ||
/// <summary>Whether to dispose of <see cref="_model"/> when this instance is disposed.</summary> | ||
private readonly bool _ownsModel; | ||
|
||
/// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary> | ||
/// <param name="modelPath">The file path to the model to load.</param> | ||
/// <exception cref="ArgumentNullException"><paramref name="modelPath"/> is null.</exception> | ||
public ChatClient(string modelPath) | ||
{ | ||
if (modelPath is null) | ||
{ | ||
throw new ArgumentNullException(nameof(modelPath)); | ||
} | ||
|
||
_ownsModel = true; | ||
_model = new Model(modelPath); | ||
_tokenizer = new Tokenizer(_model); | ||
|
||
Metadata = new(typeof(ChatClient).Namespace, new Uri($"file://{modelPath}"), modelPath); | ||
} | ||
|
||
/// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary> | ||
/// <param name="model">The model to employ.</param> | ||
/// <param name="ownsModel"> | ||
/// <see langword="true"/> if this <see cref="IChatClient"/> owns the <paramref name="model"/> and should | ||
/// dispose of it when this <see cref="IChatClient"/> is disposed; otherwise, <see langword="false"/>. | ||
/// The default is <see langword="true"/>. | ||
/// </param> | ||
/// <exception cref="ArgumentNullException"><paramref name="model"/> is null.</exception> | ||
public ChatClient(Model model, bool ownsModel = true) | ||
{ | ||
if (model is null) | ||
{ | ||
throw new ArgumentNullException(nameof(model)); | ||
} | ||
|
||
_ownsModel = ownsModel; | ||
_model = model; | ||
_tokenizer = new Tokenizer(_model); | ||
|
||
Metadata = new("Microsoft.ML.OnnxRuntimeGenAI"); | ||
} | ||
|
||
/// <inheritdoc/> | ||
public ChatClientMetadata Metadata { get; } | ||
|
||
/// <summary> | ||
/// Gets or sets stop sequences to use during generation. | ||
/// </summary> | ||
/// <remarks> | ||
/// These will apply in addition to any stop sequences that are a part of the <see cref="ChatOptions.StopSequences"/>. | ||
/// </remarks> | ||
public IList<string> StopSequences { get; set; } = | ||
[ | ||
// Default stop sequences based on Phi3 | ||
"<|system|>", | ||
"<|user|>", | ||
"<|assistant|>", | ||
"<|end|>" | ||
]; | ||
|
||
/// <summary> | ||
/// Gets or sets a function that creates a prompt string from the chat history. | ||
/// </summary> | ||
public Func<IEnumerable<ChatMessage>, string> PromptFormatter { get; set; } | ||
|
||
/// <inheritdoc/> | ||
public void Dispose() | ||
{ | ||
_tokenizer.Dispose(); | ||
|
||
if (_ownsModel) | ||
{ | ||
_model.Dispose(); | ||
} | ||
} | ||
|
||
/// <inheritdoc/> | ||
public async Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions options = null, CancellationToken cancellationToken = default) | ||
{ | ||
if (chatMessages is null) | ||
{ | ||
throw new ArgumentNullException(nameof(chatMessages)); | ||
} | ||
|
||
StringBuilder text = new(); | ||
await Task.Run(() => | ||
{ | ||
using Sequences tokens = _tokenizer.Encode(CreatePrompt(chatMessages)); | ||
using GeneratorParams generatorParams = new(_model); | ||
UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options); | ||
generatorParams.SetInputSequences(tokens); | ||
|
||
using Generator generator = new(_model, generatorParams); | ||
using var tokenizerStream = _tokenizer.CreateStream(); | ||
|
||
var completionId = Guid.NewGuid().ToString(); | ||
while (!generator.IsDone()) | ||
{ | ||
cancellationToken.ThrowIfCancellationRequested(); | ||
|
||
generator.ComputeLogits(); | ||
generator.GenerateNextToken(); | ||
|
||
ReadOnlySpan<int> outputSequence = generator.GetSequence(0); | ||
string next = tokenizerStream.Decode(outputSequence[outputSequence.Length - 1]); | ||
|
||
if (IsStop(next, options)) | ||
{ | ||
break; | ||
} | ||
|
||
text.Append(next); | ||
} | ||
}, cancellationToken); | ||
|
||
return new ChatCompletion(new ChatMessage(ChatRole.Assistant, text.ToString())) | ||
{ | ||
CompletionId = Guid.NewGuid().ToString(), | ||
CreatedAt = DateTimeOffset.UtcNow, | ||
ModelId = Metadata.ModelId, | ||
}; | ||
} | ||
|
||
/// <inheritdoc/> | ||
public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync( | ||
IList<ChatMessage> chatMessages, ChatOptions options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | ||
{ | ||
if (chatMessages is null) | ||
{ | ||
throw new ArgumentNullException(nameof(chatMessages)); | ||
} | ||
|
||
using Sequences tokens = _tokenizer.Encode(CreatePrompt(chatMessages)); | ||
using GeneratorParams generatorParams = new(_model); | ||
UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options); | ||
generatorParams.SetInputSequences(tokens); | ||
|
||
using Generator generator = new(_model, generatorParams); | ||
using var tokenizerStream = _tokenizer.CreateStream(); | ||
|
||
var completionId = Guid.NewGuid().ToString(); | ||
while (!generator.IsDone()) | ||
{ | ||
string next = await Task.Run(() => | ||
{ | ||
generator.ComputeLogits(); | ||
generator.GenerateNextToken(); | ||
|
||
ReadOnlySpan<int> outputSequence = generator.GetSequence(0); | ||
return tokenizerStream.Decode(outputSequence[outputSequence.Length - 1]); | ||
}, cancellationToken); | ||
|
||
if (IsStop(next, options)) | ||
{ | ||
break; | ||
} | ||
|
||
yield return new StreamingChatCompletionUpdate | ||
{ | ||
CompletionId = completionId, | ||
CreatedAt = DateTimeOffset.UtcNow, | ||
Role = ChatRole.Assistant, | ||
Text = next, | ||
}; | ||
} | ||
} | ||
|
||
/// <inheritdoc/> | ||
public TService GetService<TService>(object key = null) where TService : class => | ||
typeof(TService) == typeof(Model) ? (TService)(object)_model : | ||
typeof(TService) == typeof(Tokenizer) ? (TService)(object)_tokenizer : | ||
this as TService; | ||
|
||
/// <summary>Gets whether the specified token is a stop sequence.</summary> | ||
private bool IsStop(string token, ChatOptions options) => | ||
options?.StopSequences?.Contains(token) is true || | ||
StopSequences?.Contains(token) is true; | ||
|
||
/// <summary>Creates a prompt string from the supplied chat history.</summary> | ||
private string CreatePrompt(IEnumerable<ChatMessage> messages) | ||
{ | ||
if (messages is null) | ||
{ | ||
throw new ArgumentNullException(nameof(messages)); | ||
} | ||
|
||
if (PromptFormatter is not null) | ||
{ | ||
return PromptFormatter(messages) ?? string.Empty; | ||
} | ||
|
||
// Default formatting based on Phi3. | ||
StringBuilder prompt = new(); | ||
|
||
foreach (var message in messages) | ||
{ | ||
foreach (var content in message.Contents) | ||
{ | ||
switch (content) | ||
{ | ||
case TextContent tc when !string.IsNullOrWhiteSpace(tc.Text): | ||
prompt.Append("<|").Append(message.Role.Value).Append("|>\n") | ||
.Append(tc.Text.Replace("<|end|>\n", "")) | ||
.Append("<|end|>\n"); | ||
break; | ||
} | ||
} | ||
} | ||
|
||
prompt.Append("<|assistant|>"); | ||
|
||
return prompt.ToString(); | ||
} | ||
|
||
/// <summary>Updates the <paramref name="generatorParams"/> based on the supplied <paramref name="options"/>.</summary> | ||
private static void UpdateGeneratorParamsFromOptions(int numInputTokens, GeneratorParams generatorParams, ChatOptions options) | ||
{ | ||
if (options is null) | ||
{ | ||
return; | ||
} | ||
|
||
if (options.MaxOutputTokens.HasValue) | ||
{ | ||
generatorParams.SetSearchOption("max_length", numInputTokens + options.MaxOutputTokens.Value); | ||
} | ||
|
||
if (options.Temperature.HasValue) | ||
{ | ||
generatorParams.SetSearchOption("temperature", options.Temperature.Value); | ||
} | ||
|
||
if (options.TopP.HasValue || options.TopK.HasValue) | ||
{ | ||
if (options.TopP.HasValue) | ||
{ | ||
generatorParams.SetSearchOption("top_p", options.TopP.Value); | ||
} | ||
|
||
if (options.TopK.HasValue) | ||
{ | ||
generatorParams.SetSearchOption("top_k", options.TopK.Value); | ||
} | ||
} | ||
|
||
if (options.AdditionalProperties is { } props) | ||
{ | ||
foreach (var entry in props) | ||
{ | ||
switch (entry.Value) | ||
{ | ||
case int i: generatorParams.SetSearchOption(entry.Key, i); break; | ||
case long l: generatorParams.SetSearchOption(entry.Key, l); break; | ||
case float f: generatorParams.SetSearchOption(entry.Key, f); break; | ||
case double d: generatorParams.SetSearchOption(entry.Key, d); break; | ||
case bool b: generatorParams.SetSearchOption(entry.Key, b); break; | ||
} | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters