Skip to content

Commit

Permalink
add_special_tokens = false for chat (#962)
Browse files Browse the repository at this point in the history
Chat for continuous batching and for static pipeline should match with
stateful and HF

https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1884-L1893

---------

Co-authored-by: Vladimir Zlobin <[email protected]>
  • Loading branch information
pavel-esir and Wovchena authored Oct 14, 2024
1 parent bb6e307 commit 510fcd3
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,8 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<s
constexpr bool add_generation_prompt = true;
std::string history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
timer.start();
input_ids.push_back(m_tokenizer.encode(history).input_ids);
// ov::genai::add_special_tokens(false) is aligned with stateful pipeline
input_ids.push_back(m_tokenizer.encode(history, ov::genai::add_special_tokens(false)).input_ids);
timer.end();
} else {
input_ids.reserve(prompts.size());
Expand Down
7 changes: 4 additions & 3 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,13 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
m_history.push_back({{"role", "user"}, {"content", prompt}});
constexpr bool add_generation_prompt = true;
auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
bool add_special_tokens_ = false; // Do not add special tokens is chat scenario.
auto new_chat_tokens = m_tokenizer.encode(new_templated_chat_history, ov::genai::add_special_tokens(add_special_tokens_));
// Do not add special tokens in chat scenario to be aligned with HF.
bool add_special_tokens = false;
auto new_chat_tokens = m_tokenizer.encode(new_templated_chat_history, ov::genai::add_special_tokens(add_special_tokens));
if (m_is_cache_empty) {
encoded_input = new_chat_tokens;
} else {
auto prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(add_special_tokens_));
auto prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(add_special_tokens));
encoded_input = utils::subtract_chat_tokenized_inputs(new_chat_tokens, prev_chat_tokens);
}
m_templated_chat_history = new_templated_chat_history;
Expand Down
6 changes: 5 additions & 1 deletion src/cpp/src/llm_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,12 +478,16 @@ DecodedResults StaticLLMPipeline::generate(
prompt = std::get<std::string>(inputs);
}

ov::genai::TokenizedInputs tokenized_input;
if (m_is_chat_conversation) {
m_history.push_back({{"role", "user"}, {"content", prompt}});
constexpr bool add_generation_prompt = true;
prompt = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
// for chat ov::genai::add_special_tokens(false) is aligned with stateful pipeline and HF
tokenized_input = m_tokenizer.encode(prompt, ov::genai::add_special_tokens(false));
} else {
tokenized_input = m_tokenizer.encode(prompt);
}
auto tokenized_input = m_tokenizer.encode(prompt);

auto encode_stop_time = std::chrono::steady_clock::now();
auto encoded_results = generate(tokenized_input, config, streamer);
Expand Down

0 comments on commit 510fcd3

Please sign in to comment.