Skip to content

Commit

Permalink
Improve error handling, midsentence cutoffs, and max_tokens limit in …
Browse files Browse the repository at this point in the history
…Gemma2 model wrapper.

PiperOrigin-RevId: 672868968
Change-Id: I5f357745c8a0c006e2ef32fd3cc1db5110d0cda2
  • Loading branch information
jzleibo authored and copybara-github committed Sep 10, 2024
1 parent 4066851 commit b44d6bd
Showing 1 changed file with 67 additions and 22 deletions.
89 changes: 67 additions & 22 deletions concordia/language_model/together_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.


"""Language Model that uses Together AI api.
Recommended model name is 'google/gemma-2-9b-it'
Expand All @@ -21,13 +20,16 @@
from collections.abc import Collection, Sequence
import concurrent.futures
import os
import time
from concordia.language_model import language_model
from concordia.utils import measurements as measurements_lib
import numpy as np
import together
from typing_extensions import override

_MAX_MULTIPLE_CHOICE_ATTEMPTS = 20
_MAX_ATTEMPTS = 20
_NUM_SILENT_ATTEMPTS = 3
_MAX_PROMPT_LENGTH = 7500


class Gemma2(language_model.LanguageModel):
Expand Down Expand Up @@ -70,9 +72,6 @@ def sample_text(
timeout: float = language_model.DEFAULT_TIMEOUT_SECONDS,
seed: int | None = None,
) -> str:
# gpt models do not support `max_tokens` > 4096.
max_tokens = min(max_tokens, 4000)

messages = [
{
'role': 'system',
Expand All @@ -98,22 +97,51 @@ def sample_text(
{'role': 'user', 'content': prompt},
]

response = self._client.chat.completions.create(
model=self._model_name,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
timeout=timeout,
stop=terminators,
seed=seed,
)
# gemma2 does not support `tokens` + `max_new_tokens` > 8193.
# gemma2 interprets our `max_tokens`` as their `max_new_tokens`.
if len(prompt) > _MAX_PROMPT_LENGTH:
print(f'Warning: Truncating prompt of {len(prompt)} tokens to '
f'{_MAX_PROMPT_LENGTH} tokens.')
prompt = prompt[-_MAX_PROMPT_LENGTH:]
max_tokens = min(max_tokens, _MAX_PROMPT_LENGTH - len(prompt))

result = ''
for attempts in range(_MAX_ATTEMPTS):
if attempts > 0:
if attempts >= _NUM_SILENT_ATTEMPTS:
print(
'Sleeping for 10 seconds... '
+ f'attempt: {attempts} / {_MAX_ATTEMPTS}'
)
time.sleep(10)
try:
response = self._client.chat.completions.create(
model=self._model_name,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
timeout=timeout,
stop=terminators,
seed=seed,
)
except together.error.RateLimitError as err:
if attempts >= _NUM_SILENT_ATTEMPTS:
print(f' Exception: {err}')
continue
else:
result = response.choices[0].message.content
break

if self._measurements is not None:
self._measurements.publish_datum(
self._channel,
{'raw_text_length': len(response.choices[0].message.content)},
{'raw_text_length': len(result)},
)
return response.choices[0].message.content
# Remove the occasional sentence fragment from the end of the result.
last_stop = result.rfind('.')
if last_stop >= 0:
result = result[: last_stop + 1]
return result

@override
def sample_choice(
Expand Down Expand Up @@ -150,12 +178,29 @@ def _sample_choice(response: str) -> float:
{'role': 'user', 'content': prompt + response},
]

result = self._client.chat.completions.create(
model=self._model_name,
messages=messages,
seed=seed,
logprobs=1,
)
result = ''
for attempts in range(_MAX_ATTEMPTS):
if attempts > 0:
if attempts >= _NUM_SILENT_ATTEMPTS:
print(
'Sleeping for 10 seconds... '
+ f'attempt: {attempts} / {_MAX_ATTEMPTS}'
)
time.sleep(10)
try:
result = self._client.chat.completions.create(
model=self._model_name,
messages=messages,
seed=seed,
logprobs=1,
)
except together.error.RateLimitError as err:
if attempts >= _NUM_SILENT_ATTEMPTS:
print(f' Exception: {err}')
continue
else:
break

lp = sum(result.choices[0].logprobs.token_logprobs)

return lp
Expand Down

0 comments on commit b44d6bd

Please sign in to comment.