diff --git a/packages/tasks/src/tasks/text-generation/inference.ts b/packages/tasks/src/tasks/text-generation/inference.ts index 7dd54ab9e..b3c002c6e 100644 --- a/packages/tasks/src/tasks/text-generation/inference.ts +++ b/packages/tasks/src/tasks/text-generation/inference.ts @@ -3,6 +3,7 @@ * * Using src/scripts/inference-codegen */ + /** * Inputs for Text Generation inference */ @@ -17,6 +18,7 @@ export interface TextGenerationInput { parameters?: TextGenerationParameters; [property: string]: unknown; } + /** * Additional inference parameters * @@ -40,10 +42,23 @@ export interface TextGenerationParameters { * Whether to use logits sampling instead of greedy decoding when generating new tokens. */ do_sample?: boolean; + /** + * The maximum number of tokens to generate. + */ + max_new_tokens?: number; + /** + * The parameter for repetition penalty. A value of 1.0 means no penalty. See [this + * paper](https://hf.co/papers/1909.05858) for more details. + */ + repetition_penalty?: number; /** * Whether to prepend the prompt to the generated text. */ return_full_text?: boolean; + /** + * The random sampling seed. + */ + seed?: number; /** * Stop generating tokens if a member of `stop_sequences` is generated. */ @@ -76,10 +91,30 @@ export interface TextGenerationParameters { watermark?: boolean; [property: string]: unknown; } + +/** + * Outputs for Text Generation inference + */ +export interface TextGenerationOutput { + /** + * When enabled, details about the generation + */ + details?: TextGenerationOutputDetails; + /** + * The generated text + */ + generated_text: string; + [property: string]: unknown; +} + /** * When enabled, details about the generation */ export interface TextGenerationOutputDetails { + /** + * Details about additional sequences when best_of is provided + */ + best_of_sequences?: TextGenerationSequenceDetails[]; /** * The reason why the generation was stopped. */ @@ -99,6 +134,32 @@ export interface TextGenerationOutputDetails { tokens: Token[]; [property: string]: unknown; } + +export interface TextGenerationSequenceDetails { + /** + * The reason why the generation was stopped. + */ + finish_reason: FinishReason; + /** + * The generated text + */ + generated_text: number; + /** + * The number of generated tokens + */ + generated_tokens: number; + prefill: PrefillToken[]; + /** + * The random seed used for generation + */ + seed?: number; + /** + * The generated tokens and associated details + */ + tokens: Token[]; + [property: string]: unknown; +} + /** * The generated sequence reached the maximum allowed length * @@ -107,6 +168,7 @@ export interface TextGenerationOutputDetails { * One of the sequence in stop_sequences was generated */ export type FinishReason = "length" | "eos_token" | "stop_sequence"; + export interface PrefillToken { id: number; logprob: number; @@ -116,6 +178,7 @@ export interface PrefillToken { text: string; [property: string]: unknown; } + export interface Token { id: number; logprob: number; @@ -129,18 +192,3 @@ export interface Token { text: string; [property: string]: unknown; } -export type TextGenerationOutput = TextGenerationOutputElement[]; -/** - * Outputs for Text Generation inference - */ -export interface TextGenerationOutputElement { - /** - * When enabled, details about the generation - */ - details?: TextGenerationOutputDetails; - /** - * The generated text - */ - generated_text: string; - [property: string]: unknown; -} diff --git a/packages/tasks/src/tasks/text-generation/spec/input.json b/packages/tasks/src/tasks/text-generation/spec/input.json index f759d5f00..c40d8aab1 100644 --- a/packages/tasks/src/tasks/text-generation/spec/input.json +++ b/packages/tasks/src/tasks/text-generation/spec/input.json @@ -36,10 +36,22 @@ "type": "boolean", "description": "Whether to use logits sampling instead of greedy decoding when generating new tokens." }, + "max_new_tokens": { + "type": "integer", + "description": "The maximum number of tokens to generate." + }, + "repetition_penalty": { + "type": "number", + "description": "The parameter for repetition penalty. A value of 1.0 means no penalty. See [this paper](https://hf.co/papers/1909.05858) for more details." + }, "return_full_text": { "type": "boolean", "description": "Whether to prepend the prompt to the generated text." }, + "seed": { + "type": "integer", + "description": "The random sampling seed." + }, "stop_sequences": { "type": "array", "items": { diff --git a/packages/tasks/src/tasks/text-generation/spec/output.json b/packages/tasks/src/tasks/text-generation/spec/output.json index 85e472121..f6d1882c6 100644 --- a/packages/tasks/src/tasks/text-generation/spec/output.json +++ b/packages/tasks/src/tasks/text-generation/spec/output.json @@ -3,69 +3,45 @@ "$schema": "http://json-schema.org/draft-06/schema#", "description": "Outputs for Text Generation inference", "title": "TextGenerationOutput", - "type": "array", - "items": { - "type": "object", - "properties": { - "generated_text": { - "type": "string", - "description": "The generated text" - }, - "details": { - "type": "object", - "description": "When enabled, details about the generation", - "title": "TextGenerationOutputDetails", - "properties": { - "generated_tokens": { - "type": "integer", - "description": "The number of generated tokens" - }, - "finish_reason": { - "type": "string", - "description": "The reason why the generation was stopped.", - "oneOf": [ - { "const": "length", "description": "The generated sequence reached the maximum allowed length" }, - { "const": "eos_token", "description": "The model generated an end-of-sentence (EOS) token" }, - { "const": "stop_sequence", "description": "One of the sequence in stop_sequences was generated" } - ] - }, - "seed": { - "type": "integer", - "description": "The random seed used for generation" - }, - "tokens": { - "type": "array", - "description": "The generated tokens and associated details", - "items": { - "$ref": "#/$defs/Token" - } - }, - "prefill": { - "type": "array", - "items": { - "title": "PrefillToken", - "type": "object", - "properties": { - "id": { - "type": "integer" - }, - "logprob": { - "type": "number" - }, - "text": { - "type": "string", - "description": "The text associated with that token" - } - }, - "required": ["id", "logprob", "text"] + "type": "object", + "properties": { + "generated_text": { + "type": "string", + "description": "The generated text" + }, + "details": { + "description": "When enabled, details about the generation", + "title": "TextGenerationOutputDetails", + "allOf": [ + { + "type": "object", + "properties": { + "best_of_sequences": { + "type": "array", + "description": "Details about additional sequences when best_of is provided", + "items": { + "allOf": [ + { "$ref": "#/$defs/SequenceDetails" }, + { + "type": "object", + "properties": { + "generated_text": { + "type": "integer", + "description": "The generated text" + } + }, + "required": ["generated_text"] + } + ] + } } } }, - "required": ["generated_tokens", "finish_reason", "prefill", "tokens"] - } - }, - "required": ["generated_text"] + { "$ref": "#/$defs/SequenceDetails" } + ] + } }, + "required": ["generated_text"], "$defs": { "Token": { @@ -88,6 +64,57 @@ } }, "required": ["id", "logprob", "special", "text"] + }, + "SequenceDetails": { + "type": "object", + "title": "TextGenerationSequenceDetails", + "properties": { + "finish_reason": { + "type": "string", + "description": "The reason why the generation was stopped.", + "oneOf": [ + { "const": "length", "description": "The generated sequence reached the maximum allowed length" }, + { "const": "eos_token", "description": "The model generated an end-of-sentence (EOS) token" }, + { "const": "stop_sequence", "description": "One of the sequence in stop_sequences was generated" } + ] + }, + "generated_tokens": { + "type": "integer", + "description": "The number of generated tokens" + }, + "prefill": { + "type": "array", + "items": { + "title": "PrefillToken", + "type": "object", + "properties": { + "id": { + "type": "integer" + }, + "logprob": { + "type": "number" + }, + "text": { + "type": "string", + "description": "The text associated with that token" + } + }, + "required": ["id", "logprob", "text"] + } + }, + "seed": { + "type": "integer", + "description": "The random seed used for generation" + }, + "tokens": { + "type": "array", + "description": "The generated tokens and associated details", + "items": { + "$ref": "#/$defs/Token" + } + } + }, + "required": ["finish_reason", "generated_tokens", "prefill", "tokens"] } } }