Skip to content

Commit

Permalink
feat: Refactor rag_store and rag_retrieval to use v1 protos
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700410303
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Nov 26, 2024
1 parent 6d43ad3 commit dfe6d6c
Show file tree
Hide file tree
Showing 12 changed files with 399 additions and 311 deletions.
12 changes: 12 additions & 0 deletions google/cloud/aiplatform/compat/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,15 @@
from google.cloud.aiplatform_v1.services.vizier_service import (
client as vizier_service_client_v1,
)
from google.cloud.aiplatform_v1.services.vertex_rag_service import (
client as vertex_rag_service_client_v1,
)
from google.cloud.aiplatform_v1.services.vertex_rag_data_service import (
client as vertex_rag_data_service_client_v1,
)
from google.cloud.aiplatform_v1.services.vertex_rag_data_service import (
async_client as vertex_rag_data_service_async_client_v1,
)

__all__ = (
# v1
Expand All @@ -204,6 +213,9 @@
specialist_pool_service_client_v1,
tensorboard_service_client_v1,
vizier_service_client_v1,
vertex_rag_data_service_async_client_v1,
vertex_rag_data_service_client_v1,
vertex_rag_service_client_v1,
# v1beta1
dataset_service_client_v1beta1,
deployment_resource_pool_service_client_v1beta1,
Expand Down
21 changes: 14 additions & 7 deletions google/cloud/aiplatform/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@
tensorboard_service_client_v1,
vizier_service_client_v1,
persistent_resource_service_client_v1,
vertex_rag_data_service_async_client_v1,
vertex_rag_data_service_client_v1,
vertex_rag_service_client_v1,
)

from google.cloud.aiplatform.compat.types import (
Expand Down Expand Up @@ -138,6 +141,7 @@
schedule_service_client_v1.ScheduleServiceClient,
tensorboard_service_client_v1.TensorboardServiceClient,
vizier_service_client_v1.VizierServiceClient,
vertex_rag_service_client_v1.VertexRagServiceClient,
)


Expand Down Expand Up @@ -967,8 +971,9 @@ class ReasoningEngineExecutionClientWithOverride(ClientWithOverride):

class VertexRagDataClientWithOverride(ClientWithOverride):
_is_temporary = True
_default_version = compat.V1BETA1
_default_version = compat.DEFAULT_VERSION
_version_map = (
(compat.V1, vertex_rag_data_service_client_v1.VertexRagDataServiceClient),
(
compat.V1BETA1,
vertex_rag_data_service_client_v1beta1.VertexRagDataServiceClient,
Expand All @@ -978,8 +983,12 @@ class VertexRagDataClientWithOverride(ClientWithOverride):

class VertexRagDataAsyncClientWithOverride(ClientWithOverride):
_is_temporary = True
_default_version = compat.V1BETA1
_default_version = compat.DEFAULT_VERSION
_version_map = (
(
compat.V1,
vertex_rag_data_service_async_client_v1.VertexRagDataServiceAsyncClient,
),
(
compat.V1BETA1,
vertex_rag_data_service_async_client_v1beta1.VertexRagDataServiceAsyncClient,
Expand All @@ -989,12 +998,10 @@ class VertexRagDataAsyncClientWithOverride(ClientWithOverride):

class VertexRagClientWithOverride(ClientWithOverride):
_is_temporary = True
_default_version = compat.V1BETA1
_default_version = compat.DEFAULT_VERSION
_version_map = (
(
compat.V1BETA1,
vertex_rag_service_client_v1beta1.VertexRagServiceClient,
),
(compat.V1, vertex_rag_service_client_v1.VertexRagServiceClient),
(compat.V1BETA1, vertex_rag_service_client_v1beta1.VertexRagServiceClient),
)


Expand Down
13 changes: 11 additions & 2 deletions tests/unit/vertex_rag/test_rag_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@

from google.cloud import aiplatform

from vertexai.preview.rag import (
from vertexai.rag import (
EmbeddingModelConfig,
Filter,
Pinecone,
RagCorpus,
RagFile,
RagResource,
RagRetrievalConfig,
SharePointSource,
SharePointSources,
SlackChannelsSource,
Expand All @@ -34,6 +36,7 @@
VertexVectorSearch,
VertexFeatureStore,
)

from google.cloud.aiplatform_v1beta1 import (
GoogleDriveSource,
RagFileChunkingConfig,
Expand All @@ -46,9 +49,11 @@
RagFile as GapicRagFile,
SharePointSources as GapicSharePointSources,
SlackSource as GapicSlackSource,
RagVectorDbConfig,
)
from google.cloud.aiplatform_v1 import (
RagContexts,
RetrieveContextsResponse,
RagVectorDbConfig,
)
from google.cloud.aiplatform_v1beta1.types import api_auth
from google.protobuf import timestamp_pb2
Expand Down Expand Up @@ -529,3 +534,7 @@
rag_corpus="213lkj-1/23jkl/",
rag_file_ids=[TEST_RAG_FILE_ID],
)
TEST_RAG_RETRIEVAL_CONFIG = RagRetrievalConfig(
top_k=2,
filter=Filter(vector_distance_threshold=0.5),
)
Loading

0 comments on commit dfe6d6c

Please sign in to comment.