Skip to content

Commit

Permalink
Support free-form input for workspace model for providers with no `/m…
Browse files Browse the repository at this point in the history
…odels` endpoint (#2397)

* support generic openai workspace model

* Update UI for free form input for some providers

---------

Co-authored-by: Timothy Carambat <[email protected]>
  • Loading branch information
shatfield4 and timothycarambat authored Oct 15, 2024
1 parent c3723ce commit 6674e5a
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 53 deletions.
17 changes: 0 additions & 17 deletions frontend/src/components/LLMSelection/AzureAiOptions/index.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -71,23 +71,6 @@ export default function AzureAiOptions({ settings }) {
</option>
</select>
</div>

<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
Embedding Deployment Name
</label>
<input
type="text"
name="AzureOpenAiEmbeddingModelPref"
className="bg-zinc-900 text-white placeholder:text-white/20 text-sm rounded-lg focus:outline-primary-button active:outline-primary-button outline-none block w-full p-2.5"
placeholder="Azure OpenAI embedding model deployment name"
defaultValue={settings?.AzureOpenAiEmbeddingModelPref}
required={true}
autoComplete="off"
spellCheck={false}
/>
</div>
<div className="flex-flex-col w-60"></div>
</div>
</div>
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@ import { useTranslation } from "react-i18next";
import { Link } from "react-router-dom";
import paths from "@/utils/paths";

// Some providers can only be associated with a single model.
// In that case there is no selection to be made so we can just move on.
const NO_MODEL_SELECTION = [
"default",
"huggingface",
"generic-openai",
"bedrock",
];
const DISABLED_PROVIDERS = ["azure", "native"];
// Some providers do not support model selection via /models.
// In that case we allow the user to enter the model name manually and hope they
// type it correctly.
const FREE_FORM_LLM_SELECTION = ["bedrock", "azure", "generic-openai"];

// Some providers do not support model selection via /models
// and only have a fixed single-model they can use.
const NO_MODEL_SELECTION = ["default", "huggingface"];

// Some providers we just fully disable for ease of use.
const DISABLED_PROVIDERS = ["native"];

const LLM_DEFAULT = {
name: "System default",
value: "default",
Expand Down Expand Up @@ -65,8 +68,8 @@ export default function WorkspaceLLMSelection({
);
setFilteredLLMs(filtered);
}, [LLMS, searchQuery, selectedLLM]);

const selectedLLMObject = LLMS.find((llm) => llm.value === selectedLLM);

return (
<div className="border-b border-white/40 pb-8">
<div className="flex flex-col">
Expand Down Expand Up @@ -155,30 +158,66 @@ export default function WorkspaceLLMSelection({
</button>
)}
</div>
{NO_MODEL_SELECTION.includes(selectedLLM) ? (
<>
{selectedLLM !== "default" && (
<div className="w-full h-10 justify-center items-center flex mt-4">
<p className="text-sm font-base text-white text-opacity-60 text-center">
Multi-model support is not supported for this provider yet.
<br />
This workspace will use{" "}
<Link to={paths.settings.llmPreference()} className="underline">
the model set for the system.
</Link>
</p>
</div>
)}
</>
) : (
<div className="mt-4 flex flex-col gap-y-1">
<ChatModelSelection
provider={selectedLLM}
workspace={workspace}
setHasChanges={setHasChanges}
/>
<ModelSelector
selectedLLM={selectedLLM}
workspace={workspace}
setHasChanges={setHasChanges}
/>
</div>
);
}

// TODO: Add this to agent selector as well as make generic component.
function ModelSelector({ selectedLLM, workspace, setHasChanges }) {
if (NO_MODEL_SELECTION.includes(selectedLLM)) {
if (selectedLLM !== "default") {
return (
<div className="w-full h-10 justify-center items-center flex mt-4">
<p className="text-sm font-base text-white text-opacity-60 text-center">
Multi-model support is not supported for this provider yet.
<br />
This workspace will use{" "}
<Link to={paths.settings.llmPreference()} className="underline">
the model set for the system.
</Link>
</p>
</div>
)}
);
}
return null;
}

if (FREE_FORM_LLM_SELECTION.includes(selectedLLM)) {
return (
<FreeFormLLMInput workspace={workspace} setHasChanges={setHasChanges} />
);
}

return (
<ChatModelSelection
provider={selectedLLM}
workspace={workspace}
setHasChanges={setHasChanges}
/>
);
}

function FreeFormLLMInput({ workspace, setHasChanges }) {
const { t } = useTranslation();
return (
<div className="mt-4 flex flex-col gap-y-1">
<label className="block input-label">{t("chat.model.title")}</label>
<p className="text-white text-opacity-60 text-xs font-medium py-1.5">
{t("chat.model.description")}
</p>
<input
type="text"
name="chatModel"
defaultValue={workspace?.chatModel || ""}
onChange={() => setHasChanges(true)}
className="bg-zinc-900 text-white placeholder:text-white/20 text-sm rounded-lg focus:outline-primary-button active:outline-primary-button outline-none block w-full p-2.5"
placeholder="Enter model name exactly as referenced in the API (e.g., gpt-3.5-turbo)"
/>
</div>
);
}
4 changes: 2 additions & 2 deletions server/utils/AiProviders/azureOpenAi/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ const {
} = require("../../helpers/chat/responses");

class AzureOpenAiLLM {
constructor(embedder = null, _modelPreference = null) {
constructor(embedder = null, modelPreference = null) {
const { OpenAIClient, AzureKeyCredential } = require("@azure/openai");
if (!process.env.AZURE_OPENAI_ENDPOINT)
throw new Error("No Azure API endpoint was set.");
Expand All @@ -16,7 +16,7 @@ class AzureOpenAiLLM {
process.env.AZURE_OPENAI_ENDPOINT,
new AzureKeyCredential(process.env.AZURE_OPENAI_KEY)
);
this.model = process.env.OPEN_MODEL_PREF;
this.model = modelPreference ?? process.env.OPEN_MODEL_PREF;
this.limits = {
history: this.promptWindowLimit() * 0.15,
system: this.promptWindowLimit() * 0.15,
Expand Down
2 changes: 1 addition & 1 deletion server/utils/AiProviders/bedrock/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class AWSBedrockLLM {
#bedrockClient({ temperature = 0.7 }) {
const { ChatBedrockConverse } = require("@langchain/aws");
return new ChatBedrockConverse({
model: process.env.AWS_BEDROCK_LLM_MODEL_PREFERENCE,
model: this.model,
region: process.env.AWS_BEDROCK_LLM_REGION,
credentials: {
accessKeyId: process.env.AWS_BEDROCK_LLM_ACCESS_KEY_ID,
Expand Down

0 comments on commit 6674e5a

Please sign in to comment.