Skip to content

Commit

Permalink
Final fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov committed Dec 27, 2024
1 parent 1ee971e commit 197f901
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 24 deletions.
11 changes: 8 additions & 3 deletions src/cpp/src/generation_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,15 @@ GenerationConfig::GenerationConfig(const std::filesystem::path& json_path) {
// note that stop_token_ids is not present in HF GenerationConfig, but some generation_config.json define
// multiple eos_token_id (e.g. https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/generation_config.json)
// so, we need to read them as 'stop_token_ids'
read_json_param(data, "eos_token_id", stop_token_ids);
std::vector<int64_t> ordered_stop_token_ids;
read_json_param(data, "eos_token_id", ordered_stop_token_ids);

if (eos_token_id == -1 && !stop_token_ids.empty()) {
eos_token_id = *stop_token_ids.begin();
if (!ordered_stop_token_ids.empty()) {
std::copy(ordered_stop_token_ids.begin(), ordered_stop_token_ids.end(), std::back_inserter(stop_token_ids));

if (eos_token_id == -1) {
eos_token_id = ordered_stop_token_ids[0];
}
}

// note that echo is not present in HF GenerationConfig
Expand Down
28 changes: 16 additions & 12 deletions tests/python_tests/test_continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,27 +105,30 @@ def test_cb_streamer_vs_return_vs_stateful(prompt):

generation_configs = [
dict(do_sample=False, max_new_tokens=20),
dict(do_sample=False, num_beam_groups=3, num_beams=15, num_return_sequences=1, max_new_tokens=10, diversity_penalty=1.0)
dict(do_sample=False, num_beam_groups=3, num_beams=15, num_return_sequences=1, max_new_tokens=10, diversity_penalty=1.0, repetition_penalty=1.0)
]
questions = [
'1+1=',
'What is the previous answer?',
'Why is the Sun yellow?',
'What was my first question?'
]
@pytest.mark.parametrize("generation_config", generation_configs[1:])
@pytest.mark.parametrize("generation_config_kwargs", generation_configs[1:])
@pytest.mark.parametrize("model_descr", get_chat_models_list())
@pytest.mark.precommit
def test_chat_scenario_vs_stateful(model_descr, generation_config: Dict):
def test_chat_scenario_vs_stateful(model_descr, generation_config_kwargs: Dict):
model_id, path, hf_tokenizer, opt_model, ov_pipe = read_model((model_descr[0], model_descr[1] / '_test_chat'))
cb_pipe = get_continuous_batching(path)

ov_pipe.start_chat()
cb_pipe.start_chat()

generation_config = GenerationConfig(**generation_config_kwargs)
ov_pipe.set_generation_config(generation_config)

for question in questions:
generated = cb_pipe.generate(question, **generation_config)
reference = ov_pipe.generate(question, **generation_config)
generated = cb_pipe.generate(question, generation_config=generation_config)
reference = ov_pipe.generate(question)
assert generated == reference

# Test that finish_chat() doesn't fail just in case.
Expand Down Expand Up @@ -168,12 +171,13 @@ def test_post_oom_health(tmp_path, sampling_config):
# Pre-emption
#

def get_parallel_samppling_seq_len_300() -> GenerationConfig:
def get_parallel_sampling_seq_len_300() -> GenerationConfig:
generation_config = GenerationConfig()
generation_config.num_return_sequences = 3
generation_config.do_sample = True
generation_config.top_k = 10
generation_config.top_p = 0.5
# TODO: add generation_config.generator and return parameters below
# generation_config.num_return_sequences = 3
# generation_config.do_sample = True
# generation_config.top_k = 10
# generation_config.top_p = 0.5
generation_config.max_new_tokens = 300
return generation_config

Expand All @@ -188,8 +192,8 @@ def get_beam_search_seq_len_300() -> GenerationConfig:

scheduler_params_list = [({"num_kv_blocks": 2, "dynamic_split_fuse": True, "max_num_batched_tokens": 256, "max_num_seqs": 256}, get_greedy()),
({"num_kv_blocks": 2, "dynamic_split_fuse": False, "max_num_batched_tokens": 256, "max_num_seqs": 256}, get_greedy()),
({"num_kv_blocks": 10, "dynamic_split_fuse": True}, get_parallel_samppling_seq_len_300()),
({"num_kv_blocks": 10, "dynamic_split_fuse": False}, get_parallel_samppling_seq_len_300()),
({"num_kv_blocks": 10, "dynamic_split_fuse": True}, get_parallel_sampling_seq_len_300()),
({"num_kv_blocks": 10, "dynamic_split_fuse": False}, get_parallel_sampling_seq_len_300()),
({"num_kv_blocks": 34, "dynamic_split_fuse": True, "max_num_batched_tokens": 256, "max_num_seqs": 256}, get_beam_search()),
({"num_kv_blocks": 34, "dynamic_split_fuse": False, "max_num_batched_tokens": 256, "max_num_seqs": 256}, get_beam_search()),
({"num_kv_blocks": 100, "dynamic_split_fuse": True}, get_beam_search_seq_len_300()),
Expand Down
20 changes: 11 additions & 9 deletions tests/python_tests/test_llm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import openvino_genai as ov_genai
from openvino_genai import StopCriteria
from openvino_genai import StopCriteria, GenerationConfig
import pytest
from typing import Union, List, Dict, Optional
import numpy as np
Expand Down Expand Up @@ -298,31 +298,33 @@ def test_batch_size_switch():
#

generation_configs = [
dict(do_sample=False, max_new_tokens=20),
dict(do_sample=False, num_beam_groups=3, num_beams=15, num_return_sequences=1, max_new_tokens=10, diversity_penalty=1.0)
dict(max_new_tokens=20),
dict(max_new_tokens=10, num_beam_groups=3, num_beams=15, num_return_sequences=1, diversity_penalty=1.0)
]


questions = [
'1+1=',
'What is the previous answer?',
'Why is the Sun yellow?',
'What was my first question?'
]


@pytest.mark.parametrize("generation_config", generation_configs)
@pytest.mark.parametrize("generation_config_kwargs", generation_configs)
@pytest.mark.parametrize("model_descr", get_chat_models_list())
@pytest.mark.precommit
@pytest.mark.nightly
def test_chat_compare_with_HF(model_descr, generation_config: Dict):
def test_chat_compare_with_HF(model_descr, generation_config_kwargs: Dict):
chat_history_hf = []
chat_history_ov = []
chat_prompt = ''

# Will set add_special_tokens=False inside pipeline when start_chat() is called.
model_id, path, tokenizer, opt_model, ov_pipe = read_model((model_descr[0], model_descr[1] / '_test_chat'))

from transformers import GenerationConfig as HFGenerationConfig
hf_generation_config = HFGenerationConfig(**generation_config_kwargs)
ov_generation_config = GenerationConfig(**generation_config_kwargs)

ov_pipe.start_chat()
for prompt in questions:
chat_history_hf.append({'role': 'user', 'content': prompt})
Expand All @@ -331,11 +333,11 @@ def test_chat_compare_with_HF(model_descr, generation_config: Dict):
chat_prompt = tokenizer.apply_chat_template(chat_history_hf, tokenize=False, add_generation_prompt=True)
tokenized = tokenizer(chat_prompt, return_tensors='pt', add_special_tokens=False)

answer = opt_model.generate(**tokenized, **generation_config)
answer = opt_model.generate(**tokenized, generation_config=hf_generation_config)
answer_str = tokenizer.decode(answer[0, tokenized['input_ids'].numel():], skip_special_tokens=True)
chat_history_hf.append({'role': 'assistant', 'content': answer_str})

answer_ov = ov_pipe.generate(prompt, **generation_config)
answer_ov = ov_pipe.generate(prompt, generation_config=ov_generation_config)
chat_history_ov.append({'role': 'assistant', 'content': answer_ov})

ov_pipe.finish_chat()
Expand Down

0 comments on commit 197f901

Please sign in to comment.