Skip to content

Commit

Permalink
Aligned results
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexKoff88 committed Jan 2, 2025
1 parent 65d15c0 commit 1d1fd96
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 27 deletions.
9 changes: 1 addition & 8 deletions tools/who_what_benchmark/whowhatbench/model_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,7 @@ def load_text_llamacpp_pipeline(model_dir):
logger.error(
"Failed to import llama_cpp package. Please install llama-cpp-python.")
exit(-1)
# from llama_cpp.llama_tokenizer import LlamaHFTokenizer
# tokenizer = LlamaHFTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
# tokenizer.chat_template = None
# model = Llama(model_dir,
# chat_format="functionary-v1",
# tokenizer=tokenizer)
model = Llama(model_dir, chat_format="")#, chat_format="llama-2")
#model.create_chat_completion(messages = [])
model = Llama(model_dir)
return model


Expand Down
28 changes: 19 additions & 9 deletions tools/who_what_benchmark/whowhatbench/text_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
generation_config=None,
generation_config_base=None,
seqs_per_request=None,
use_chat_template=None,
) -> None:
assert (
base_model is not None or gt_data is not None
Expand All @@ -123,6 +124,7 @@ def __init__(
self.generation_config_base = generation_config
self.seqs_per_request = seqs_per_request
self.generation_fn = gen_answer_fn
self.use_chat_template = use_chat_template
if self.generation_config is not None:
assert self.seqs_per_request is not None

Expand Down Expand Up @@ -202,15 +204,22 @@ def worst_examples(self, top_k: int = 5, metric="similarity"):
return res

def _generate_data(self, model, gen_answer_fn=None, generation_config=None):
def default_gen_answer(model, tokenizer, prompt, max_new_tokens, crop_question):
inputs = self.tokenizer(prompt, return_tensors="pt")

tokens = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens)

if crop_question:
tokens = tokens[:, inputs["input_ids"].shape[-1] :]

return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)[0]
def default_gen_answer(model, tokenizer, prompt, max_new_tokens, crop_question, use_chat_template=False):
if use_chat_template:
message = [{"role": "user", "content": prompt}]
inputs = tokenizer.apply_chat_template(message, tokenize=True, add_generation_prompt=True, return_tensors="pt")
tokens = model.generate(inputs, do_sample=False, max_new_tokens=max_new_tokens)
if crop_question:
tokens = tokens[:, inputs.shape[-1]:]
res = self.tokenizer.decode(tokens[0], skip_special_tokens=True)
print(res)
return res
else:
inputs = self.tokenizer(prompt, return_tensors="pt")
tokens = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens)
if crop_question:
tokens = tokens[:, inputs["input_ids"].shape[-1] :]
return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)[0]

gen_answer_fn = gen_answer_fn or default_gen_answer

Expand Down Expand Up @@ -250,6 +259,7 @@ def default_gen_answer(model, tokenizer, prompt, max_new_tokens, crop_question):
p,
self.max_new_tokens,
self._crop_question,
self.use_chat_template
)
)
else:
Expand Down
31 changes: 21 additions & 10 deletions tools/who_what_benchmark/whowhatbench/wwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def parse_args():
default=None,
help="Tokenizer for divergency metric. If not provided, it will be load from base_model or target_model.",
)
parser.add_argument(
"--chat-template",
action="store_true",
help="Whether apply the default chat template.",
)
parser.add_argument(
"--gt-data",
default=None,
Expand Down Expand Up @@ -255,18 +260,23 @@ def diff_strings(a: str, b: str, *, use_loguru_colors: bool = False) -> str:
return "".join(output)


def genai_gen_text(model, tokenizer, question, max_new_tokens, skip_question):
def genai_gen_text(model, tokenizer, question, max_new_tokens, skip_question, use_chat_template=False):
return model.generate(question, do_sample=False, max_new_tokens=max_new_tokens)


def llamacpp_gen_text(model, tokenizer, question, max_new_tokens, skip_question):
output = model(question, max_tokens=max_new_tokens, echo=True, temperature=0.0)
return output["choices"][0]["text"]
# output = model.create_completion(question, max_tokens=max_new_tokens, temperature=0.0, echo=True)
# print(output)
# return output["choices"][0]["text"]#
# output = model.create_chat_completion(messages=[{"role": "user", "content": question}], max_tokens=max_new_tokens, temperature=0.0)
# return output["choices"][0]["message"]["content"]
def llamacpp_gen_text(model, tokenizer, question, max_new_tokens, skip_question, use_chat_template=False):
if use_chat_template:
output = model.create_chat_completion(messages=[{"role": "user", "content": question}], max_tokens=max_new_tokens, temperature=0.0)
text = output["choices"][0]["message"]["content"]
if skip_question:
text = text[len(question):]
return text
else:
output = model(question, max_tokens=max_new_tokens, echo=True, temperature=0.0)
text = output["choices"][0]["text"]
if skip_question:
text = text[len(question):]
return text


def genai_gen_image(model, prompt, num_inference_steps, generator=None):
Expand Down Expand Up @@ -358,7 +368,8 @@ def create_evaluator(base_model, args):
similarity_model_id=args.data_encoder,
num_samples=args.num_samples,
language=args.language,
gen_answer_fn=gen_answer_fn
gen_answer_fn=gen_answer_fn,
use_chat_template=args.chat_template,
)
elif task == "text-to-image":
return EvaluatorCLS(
Expand Down

0 comments on commit 1d1fd96

Please sign in to comment.