From 49a12d5368aefb0f56f744ee479466135befc7ef Mon Sep 17 00:00:00 2001 From: Irina Efode Date: Tue, 29 Oct 2024 12:16:57 +0400 Subject: [PATCH] [ CB ] Fix for many requests for Speculative decoding scenario (#1056) The issue was reproducible only with request_number > 5 and several `multisequence` requests. The requests was marked as `out_of_memory` in case of `waiting`. Additional changes: * refactoring and improvement code structure `_try_finish_reuest` in `Sampler` * Remove extra generation from `draft_model` candidates in case of end of sequence Validated: * CB Speculative decoding sample with 50 requests on `tiny llama + llama2-7b` --- src/cpp/src/continuous_batching_impl.cpp | 7 +- src/cpp/src/continuous_batching_impl.hpp | 2 +- src/cpp/src/sampler.cpp | 8 +-- src/cpp/src/sequence_group.hpp | 10 +-- ...batching_for_speculative_decoding_impl.cpp | 64 ++++++++----------- ...batching_for_speculative_decoding_impl.hpp | 5 +- .../speculative_decoding_impl.cpp | 26 +++++--- .../speculative_decoding_impl.hpp | 3 + tests/cpp/speculative_decoding.cpp | 1 + .../continuous_batching_benchmark.cpp | 26 ++++++-- 10 files changed, 82 insertions(+), 70 deletions(-) diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index 09abbe29ab..c56d02afef 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -123,7 +123,6 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() { static ManualTimer step_timer("step()"); step_timer.start(); - // Pull awaiting requests _pull_awaiting_requests(); m_pipeline_metrics.requests = m_requests.size(); @@ -148,8 +147,10 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() { if (scheduler_output.m_total_num_scheduled_tokens == 0) { for (size_t i = 0; i < m_requests.size(); ++i) { SequenceGroup::Ptr sequence_group = m_requests[i]; - sequence_group->set_out_of_memory(); - sequence_group->notify_handle(); + if (!sequence_group->is_waiting()) { + sequence_group->set_out_of_memory(); + sequence_group->notify_handle(); + } } _free_non_running_requests(); return; diff --git a/src/cpp/src/continuous_batching_impl.hpp b/src/cpp/src/continuous_batching_impl.hpp index 78e92d6c76..8276edb36b 100644 --- a/src/cpp/src/continuous_batching_impl.hpp +++ b/src/cpp/src/continuous_batching_impl.hpp @@ -49,7 +49,7 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc const DeviceConfig& device_config, ov::Core& core); - void _pull_awaiting_requests(); + virtual void _pull_awaiting_requests(); void _fill_prompt_log_probs(std::vector& sequence_groups, ov::Tensor& logits); public: diff --git a/src/cpp/src/sampler.cpp b/src/cpp/src/sampler.cpp index 621d0d4a33..38deb74186 100644 --- a/src/cpp/src/sampler.cpp +++ b/src/cpp/src/sampler.cpp @@ -630,8 +630,6 @@ stop_sample_tokens(Sequence::Ptr running_sequence, size_t& max_removed_tokens_per_request) { running_sequence->remove_last_tokens(token_idx); max_removed_tokens_per_request = std::max(max_removed_tokens_per_request, token_idx); - running_sequence->set_status(SequenceStatus::FINISHED); - running_sequence->set_finish_reason(GenerationFinishReason::STOP); } void @@ -798,6 +796,7 @@ SamplerOutput Sampler::sample(std::vector & sequence_groups, } min_generated_len = std::min(min_generated_len, running_sequence->get_generated_len()); } + align_all_sequence_len(sequence_group, min_generated_len, logit_processor); for (const auto& dropped_seq_id : _try_finish_generation(sequence_group)) { sampler_output.m_dropped_sequences.push_back(dropped_seq_id); } @@ -822,7 +821,9 @@ SamplerOutput Sampler::sample(std::vector & sequence_groups, } // Notify handle after sampling is done. // For non-streaming this is effective only when the generation is finished. - sequence_group->notify_handle(); + OPENVINO_ASSERT(num_tokens_to_process >= max_removed_tokens_per_request); + size_t num_output_token_to_push = num_tokens_to_process - max_removed_tokens_per_request + 1; + sequence_group->notify_handle(num_output_token_to_push); } else { // we are in prompt processing phase when prompt is split into chunks and processed step by step } @@ -833,7 +834,6 @@ SamplerOutput Sampler::sample(std::vector & sequence_groups, sequence_group->finish_iteration(); // decrease sequence_group context in case of candidates generated by draft_model were not accepted by main_model if (max_removed_tokens_per_request) { - align_all_sequence_len(sequence_group, min_generated_len, logit_processor); auto min_processed_tokens = sequence_group->get_prompt_len() + min_generated_len - 1; sequence_group->update_processed_tokens_num(min_processed_tokens); logit_processor.update_generated_len(min_processed_tokens); diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index b2532b220c..c5be82f0f2 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -609,7 +609,7 @@ class SequenceGroup { m_generation_stream->push(std::move(outputs)); } - void notify_handle() { + void notify_handle(size_t num_output_token_to_push = 0) { if (out_of_memory()) { set_generation_status(GenerationStatus::IGNORED); } else if (has_finished()) { @@ -625,12 +625,8 @@ class SequenceGroup { // (after stop string is detected its tokens are already sent) if (num_total_seqs() == 1 && (m_sampling_params.stop_strings.empty() || m_sampling_params.include_stop_str_in_output)) { - auto previous_step_gen_len = get_num_processed_tokens() > 0 ? get_num_processed_tokens() - get_prompt_len() + 1 : 0; - auto generation_len = m_sequences.front()->get_generated_len(); - if (previous_step_gen_len < generation_len) { - auto token_to_print = generation_len - previous_step_gen_len; - push_partial_outputs(token_to_print); - } + if (num_output_token_to_push) + push_partial_outputs(num_output_token_to_push); } else if (has_finished() || out_of_memory()) { push_outputs(); } diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp index fd30e9f608..c649c544a6 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp @@ -20,23 +20,16 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::Contin void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::finish_request(SequenceGroup::Ptr request) { - - for (const auto& sequence : request->get_sequences()) { - m_scheduler->free_sequence(sequence->get_id()); + for (const auto& sequence: request->get_sequences()) { + if (m_scheduler->has_block_table(sequence->get_id())) { + m_scheduler->free_sequence(sequence->get_id()); + } } m_sampler->clear_request_info(request->get_request_id()); + request->set_generation_status(GenerationStatus::DROPPED_BY_HANDLE); } void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::finish_request(int64_t request_id) { - // finish all request s in case of -1 - if (request_id == -1) { - while (!m_requests.empty()) { - const auto& request = *m_requests.rbegin(); - finish_request(request); - m_requests.pop_back(); - } - return; - } for (size_t i = 0; i < m_requests.size(); ++i) { auto& request = m_requests[i]; if (request->get_request_id() != request_id) { @@ -50,8 +43,6 @@ void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::f GeneratedRequests ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::get_generated_requests() { - _pull_awaiting_requests(); - GeneratedRequests result; for (const auto& request : m_requests) { const auto& request_id = request->get_request_id(); @@ -197,8 +188,6 @@ UpdateRequestResult ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::init_request_by_candidate( uint64_t request_id, const GeneratedSequences& candidates) { - _pull_awaiting_requests(); - for (auto& request : m_requests) { if (request->get_request_id() != request_id) { continue; @@ -218,8 +207,6 @@ UpdateRequestResult ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update_request(uint64_t request_id, const GeneratedSequences& candidates, bool is_update_logit_processor) { - _pull_awaiting_requests(); - UpdateRequestResult result{0, 0}; for (auto& request : m_requests) { if (request_id != request->get_request_id()) { @@ -227,14 +214,9 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update } std::vector running_sequences = request->get_running_sequences(); + OPENVINO_ASSERT(running_sequences.size() > 0); size_t min_generated_tokens, min_candidate_len; - if (request->get_context_len() == 0 && !request->get_num_tokens_to_validate()) { - if (candidates.begin()->second.log_probs.empty()) { - // lock generation in case on empty generation - request->pause_generation(true); - return result; - } - // init request by sequences in case the pipeline was not started + if (running_sequences.front()->get_generated_len() == 0 && !request->get_num_tokens_to_validate()) { m_sampler->create_logit_processor(request_id, request->get_sampling_parameters(), request->get_prompt_ids()); auto& logit_processor = m_sampler->get_logit_processor(request_id); result.inserted_tokens_cnt = init_request(request, candidates, logit_processor, is_update_logit_processor); @@ -270,11 +252,21 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update // update request context information to provide correct scheduling phase const size_t num_processed_tokens = request->get_num_processed_tokens(), prompt_len = request->get_prompt_len(), - updated_context_len = min_candidate_len + prompt_len; - if (num_processed_tokens > 0) + updated_context_len = min_candidate_len + prompt_len, + max_new_tokens = request->get_sampling_parameters().max_new_tokens; + size_t generated_len = request->get_context_len() - request->get_prompt_len(); + if (num_processed_tokens > 0) { request->update_processed_tokens_num(num_processed_tokens - result.removed_tokens_cnt); + generated_len -= result.removed_tokens_cnt; + } request->set_num_validated_tokens(result.inserted_tokens_cnt); request->pause_generation(false); + generated_len += result.inserted_tokens_cnt; + + // to pause `draft_model` generation in case of `generated_len >= max_new_tokens - 1` to generate last token by `main_model` + if (!m_is_validation_mode_enabled && (generated_len >= max_new_tokens - 1 || result.inserted_tokens_cnt == 0)) { + request->pause_generation(true); + } break; } @@ -282,13 +274,8 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update } void -ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::unlock_next_request_generation() { - for (auto& request : m_requests) { - if (!request->has_finished() && !request->can_generate_tokens()) { - request->pause_generation(false); - return; - } - } +ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::pull_awaiting_requests() { + ContinuousBatchingImpl::_pull_awaiting_requests(); } void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::multistep() { @@ -308,13 +295,16 @@ void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::m request->pause_generation(true); } else if (request->get_num_processed_tokens() == 0 && sampling_params.num_return_sequences > 1) { request->pause_generation(true); - } else if (sampling_params.num_assistant_tokens <= generated_tokens_cnt) { + } else if (sampling_params.num_assistant_tokens <= generated_tokens_cnt && sampling_params.assistant_confidence_threshold == 0.f) { request->pause_generation(true); - } else if (request->get_num_processed_tokens() - request->get_prompt_len() + 1 >= sampling_params.max_new_tokens - 1) { + } else if (request->get_context_len() >= request->get_prompt_len() && + (request->get_context_len() - request->get_prompt_len()) >= sampling_params.max_new_tokens - 1) { + request->pause_generation(true); + } else if (sampling_params.max_new_tokens == 0) { request->pause_generation(true); } to_generate |= request->can_generate_tokens(); } } } -} \ No newline at end of file +} diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp index a75a160f14..0040708b4b 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp @@ -23,9 +23,9 @@ class ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl : bool is_validation_mode_enabled); void multistep(); - void finish_request(int64_t request_id = -1); - void unlock_next_request_generation(); + void finish_request(int64_t request_id = -1); + void pull_awaiting_requests(); GeneratedRequests get_generated_requests(); UpdateRequestResult update_request(uint64_t request_id, const GeneratedSequences& candidates, bool is_update_logit_processor); @@ -33,5 +33,6 @@ class ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl : protected: void finish_request(SequenceGroup::Ptr request); + void _pull_awaiting_requests() override {}; }; } \ No newline at end of file diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 2008f1fb9a..864646d5cd 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -82,7 +82,8 @@ GenerationHandle ContinuousBatchingPipeline::SpeculativeDecodingImpl::add_request(uint64_t request_id, const ov::Tensor& input_ids, ov::genai::GenerationConfig sampling_params) { - m_draft_pipeline->add_request(request_id, input_ids, sampling_params); + std::lock_guard lock(m_draft_generations_mutex); + m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, input_ids, sampling_params)}); return m_main_pipeline->add_request(request_id, input_ids, sampling_params); }; @@ -90,7 +91,8 @@ GenerationHandle ContinuousBatchingPipeline::SpeculativeDecodingImpl::add_request(uint64_t request_id, const std::string& prompt, ov::genai::GenerationConfig sampling_params) { - m_draft_pipeline->add_request(request_id, prompt, sampling_params); + std::lock_guard lock(m_draft_generations_mutex); + m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, prompt, sampling_params)}); return m_main_pipeline->add_request(request_id, prompt, sampling_params); } @@ -112,12 +114,18 @@ void print_generated_request(const ov::genai::GeneratedRequests& requests) { } void ContinuousBatchingPipeline::SpeculativeDecodingImpl::step() { + // this blocks adding new requests during step as it may break coherence between main and draft models + std::lock_guard lock{m_draft_generations_mutex}; + m_draft_pipeline->pull_awaiting_requests(); + m_main_pipeline->pull_awaiting_requests(); + // generate candidates by draft model ManualTimer draft_timer("speculative_decoding: draft_model: multistep()"); draft_timer.start(); m_draft_pipeline->multistep(); draft_timer.end(); m_sd_metrics.draft_duration += draft_timer.get_duration(); + m_pipeline_metrics = m_main_pipeline->get_metrics(); // to generate num_matches statistic std::map update_sequence_info; @@ -133,6 +141,7 @@ void ContinuousBatchingPipeline::SpeculativeDecodingImpl::step() { m_main_pipeline->step(); main_timer.end(); m_sd_metrics.main_duration += main_timer.get_duration(); + m_pipeline_metrics = m_main_pipeline->get_metrics(); auto main_generated_requests = m_main_pipeline->get_generated_requests(); for (const auto& checked_sequence : main_generated_requests) { @@ -145,8 +154,8 @@ void ContinuousBatchingPipeline::SpeculativeDecodingImpl::step() { auto request_id = draft_request.first; if (!main_generated_requests.count(request_id)) { m_draft_pipeline->finish_request(request_id); - // in case of some requests not to started, unlock generation of next request - m_draft_pipeline->unlock_next_request_generation(); + // remove draft_generation_handle from queue + m_draft_generations.erase(request_id); } auto updated_seq_info = update_sequence_info[request_id]; float acceptance_rate = 1 - static_cast(updated_seq_info.removed_tokens_cnt) / updated_seq_info.inserted_tokens_cnt; @@ -175,18 +184,16 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< } }, streamer); - std::vector main_generations, draft_generations; + std::vector main_generations; for (size_t request_id = 0; request_id < input_ids.size(); ++request_id) { OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch."); main_generations.push_back(m_main_pipeline->add_request(request_id, input_ids[request_id], sampling_params[request_id])); auto draft_sampling_params = sampling_params[request_id]; // set the parameters do not stop draft generation without stopping of the same request for main pipeline - draft_sampling_params.max_new_tokens = draft_sampling_params.max_new_tokens + 1; - draft_sampling_params.min_new_tokens = draft_sampling_params.min_new_tokens + 1; draft_sampling_params.ignore_eos = true; - draft_generations.push_back(m_draft_pipeline->add_request(request_id, input_ids[request_id], draft_sampling_params)); - // decrease generation len to generate last token by main model + std::lock_guard lock(m_draft_generations_mutex); + m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, input_ids[request_id], draft_sampling_params)}); } std::vector results; @@ -210,7 +217,6 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< if (streamer_ptr) { streamer_ptr->end(); } - draft_generations.clear(); for (size_t generation_idx = 0; generation_idx < main_generations.size(); ++generation_idx) { const auto& generation = main_generations[generation_idx]; diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp index b427e311b4..f854713b5e 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp @@ -30,6 +30,9 @@ class ContinuousBatchingPipeline::SpeculativeDecodingImpl : public ContinuousBat protected: std::shared_ptr m_main_pipeline, m_draft_pipeline; SpeculativeDecodingMetrics m_sd_metrics; + // Mutex protecting access to m_draft_generations, so add_request and step methods can be called from different threads + std::mutex m_draft_generations_mutex; + std::map m_draft_generations; public: SpeculativeDecodingImpl(const std::filesystem::path& main_models_path, diff --git a/tests/cpp/speculative_decoding.cpp b/tests/cpp/speculative_decoding.cpp index 08ce6aaf66..bb10c2cc8f 100644 --- a/tests/cpp/speculative_decoding.cpp +++ b/tests/cpp/speculative_decoding.cpp @@ -28,6 +28,7 @@ class CBForSDTest : public testing::Test, public ov::genai::ContinuousBatchingPi std::lock_guard lock{m_awaiting_requests_mutex}; m_awaiting_requests.push_back(sequence_group); } + pull_awaiting_requests(); return std::make_shared(sequence_group->get_generation_stream(), sampling_params); }; diff --git a/tools/continuous_batching/benchmark/continuous_batching_benchmark.cpp b/tools/continuous_batching/benchmark/continuous_batching_benchmark.cpp index 7c3e75eafa..27c64d04a8 100644 --- a/tools/continuous_batching/benchmark/continuous_batching_benchmark.cpp +++ b/tools/continuous_batching/benchmark/continuous_batching_benchmark.cpp @@ -256,8 +256,15 @@ class GenerationInfoCollector { this->start_time = start_time; } - void add_generation(ov::genai::ContinuousBatchingPipeline* pipe, Dataset* dataset, size_t request_id) { - ov::genai::GenerationHandle generation_handle = pipe->add_request(request_id, dataset->m_prompts[request_id], dataset->m_sampling_params[request_id]); + void add_generation(ov::genai::ContinuousBatchingPipeline* pipe, Dataset* dataset, size_t request_id, bool is_speculative_decoding_enabled) { + auto sampling_params = dataset->m_sampling_params[request_id]; + if (is_speculative_decoding_enabled) { + // to enable static speculative decoding + sampling_params.num_assistant_tokens = 5; + // to enable dynamic speculative decoding + // sampling_params.assistant_confidence_threshold = 0.4f; + } + ov::genai::GenerationHandle generation_handle = pipe->add_request(request_id, dataset->m_prompts[request_id], sampling_params); std::lock_guard lock(mutex); generations_info.emplace_back(std::move(generation_handle), dataset->m_input_lens[request_id]); } @@ -306,7 +313,7 @@ class GenerationInfoCollector { } }; -void trafficSimulator(ov::genai::ContinuousBatchingPipeline* pipe, Dataset* dataset, std::string request_rate, GenerationInfoCollector* generation_info_collector) { +void trafficSimulator(ov::genai::ContinuousBatchingPipeline* pipe, Dataset* dataset, std::string request_rate, GenerationInfoCollector* generation_info_collector, bool is_speculative_decoding_enabled) { double numeric_request_rate; std::random_device rd; std::mt19937 gen(rd()); @@ -333,7 +340,7 @@ void trafficSimulator(ov::genai::ContinuousBatchingPipeline* pipe, Dataset* data generation_info_collector->set_start_time(std::chrono::steady_clock::now()); for (size_t request_id = 0; request_id < dataset->size(); ++request_id) { std::cout << "Traffic thread adding request to the queue..." << std::endl; - generation_info_collector->add_generation(pipe, dataset, request_id); + generation_info_collector->add_generation(pipe, dataset, request_id, is_speculative_decoding_enabled); if (numeric_request_rate > 0) std::this_thread::sleep_for(std::chrono::milliseconds(int(distribution(gen) * 1000))); } @@ -434,6 +441,7 @@ int main(int argc, char* argv[]) try { ("b,max_batch_size", "A maximum number of batched tokens", cxxopts::value()->default_value("256")) ("dynamic_split_fuse", "Whether to use dynamic split-fuse or vLLM scheduling", cxxopts::value()->default_value("true")) ("m,model", "Path to model and tokenizers base directory", cxxopts::value()->default_value(".")) + ("draft_model", "Path to assistant model directory", cxxopts::value()->default_value("")) ("dataset", "Path to dataset .json file", cxxopts::value()->default_value("./ShareGPT_V3_unfiltered_cleaned_split.json")) ("max_input_len", "Max input length take from dataset", cxxopts::value()->default_value("1024")) ("max_output_len", "Max output length", cxxopts::value()->default_value("2048")) @@ -462,6 +470,7 @@ int main(int argc, char* argv[]) try { const size_t max_batch_size = result["max_batch_size"].as(); const bool dynamic_split_fuse = result["dynamic_split_fuse"].as(); const std::string models_path = result["model"].as(); + const std::string draft_model_path = result["draft_model"].as(); const std::string dataset_path = result["dataset"].as(); const size_t max_input_len = result["max_input_len"].as(); const size_t max_output_len = result["max_output_len"].as(); @@ -471,6 +480,8 @@ int main(int argc, char* argv[]) try { const size_t cache_size = result["cache_size"].as(); const bool use_cache_eviction = result["use_cache_eviction"].as(); + bool is_speculative_decoding_enabled = !draft_model_path.empty(); + // Create requests for generation Dataset dataset = filtered_dataset(models_path, dataset_path, num_prompts, max_input_len, max_output_len); @@ -509,6 +520,9 @@ int main(int argc, char* argv[]) try { std::cout << "\tPlugin configuration JSON: " << device_config << std::endl; ov::AnyMap device_config_map = {}; + if (is_speculative_decoding_enabled) { + device_config_map.insert({ ov::genai::draft_model(draft_model_path) }); + } if (!parse_plugin_config_string(device_config, device_config_map)) { std::cout << "ERROR: Wrong json parameter in device_config." << std::endl; return EXIT_FAILURE; @@ -524,14 +538,14 @@ int main(int argc, char* argv[]) try { std::atomic finishGenerationThread{false}; if (request_rate == "inf") { - std::thread trafficSimulatorThread(trafficSimulator, &pipe, &dataset, request_rate, &generation_info_collector); + std::thread trafficSimulatorThread(trafficSimulator, &pipe, &dataset, request_rate, &generation_info_collector, is_speculative_decoding_enabled); trafficSimulatorThread.join(); } std::thread lmmEngineThread(llmEngineLoop, &pipe, &dataset, &finishGenerationThread); std::thread statisticsReporterThread(statisticsReporter, &generation_info_collector, num_prompts); if (request_rate != "inf") { - std::thread trafficSimulatorThread(trafficSimulator, &pipe, &dataset, request_rate, &generation_info_collector); + std::thread trafficSimulatorThread(trafficSimulator, &pipe, &dataset, request_rate, &generation_info_collector, is_speculative_decoding_enabled); trafficSimulatorThread.join(); } statisticsReporterThread.join();