diff --git a/packages/inference/src/index.ts b/packages/inference/src/index.ts index 566b5c1ab..3934a0493 100644 --- a/packages/inference/src/index.ts +++ b/packages/inference/src/index.ts @@ -1,4 +1,8 @@ export { HfInference, HfInferenceEndpoint } from "./HfInference"; export { InferenceOutputError } from "./lib/InferenceOutputError"; +export { FAL_AI_SUPPORTED_MODEL_IDS } from "./providers/fal-ai"; +export { REPLICATE_SUPPORTED_MODEL_IDS } from "./providers/replicate"; +export { SAMBANOVA_SUPPORTED_MODEL_IDS } from "./providers/sambanova"; +export { TOGETHER_SUPPORTED_MODEL_IDS } from "./providers/together"; export * from "./types"; export * from "./tasks"; diff --git a/packages/inference/src/lib/makeRequestOptions.ts b/packages/inference/src/lib/makeRequestOptions.ts index 31913568f..233297e1e 100644 --- a/packages/inference/src/lib/makeRequestOptions.ts +++ b/packages/inference/src/lib/makeRequestOptions.ts @@ -1,8 +1,8 @@ import { HF_HUB_URL, HF_INFERENCE_API_URL } from "../config"; -import { FAL_AI_API_BASE_URL, FAL_AI_MODEL_IDS } from "../providers/fal-ai"; -import { REPLICATE_API_BASE_URL, REPLICATE_MODEL_IDS } from "../providers/replicate"; -import { SAMBANOVA_API_BASE_URL, SAMBANOVA_MODEL_IDS } from "../providers/sambanova"; -import { TOGETHER_API_BASE_URL, TOGETHER_MODEL_IDS } from "../providers/together"; +import { FAL_AI_API_BASE_URL, FAL_AI_SUPPORTED_MODEL_IDS } from "../providers/fal-ai"; +import { REPLICATE_API_BASE_URL, REPLICATE_SUPPORTED_MODEL_IDS } from "../providers/replicate"; +import { SAMBANOVA_API_BASE_URL, SAMBANOVA_SUPPORTED_MODEL_IDS } from "../providers/sambanova"; +import { TOGETHER_API_BASE_URL, TOGETHER_SUPPORTED_MODEL_IDS } from "../providers/together"; import type { InferenceProvider } from "../types"; import type { InferenceTask, Options, RequestArgs } from "../types"; import { isUrl } from "./isUrl"; @@ -50,13 +50,13 @@ export async function makeRequestOptions( let model: string; if (!maybeModel) { if (taskHint) { - model = mapModel({ model: await loadDefaultModel(taskHint), provider }); + model = mapModel({ model: await loadDefaultModel(taskHint), provider, taskHint, chatCompletion }); } else { throw new Error("No model provided, and no default model found for this task"); /// TODO : change error message ^ } } else { - model = mapModel({ model: maybeModel, provider }); + model = mapModel({ model: maybeModel, provider, taskHint, chatCompletion }); } /// If accessToken is passed, it should take precedence over includeCredentials @@ -143,24 +143,34 @@ export async function makeRequestOptions( return { url, info }; } -function mapModel(params: { model: string; provider: InferenceProvider }): string { +function mapModel(params: { + model: string; + provider: InferenceProvider; + taskHint: InferenceTask | undefined; + chatCompletion: boolean | undefined; +}): string { + if (params.provider === "hf-inference") { + return params.model; + } + if (!params.taskHint) { + throw new Error("taskHint must be specified when using a third-party provider"); + } + const task = params.taskHint === "text-generation" && params.chatCompletion ? "conversational" : params.taskHint; const model = (() => { switch (params.provider) { case "fal-ai": - return FAL_AI_MODEL_IDS[params.model]; + return FAL_AI_SUPPORTED_MODEL_IDS[task]?.[params.model]; case "replicate": - return REPLICATE_MODEL_IDS[params.model]; + return REPLICATE_SUPPORTED_MODEL_IDS[task]?.[params.model]; case "sambanova": - return SAMBANOVA_MODEL_IDS[params.model]; + return SAMBANOVA_SUPPORTED_MODEL_IDS[task]?.[params.model]; case "together": - return TOGETHER_MODEL_IDS[params.model]?.id; - case "hf-inference": - return params.model; + return TOGETHER_SUPPORTED_MODEL_IDS[task]?.[params.model]; } })(); if (!model) { - throw new Error(`Model ${params.model} is not supported for provider ${params.provider}`); + throw new Error(`Model ${params.model} is not supported for task ${task} and provider ${params.provider}`); } return model; } diff --git a/packages/inference/src/providers/fal-ai.ts b/packages/inference/src/providers/fal-ai.ts index af513f5b0..93d1b4e24 100644 --- a/packages/inference/src/providers/fal-ai.ts +++ b/packages/inference/src/providers/fal-ai.ts @@ -1,17 +1,15 @@ -import type { ModelId } from "../types"; +import type { ProviderMapping } from "./types"; export const FAL_AI_API_BASE_URL = "https://fal.run"; type FalAiId = string; -/** - * Mapping from HF model ID -> fal.ai app id - */ -export const FAL_AI_MODEL_IDS: Partial> = { - /** text-to-image */ - "black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell", - "black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev", - - /** automatic-speech-recognition */ - "openai/whisper-large-v3": "fal-ai/whisper", +export const FAL_AI_SUPPORTED_MODEL_IDS: ProviderMapping = { + "text-to-image": { + "black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell", + "black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev", + }, + "automatic-speech-recognition": { + "openai/whisper-large-v3": "fal-ai/whisper", + }, }; diff --git a/packages/inference/src/providers/replicate.ts b/packages/inference/src/providers/replicate.ts index 856c07d30..376d6dc01 100644 --- a/packages/inference/src/providers/replicate.ts +++ b/packages/inference/src/providers/replicate.ts @@ -1,22 +1,16 @@ -import type { ModelId } from "../types"; +import type { ProviderMapping } from "./types"; export const REPLICATE_API_BASE_URL = "https://api.replicate.com"; type ReplicateId = string; -/** - * Mapping from HF model ID -> Replicate model ID - * - * Available models can be fetched with: - * ``` - * curl -s \ - * -H "Authorization: Bearer $REPLICATE_API_TOKEN" \ - * 'https://api.replicate.com/v1/models' - * ``` - */ -export const REPLICATE_MODEL_IDS: Partial> = { - /** text-to-image */ - "black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell", - "ByteDance/SDXL-Lightning": - "bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637", +export const REPLICATE_SUPPORTED_MODEL_IDS: ProviderMapping = { + "text-to-image": { + "black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell", + "ByteDance/SDXL-Lightning": + "bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637", + }, + // "text-to-speech": { + // "SWivid/F5-TTS": "x-lance/f5-tts:87faf6dd7a692dd82043f662e76369cab126a2cf1937e25a9d41e0b834fd230e" + // }, }; diff --git a/packages/inference/src/providers/sambanova.ts b/packages/inference/src/providers/sambanova.ts index 70d6d57cf..be32901e5 100644 --- a/packages/inference/src/providers/sambanova.ts +++ b/packages/inference/src/providers/sambanova.ts @@ -1,32 +1,23 @@ -import type { ModelId } from "../types"; +import type { ProviderMapping } from "./types"; export const SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai"; -/** - * Note for reviewers: our goal would be to ask Sambanova to support - * our model ids too, so we don't have to define a mapping - * or keep it up-to-date. - * - * As a fallback, if the above is not possible, ask Sambanova to - * provide the mapping as an fetchable API. - */ type SambanovaId = string; -/** - * https://community.sambanova.ai/t/supported-models/193 - */ -export const SAMBANOVA_MODEL_IDS: Partial> = { +export const SAMBANOVA_SUPPORTED_MODEL_IDS: ProviderMapping = { /** Chat completion / conversational */ - "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct", - "Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct", - "Qwen/QwQ-32B-Preview": "QwQ-32B-Preview", - "meta-llama/Llama-3.3-70B-Instruct": "Meta-Llama-3.3-70B-Instruct", - "meta-llama/Llama-3.2-1B": "Meta-Llama-3.2-1B-Instruct", - "meta-llama/Llama-3.2-3B": "Meta-Llama-3.2-3B-Instruct", - "meta-llama/Llama-3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct", - "meta-llama/Llama-3.2-90B-Vision-Instruct": "Llama-3.2-90B-Vision-Instruct", - "meta-llama/Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct", - "meta-llama/Llama-3.1-70B-Instruct": "Meta-Llama-3.1-70B-Instruct", - "meta-llama/Llama-3.1-405B-Instruct": "Meta-Llama-3.1-405B-Instruct", - "meta-llama/Llama-Guard-3-8B": "Meta-Llama-Guard-3-8B", + conversational: { + "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct", + "Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct", + "Qwen/QwQ-32B-Preview": "QwQ-32B-Preview", + "meta-llama/Llama-3.3-70B-Instruct": "Meta-Llama-3.3-70B-Instruct", + "meta-llama/Llama-3.2-1B-Instruct": "Meta-Llama-3.2-1B-Instruct", + "meta-llama/Llama-3.2-3B-Instruct": "Meta-Llama-3.2-3B-Instruct", + "meta-llama/Llama-3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct", + "meta-llama/Llama-3.2-90B-Vision-Instruct": "Llama-3.2-90B-Vision-Instruct", + "meta-llama/Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct", + "meta-llama/Llama-3.1-70B-Instruct": "Meta-Llama-3.1-70B-Instruct", + "meta-llama/Llama-3.1-405B-Instruct": "Meta-Llama-3.1-405B-Instruct", + "meta-llama/Llama-Guard-3-8B": "Meta-Llama-Guard-3-8B", + }, }; diff --git a/packages/inference/src/providers/together.ts b/packages/inference/src/providers/together.ts index 41bac5a5b..2869c8cca 100644 --- a/packages/inference/src/providers/together.ts +++ b/packages/inference/src/providers/together.ts @@ -1,4 +1,4 @@ -import type { ModelId } from "../types"; +import type { ProviderMapping } from "./types"; export const TOGETHER_API_BASE_URL = "https://api.together.xyz"; @@ -10,50 +10,49 @@ type TogetherId = string; /** * https://docs.together.ai/reference/models-1 */ -export const TOGETHER_MODEL_IDS: Partial< - Record -> = { - /** text-to-image */ - "black-forest-labs/FLUX.1-Canny-dev": { id: "black-forest-labs/FLUX.1-canny", type: "image" }, - "black-forest-labs/FLUX.1-Depth-dev": { id: "black-forest-labs/FLUX.1-depth", type: "image" }, - "black-forest-labs/FLUX.1-dev": { id: "black-forest-labs/FLUX.1-dev", type: "image" }, - "black-forest-labs/FLUX.1-Redux-dev": { id: "black-forest-labs/FLUX.1-redux", type: "image" }, - "black-forest-labs/FLUX.1-schnell": { id: "black-forest-labs/FLUX.1-pro", type: "image" }, - "stabilityai/stable-diffusion-xl-base-1.0": { id: "stabilityai/stable-diffusion-xl-base-1.0", type: "image" }, - - /** chat completion */ - "databricks/dbrx-instruct": { id: "databricks/dbrx-instruct", type: "chat" }, - "deepseek-ai/deepseek-llm-67b-chat": { id: "deepseek-ai/deepseek-llm-67b-chat", type: "chat" }, - "google/gemma-2-9b-it": { id: "google/gemma-2-9b-it", type: "chat" }, - "google/gemma-2b-it": { id: "google/gemma-2-27b-it", type: "chat" }, - "llava-hf/llava-v1.6-mistral-7b-hf": { id: "llava-hf/llava-v1.6-mistral-7b-hf", type: "chat" }, - "meta-llama/Llama-2-13b-chat-hf": { id: "meta-llama/Llama-2-13b-chat-hf", type: "chat" }, - "meta-llama/Llama-2-70b-hf": { id: "meta-llama/Llama-2-70b-hf", type: "language" }, - "meta-llama/Llama-2-7b-chat-hf": { id: "meta-llama/Llama-2-7b-chat-hf", type: "chat" }, - "meta-llama/Llama-3.2-11B-Vision-Instruct": { id: "meta-llama/Llama-Vision-Free", type: "chat" }, - "meta-llama/Llama-3.2-3B-Instruct": { id: "meta-llama/Llama-3.2-3B-Instruct-Turbo", type: "chat" }, - "meta-llama/Llama-3.2-90B-Vision-Instruct": { id: "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", type: "chat" }, - "meta-llama/Llama-3.3-70B-Instruct": { id: "meta-llama/Llama-3.3-70B-Instruct-Turbo", type: "chat" }, - "meta-llama/Meta-Llama-3-70B-Instruct": { id: "meta-llama/Llama-3-70b-chat-hf", type: "chat" }, - "meta-llama/Meta-Llama-3-8B-Instruct": { id: "togethercomputer/Llama-3-8b-chat-hf-int4", type: "chat" }, - "meta-llama/Meta-Llama-3.1-405B-Instruct": { id: "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", type: "chat" }, - "meta-llama/Meta-Llama-3.1-70B-Instruct": { id: "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", type: "chat" }, - "meta-llama/Meta-Llama-3.1-8B-Instruct": { id: "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo-128K", type: "chat" }, - "microsoft/WizardLM-2-8x22B": { id: "microsoft/WizardLM-2-8x22B", type: "chat" }, - "mistralai/Mistral-7B-Instruct-v0.3": { id: "mistralai/Mistral-7B-Instruct-v0.3", type: "chat" }, - "mistralai/Mixtral-8x22B-Instruct-v0.1": { id: "mistralai/Mixtral-8x22B-Instruct-v0.1", type: "chat" }, - "mistralai/Mixtral-8x7B-Instruct-v0.1": { id: "mistralai/Mixtral-8x7B-Instruct-v0.1", type: "chat" }, - "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": { id: "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", type: "chat" }, - "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF": { id: "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF", type: "chat" }, - "Qwen/Qwen2-72B-Instruct": { id: "Qwen/Qwen2-72B-Instruct", type: "chat" }, - "Qwen/Qwen2.5-72B-Instruct": { id: "Qwen/Qwen2.5-72B-Instruct-Turbo", type: "chat" }, - "Qwen/Qwen2.5-7B-Instruct": { id: "Qwen/Qwen2.5-7B-Instruct-Turbo", type: "chat" }, - "Qwen/Qwen2.5-Coder-32B-Instruct": { id: "Qwen/Qwen2.5-Coder-32B-Instruct", type: "chat" }, - "Qwen/QwQ-32B-Preview": { id: "Qwen/QwQ-32B-Preview", type: "chat" }, - "scb10x/llama-3-typhoon-v1.5-8b-instruct": { id: "scb10x/scb10x-llama3-typhoon-v1-5-8b-instruct", type: "chat" }, - "scb10x/llama-3-typhoon-v1.5x-70b-instruct-awq": { id: "scb10x/scb10x-llama3-typhoon-v1-5x-4f316", type: "chat" }, - - /** text-generation */ - "meta-llama/Meta-Llama-3-8B": { id: "meta-llama/Meta-Llama-3-8B", type: "language" }, - "mistralai/Mixtral-8x7B-v0.1": { id: "mistralai/Mixtral-8x7B-v0.1", type: "language" }, +export const TOGETHER_SUPPORTED_MODEL_IDS: ProviderMapping = { + "text-to-image": { + "black-forest-labs/FLUX.1-Canny-dev": "black-forest-labs/FLUX.1-canny", + "black-forest-labs/FLUX.1-Depth-dev": "black-forest-labs/FLUX.1-depth", + "black-forest-labs/FLUX.1-dev": "black-forest-labs/FLUX.1-dev", + "black-forest-labs/FLUX.1-Redux-dev": "black-forest-labs/FLUX.1-redux", + "black-forest-labs/FLUX.1-schnell": "black-forest-labs/FLUX.1-pro", + "stabilityai/stable-diffusion-xl-base-1.0": "stabilityai/stable-diffusion-xl-base-1.0", + }, + conversational: { + "databricks/dbrx-instruct": "databricks/dbrx-instruct", + "deepseek-ai/deepseek-llm-67b-chat": "deepseek-ai/deepseek-llm-67b-chat", + "google/gemma-2-9b-it": "google/gemma-2-9b-it", + "google/gemma-2b-it": "google/gemma-2-27b-it", + "llava-hf/llava-v1.6-mistral-7b-hf": "llava-hf/llava-v1.6-mistral-7b-hf", + "meta-llama/Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf", + "meta-llama/Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf", + "meta-llama/Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf", + "meta-llama/Llama-3.2-11B-Vision-Instruct": "meta-llama/Llama-Vision-Free", + "meta-llama/Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo", + "meta-llama/Llama-3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", + "meta-llama/Llama-3.3-70B-Instruct": "meta-llama/Llama-3.3-70B-Instruct-Turbo", + "meta-llama/Meta-Llama-3-70B-Instruct": "meta-llama/Llama-3-70b-chat-hf", + "meta-llama/Meta-Llama-3-8B-Instruct": "togethercomputer/Llama-3-8b-chat-hf-int4", + "meta-llama/Meta-Llama-3.1-405B-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", + "meta-llama/Meta-Llama-3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + "meta-llama/Meta-Llama-3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo-128K", + "microsoft/WizardLM-2-8x22B": "microsoft/WizardLM-2-8x22B", + "mistralai/Mistral-7B-Instruct-v0.3": "mistralai/Mistral-7B-Instruct-v0.3", + "mistralai/Mixtral-8x22B-Instruct-v0.1": "mistralai/Mixtral-8x22B-Instruct-v0.1", + "mistralai/Mixtral-8x7B-Instruct-v0.1": "mistralai/Mixtral-8x7B-Instruct-v0.1", + "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", + "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF": "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF", + "Qwen/Qwen2-72B-Instruct": "Qwen/Qwen2-72B-Instruct", + "Qwen/Qwen2.5-72B-Instruct": "Qwen/Qwen2.5-72B-Instruct-Turbo", + "Qwen/Qwen2.5-7B-Instruct": "Qwen/Qwen2.5-7B-Instruct-Turbo", + "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen/Qwen2.5-Coder-32B-Instruct", + "Qwen/QwQ-32B-Preview": "Qwen/QwQ-32B-Preview", + "scb10x/llama-3-typhoon-v1.5-8b-instruct": "scb10x/scb10x-llama3-typhoon-v1-5-8b-instruct", + "scb10x/llama-3-typhoon-v1.5x-70b-instruct-awq": "scb10x/scb10x-llama3-typhoon-v1-5x-4f316", + }, + "text-generation": { + "meta-llama/Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B", + "mistralai/Mixtral-8x7B-v0.1": "mistralai/Mixtral-8x7B-v0.1", + }, }; diff --git a/packages/inference/src/providers/types.ts b/packages/inference/src/providers/types.ts new file mode 100644 index 000000000..c037b46cc --- /dev/null +++ b/packages/inference/src/providers/types.ts @@ -0,0 +1,5 @@ +import type { InferenceTask, ModelId } from "../types"; + +export type ProviderMapping = Partial< + Record>> +>; diff --git a/packages/inference/test/HfInference.spec.ts b/packages/inference/test/HfInference.spec.ts index 545db71ea..4a989462f 100644 --- a/packages/inference/test/HfInference.spec.ts +++ b/packages/inference/test/HfInference.spec.ts @@ -2,7 +2,7 @@ import { expect, it, describe, assert } from "vitest"; import type { ChatCompletionStreamOutput } from "@huggingface/tasks"; -import { HfInference } from "../src"; +import { chatCompletion, HfInference } from "../src"; import "./vcr"; import { readTestFile } from "./test-files"; @@ -820,6 +820,15 @@ describe.concurrent("HfInference", () => { }); expect(res).toBeInstanceOf(Blob); }); + + it.skip("textToSpeech versioned", async () => { + const res = await client.textToSpeech({ + model: "SWivid/F5-TTS", + provider: "replicate", + inputs: "Hello, how are you?", + }); + expect(res).toBeInstanceOf(Blob); + }); }, TIMEOUT ); @@ -911,4 +920,19 @@ describe.concurrent("HfInference", () => { }, TIMEOUT ); + + describe.concurrent("3rd party providers", () => { + it("chatCompletion - fails with unsupported model", async () => { + expect( + chatCompletion({ + model: "black-forest-labs/Flux.1-dev", + provider: "together", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], + accessToken: env.HF_TOGETHER_KEY, + }) + ).rejects.toThrowError( + "Model black-forest-labs/Flux.1-dev is not supported for task conversational and provider together" + ); + }); + }); });