Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov committed Jan 2, 2025
1 parent a107a98 commit 89145b1
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
24 changes: 13 additions & 11 deletions tests/python_tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,10 @@ def convert_to_hf(
default_generation_config : HFGenerationConfig,
generation_config : GenerationConfig
) -> HFGenerationConfig:
kwargs = {}
if generation_config is None:
return

kwargs = {}
# generic parameters
kwargs['max_length'] = generation_config.max_length
# has higher priority than 'max_length'
Expand All @@ -195,7 +197,7 @@ def convert_to_hf(
kwargs['eos_token_id'] = generation_config.eos_token_id
else:
kwargs['eos_token_id'] = default_generation_config.eos_token_id

# copy penalties
kwargs['repetition_penalty'] = generation_config.repetition_penalty

Expand Down Expand Up @@ -239,14 +241,14 @@ def run_hugging_face(
opt_model,
hf_tokenizer,
prompts: List[str],
generation_config: GenerationConfig,
generation_configs: List[GenerationConfig] | GenerationConfig,
) -> List[GenerationResult]:
hf_generation_config = convert_to_hf(opt_model.generation_config, generation_config)
generation_results = []

for prompt in prompts:
for prompt, generation_config in zip(prompts, generation_configs):
hf_generation_config = convert_to_hf(opt_model.generation_config, generation_config)
inputs = hf_tokenizer(prompt, return_tensors="pt")
prompt_len = 0 if generation_config.echo else inputs['input_ids'].numel()

generate_outputs = opt_model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'],
generation_config=hf_generation_config, return_dict_in_generate=True, tokenizer=hf_tokenizer)
all_text_batch = hf_tokenizer.batch_decode([generated_ids[prompt_len:] for generated_ids in generate_outputs.sequences], skip_special_tokens=True)
Expand All @@ -270,8 +272,8 @@ def run_continuous_batching(
prompts: List[str],
generation_configs : List[GenerationConfig] | GenerationConfig
) -> List[GenerationResult]:
if not type(generation_configs) is List:
generation_configs = [generation_configs]
if type(generation_configs) is not list:
generation_configs = [generation_configs] * len(prompts)

cb_pipe = ContinuousBatchingPipeline(models_path, scheduler_config=scheduler_config, device='CPU')
output = cb_pipe.generate(prompts, generation_configs)
Expand Down Expand Up @@ -331,7 +333,7 @@ def compare_generation_result(hf_result: GenerationResult, ov_result: Generation


def compare_generation_results(prompts: List[str], hf_results: List[GenerationResult], ov_results: List[GenerationResult], generation_configs: List[GenerationConfig] | GenerationConfig):
if not type(generation_configs) is List:
if type(generation_configs) is not list:
generation_configs = [generation_configs]

assert len(prompts) == len(hf_results)
Expand Down Expand Up @@ -383,7 +385,7 @@ def run_cb_pipeline_with_ref(tmp_path: str, model_id: str, scheduler_params: dic
scheduler_config = get_scheduler_config(scheduler_params)

# override dataset's generation config
if not generation_config is None:
if generation_config is not None:
if type(generation_config) is dict:
generation_config = GenerationConfig(**generation_config)
generation_configs = [generation_config] * len(prompts)
Expand All @@ -395,7 +397,7 @@ def run_cb_pipeline_with_ref(tmp_path: str, model_id: str, scheduler_params: dic
if use_optimum:
convert_models(opt_model, hf_tokenizer, models_path)

hf_results = run_hugging_face(opt_model=opt_model, hf_tokenizer=hf_tokenizer, prompts=prompts, generation_config=generation_config)
hf_results = run_hugging_face(opt_model=opt_model, hf_tokenizer=hf_tokenizer, prompts=prompts, generation_configs=generation_configs)
ov_results = run_continuous_batching(models_path, scheduler_config, prompts, generation_configs)

compare_generation_results(prompts, hf_results, ov_results, generation_configs)
Expand Down
6 changes: 3 additions & 3 deletions tests/python_tests/test_continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def test_cb_streamer_vs_return_vs_stateful(prompt):
))
cb_pipe = get_continuous_batching(path)
streamed = []
generated = cb_pipe.generate(prompt, max_new_tokens=20, streamer=lambda subword: streamed.append(subword))
reference = ov_pipe.generate(prompt, max_new_tokens=20)
generated = cb_pipe.generate(prompt, max_new_tokens=20, streamer=lambda subword: streamed.append(subword)).texts[0]
reference = ov_pipe.generate(prompt, max_new_tokens=20).texts[0]
assert generated == "".join(streamed)
assert "".join(streamed) == reference

Expand Down Expand Up @@ -128,7 +128,7 @@ def test_chat_scenario_vs_stateful(model_descr, generation_config_kwargs: Dict):

for question in questions:
generated = cb_pipe.generate(question, generation_config=generation_config)
reference = ov_pipe.generate(question)
reference = ov_pipe.generate(question).texts[0]
assert generated == reference

# Test that finish_chat() doesn't fail just in case.
Expand Down

0 comments on commit 89145b1

Please sign in to comment.