Skip to content

Commit

Permalink
[Inference] compatibility with third-party Inference providers (#1077)
Browse files Browse the repository at this point in the history
# TL;DR

Allow users to request 3rd party inference providers (Sambanova,
Replicate, Together, Fal) with `@huggingface/inference` for a curated
set of models on the HF Hub

For now, Requesting a 3rd party inference provider requires users to
pass an api key from this provider as a parameter to the inference
function.

---------

Co-authored-by: SBrandeis <[email protected]>
  • Loading branch information
julien-c and SBrandeis authored Jan 9, 2025
1 parent 29715d4 commit 86b1f2e
Show file tree
Hide file tree
Showing 20 changed files with 1,644 additions and 765 deletions.
12 changes: 12 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ jobs:
run: VCR_MODE=playback pnpm --filter ...[${{ steps.since.outputs.SINCE }}] test
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_FAL_KEY: dummy
HF_REPLICATE_KEY: dummy
HF_SAMBANOVA_KEY: dummy
HF_TOGETHER_KEY: dummy

browser:
runs-on: ubuntu-latest
Expand Down Expand Up @@ -77,6 +81,10 @@ jobs:
run: VCR_MODE=playback pnpm --filter ...[${{ steps.since.outputs.SINCE }}] test:browser
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_FAL_KEY: dummy
HF_REPLICATE_KEY: dummy
HF_SAMBANOVA_KEY: dummy
HF_TOGETHER_KEY: dummy

e2e:
runs-on: ubuntu-latest
Expand Down Expand Up @@ -140,3 +148,7 @@ jobs:
env:
NPM_CONFIG_REGISTRY: http://localhost:4874/
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_FAL_KEY: dummy
HF_REPLICATE_KEY: dummy
HF_SAMBANOVA_KEY: dummy
HF_TOGETHER_KEY: dummy
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// Programatically interact with the Hub

await createRepo({
repo: {type: "model", name: "my-user/nlp-model"},
repo: { type: "model", name: "my-user/nlp-model" },
accessToken: HF_TOKEN
});

Expand Down Expand Up @@ -53,11 +53,13 @@ await inference.textToImage({

This is a collection of JS libraries to interact with the Hugging Face API, with TS types included.

- [@huggingface/inference](packages/inference/README.md): Use Inference Endpoints (dedicated) and Inference API (serverless) to make calls to 100,000+ Machine Learning models
- [@huggingface/inference](packages/inference/README.md): Use Inference API (serverless) and Inference Endpoints (dedicated) to make calls to 100,000+ Machine Learning models
- [@huggingface/hub](packages/hub/README.md): Interact with huggingface.co to create or delete repos and commit / download files
- [@huggingface/agents](packages/agents/README.md): Interact with HF models through a natural language interface
- [@huggingface/gguf](packages/gguf/README.md): A GGUF parser that works on remotely hosted files.
- [@huggingface/dduf](packages/dduf/README.md): Similar package for DDUF (DDUF Diffusers Unified Format)
- [@huggingface/tasks](packages/tasks/README.md): The definition files and source-of-truth for the Hub's main primitives like pipeline tasks, model libraries, etc.
- [@huggingface/jinja](packages/jinja/README.md): A minimalistic JS implementation of the Jinja templating engine, to be used for ML chat templates.
- [@huggingface/space-header](packages/space-header/README.md): Use the Space `mini_header` outside Hugging Face


Expand Down Expand Up @@ -165,7 +167,7 @@ await inference.imageToText({
const gpt2 = inference.endpoint('https://xyz.eu-west-1.aws.endpoints.huggingface.cloud/gpt2');
const { generated_text } = await gpt2.textGeneration({inputs: 'The answer to the universe is'});

//Chat Completion
// Chat Completion
const llamaEndpoint = inference.endpoint(
"https://api-inference.huggingface.co/models/meta-llama/Llama-3.1-8B-Instruct"
);
Expand All @@ -185,7 +187,7 @@ import { createRepo, uploadFile, deleteFiles } from "@huggingface/hub";
const HF_TOKEN = "hf_...";

await createRepo({
repo: "my-user/nlp-model", // or {type: "model", name: "my-user/nlp-test"},
repo: "my-user/nlp-model", // or { type: "model", name: "my-user/nlp-test" },
accessToken: HF_TOKEN
});

Expand All @@ -200,7 +202,7 @@ await uploadFile({
});

await deleteFiles({
repo: {type: "space", name: "my-user/my-space"}, // or "spaces/my-user/my-space"
repo: { type: "space", name: "my-user/my-space" }, // or "spaces/my-user/my-space"
accessToken: HF_TOKEN,
paths: ["README.md", ".gitattributes"]
});
Expand All @@ -209,7 +211,7 @@ await deleteFiles({
### @huggingface/agents example

```ts
import {HfAgent, LLMFromHub, defaultTools} from '@huggingface/agents';
import { HfAgent, LLMFromHub, defaultTools } from '@huggingface/agents';

const HF_TOKEN = "hf_...";

Expand Down
2 changes: 1 addition & 1 deletion packages/inference/LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2022 Tim Mikeladze
Copyright (c) 2022 Tim Mikeladze and the Hugging Face team

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
4 changes: 2 additions & 2 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# 🤗 Hugging Face Inference Endpoints

A Typescript powered wrapper for the Hugging Face Inference Endpoints API. Learn more about Inference Endpoints at [Hugging Face](https://huggingface.co/inference-endpoints).
It works with both [Inference API (serverless)](https://huggingface.co/docs/api-inference/index) and [Inference Endpoints (dedicated)](https://huggingface.co/docs/inference-endpoints/index).
A Typescript powered wrapper for the Hugging Face Inference API (serverless), Inference Endpoints (dedicated), and third-party Inference Providers.
It works with [Inference API (serverless)](https://huggingface.co/docs/api-inference/index) and [Inference Endpoints (dedicated)](https://huggingface.co/docs/inference-endpoints/index), and even with supported third-party Inference Providers.

Check out the [full documentation](https://huggingface.co/docs/huggingface.js/inference/README).

Expand Down
94 changes: 77 additions & 17 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import type { InferenceTask, Options, RequestArgs } from "../types";
import { FAL_AI_API_BASE_URL, FAL_AI_MODEL_IDS } from "../providers/fal-ai";
import { REPLICATE_API_BASE_URL, REPLICATE_MODEL_IDS } from "../providers/replicate";
import { SAMBANOVA_API_BASE_URL, SAMBANOVA_MODEL_IDS } from "../providers/sambanova";
import { TOGETHER_API_BASE_URL, TOGETHER_MODEL_IDS } from "../providers/together";
import { INFERENCE_PROVIDERS, type InferenceTask, type Options, type RequestArgs } from "../types";
import { omit } from "../utils/omit";
import { HF_HUB_URL } from "./getDefaultTask";
import { isUrl } from "./isUrl";

const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";

/**
* Loaded from huggingface.co/api/tasks if needed
* Lazy-loaded from huggingface.co/api/tasks when needed
* Used to determine the default model to use when it's not user defined
*/
let tasks: Record<string, { models: { id: string }[] }> | null = null;

Expand All @@ -26,21 +31,14 @@ export async function makeRequestOptions(
chatCompletion?: boolean;
}
): Promise<{ url: string; info: RequestInit }> {
const { accessToken, endpointUrl, ...otherArgs } = args;
const { accessToken, endpointUrl, provider, ...otherArgs } = args;
let { model } = args;
const {
forceTask: task,
includeCredentials,
taskHint,
wait_for_model,
use_cache,
dont_load_model,
chatCompletion,
} = options ?? {};
const { forceTask, includeCredentials, taskHint, wait_for_model, use_cache, dont_load_model, chatCompletion } =
options ?? {};

const headers: Record<string, string> = {};
if (accessToken) {
headers["Authorization"] = `Bearer ${accessToken}`;
headers["Authorization"] = provider === "fal-ai" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
}

if (!model && !tasks && taskHint) {
Expand All @@ -61,6 +59,35 @@ export async function makeRequestOptions(
if (!model) {
throw new Error("No model provided, and no default model found for this task");
}
if (provider) {
if (!INFERENCE_PROVIDERS.includes(provider)) {
throw new Error("Unknown Inference provider");
}
if (!accessToken) {
throw new Error("Specifying an Inference provider requires an accessToken");
}

const modelId = (() => {
switch (provider) {
case "replicate":
return REPLICATE_MODEL_IDS[model];
case "sambanova":
return SAMBANOVA_MODEL_IDS[model];
case "together":
return TOGETHER_MODEL_IDS[model]?.id;
case "fal-ai":
return FAL_AI_MODEL_IDS[model];
default:
return model;
}
})();

if (!modelId) {
throw new Error(`Model ${model} is not supported for provider ${provider}`);
}

model = modelId;
}

const binary = "data" in args && !!args.data;

Expand All @@ -77,6 +104,9 @@ export async function makeRequestOptions(
if (dont_load_model) {
headers["X-Load-Model"] = "0";
}
if (provider === "replicate") {
headers["Prefer"] = "wait";
}

let url = (() => {
if (endpointUrl && isUrl(model)) {
Expand All @@ -89,8 +119,33 @@ export async function makeRequestOptions(
if (endpointUrl) {
return endpointUrl;
}
if (task) {
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`;
if (forceTask) {
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${forceTask}/${model}`;
}
if (provider) {
if (!accessToken) {
throw new Error("Specifying an Inference provider requires an accessToken");
}
if (accessToken.startsWith("hf_")) {
/// TODO we wil proxy the request server-side (using our own keys) and handle billing for it on the user's HF account.
throw new Error("Inference proxying is not implemented yet");
} else {
switch (provider) {
case "fal-ai":
return `${FAL_AI_API_BASE_URL}/${model}`;
case "replicate":
return `${REPLICATE_API_BASE_URL}/v1/models/${model}/predictions`;
case "sambanova":
return SAMBANOVA_API_BASE_URL;
case "together":
if (taskHint === "text-to-image") {
return `${TOGETHER_API_BASE_URL}/v1/images/generations`;
}
return TOGETHER_API_BASE_URL;
default:
break;
}
}
}

return `${HF_INFERENCE_API_BASE_URL}/models/${model}`;
Expand All @@ -99,6 +154,9 @@ export async function makeRequestOptions(
if (chatCompletion && !url.endsWith("/chat/completions")) {
url += "/v1/chat/completions";
}
if (provider === "together" && taskHint === "text-generation" && !chatCompletion) {
url += "/v1/completions";
}

/**
* For edge runtimes, leave 'credentials' undefined, otherwise cloudflare workers will error
Expand All @@ -116,9 +174,11 @@ export async function makeRequestOptions(
body: binary
? args.data
: JSON.stringify({
...(otherArgs.model && isUrl(otherArgs.model) ? omit(otherArgs, "model") : otherArgs),
...((otherArgs.model && isUrl(otherArgs.model)) || provider === "replicate" || provider === "fal-ai"
? omit(otherArgs, "model")
: { ...otherArgs, model }),
}),
...(credentials && { credentials }),
...(credentials ? { credentials } : undefined),
signal: options?.signal,
};

Expand Down
17 changes: 17 additions & 0 deletions packages/inference/src/providers/fal-ai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import type { ModelId } from "../types";

export const FAL_AI_API_BASE_URL = "https://fal.run";

type FalAiId = string;

/**
* Mapping from HF model ID -> fal.ai app id
*/
export const FAL_AI_MODEL_IDS: Record<ModelId, FalAiId> = {
/** text-to-image */
"black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev",

/** automatic-speech-recognition */
"openai/whisper-large-v3": "fal-ai/whisper",
};
21 changes: 21 additions & 0 deletions packages/inference/src/providers/replicate.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import type { ModelId } from "../types";

export const REPLICATE_API_BASE_URL = "https://api.replicate.com";

type ReplicateId = string;

/**
* Mapping from HF model ID -> Replicate model ID
*
* Available models can be fetched with:
* ```
* curl -s \
* -H "Authorization: Bearer $REPLICATE_API_TOKEN" \
* 'https://api.replicate.com/v1/models'
* ```
*/
export const REPLICATE_MODEL_IDS: Record<ModelId, ReplicateId> = {
/** text-to-image */
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
"ByteDance/SDXL-Lightning": "bytedance/sdxl-lightning-4step",
};
32 changes: 32 additions & 0 deletions packages/inference/src/providers/sambanova.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import type { ModelId } from "../types";

export const SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";

/**
* Note for reviewers: our goal would be to ask Sambanova to support
* our model ids too, so we don't have to define a mapping
* or keep it up-to-date.
*
* As a fallback, if the above is not possible, ask Sambanova to
* provide the mapping as an fetchable API.
*/
type SambanovaId = string;

/**
* https://community.sambanova.ai/t/supported-models/193
*/
export const SAMBANOVA_MODEL_IDS: Record<ModelId, SambanovaId> = {
/** Chat completion / conversational */
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
"Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct",
"Qwen/QwQ-32B-Preview": "QwQ-32B-Preview",
"meta-llama/Llama-3.3-70B-Instruct": "Meta-Llama-3.3-70B-Instruct",
"meta-llama/Llama-3.2-1B": "Meta-Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.2-3B": "Meta-Llama-3.2-3B-Instruct",
"meta-llama/Llama-3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct",
"meta-llama/Llama-3.2-90B-Vision-Instruct": "Llama-3.2-90B-Vision-Instruct",
"meta-llama/Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct",
"meta-llama/Llama-3.1-70B-Instruct": "Meta-Llama-3.1-70B-Instruct",
"meta-llama/Llama-3.1-405B-Instruct": "Meta-Llama-3.1-405B-Instruct",
"meta-llama/Llama-Guard-3-8B": "Meta-Llama-Guard-3-8B",
};
60 changes: 60 additions & 0 deletions packages/inference/src/providers/together.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import type { ModelId } from "../types";

export const TOGETHER_API_BASE_URL = "https://api.together.xyz";

/**
* Same comment as in sambanova.ts
*/
type TogetherId = string;

/**
* https://docs.together.ai/reference/models-1
*/
export const TOGETHER_MODEL_IDS: Record<
ModelId,
{ id: TogetherId; type: "chat" | "embedding" | "image" | "language" | "moderation" }
> = {
/** text-to-image */
"black-forest-labs/FLUX.1-Canny-dev": { id: "black-forest-labs/FLUX.1-canny", type: "image" },
"black-forest-labs/FLUX.1-Depth-dev": { id: "black-forest-labs/FLUX.1-depth", type: "image" },
"black-forest-labs/FLUX.1-dev": { id: "black-forest-labs/FLUX.1-dev", type: "image" },
"black-forest-labs/FLUX.1-Redux-dev": { id: "black-forest-labs/FLUX.1-redux", type: "image" },
"black-forest-labs/FLUX.1-schnell": { id: "black-forest-labs/FLUX.1-pro", type: "image" },
"stabilityai/stable-diffusion-xl-base-1.0": { id: "stabilityai/stable-diffusion-xl-base-1.0", type: "image" },

/** chat completion */
"databricks/dbrx-instruct": { id: "databricks/dbrx-instruct", type: "chat" },
"deepseek-ai/deepseek-llm-67b-chat": { id: "deepseek-ai/deepseek-llm-67b-chat", type: "chat" },
"google/gemma-2-9b-it": { id: "google/gemma-2-9b-it", type: "chat" },
"google/gemma-2b-it": { id: "google/gemma-2-27b-it", type: "chat" },
"llava-hf/llava-v1.6-mistral-7b-hf": { id: "llava-hf/llava-v1.6-mistral-7b-hf", type: "chat" },
"meta-llama/Llama-2-13b-chat-hf": { id: "meta-llama/Llama-2-13b-chat-hf", type: "chat" },
"meta-llama/Llama-2-70b-hf": { id: "meta-llama/Llama-2-70b-hf", type: "language" },
"meta-llama/Llama-2-7b-chat-hf": { id: "meta-llama/Llama-2-7b-chat-hf", type: "chat" },
"meta-llama/Llama-3.2-11B-Vision-Instruct": { id: "meta-llama/Llama-Vision-Free", type: "chat" },
"meta-llama/Llama-3.2-3B-Instruct": { id: "meta-llama/Llama-3.2-3B-Instruct-Turbo", type: "chat" },
"meta-llama/Llama-3.2-90B-Vision-Instruct": { id: "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", type: "chat" },
"meta-llama/Llama-3.3-70B-Instruct": { id: "meta-llama/Llama-3.3-70B-Instruct-Turbo", type: "chat" },
"meta-llama/Meta-Llama-3-70B-Instruct": { id: "meta-llama/Llama-3-70b-chat-hf", type: "chat" },
"meta-llama/Meta-Llama-3-8B-Instruct": { id: "togethercomputer/Llama-3-8b-chat-hf-int4", type: "chat" },
"meta-llama/Meta-Llama-3.1-405B-Instruct": { id: "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", type: "chat" },
"meta-llama/Meta-Llama-3.1-70B-Instruct": { id: "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", type: "chat" },
"meta-llama/Meta-Llama-3.1-8B-Instruct": { id: "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo-128K", type: "chat" },
"microsoft/WizardLM-2-8x22B": { id: "microsoft/WizardLM-2-8x22B", type: "chat" },
"mistralai/Mistral-7B-Instruct-v0.3": { id: "mistralai/Mistral-7B-Instruct-v0.3", type: "chat" },
"mistralai/Mixtral-8x22B-Instruct-v0.1": { id: "mistralai/Mixtral-8x22B-Instruct-v0.1", type: "chat" },
"mistralai/Mixtral-8x7B-Instruct-v0.1": { id: "mistralai/Mixtral-8x7B-Instruct-v0.1", type: "chat" },
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": { id: "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", type: "chat" },
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF": { id: "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF", type: "chat" },
"Qwen/Qwen2-72B-Instruct": { id: "Qwen/Qwen2-72B-Instruct", type: "chat" },
"Qwen/Qwen2.5-72B-Instruct": { id: "Qwen/Qwen2.5-72B-Instruct-Turbo", type: "chat" },
"Qwen/Qwen2.5-7B-Instruct": { id: "Qwen/Qwen2.5-7B-Instruct-Turbo", type: "chat" },
"Qwen/Qwen2.5-Coder-32B-Instruct": { id: "Qwen/Qwen2.5-Coder-32B-Instruct", type: "chat" },
"Qwen/QwQ-32B-Preview": { id: "Qwen/QwQ-32B-Preview", type: "chat" },
"scb10x/llama-3-typhoon-v1.5-8b-instruct": { id: "scb10x/scb10x-llama3-typhoon-v1-5-8b-instruct", type: "chat" },
"scb10x/llama-3-typhoon-v1.5x-70b-instruct-awq": { id: "scb10x/scb10x-llama3-typhoon-v1-5x-4f316", type: "chat" },

/** text-generation */
"meta-llama/Meta-Llama-3-8B": { id: "meta-llama/Meta-Llama-3-8B", type: "language" },
"mistralai/Mixtral-8x7B-v0.1": { id: "mistralai/Mixtral-8x7B-v0.1", type: "language" },
};
Loading

0 comments on commit 86b1f2e

Please sign in to comment.