Skip to content

Commit

Permalink
chore: Add RagManagedDb to RAG corpus creation as well as default db …
Browse files Browse the repository at this point in the history
…config when no vector_db specified

PiperOrigin-RevId: 678343933
  • Loading branch information
speedstorm1 authored and copybara-github committed Sep 24, 2024
1 parent c669d33 commit 6e5cbfd
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 11 deletions.
2 changes: 2 additions & 0 deletions vertexai/preview/rag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
Pinecone,
RagCorpus,
RagFile,
RagManagedDb,
RagResource,
SharePointSource,
SharePointSources,
Expand All @@ -61,6 +62,7 @@
"Pinecone",
"RagCorpus",
"RagFile",
"RagManagedDb",
"RagResource",
"Retrieval",
"SharePointSource",
Expand Down
12 changes: 6 additions & 6 deletions vertexai/preview/rag/rag_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
Pinecone,
RagCorpus,
RagFile,
RagManagedDb,
SharePointSources,
SlackChannelsSource,
VertexFeatureStore,
Expand All @@ -61,7 +62,7 @@ def create_corpus(
description: Optional[str] = None,
embedding_model_config: Optional[EmbeddingModelConfig] = None,
vector_db: Optional[
Union[Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone]
Union[Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone, RagManagedDb]
] = None,
) -> RagCorpus:
"""Creates a new RagCorpus resource.
Expand Down Expand Up @@ -102,11 +103,10 @@ def create_corpus(
embedding_model_config=embedding_model_config,
rag_corpus=rag_corpus,
)
if vector_db is not None:
_gapic_utils.set_vector_db(
vector_db=vector_db,
rag_corpus=rag_corpus,
)
_gapic_utils.set_vector_db(
vector_db=vector_db,
rag_corpus=rag_corpus,
)

request = CreateRagCorpusRequest(
parent=parent,
Expand Down
26 changes: 21 additions & 5 deletions vertexai/preview/rag/utils/_gapic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
Pinecone,
RagCorpus,
RagFile,
RagManagedDb,
SharePointSources,
SlackChannelsSource,
JiraSource,
Expand Down Expand Up @@ -107,6 +108,13 @@ def _check_weaviate(gapic_vector_db: RagVectorDbConfig) -> bool:
return gapic_vector_db.weaviate.ByteSize() > 0


def _check_rag_managed_db(gapic_vector_db: RagVectorDbConfig) -> bool:
try:
return gapic_vector_db.__contains__("rag_managed_db")
except AttributeError:
return gapic_vector_db.rag_managed_db.ByteSize() > 0


def _check_vertex_feature_store(gapic_vector_db: RagVectorDbConfig) -> bool:
try:
return gapic_vector_db.__contains__("vertex_feature_store")
Expand All @@ -130,8 +138,8 @@ def _check_vertex_vector_search(gapic_vector_db: RagVectorDbConfig) -> bool:

def convert_gapic_to_vector_db(
gapic_vector_db: RagVectorDbConfig,
) -> Union[Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone]:
"""Convert Gapic RagVectorDbConfig to Weaviate, VertexFeatureStore, VertexVectorSearch, or Pinecone."""
) -> Union[Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone, RagManagedDb]:
"""Convert Gapic RagVectorDbConfig to Weaviate, VertexFeatureStore, VertexVectorSearch, RagManagedDb, or Pinecone."""
if _check_weaviate(gapic_vector_db):
return Weaviate(
weaviate_http_endpoint=gapic_vector_db.weaviate.http_endpoint,
Expand All @@ -152,6 +160,8 @@ def convert_gapic_to_vector_db(
index_endpoint=gapic_vector_db.vertex_vector_search.index_endpoint,
index=gapic_vector_db.vertex_vector_search.index,
)
elif _check_rag_managed_db(gapic_vector_db):
return RagManagedDb()
else:
return None

Expand Down Expand Up @@ -499,11 +509,17 @@ def set_embedding_model_config(


def set_vector_db(
vector_db: Union[Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone],
vector_db: Union[
Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone, RagManagedDb, None
],
rag_corpus: GapicRagCorpus,
) -> None:
"""Sets the vector db configuration for the rag corpus."""
if isinstance(vector_db, Weaviate):
if vector_db is None or isinstance(vector_db, RagManagedDb):
rag_corpus.rag_vector_db_config = RagVectorDbConfig(
rag_managed_db=RagVectorDbConfig.RagManagedDb(),
)
elif isinstance(vector_db, Weaviate):
http_endpoint = vector_db.weaviate_http_endpoint
collection_name = vector_db.collection_name
api_key = vector_db.api_key
Expand Down Expand Up @@ -553,5 +569,5 @@ def set_vector_db(
)
else:
raise TypeError(
"vector_db must be a Weaviate, VertexFeatureStore, VertexVectorSearch, or Pinecone."
"vector_db must be a Weaviate, VertexFeatureStore, VertexVectorSearch, RagManagedDb, or Pinecone."
)
5 changes: 5 additions & 0 deletions vertexai/preview/rag/utils/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ class VertexVectorSearch:
index: str


@dataclasses.dataclass
class RagManagedDb:
"""RagManagedDb."""


@dataclasses.dataclass
class Pinecone:
"""Pinecone.
Expand Down

0 comments on commit 6e5cbfd

Please sign in to comment.