Skip to content

Commit

Permalink
GSM1k evaluation code
Browse files Browse the repository at this point in the history
  • Loading branch information
hughbzhang committed Jun 13, 2024
1 parent 9e6e240 commit 6b37263
Show file tree
Hide file tree
Showing 16 changed files with 708 additions and 6 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,9 @@ lm_eval/caching/.cache
# don't track files created by wandb
wandb
examples/wandb
results/
plots/

st_config/.*

.DS_Store
78 changes: 78 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# syntax=docker/dockerfile:1.3

FROM pytorch/pytorch:2.1.2-cuda12.1-cudnn8-devel

# Default Environment Variables
ENV PYTHONFAULTHANDLER=1
ENV PYTHONUNBUFFERED=1
ENV PYTHONHASHSEED=random

# Build-time Environment Variables
ARG PIP_NO_CACHE_DIR=off
ARG PIP_DISABLE_PIP_VERSION_CHECK=on
ARG PIP_DEFAULT_TIMEOUT=100
ARG DEBIAN_FRONTEND=noninteractive
ARG DEBCONF_NONINTERACTIVE_SEEN=true

# Install Debian packages
RUN apt-get -yq update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
build-essential \
apt-utils \
dumb-init \
git \
gcc \
ssh \
htop \
iftop \
vim \
apt-transport-https \
ca-certificates \
gnupg \
curl \
zlib1g-dev \
libjpeg-dev \
libsm6 \
libxext6 \
libxrender-dev \
libgl1-mesa-glx \
libglib2.0-0 \
libgtk2.0-dev \
libssl-dev \
libbz2-dev \
libreadline-dev \
libsqlite3-dev \
wget \
llvm \
libncurses5-dev \
libncursesw5-dev \
xz-utils \
tk-dev \
libffi-dev \
liblzma-dev \
python3-openssl \
libcurl4-openssl-dev \
libssl-dev \
python3-dev \
gcc \
&& apt-get autoremove -y \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean

# Install s5cmd for fast reads from s3
RUN conda install -y -c conda-forge s5cmd
RUN pip3 install --upgrade pip --no-cache-dir

# CACHE INSTALL
COPY . /workspace/research/hugh/lm-evaluation-harness
WORKDIR /workspace/research/hugh/lm-evaluation-harness
RUN pip3 install -e .
RUN pip3 install -r requirements_for_docker.txt
RUN pip3 install transformers==4.41.2
RUN pip3 install flash-attn==2.5.8 --no-build-isolation

# set FORCE_CUDA because during `docker build` cuda is not accessible
ENV FORCE_CUDA="1"
ENV DD_SERVICE lm-evaluation-harness
ENV AWS_DEFAULT_REGION=us-west-2

WORKDIR /workspace/
16 changes: 16 additions & 0 deletions all_api_models.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
LITELLM/gpt-4
LITELLM/gpt-4-turbo
LITELLM/gpt-4o
LITELLM/gemini-1.5-pro-preview-0409
LITELLM/gemini-1.5-pro-preview-0514
LITELLM/gemini-1.5-flash-preview-0514
LITELLM/gemini-pro
LITELLM/mistral-large-latest
LITELLM/command
LITELLM/claude-3-opus-20240229
LITELLM/gpt-3.5-turbo
LITELLM/claude-2.1
LITELLM/claude-3-haiku-20240307
LITELLM/claude-3-sonnet-20240229
LITELLM/mistral-small-latest
LITELLM/mistral-medium-latest
60 changes: 60 additions & 0 deletions all_models_to_evaluate.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
microsoft/Phi-3-mini-4k-instruct
microsoft/Phi-3-mini-128k-instruct
microsoft/Phi-3-small-8k-instruct
microsoft/Phi-3-small-128k-instruct
microsoft/Phi-3-medium-4k-instruct
microsoft/Phi-3-medium-128k-instruct
meta-llama/Meta-Llama-3-8B-Instruct
meta-llama/Meta-Llama-3-8B
meta-llama/Meta-Llama-3-70B-Instruct
meta-llama/Meta-Llama-3-70B
mistralai/Codestral-22B-v0.1
mistralai/Mixtral-8x22B-v0.1
mistralai/Mixtral-8x22B-Instruct-v0.1
01-ai/Yi-34B-Chat
01-ai/Yi-6B-Chat
codellama/CodeLlama-13b-hf
codellama/CodeLlama-13b-Instruct-hf
codellama/CodeLlama-13b-Python-hf
codellama/CodeLlama-34b-hf
codellama/CodeLlama-34b-Instruct-hf
codellama/CodeLlama-34b-Python-hf
codellama/CodeLlama-70b-hf
codellama/CodeLlama-70b-Instruct-hf
codellama/CodeLlama-70b-Python-hf
codellama/CodeLlama-7b-hf
codellama/CodeLlama-7b-Instruct-hf
codellama/CodeLlama-7b-Python-hf
databricks/dbrx-base
databricks/dbrx-instruct
deepseek-ai/deepseek-coder-33b-instruct
deepseek-ai/deepseek-llm-67b-base
EleutherAI/gpt-neox-20b
EleutherAI/pythia-12b
EleutherAI/llemma_34b
EleutherAI/llemma_7b
google/gemma-2b
google/gemma-2b-it
google/gemma-7b
google/gemma-7b-it
lmsys/vicuna-33b-v1.3
meta-llama/Llama-2-13b-hf
meta-llama/Llama-2-70b-hf
meta-llama/Llama-2-7b-hf
microsoft/phi-1_5
microsoft/phi-2
mistralai/Mistral-7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
mistralai/Mistral-7B-v0.1
mistralai/Mixtral-8x7B-Instruct-v0.1
mistralai/Mixtral-8x7B-v0.1
openai-community/gpt2-xl
Phind/Phind-CodeLlama-34B-v2
tiiuae/falcon-180B-chat
Xwin-LM/Xwin-Math-13B-V1.0
Xwin-LM/Xwin-Math-70B-V1.0
Xwin-LM/Xwin-Math-7B-V1.0
peiyi9979/math-shepherd-mistral-7b-rl
deepseek-ai/deepseek-math-7b-rl
abacusai/Smaug-2-72B
abacusai/Smaug-34B-v0.1
1 change: 1 addition & 0 deletions lm_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
optimum_lm,
textsynth,
vllm_causallms,
litellm_completions,
)


Expand Down
155 changes: 155 additions & 0 deletions lm_eval/models/litellm_completions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# LiteLLM API engine, designed to largely mimic the OpenAIChatCompletions Format

import copy
import os
from collections import defaultdict
from importlib.util import find_spec
from typing import List, Literal, Optional, Tuple
import time

from tqdm import tqdm
from IPython import embed

import lm_eval.models.utils
from lm_eval import utils
from lm_eval.api.model import LM, TemplateLM
from lm_eval.api.registry import register_model
from lm_eval.models.utils import retry_on_specific_exceptions
from lm_eval.utils import eval_logger

import litellm
litellm.vertex_project = "sacred-vigil-412721" # TODO: replace with your own project name
litellm.vertex_location = "us-central1" # proj location

@register_model("litellm")
class LiteLLM(LM):
def __init__(
self,
model: str = "gpt-3.5-turbo",
truncate: bool = False,
**kwargs,
) -> None:
"""
:param model: str
Implements an OpenAI-style chat completion API for
models via LiteLLM
HuggingFace Tokenizer
OpenAI API model (e.g. gpt-3.5-turbo)
using the **gen_kwargs passed on init
:param truncate: bool
Truncate input if too long (if False and input is too long, throw error)
"""
super().__init__()
self.model = model
self.truncate = truncate

@property
def max_length(self) -> int:
# Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
return 4096

@property
def max_gen_toks(self) -> int:
return 4096

@property
def batch_size(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()

@property
def device(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()

def generate_until(self, requests) -> List[str]:
res = defaultdict(list)
re_ords = {}

# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
grouper = lm_eval.models.utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer(
[req.args for req in reqs], lambda x: (-len(x[0]), x[0])
)

pbar = tqdm(total=len(requests), disable=(self.rank != 0))
for key, re_ord in re_ords.items():
# n needs to be 1 because messages in
# chat completion are not batch but
# is regarded as a single conversation.
chunks = lm_eval.models.utils.chunks(re_ord.get_reordered(), n=1)
for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk)
inps = [{"role": "user", "content": context} for context in contexts]

gen_kwargs = all_gen_kwargs[0]
until = None
if isinstance(kwargs := copy.deepcopy(gen_kwargs), dict):
if "do_sample" in kwargs.keys():
kwargs.pop("do_sample")
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected repr(kwargs['until']) to be of type Union[str, list] but got {until}"
)
kwargs["stop"] = until
kwargs["max_tokens"] = kwargs.pop("max_gen_toks", self.max_gen_toks)
else:
raise ValueError(
f"Expected repr(kwargs) to be of type repr(dict) but got {kwargs}"
)

if "mistral" in self.model or "gemini" in self.model:
# Mistral doesn't support the stop parameter, so remove
kwargs.pop("stop", None)

if "gemini" in self.model:
# Sleep between requests for rate limit
time.sleep(3)

try:
response = litellm.completion(num_retries=10,
messages=inps,
model=self.model,
**kwargs,
)

for resp, (context, args_) in zip(response.choices, chunk):
s = resp.message.content

if until is not None:
for term in until:
if len(term) > 0:
s = s.split(term)[0]

res[key].append(s)

self.cache_hook.add_partial(
"generate_until", (context, {"until": until}), s
)
pbar.update(1)

except Exception as e:
print("Error in generate_until: {}".format(str(e)))
res[key].append("0")

# reorder this group of results back to original unsorted form
res[key] = re_ord.get_original(res[key])

pbar.close()

return grouper.get_original(res)

def loglikelihood(self, requests):
raise NotImplementedError("No support for logits.")

def loglikelihood_rolling(self, requests):
raise NotImplementedError("No support for logits.")
6 changes: 5 additions & 1 deletion lm_eval/models/vllm_causallms.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
add_bos_token: Optional[bool] = False,
tensor_parallel_size: int = 1,
quantization: Optional[str] = None,
max_gen_toks: int = 256,
max_gen_toks: int = 1000,
swap_space: int = 4,
batch_size: Union[str, int] = 1,
max_batch_size=None,
Expand Down Expand Up @@ -123,6 +123,10 @@ def __init__(

@property
def eot_token_id(self):

if self.tokenizer.eos_token_id is None:
self.tokenizer.eos_token_id = self.tokenizer.pad_token_id

# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id

Expand Down
5 changes: 0 additions & 5 deletions lm_eval/tasks/gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@ generation_kwargs:
repeats: 1
num_fewshot: 5
filter_list:
- name: "strict-match"
filter:
- function: "regex"
regex_pattern: "#### (\\-?[0-9\\.\\,]+)"
- function: "take_first"
- name: "flexible-extract"
filter:
- function: "regex"
Expand Down
Loading

0 comments on commit 6b37263

Please sign in to comment.