From 1d1fd9639c24862497191c36f8980f794a9b0a6d Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 2 Jan 2025 11:27:10 +0400 Subject: [PATCH] Aligned results --- .../whowhatbench/model_loaders.py | 9 +----- .../whowhatbench/text_evaluator.py | 28 +++++++++++------ tools/who_what_benchmark/whowhatbench/wwb.py | 31 +++++++++++++------ 3 files changed, 41 insertions(+), 27 deletions(-) diff --git a/tools/who_what_benchmark/whowhatbench/model_loaders.py b/tools/who_what_benchmark/whowhatbench/model_loaders.py index 3c705a7c02..c792a3c0b2 100644 --- a/tools/who_what_benchmark/whowhatbench/model_loaders.py +++ b/tools/who_what_benchmark/whowhatbench/model_loaders.py @@ -48,14 +48,7 @@ def load_text_llamacpp_pipeline(model_dir): logger.error( "Failed to import llama_cpp package. Please install llama-cpp-python.") exit(-1) - # from llama_cpp.llama_tokenizer import LlamaHFTokenizer - # tokenizer = LlamaHFTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - # tokenizer.chat_template = None - # model = Llama(model_dir, - # chat_format="functionary-v1", - # tokenizer=tokenizer) - model = Llama(model_dir, chat_format="")#, chat_format="llama-2") - #model.create_chat_completion(messages = []) + model = Llama(model_dir) return model diff --git a/tools/who_what_benchmark/whowhatbench/text_evaluator.py b/tools/who_what_benchmark/whowhatbench/text_evaluator.py index 50ce224def..bbeaf0b762 100644 --- a/tools/who_what_benchmark/whowhatbench/text_evaluator.py +++ b/tools/who_what_benchmark/whowhatbench/text_evaluator.py @@ -108,6 +108,7 @@ def __init__( generation_config=None, generation_config_base=None, seqs_per_request=None, + use_chat_template=None, ) -> None: assert ( base_model is not None or gt_data is not None @@ -123,6 +124,7 @@ def __init__( self.generation_config_base = generation_config self.seqs_per_request = seqs_per_request self.generation_fn = gen_answer_fn + self.use_chat_template = use_chat_template if self.generation_config is not None: assert self.seqs_per_request is not None @@ -202,15 +204,22 @@ def worst_examples(self, top_k: int = 5, metric="similarity"): return res def _generate_data(self, model, gen_answer_fn=None, generation_config=None): - def default_gen_answer(model, tokenizer, prompt, max_new_tokens, crop_question): - inputs = self.tokenizer(prompt, return_tensors="pt") - - tokens = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens) - - if crop_question: - tokens = tokens[:, inputs["input_ids"].shape[-1] :] - - return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)[0] + def default_gen_answer(model, tokenizer, prompt, max_new_tokens, crop_question, use_chat_template=False): + if use_chat_template: + message = [{"role": "user", "content": prompt}] + inputs = tokenizer.apply_chat_template(message, tokenize=True, add_generation_prompt=True, return_tensors="pt") + tokens = model.generate(inputs, do_sample=False, max_new_tokens=max_new_tokens) + if crop_question: + tokens = tokens[:, inputs.shape[-1]:] + res = self.tokenizer.decode(tokens[0], skip_special_tokens=True) + print(res) + return res + else: + inputs = self.tokenizer(prompt, return_tensors="pt") + tokens = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens) + if crop_question: + tokens = tokens[:, inputs["input_ids"].shape[-1] :] + return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)[0] gen_answer_fn = gen_answer_fn or default_gen_answer @@ -250,6 +259,7 @@ def default_gen_answer(model, tokenizer, prompt, max_new_tokens, crop_question): p, self.max_new_tokens, self._crop_question, + self.use_chat_template ) ) else: diff --git a/tools/who_what_benchmark/whowhatbench/wwb.py b/tools/who_what_benchmark/whowhatbench/wwb.py index fde2d3174c..47eed2cae6 100644 --- a/tools/who_what_benchmark/whowhatbench/wwb.py +++ b/tools/who_what_benchmark/whowhatbench/wwb.py @@ -40,6 +40,11 @@ def parse_args(): default=None, help="Tokenizer for divergency metric. If not provided, it will be load from base_model or target_model.", ) + parser.add_argument( + "--chat-template", + action="store_true", + help="Whether apply the default chat template.", + ) parser.add_argument( "--gt-data", default=None, @@ -255,18 +260,23 @@ def diff_strings(a: str, b: str, *, use_loguru_colors: bool = False) -> str: return "".join(output) -def genai_gen_text(model, tokenizer, question, max_new_tokens, skip_question): +def genai_gen_text(model, tokenizer, question, max_new_tokens, skip_question, use_chat_template=False): return model.generate(question, do_sample=False, max_new_tokens=max_new_tokens) -def llamacpp_gen_text(model, tokenizer, question, max_new_tokens, skip_question): - output = model(question, max_tokens=max_new_tokens, echo=True, temperature=0.0) - return output["choices"][0]["text"] - # output = model.create_completion(question, max_tokens=max_new_tokens, temperature=0.0, echo=True) - # print(output) - # return output["choices"][0]["text"]# - # output = model.create_chat_completion(messages=[{"role": "user", "content": question}], max_tokens=max_new_tokens, temperature=0.0) - # return output["choices"][0]["message"]["content"] +def llamacpp_gen_text(model, tokenizer, question, max_new_tokens, skip_question, use_chat_template=False): + if use_chat_template: + output = model.create_chat_completion(messages=[{"role": "user", "content": question}], max_tokens=max_new_tokens, temperature=0.0) + text = output["choices"][0]["message"]["content"] + if skip_question: + text = text[len(question):] + return text + else: + output = model(question, max_tokens=max_new_tokens, echo=True, temperature=0.0) + text = output["choices"][0]["text"] + if skip_question: + text = text[len(question):] + return text def genai_gen_image(model, prompt, num_inference_steps, generator=None): @@ -358,7 +368,8 @@ def create_evaluator(base_model, args): similarity_model_id=args.data_encoder, num_samples=args.num_samples, language=args.language, - gen_answer_fn=gen_answer_fn + gen_answer_fn=gen_answer_fn, + use_chat_template=args.chat_template, ) elif task == "text-to-image": return EvaluatorCLS(