Skip to content

Commit

Permalink
[Inference] [Providers] Enforce task in mapping + expose them (#1109)
Browse files Browse the repository at this point in the history
# TL;DR

- Add task metadata to the HF id -> Provider id mappings, to forbid the
usage of a chat model with the `textToImage` inference function for
example
- Expose the supported models mappings in `index.ts`
  • Loading branch information
SBrandeis authored Jan 17, 2025
1 parent 044942d commit f1fe42c
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 114 deletions.
4 changes: 4 additions & 0 deletions packages/inference/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
export { HfInference, HfInferenceEndpoint } from "./HfInference";
export { InferenceOutputError } from "./lib/InferenceOutputError";
export { FAL_AI_SUPPORTED_MODEL_IDS } from "./providers/fal-ai";
export { REPLICATE_SUPPORTED_MODEL_IDS } from "./providers/replicate";
export { SAMBANOVA_SUPPORTED_MODEL_IDS } from "./providers/sambanova";
export { TOGETHER_SUPPORTED_MODEL_IDS } from "./providers/together";
export * from "./types";
export * from "./tasks";
38 changes: 24 additions & 14 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { HF_HUB_URL, HF_INFERENCE_API_URL } from "../config";
import { FAL_AI_API_BASE_URL, FAL_AI_MODEL_IDS } from "../providers/fal-ai";
import { REPLICATE_API_BASE_URL, REPLICATE_MODEL_IDS } from "../providers/replicate";
import { SAMBANOVA_API_BASE_URL, SAMBANOVA_MODEL_IDS } from "../providers/sambanova";
import { TOGETHER_API_BASE_URL, TOGETHER_MODEL_IDS } from "../providers/together";
import { FAL_AI_API_BASE_URL, FAL_AI_SUPPORTED_MODEL_IDS } from "../providers/fal-ai";
import { REPLICATE_API_BASE_URL, REPLICATE_SUPPORTED_MODEL_IDS } from "../providers/replicate";
import { SAMBANOVA_API_BASE_URL, SAMBANOVA_SUPPORTED_MODEL_IDS } from "../providers/sambanova";
import { TOGETHER_API_BASE_URL, TOGETHER_SUPPORTED_MODEL_IDS } from "../providers/together";
import type { InferenceProvider } from "../types";
import type { InferenceTask, Options, RequestArgs } from "../types";
import { isUrl } from "./isUrl";
Expand Down Expand Up @@ -50,13 +50,13 @@ export async function makeRequestOptions(
let model: string;
if (!maybeModel) {
if (taskHint) {
model = mapModel({ model: await loadDefaultModel(taskHint), provider });
model = mapModel({ model: await loadDefaultModel(taskHint), provider, taskHint, chatCompletion });
} else {
throw new Error("No model provided, and no default model found for this task");
/// TODO : change error message ^
}
} else {
model = mapModel({ model: maybeModel, provider });
model = mapModel({ model: maybeModel, provider, taskHint, chatCompletion });
}

/// If accessToken is passed, it should take precedence over includeCredentials
Expand Down Expand Up @@ -143,24 +143,34 @@ export async function makeRequestOptions(
return { url, info };
}

function mapModel(params: { model: string; provider: InferenceProvider }): string {
function mapModel(params: {
model: string;
provider: InferenceProvider;
taskHint: InferenceTask | undefined;
chatCompletion: boolean | undefined;
}): string {
if (params.provider === "hf-inference") {
return params.model;
}
if (!params.taskHint) {
throw new Error("taskHint must be specified when using a third-party provider");
}
const task = params.taskHint === "text-generation" && params.chatCompletion ? "conversational" : params.taskHint;
const model = (() => {
switch (params.provider) {
case "fal-ai":
return FAL_AI_MODEL_IDS[params.model];
return FAL_AI_SUPPORTED_MODEL_IDS[task]?.[params.model];
case "replicate":
return REPLICATE_MODEL_IDS[params.model];
return REPLICATE_SUPPORTED_MODEL_IDS[task]?.[params.model];
case "sambanova":
return SAMBANOVA_MODEL_IDS[params.model];
return SAMBANOVA_SUPPORTED_MODEL_IDS[task]?.[params.model];
case "together":
return TOGETHER_MODEL_IDS[params.model]?.id;
case "hf-inference":
return params.model;
return TOGETHER_SUPPORTED_MODEL_IDS[task]?.[params.model];
}
})();

if (!model) {
throw new Error(`Model ${params.model} is not supported for provider ${params.provider}`);
throw new Error(`Model ${params.model} is not supported for task ${task} and provider ${params.provider}`);
}
return model;
}
Expand Down
20 changes: 9 additions & 11 deletions packages/inference/src/providers/fal-ai.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import type { ModelId } from "../types";
import type { ProviderMapping } from "./types";

export const FAL_AI_API_BASE_URL = "https://fal.run";

type FalAiId = string;

/**
* Mapping from HF model ID -> fal.ai app id
*/
export const FAL_AI_MODEL_IDS: Partial<Record<ModelId, FalAiId>> = {
/** text-to-image */
"black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev",

/** automatic-speech-recognition */
"openai/whisper-large-v3": "fal-ai/whisper",
export const FAL_AI_SUPPORTED_MODEL_IDS: ProviderMapping<FalAiId> = {
"text-to-image": {
"black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev",
},
"automatic-speech-recognition": {
"openai/whisper-large-v3": "fal-ai/whisper",
},
};
26 changes: 10 additions & 16 deletions packages/inference/src/providers/replicate.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
import type { ModelId } from "../types";
import type { ProviderMapping } from "./types";

export const REPLICATE_API_BASE_URL = "https://api.replicate.com";

type ReplicateId = string;

/**
* Mapping from HF model ID -> Replicate model ID
*
* Available models can be fetched with:
* ```
* curl -s \
* -H "Authorization: Bearer $REPLICATE_API_TOKEN" \
* 'https://api.replicate.com/v1/models'
* ```
*/
export const REPLICATE_MODEL_IDS: Partial<Record<ModelId, ReplicateId>> = {
/** text-to-image */
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
"ByteDance/SDXL-Lightning":
"bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637",
export const REPLICATE_SUPPORTED_MODEL_IDS: ProviderMapping<ReplicateId> = {
"text-to-image": {
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
"ByteDance/SDXL-Lightning":
"bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637",
},
// "text-to-speech": {
// "SWivid/F5-TTS": "x-lance/f5-tts:87faf6dd7a692dd82043f662e76369cab126a2cf1937e25a9d41e0b834fd230e"
// },
};
41 changes: 16 additions & 25 deletions packages/inference/src/providers/sambanova.ts
Original file line number Diff line number Diff line change
@@ -1,32 +1,23 @@
import type { ModelId } from "../types";
import type { ProviderMapping } from "./types";

export const SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";

/**
* Note for reviewers: our goal would be to ask Sambanova to support
* our model ids too, so we don't have to define a mapping
* or keep it up-to-date.
*
* As a fallback, if the above is not possible, ask Sambanova to
* provide the mapping as an fetchable API.
*/
type SambanovaId = string;

/**
* https://community.sambanova.ai/t/supported-models/193
*/
export const SAMBANOVA_MODEL_IDS: Partial<Record<ModelId, SambanovaId>> = {
export const SAMBANOVA_SUPPORTED_MODEL_IDS: ProviderMapping<SambanovaId> = {
/** Chat completion / conversational */
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
"Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct",
"Qwen/QwQ-32B-Preview": "QwQ-32B-Preview",
"meta-llama/Llama-3.3-70B-Instruct": "Meta-Llama-3.3-70B-Instruct",
"meta-llama/Llama-3.2-1B": "Meta-Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.2-3B": "Meta-Llama-3.2-3B-Instruct",
"meta-llama/Llama-3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct",
"meta-llama/Llama-3.2-90B-Vision-Instruct": "Llama-3.2-90B-Vision-Instruct",
"meta-llama/Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct",
"meta-llama/Llama-3.1-70B-Instruct": "Meta-Llama-3.1-70B-Instruct",
"meta-llama/Llama-3.1-405B-Instruct": "Meta-Llama-3.1-405B-Instruct",
"meta-llama/Llama-Guard-3-8B": "Meta-Llama-Guard-3-8B",
conversational: {
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
"Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct",
"Qwen/QwQ-32B-Preview": "QwQ-32B-Preview",
"meta-llama/Llama-3.3-70B-Instruct": "Meta-Llama-3.3-70B-Instruct",
"meta-llama/Llama-3.2-1B-Instruct": "Meta-Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.2-3B-Instruct": "Meta-Llama-3.2-3B-Instruct",
"meta-llama/Llama-3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct",
"meta-llama/Llama-3.2-90B-Vision-Instruct": "Llama-3.2-90B-Vision-Instruct",
"meta-llama/Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct",
"meta-llama/Llama-3.1-70B-Instruct": "Meta-Llama-3.1-70B-Instruct",
"meta-llama/Llama-3.1-405B-Instruct": "Meta-Llama-3.1-405B-Instruct",
"meta-llama/Llama-Guard-3-8B": "Meta-Llama-Guard-3-8B",
},
};
93 changes: 46 additions & 47 deletions packages/inference/src/providers/together.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { ModelId } from "../types";
import type { ProviderMapping } from "./types";

export const TOGETHER_API_BASE_URL = "https://api.together.xyz";

Expand All @@ -10,50 +10,49 @@ type TogetherId = string;
/**
* https://docs.together.ai/reference/models-1
*/
export const TOGETHER_MODEL_IDS: Partial<
Record<ModelId, { id: TogetherId; type: "chat" | "embedding" | "image" | "language" | "moderation" }>
> = {
/** text-to-image */
"black-forest-labs/FLUX.1-Canny-dev": { id: "black-forest-labs/FLUX.1-canny", type: "image" },
"black-forest-labs/FLUX.1-Depth-dev": { id: "black-forest-labs/FLUX.1-depth", type: "image" },
"black-forest-labs/FLUX.1-dev": { id: "black-forest-labs/FLUX.1-dev", type: "image" },
"black-forest-labs/FLUX.1-Redux-dev": { id: "black-forest-labs/FLUX.1-redux", type: "image" },
"black-forest-labs/FLUX.1-schnell": { id: "black-forest-labs/FLUX.1-pro", type: "image" },
"stabilityai/stable-diffusion-xl-base-1.0": { id: "stabilityai/stable-diffusion-xl-base-1.0", type: "image" },

/** chat completion */
"databricks/dbrx-instruct": { id: "databricks/dbrx-instruct", type: "chat" },
"deepseek-ai/deepseek-llm-67b-chat": { id: "deepseek-ai/deepseek-llm-67b-chat", type: "chat" },
"google/gemma-2-9b-it": { id: "google/gemma-2-9b-it", type: "chat" },
"google/gemma-2b-it": { id: "google/gemma-2-27b-it", type: "chat" },
"llava-hf/llava-v1.6-mistral-7b-hf": { id: "llava-hf/llava-v1.6-mistral-7b-hf", type: "chat" },
"meta-llama/Llama-2-13b-chat-hf": { id: "meta-llama/Llama-2-13b-chat-hf", type: "chat" },
"meta-llama/Llama-2-70b-hf": { id: "meta-llama/Llama-2-70b-hf", type: "language" },
"meta-llama/Llama-2-7b-chat-hf": { id: "meta-llama/Llama-2-7b-chat-hf", type: "chat" },
"meta-llama/Llama-3.2-11B-Vision-Instruct": { id: "meta-llama/Llama-Vision-Free", type: "chat" },
"meta-llama/Llama-3.2-3B-Instruct": { id: "meta-llama/Llama-3.2-3B-Instruct-Turbo", type: "chat" },
"meta-llama/Llama-3.2-90B-Vision-Instruct": { id: "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", type: "chat" },
"meta-llama/Llama-3.3-70B-Instruct": { id: "meta-llama/Llama-3.3-70B-Instruct-Turbo", type: "chat" },
"meta-llama/Meta-Llama-3-70B-Instruct": { id: "meta-llama/Llama-3-70b-chat-hf", type: "chat" },
"meta-llama/Meta-Llama-3-8B-Instruct": { id: "togethercomputer/Llama-3-8b-chat-hf-int4", type: "chat" },
"meta-llama/Meta-Llama-3.1-405B-Instruct": { id: "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", type: "chat" },
"meta-llama/Meta-Llama-3.1-70B-Instruct": { id: "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", type: "chat" },
"meta-llama/Meta-Llama-3.1-8B-Instruct": { id: "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo-128K", type: "chat" },
"microsoft/WizardLM-2-8x22B": { id: "microsoft/WizardLM-2-8x22B", type: "chat" },
"mistralai/Mistral-7B-Instruct-v0.3": { id: "mistralai/Mistral-7B-Instruct-v0.3", type: "chat" },
"mistralai/Mixtral-8x22B-Instruct-v0.1": { id: "mistralai/Mixtral-8x22B-Instruct-v0.1", type: "chat" },
"mistralai/Mixtral-8x7B-Instruct-v0.1": { id: "mistralai/Mixtral-8x7B-Instruct-v0.1", type: "chat" },
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": { id: "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", type: "chat" },
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF": { id: "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF", type: "chat" },
"Qwen/Qwen2-72B-Instruct": { id: "Qwen/Qwen2-72B-Instruct", type: "chat" },
"Qwen/Qwen2.5-72B-Instruct": { id: "Qwen/Qwen2.5-72B-Instruct-Turbo", type: "chat" },
"Qwen/Qwen2.5-7B-Instruct": { id: "Qwen/Qwen2.5-7B-Instruct-Turbo", type: "chat" },
"Qwen/Qwen2.5-Coder-32B-Instruct": { id: "Qwen/Qwen2.5-Coder-32B-Instruct", type: "chat" },
"Qwen/QwQ-32B-Preview": { id: "Qwen/QwQ-32B-Preview", type: "chat" },
"scb10x/llama-3-typhoon-v1.5-8b-instruct": { id: "scb10x/scb10x-llama3-typhoon-v1-5-8b-instruct", type: "chat" },
"scb10x/llama-3-typhoon-v1.5x-70b-instruct-awq": { id: "scb10x/scb10x-llama3-typhoon-v1-5x-4f316", type: "chat" },

/** text-generation */
"meta-llama/Meta-Llama-3-8B": { id: "meta-llama/Meta-Llama-3-8B", type: "language" },
"mistralai/Mixtral-8x7B-v0.1": { id: "mistralai/Mixtral-8x7B-v0.1", type: "language" },
export const TOGETHER_SUPPORTED_MODEL_IDS: ProviderMapping<TogetherId> = {
"text-to-image": {
"black-forest-labs/FLUX.1-Canny-dev": "black-forest-labs/FLUX.1-canny",
"black-forest-labs/FLUX.1-Depth-dev": "black-forest-labs/FLUX.1-depth",
"black-forest-labs/FLUX.1-dev": "black-forest-labs/FLUX.1-dev",
"black-forest-labs/FLUX.1-Redux-dev": "black-forest-labs/FLUX.1-redux",
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/FLUX.1-pro",
"stabilityai/stable-diffusion-xl-base-1.0": "stabilityai/stable-diffusion-xl-base-1.0",
},
conversational: {
"databricks/dbrx-instruct": "databricks/dbrx-instruct",
"deepseek-ai/deepseek-llm-67b-chat": "deepseek-ai/deepseek-llm-67b-chat",
"google/gemma-2-9b-it": "google/gemma-2-9b-it",
"google/gemma-2b-it": "google/gemma-2-27b-it",
"llava-hf/llava-v1.6-mistral-7b-hf": "llava-hf/llava-v1.6-mistral-7b-hf",
"meta-llama/Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
"meta-llama/Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf",
"meta-llama/Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
"meta-llama/Llama-3.2-11B-Vision-Instruct": "meta-llama/Llama-Vision-Free",
"meta-llama/Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo",
"meta-llama/Llama-3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
"meta-llama/Llama-3.3-70B-Instruct": "meta-llama/Llama-3.3-70B-Instruct-Turbo",
"meta-llama/Meta-Llama-3-70B-Instruct": "meta-llama/Llama-3-70b-chat-hf",
"meta-llama/Meta-Llama-3-8B-Instruct": "togethercomputer/Llama-3-8b-chat-hf-int4",
"meta-llama/Meta-Llama-3.1-405B-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
"meta-llama/Meta-Llama-3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
"meta-llama/Meta-Llama-3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo-128K",
"microsoft/WizardLM-2-8x22B": "microsoft/WizardLM-2-8x22B",
"mistralai/Mistral-7B-Instruct-v0.3": "mistralai/Mistral-7B-Instruct-v0.3",
"mistralai/Mixtral-8x22B-Instruct-v0.1": "mistralai/Mixtral-8x22B-Instruct-v0.1",
"mistralai/Mixtral-8x7B-Instruct-v0.1": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF": "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
"Qwen/Qwen2-72B-Instruct": "Qwen/Qwen2-72B-Instruct",
"Qwen/Qwen2.5-72B-Instruct": "Qwen/Qwen2.5-72B-Instruct-Turbo",
"Qwen/Qwen2.5-7B-Instruct": "Qwen/Qwen2.5-7B-Instruct-Turbo",
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen/Qwen2.5-Coder-32B-Instruct",
"Qwen/QwQ-32B-Preview": "Qwen/QwQ-32B-Preview",
"scb10x/llama-3-typhoon-v1.5-8b-instruct": "scb10x/scb10x-llama3-typhoon-v1-5-8b-instruct",
"scb10x/llama-3-typhoon-v1.5x-70b-instruct-awq": "scb10x/scb10x-llama3-typhoon-v1-5x-4f316",
},
"text-generation": {
"meta-llama/Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B",
"mistralai/Mixtral-8x7B-v0.1": "mistralai/Mixtral-8x7B-v0.1",
},
};
5 changes: 5 additions & 0 deletions packages/inference/src/providers/types.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import type { InferenceTask, ModelId } from "../types";

export type ProviderMapping<ProviderId extends string> = Partial<
Record<InferenceTask | "conversational", Partial<Record<ModelId, ProviderId>>>
>;
26 changes: 25 additions & 1 deletion packages/inference/test/HfInference.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { expect, it, describe, assert } from "vitest";

import type { ChatCompletionStreamOutput } from "@huggingface/tasks";

import { HfInference } from "../src";
import { chatCompletion, HfInference } from "../src";
import "./vcr";
import { readTestFile } from "./test-files";

Expand Down Expand Up @@ -820,6 +820,15 @@ describe.concurrent("HfInference", () => {
});
expect(res).toBeInstanceOf(Blob);
});

it.skip("textToSpeech versioned", async () => {
const res = await client.textToSpeech({
model: "SWivid/F5-TTS",
provider: "replicate",
inputs: "Hello, how are you?",
});
expect(res).toBeInstanceOf(Blob);
});
},
TIMEOUT
);
Expand Down Expand Up @@ -911,4 +920,19 @@ describe.concurrent("HfInference", () => {
},
TIMEOUT
);

describe.concurrent("3rd party providers", () => {
it("chatCompletion - fails with unsupported model", async () => {
expect(
chatCompletion({
model: "black-forest-labs/Flux.1-dev",
provider: "together",
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
accessToken: env.HF_TOGETHER_KEY,
})
).rejects.toThrowError(
"Model black-forest-labs/Flux.1-dev is not supported for task conversational and provider together"
);
});
});
});

0 comments on commit f1fe42c

Please sign in to comment.