Skip to content

Commit

Permalink
chore: test change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 713733033
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Jan 13, 2025
1 parent f019e18 commit 0421d42
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 0 deletions.
5 changes: 5 additions & 0 deletions google/cloud/aiplatform/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ def create_client(
client: Instantiated Vertex AI Service client with optional overrides
"""
gapic_version = __version__
print("IN BRANCH CL PRINT -- PARKER")

if appended_gapic_version:
gapic_version = f"{gapic_version}+{appended_gapic_version}"
Expand Down Expand Up @@ -586,6 +587,7 @@ def create_client(
gapic_version=gapic_version,
user_agent=user_agent,
)
print("Branch CL client info: " + client_info)

kwargs = {
"credentials": credentials or self.credentials,
Expand All @@ -598,6 +600,7 @@ def create_client(
),
"client_info": client_info,
}
print("Branch CL kwargs: " + kwargs)

# Do not pass "grpc", rely on gapic defaults unless "rest" is specified
if self._api_transport == "rest" and "Async" in client_class.__name__:
Expand All @@ -622,7 +625,9 @@ def create_client(
client = client_class(**kwargs)
# We only wrap the client if the request_metadata is set at the creation time.
if self._request_metadata:
print("Branch CL wrapping client because request metadata is set")
client = _ClientWrapperThatAddsDefaultMetadata(client)
print("Branch CL returning client: " + str(client))
return client

def _get_default_project_and_location(self) -> Tuple[str, str]:
Expand Down
22 changes: 22 additions & 0 deletions vertexai/prompts/_prompt_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,17 +493,28 @@ def _create_prompt_version_resource(

def _get_prompt_resource(prompt: Prompt, prompt_id: str) -> gca_dataset.Dataset:
"""Helper function to get a prompt resource from a prompt id."""
print("Branch CL _get_prompt_resource")
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location
name = f"projects/{project}/locations/{location}/datasets/{prompt_id}"
dataset = prompt._dataset_client.get_dataset(name=name)

print("Setting ._dataset_client.client_info.gapic_version")
prompt._dataset_client.client_info.gapic_version = prompt._dataset_client.client_info.gapic_version + "+prompt_management"
print(str(prompt._dataset_client.client_info.gapic_version))

print("Setting ._dataset_client.appended_gapic_version")
prompt._dataset_client.appended_gapic_version = prompt._dataset_client.appended_gapic_version + "+prompt_management"
print(str(prompt._dataset_client.appended_gapic_version))

return dataset


def _get_prompt_resource_from_version(
prompt: Prompt, prompt_id: str, version_id: str
) -> gca_dataset.Dataset:
"""Helper function to get a prompt resource from a prompt version id."""
print("Branch CL _get_prompt_resource_from_version")
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location
name = f"projects/{project}/locations/{location}/datasets/{prompt_id}/datasetVersions/{version_id}"
Expand All @@ -516,6 +527,14 @@ def _get_prompt_resource_from_version(
name = f"projects/{project}/locations/{location}/datasets/{prompt_id}"
dataset = prompt._dataset_client.get_dataset(name=name)

print("Setting ._dataset_client.client_info.gapic_version")
prompt._dataset_client.client_info.gapic_version = prompt._dataset_client.client_info.gapic_version + "+prompt_management"
print(str(prompt._dataset_client.client_info.gapic_version))

print("Setting ._dataset_client.appended_gapic_version")
prompt._dataset_client.appended_gapic_version = prompt._dataset_client.appended_gapic_version + "+prompt_management"
print(str(prompt._dataset_client.appended_gapic_version))

# Step 3: Convert to DatasetVersion object to Dataset object
dataset = gca_dataset.Dataset(
name=name,
Expand Down Expand Up @@ -573,19 +592,22 @@ def get(prompt_id: str, version_id: Optional[str] = None) -> Prompt:
"""
prompt = Prompt()
if version_id:
print("Branch CL get prompt resource from version")
dataset = _get_prompt_resource_from_version(
prompt=prompt,
prompt_id=prompt_id,
version_id=version_id,
)
else:
print("Branch CL get prompt resource")
dataset = _get_prompt_resource(prompt=prompt, prompt_id=prompt_id)

# Remove etag to avoid error for repeated dataset updates
dataset.etag = None

prompt._dataset = dataset
prompt._version_id = version_id
prompt._used_prompt_management = True

dataset_dict = _proto_to_dict(dataset)

Expand Down
7 changes: 7 additions & 0 deletions vertexai/prompts/_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __init__(
self._prompt_name = None
self._version_id = None
self._version_name = None
self._used_prompt_management = None

self.prompt_data = prompt_data
self.variables = variables if variables else [{}]
Expand Down Expand Up @@ -610,6 +611,12 @@ def generate_content(
model = GenerativeModel(
model_name=model_name, system_instruction=system_instruction
)

if self._used_prompt_management:
# Want to update `appended_gapic_version` field here with the
# boolean value...
print("Branch CL generate_content AFTER _used_prompt_management")

return model.generate_content(
contents=contents,
generation_config=generation_config,
Expand Down

0 comments on commit 0421d42

Please sign in to comment.