Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] compatibility with third-party Inference providers #1077

Merged
merged 25 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// Programatically interact with the Hub

await createRepo({
repo: {type: "model", name: "my-user/nlp-model"},
repo: { type: "model", name: "my-user/nlp-model" },
accessToken: HF_TOKEN
});

Expand Down Expand Up @@ -53,11 +53,13 @@ await inference.textToImage({

This is a collection of JS libraries to interact with the Hugging Face API, with TS types included.

- [@huggingface/inference](packages/inference/README.md): Use Inference Endpoints (dedicated) and Inference API (serverless) to make calls to 100,000+ Machine Learning models
- [@huggingface/inference](packages/inference/README.md): Use Inference API (serverless) and Inference Endpoints (dedicated) to make calls to 100,000+ Machine Learning models
- [@huggingface/hub](packages/hub/README.md): Interact with huggingface.co to create or delete repos and commit / download files
- [@huggingface/agents](packages/agents/README.md): Interact with HF models through a natural language interface
- [@huggingface/gguf](packages/gguf/README.md): A GGUF parser that works on remotely hosted files.
- [@huggingface/dduf](packages/dduf/README.md): Similar package for DDUF (DDUF Diffusers Unified Format)
- [@huggingface/tasks](packages/tasks/README.md): The definition files and source-of-truth for the Hub's main primitives like pipeline tasks, model libraries, etc.
- [@huggingface/jinja](packages/jinja/README.md): A minimalistic JS implementation of the Jinja templating engine, to be used for ML chat templates.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @xenova fyi

- [@huggingface/space-header](packages/space-header/README.md): Use the Space `mini_header` outside Hugging Face


Expand Down Expand Up @@ -165,7 +167,7 @@ await inference.imageToText({
const gpt2 = inference.endpoint('https://xyz.eu-west-1.aws.endpoints.huggingface.cloud/gpt2');
const { generated_text } = await gpt2.textGeneration({inputs: 'The answer to the universe is'});

//Chat Completion
// Chat Completion
const llamaEndpoint = inference.endpoint(
"https://api-inference.huggingface.co/models/meta-llama/Llama-3.1-8B-Instruct"
);
Expand All @@ -185,7 +187,7 @@ import { createRepo, uploadFile, deleteFiles } from "@huggingface/hub";
const HF_TOKEN = "hf_...";

await createRepo({
repo: "my-user/nlp-model", // or {type: "model", name: "my-user/nlp-test"},
repo: "my-user/nlp-model", // or { type: "model", name: "my-user/nlp-test" },
accessToken: HF_TOKEN
});

Expand All @@ -200,7 +202,7 @@ await uploadFile({
});

await deleteFiles({
repo: {type: "space", name: "my-user/my-space"}, // or "spaces/my-user/my-space"
repo: { type: "space", name: "my-user/my-space" }, // or "spaces/my-user/my-space"
accessToken: HF_TOKEN,
paths: ["README.md", ".gitattributes"]
});
Expand All @@ -209,7 +211,7 @@ await deleteFiles({
### @huggingface/agents example

```ts
import {HfAgent, LLMFromHub, defaultTools} from '@huggingface/agents';
import { HfAgent, LLMFromHub, defaultTools } from '@huggingface/agents';

const HF_TOKEN = "hf_...";

Expand Down
2 changes: 1 addition & 1 deletion packages/inference/LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2022 Tim Mikeladze
Copyright (c) 2022 Tim Mikeladze and the Hugging Face team

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
4 changes: 2 additions & 2 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# 🤗 Hugging Face Inference Endpoints

A Typescript powered wrapper for the Hugging Face Inference Endpoints API. Learn more about Inference Endpoints at [Hugging Face](https://huggingface.co/inference-endpoints).
It works with both [Inference API (serverless)](https://huggingface.co/docs/api-inference/index) and [Inference Endpoints (dedicated)](https://huggingface.co/docs/inference-endpoints/index).
A Typescript powered wrapper for the Hugging Face Inference API (serverless), Inference Endpoints (dedicated), and third-party Inference Providers.
It works with [Inference API (serverless)](https://huggingface.co/docs/api-inference/index) and [Inference Endpoints (dedicated)](https://huggingface.co/docs/inference-endpoints/index), and even with supported third-party Inference Providers.

Check out the [full documentation](https://huggingface.co/docs/huggingface.js/inference/README).

Expand Down
76 changes: 61 additions & 15 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import type { InferenceTask, Options, RequestArgs } from "../types";
import { REPLICATE_API_BASE_URL, REPLICATE_MODEL_IDS } from "../providers/replicate";
import { SAMBANOVA_API_BASE_URL, SAMBANOVA_MODEL_IDS } from "../providers/sambanova";
import { TOGETHER_API_BASE_URL, TOGETHER_MODEL_IDS } from "../providers/together";
import { INFERENCE_PROVIDERS, type InferenceTask, type Options, type RequestArgs } from "../types";
import { omit } from "../utils/omit";
import { HF_HUB_URL } from "./getDefaultTask";
import { isUrl } from "./isUrl";
Expand Down Expand Up @@ -26,17 +29,10 @@ export async function makeRequestOptions(
chatCompletion?: boolean;
}
): Promise<{ url: string; info: RequestInit }> {
const { accessToken, endpointUrl, ...otherArgs } = args;
const { accessToken, endpointUrl, provider, ...otherArgs } = args;
let { model } = args;
const {
forceTask: task,
includeCredentials,
taskHint,
wait_for_model,
use_cache,
dont_load_model,
chatCompletion,
} = options ?? {};
const { forceTask, includeCredentials, taskHint, wait_for_model, use_cache, dont_load_model, chatCompletion } =
options ?? {};

const headers: Record<string, string> = {};
if (accessToken) {
Expand All @@ -61,6 +57,27 @@ export async function makeRequestOptions(
if (!model) {
throw new Error("No model provided, and no default model found for this task");
}
if (provider) {
if (!INFERENCE_PROVIDERS.includes(provider)) {
throw new Error("Unknown Inference provider");
}
if (!accessToken) {
throw new Error("Specifying an Inference provider requires an accessToken");
}
switch (provider) {
case "replicate":
model = REPLICATE_MODEL_IDS[model];
break;
case "sambanova":
model = SAMBANOVA_MODEL_IDS[model];
break;
case "together":
model = TOGETHER_MODEL_IDS[model]?.id ?? model;
break;
default:
break;
}
SBrandeis marked this conversation as resolved.
Show resolved Hide resolved
}

const binary = "data" in args && !!args.data;

Expand All @@ -77,6 +94,9 @@ export async function makeRequestOptions(
if (dont_load_model) {
headers["X-Load-Model"] = "0";
}
if (provider === "replicate") {
headers["Prefer"] = "wait";
}

let url = (() => {
if (endpointUrl && isUrl(model)) {
Expand All @@ -89,8 +109,32 @@ export async function makeRequestOptions(
if (endpointUrl) {
return endpointUrl;
}
if (task) {
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`;
if (forceTask) {
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${forceTask}/${model}`;
}
if (provider) {
if (!accessToken) {
throw new Error("Specifying an Inference provider requires an accessToken");
}
if (accessToken.startsWith("hf_")) {
/// TODO we wil proxy the request server-side (using our own keys) and handle billing for it on the user's HF account.
throw new Error("Inference proxying is not implemented yet");
} else {
Comment on lines +130 to +131
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/// This is an external key
switch (provider) {
case "replicate":
return `${REPLICATE_API_BASE_URL}/v1/models/${model}/predictions`;
case "sambanova":
return SAMBANOVA_API_BASE_URL;
case "together":
if (taskHint === "text-to-image") {
return `${TOGETHER_API_BASE_URL}/v1/images/generations`;
}
return TOGETHER_API_BASE_URL;
default:
break;
}
}
}

return `${HF_INFERENCE_API_BASE_URL}/models/${model}`;
Expand All @@ -116,9 +160,11 @@ export async function makeRequestOptions(
body: binary
? args.data
: JSON.stringify({
...(otherArgs.model && isUrl(otherArgs.model) ? omit(otherArgs, "model") : otherArgs),
...((otherArgs.model && isUrl(otherArgs.model)) || provider === "replicate"
? omit(otherArgs, "model")
: { ...otherArgs, model }),
}),
...(credentials && { credentials }),
...(credentials ? { credentials } : undefined),
signal: options?.signal,
};

Expand Down
18 changes: 18 additions & 0 deletions packages/inference/src/providers/replicate.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import type { ModelId } from "../types";

export const REPLICATE_API_BASE_URL = "https://api.replicate.com";

/**
* Same comment as in sambanova.ts
*/
type ReplicateId = string;

/**
* curl -s \
* -H "Authorization: Bearer $REPLICATE_API_TOKEN" \
* https://api.replicate.com/v1/models
*/
export const REPLICATE_MODEL_IDS: Record<ModelId, ReplicateId> = {
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
"ByteDance/SDXL-Lightning": "bytedance/sdxl-lightning-4step",
};
31 changes: 31 additions & 0 deletions packages/inference/src/providers/sambanova.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import type { ModelId } from "../types";

export const SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";

/**
* Note for reviewers: our goal would be to ask Sambanova to support
* our model ids too, so we don't have to define a mapping
* or keep it up-to-date.
*
* As a fallback, if the above is not possible, ask Sambanova to
* provide the mapping as an fetchable API.
*/
type SambanovaId = string;

/**
* https://community.sambanova.ai/t/supported-models/193
*/
export const SAMBANOVA_MODEL_IDS: Record<ModelId, SambanovaId> = {
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
"Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct",
"Qwen/QwQ-32B-Preview": "QwQ-32B-Preview",
"meta-llama/Llama-3.3-70B-Instruct": "Meta-Llama-3.3-70B-Instruct",
"meta-llama/Llama-3.2-1B": "Meta-Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.2-3B": "Meta-Llama-3.2-3B-Instruct",
"meta-llama/Llama-3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct",
"meta-llama/Llama-3.2-90B-Vision-Instruct": "Llama-3.2-90B-Vision-Instruct",
"meta-llama/Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct",
"meta-llama/Llama-3.1-70B-Instruct": "Meta-Llama-3.1-70B-Instruct",
"meta-llama/Llama-3.1-405B-Instruct": "Meta-Llama-3.1-405B-Instruct",
"meta-llama/Llama-Guard-3-8B": "Meta-Llama-Guard-3-8B",
};
60 changes: 60 additions & 0 deletions packages/inference/src/providers/together.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import type { ModelId } from "../types";

export const TOGETHER_API_BASE_URL = "https://api.together.xyz";

/**
* Same comment as in sambanova.ts
*/
type TogetherId = string;

/**
* https://docs.together.ai/reference/models-1
*/
export const TOGETHER_MODEL_IDS: Record<
ModelId,
{ id: TogetherId; type: "chat" | "embedding" | "image" | "language" | "moderation" }
> = {
"BAAI/bge-base-en-v1.5": { id: "BAAI/bge-base-en-v1.5", type: "embedding" },
"black-forest-labs/FLUX.1-Canny-dev": { id: "black-forest-labs/FLUX.1-canny", type: "image" },
"black-forest-labs/FLUX.1-Depth-dev": { id: "black-forest-labs/FLUX.1-depth", type: "image" },
"black-forest-labs/FLUX.1-dev": { id: "black-forest-labs/FLUX.1-dev", type: "image" },
"black-forest-labs/FLUX.1-Redux-dev": { id: "black-forest-labs/FLUX.1-redux", type: "image" },
"black-forest-labs/FLUX.1-schnell": { id: "black-forest-labs/FLUX.1-pro", type: "image" },
"databricks/dbrx-instruct": { id: "databricks/dbrx-instruct", type: "chat" },
"deepseek-ai/deepseek-llm-67b-chat": { id: "deepseek-ai/deepseek-llm-67b-chat", type: "chat" },
"google/gemma-2-9b-it": { id: "google/gemma-2-9b-it", type: "chat" },
"google/gemma-2b-it": { id: "google/gemma-2-27b-it", type: "chat" },
"llava-hf/llava-v1.6-mistral-7b-hf": { id: "llava-hf/llava-v1.6-mistral-7b-hf", type: "chat" },
"meta-llama/Llama-2-13b-chat-hf": { id: "meta-llama/Llama-2-13b-chat-hf", type: "chat" },
"meta-llama/Llama-2-70b-hf": { id: "meta-llama/Llama-2-70b-hf", type: "language" },
"meta-llama/Llama-2-7b-chat-hf": { id: "meta-llama/Llama-2-7b-chat-hf", type: "chat" },
"meta-llama/Llama-3.2-11B-Vision-Instruct": { id: "meta-llama/Llama-Vision-Free", type: "chat" },
"meta-llama/Llama-3.2-3B-Instruct": { id: "meta-llama/Llama-3.2-3B-Instruct-Turbo", type: "chat" },
"meta-llama/Llama-3.2-90B-Vision-Instruct": { id: "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", type: "chat" },
"meta-llama/Llama-3.3-70B-Instruct": { id: "meta-llama/Llama-3.3-70B-Instruct-Turbo", type: "chat" },
"meta-llama/Llama-Guard-3-11B-Vision": { id: "meta-llama/Llama-Guard-3-11B-Vision-Turbo", type: "moderation" },
"meta-llama/LlamaGuard-7b": { id: "Meta-Llama/Llama-Guard-7b", type: "moderation" },
"meta-llama/Meta-Llama-3-70B-Instruct": { id: "meta-llama/Llama-3-70b-chat-hf", type: "chat" },
"meta-llama/Meta-Llama-3-8B": { id: "meta-llama/Meta-Llama-3-8B", type: "language" },
"meta-llama/Meta-Llama-3-8B-Instruct": { id: "togethercomputer/Llama-3-8b-chat-hf-int4", type: "chat" },
"meta-llama/Meta-Llama-3.1-405B-Instruct": { id: "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", type: "chat" },
"meta-llama/Meta-Llama-3.1-70B-Instruct": { id: "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", type: "chat" },
"meta-llama/Meta-Llama-3.1-8B-Instruct": { id: "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo-128K", type: "chat" },
"microsoft/WizardLM-2-8x22B": { id: "microsoft/WizardLM-2-8x22B", type: "chat" },
"mistralai/Mistral-7B-Instruct-v0.3": { id: "mistralai/Mistral-7B-Instruct-v0.3", type: "chat" },
"mistralai/Mixtral-8x22B-Instruct-v0.1": { id: "mistralai/Mixtral-8x22B-Instruct-v0.1", type: "chat" },
"mistralai/Mixtral-8x7B-Instruct-v0.1": { id: "mistralai/Mixtral-8x7B-Instruct-v0.1", type: "chat" },
"mistralai/Mixtral-8x7B-v0.1": { id: "mistralai/Mixtral-8x7B-v0.1", type: "language" },
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": { id: "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", type: "chat" },
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF": { id: "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF", type: "chat" },
"Qwen/Qwen2-72B-Instruct": { id: "Qwen/Qwen2-72B-Instruct", type: "chat" },
"Qwen/Qwen2.5-72B-Instruct": { id: "Qwen/Qwen2.5-72B-Instruct-Turbo", type: "chat" },
"Qwen/Qwen2.5-7B-Instruct": { id: "Qwen/Qwen2.5-7B-Instruct-Turbo", type: "chat" },
"Qwen/Qwen2.5-Coder-32B-Instruct": { id: "Qwen/Qwen2.5-Coder-32B-Instruct", type: "chat" },
"Qwen/QwQ-32B-Preview": { id: "Qwen/QwQ-32B-Preview", type: "chat" },
"scb10x/llama-3-typhoon-v1.5-8b-instruct": { id: "scb10x/scb10x-llama3-typhoon-v1-5-8b-instruct", type: "chat" },
"scb10x/llama-3-typhoon-v1.5x-70b-instruct-awq": { id: "scb10x/scb10x-llama3-typhoon-v1-5x-4f316", type: "chat" },
"stabilityai/stable-diffusion-xl-base-1.0": { id: "stabilityai/stable-diffusion-xl-base-1.0", type: "image" },
"togethercomputer/m2-bert-80M-32k-retrieval": { id: "togethercomputer/m2-bert-80M-32k-retrieval", type: "embedding" },
"togethercomputer/m2-bert-80M-8k-retrieval": { id: "togethercomputer/m2-bert-80M-8k-retrieval", type: "embedding" },
};
8 changes: 7 additions & 1 deletion packages/inference/src/tasks/custom/request.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,19 @@ export async function request<T>(
}

if (!response.ok) {
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
if (
["application/json", "application/problem+json"].some(
(contentType) => response.headers.get("Content-Type")?.startsWith(contentType)
)
) {
const output = await response.json();
if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
}
if (output.error) {
throw new Error(JSON.stringify(output.error));
} else {
throw new Error(output);
}
}
throw new Error("An error occurred while fetching the blob");
Expand Down
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/custom/streamingRequest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ export async function* streamingRequest<T>(
try {
while (true) {
const { done, value } = await reader.read();
if (done) return;
if (done) {
return;
}
onChunk(value);
for (const event of events) {
if (event.data.length > 0) {
Expand Down
Loading
Loading