Skip to content

Commit

Permalink
provider fix (#3187)
Browse files Browse the repository at this point in the history
* clean horizontal scrollbar

* provider fix

* ensure proper migration

* k

* update migration

* Revert "clean horizontal scrollbar"

This reverts commit fa592a1.
  • Loading branch information
pablonyx authored Nov 21, 2024
1 parent 366aa2a commit bd9f158
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 1 deletion.
59 changes: 59 additions & 0 deletions backend/alembic/versions/177de57c21c9_display_custom_llm_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""display custom llm models
Revision ID: 177de57c21c9
Revises: 4ee1287bd26a
Create Date: 2024-11-21 11:49:04.488677
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy import and_

revision = "177de57c21c9"
down_revision = "4ee1287bd26a"
branch_labels = None
depends_on = None
depends_on = None


def upgrade() -> None:
conn = op.get_bind()
llm_provider = sa.table(
"llm_provider",
sa.column("id", sa.Integer),
sa.column("provider", sa.String),
sa.column("model_names", postgresql.ARRAY(sa.String)),
sa.column("display_model_names", postgresql.ARRAY(sa.String)),
)

excluded_providers = ["openai", "bedrock", "anthropic", "azure"]

providers_to_update = sa.select(
llm_provider.c.id,
llm_provider.c.model_names,
llm_provider.c.display_model_names,
).where(
and_(
~llm_provider.c.provider.in_(excluded_providers),
llm_provider.c.model_names.isnot(None),
)
)

results = conn.execute(providers_to_update).fetchall()

for provider_id, model_names, display_model_names in results:
if display_model_names is None:
display_model_names = []

combined_model_names = list(set(display_model_names + model_names))
update_stmt = (
llm_provider.update()
.where(llm_provider.c.id == provider_id)
.values(display_model_names=combined_model_names)
)
conn.execute(update_stmt)


def downgrade() -> None:
pass
2 changes: 1 addition & 1 deletion backend/danswer/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,7 +1181,7 @@ class LLMProvider(Base):
default_model_name: Mapped[str] = mapped_column(String)
fast_default_model_name: Mapped[str | None] = mapped_column(String, nullable=True)

# Models to actually disp;aly to users
# Models to actually display to users
# If nulled out, we assume in the application logic we should present all
display_model_names: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ export function CustomLLMProviderUpdateForm({
},
body: JSON.stringify({
...values,
// For custom llm providers, all model names are displayed
display_model_names: values.model_names,
custom_config: customConfigProcessing(values.custom_config_list),
}),
});
Expand Down

0 comments on commit bd9f158

Please sign in to comment.