From ea2d4711685ddc42e966b725e29e62d7a4de2642 Mon Sep 17 00:00:00 2001 From: Simon Brandeis <33657802+SBrandeis@users.noreply.github.com> Date: Mon, 12 Feb 2024 12:01:18 +0100 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20[Tasks]=20JSON=20spec:=20t?= =?UTF-8?q?ext-generation=20(#468)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # TL;DR - Update `text-generation` spec to match TGI API - ~Add `conversational` spec, heavily inspired by TGI messages API (cc @Wauplin @osanseviero @Narsil )~ - ~Relevant related work: https://github.com/huggingface/huggingface.js/pull/457 & https://github.com/huggingface/moon-landing/pull/8723~ - regenerate typescript code for those tasks --- .../src/tasks/text-generation/inference.ts | 114 +++++++++++++++++- .../src/tasks/text-generation/spec/input.json | 20 ++- .../tasks/text-generation/spec/output.json | 108 ++++++++++++++++- 3 files changed, 235 insertions(+), 7 deletions(-) diff --git a/packages/tasks/src/tasks/text-generation/inference.ts b/packages/tasks/src/tasks/text-generation/inference.ts index 0f0d1e8d7..b3c002c6e 100644 --- a/packages/tasks/src/tasks/text-generation/inference.ts +++ b/packages/tasks/src/tasks/text-generation/inference.ts @@ -26,11 +26,24 @@ export interface TextGenerationInput { */ export interface TextGenerationParameters { /** - * Whether to use logit sampling (true) or greedy search (false). + * The number of sampling queries to run. Only the best one (in terms of total logprob) will + * be returned. + */ + best_of?: number; + /** + * Whether or not to output decoder input details + */ + decoder_input_details?: boolean; + /** + * Whether or not to output details + */ + details?: boolean; + /** + * Whether to use logits sampling instead of greedy decoding when generating new tokens. */ do_sample?: boolean; /** - * Maximum number of generated tokens. + * The maximum number of tokens to generate. */ max_new_tokens?: number; /** @@ -42,6 +55,10 @@ export interface TextGenerationParameters { * Whether to prepend the prompt to the generated text. */ return_full_text?: boolean; + /** + * The random sampling seed. + */ + seed?: number; /** * Stop generating tokens if a member of `stop_sequences` is generated. */ @@ -79,10 +96,99 @@ export interface TextGenerationParameters { * Outputs for Text Generation inference */ export interface TextGenerationOutput { - generatedText: unknown; + /** + * When enabled, details about the generation + */ + details?: TextGenerationOutputDetails; + /** + * The generated text + */ + generated_text: string; + [property: string]: unknown; +} + +/** + * When enabled, details about the generation + */ +export interface TextGenerationOutputDetails { + /** + * Details about additional sequences when best_of is provided + */ + best_of_sequences?: TextGenerationSequenceDetails[]; + /** + * The reason why the generation was stopped. + */ + finish_reason: FinishReason; + /** + * The number of generated tokens + */ + generated_tokens: number; + prefill: PrefillToken[]; + /** + * The random seed used for generation + */ + seed?: number; + /** + * The generated tokens and associated details + */ + tokens: Token[]; + [property: string]: unknown; +} + +export interface TextGenerationSequenceDetails { + /** + * The reason why the generation was stopped. + */ + finish_reason: FinishReason; /** * The generated text */ - generated_text?: string; + generated_text: number; + /** + * The number of generated tokens + */ + generated_tokens: number; + prefill: PrefillToken[]; + /** + * The random seed used for generation + */ + seed?: number; + /** + * The generated tokens and associated details + */ + tokens: Token[]; + [property: string]: unknown; +} + +/** + * The generated sequence reached the maximum allowed length + * + * The model generated an end-of-sentence (EOS) token + * + * One of the sequence in stop_sequences was generated + */ +export type FinishReason = "length" | "eos_token" | "stop_sequence"; + +export interface PrefillToken { + id: number; + logprob: number; + /** + * The text associated with that token + */ + text: string; + [property: string]: unknown; +} + +export interface Token { + id: number; + logprob: number; + /** + * Whether or not that token is a special one + */ + special: boolean; + /** + * The text associated with that token + */ + text: string; [property: string]: unknown; } diff --git a/packages/tasks/src/tasks/text-generation/spec/input.json b/packages/tasks/src/tasks/text-generation/spec/input.json index 26fe24c8a..c40d8aab1 100644 --- a/packages/tasks/src/tasks/text-generation/spec/input.json +++ b/packages/tasks/src/tasks/text-generation/spec/input.json @@ -20,13 +20,25 @@ "description": "Additional inference parameters for Text Generation", "type": "object", "properties": { + "best_of": { + "type": "integer", + "description": "The number of sampling queries to run. Only the best one (in terms of total logprob) will be returned." + }, + "decoder_input_details": { + "type": "boolean", + "description": "Whether or not to output decoder input details" + }, + "details": { + "type": "boolean", + "description": "Whether or not to output details" + }, "do_sample": { "type": "boolean", - "description": "Whether to use logit sampling (true) or greedy search (false)." + "description": "Whether to use logits sampling instead of greedy decoding when generating new tokens." }, "max_new_tokens": { "type": "integer", - "description": "Maximum number of generated tokens." + "description": "The maximum number of tokens to generate." }, "repetition_penalty": { "type": "number", @@ -36,6 +48,10 @@ "type": "boolean", "description": "Whether to prepend the prompt to the generated text." }, + "seed": { + "type": "integer", + "description": "The random sampling seed." + }, "stop_sequences": { "type": "array", "items": { diff --git a/packages/tasks/src/tasks/text-generation/spec/output.json b/packages/tasks/src/tasks/text-generation/spec/output.json index b38bc8be3..dc68a3aad 100644 --- a/packages/tasks/src/tasks/text-generation/spec/output.json +++ b/packages/tasks/src/tasks/text-generation/spec/output.json @@ -8,7 +8,113 @@ "generated_text": { "type": "string", "description": "The generated text" + }, + "details": { + "description": "When enabled, details about the generation", + "title": "TextGenerationOutputDetails", + "allOf": [ + { "$ref": "#/$defs/SequenceDetails" }, + { + "type": "object", + "properties": { + "best_of_sequences": { + "type": "array", + "description": "Details about additional sequences when best_of is provided", + "items": { + "allOf": [ + { "$ref": "#/$defs/SequenceDetails" }, + { + "type": "object", + "properties": { + "generated_text": { + "type": "integer", + "description": "The generated text" + } + }, + "required": ["generated_text"] + } + ] + } + } + } + } + ] } }, - "required": ["generatedText"] + "required": ["generated_text"], + + "$defs": { + "Token": { + "type": "object", + "title": "Token", + "properties": { + "id": { + "type": "integer" + }, + "logprob": { + "type": "number" + }, + "special": { + "type": "boolean", + "description": "Whether or not that token is a special one" + }, + "text": { + "type": "string", + "description": "The text associated with that token" + } + }, + "required": ["id", "logprob", "special", "text"] + }, + "SequenceDetails": { + "type": "object", + "title": "TextGenerationSequenceDetails", + "properties": { + "finish_reason": { + "type": "string", + "description": "The reason why the generation was stopped.", + "oneOf": [ + { "const": "length", "description": "The generated sequence reached the maximum allowed length" }, + { "const": "eos_token", "description": "The model generated an end-of-sentence (EOS) token" }, + { "const": "stop_sequence", "description": "One of the sequence in stop_sequences was generated" } + ] + }, + "generated_tokens": { + "type": "integer", + "description": "The number of generated tokens" + }, + "prefill": { + "type": "array", + "items": { + "title": "PrefillToken", + "type": "object", + "properties": { + "id": { + "type": "integer" + }, + "logprob": { + "type": "number" + }, + "text": { + "type": "string", + "description": "The text associated with that token" + } + }, + "required": ["id", "logprob", "text"] + } + }, + "seed": { + "type": "integer", + "description": "The random seed used for generation" + }, + "tokens": { + "type": "array", + "description": "The generated tokens and associated details", + "items": { + "$ref": "#/$defs/Token" + } + } + }, + "required": ["finish_reason", "generated_tokens", "prefill", "tokens"] + } + } }