-
Notifications
You must be signed in to change notification settings - Fork 269
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inference] compatibility with third-party Inference providers #1077
Changes from 23 commits
21fc18c
f327375
4526b56
5e1ee1f
a2e7872
f1961a8
a413824
5629b86
d96a18e
aa7d1ca
238567a
43f739f
8155828
d8e01ab
884f931
e801bfa
01ccbf6
104f672
3528b7c
7e634bc
7c7a4ff
3939438
1b71e5b
841efab
e3c5264
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,17 @@ | ||
import type { InferenceTask, Options, RequestArgs } from "../types"; | ||
import { FAL_AI_API_BASE_URL, FAL_AI_MODEL_IDS } from "../providers/fal-ai"; | ||
import { REPLICATE_API_BASE_URL, REPLICATE_MODEL_IDS } from "../providers/replicate"; | ||
import { SAMBANOVA_API_BASE_URL, SAMBANOVA_MODEL_IDS } from "../providers/sambanova"; | ||
import { TOGETHER_API_BASE_URL, TOGETHER_MODEL_IDS } from "../providers/together"; | ||
import { INFERENCE_PROVIDERS, type InferenceTask, type Options, type RequestArgs } from "../types"; | ||
import { omit } from "../utils/omit"; | ||
import { HF_HUB_URL } from "./getDefaultTask"; | ||
import { isUrl } from "./isUrl"; | ||
|
||
const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co"; | ||
|
||
/** | ||
* Loaded from huggingface.co/api/tasks if needed | ||
* Lazy-loaded from huggingface.co/api/tasks when needed | ||
* Used to determine the default model to use when it's not user defined | ||
*/ | ||
let tasks: Record<string, { models: { id: string }[] }> | null = null; | ||
|
||
|
@@ -26,21 +31,14 @@ export async function makeRequestOptions( | |
chatCompletion?: boolean; | ||
} | ||
): Promise<{ url: string; info: RequestInit }> { | ||
const { accessToken, endpointUrl, ...otherArgs } = args; | ||
const { accessToken, endpointUrl, provider, ...otherArgs } = args; | ||
let { model } = args; | ||
const { | ||
forceTask: task, | ||
includeCredentials, | ||
taskHint, | ||
wait_for_model, | ||
use_cache, | ||
dont_load_model, | ||
chatCompletion, | ||
} = options ?? {}; | ||
const { forceTask, includeCredentials, taskHint, wait_for_model, use_cache, dont_load_model, chatCompletion } = | ||
options ?? {}; | ||
|
||
const headers: Record<string, string> = {}; | ||
if (accessToken) { | ||
headers["Authorization"] = `Bearer ${accessToken}`; | ||
headers["Authorization"] = provider === "fal-ai" ? `Key ${accessToken}` : `Bearer ${accessToken}`; | ||
} | ||
|
||
if (!model && !tasks && taskHint) { | ||
|
@@ -61,6 +59,35 @@ export async function makeRequestOptions( | |
if (!model) { | ||
throw new Error("No model provided, and no default model found for this task"); | ||
} | ||
if (provider) { | ||
if (!INFERENCE_PROVIDERS.includes(provider)) { | ||
throw new Error("Unknown Inference provider"); | ||
} | ||
if (!accessToken) { | ||
throw new Error("Specifying an Inference provider requires an accessToken"); | ||
} | ||
|
||
const modelId = (() => { | ||
switch (provider) { | ||
case "replicate": | ||
return REPLICATE_MODEL_IDS[model]; | ||
case "sambanova": | ||
return SAMBANOVA_MODEL_IDS[model]; | ||
case "together": | ||
return TOGETHER_MODEL_IDS[model]?.id; | ||
case "fal-ai": | ||
return FAL_AI_MODEL_IDS[model]; | ||
default: | ||
return model; | ||
} | ||
Comment on lines
+71
to
+81
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should maybe at least Note that we could maybe maintain a mapping in the backend and in case of errors try to load it -only once (like we do for default models associated to tasks). Just a thought for the future, but it would enable new mappings without updating the lib. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
cc @julien-c - we discussed it and decided to stick to HF model IDs for now
Yes, we definitely want a way for 3rd party providers to expose the mapping There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes. simpler to always be "Hub-centric" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah yes and i now remember, the |
||
})(); | ||
|
||
if (!modelId) { | ||
throw new Error(`Model ${model} is not supported for provider ${provider}`); | ||
} | ||
|
||
model = modelId; | ||
} | ||
|
||
const binary = "data" in args && !!args.data; | ||
|
||
|
@@ -77,6 +104,9 @@ export async function makeRequestOptions( | |
if (dont_load_model) { | ||
headers["X-Load-Model"] = "0"; | ||
} | ||
if (provider === "replicate") { | ||
headers["Prefer"] = "wait"; | ||
} | ||
|
||
let url = (() => { | ||
if (endpointUrl && isUrl(model)) { | ||
|
@@ -89,8 +119,33 @@ export async function makeRequestOptions( | |
if (endpointUrl) { | ||
return endpointUrl; | ||
} | ||
if (task) { | ||
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`; | ||
if (forceTask) { | ||
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${forceTask}/${model}`; | ||
} | ||
if (provider) { | ||
if (!accessToken) { | ||
throw new Error("Specifying an Inference provider requires an accessToken"); | ||
} | ||
if (accessToken.startsWith("hf_")) { | ||
/// TODO we wil proxy the request server-side (using our own keys) and handle billing for it on the user's HF account. | ||
throw new Error("Inference proxying is not implemented yet"); | ||
} else { | ||
Comment on lines
+130
to
+131
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
switch (provider) { | ||
case "fal-ai": | ||
return `${FAL_AI_API_BASE_URL}/${model}`; | ||
case "replicate": | ||
return `${REPLICATE_API_BASE_URL}/v1/models/${model}/predictions`; | ||
case "sambanova": | ||
return SAMBANOVA_API_BASE_URL; | ||
case "together": | ||
if (taskHint === "text-to-image") { | ||
return `${TOGETHER_API_BASE_URL}/v1/images/generations`; | ||
} | ||
return TOGETHER_API_BASE_URL; | ||
default: | ||
break; | ||
} | ||
} | ||
} | ||
|
||
return `${HF_INFERENCE_API_BASE_URL}/models/${model}`; | ||
|
@@ -99,6 +154,9 @@ export async function makeRequestOptions( | |
if (chatCompletion && !url.endsWith("/chat/completions")) { | ||
url += "/v1/chat/completions"; | ||
} | ||
if (provider === "together" && taskHint === "text-generation" && !chatCompletion) { | ||
url += "/v1/completions"; | ||
} | ||
|
||
coyotte508 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/** | ||
* For edge runtimes, leave 'credentials' undefined, otherwise cloudflare workers will error | ||
|
@@ -116,9 +174,11 @@ export async function makeRequestOptions( | |
body: binary | ||
? args.data | ||
: JSON.stringify({ | ||
...(otherArgs.model && isUrl(otherArgs.model) ? omit(otherArgs, "model") : otherArgs), | ||
...((otherArgs.model && isUrl(otherArgs.model)) || provider === "replicate" || provider === "fal-ai" | ||
? omit(otherArgs, "model") | ||
: { ...otherArgs, model }), | ||
}), | ||
...(credentials && { credentials }), | ||
...(credentials ? { credentials } : undefined), | ||
signal: options?.signal, | ||
}; | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import type { ModelId } from "../types"; | ||
|
||
export const FAL_AI_API_BASE_URL = "https://fal.run"; | ||
|
||
type FalAiId = string; | ||
|
||
/** | ||
* Mapping from HF model ID -> fal.ai app id | ||
*/ | ||
export const FAL_AI_MODEL_IDS: Record<ModelId, FalAiId> = { | ||
/** text-to-image */ | ||
"black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell", | ||
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev", | ||
|
||
/** automatic-speech-recognition */ | ||
"openai/whisper-large-v3": "fal-ai/whisper", | ||
}; | ||
Comment on lines
+10
to
+17
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Idk if it has been discussed somewhere but I feel that at some point these mappings should be maintained as an API lazy-loaded by the inference client (same as what we are doing for the tasks). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. scroll up: #1077 (comment) :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🙈 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. great minds think alike |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import type { ModelId } from "../types"; | ||
|
||
export const REPLICATE_API_BASE_URL = "https://api.replicate.com"; | ||
|
||
type ReplicateId = string; | ||
|
||
/** | ||
* Mapping from HF model ID -> Replicate model ID | ||
* | ||
* Available models can be fetched with: | ||
* ``` | ||
* curl -s \ | ||
* -H "Authorization: Bearer $REPLICATE_API_TOKEN" \ | ||
* 'https://api.replicate.com/v1/models' | ||
* ``` | ||
*/ | ||
export const REPLICATE_MODEL_IDS: Record<ModelId, ReplicateId> = { | ||
/** text-to-image */ | ||
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell", | ||
"ByteDance/SDXL-Lightning": "bytedance/sdxl-lightning-4step", | ||
}; |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import type { ModelId } from "../types"; | ||
|
||
export const SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai"; | ||
|
||
/** | ||
* Note for reviewers: our goal would be to ask Sambanova to support | ||
* our model ids too, so we don't have to define a mapping | ||
* or keep it up-to-date. | ||
* | ||
* As a fallback, if the above is not possible, ask Sambanova to | ||
* provide the mapping as an fetchable API. | ||
*/ | ||
type SambanovaId = string; | ||
|
||
/** | ||
* https://community.sambanova.ai/t/supported-models/193 | ||
*/ | ||
export const SAMBANOVA_MODEL_IDS: Record<ModelId, SambanovaId> = { | ||
/** Chat completion / conversational */ | ||
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct", | ||
"Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct", | ||
"Qwen/QwQ-32B-Preview": "QwQ-32B-Preview", | ||
"meta-llama/Llama-3.3-70B-Instruct": "Meta-Llama-3.3-70B-Instruct", | ||
"meta-llama/Llama-3.2-1B": "Meta-Llama-3.2-1B-Instruct", | ||
"meta-llama/Llama-3.2-3B": "Meta-Llama-3.2-3B-Instruct", | ||
"meta-llama/Llama-3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct", | ||
"meta-llama/Llama-3.2-90B-Vision-Instruct": "Llama-3.2-90B-Vision-Instruct", | ||
"meta-llama/Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct", | ||
"meta-llama/Llama-3.1-70B-Instruct": "Meta-Llama-3.1-70B-Instruct", | ||
"meta-llama/Llama-3.1-405B-Instruct": "Meta-Llama-3.1-405B-Instruct", | ||
"meta-llama/Llama-Guard-3-8B": "Meta-Llama-Guard-3-8B", | ||
}; |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import type { ModelId } from "../types"; | ||
|
||
export const TOGETHER_API_BASE_URL = "https://api.together.xyz"; | ||
|
||
/** | ||
* Same comment as in sambanova.ts | ||
*/ | ||
type TogetherId = string; | ||
|
||
/** | ||
* https://docs.together.ai/reference/models-1 | ||
*/ | ||
export const TOGETHER_MODEL_IDS: Record< | ||
ModelId, | ||
{ id: TogetherId; type: "chat" | "embedding" | "image" | "language" | "moderation" } | ||
> = { | ||
/** text-to-image */ | ||
"black-forest-labs/FLUX.1-Canny-dev": { id: "black-forest-labs/FLUX.1-canny", type: "image" }, | ||
"black-forest-labs/FLUX.1-Depth-dev": { id: "black-forest-labs/FLUX.1-depth", type: "image" }, | ||
"black-forest-labs/FLUX.1-dev": { id: "black-forest-labs/FLUX.1-dev", type: "image" }, | ||
"black-forest-labs/FLUX.1-Redux-dev": { id: "black-forest-labs/FLUX.1-redux", type: "image" }, | ||
"black-forest-labs/FLUX.1-schnell": { id: "black-forest-labs/FLUX.1-pro", type: "image" }, | ||
"stabilityai/stable-diffusion-xl-base-1.0": { id: "stabilityai/stable-diffusion-xl-base-1.0", type: "image" }, | ||
|
||
/** chat completion */ | ||
"databricks/dbrx-instruct": { id: "databricks/dbrx-instruct", type: "chat" }, | ||
"deepseek-ai/deepseek-llm-67b-chat": { id: "deepseek-ai/deepseek-llm-67b-chat", type: "chat" }, | ||
"google/gemma-2-9b-it": { id: "google/gemma-2-9b-it", type: "chat" }, | ||
"google/gemma-2b-it": { id: "google/gemma-2-27b-it", type: "chat" }, | ||
"llava-hf/llava-v1.6-mistral-7b-hf": { id: "llava-hf/llava-v1.6-mistral-7b-hf", type: "chat" }, | ||
"meta-llama/Llama-2-13b-chat-hf": { id: "meta-llama/Llama-2-13b-chat-hf", type: "chat" }, | ||
"meta-llama/Llama-2-70b-hf": { id: "meta-llama/Llama-2-70b-hf", type: "language" }, | ||
"meta-llama/Llama-2-7b-chat-hf": { id: "meta-llama/Llama-2-7b-chat-hf", type: "chat" }, | ||
"meta-llama/Llama-3.2-11B-Vision-Instruct": { id: "meta-llama/Llama-Vision-Free", type: "chat" }, | ||
"meta-llama/Llama-3.2-3B-Instruct": { id: "meta-llama/Llama-3.2-3B-Instruct-Turbo", type: "chat" }, | ||
"meta-llama/Llama-3.2-90B-Vision-Instruct": { id: "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", type: "chat" }, | ||
"meta-llama/Llama-3.3-70B-Instruct": { id: "meta-llama/Llama-3.3-70B-Instruct-Turbo", type: "chat" }, | ||
"meta-llama/Meta-Llama-3-70B-Instruct": { id: "meta-llama/Llama-3-70b-chat-hf", type: "chat" }, | ||
"meta-llama/Meta-Llama-3-8B-Instruct": { id: "togethercomputer/Llama-3-8b-chat-hf-int4", type: "chat" }, | ||
"meta-llama/Meta-Llama-3.1-405B-Instruct": { id: "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", type: "chat" }, | ||
"meta-llama/Meta-Llama-3.1-70B-Instruct": { id: "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", type: "chat" }, | ||
"meta-llama/Meta-Llama-3.1-8B-Instruct": { id: "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo-128K", type: "chat" }, | ||
"microsoft/WizardLM-2-8x22B": { id: "microsoft/WizardLM-2-8x22B", type: "chat" }, | ||
"mistralai/Mistral-7B-Instruct-v0.3": { id: "mistralai/Mistral-7B-Instruct-v0.3", type: "chat" }, | ||
"mistralai/Mixtral-8x22B-Instruct-v0.1": { id: "mistralai/Mixtral-8x22B-Instruct-v0.1", type: "chat" }, | ||
"mistralai/Mixtral-8x7B-Instruct-v0.1": { id: "mistralai/Mixtral-8x7B-Instruct-v0.1", type: "chat" }, | ||
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": { id: "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", type: "chat" }, | ||
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF": { id: "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF", type: "chat" }, | ||
"Qwen/Qwen2-72B-Instruct": { id: "Qwen/Qwen2-72B-Instruct", type: "chat" }, | ||
"Qwen/Qwen2.5-72B-Instruct": { id: "Qwen/Qwen2.5-72B-Instruct-Turbo", type: "chat" }, | ||
"Qwen/Qwen2.5-7B-Instruct": { id: "Qwen/Qwen2.5-7B-Instruct-Turbo", type: "chat" }, | ||
"Qwen/Qwen2.5-Coder-32B-Instruct": { id: "Qwen/Qwen2.5-Coder-32B-Instruct", type: "chat" }, | ||
"Qwen/QwQ-32B-Preview": { id: "Qwen/QwQ-32B-Preview", type: "chat" }, | ||
"scb10x/llama-3-typhoon-v1.5-8b-instruct": { id: "scb10x/scb10x-llama3-typhoon-v1-5-8b-instruct", type: "chat" }, | ||
"scb10x/llama-3-typhoon-v1.5x-70b-instruct-awq": { id: "scb10x/scb10x-llama3-typhoon-v1-5x-4f316", type: "chat" }, | ||
|
||
/** text-generation */ | ||
"meta-llama/Meta-Llama-3-8B": { id: "meta-llama/Meta-Llama-3-8B", type: "language" }, | ||
"mistralai/Mixtral-8x7B-v0.1": { id: "mistralai/Mixtral-8x7B-v0.1", type: "language" }, | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @xenova fyi