Skip to content

Commit

Permalink
[ CB ] Fix for many requests for Speculative decoding scenario (#1056)
Browse files Browse the repository at this point in the history
The issue was reproducible only with request_number > 5 and several
`multisequence` requests.
The requests was marked as `out_of_memory` in case of `waiting`.

Additional changes:
* refactoring and improvement code structure `_try_finish_reuest` in
`Sampler`
* Remove extra generation from `draft_model` candidates in case of end
of sequence

Validated:
* CB Speculative decoding sample with 50 requests on `tiny llama +
llama2-7b`
  • Loading branch information
iefode authored Oct 29, 2024
1 parent 418aece commit 49a12d5
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 70 deletions.
7 changes: 4 additions & 3 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() {
static ManualTimer step_timer("step()");
step_timer.start();

// Pull awaiting requests
_pull_awaiting_requests();

m_pipeline_metrics.requests = m_requests.size();
Expand All @@ -148,8 +147,10 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() {
if (scheduler_output.m_total_num_scheduled_tokens == 0) {
for (size_t i = 0; i < m_requests.size(); ++i) {
SequenceGroup::Ptr sequence_group = m_requests[i];
sequence_group->set_out_of_memory();
sequence_group->notify_handle();
if (!sequence_group->is_waiting()) {
sequence_group->set_out_of_memory();
sequence_group->notify_handle();
}
}
_free_non_running_requests();
return;
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/continuous_batching_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc
const DeviceConfig& device_config,
ov::Core& core);

void _pull_awaiting_requests();
virtual void _pull_awaiting_requests();

void _fill_prompt_log_probs(std::vector<SequenceGroup::Ptr>& sequence_groups, ov::Tensor& logits);
public:
Expand Down
8 changes: 4 additions & 4 deletions src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -630,8 +630,6 @@ stop_sample_tokens(Sequence::Ptr running_sequence,
size_t& max_removed_tokens_per_request) {
running_sequence->remove_last_tokens(token_idx);
max_removed_tokens_per_request = std::max(max_removed_tokens_per_request, token_idx);
running_sequence->set_status(SequenceStatus::FINISHED);
running_sequence->set_finish_reason(GenerationFinishReason::STOP);
}

void
Expand Down Expand Up @@ -798,6 +796,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
}
min_generated_len = std::min(min_generated_len, running_sequence->get_generated_len());
}
align_all_sequence_len(sequence_group, min_generated_len, logit_processor);
for (const auto& dropped_seq_id : _try_finish_generation(sequence_group)) {
sampler_output.m_dropped_sequences.push_back(dropped_seq_id);
}
Expand All @@ -822,7 +821,9 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
}
// Notify handle after sampling is done.
// For non-streaming this is effective only when the generation is finished.
sequence_group->notify_handle();
OPENVINO_ASSERT(num_tokens_to_process >= max_removed_tokens_per_request);
size_t num_output_token_to_push = num_tokens_to_process - max_removed_tokens_per_request + 1;
sequence_group->notify_handle(num_output_token_to_push);
} else {
// we are in prompt processing phase when prompt is split into chunks and processed step by step
}
Expand All @@ -833,7 +834,6 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
sequence_group->finish_iteration();
// decrease sequence_group context in case of candidates generated by draft_model were not accepted by main_model
if (max_removed_tokens_per_request) {
align_all_sequence_len(sequence_group, min_generated_len, logit_processor);
auto min_processed_tokens = sequence_group->get_prompt_len() + min_generated_len - 1;
sequence_group->update_processed_tokens_num(min_processed_tokens);
logit_processor.update_generated_len(min_processed_tokens);
Expand Down
10 changes: 3 additions & 7 deletions src/cpp/src/sequence_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ class SequenceGroup {
m_generation_stream->push(std::move(outputs));
}

void notify_handle() {
void notify_handle(size_t num_output_token_to_push = 0) {
if (out_of_memory()) {
set_generation_status(GenerationStatus::IGNORED);
} else if (has_finished()) {
Expand All @@ -625,12 +625,8 @@ class SequenceGroup {
// (after stop string is detected its tokens are already sent)
if (num_total_seqs() == 1 &&
(m_sampling_params.stop_strings.empty() || m_sampling_params.include_stop_str_in_output)) {
auto previous_step_gen_len = get_num_processed_tokens() > 0 ? get_num_processed_tokens() - get_prompt_len() + 1 : 0;
auto generation_len = m_sequences.front()->get_generated_len();
if (previous_step_gen_len < generation_len) {
auto token_to_print = generation_len - previous_step_gen_len;
push_partial_outputs(token_to_print);
}
if (num_output_token_to_push)
push_partial_outputs(num_output_token_to_push);
} else if (has_finished() || out_of_memory()) {
push_outputs();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,16 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::Contin

void
ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::finish_request(SequenceGroup::Ptr request) {

for (const auto& sequence : request->get_sequences()) {
m_scheduler->free_sequence(sequence->get_id());
for (const auto& sequence: request->get_sequences()) {
if (m_scheduler->has_block_table(sequence->get_id())) {
m_scheduler->free_sequence(sequence->get_id());
}
}
m_sampler->clear_request_info(request->get_request_id());
request->set_generation_status(GenerationStatus::DROPPED_BY_HANDLE);
}

void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::finish_request(int64_t request_id) {
// finish all request s in case of -1
if (request_id == -1) {
while (!m_requests.empty()) {
const auto& request = *m_requests.rbegin();
finish_request(request);
m_requests.pop_back();
}
return;
}
for (size_t i = 0; i < m_requests.size(); ++i) {
auto& request = m_requests[i];
if (request->get_request_id() != request_id) {
Expand All @@ -50,8 +43,6 @@ void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::f

GeneratedRequests
ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::get_generated_requests() {
_pull_awaiting_requests();

GeneratedRequests result;
for (const auto& request : m_requests) {
const auto& request_id = request->get_request_id();
Expand Down Expand Up @@ -197,8 +188,6 @@ UpdateRequestResult
ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::init_request_by_candidate(
uint64_t request_id,
const GeneratedSequences& candidates) {
_pull_awaiting_requests();

for (auto& request : m_requests) {
if (request->get_request_id() != request_id) {
continue;
Expand All @@ -218,23 +207,16 @@ UpdateRequestResult
ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update_request(uint64_t request_id,
const GeneratedSequences& candidates,
bool is_update_logit_processor) {
_pull_awaiting_requests();

UpdateRequestResult result{0, 0};
for (auto& request : m_requests) {
if (request_id != request->get_request_id()) {
continue;
}

std::vector<Sequence::Ptr> running_sequences = request->get_running_sequences();
OPENVINO_ASSERT(running_sequences.size() > 0);
size_t min_generated_tokens, min_candidate_len;
if (request->get_context_len() == 0 && !request->get_num_tokens_to_validate()) {
if (candidates.begin()->second.log_probs.empty()) {
// lock generation in case on empty generation
request->pause_generation(true);
return result;
}
// init request by sequences in case the pipeline was not started
if (running_sequences.front()->get_generated_len() == 0 && !request->get_num_tokens_to_validate()) {
m_sampler->create_logit_processor(request_id, request->get_sampling_parameters(), request->get_prompt_ids());
auto& logit_processor = m_sampler->get_logit_processor(request_id);
result.inserted_tokens_cnt = init_request(request, candidates, logit_processor, is_update_logit_processor);
Expand Down Expand Up @@ -270,25 +252,30 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update
// update request context information to provide correct scheduling phase
const size_t num_processed_tokens = request->get_num_processed_tokens(),
prompt_len = request->get_prompt_len(),
updated_context_len = min_candidate_len + prompt_len;
if (num_processed_tokens > 0)
updated_context_len = min_candidate_len + prompt_len,
max_new_tokens = request->get_sampling_parameters().max_new_tokens;
size_t generated_len = request->get_context_len() - request->get_prompt_len();
if (num_processed_tokens > 0) {
request->update_processed_tokens_num(num_processed_tokens - result.removed_tokens_cnt);
generated_len -= result.removed_tokens_cnt;
}
request->set_num_validated_tokens(result.inserted_tokens_cnt);
request->pause_generation(false);
generated_len += result.inserted_tokens_cnt;

// to pause `draft_model` generation in case of `generated_len >= max_new_tokens - 1` to generate last token by `main_model`
if (!m_is_validation_mode_enabled && (generated_len >= max_new_tokens - 1 || result.inserted_tokens_cnt == 0)) {
request->pause_generation(true);
}
break;
}

return result;
}

void
ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::unlock_next_request_generation() {
for (auto& request : m_requests) {
if (!request->has_finished() && !request->can_generate_tokens()) {
request->pause_generation(false);
return;
}
}
ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::pull_awaiting_requests() {
ContinuousBatchingImpl::_pull_awaiting_requests();
}

void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::multistep() {
Expand All @@ -308,13 +295,16 @@ void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::m
request->pause_generation(true);
} else if (request->get_num_processed_tokens() == 0 && sampling_params.num_return_sequences > 1) {
request->pause_generation(true);
} else if (sampling_params.num_assistant_tokens <= generated_tokens_cnt) {
} else if (sampling_params.num_assistant_tokens <= generated_tokens_cnt && sampling_params.assistant_confidence_threshold == 0.f) {
request->pause_generation(true);
} else if (request->get_num_processed_tokens() - request->get_prompt_len() + 1 >= sampling_params.max_new_tokens - 1) {
} else if (request->get_context_len() >= request->get_prompt_len() &&
(request->get_context_len() - request->get_prompt_len()) >= sampling_params.max_new_tokens - 1) {
request->pause_generation(true);
} else if (sampling_params.max_new_tokens == 0) {
request->pause_generation(true);
}
to_generate |= request->can_generate_tokens();
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@ class ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl :
bool is_validation_mode_enabled);

void multistep();
void finish_request(int64_t request_id = -1);
void unlock_next_request_generation();

void finish_request(int64_t request_id = -1);
void pull_awaiting_requests();
GeneratedRequests get_generated_requests();
UpdateRequestResult update_request(uint64_t request_id, const GeneratedSequences& candidates, bool is_update_logit_processor);

UpdateRequestResult init_request_by_candidate(uint64_t request_id, const GeneratedSequences& candidates);

protected:
void finish_request(SequenceGroup::Ptr request);
void _pull_awaiting_requests() override {};
};
}
26 changes: 16 additions & 10 deletions src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,17 @@ GenerationHandle
ContinuousBatchingPipeline::SpeculativeDecodingImpl::add_request(uint64_t request_id,
const ov::Tensor& input_ids,
ov::genai::GenerationConfig sampling_params) {
m_draft_pipeline->add_request(request_id, input_ids, sampling_params);
std::lock_guard<std::mutex> lock(m_draft_generations_mutex);
m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, input_ids, sampling_params)});
return m_main_pipeline->add_request(request_id, input_ids, sampling_params);
};

GenerationHandle
ContinuousBatchingPipeline::SpeculativeDecodingImpl::add_request(uint64_t request_id,
const std::string& prompt,
ov::genai::GenerationConfig sampling_params) {
m_draft_pipeline->add_request(request_id, prompt, sampling_params);
std::lock_guard<std::mutex> lock(m_draft_generations_mutex);
m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, prompt, sampling_params)});
return m_main_pipeline->add_request(request_id, prompt, sampling_params);
}

Expand All @@ -112,12 +114,18 @@ void print_generated_request(const ov::genai::GeneratedRequests& requests) {
}

void ContinuousBatchingPipeline::SpeculativeDecodingImpl::step() {
// this blocks adding new requests during step as it may break coherence between main and draft models
std::lock_guard<std::mutex> lock{m_draft_generations_mutex};
m_draft_pipeline->pull_awaiting_requests();
m_main_pipeline->pull_awaiting_requests();

// generate candidates by draft model
ManualTimer draft_timer("speculative_decoding: draft_model: multistep()");
draft_timer.start();
m_draft_pipeline->multistep();
draft_timer.end();
m_sd_metrics.draft_duration += draft_timer.get_duration();
m_pipeline_metrics = m_main_pipeline->get_metrics();

// to generate num_matches statistic
std::map<int64_t, UpdateRequestResult> update_sequence_info;
Expand All @@ -133,6 +141,7 @@ void ContinuousBatchingPipeline::SpeculativeDecodingImpl::step() {
m_main_pipeline->step();
main_timer.end();
m_sd_metrics.main_duration += main_timer.get_duration();
m_pipeline_metrics = m_main_pipeline->get_metrics();

auto main_generated_requests = m_main_pipeline->get_generated_requests();
for (const auto& checked_sequence : main_generated_requests) {
Expand All @@ -145,8 +154,8 @@ void ContinuousBatchingPipeline::SpeculativeDecodingImpl::step() {
auto request_id = draft_request.first;
if (!main_generated_requests.count(request_id)) {
m_draft_pipeline->finish_request(request_id);
// in case of some requests not to started, unlock generation of next request
m_draft_pipeline->unlock_next_request_generation();
// remove draft_generation_handle from queue
m_draft_generations.erase(request_id);
}
auto updated_seq_info = update_sequence_info[request_id];
float acceptance_rate = 1 - static_cast<float>(updated_seq_info.removed_tokens_cnt) / updated_seq_info.inserted_tokens_cnt;
Expand Down Expand Up @@ -175,18 +184,16 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector<
}
}, streamer);

std::vector<GenerationHandle> main_generations, draft_generations;
std::vector<GenerationHandle> main_generations;
for (size_t request_id = 0; request_id < input_ids.size(); ++request_id) {
OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch.");
main_generations.push_back(m_main_pipeline->add_request(request_id, input_ids[request_id], sampling_params[request_id]));

auto draft_sampling_params = sampling_params[request_id];
// set the parameters do not stop draft generation without stopping of the same request for main pipeline
draft_sampling_params.max_new_tokens = draft_sampling_params.max_new_tokens + 1;
draft_sampling_params.min_new_tokens = draft_sampling_params.min_new_tokens + 1;
draft_sampling_params.ignore_eos = true;
draft_generations.push_back(m_draft_pipeline->add_request(request_id, input_ids[request_id], draft_sampling_params));
// decrease generation len to generate last token by main model
std::lock_guard<std::mutex> lock(m_draft_generations_mutex);
m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, input_ids[request_id], draft_sampling_params)});
}

std::vector<EncodedGenerationResult> results;
Expand All @@ -210,7 +217,6 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector<
if (streamer_ptr) {
streamer_ptr->end();
}
draft_generations.clear();

for (size_t generation_idx = 0; generation_idx < main_generations.size(); ++generation_idx) {
const auto& generation = main_generations[generation_idx];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ class ContinuousBatchingPipeline::SpeculativeDecodingImpl : public ContinuousBat
protected:
std::shared_ptr<ContinuousBatchingForSpeculativeDecodingImpl> m_main_pipeline, m_draft_pipeline;
SpeculativeDecodingMetrics m_sd_metrics;
// Mutex protecting access to m_draft_generations, so add_request and step methods can be called from different threads
std::mutex m_draft_generations_mutex;
std::map<uint64_t, GenerationHandle> m_draft_generations;

public:
SpeculativeDecodingImpl(const std::filesystem::path& main_models_path,
Expand Down
1 change: 1 addition & 0 deletions tests/cpp/speculative_decoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class CBForSDTest : public testing::Test, public ov::genai::ContinuousBatchingPi
std::lock_guard<std::mutex> lock{m_awaiting_requests_mutex};
m_awaiting_requests.push_back(sequence_group);
}
pull_awaiting_requests();
return std::make_shared<ov::genai::GenerationHandleImpl>(sequence_group->get_generation_stream(), sampling_params);
};

Expand Down
Loading

0 comments on commit 49a12d5

Please sign in to comment.