-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9e6e240
commit 6b37263
Showing
16 changed files
with
708 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# syntax=docker/dockerfile:1.3 | ||
|
||
FROM pytorch/pytorch:2.1.2-cuda12.1-cudnn8-devel | ||
|
||
# Default Environment Variables | ||
ENV PYTHONFAULTHANDLER=1 | ||
ENV PYTHONUNBUFFERED=1 | ||
ENV PYTHONHASHSEED=random | ||
|
||
# Build-time Environment Variables | ||
ARG PIP_NO_CACHE_DIR=off | ||
ARG PIP_DISABLE_PIP_VERSION_CHECK=on | ||
ARG PIP_DEFAULT_TIMEOUT=100 | ||
ARG DEBIAN_FRONTEND=noninteractive | ||
ARG DEBCONF_NONINTERACTIVE_SEEN=true | ||
|
||
# Install Debian packages | ||
RUN apt-get -yq update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ | ||
build-essential \ | ||
apt-utils \ | ||
dumb-init \ | ||
git \ | ||
gcc \ | ||
ssh \ | ||
htop \ | ||
iftop \ | ||
vim \ | ||
apt-transport-https \ | ||
ca-certificates \ | ||
gnupg \ | ||
curl \ | ||
zlib1g-dev \ | ||
libjpeg-dev \ | ||
libsm6 \ | ||
libxext6 \ | ||
libxrender-dev \ | ||
libgl1-mesa-glx \ | ||
libglib2.0-0 \ | ||
libgtk2.0-dev \ | ||
libssl-dev \ | ||
libbz2-dev \ | ||
libreadline-dev \ | ||
libsqlite3-dev \ | ||
wget \ | ||
llvm \ | ||
libncurses5-dev \ | ||
libncursesw5-dev \ | ||
xz-utils \ | ||
tk-dev \ | ||
libffi-dev \ | ||
liblzma-dev \ | ||
python3-openssl \ | ||
libcurl4-openssl-dev \ | ||
libssl-dev \ | ||
python3-dev \ | ||
gcc \ | ||
&& apt-get autoremove -y \ | ||
&& rm -rf /var/lib/apt/lists/* \ | ||
&& apt-get clean | ||
|
||
# Install s5cmd for fast reads from s3 | ||
RUN conda install -y -c conda-forge s5cmd | ||
RUN pip3 install --upgrade pip --no-cache-dir | ||
|
||
# CACHE INSTALL | ||
COPY . /workspace/research/hugh/lm-evaluation-harness | ||
WORKDIR /workspace/research/hugh/lm-evaluation-harness | ||
RUN pip3 install -e . | ||
RUN pip3 install -r requirements_for_docker.txt | ||
RUN pip3 install transformers==4.41.2 | ||
RUN pip3 install flash-attn==2.5.8 --no-build-isolation | ||
|
||
# set FORCE_CUDA because during `docker build` cuda is not accessible | ||
ENV FORCE_CUDA="1" | ||
ENV DD_SERVICE lm-evaluation-harness | ||
ENV AWS_DEFAULT_REGION=us-west-2 | ||
|
||
WORKDIR /workspace/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
LITELLM/gpt-4 | ||
LITELLM/gpt-4-turbo | ||
LITELLM/gpt-4o | ||
LITELLM/gemini-1.5-pro-preview-0409 | ||
LITELLM/gemini-1.5-pro-preview-0514 | ||
LITELLM/gemini-1.5-flash-preview-0514 | ||
LITELLM/gemini-pro | ||
LITELLM/mistral-large-latest | ||
LITELLM/command | ||
LITELLM/claude-3-opus-20240229 | ||
LITELLM/gpt-3.5-turbo | ||
LITELLM/claude-2.1 | ||
LITELLM/claude-3-haiku-20240307 | ||
LITELLM/claude-3-sonnet-20240229 | ||
LITELLM/mistral-small-latest | ||
LITELLM/mistral-medium-latest |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
microsoft/Phi-3-mini-4k-instruct | ||
microsoft/Phi-3-mini-128k-instruct | ||
microsoft/Phi-3-small-8k-instruct | ||
microsoft/Phi-3-small-128k-instruct | ||
microsoft/Phi-3-medium-4k-instruct | ||
microsoft/Phi-3-medium-128k-instruct | ||
meta-llama/Meta-Llama-3-8B-Instruct | ||
meta-llama/Meta-Llama-3-8B | ||
meta-llama/Meta-Llama-3-70B-Instruct | ||
meta-llama/Meta-Llama-3-70B | ||
mistralai/Codestral-22B-v0.1 | ||
mistralai/Mixtral-8x22B-v0.1 | ||
mistralai/Mixtral-8x22B-Instruct-v0.1 | ||
01-ai/Yi-34B-Chat | ||
01-ai/Yi-6B-Chat | ||
codellama/CodeLlama-13b-hf | ||
codellama/CodeLlama-13b-Instruct-hf | ||
codellama/CodeLlama-13b-Python-hf | ||
codellama/CodeLlama-34b-hf | ||
codellama/CodeLlama-34b-Instruct-hf | ||
codellama/CodeLlama-34b-Python-hf | ||
codellama/CodeLlama-70b-hf | ||
codellama/CodeLlama-70b-Instruct-hf | ||
codellama/CodeLlama-70b-Python-hf | ||
codellama/CodeLlama-7b-hf | ||
codellama/CodeLlama-7b-Instruct-hf | ||
codellama/CodeLlama-7b-Python-hf | ||
databricks/dbrx-base | ||
databricks/dbrx-instruct | ||
deepseek-ai/deepseek-coder-33b-instruct | ||
deepseek-ai/deepseek-llm-67b-base | ||
EleutherAI/gpt-neox-20b | ||
EleutherAI/pythia-12b | ||
EleutherAI/llemma_34b | ||
EleutherAI/llemma_7b | ||
google/gemma-2b | ||
google/gemma-2b-it | ||
google/gemma-7b | ||
google/gemma-7b-it | ||
lmsys/vicuna-33b-v1.3 | ||
meta-llama/Llama-2-13b-hf | ||
meta-llama/Llama-2-70b-hf | ||
meta-llama/Llama-2-7b-hf | ||
microsoft/phi-1_5 | ||
microsoft/phi-2 | ||
mistralai/Mistral-7B-Instruct-v0.1 | ||
mistralai/Mistral-7B-Instruct-v0.2 | ||
mistralai/Mistral-7B-v0.1 | ||
mistralai/Mixtral-8x7B-Instruct-v0.1 | ||
mistralai/Mixtral-8x7B-v0.1 | ||
openai-community/gpt2-xl | ||
Phind/Phind-CodeLlama-34B-v2 | ||
tiiuae/falcon-180B-chat | ||
Xwin-LM/Xwin-Math-13B-V1.0 | ||
Xwin-LM/Xwin-Math-70B-V1.0 | ||
Xwin-LM/Xwin-Math-7B-V1.0 | ||
peiyi9979/math-shepherd-mistral-7b-rl | ||
deepseek-ai/deepseek-math-7b-rl | ||
abacusai/Smaug-2-72B | ||
abacusai/Smaug-34B-v0.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
optimum_lm, | ||
textsynth, | ||
vllm_causallms, | ||
litellm_completions, | ||
) | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
# LiteLLM API engine, designed to largely mimic the OpenAIChatCompletions Format | ||
|
||
import copy | ||
import os | ||
from collections import defaultdict | ||
from importlib.util import find_spec | ||
from typing import List, Literal, Optional, Tuple | ||
import time | ||
|
||
from tqdm import tqdm | ||
from IPython import embed | ||
|
||
import lm_eval.models.utils | ||
from lm_eval import utils | ||
from lm_eval.api.model import LM, TemplateLM | ||
from lm_eval.api.registry import register_model | ||
from lm_eval.models.utils import retry_on_specific_exceptions | ||
from lm_eval.utils import eval_logger | ||
|
||
import litellm | ||
litellm.vertex_project = "sacred-vigil-412721" # TODO: replace with your own project name | ||
litellm.vertex_location = "us-central1" # proj location | ||
|
||
@register_model("litellm") | ||
class LiteLLM(LM): | ||
def __init__( | ||
self, | ||
model: str = "gpt-3.5-turbo", | ||
truncate: bool = False, | ||
**kwargs, | ||
) -> None: | ||
""" | ||
:param model: str | ||
Implements an OpenAI-style chat completion API for | ||
models via LiteLLM | ||
HuggingFace Tokenizer | ||
OpenAI API model (e.g. gpt-3.5-turbo) | ||
using the **gen_kwargs passed on init | ||
:param truncate: bool | ||
Truncate input if too long (if False and input is too long, throw error) | ||
""" | ||
super().__init__() | ||
self.model = model | ||
self.truncate = truncate | ||
|
||
@property | ||
def max_length(self) -> int: | ||
# Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token | ||
return 4096 | ||
|
||
@property | ||
def max_gen_toks(self) -> int: | ||
return 4096 | ||
|
||
@property | ||
def batch_size(self): | ||
# Isn't used because we override _loglikelihood_tokens | ||
raise NotImplementedError() | ||
|
||
@property | ||
def device(self): | ||
# Isn't used because we override _loglikelihood_tokens | ||
raise NotImplementedError() | ||
|
||
def generate_until(self, requests) -> List[str]: | ||
res = defaultdict(list) | ||
re_ords = {} | ||
|
||
# we group requests by their generation_kwargs, | ||
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling | ||
# in the same batch. | ||
grouper = lm_eval.models.utils.Grouper(requests, lambda x: str(x.args[1])) | ||
for key, reqs in grouper.get_grouped().items(): | ||
# within each set of reqs for given kwargs, we reorder by token length, descending. | ||
re_ords[key] = utils.Reorderer( | ||
[req.args for req in reqs], lambda x: (-len(x[0]), x[0]) | ||
) | ||
|
||
pbar = tqdm(total=len(requests), disable=(self.rank != 0)) | ||
for key, re_ord in re_ords.items(): | ||
# n needs to be 1 because messages in | ||
# chat completion are not batch but | ||
# is regarded as a single conversation. | ||
chunks = lm_eval.models.utils.chunks(re_ord.get_reordered(), n=1) | ||
for chunk in chunks: | ||
contexts, all_gen_kwargs = zip(*chunk) | ||
inps = [{"role": "user", "content": context} for context in contexts] | ||
|
||
gen_kwargs = all_gen_kwargs[0] | ||
until = None | ||
if isinstance(kwargs := copy.deepcopy(gen_kwargs), dict): | ||
if "do_sample" in kwargs.keys(): | ||
kwargs.pop("do_sample") | ||
if "until" in kwargs.keys(): | ||
until = kwargs.pop("until") | ||
if isinstance(until, str): | ||
until = [kwargs] | ||
elif not isinstance(until, list): | ||
raise ValueError( | ||
f"Expected repr(kwargs['until']) to be of type Union[str, list] but got {until}" | ||
) | ||
kwargs["stop"] = until | ||
kwargs["max_tokens"] = kwargs.pop("max_gen_toks", self.max_gen_toks) | ||
else: | ||
raise ValueError( | ||
f"Expected repr(kwargs) to be of type repr(dict) but got {kwargs}" | ||
) | ||
|
||
if "mistral" in self.model or "gemini" in self.model: | ||
# Mistral doesn't support the stop parameter, so remove | ||
kwargs.pop("stop", None) | ||
|
||
if "gemini" in self.model: | ||
# Sleep between requests for rate limit | ||
time.sleep(3) | ||
|
||
try: | ||
response = litellm.completion(num_retries=10, | ||
messages=inps, | ||
model=self.model, | ||
**kwargs, | ||
) | ||
|
||
for resp, (context, args_) in zip(response.choices, chunk): | ||
s = resp.message.content | ||
|
||
if until is not None: | ||
for term in until: | ||
if len(term) > 0: | ||
s = s.split(term)[0] | ||
|
||
res[key].append(s) | ||
|
||
self.cache_hook.add_partial( | ||
"generate_until", (context, {"until": until}), s | ||
) | ||
pbar.update(1) | ||
|
||
except Exception as e: | ||
print("Error in generate_until: {}".format(str(e))) | ||
res[key].append("0") | ||
|
||
# reorder this group of results back to original unsorted form | ||
res[key] = re_ord.get_original(res[key]) | ||
|
||
pbar.close() | ||
|
||
return grouper.get_original(res) | ||
|
||
def loglikelihood(self, requests): | ||
raise NotImplementedError("No support for logits.") | ||
|
||
def loglikelihood_rolling(self, requests): | ||
raise NotImplementedError("No support for logits.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.