Skip to content

Commit

Permalink
add a version of agent development colab that uses GCP hosted model
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 688940505
Change-Id: I9d1e96b8db9e58d310934413775d7b0d3f04fcb6
  • Loading branch information
vezhnick authored and copybara-github committed Oct 23, 2024
1 parent f6256d7 commit f048a5b
Show file tree
Hide file tree
Showing 2 changed files with 881 additions and 1 deletion.
46 changes: 45 additions & 1 deletion concordia/language_model/google_cloud_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,45 @@
_DEFAULT_MAX_TOKENS = 5000 # Adjust as needed for the specific model


def _wrap_prompt(prompt: str) -> str:
"""Wraps a prompt with the default conditioning.
Args:
prompt: the prompt to wrap
Returns:
the prompt wrapped with the default conditioning
"""
turns = []
turns.append(
"<start_of_turn>system You always continue sentences provided by the "
"user and you never repeat what the user already said.<end_of_turn>"
)
turns.append(
"<start_of_turn>user Question: Is Jake a turtle?\n"
"Answer: Jake is <end_of_turn>"
)
turns.append(
"<start_of_turn>model not a turtle.<end_of_turn>"
)
turns.append(
"<start_of_turn>user Question: What is Priya doing right now?\n"
"Answer: Priya is currently <end_of_turn>"
)
turns.append(
"<start_of_turn>model sleeping.<end_of_turn>"
)
turns.append(
"<start_of_turn>user Question:\n"
+ prompt
+ "<end_of_turn>"
)
turns.append(
"<start_of_turn>model "
)
return "\n".join(turns)


class VertexAI(language_model.LanguageModel):
"""Language Model that uses Google Cloud Vertex AI models.
Expand Down Expand Up @@ -102,7 +141,9 @@ def sample_text(
try:
response = self._client.predict(
endpoint=self._endpoint_name,
instances=[{"inputs": prompt}],
instances=[{
"inputs": _wrap_prompt(prompt)
}],
parameters=self._parameters,
).predictions[0]

Expand Down Expand Up @@ -153,6 +194,9 @@ def sample_choice(
)

sample = self.sample_text(prompt, temperature=temperature, seed=seed)

# clean up the sample from newlines and spaces
sample = sample.replace("\n", "").replace(" ", "")
answer = sampling.extract_choice_response(sample)
try:
idx = responses.index(answer)
Expand Down
Loading

0 comments on commit f048a5b

Please sign in to comment.