Skip to content

Commit

Permalink
Re-use CB sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
TolyaTalamanov committed Jan 2, 2025
1 parent a0e1577 commit 28435c8
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 81 deletions.
137 changes: 61 additions & 76 deletions src/cpp/src/llm_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,8 @@ StaticLLMPipeline::StaticLLMPipeline(
const std::string& device,
const ov::AnyMap& config
) : LLMPipelineImplBase(tokenizer,
utils::from_config_json_if_exists(models_path)) {
utils::from_config_json_if_exists(models_path)),
m_sampler(m_tokenizer) {
auto properties = config;
/* NB: Static LLM pipeline consists of two models,
first to process the input prompt (prefill),
Expand Down Expand Up @@ -675,6 +676,8 @@ StaticLLMPipeline::StaticLLMPipeline(
if (m_generation_config.eos_token_id == -1) {
m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id());
}

m_sampler.set_seed(m_generation_config.rng_seed);
};

StaticLLMPipeline::StaticLLMPipeline(
Expand All @@ -691,7 +694,7 @@ StaticLLMPipeline::StaticLLMPipeline(
const std::string& device,
const ov::AnyMap& properties,
const ov::genai::GenerationConfig& generation_config
) : LLMPipelineImplBase(tokenizer, generation_config) {
) : LLMPipelineImplBase(tokenizer, generation_config), m_sampler(m_tokenizer) {
bool use_blobs = false;
auto anyopt = get_option<bool>(properties, "USE_BLOBS");
if (anyopt.has_value()) {
Expand All @@ -710,6 +713,8 @@ StaticLLMPipeline::StaticLLMPipeline(
if (m_generation_config.eos_token_id == -1) {
m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id());
}

m_sampler.set_seed(m_generation_config.rng_seed);
}

void StaticLLMPipeline::setupAndCompileModels(
Expand Down Expand Up @@ -940,28 +945,29 @@ DecodedResults StaticLLMPipeline::generate(
return decoded_results;
}

int64_t sample_next_token(const ov::Tensor& logits,
const GenerationConfig& config,
std::mt19937& rng_engine,
LogitProcessor& logit_processor) {
const size_t vocab_size = logits.get_shape()[2];
const size_t seq_len_size = logits.get_shape()[1];
const size_t offset = (seq_len_size - 1) * vocab_size;
// NB: Slice out and take probabilities only for the last token
Logits logit_vector(logits.data<float>() + offset, vocab_size);
logit_processor.apply(logit_vector);
int64_t last_token = -1;
if (config.is_greedy_decoding()) {
last_token = ov::genai::greedy_sample(logit_vector, config.logprobs).m_index;
} else if (config.is_multinomial()) {
last_token = ov::genai::multinomial_sample(logit_vector, 1u, rng_engine)[0].m_index;
} else {
// NB: Only greedy and multinomial supported,
// the appropriate check is performed before
OPENVINO_ASSERT(false);
void stream_generated_tokens(std::shared_ptr<StreamerBase> streamer_ptr,
GenerationHandle& handle) {
if (streamer_ptr && handle->can_read()) {
std::unordered_map<uint64_t, GenerationOutput> token = handle->back();
for (const auto& gen_token : token.begin()->second.generated_ids) {
if (streamer_ptr->put(gen_token)) {
handle->drop();
break;
}
}
}
logit_processor.register_new_generated_token(last_token);
return last_token;
}

int64_t get_last_token(SequenceGroup::Ptr sequence_group) {
const auto running_sequences = sequence_group->get_running_sequences();
OPENVINO_ASSERT(running_sequences.size() == 1u);
const auto sequence = running_sequences.front();

size_t num_scheduled_tokens = sequence_group->get_num_scheduled_tokens();
OPENVINO_ASSERT(num_scheduled_tokens == 1u);

const auto num_processed_tokens = sequence_group->get_num_processed_tokens();
return sequence->get_generated_ids()[num_processed_tokens - sequence_group->get_prompt_len()];
}

EncodedResults StaticLLMPipeline::generate(
Expand All @@ -981,7 +987,10 @@ EncodedResults StaticLLMPipeline::generate(
attention_mask = data->attention_mask;
}

if (input_ids.get_shape().at(0) > 1u) {
ov::Shape prompts_shape = input_ids.get_shape();
const size_t batch_size = prompts_shape[0];

if (batch_size > 1u) {
OPENVINO_THROW("Currently only batch size=1 is supported");
}

Expand All @@ -1004,25 +1013,6 @@ EncodedResults StaticLLMPipeline::generate(
OPENVINO_THROW("Currently only greedy and multinomial decoding are supported");
}

// FIXME:...
if ( streamer_ptr &&
!config.stop_strings.empty() &&
!config.include_stop_str_in_output) {
OPENVINO_THROW("Static LLM pipeline doesn't support "
"\"include_stop_str_in_output: false\" when a streamer is provided");
}

std::vector<int64_t> input_ids_vec;
input_ids_vec.reserve(input_ids.get_size());
std::copy_n(input_ids.data<int64_t>(), input_ids.get_size(), std::back_inserter(input_ids_vec));
LogitProcessor logit_processor(config, input_ids_vec);
m_rng_engine.seed(config.rng_seed);

const auto processed_stop_strings =
process_stop_strings(config.stop_strings, m_tokenizer);

ov::Shape prompts_shape = input_ids.get_shape();
const size_t batch_size = prompts_shape[0];
ov::genai::EncodedResults results;
auto& raw_perf_counters = results.perf_metrics.raw_metrics;
// NB: Only batch=1 is supported now
Expand Down Expand Up @@ -1060,12 +1050,20 @@ EncodedResults StaticLLMPipeline::generate(
// NB: Now there are prompt_len tokens in KV-cache
m_kvcache_desc.num_stored_tokens += static_cast<uint32_t>(prompt_len);

auto last_token = sample_next_token(
m_prefill_request.get_tensor("logits"), config, m_rng_engine, logit_processor);
results.tokens[0].push_back(last_token);
if (streamer_ptr && streamer_ptr->put(last_token)) {
return results;
}
auto logits = m_prefill_request.get_tensor("logits");
int64_t output_sequence_len = logits.get_shape().at(1);

auto sequence_group = std::make_shared<SequenceGroup>(
0 /* request_id */, padded_input_ids, config, 1 /* block_size */);
sequence_group->update_processed_tokens_num(m_kvcache_desc.max_prompt_size - output_sequence_len);
sequence_group->schedule_tokens(output_sequence_len);

// NB: Controls what tokens are ready to be pushed into the streamer
GenerationHandle handle = std::make_shared<GenerationHandleImpl>(
sequence_group->get_generation_stream(), sequence_group->get_sampling_parameters());

SamplerOutput sampler_output = m_sampler.sample({sequence_group}, logits);
stream_generated_tokens(streamer_ptr, handle);

// Outputs: logits, ...
const auto kStartOutputKVCacheLayers = 1u;
Expand Down Expand Up @@ -1106,46 +1104,27 @@ EncodedResults StaticLLMPipeline::generate(
std::fill(attention_mask_data, attention_mask_data + m_kvcache_desc.num_stored_tokens - 1u, 1u);
attention_mask_data[m_kvcache_desc.total_size - 1] = 1u;

const size_t max_tokens = config.get_max_new_tokens(prompt_len);
for (int i = 0; i < max_tokens - 1; ++i) {
while (sequence_group->is_running()) {
sequence_group->schedule_tokens(1);
int64_t last_token = get_last_token(sequence_group);

input_ids_data[0] = last_token;
position_ids_data[0] = m_kvcache_desc.num_stored_tokens;
attention_mask_data[m_kvcache_desc.num_stored_tokens - 1] = 1u;

m_kvcache_request.infer();
m_kvcache_desc.num_stored_tokens += 1;

last_token = sample_next_token(
m_kvcache_request.get_tensor("logits"), config, m_rng_engine, logit_processor);
results.tokens[0].push_back(last_token);

raw_perf_counters.m_new_token_times.emplace_back(std::chrono::steady_clock::now());
raw_perf_counters.m_batch_sizes.emplace_back(batch_size);

bool met_stop_str = false;
if (!config.stop_strings.empty()) {
auto match_result = match_stop_string(m_tokenizer,
results.tokens[0],
processed_stop_strings,
config.include_stop_str_in_output);
if (match_result.is_matched) {
met_stop_str = true;
results.tokens[0].erase(
results.tokens[0].end() - match_result.to_remove, results.tokens[0].end());
}
}

if (streamer_ptr && streamer_ptr->put(last_token)) {
break;
}

if (met_stop_str || (last_token == config.eos_token_id && !config.ignore_eos)) {
break;
}
SamplerOutput sampler_output = m_sampler.sample(
{sequence_group}, m_kvcache_request.get_tensor("logits"));
stream_generated_tokens(streamer_ptr, handle);

// NB: KV-cache is full, further generation is impossible
if (m_kvcache_desc.num_stored_tokens == m_kvcache_desc.total_size) {
break;
sequence_group->set_out_of_memory();
}

// NB: Write KV-cache for the new token to the correct input position for the next iteration
Expand All @@ -1168,6 +1147,12 @@ EncodedResults StaticLLMPipeline::generate(
streamer_ptr->end();
}

OPENVINO_ASSERT(sequence_group->get_finished_sequences().size() == 1u);
auto sequence = sequence_group->get_finished_sequences().front();
results.tokens[0] = sequence->get_generated_ids();
results.scores[0] = sequence->get_cumulative_log_prob();
m_sampler.clear_request_info(sequence_group->get_request_id());

auto stop_time = std::chrono::steady_clock::now();
// If is called without tokenization then that stat will not be reported.
auto& metrics = results.perf_metrics;
Expand Down
6 changes: 3 additions & 3 deletions src/cpp/src/llm_pipeline_static.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <random>

#include "llm_pipeline_base.hpp"
#include "sampler.hpp"

namespace ov {
namespace genai {
Expand Down Expand Up @@ -78,15 +79,14 @@ class StaticLLMPipeline final : public LLMPipelineImplBase {
bool v_tensors_transposed;
};

Sampler m_sampler;

KVCacheDesc m_kvcache_desc;
ov::InferRequest m_kvcache_request;
ov::InferRequest m_prefill_request;

bool m_is_chat_conversation = false;
ChatHistory m_history;

// NB: For multinomial sampling
std::mt19937 m_rng_engine;
};

} // namespace genai
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ process_stop_strings(const std::set<std::string>& stop_strings, Tokenizer& token
return result;
}

SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
SamplerOutput Sampler::sample(const std::vector<SequenceGroup::Ptr> & sequence_groups,
ov::Tensor logits,
bool is_validation_mode_enabled) {
const float * logits_data = logits.data<float>();
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class Sampler {
Sampler() = default;
Sampler(Tokenizer & tokenizer) : m_tokenizer(tokenizer) {};

SamplerOutput sample(std::vector<SequenceGroup::Ptr> & sequence_groups, ov::Tensor logits, bool is_validation_mode_enabled = false);
SamplerOutput sample(const std::vector<SequenceGroup::Ptr> & sequence_groups, ov::Tensor logits, bool is_validation_mode_enabled = false);
void set_seed(size_t new_seed) {
rng_engine.seed(new_seed);
seed = new_seed;
Expand Down

0 comments on commit 28435c8

Please sign in to comment.