Skip to content

Commit

Permalink
Generic fixes for CB integration via LLMPipeline (#961)
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov authored Oct 14, 2024
1 parent ec90baa commit 56b05c2
Showing 1 changed file with 34 additions and 6 deletions.
40 changes: 34 additions & 6 deletions src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "text_callback_streamer.hpp"
#include "continuous_batching_impl.hpp"
#include "paged_attention_transformations.hpp"
#include "utils.hpp"

namespace ov::genai {
template<class... Ts> struct overloaded : Ts... {using Ts::operator()...;};
Expand All @@ -18,15 +19,18 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl(
m_tokenizer = tokenizer;
ov::Core core;

auto [core_plugin_config, compile_plugin_config] = ov::genai::utils::split_core_complile_config(plugin_config);
core.set_property(core_plugin_config);

// The model can be compiled for GPU as well
std::shared_ptr<ov::Model> model = core.read_model(models_path + "/openvino_model.xml");

DeviceConfig device_config(core, scheduler_config, device, plugin_config);
DeviceConfig device_config(core, scheduler_config, device, compile_plugin_config);

bool is_need_per_layer_cache_control = scheduler_config.use_cache_eviction;
apply_paged_attention_transformations(model, device_config, is_need_per_layer_cache_control);

ov::InferRequest infer_request = core.compile_model(model, device_config.get_device(), plugin_config).create_infer_request();
ov::InferRequest infer_request = core.compile_model(model, device_config.get_device(), compile_plugin_config).create_infer_request();

// setup KV caches
m_cache_manager = std::make_shared<CacheManager>(device_config, core);
Expand Down Expand Up @@ -69,7 +73,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::add_request(uint64_t request
sampling_params.set_eos_token_id(m_tokenizer.get_eos_token_id());
sampling_params.validate();
SequenceGroup::Ptr sequence_group = std::make_shared<SequenceGroup>(request_id, input_ids,
sampling_params,
sampling_params,
m_scheduler->get_config().block_size,
m_scheduler->get_config().enable_prefix_caching);
sequence_group->set_sequence_group_ptr(sequence_group);
Expand All @@ -87,7 +91,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::add_request(uint64_t request
GenerationHandle
ContinuousBatchingPipeline::ContinuousBatchingImpl::add_request(uint64_t request_id,
const std::string& prompt,
ov::genai::GenerationConfig sampling_params) {
ov::genai::GenerationConfig sampling_params) {
static ManualTimer timer("tokenize");
timer.start();
ov::Tensor input_ids = m_tokenizer.encode(prompt).input_ids;
Expand Down Expand Up @@ -255,20 +259,44 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
std::vector<EncodedGenerationResult> results;
results.reserve(m_awaiting_requests.size());

bool continue_generation = true;
auto drop_requests = [&] () {
for (const std::shared_ptr<ov::genai::SequenceGroup> request : m_requests) {
for (const auto& sequence: request->get_sequences()) {
if (m_scheduler->has_block_table(sequence->get_id())) {
m_scheduler->free_sequence(sequence->get_id());
}
}
m_sampler->clear_beam_search_info(request->get_request_id());
}
m_requests.clear();
};

bool continue_generation = true, step_throws_exception = false;
while (has_non_finished_requests() && continue_generation) {
step();
try {
step();
} catch (...) {
drop_requests();
throw;
}
if (streamer_ptr && generations.at(0)->can_read()) {
std::unordered_map<uint64_t, GenerationOutput> token = generations.at(0).get()->back();
OPENVINO_ASSERT(1 == token.size());
OPENVINO_ASSERT(1 == token.begin()->second.generated_ids.size());
continue_generation = !streamer_ptr->put(token.begin()->second.generated_ids.at(0));
}
}

if (streamer_ptr) {
streamer_ptr->end();
}

if (!continue_generation) {
drop_requests();
} else {
OPENVINO_ASSERT(m_requests.empty(), "Internal error: current request is supposed to be dropped within step() function as completed");
}

for (size_t generation_idx = 0; generation_idx < generations.size(); ++generation_idx) {
const auto& generation = generations[generation_idx];
EncodedGenerationResult result;
Expand Down

0 comments on commit 56b05c2

Please sign in to comment.