Skip to content

Commit

Permalink
Sampler tests refactoring: part 1
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov committed Jan 2, 2025
1 parent 482fa79 commit 4fcfa1d
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 79 deletions.
8 changes: 1 addition & 7 deletions src/python/py_llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,7 @@ py::object call_common_generate(
results = py::cast(pipe.generate(tokenized_input, updated_config, streamer));
},
[&](std::string string_input) {
DecodedResults res = pipe.generate(string_input, updated_config, streamer);
// If input was a string return a single string otherwise return DecodedResults.
if (updated_config.has_value() && (*updated_config).num_return_sequences == 1) {
results = py::cast<py::object>(pyutils::handle_utf8(res.texts[0]));
} else {
results = py::cast(res);
}
results = py::cast(pipe.generate(string_input, updated_config, streamer));
},
[&](std::vector<std::string> string_input) {
// For DecodedResults texts getter already handles utf8 decoding.
Expand Down
126 changes: 88 additions & 38 deletions tests/python_tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from optimum.intel import OVModelForCausalLM
from pathlib import Path
from openvino_genai import ContinuousBatchingPipeline, SchedulerConfig, GenerationResult, GenerationConfig
from openvino_genai import ContinuousBatchingPipeline, LLMPipeline, SchedulerConfig, GenerationResult, GenerationConfig, DecodedResults
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import GenerationConfig as HFGenerationConfig
from typing import List, Tuple
Expand Down Expand Up @@ -302,6 +302,8 @@ def convert_to_hf(
kwargs['bos_token_id'] = default_generation_config.bos_token_id
kwargs['eos_token_id'] = default_generation_config.eos_token_id
kwargs['pad_token_id'] = default_generation_config.pad_token_id

# copy penalties
kwargs['repetition_penalty'] = generation_config.repetition_penalty

if generation_config.is_beam_search():
Expand Down Expand Up @@ -332,15 +334,15 @@ def run_hugging_face(
opt_model,
hf_tokenizer,
prompts: List[str],
generation_configs: List[GenerationConfig],
generation_config: GenerationConfig,
) -> List[GenerationResult]:
hf_generation_config = convert_to_hf(opt_model.generation_config, generation_config)
generation_results = []
for prompt, generation_config in zip(prompts, generation_configs):
for prompt in prompts:
inputs = hf_tokenizer(prompt, return_tensors="pt")
prompt_len = inputs['input_ids'].numel()
generate_outputs = opt_model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'],
generation_config=convert_to_hf(opt_model.generation_config, generation_config),
return_dict_in_generate=True, tokenizer=hf_tokenizer)
generation_config=hf_generation_config, return_dict_in_generate=True, tokenizer=hf_tokenizer)
all_text_batch = hf_tokenizer.batch_decode([generated_ids[prompt_len:] for generated_ids in generate_outputs.sequences], skip_special_tokens=True)

generation_result = GenerationResult()
Expand All @@ -360,16 +362,45 @@ def run_continuous_batching(
models_path : Path,
scheduler_config : SchedulerConfig,
prompts: List[str],
generation_configs : List[GenerationConfig]
generation_configs : List[GenerationConfig] | GenerationConfig
) -> List[GenerationResult]:
pipe = ContinuousBatchingPipeline(models_path, scheduler_config, "CPU")
output = pipe.generate(prompts, generation_configs)
del pipe
if not type(generation_configs) is List:
generation_configs = [generation_configs]

cb_pipe = ContinuousBatchingPipeline(models_path, scheduler_config=scheduler_config, device='CPU')
output = cb_pipe.generate(prompts, generation_configs)

del cb_pipe
shutil.rmtree(models_path)

return output


def compare_results(hf_result: GenerationResult, ov_result: GenerationResult, generation_config: GenerationConfig):
def run_llm_pipeline(
models_path : Path,
prompts: List[str],
generation_config : GenerationConfig
) -> List[GenerationResult]:
ov_pipe = LLMPipeline(models_path, device='CPU')

generation_results = []
for prompt in prompts:
generate_outputs : DecodedResults = ov_pipe.generate(inputs=prompt, generation_config=generation_config)

generation_result = GenerationResult()
generation_result.m_generation_ids = generate_outputs.texts
# sequences_scores are available only for beam search case
if generation_config.is_beam_search():
generation_result.m_scores = [score for score in generate_outputs.scores]
generation_results.append(generation_result)

del ov_pipe
shutil.rmtree(models_path)

return generation_results


def compare_generation_result(hf_result: GenerationResult, ov_result: GenerationResult, generation_config: GenerationConfig ):
if generation_config.is_beam_search():
assert len(hf_result.m_scores) == len(ov_result.m_scores)
for hf_score, ov_score in zip(hf_result.m_scores, ov_result.m_scores):
Expand All @@ -386,46 +417,78 @@ def compare_results(hf_result: GenerationResult, ov_result: GenerationResult, ge
assert hf_text == ov_text


def get_hugging_face_model_and_tokenizer(model_id: str, use_optimum = True):
def compare_generation_results(prompts: List[str], hf_results: List[GenerationResult], ov_results: List[GenerationResult], generation_configs: List[GenerationConfig] | GenerationConfig):
if not type(generation_configs) is List:
generation_configs = [generation_configs]

assert len(prompts) == len(hf_results)
assert len(prompts) == len(ov_results)

for prompt, ref_result, ov_result, generation_config in zip(prompts, hf_results, ov_results, generation_configs):
print(f"Prompt = {prompt}\nReference result = {ref_result}\nOpenVINO result = {ov_result.m_generation_ids}")
compare_generation_result(ref_result, ov_result, generation_config)


def get_hugging_face_models(model_id: str, use_optimum = True):
hf_tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
opt_model = OVModelForCausalLM.from_pretrained(model_id, export=True, trust_remote_code=True) if use_optimum else \
AutoModelForCausalLM.from_pretrained(model_id)
return opt_model, hf_tokenizer


def save_ov_model_from_optimum(model, hf_tokenizer, models_path: Path):
model.save_pretrained(models_path)
def convert_models(opt_model : OVModelForCausalLM, hf_tokenizer : AutoTokenizer, models_path: Path):
opt_model.save_pretrained(models_path)

# convert tokenizers as well
from openvino_tokenizers import convert_tokenizer
from openvino import serialize
tokenizer, detokenizer = convert_tokenizer(hf_tokenizer, with_detokenizer=True, skip_special_tokens=True)

tokenizer, detokenizer = convert_tokenizer(hf_tokenizer, with_detokenizer=True)
serialize(tokenizer, models_path / "openvino_tokenizer.xml")
serialize(detokenizer, models_path / "openvino_detokenizer.xml")


def _generate_and_compare_with_reference_results(models_path: Path, prompts: List[str], reference_results: List[GenerationResult], generation_configs: List[GenerationConfig], scheduler_config: SchedulerConfig):
ov_results : List[GenerationResult] = run_continuous_batching(models_path, scheduler_config, prompts, generation_configs)
def run_llm_pipeline_with_ref(model_id: str, prompts: List[str], generation_config: GenerationConfig | dict, tmp_path: Path):
use_optimum = True
models_path : Path = tmp_path / model_id
opt_model, hf_tokenizer = get_hugging_face_models(model_id, use_optimum)

if type(generation_config) is dict:
generation_config = GenerationConfig(**generation_config)

if use_optimum:
convert_models(opt_model, hf_tokenizer, models_path)

ov_results = run_llm_pipeline(models_path, prompts, generation_config)
hf_results = run_hugging_face(opt_model, hf_tokenizer, prompts, generation_config)

compare_generation_results(prompts, hf_results, ov_results, generation_config)

assert len(prompts) == len(reference_results)
assert len(prompts) == len(ov_results)

for prompt, ref_result, ov_result, generation_config in zip(prompts, reference_results, ov_results, generation_configs):
print(f"Prompt = {prompt}\nref result = {ref_result}\nOV result = {ov_result.m_generation_ids}")
compare_results(ref_result, ov_result, generation_config)
def run_cb_pipeline_with_ref(tmp_path: str, model_id: str, scheduler_params: dict = {}, generation_config : GenerationConfig | dict = None):
prompts, generation_configs = get_test_dataset()
scheduler_config = get_scheduler_config(scheduler_params)

# override dataset's generation config
if not generation_config is None:
if type(generation_config) is dict:
generation_config = GenerationConfig(**generation_config)
generation_configs = [generation_config] * len(prompts)

def generate_and_compare_with_hf(model_id: str, prompts: List[str], generation_configs: List[GenerationConfig], scheduler_config: SchedulerConfig, tmp_path: Path):
use_optimum = True
models_path : Path = tmp_path / model_id
opt_model, hf_tokenizer = get_hugging_face_model_and_tokenizer(model_id, use_optimum)
opt_model, hf_tokenizer = get_hugging_face_models(model_id, use_optimum)

if use_optimum:
save_ov_model_from_optimum(opt_model, hf_tokenizer, models_path)
convert_models(opt_model, hf_tokenizer, models_path)

hf_results = run_hugging_face(opt_model=opt_model, hf_tokenizer=hf_tokenizer, prompts=prompts, generation_configs=generation_configs)
_generate_and_compare_with_reference_results(models_path, prompts, hf_results, generation_configs, scheduler_config)
ov_results = run_continuous_batching(models_path, scheduler_config, prompts, generation_configs)

compare_generation_results(prompts, hf_results, ov_results, generation_configs)


# TODO: remove after Generator property is supported by LLMPipeline / VLMPipeline
def generate_and_compare_with_reference_text(models_path: Path, prompts: List[str], reference_texts_per_prompt: List[List[str]], generation_configs: List[GenerationConfig], scheduler_config: SchedulerConfig):
ov_results : List[GenerationResult] = run_continuous_batching(models_path, scheduler_config, prompts, generation_configs)

Expand All @@ -440,19 +503,6 @@ def generate_and_compare_with_reference_text(models_path: Path, prompts: List[st
assert ref_text == ov_text


def run_continuous_batching_pipeline_test(tmp_path: str, model_id: str, scheduler_params: dict = None, generation_config = None):
prompts, generation_configs = get_test_dataset()
scheduler_config = get_scheduler_config(scheduler_params)

if generation_config is not None:
generation_config.rng_seed = 0
generation_configs = [generation_config] * len(prompts)

generate_and_compare_with_hf(model_id, prompts, generation_configs, scheduler_config, tmp_path)


DEFAULT_SCHEDULER_CONFIG = get_scheduler_config({"num_kv_blocks": 300, "dynamic_split_fuse": True, "max_num_batched_tokens": 256, "max_num_seqs": 256})

def get_image_by_link(link):
from PIL import Image
import requests
Expand Down
24 changes: 12 additions & 12 deletions tests/python_tests/test_continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from pathlib import Path
from openvino_genai import ContinuousBatchingPipeline, GenerationConfig, Tokenizer

from common import get_hugging_face_model_and_tokenizer, save_ov_model_from_optimum, generate_and_compare_with_reference_text, \
get_scheduler_config, get_greedy, run_continuous_batching_pipeline_test, get_beam_search, get_greedy, \
from common import get_hugging_face_models, convert_models, generate_and_compare_with_reference_text, \
get_scheduler_config, get_greedy, run_cb_pipeline_with_ref, get_beam_search, get_greedy, \
get_multinomial_all_parameters, get_multinomial_temperature_and_num_return_sequence, \
get_multinomial_temperature_and_top_k, get_multinomial_temperature, get_multinomial_temperature_and_top_p
from test_sampling import RandomSamplingTestStruct, get_current_platform_ref_texts
Expand Down Expand Up @@ -39,19 +39,19 @@ def read_models_list(file_name: str):
@pytest.mark.precommit
@pytest.mark.parametrize("model_id", read_models_list(os.path.join(os.path.dirname(os.path.realpath(__file__)), "models", "precommit")))
def test_e2e_precommit(tmp_path, model_id):
run_continuous_batching_pipeline_test(tmp_path, model_id)
run_cb_pipeline_with_ref(tmp_path, model_id)


@pytest.mark.nightly
@pytest.mark.parametrize("model_id", read_models_list(os.path.join(os.path.dirname(os.path.realpath(__file__)), "models", "nightly")))
def test_e2e_nightly(tmp_path, model_id):
run_continuous_batching_pipeline_test(tmp_path, model_id)
run_cb_pipeline_with_ref(tmp_path, model_id)


@pytest.mark.real_models
@pytest.mark.parametrize("model_id", read_models_list(os.path.join(os.path.dirname(os.path.realpath(__file__)), "models", "real_models")))
def test_e2e_real_models(tmp_path, model_id):
run_continuous_batching_pipeline_test(tmp_path, model_id)
run_cb_pipeline_with_ref(tmp_path, model_id)

#
# Comparison with stateful
Expand Down Expand Up @@ -150,10 +150,10 @@ def test_post_oom_health(tmp_path, sampling_config):
scheduler_config.num_kv_blocks = 10 # Low cache size to trigger OOM quickly

model_id : str = "facebook/opt-125m"
opt_model, hf_tokenizer = get_hugging_face_model_and_tokenizer(model_id, use_optimum=True)
opt_model, hf_tokenizer = get_hugging_face_models(model_id, use_optimum=True)

models_path : Path = tmp_path / model_id
save_ov_model_from_optimum(opt_model, hf_tokenizer, models_path)
convert_models(opt_model, hf_tokenizer, models_path)

cb_pipe = ContinuousBatchingPipeline(models_path, Tokenizer(models_path), scheduler_config, "CPU")

Expand Down Expand Up @@ -201,7 +201,7 @@ def get_beam_search_seq_len_300() -> GenerationConfig:
@pytest.mark.parametrize("params", scheduler_params_list)
@pytest.mark.precommit
def test_preemption(tmp_path, params):
run_continuous_batching_pipeline_test(tmp_path, "facebook/opt-125m", scheduler_params=params[0], generation_config=params[1])
run_cb_pipeline_with_ref(tmp_path, "facebook/opt-125m", scheduler_params=params[0], generation_config=params[1])


multinomial_params = RandomSamplingTestStruct(
Expand Down Expand Up @@ -252,10 +252,10 @@ def test_preemption_with_multinomial(tmp_path, dynamic_split_fuse):
config.rng_seed = 0
config.max_new_tokens = 30
model_id : str = "facebook/opt-125m"
model, hf_tokenizer = get_hugging_face_model_and_tokenizer(model_id, use_optimum=True)
model, hf_tokenizer = get_hugging_face_models(model_id, use_optimum=True)

models_path : Path = tmp_path / model_id
save_ov_model_from_optimum(model, hf_tokenizer, models_path)
convert_models(model, hf_tokenizer, models_path)

scheduler_config = get_scheduler_config({"num_kv_blocks": 3, "dynamic_split_fuse": dynamic_split_fuse, "max_num_batched_tokens": 256, "max_num_seqs": 256})
generate_and_compare_with_reference_text(models_path, multinomial_params.prompts, multinomial_params.ref_texts, generation_configs, scheduler_config)
Expand Down Expand Up @@ -333,10 +333,10 @@ def test_preemption_with_multinomial_n_seq(tmp_path, dynamic_split_fuse):
for config in generation_configs:
config.rng_seed = 0
model_id : str = "facebook/opt-125m"
model, hf_tokenizer = get_hugging_face_model_and_tokenizer(model_id, use_optimum=True)
model, hf_tokenizer = get_hugging_face_models(model_id, use_optimum=True)

models_path : Path = tmp_path / model_id
save_ov_model_from_optimum(model, hf_tokenizer, models_path)
convert_models(model, hf_tokenizer, models_path)

# needed kv_blocks - 16 (2 blocks per sequence (30 tokens to generated text + prompt (> 2 tokens)) * (1 + 3 + 4) seq )
scheduler_config = get_scheduler_config({"num_kv_blocks": 8, "dynamic_split_fuse": dynamic_split_fuse, "max_num_batched_tokens": 256, "max_num_seqs": 256})
Expand Down
6 changes: 4 additions & 2 deletions tests/python_tests/test_kv_cache_eviction.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from openvino import serialize
from transformers import AutoTokenizer

from common import TESTS_ROOT, run_continuous_batching_pipeline_test
from common import TESTS_ROOT, run_cb_pipeline_with_ref


def load_prompts_dataset(file_name : str) -> Dict[str, List[str]]:
Expand Down Expand Up @@ -150,6 +150,7 @@ def get_greedy_seq_len_300() -> GenerationConfig:
generation_config.max_new_tokens = 300
return generation_config


def get_beam_search_seq_len_300() -> GenerationConfig:
generation_config = GenerationConfig()
generation_config.num_beam_groups = 3
Expand All @@ -159,6 +160,7 @@ def get_beam_search_seq_len_300() -> GenerationConfig:
generation_config.num_return_sequences = generation_config.num_beams
return generation_config


scheduler_params_list = [
({"num_kv_blocks": 0, "cache_size": 0, "dynamic_split_fuse": True, "enable_prefix_caching": True}, get_greedy_seq_len_300()),
({"num_kv_blocks": 0, "cache_size": 0, "dynamic_split_fuse": False, "max_num_batched_tokens": 600, "enable_prefix_caching": True}, get_beam_search_seq_len_300()),
Expand All @@ -168,5 +170,5 @@ def get_beam_search_seq_len_300() -> GenerationConfig:
@pytest.mark.parametrize("params", scheduler_params_list)
@pytest.mark.precommit
def test_dynamic_memory_allocation(tmp_path, params):
run_continuous_batching_pipeline_test(tmp_path, "facebook/opt-125m", params[0], params[1])
run_cb_pipeline_with_ref(tmp_path, "facebook/opt-125m", scheduler_params=params[0], generation_config=params[1])

Loading

0 comments on commit 4fcfa1d

Please sign in to comment.