From 455f12caede96e60a888a3ef923b41724d99cae1 Mon Sep 17 00:00:00 2001 From: Simon Brandeis <33657802+SBrandeis@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:11:04 +0100 Subject: [PATCH] [Inference] Proxy calls to 3rd party providers (#1108) Companion to https://github.com/huggingface-internal/moon-landing/pull/12072 (internal) --- packages/inference/src/config.ts | 2 + packages/inference/src/lib/getDefaultTask.ts | 2 +- .../inference/src/lib/makeRequestOptions.ts | 69 ++++++++++++++----- packages/inference/test/vcr.ts | 2 +- 4 files changed, 54 insertions(+), 21 deletions(-) create mode 100644 packages/inference/src/config.ts diff --git a/packages/inference/src/config.ts b/packages/inference/src/config.ts new file mode 100644 index 000000000..96ff9393b --- /dev/null +++ b/packages/inference/src/config.ts @@ -0,0 +1,2 @@ +export const HF_HUB_URL = "https://huggingface.co"; +export const HF_INFERENCE_API_URL = "https://api-inference.huggingface.co"; diff --git a/packages/inference/src/lib/getDefaultTask.ts b/packages/inference/src/lib/getDefaultTask.ts index 3149998d6..38133d379 100644 --- a/packages/inference/src/lib/getDefaultTask.ts +++ b/packages/inference/src/lib/getDefaultTask.ts @@ -1,3 +1,4 @@ +import { HF_HUB_URL } from "../config"; import { isUrl } from "./isUrl"; /** @@ -8,7 +9,6 @@ import { isUrl } from "./isUrl"; const taskCache = new Map(); const CACHE_DURATION = 10 * 60 * 1000; const MAX_CACHE_ITEMS = 1000; -export const HF_HUB_URL = "https://huggingface.co"; export interface DefaultTaskOptions { fetch?: typeof fetch; diff --git a/packages/inference/src/lib/makeRequestOptions.ts b/packages/inference/src/lib/makeRequestOptions.ts index 4c8d0be7f..31913568f 100644 --- a/packages/inference/src/lib/makeRequestOptions.ts +++ b/packages/inference/src/lib/makeRequestOptions.ts @@ -1,13 +1,13 @@ +import { HF_HUB_URL, HF_INFERENCE_API_URL } from "../config"; import { FAL_AI_API_BASE_URL, FAL_AI_MODEL_IDS } from "../providers/fal-ai"; import { REPLICATE_API_BASE_URL, REPLICATE_MODEL_IDS } from "../providers/replicate"; import { SAMBANOVA_API_BASE_URL, SAMBANOVA_MODEL_IDS } from "../providers/sambanova"; import { TOGETHER_API_BASE_URL, TOGETHER_MODEL_IDS } from "../providers/together"; import type { InferenceProvider } from "../types"; import type { InferenceTask, Options, RequestArgs } from "../types"; -import { HF_HUB_URL } from "./getDefaultTask"; import { isUrl } from "./isUrl"; -const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co"; +const HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_HUB_URL}/api/inference-proxy/{{PROVIDER}}`; /** * Lazy-loaded from huggingface.co/api/tasks when needed @@ -59,21 +59,32 @@ export async function makeRequestOptions( model = mapModel({ model: maybeModel, provider }); } + /// If accessToken is passed, it should take precedence over includeCredentials + const authMethod = accessToken + ? accessToken.startsWith("hf_") + ? "hf-token" + : "provider-key" + : includeCredentials === "include" + ? "credentials-include" + : "none"; + const url = endpointUrl ? chatCompletion ? endpointUrl + `/v1/chat/completions` : endpointUrl : makeUrl({ + authMethod, + chatCompletion: chatCompletion ?? false, + forceTask, model, provider: provider ?? "hf-inference", taskHint, - chatCompletion: chatCompletion ?? false, - forceTask, }); const headers: Record = {}; if (accessToken) { - headers["Authorization"] = provider === "fal-ai" ? `Key ${accessToken}` : `Bearer ${accessToken}`; + headers["Authorization"] = + provider === "fal-ai" && authMethod === "provider-key" ? `Key ${accessToken}` : `Bearer ${accessToken}`; } const binary = "data" in args && !!args.data; @@ -155,46 +166,66 @@ function mapModel(params: { model: string; provider: InferenceProvider }): strin } function makeUrl(params: { + authMethod: "none" | "hf-token" | "credentials-include" | "provider-key"; + chatCompletion: boolean; model: string; provider: InferenceProvider; taskHint: InferenceTask | undefined; - chatCompletion: boolean; forceTask?: string | InferenceTask; }): string { + if (params.authMethod === "none" && params.provider !== "hf-inference") { + throw new Error("Authentication is required when requesting a third-party provider. Please provide accessToken"); + } + + const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key"; switch (params.provider) { - case "fal-ai": - return `${FAL_AI_API_BASE_URL}/${params.model}`; + case "fal-ai": { + const baseUrl = shouldProxy + ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) + : FAL_AI_API_BASE_URL; + return `${baseUrl}/${params.model}`; + } case "replicate": { + const baseUrl = shouldProxy + ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) + : REPLICATE_API_BASE_URL; if (params.model.includes(":")) { /// Versioned model - return `${REPLICATE_API_BASE_URL}/v1/predictions`; + return `${baseUrl}/v1/predictions`; } /// Evergreen / Canonical model - return `${REPLICATE_API_BASE_URL}/v1/models/${params.model}/predictions`; + return `${baseUrl}/v1/models/${params.model}/predictions`; } - case "sambanova": + case "sambanova": { + const baseUrl = shouldProxy + ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) + : SAMBANOVA_API_BASE_URL; /// Sambanova API matches OpenAI-like APIs: model is defined in the request body if (params.taskHint === "text-generation" && params.chatCompletion) { - return `${SAMBANOVA_API_BASE_URL}/v1/chat/completions`; + return `${baseUrl}/v1/chat/completions`; } - return SAMBANOVA_API_BASE_URL; + return baseUrl; + } case "together": { + const baseUrl = shouldProxy + ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) + : TOGETHER_API_BASE_URL; /// Together API matches OpenAI-like APIs: model is defined in the request body if (params.taskHint === "text-to-image") { - return `${TOGETHER_API_BASE_URL}/v1/images/generations`; + return `${baseUrl}/v1/images/generations`; } if (params.taskHint === "text-generation") { if (params.chatCompletion) { - return `${TOGETHER_API_BASE_URL}/v1/chat/completions`; + return `${baseUrl}/v1/chat/completions`; } - return `${TOGETHER_API_BASE_URL}/v1/completions`; + return `${baseUrl}/v1/completions`; } - return TOGETHER_API_BASE_URL; + return baseUrl; } default: { const url = params.forceTask - ? `${HF_INFERENCE_API_BASE_URL}/pipeline/${params.forceTask}/${params.model}` - : `${HF_INFERENCE_API_BASE_URL}/models/${params.model}`; + ? `${HF_INFERENCE_API_URL}/pipeline/${params.forceTask}/${params.model}` + : `${HF_INFERENCE_API_URL}/models/${params.model}`; if (params.taskHint === "text-generation" && params.chatCompletion) { return url + `/v1/chat/completions`; } diff --git a/packages/inference/test/vcr.ts b/packages/inference/test/vcr.ts index 7175e350e..8d7460c58 100644 --- a/packages/inference/test/vcr.ts +++ b/packages/inference/test/vcr.ts @@ -1,5 +1,5 @@ import { omit } from "../src/utils/omit"; -import { HF_HUB_URL } from "../src/lib/getDefaultTask"; +import { HF_HUB_URL } from "../src/config"; import { isBackend } from "../src/utils/isBackend"; import { isFrontend } from "../src/utils/isFrontend";