Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] compatibility with third-party Inference providers #1077

Merged
merged 25 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ jobs:
run: VCR_MODE=playback pnpm --filter ...[${{ steps.since.outputs.SINCE }}] test
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_FAL_KEY: dummy
HF_REPLICATE_KEY: dummy
HF_SAMBANOVA_KEY: dummy
HF_TOGETHER_KEY: dummy

browser:
runs-on: ubuntu-latest
Expand Down Expand Up @@ -77,6 +81,10 @@ jobs:
run: VCR_MODE=playback pnpm --filter ...[${{ steps.since.outputs.SINCE }}] test:browser
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_FAL_KEY: dummy
HF_REPLICATE_KEY: dummy
HF_SAMBANOVA_KEY: dummy
HF_TOGETHER_KEY: dummy

e2e:
runs-on: ubuntu-latest
Expand Down Expand Up @@ -140,3 +148,7 @@ jobs:
env:
NPM_CONFIG_REGISTRY: http://localhost:4874/
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_FAL_KEY: dummy
HF_REPLICATE_KEY: dummy
HF_SAMBANOVA_KEY: dummy
HF_TOGETHER_KEY: dummy
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// Programatically interact with the Hub

await createRepo({
repo: {type: "model", name: "my-user/nlp-model"},
repo: { type: "model", name: "my-user/nlp-model" },
accessToken: HF_TOKEN
});

Expand Down Expand Up @@ -53,11 +53,13 @@ await inference.textToImage({

This is a collection of JS libraries to interact with the Hugging Face API, with TS types included.

- [@huggingface/inference](packages/inference/README.md): Use Inference Endpoints (dedicated) and Inference API (serverless) to make calls to 100,000+ Machine Learning models
- [@huggingface/inference](packages/inference/README.md): Use Inference API (serverless) and Inference Endpoints (dedicated) to make calls to 100,000+ Machine Learning models
- [@huggingface/hub](packages/hub/README.md): Interact with huggingface.co to create or delete repos and commit / download files
- [@huggingface/agents](packages/agents/README.md): Interact with HF models through a natural language interface
- [@huggingface/gguf](packages/gguf/README.md): A GGUF parser that works on remotely hosted files.
- [@huggingface/dduf](packages/dduf/README.md): Similar package for DDUF (DDUF Diffusers Unified Format)
- [@huggingface/tasks](packages/tasks/README.md): The definition files and source-of-truth for the Hub's main primitives like pipeline tasks, model libraries, etc.
- [@huggingface/jinja](packages/jinja/README.md): A minimalistic JS implementation of the Jinja templating engine, to be used for ML chat templates.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @xenova fyi

- [@huggingface/space-header](packages/space-header/README.md): Use the Space `mini_header` outside Hugging Face


Expand Down Expand Up @@ -165,7 +167,7 @@ await inference.imageToText({
const gpt2 = inference.endpoint('https://xyz.eu-west-1.aws.endpoints.huggingface.cloud/gpt2');
const { generated_text } = await gpt2.textGeneration({inputs: 'The answer to the universe is'});

//Chat Completion
// Chat Completion
const llamaEndpoint = inference.endpoint(
"https://api-inference.huggingface.co/models/meta-llama/Llama-3.1-8B-Instruct"
);
Expand All @@ -185,7 +187,7 @@ import { createRepo, uploadFile, deleteFiles } from "@huggingface/hub";
const HF_TOKEN = "hf_...";

await createRepo({
repo: "my-user/nlp-model", // or {type: "model", name: "my-user/nlp-test"},
repo: "my-user/nlp-model", // or { type: "model", name: "my-user/nlp-test" },
accessToken: HF_TOKEN
});

Expand All @@ -200,7 +202,7 @@ await uploadFile({
});

await deleteFiles({
repo: {type: "space", name: "my-user/my-space"}, // or "spaces/my-user/my-space"
repo: { type: "space", name: "my-user/my-space" }, // or "spaces/my-user/my-space"
accessToken: HF_TOKEN,
paths: ["README.md", ".gitattributes"]
});
Expand All @@ -209,7 +211,7 @@ await deleteFiles({
### @huggingface/agents example

```ts
import {HfAgent, LLMFromHub, defaultTools} from '@huggingface/agents';
import { HfAgent, LLMFromHub, defaultTools } from '@huggingface/agents';

const HF_TOKEN = "hf_...";

Expand Down
2 changes: 1 addition & 1 deletion packages/inference/LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2022 Tim Mikeladze
Copyright (c) 2022 Tim Mikeladze and the Hugging Face team

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
4 changes: 2 additions & 2 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# 🤗 Hugging Face Inference Endpoints

A Typescript powered wrapper for the Hugging Face Inference Endpoints API. Learn more about Inference Endpoints at [Hugging Face](https://huggingface.co/inference-endpoints).
It works with both [Inference API (serverless)](https://huggingface.co/docs/api-inference/index) and [Inference Endpoints (dedicated)](https://huggingface.co/docs/inference-endpoints/index).
A Typescript powered wrapper for the Hugging Face Inference API (serverless), Inference Endpoints (dedicated), and third-party Inference Providers.
It works with [Inference API (serverless)](https://huggingface.co/docs/api-inference/index) and [Inference Endpoints (dedicated)](https://huggingface.co/docs/inference-endpoints/index), and even with supported third-party Inference Providers.

Check out the [full documentation](https://huggingface.co/docs/huggingface.js/inference/README).

Expand Down
94 changes: 77 additions & 17 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import type { InferenceTask, Options, RequestArgs } from "../types";
import { FAL_AI_API_BASE_URL, FAL_AI_MODEL_IDS } from "../providers/fal-ai";
import { REPLICATE_API_BASE_URL, REPLICATE_MODEL_IDS } from "../providers/replicate";
import { SAMBANOVA_API_BASE_URL, SAMBANOVA_MODEL_IDS } from "../providers/sambanova";
import { TOGETHER_API_BASE_URL, TOGETHER_MODEL_IDS } from "../providers/together";
import { INFERENCE_PROVIDERS, type InferenceTask, type Options, type RequestArgs } from "../types";
import { omit } from "../utils/omit";
import { HF_HUB_URL } from "./getDefaultTask";
import { isUrl } from "./isUrl";

const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";

/**
* Loaded from huggingface.co/api/tasks if needed
* Lazy-loaded from huggingface.co/api/tasks when needed
* Used to determine the default model to use when it's not user defined
*/
let tasks: Record<string, { models: { id: string }[] }> | null = null;

Expand All @@ -26,21 +31,14 @@ export async function makeRequestOptions(
chatCompletion?: boolean;
}
): Promise<{ url: string; info: RequestInit }> {
const { accessToken, endpointUrl, ...otherArgs } = args;
const { accessToken, endpointUrl, provider, ...otherArgs } = args;
let { model } = args;
const {
forceTask: task,
includeCredentials,
taskHint,
wait_for_model,
use_cache,
dont_load_model,
chatCompletion,
} = options ?? {};
const { forceTask, includeCredentials, taskHint, wait_for_model, use_cache, dont_load_model, chatCompletion } =
options ?? {};

const headers: Record<string, string> = {};
if (accessToken) {
headers["Authorization"] = `Bearer ${accessToken}`;
headers["Authorization"] = provider === "fal-ai" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
}

if (!model && !tasks && taskHint) {
Expand All @@ -61,6 +59,35 @@ export async function makeRequestOptions(
if (!model) {
throw new Error("No model provided, and no default model found for this task");
}
if (provider) {
if (!INFERENCE_PROVIDERS.includes(provider)) {
throw new Error("Unknown Inference provider");
}
if (!accessToken) {
throw new Error("Specifying an Inference provider requires an accessToken");
}

const modelId = (() => {
switch (provider) {
case "replicate":
return REPLICATE_MODEL_IDS[model];
case "sambanova":
return SAMBANOVA_MODEL_IDS[model];
case "together":
return TOGETHER_MODEL_IDS[model]?.id;
case "fal-ai":
return FAL_AI_MODEL_IDS[model];
default:
return model;
}
Comment on lines +71 to +81
Copy link
Member

@coyotte508 coyotte508 Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should maybe at least return REPLICATE_MODEL_IDS[model] ?? model; to fallback to the provided id in case the user directly provided a replicate model id and not a HF model ID. Same with the others.

Note that we could maybe maintain a mapping in the backend and in case of errors try to load it -only once (like we do for default models associated to tasks). Just a thought for the future, but it would enable new mappings without updating the lib.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should maybe at least return REPLICATE_MODEL_IDS[model] ?? model; to fallback to the provided id in case the user directly provided a replicate model id and not a HF model ID. Same with the others.

cc @julien-c - we discussed it and decided to stick to HF model IDs for now

Note that we could maybe maintain a mapping in the backend and in case of errors try to load it -only once (like we do for default models associated to tasks). Just a thought for the future, but it would enable new mappings without updating the lib.

Yes, we definitely want a way for 3rd party providers to expose the mapping HF model ID -> Provider ID that does not require hardcoding / updating the huggingface.js lib

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we discussed it and decided to stick to HF model IDs for now

Yes. simpler to always be "Hub-centric"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes and i now remember, the ?? model in my mind was to work out of the box for models that have the same id on the inference provider as the HF id, NOT to work if you pass the provider's (different) id

})();

if (!modelId) {
throw new Error(`Model ${model} is not supported for provider ${provider}`);
}

model = modelId;
}

const binary = "data" in args && !!args.data;

Expand All @@ -77,6 +104,9 @@ export async function makeRequestOptions(
if (dont_load_model) {
headers["X-Load-Model"] = "0";
}
if (provider === "replicate") {
headers["Prefer"] = "wait";
}

let url = (() => {
if (endpointUrl && isUrl(model)) {
Expand All @@ -89,8 +119,33 @@ export async function makeRequestOptions(
if (endpointUrl) {
return endpointUrl;
}
if (task) {
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`;
if (forceTask) {
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${forceTask}/${model}`;
}
if (provider) {
if (!accessToken) {
throw new Error("Specifying an Inference provider requires an accessToken");
}
if (accessToken.startsWith("hf_")) {
/// TODO we wil proxy the request server-side (using our own keys) and handle billing for it on the user's HF account.
throw new Error("Inference proxying is not implemented yet");
} else {
Comment on lines +130 to +131
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switch (provider) {
case "fal-ai":
return `${FAL_AI_API_BASE_URL}/${model}`;
case "replicate":
return `${REPLICATE_API_BASE_URL}/v1/models/${model}/predictions`;
case "sambanova":
return SAMBANOVA_API_BASE_URL;
case "together":
if (taskHint === "text-to-image") {
return `${TOGETHER_API_BASE_URL}/v1/images/generations`;
}
return TOGETHER_API_BASE_URL;
default:
break;
}
}
}

return `${HF_INFERENCE_API_BASE_URL}/models/${model}`;
Expand All @@ -99,6 +154,9 @@ export async function makeRequestOptions(
if (chatCompletion && !url.endsWith("/chat/completions")) {
url += "/v1/chat/completions";
}
if (provider === "together" && taskHint === "text-generation" && !chatCompletion) {
url += "/v1/completions";
}

coyotte508 marked this conversation as resolved.
Show resolved Hide resolved
/**
* For edge runtimes, leave 'credentials' undefined, otherwise cloudflare workers will error
Expand All @@ -116,9 +174,11 @@ export async function makeRequestOptions(
body: binary
? args.data
: JSON.stringify({
...(otherArgs.model && isUrl(otherArgs.model) ? omit(otherArgs, "model") : otherArgs),
...((otherArgs.model && isUrl(otherArgs.model)) || provider === "replicate" || provider === "fal-ai"
? omit(otherArgs, "model")
: { ...otherArgs, model }),
}),
...(credentials && { credentials }),
...(credentials ? { credentials } : undefined),
signal: options?.signal,
};

Expand Down
17 changes: 17 additions & 0 deletions packages/inference/src/providers/fal-ai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import type { ModelId } from "../types";

export const FAL_AI_API_BASE_URL = "https://fal.run";

type FalAiId = string;

/**
* Mapping from HF model ID -> fal.ai app id
*/
export const FAL_AI_MODEL_IDS: Record<ModelId, FalAiId> = {
/** text-to-image */
"black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev",

/** automatic-speech-recognition */
"openai/whisper-large-v3": "fal-ai/whisper",
};
Comment on lines +10 to +17
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idk if it has been discussed somewhere but I feel that at some point these mappings should be maintained as an API lazy-loaded by the inference client (same as what we are doing for the tasks).
This way, we would be able to retrieve it in huggingface_hub + could be quickly updated without client releases.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scroll up: #1077 (comment) :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙈

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great minds think alike

21 changes: 21 additions & 0 deletions packages/inference/src/providers/replicate.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import type { ModelId } from "../types";

export const REPLICATE_API_BASE_URL = "https://api.replicate.com";

type ReplicateId = string;

/**
* Mapping from HF model ID -> Replicate model ID
*
* Available models can be fetched with:
* ```
* curl -s \
* -H "Authorization: Bearer $REPLICATE_API_TOKEN" \
* 'https://api.replicate.com/v1/models'
* ```
*/
export const REPLICATE_MODEL_IDS: Record<ModelId, ReplicateId> = {
/** text-to-image */
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
"ByteDance/SDXL-Lightning": "bytedance/sdxl-lightning-4step",
};
32 changes: 32 additions & 0 deletions packages/inference/src/providers/sambanova.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import type { ModelId } from "../types";

export const SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";

/**
* Note for reviewers: our goal would be to ask Sambanova to support
* our model ids too, so we don't have to define a mapping
* or keep it up-to-date.
*
* As a fallback, if the above is not possible, ask Sambanova to
* provide the mapping as an fetchable API.
*/
type SambanovaId = string;

/**
* https://community.sambanova.ai/t/supported-models/193
*/
export const SAMBANOVA_MODEL_IDS: Record<ModelId, SambanovaId> = {
/** Chat completion / conversational */
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
"Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct",
"Qwen/QwQ-32B-Preview": "QwQ-32B-Preview",
"meta-llama/Llama-3.3-70B-Instruct": "Meta-Llama-3.3-70B-Instruct",
"meta-llama/Llama-3.2-1B": "Meta-Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.2-3B": "Meta-Llama-3.2-3B-Instruct",
"meta-llama/Llama-3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct",
"meta-llama/Llama-3.2-90B-Vision-Instruct": "Llama-3.2-90B-Vision-Instruct",
"meta-llama/Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct",
"meta-llama/Llama-3.1-70B-Instruct": "Meta-Llama-3.1-70B-Instruct",
"meta-llama/Llama-3.1-405B-Instruct": "Meta-Llama-3.1-405B-Instruct",
"meta-llama/Llama-Guard-3-8B": "Meta-Llama-Guard-3-8B",
};
60 changes: 60 additions & 0 deletions packages/inference/src/providers/together.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import type { ModelId } from "../types";

export const TOGETHER_API_BASE_URL = "https://api.together.xyz";

/**
* Same comment as in sambanova.ts
*/
type TogetherId = string;

/**
* https://docs.together.ai/reference/models-1
*/
export const TOGETHER_MODEL_IDS: Record<
ModelId,
{ id: TogetherId; type: "chat" | "embedding" | "image" | "language" | "moderation" }
> = {
/** text-to-image */
"black-forest-labs/FLUX.1-Canny-dev": { id: "black-forest-labs/FLUX.1-canny", type: "image" },
"black-forest-labs/FLUX.1-Depth-dev": { id: "black-forest-labs/FLUX.1-depth", type: "image" },
"black-forest-labs/FLUX.1-dev": { id: "black-forest-labs/FLUX.1-dev", type: "image" },
"black-forest-labs/FLUX.1-Redux-dev": { id: "black-forest-labs/FLUX.1-redux", type: "image" },
"black-forest-labs/FLUX.1-schnell": { id: "black-forest-labs/FLUX.1-pro", type: "image" },
"stabilityai/stable-diffusion-xl-base-1.0": { id: "stabilityai/stable-diffusion-xl-base-1.0", type: "image" },

/** chat completion */
"databricks/dbrx-instruct": { id: "databricks/dbrx-instruct", type: "chat" },
"deepseek-ai/deepseek-llm-67b-chat": { id: "deepseek-ai/deepseek-llm-67b-chat", type: "chat" },
"google/gemma-2-9b-it": { id: "google/gemma-2-9b-it", type: "chat" },
"google/gemma-2b-it": { id: "google/gemma-2-27b-it", type: "chat" },
"llava-hf/llava-v1.6-mistral-7b-hf": { id: "llava-hf/llava-v1.6-mistral-7b-hf", type: "chat" },
"meta-llama/Llama-2-13b-chat-hf": { id: "meta-llama/Llama-2-13b-chat-hf", type: "chat" },
"meta-llama/Llama-2-70b-hf": { id: "meta-llama/Llama-2-70b-hf", type: "language" },
"meta-llama/Llama-2-7b-chat-hf": { id: "meta-llama/Llama-2-7b-chat-hf", type: "chat" },
"meta-llama/Llama-3.2-11B-Vision-Instruct": { id: "meta-llama/Llama-Vision-Free", type: "chat" },
"meta-llama/Llama-3.2-3B-Instruct": { id: "meta-llama/Llama-3.2-3B-Instruct-Turbo", type: "chat" },
"meta-llama/Llama-3.2-90B-Vision-Instruct": { id: "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", type: "chat" },
"meta-llama/Llama-3.3-70B-Instruct": { id: "meta-llama/Llama-3.3-70B-Instruct-Turbo", type: "chat" },
"meta-llama/Meta-Llama-3-70B-Instruct": { id: "meta-llama/Llama-3-70b-chat-hf", type: "chat" },
"meta-llama/Meta-Llama-3-8B-Instruct": { id: "togethercomputer/Llama-3-8b-chat-hf-int4", type: "chat" },
"meta-llama/Meta-Llama-3.1-405B-Instruct": { id: "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", type: "chat" },
"meta-llama/Meta-Llama-3.1-70B-Instruct": { id: "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", type: "chat" },
"meta-llama/Meta-Llama-3.1-8B-Instruct": { id: "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo-128K", type: "chat" },
"microsoft/WizardLM-2-8x22B": { id: "microsoft/WizardLM-2-8x22B", type: "chat" },
"mistralai/Mistral-7B-Instruct-v0.3": { id: "mistralai/Mistral-7B-Instruct-v0.3", type: "chat" },
"mistralai/Mixtral-8x22B-Instruct-v0.1": { id: "mistralai/Mixtral-8x22B-Instruct-v0.1", type: "chat" },
"mistralai/Mixtral-8x7B-Instruct-v0.1": { id: "mistralai/Mixtral-8x7B-Instruct-v0.1", type: "chat" },
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": { id: "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", type: "chat" },
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF": { id: "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF", type: "chat" },
"Qwen/Qwen2-72B-Instruct": { id: "Qwen/Qwen2-72B-Instruct", type: "chat" },
"Qwen/Qwen2.5-72B-Instruct": { id: "Qwen/Qwen2.5-72B-Instruct-Turbo", type: "chat" },
"Qwen/Qwen2.5-7B-Instruct": { id: "Qwen/Qwen2.5-7B-Instruct-Turbo", type: "chat" },
"Qwen/Qwen2.5-Coder-32B-Instruct": { id: "Qwen/Qwen2.5-Coder-32B-Instruct", type: "chat" },
"Qwen/QwQ-32B-Preview": { id: "Qwen/QwQ-32B-Preview", type: "chat" },
"scb10x/llama-3-typhoon-v1.5-8b-instruct": { id: "scb10x/scb10x-llama3-typhoon-v1-5-8b-instruct", type: "chat" },
"scb10x/llama-3-typhoon-v1.5x-70b-instruct-awq": { id: "scb10x/scb10x-llama3-typhoon-v1-5x-4f316", type: "chat" },

/** text-generation */
"meta-llama/Meta-Llama-3-8B": { id: "meta-llama/Meta-Llama-3-8B", type: "language" },
"mistralai/Mixtral-8x7B-v0.1": { id: "mistralai/Mixtral-8x7B-v0.1", type: "language" },
};
Loading
Loading