Skip to content

Commit

Permalink
updated code to work with latest Vertex AI API's prediction service
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 687451385
Change-Id: Ia9e66b83cccccf4f3edca571e90aa9ca4e389860
  • Loading branch information
minsukchang authored and copybara-github committed Oct 18, 2024
1 parent b6a2d2c commit 39a390d
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions concordia/language_model/google_cloud_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ class VertexAI(language_model.LanguageModel):

def __init__(
self,
model_name: str,
model_name: str, # endpoint ID, all numbers
*,
project: str,
project: str, # project ID, all numbers
location: str, # e.g., "us-central1"
measurements: measurements_lib.Measurements | None = None,
channel: str = language_model.DEFAULT_STATS_CHANNEL,
Expand All @@ -69,7 +69,7 @@ def __init__(
self._location = location
self._measurements = measurements
self._channel = channel
aiplatform.init(project=project, location=location)
aiplatform.init(model_name=model_name, project=project, location=location)

@override
# sample_text:
Expand All @@ -91,6 +91,16 @@ def sample_text(
seed: int | None = None, # Vertex doesn't directly support seed.
) -> str:

endpoint_name = (
"projects/"
+ self._project
+ "locations/"
+ self._location
+ "/endpoints/"
+ self._model_name
)
api_endpoint = self._location + "-aiplatform.googleapis.com"

max_tokens = min(max_tokens, _DEFAULT_MAX_TOKENS)

result = ""
Expand All @@ -106,11 +116,16 @@ def sample_text(
)
time.sleep(seconds_to_sleep)

client_options = {"api_endpoint": api_endpoint}
client = aiplatform.gapic.PredictionServiceClient(
client_options=client_options
)

try:
response = (
aiplatform.PredictionServiceClient()
.predict(
[{"content": prompt}],
client().predict(
endpoint=endpoint_name,
instances=[{"inputs": prompt}],
parameters={
"temperature": temperature,
"max_output_tokens": max_tokens,
Expand All @@ -120,9 +135,7 @@ def sample_text(
)
.predictions[0]
)
result = response[0][
"content"
] # Adjust based on API response structure
result = response

# Apply terminators
for terminator in terminators:
Expand Down

0 comments on commit 39a390d

Please sign in to comment.