Skip to content

Commit

Permalink
[Inference] Proxy calls to 3rd party providers (#1108)
Browse files Browse the repository at this point in the history
  • Loading branch information
SBrandeis authored Jan 17, 2025
1 parent 4b2fbb6 commit 455f12c
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 21 deletions.
2 changes: 2 additions & 0 deletions packages/inference/src/config.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
export const HF_HUB_URL = "https://huggingface.co";
export const HF_INFERENCE_API_URL = "https://api-inference.huggingface.co";
2 changes: 1 addition & 1 deletion packages/inference/src/lib/getDefaultTask.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { HF_HUB_URL } from "../config";
import { isUrl } from "./isUrl";

/**
Expand All @@ -8,7 +9,6 @@ import { isUrl } from "./isUrl";
const taskCache = new Map<string, { task: string; date: Date }>();
const CACHE_DURATION = 10 * 60 * 1000;
const MAX_CACHE_ITEMS = 1000;
export const HF_HUB_URL = "https://huggingface.co";

export interface DefaultTaskOptions {
fetch?: typeof fetch;
Expand Down
69 changes: 50 additions & 19 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import { HF_HUB_URL, HF_INFERENCE_API_URL } from "../config";
import { FAL_AI_API_BASE_URL, FAL_AI_MODEL_IDS } from "../providers/fal-ai";
import { REPLICATE_API_BASE_URL, REPLICATE_MODEL_IDS } from "../providers/replicate";
import { SAMBANOVA_API_BASE_URL, SAMBANOVA_MODEL_IDS } from "../providers/sambanova";
import { TOGETHER_API_BASE_URL, TOGETHER_MODEL_IDS } from "../providers/together";
import type { InferenceProvider } from "../types";
import type { InferenceTask, Options, RequestArgs } from "../types";
import { HF_HUB_URL } from "./getDefaultTask";
import { isUrl } from "./isUrl";

const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";
const HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_HUB_URL}/api/inference-proxy/{{PROVIDER}}`;

/**
* Lazy-loaded from huggingface.co/api/tasks when needed
Expand Down Expand Up @@ -59,21 +59,32 @@ export async function makeRequestOptions(
model = mapModel({ model: maybeModel, provider });
}

/// If accessToken is passed, it should take precedence over includeCredentials
const authMethod = accessToken
? accessToken.startsWith("hf_")
? "hf-token"
: "provider-key"
: includeCredentials === "include"
? "credentials-include"
: "none";

const url = endpointUrl
? chatCompletion
? endpointUrl + `/v1/chat/completions`
: endpointUrl
: makeUrl({
authMethod,
chatCompletion: chatCompletion ?? false,
forceTask,
model,
provider: provider ?? "hf-inference",
taskHint,
chatCompletion: chatCompletion ?? false,
forceTask,
});

const headers: Record<string, string> = {};
if (accessToken) {
headers["Authorization"] = provider === "fal-ai" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
headers["Authorization"] =
provider === "fal-ai" && authMethod === "provider-key" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
}

const binary = "data" in args && !!args.data;
Expand Down Expand Up @@ -155,46 +166,66 @@ function mapModel(params: { model: string; provider: InferenceProvider }): strin
}

function makeUrl(params: {
authMethod: "none" | "hf-token" | "credentials-include" | "provider-key";
chatCompletion: boolean;
model: string;
provider: InferenceProvider;
taskHint: InferenceTask | undefined;
chatCompletion: boolean;
forceTask?: string | InferenceTask;
}): string {
if (params.authMethod === "none" && params.provider !== "hf-inference") {
throw new Error("Authentication is required when requesting a third-party provider. Please provide accessToken");
}

const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key";
switch (params.provider) {
case "fal-ai":
return `${FAL_AI_API_BASE_URL}/${params.model}`;
case "fal-ai": {
const baseUrl = shouldProxy
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
: FAL_AI_API_BASE_URL;
return `${baseUrl}/${params.model}`;
}
case "replicate": {
const baseUrl = shouldProxy
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
: REPLICATE_API_BASE_URL;
if (params.model.includes(":")) {
/// Versioned model
return `${REPLICATE_API_BASE_URL}/v1/predictions`;
return `${baseUrl}/v1/predictions`;
}
/// Evergreen / Canonical model
return `${REPLICATE_API_BASE_URL}/v1/models/${params.model}/predictions`;
return `${baseUrl}/v1/models/${params.model}/predictions`;
}
case "sambanova":
case "sambanova": {
const baseUrl = shouldProxy
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
: SAMBANOVA_API_BASE_URL;
/// Sambanova API matches OpenAI-like APIs: model is defined in the request body
if (params.taskHint === "text-generation" && params.chatCompletion) {
return `${SAMBANOVA_API_BASE_URL}/v1/chat/completions`;
return `${baseUrl}/v1/chat/completions`;
}
return SAMBANOVA_API_BASE_URL;
return baseUrl;
}
case "together": {
const baseUrl = shouldProxy
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
: TOGETHER_API_BASE_URL;
/// Together API matches OpenAI-like APIs: model is defined in the request body
if (params.taskHint === "text-to-image") {
return `${TOGETHER_API_BASE_URL}/v1/images/generations`;
return `${baseUrl}/v1/images/generations`;
}
if (params.taskHint === "text-generation") {
if (params.chatCompletion) {
return `${TOGETHER_API_BASE_URL}/v1/chat/completions`;
return `${baseUrl}/v1/chat/completions`;
}
return `${TOGETHER_API_BASE_URL}/v1/completions`;
return `${baseUrl}/v1/completions`;
}
return TOGETHER_API_BASE_URL;
return baseUrl;
}
default: {
const url = params.forceTask
? `${HF_INFERENCE_API_BASE_URL}/pipeline/${params.forceTask}/${params.model}`
: `${HF_INFERENCE_API_BASE_URL}/models/${params.model}`;
? `${HF_INFERENCE_API_URL}/pipeline/${params.forceTask}/${params.model}`
: `${HF_INFERENCE_API_URL}/models/${params.model}`;
if (params.taskHint === "text-generation" && params.chatCompletion) {
return url + `/v1/chat/completions`;
}
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/test/vcr.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { omit } from "../src/utils/omit";
import { HF_HUB_URL } from "../src/lib/getDefaultTask";
import { HF_HUB_URL } from "../src/config";
import { isBackend } from "../src/utils/isBackend";
import { isFrontend } from "../src/utils/isFrontend";

Expand Down

0 comments on commit 455f12c

Please sign in to comment.