-
Notifications
You must be signed in to change notification settings - Fork 269
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inference] compatibility with third-party Inference providers #1077
Merged
Merged
Changes from 10 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
21fc18c
formatting
julien-c f327375
link to jinja in main README
julien-c 4526b56
link to DDUF
julien-c 5e1ee1f
Update LICENSE
julien-c a2e7872
Update README.md
julien-c f1961a8
implem
julien-c a413824
also stream
julien-c 5629b86
Together.ai implem
julien-c d96a18e
textToImage should work too
julien-c aa7d1ca
support for replicate
julien-c 238567a
add fal-ai as a provider
SBrandeis 43f739f
tweaks
SBrandeis 8155828
stringify JSON error when chat completing
SBrandeis d8e01ab
Update tests & VCR tapes
SBrandeis 884f931
add sample for asr on fal
SBrandeis e801bfa
update tapes
SBrandeis 01ccbf6
lint
SBrandeis 104f672
fix test + update tapes
SBrandeis 3528b7c
format + lint
SBrandeis 7e634bc
tests tests tests
SBrandeis 7c7a4ff
Merge branch 'main' into inference-providers
SBrandeis 3939438
add dummy keys to secrets
SBrandeis 1b71e5b
update tests & tapes
SBrandeis 841efab
fix textGeneration typo
SBrandeis e3c5264
fix tape for openai-community/gpt2 test
SBrandeis File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,7 @@ | ||
import type { InferenceTask, Options, RequestArgs } from "../types"; | ||
import { REPLICATE_API_BASE_URL, REPLICATE_MODEL_IDS } from "../providers/replicate"; | ||
import { SAMBANOVA_API_BASE_URL, SAMBANOVA_MODEL_IDS } from "../providers/sambanova"; | ||
import { TOGETHER_API_BASE_URL, TOGETHER_MODEL_IDS } from "../providers/together"; | ||
import { INFERENCE_PROVIDERS, type InferenceTask, type Options, type RequestArgs } from "../types"; | ||
import { omit } from "../utils/omit"; | ||
import { HF_HUB_URL } from "./getDefaultTask"; | ||
import { isUrl } from "./isUrl"; | ||
|
@@ -26,17 +29,10 @@ export async function makeRequestOptions( | |
chatCompletion?: boolean; | ||
} | ||
): Promise<{ url: string; info: RequestInit }> { | ||
const { accessToken, endpointUrl, ...otherArgs } = args; | ||
const { accessToken, endpointUrl, provider, ...otherArgs } = args; | ||
let { model } = args; | ||
const { | ||
forceTask: task, | ||
includeCredentials, | ||
taskHint, | ||
wait_for_model, | ||
use_cache, | ||
dont_load_model, | ||
chatCompletion, | ||
} = options ?? {}; | ||
const { forceTask, includeCredentials, taskHint, wait_for_model, use_cache, dont_load_model, chatCompletion } = | ||
options ?? {}; | ||
|
||
const headers: Record<string, string> = {}; | ||
if (accessToken) { | ||
|
@@ -61,6 +57,27 @@ export async function makeRequestOptions( | |
if (!model) { | ||
throw new Error("No model provided, and no default model found for this task"); | ||
} | ||
if (provider) { | ||
if (!INFERENCE_PROVIDERS.includes(provider)) { | ||
throw new Error("Unknown Inference provider"); | ||
} | ||
if (!accessToken) { | ||
throw new Error("Specifying an Inference provider requires an accessToken"); | ||
} | ||
switch (provider) { | ||
case "replicate": | ||
model = REPLICATE_MODEL_IDS[model]; | ||
break; | ||
case "sambanova": | ||
model = SAMBANOVA_MODEL_IDS[model]; | ||
break; | ||
case "together": | ||
model = TOGETHER_MODEL_IDS[model]?.id ?? model; | ||
break; | ||
default: | ||
break; | ||
} | ||
SBrandeis marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
const binary = "data" in args && !!args.data; | ||
|
||
|
@@ -77,6 +94,9 @@ export async function makeRequestOptions( | |
if (dont_load_model) { | ||
headers["X-Load-Model"] = "0"; | ||
} | ||
if (provider === "replicate") { | ||
headers["Prefer"] = "wait"; | ||
} | ||
|
||
let url = (() => { | ||
if (endpointUrl && isUrl(model)) { | ||
|
@@ -89,8 +109,32 @@ export async function makeRequestOptions( | |
if (endpointUrl) { | ||
return endpointUrl; | ||
} | ||
if (task) { | ||
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`; | ||
if (forceTask) { | ||
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${forceTask}/${model}`; | ||
} | ||
if (provider) { | ||
if (!accessToken) { | ||
throw new Error("Specifying an Inference provider requires an accessToken"); | ||
} | ||
if (accessToken.startsWith("hf_")) { | ||
/// TODO we wil proxy the request server-side (using our own keys) and handle billing for it on the user's HF account. | ||
throw new Error("Inference proxying is not implemented yet"); | ||
} else { | ||
Comment on lines
+130
to
+131
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
/// This is an external key | ||
switch (provider) { | ||
case "replicate": | ||
return `${REPLICATE_API_BASE_URL}/v1/models/${model}/predictions`; | ||
case "sambanova": | ||
return SAMBANOVA_API_BASE_URL; | ||
case "together": | ||
if (taskHint === "text-to-image") { | ||
return `${TOGETHER_API_BASE_URL}/v1/images/generations`; | ||
} | ||
return TOGETHER_API_BASE_URL; | ||
default: | ||
break; | ||
} | ||
} | ||
} | ||
|
||
return `${HF_INFERENCE_API_BASE_URL}/models/${model}`; | ||
|
@@ -116,9 +160,11 @@ export async function makeRequestOptions( | |
body: binary | ||
? args.data | ||
: JSON.stringify({ | ||
...(otherArgs.model && isUrl(otherArgs.model) ? omit(otherArgs, "model") : otherArgs), | ||
...((otherArgs.model && isUrl(otherArgs.model)) || provider === "replicate" | ||
? omit(otherArgs, "model") | ||
: { ...otherArgs, model }), | ||
}), | ||
...(credentials && { credentials }), | ||
...(credentials ? { credentials } : undefined), | ||
signal: options?.signal, | ||
}; | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import type { ModelId } from "../types"; | ||
|
||
export const REPLICATE_API_BASE_URL = "https://api.replicate.com"; | ||
|
||
/** | ||
* Same comment as in sambanova.ts | ||
*/ | ||
type ReplicateId = string; | ||
|
||
/** | ||
* curl -s \ | ||
* -H "Authorization: Bearer $REPLICATE_API_TOKEN" \ | ||
* https://api.replicate.com/v1/models | ||
*/ | ||
export const REPLICATE_MODEL_IDS: Record<ModelId, ReplicateId> = { | ||
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell", | ||
"ByteDance/SDXL-Lightning": "bytedance/sdxl-lightning-4step", | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import type { ModelId } from "../types"; | ||
|
||
export const SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai"; | ||
|
||
/** | ||
* Note for reviewers: our goal would be to ask Sambanova to support | ||
* our model ids too, so we don't have to define a mapping | ||
* or keep it up-to-date. | ||
* | ||
* As a fallback, if the above is not possible, ask Sambanova to | ||
* provide the mapping as an fetchable API. | ||
*/ | ||
type SambanovaId = string; | ||
|
||
/** | ||
* https://community.sambanova.ai/t/supported-models/193 | ||
*/ | ||
export const SAMBANOVA_MODEL_IDS: Record<ModelId, SambanovaId> = { | ||
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct", | ||
"Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct", | ||
"Qwen/QwQ-32B-Preview": "QwQ-32B-Preview", | ||
"meta-llama/Llama-3.3-70B-Instruct": "Meta-Llama-3.3-70B-Instruct", | ||
"meta-llama/Llama-3.2-1B": "Meta-Llama-3.2-1B-Instruct", | ||
"meta-llama/Llama-3.2-3B": "Meta-Llama-3.2-3B-Instruct", | ||
"meta-llama/Llama-3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct", | ||
"meta-llama/Llama-3.2-90B-Vision-Instruct": "Llama-3.2-90B-Vision-Instruct", | ||
"meta-llama/Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct", | ||
"meta-llama/Llama-3.1-70B-Instruct": "Meta-Llama-3.1-70B-Instruct", | ||
"meta-llama/Llama-3.1-405B-Instruct": "Meta-Llama-3.1-405B-Instruct", | ||
"meta-llama/Llama-Guard-3-8B": "Meta-Llama-Guard-3-8B", | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import type { ModelId } from "../types"; | ||
|
||
export const TOGETHER_API_BASE_URL = "https://api.together.xyz"; | ||
|
||
/** | ||
* Same comment as in sambanova.ts | ||
*/ | ||
type TogetherId = string; | ||
|
||
/** | ||
* https://docs.together.ai/reference/models-1 | ||
*/ | ||
export const TOGETHER_MODEL_IDS: Record< | ||
ModelId, | ||
{ id: TogetherId; type: "chat" | "embedding" | "image" | "language" | "moderation" } | ||
> = { | ||
"BAAI/bge-base-en-v1.5": { id: "BAAI/bge-base-en-v1.5", type: "embedding" }, | ||
"black-forest-labs/FLUX.1-Canny-dev": { id: "black-forest-labs/FLUX.1-canny", type: "image" }, | ||
"black-forest-labs/FLUX.1-Depth-dev": { id: "black-forest-labs/FLUX.1-depth", type: "image" }, | ||
"black-forest-labs/FLUX.1-dev": { id: "black-forest-labs/FLUX.1-dev", type: "image" }, | ||
"black-forest-labs/FLUX.1-Redux-dev": { id: "black-forest-labs/FLUX.1-redux", type: "image" }, | ||
"black-forest-labs/FLUX.1-schnell": { id: "black-forest-labs/FLUX.1-pro", type: "image" }, | ||
"databricks/dbrx-instruct": { id: "databricks/dbrx-instruct", type: "chat" }, | ||
"deepseek-ai/deepseek-llm-67b-chat": { id: "deepseek-ai/deepseek-llm-67b-chat", type: "chat" }, | ||
"google/gemma-2-9b-it": { id: "google/gemma-2-9b-it", type: "chat" }, | ||
"google/gemma-2b-it": { id: "google/gemma-2-27b-it", type: "chat" }, | ||
"llava-hf/llava-v1.6-mistral-7b-hf": { id: "llava-hf/llava-v1.6-mistral-7b-hf", type: "chat" }, | ||
"meta-llama/Llama-2-13b-chat-hf": { id: "meta-llama/Llama-2-13b-chat-hf", type: "chat" }, | ||
"meta-llama/Llama-2-70b-hf": { id: "meta-llama/Llama-2-70b-hf", type: "language" }, | ||
"meta-llama/Llama-2-7b-chat-hf": { id: "meta-llama/Llama-2-7b-chat-hf", type: "chat" }, | ||
"meta-llama/Llama-3.2-11B-Vision-Instruct": { id: "meta-llama/Llama-Vision-Free", type: "chat" }, | ||
"meta-llama/Llama-3.2-3B-Instruct": { id: "meta-llama/Llama-3.2-3B-Instruct-Turbo", type: "chat" }, | ||
"meta-llama/Llama-3.2-90B-Vision-Instruct": { id: "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", type: "chat" }, | ||
"meta-llama/Llama-3.3-70B-Instruct": { id: "meta-llama/Llama-3.3-70B-Instruct-Turbo", type: "chat" }, | ||
"meta-llama/Llama-Guard-3-11B-Vision": { id: "meta-llama/Llama-Guard-3-11B-Vision-Turbo", type: "moderation" }, | ||
"meta-llama/LlamaGuard-7b": { id: "Meta-Llama/Llama-Guard-7b", type: "moderation" }, | ||
"meta-llama/Meta-Llama-3-70B-Instruct": { id: "meta-llama/Llama-3-70b-chat-hf", type: "chat" }, | ||
"meta-llama/Meta-Llama-3-8B": { id: "meta-llama/Meta-Llama-3-8B", type: "language" }, | ||
"meta-llama/Meta-Llama-3-8B-Instruct": { id: "togethercomputer/Llama-3-8b-chat-hf-int4", type: "chat" }, | ||
"meta-llama/Meta-Llama-3.1-405B-Instruct": { id: "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", type: "chat" }, | ||
"meta-llama/Meta-Llama-3.1-70B-Instruct": { id: "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", type: "chat" }, | ||
"meta-llama/Meta-Llama-3.1-8B-Instruct": { id: "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo-128K", type: "chat" }, | ||
"microsoft/WizardLM-2-8x22B": { id: "microsoft/WizardLM-2-8x22B", type: "chat" }, | ||
"mistralai/Mistral-7B-Instruct-v0.3": { id: "mistralai/Mistral-7B-Instruct-v0.3", type: "chat" }, | ||
"mistralai/Mixtral-8x22B-Instruct-v0.1": { id: "mistralai/Mixtral-8x22B-Instruct-v0.1", type: "chat" }, | ||
"mistralai/Mixtral-8x7B-Instruct-v0.1": { id: "mistralai/Mixtral-8x7B-Instruct-v0.1", type: "chat" }, | ||
"mistralai/Mixtral-8x7B-v0.1": { id: "mistralai/Mixtral-8x7B-v0.1", type: "language" }, | ||
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": { id: "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", type: "chat" }, | ||
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF": { id: "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF", type: "chat" }, | ||
"Qwen/Qwen2-72B-Instruct": { id: "Qwen/Qwen2-72B-Instruct", type: "chat" }, | ||
"Qwen/Qwen2.5-72B-Instruct": { id: "Qwen/Qwen2.5-72B-Instruct-Turbo", type: "chat" }, | ||
"Qwen/Qwen2.5-7B-Instruct": { id: "Qwen/Qwen2.5-7B-Instruct-Turbo", type: "chat" }, | ||
"Qwen/Qwen2.5-Coder-32B-Instruct": { id: "Qwen/Qwen2.5-Coder-32B-Instruct", type: "chat" }, | ||
"Qwen/QwQ-32B-Preview": { id: "Qwen/QwQ-32B-Preview", type: "chat" }, | ||
"scb10x/llama-3-typhoon-v1.5-8b-instruct": { id: "scb10x/scb10x-llama3-typhoon-v1-5-8b-instruct", type: "chat" }, | ||
"scb10x/llama-3-typhoon-v1.5x-70b-instruct-awq": { id: "scb10x/scb10x-llama3-typhoon-v1-5x-4f316", type: "chat" }, | ||
"stabilityai/stable-diffusion-xl-base-1.0": { id: "stabilityai/stable-diffusion-xl-base-1.0", type: "image" }, | ||
"togethercomputer/m2-bert-80M-32k-retrieval": { id: "togethercomputer/m2-bert-80M-32k-retrieval", type: "embedding" }, | ||
"togethercomputer/m2-bert-80M-8k-retrieval": { id: "togethercomputer/m2-bert-80M-8k-retrieval", type: "embedding" }, | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @xenova fyi