diff --git a/concordia/language_model/google_cloud_custom_model.py b/concordia/language_model/google_cloud_custom_model.py index fcea70f0..e06f942b 100644 --- a/concordia/language_model/google_cloud_custom_model.py +++ b/concordia/language_model/google_cloud_custom_model.py @@ -48,9 +48,9 @@ class VertexAI(language_model.LanguageModel): def __init__( self, - model_name: str, + model_name: str, # endpoint ID, all numbers *, - project: str, + project: str, # project ID, all numbers location: str, # e.g., "us-central1" measurements: measurements_lib.Measurements | None = None, channel: str = language_model.DEFAULT_STATS_CHANNEL, @@ -69,7 +69,7 @@ def __init__( self._location = location self._measurements = measurements self._channel = channel - aiplatform.init(project=project, location=location) + aiplatform.init(model_name=model_name, project=project, location=location) @override # sample_text: @@ -91,6 +91,16 @@ def sample_text( seed: int | None = None, # Vertex doesn't directly support seed. ) -> str: + endpoint_name = ( + "projects/" + + self._project + + "locations/" + + self._location + + "/endpoints/" + + self._model_name + ) + api_endpoint = self._location + "-aiplatform.googleapis.com" + max_tokens = min(max_tokens, _DEFAULT_MAX_TOKENS) result = "" @@ -106,11 +116,16 @@ def sample_text( ) time.sleep(seconds_to_sleep) + client_options = {"api_endpoint": api_endpoint} + client = aiplatform.gapic.PredictionServiceClient( + client_options=client_options + ) + try: response = ( - aiplatform.PredictionServiceClient() - .predict( - [{"content": prompt}], + client().predict( + endpoint=endpoint_name, + instances=[{"inputs": prompt}], parameters={ "temperature": temperature, "max_output_tokens": max_tokens, @@ -120,9 +135,7 @@ def sample_text( ) .predictions[0] ) - result = response[0][ - "content" - ] # Adjust based on API response structure + result = response # Apply terminators for terminator in terminators: