diff --git a/concordia/language_model/together_ai.py b/concordia/language_model/together_ai.py index 6edd0e27..5ec9f39b 100644 --- a/concordia/language_model/together_ai.py +++ b/concordia/language_model/together_ai.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - """Language Model that uses Together AI api. Recommended model name is 'google/gemma-2-9b-it' @@ -21,13 +20,16 @@ from collections.abc import Collection, Sequence import concurrent.futures import os +import time from concordia.language_model import language_model from concordia.utils import measurements as measurements_lib import numpy as np import together from typing_extensions import override -_MAX_MULTIPLE_CHOICE_ATTEMPTS = 20 +_MAX_ATTEMPTS = 20 +_NUM_SILENT_ATTEMPTS = 3 +_MAX_PROMPT_LENGTH = 7500 class Gemma2(language_model.LanguageModel): @@ -70,9 +72,6 @@ def sample_text( timeout: float = language_model.DEFAULT_TIMEOUT_SECONDS, seed: int | None = None, ) -> str: - # gpt models do not support `max_tokens` > 4096. - max_tokens = min(max_tokens, 4000) - messages = [ { 'role': 'system', @@ -98,22 +97,51 @@ def sample_text( {'role': 'user', 'content': prompt}, ] - response = self._client.chat.completions.create( - model=self._model_name, - messages=messages, - temperature=temperature, - max_tokens=max_tokens, - timeout=timeout, - stop=terminators, - seed=seed, - ) + # gemma2 does not support `tokens` + `max_new_tokens` > 8193. + # gemma2 interprets our `max_tokens`` as their `max_new_tokens`. + if len(prompt) > _MAX_PROMPT_LENGTH: + print(f'Warning: Truncating prompt of {len(prompt)} tokens to ' + f'{_MAX_PROMPT_LENGTH} tokens.') + prompt = prompt[-_MAX_PROMPT_LENGTH:] + max_tokens = min(max_tokens, _MAX_PROMPT_LENGTH - len(prompt)) + + result = '' + for attempts in range(_MAX_ATTEMPTS): + if attempts > 0: + if attempts >= _NUM_SILENT_ATTEMPTS: + print( + 'Sleeping for 10 seconds... ' + + f'attempt: {attempts} / {_MAX_ATTEMPTS}' + ) + time.sleep(10) + try: + response = self._client.chat.completions.create( + model=self._model_name, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + timeout=timeout, + stop=terminators, + seed=seed, + ) + except together.error.RateLimitError as err: + if attempts >= _NUM_SILENT_ATTEMPTS: + print(f' Exception: {err}') + continue + else: + result = response.choices[0].message.content + break if self._measurements is not None: self._measurements.publish_datum( self._channel, - {'raw_text_length': len(response.choices[0].message.content)}, + {'raw_text_length': len(result)}, ) - return response.choices[0].message.content + # Remove the occasional sentence fragment from the end of the result. + last_stop = result.rfind('.') + if last_stop >= 0: + result = result[: last_stop + 1] + return result @override def sample_choice( @@ -150,12 +178,29 @@ def _sample_choice(response: str) -> float: {'role': 'user', 'content': prompt + response}, ] - result = self._client.chat.completions.create( - model=self._model_name, - messages=messages, - seed=seed, - logprobs=1, - ) + result = '' + for attempts in range(_MAX_ATTEMPTS): + if attempts > 0: + if attempts >= _NUM_SILENT_ATTEMPTS: + print( + 'Sleeping for 10 seconds... ' + + f'attempt: {attempts} / {_MAX_ATTEMPTS}' + ) + time.sleep(10) + try: + result = self._client.chat.completions.create( + model=self._model_name, + messages=messages, + seed=seed, + logprobs=1, + ) + except together.error.RateLimitError as err: + if attempts >= _NUM_SILENT_ATTEMPTS: + print(f' Exception: {err}') + continue + else: + break + lp = sum(result.choices[0].logprobs.token_logprobs) return lp