Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] [Providers] Enforce task in mapping + expose them #1109

Merged
merged 8 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions packages/inference/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
export { HfInference, HfInferenceEndpoint } from "./HfInference";
export { InferenceOutputError } from "./lib/InferenceOutputError";
export { FAL_AI_SUPPORTED_MODEL_IDS } from "./providers/fal-ai";
export { REPLICATE_SUPPORTED_MODEL_IDS } from "./providers/replicate";
export { SAMBANOVA_SUPPORTED_MODEL_IDS } from "./providers/sambanova";
export { TOGETHER_SUPPORTED_MODEL_IDS } from "./providers/together";
Comment on lines +3 to +6
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(for moon's consumption, i presume?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

export * from "./types";
export * from "./tasks";
38 changes: 24 additions & 14 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { HF_HUB_URL, HF_INFERENCE_API_URL } from "../config";
import { FAL_AI_API_BASE_URL, FAL_AI_MODEL_IDS } from "../providers/fal-ai";
import { REPLICATE_API_BASE_URL, REPLICATE_MODEL_IDS } from "../providers/replicate";
import { SAMBANOVA_API_BASE_URL, SAMBANOVA_MODEL_IDS } from "../providers/sambanova";
import { TOGETHER_API_BASE_URL, TOGETHER_MODEL_IDS } from "../providers/together";
import { FAL_AI_API_BASE_URL, FAL_AI_SUPPORTED_MODEL_IDS } from "../providers/fal-ai";
import { REPLICATE_API_BASE_URL, REPLICATE_SUPPORTED_MODEL_IDS } from "../providers/replicate";
import { SAMBANOVA_API_BASE_URL, SAMBANOVA_SUPPORTED_MODEL_IDS } from "../providers/sambanova";
import { TOGETHER_API_BASE_URL, TOGETHER_SUPPORTED_MODEL_IDS } from "../providers/together";
import type { InferenceProvider } from "../types";
import type { InferenceTask, Options, RequestArgs } from "../types";
import { isUrl } from "./isUrl";
Expand Down Expand Up @@ -50,13 +50,13 @@ export async function makeRequestOptions(
let model: string;
if (!maybeModel) {
if (taskHint) {
model = mapModel({ model: await loadDefaultModel(taskHint), provider });
model = mapModel({ model: await loadDefaultModel(taskHint), provider, taskHint, chatCompletion });
} else {
throw new Error("No model provided, and no default model found for this task");
/// TODO : change error message ^
}
} else {
model = mapModel({ model: maybeModel, provider });
model = mapModel({ model: maybeModel, provider, taskHint, chatCompletion });
}

/// If accessToken is passed, it should take precedence over includeCredentials
Expand Down Expand Up @@ -143,24 +143,34 @@ export async function makeRequestOptions(
return { url, info };
}

function mapModel(params: { model: string; provider: InferenceProvider }): string {
function mapModel(params: {
model: string;
provider: InferenceProvider;
taskHint: InferenceTask | undefined;
chatCompletion: boolean | undefined;
}): string {
if (params.provider === "hf-inference") {
return params.model;
}
if (!params.taskHint) {
throw new Error("taskHint must be specified when using a third-party provider");
}
const task = params.taskHint === "text-generation" && params.chatCompletion ? "conversational" : params.taskHint;
const model = (() => {
switch (params.provider) {
case "fal-ai":
return FAL_AI_MODEL_IDS[params.model];
return FAL_AI_SUPPORTED_MODEL_IDS[task]?.[params.model];
case "replicate":
return REPLICATE_MODEL_IDS[params.model];
return REPLICATE_SUPPORTED_MODEL_IDS[task]?.[params.model];
case "sambanova":
return SAMBANOVA_MODEL_IDS[params.model];
return SAMBANOVA_SUPPORTED_MODEL_IDS[task]?.[params.model];
case "together":
return TOGETHER_MODEL_IDS[params.model]?.id;
case "hf-inference":
return params.model;
return TOGETHER_SUPPORTED_MODEL_IDS[task]?.[params.model];
}
})();

if (!model) {
throw new Error(`Model ${params.model} is not supported for provider ${params.provider}`);
throw new Error(`Model ${params.model} is not supported for task ${task} and provider ${params.provider}`);
}
return model;
}
Expand Down
20 changes: 9 additions & 11 deletions packages/inference/src/providers/fal-ai.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import type { ModelId } from "../types";
import type { ProviderMapping } from "./types";

export const FAL_AI_API_BASE_URL = "https://fal.run";

type FalAiId = string;

/**
* Mapping from HF model ID -> fal.ai app id
*/
export const FAL_AI_MODEL_IDS: Partial<Record<ModelId, FalAiId>> = {
/** text-to-image */
"black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev",

/** automatic-speech-recognition */
"openai/whisper-large-v3": "fal-ai/whisper",
export const FAL_AI_SUPPORTED_MODEL_IDS: ProviderMapping<FalAiId> = {
"text-to-image": {
"black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev",
},
"automatic-speech-recognition": {
"openai/whisper-large-v3": "fal-ai/whisper",
},
};
26 changes: 10 additions & 16 deletions packages/inference/src/providers/replicate.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
import type { ModelId } from "../types";
import type { ProviderMapping } from "./types";

export const REPLICATE_API_BASE_URL = "https://api.replicate.com";

type ReplicateId = string;

/**
* Mapping from HF model ID -> Replicate model ID
*
* Available models can be fetched with:
* ```
* curl -s \
* -H "Authorization: Bearer $REPLICATE_API_TOKEN" \
* 'https://api.replicate.com/v1/models'
* ```
*/
export const REPLICATE_MODEL_IDS: Partial<Record<ModelId, ReplicateId>> = {
/** text-to-image */
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
"ByteDance/SDXL-Lightning":
"bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637",
export const REPLICATE_SUPPORTED_MODEL_IDS: ProviderMapping<ReplicateId> = {
"text-to-image": {
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
"ByteDance/SDXL-Lightning":
"bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637",
},
// "text-to-speech": {
// "SWivid/F5-TTS": "x-lance/f5-tts:87faf6dd7a692dd82043f662e76369cab126a2cf1937e25a9d41e0b834fd230e"
// },
};
41 changes: 16 additions & 25 deletions packages/inference/src/providers/sambanova.ts
Original file line number Diff line number Diff line change
@@ -1,32 +1,23 @@
import type { ModelId } from "../types";
import type { ProviderMapping } from "./types";

export const SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";

/**
* Note for reviewers: our goal would be to ask Sambanova to support
* our model ids too, so we don't have to define a mapping
* or keep it up-to-date.
*
* As a fallback, if the above is not possible, ask Sambanova to
* provide the mapping as an fetchable API.
*/
type SambanovaId = string;

/**
* https://community.sambanova.ai/t/supported-models/193
*/
export const SAMBANOVA_MODEL_IDS: Partial<Record<ModelId, SambanovaId>> = {
export const SAMBANOVA_SUPPORTED_MODEL_IDS: ProviderMapping<SambanovaId> = {
/** Chat completion / conversational */
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
"Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct",
"Qwen/QwQ-32B-Preview": "QwQ-32B-Preview",
"meta-llama/Llama-3.3-70B-Instruct": "Meta-Llama-3.3-70B-Instruct",
"meta-llama/Llama-3.2-1B": "Meta-Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.2-3B": "Meta-Llama-3.2-3B-Instruct",
"meta-llama/Llama-3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct",
"meta-llama/Llama-3.2-90B-Vision-Instruct": "Llama-3.2-90B-Vision-Instruct",
"meta-llama/Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct",
"meta-llama/Llama-3.1-70B-Instruct": "Meta-Llama-3.1-70B-Instruct",
"meta-llama/Llama-3.1-405B-Instruct": "Meta-Llama-3.1-405B-Instruct",
"meta-llama/Llama-Guard-3-8B": "Meta-Llama-Guard-3-8B",
conversational: {
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
"Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct",
"Qwen/QwQ-32B-Preview": "QwQ-32B-Preview",
"meta-llama/Llama-3.3-70B-Instruct": "Meta-Llama-3.3-70B-Instruct",
"meta-llama/Llama-3.2-1B-Instruct": "Meta-Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.2-3B-Instruct": "Meta-Llama-3.2-3B-Instruct",
"meta-llama/Llama-3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct",
"meta-llama/Llama-3.2-90B-Vision-Instruct": "Llama-3.2-90B-Vision-Instruct",
"meta-llama/Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct",
"meta-llama/Llama-3.1-70B-Instruct": "Meta-Llama-3.1-70B-Instruct",
"meta-llama/Llama-3.1-405B-Instruct": "Meta-Llama-3.1-405B-Instruct",
"meta-llama/Llama-Guard-3-8B": "Meta-Llama-Guard-3-8B",
},
};
93 changes: 46 additions & 47 deletions packages/inference/src/providers/together.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { ModelId } from "../types";
import type { ProviderMapping } from "./types";

export const TOGETHER_API_BASE_URL = "https://api.together.xyz";

Expand All @@ -10,50 +10,49 @@ type TogetherId = string;
/**
* https://docs.together.ai/reference/models-1
*/
export const TOGETHER_MODEL_IDS: Partial<
Record<ModelId, { id: TogetherId; type: "chat" | "embedding" | "image" | "language" | "moderation" }>
> = {
/** text-to-image */
"black-forest-labs/FLUX.1-Canny-dev": { id: "black-forest-labs/FLUX.1-canny", type: "image" },
"black-forest-labs/FLUX.1-Depth-dev": { id: "black-forest-labs/FLUX.1-depth", type: "image" },
"black-forest-labs/FLUX.1-dev": { id: "black-forest-labs/FLUX.1-dev", type: "image" },
"black-forest-labs/FLUX.1-Redux-dev": { id: "black-forest-labs/FLUX.1-redux", type: "image" },
"black-forest-labs/FLUX.1-schnell": { id: "black-forest-labs/FLUX.1-pro", type: "image" },
"stabilityai/stable-diffusion-xl-base-1.0": { id: "stabilityai/stable-diffusion-xl-base-1.0", type: "image" },

/** chat completion */
"databricks/dbrx-instruct": { id: "databricks/dbrx-instruct", type: "chat" },
"deepseek-ai/deepseek-llm-67b-chat": { id: "deepseek-ai/deepseek-llm-67b-chat", type: "chat" },
"google/gemma-2-9b-it": { id: "google/gemma-2-9b-it", type: "chat" },
"google/gemma-2b-it": { id: "google/gemma-2-27b-it", type: "chat" },
"llava-hf/llava-v1.6-mistral-7b-hf": { id: "llava-hf/llava-v1.6-mistral-7b-hf", type: "chat" },
"meta-llama/Llama-2-13b-chat-hf": { id: "meta-llama/Llama-2-13b-chat-hf", type: "chat" },
"meta-llama/Llama-2-70b-hf": { id: "meta-llama/Llama-2-70b-hf", type: "language" },
"meta-llama/Llama-2-7b-chat-hf": { id: "meta-llama/Llama-2-7b-chat-hf", type: "chat" },
"meta-llama/Llama-3.2-11B-Vision-Instruct": { id: "meta-llama/Llama-Vision-Free", type: "chat" },
"meta-llama/Llama-3.2-3B-Instruct": { id: "meta-llama/Llama-3.2-3B-Instruct-Turbo", type: "chat" },
"meta-llama/Llama-3.2-90B-Vision-Instruct": { id: "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", type: "chat" },
"meta-llama/Llama-3.3-70B-Instruct": { id: "meta-llama/Llama-3.3-70B-Instruct-Turbo", type: "chat" },
"meta-llama/Meta-Llama-3-70B-Instruct": { id: "meta-llama/Llama-3-70b-chat-hf", type: "chat" },
"meta-llama/Meta-Llama-3-8B-Instruct": { id: "togethercomputer/Llama-3-8b-chat-hf-int4", type: "chat" },
"meta-llama/Meta-Llama-3.1-405B-Instruct": { id: "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", type: "chat" },
"meta-llama/Meta-Llama-3.1-70B-Instruct": { id: "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", type: "chat" },
"meta-llama/Meta-Llama-3.1-8B-Instruct": { id: "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo-128K", type: "chat" },
"microsoft/WizardLM-2-8x22B": { id: "microsoft/WizardLM-2-8x22B", type: "chat" },
"mistralai/Mistral-7B-Instruct-v0.3": { id: "mistralai/Mistral-7B-Instruct-v0.3", type: "chat" },
"mistralai/Mixtral-8x22B-Instruct-v0.1": { id: "mistralai/Mixtral-8x22B-Instruct-v0.1", type: "chat" },
"mistralai/Mixtral-8x7B-Instruct-v0.1": { id: "mistralai/Mixtral-8x7B-Instruct-v0.1", type: "chat" },
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": { id: "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", type: "chat" },
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF": { id: "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF", type: "chat" },
"Qwen/Qwen2-72B-Instruct": { id: "Qwen/Qwen2-72B-Instruct", type: "chat" },
"Qwen/Qwen2.5-72B-Instruct": { id: "Qwen/Qwen2.5-72B-Instruct-Turbo", type: "chat" },
"Qwen/Qwen2.5-7B-Instruct": { id: "Qwen/Qwen2.5-7B-Instruct-Turbo", type: "chat" },
"Qwen/Qwen2.5-Coder-32B-Instruct": { id: "Qwen/Qwen2.5-Coder-32B-Instruct", type: "chat" },
"Qwen/QwQ-32B-Preview": { id: "Qwen/QwQ-32B-Preview", type: "chat" },
"scb10x/llama-3-typhoon-v1.5-8b-instruct": { id: "scb10x/scb10x-llama3-typhoon-v1-5-8b-instruct", type: "chat" },
"scb10x/llama-3-typhoon-v1.5x-70b-instruct-awq": { id: "scb10x/scb10x-llama3-typhoon-v1-5x-4f316", type: "chat" },

/** text-generation */
"meta-llama/Meta-Llama-3-8B": { id: "meta-llama/Meta-Llama-3-8B", type: "language" },
"mistralai/Mixtral-8x7B-v0.1": { id: "mistralai/Mixtral-8x7B-v0.1", type: "language" },
export const TOGETHER_SUPPORTED_MODEL_IDS: ProviderMapping<TogetherId> = {
"text-to-image": {
"black-forest-labs/FLUX.1-Canny-dev": "black-forest-labs/FLUX.1-canny",
"black-forest-labs/FLUX.1-Depth-dev": "black-forest-labs/FLUX.1-depth",
"black-forest-labs/FLUX.1-dev": "black-forest-labs/FLUX.1-dev",
"black-forest-labs/FLUX.1-Redux-dev": "black-forest-labs/FLUX.1-redux",
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/FLUX.1-pro",
"stabilityai/stable-diffusion-xl-base-1.0": "stabilityai/stable-diffusion-xl-base-1.0",
},
conversational: {
"databricks/dbrx-instruct": "databricks/dbrx-instruct",
"deepseek-ai/deepseek-llm-67b-chat": "deepseek-ai/deepseek-llm-67b-chat",
"google/gemma-2-9b-it": "google/gemma-2-9b-it",
"google/gemma-2b-it": "google/gemma-2-27b-it",
"llava-hf/llava-v1.6-mistral-7b-hf": "llava-hf/llava-v1.6-mistral-7b-hf",
"meta-llama/Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
"meta-llama/Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf",
"meta-llama/Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
"meta-llama/Llama-3.2-11B-Vision-Instruct": "meta-llama/Llama-Vision-Free",
"meta-llama/Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo",
"meta-llama/Llama-3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
"meta-llama/Llama-3.3-70B-Instruct": "meta-llama/Llama-3.3-70B-Instruct-Turbo",
"meta-llama/Meta-Llama-3-70B-Instruct": "meta-llama/Llama-3-70b-chat-hf",
"meta-llama/Meta-Llama-3-8B-Instruct": "togethercomputer/Llama-3-8b-chat-hf-int4",
"meta-llama/Meta-Llama-3.1-405B-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
"meta-llama/Meta-Llama-3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
"meta-llama/Meta-Llama-3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo-128K",
"microsoft/WizardLM-2-8x22B": "microsoft/WizardLM-2-8x22B",
"mistralai/Mistral-7B-Instruct-v0.3": "mistralai/Mistral-7B-Instruct-v0.3",
"mistralai/Mixtral-8x22B-Instruct-v0.1": "mistralai/Mixtral-8x22B-Instruct-v0.1",
"mistralai/Mixtral-8x7B-Instruct-v0.1": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF": "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
"Qwen/Qwen2-72B-Instruct": "Qwen/Qwen2-72B-Instruct",
"Qwen/Qwen2.5-72B-Instruct": "Qwen/Qwen2.5-72B-Instruct-Turbo",
"Qwen/Qwen2.5-7B-Instruct": "Qwen/Qwen2.5-7B-Instruct-Turbo",
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen/Qwen2.5-Coder-32B-Instruct",
"Qwen/QwQ-32B-Preview": "Qwen/QwQ-32B-Preview",
"scb10x/llama-3-typhoon-v1.5-8b-instruct": "scb10x/scb10x-llama3-typhoon-v1-5-8b-instruct",
"scb10x/llama-3-typhoon-v1.5x-70b-instruct-awq": "scb10x/scb10x-llama3-typhoon-v1-5x-4f316",
},
"text-generation": {
"meta-llama/Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B",
"mistralai/Mixtral-8x7B-v0.1": "mistralai/Mixtral-8x7B-v0.1",
},
};
5 changes: 5 additions & 0 deletions packages/inference/src/providers/types.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import type { InferenceTask, ModelId } from "../types";

export type ProviderMapping<ProviderId extends string> = Partial<
Record<InferenceTask | "conversational", Partial<Record<ModelId, ProviderId>>>
>;
26 changes: 25 additions & 1 deletion packages/inference/test/HfInference.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { expect, it, describe, assert } from "vitest";

import type { ChatCompletionStreamOutput } from "@huggingface/tasks";

import { HfInference } from "../src";
import { chatCompletion, HfInference } from "../src";
import "./vcr";
import { readTestFile } from "./test-files";

Expand Down Expand Up @@ -820,6 +820,15 @@ describe.concurrent("HfInference", () => {
});
expect(res).toBeInstanceOf(Blob);
});

it.skip("textToSpeech versioned", async () => {
const res = await client.textToSpeech({
model: "SWivid/F5-TTS",
provider: "replicate",
inputs: "Hello, how are you?",
});
expect(res).toBeInstanceOf(Blob);
});
},
TIMEOUT
);
Expand Down Expand Up @@ -911,4 +920,19 @@ describe.concurrent("HfInference", () => {
},
TIMEOUT
);

describe.concurrent("3rd party providers", () => {
it("chatCompletion - fails with unsupported model", async () => {
expect(
chatCompletion({
model: "black-forest-labs/Flux.1-dev",
provider: "together",
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
accessToken: env.HF_TOGETHER_KEY,
})
).rejects.toThrowError(
"Model black-forest-labs/Flux.1-dev is not supported for task conversational and provider together"
);
});
});
});
Loading