Skip to content

Commit

Permalink
text-generation: Add missing parameters + factorize
Browse files Browse the repository at this point in the history
  • Loading branch information
SBrandeis committed Feb 9, 2024
1 parent 3cad908 commit c8d0e9a
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 74 deletions.
78 changes: 63 additions & 15 deletions packages/tasks/src/tasks/text-generation/inference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
*
* Using src/scripts/inference-codegen
*/

/**
* Inputs for Text Generation inference
*/
Expand All @@ -17,6 +18,7 @@ export interface TextGenerationInput {
parameters?: TextGenerationParameters;
[property: string]: unknown;
}

/**
* Additional inference parameters
*
Expand All @@ -40,10 +42,23 @@ export interface TextGenerationParameters {
* Whether to use logits sampling instead of greedy decoding when generating new tokens.
*/
do_sample?: boolean;
/**
* The maximum number of tokens to generate.
*/
max_new_tokens?: number;
/**
* The parameter for repetition penalty. A value of 1.0 means no penalty. See [this
* paper](https://hf.co/papers/1909.05858) for more details.
*/
repetition_penalty?: number;
/**
* Whether to prepend the prompt to the generated text.
*/
return_full_text?: boolean;
/**
* The random sampling seed.
*/
seed?: number;
/**
* Stop generating tokens if a member of `stop_sequences` is generated.
*/
Expand Down Expand Up @@ -76,10 +91,30 @@ export interface TextGenerationParameters {
watermark?: boolean;
[property: string]: unknown;
}

/**
* Outputs for Text Generation inference
*/
export interface TextGenerationOutput {
/**
* When enabled, details about the generation
*/
details?: TextGenerationOutputDetails;
/**
* The generated text
*/
generated_text: string;
[property: string]: unknown;
}

/**
* When enabled, details about the generation
*/
export interface TextGenerationOutputDetails {
/**
* Details about additional sequences when best_of is provided
*/
best_of_sequences?: TextGenerationSequenceDetails[];
/**
* The reason why the generation was stopped.
*/
Expand All @@ -99,6 +134,32 @@ export interface TextGenerationOutputDetails {
tokens: Token[];
[property: string]: unknown;
}

export interface TextGenerationSequenceDetails {
/**
* The reason why the generation was stopped.
*/
finish_reason: FinishReason;
/**
* The generated text
*/
generated_text: number;
/**
* The number of generated tokens
*/
generated_tokens: number;
prefill: PrefillToken[];
/**
* The random seed used for generation
*/
seed?: number;
/**
* The generated tokens and associated details
*/
tokens: Token[];
[property: string]: unknown;
}

/**
* The generated sequence reached the maximum allowed length
*
Expand All @@ -107,6 +168,7 @@ export interface TextGenerationOutputDetails {
* One of the sequence in stop_sequences was generated
*/
export type FinishReason = "length" | "eos_token" | "stop_sequence";

export interface PrefillToken {
id: number;
logprob: number;
Expand All @@ -116,6 +178,7 @@ export interface PrefillToken {
text: string;
[property: string]: unknown;
}

export interface Token {
id: number;
logprob: number;
Expand All @@ -129,18 +192,3 @@ export interface Token {
text: string;
[property: string]: unknown;
}
export type TextGenerationOutput = TextGenerationOutputElement[];
/**
* Outputs for Text Generation inference
*/
export interface TextGenerationOutputElement {
/**
* When enabled, details about the generation
*/
details?: TextGenerationOutputDetails;
/**
* The generated text
*/
generated_text: string;
[property: string]: unknown;
}
12 changes: 12 additions & 0 deletions packages/tasks/src/tasks/text-generation/spec/input.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,22 @@
"type": "boolean",
"description": "Whether to use logits sampling instead of greedy decoding when generating new tokens."
},
"max_new_tokens": {
"type": "integer",
"description": "The maximum number of tokens to generate."
},
"repetition_penalty": {
"type": "number",
"description": "The parameter for repetition penalty. A value of 1.0 means no penalty. See [this paper](https://hf.co/papers/1909.05858) for more details."
},
"return_full_text": {
"type": "boolean",
"description": "Whether to prepend the prompt to the generated text."
},
"seed": {
"type": "integer",
"description": "The random sampling seed."
},
"stop_sequences": {
"type": "array",
"items": {
Expand Down
145 changes: 86 additions & 59 deletions packages/tasks/src/tasks/text-generation/spec/output.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,69 +3,45 @@
"$schema": "http://json-schema.org/draft-06/schema#",
"description": "Outputs for Text Generation inference",
"title": "TextGenerationOutput",
"type": "array",
"items": {
"type": "object",
"properties": {
"generated_text": {
"type": "string",
"description": "The generated text"
},
"details": {
"type": "object",
"description": "When enabled, details about the generation",
"title": "TextGenerationOutputDetails",
"properties": {
"generated_tokens": {
"type": "integer",
"description": "The number of generated tokens"
},
"finish_reason": {
"type": "string",
"description": "The reason why the generation was stopped.",
"oneOf": [
{ "const": "length", "description": "The generated sequence reached the maximum allowed length" },
{ "const": "eos_token", "description": "The model generated an end-of-sentence (EOS) token" },
{ "const": "stop_sequence", "description": "One of the sequence in stop_sequences was generated" }
]
},
"seed": {
"type": "integer",
"description": "The random seed used for generation"
},
"tokens": {
"type": "array",
"description": "The generated tokens and associated details",
"items": {
"$ref": "#/$defs/Token"
}
},
"prefill": {
"type": "array",
"items": {
"title": "PrefillToken",
"type": "object",
"properties": {
"id": {
"type": "integer"
},
"logprob": {
"type": "number"
},
"text": {
"type": "string",
"description": "The text associated with that token"
}
},
"required": ["id", "logprob", "text"]
"type": "object",
"properties": {
"generated_text": {
"type": "string",
"description": "The generated text"
},
"details": {
"description": "When enabled, details about the generation",
"title": "TextGenerationOutputDetails",
"allOf": [
{
"type": "object",
"properties": {
"best_of_sequences": {
"type": "array",
"description": "Details about additional sequences when best_of is provided",
"items": {
"allOf": [
{ "$ref": "#/$defs/SequenceDetails" },
{
"type": "object",
"properties": {
"generated_text": {
"type": "integer",
"description": "The generated text"
}
},
"required": ["generated_text"]
}
]
}
}
}
},
"required": ["generated_tokens", "finish_reason", "prefill", "tokens"]
}
},
"required": ["generated_text"]
{ "$ref": "#/$defs/SequenceDetails" }
]
}
},
"required": ["generated_text"],

"$defs": {
"Token": {
Expand All @@ -88,6 +64,57 @@
}
},
"required": ["id", "logprob", "special", "text"]
},
"SequenceDetails": {
"type": "object",
"title": "TextGenerationSequenceDetails",
"properties": {
"finish_reason": {
"type": "string",
"description": "The reason why the generation was stopped.",
"oneOf": [
{ "const": "length", "description": "The generated sequence reached the maximum allowed length" },
{ "const": "eos_token", "description": "The model generated an end-of-sentence (EOS) token" },
{ "const": "stop_sequence", "description": "One of the sequence in stop_sequences was generated" }
]
},
"generated_tokens": {
"type": "integer",
"description": "The number of generated tokens"
},
"prefill": {
"type": "array",
"items": {
"title": "PrefillToken",
"type": "object",
"properties": {
"id": {
"type": "integer"
},
"logprob": {
"type": "number"
},
"text": {
"type": "string",
"description": "The text associated with that token"
}
},
"required": ["id", "logprob", "text"]
}
},
"seed": {
"type": "integer",
"description": "The random seed used for generation"
},
"tokens": {
"type": "array",
"description": "The generated tokens and associated details",
"items": {
"$ref": "#/$defs/Token"
}
}
},
"required": ["finish_reason", "generated_tokens", "prefill", "tokens"]
}
}
}

0 comments on commit c8d0e9a

Please sign in to comment.