diff --git a/src/cpp/src/llm_pipeline_static.cpp b/src/cpp/src/llm_pipeline_static.cpp index c98b571179..8f504e9643 100644 --- a/src/cpp/src/llm_pipeline_static.cpp +++ b/src/cpp/src/llm_pipeline_static.cpp @@ -209,6 +209,13 @@ std::shared_ptr cvt_value_tensors_layout(std::shared_ptr m return ppp.build(); } +void unroll_sdpa(std::shared_ptr model) { + ov::pass::GraphRewrite rewr; + rewr.add_matcher(); + rewr.run_on_model(model); + ov::pass::Validate().run_on_model(model); +} + bool optimize_value_tensors(std::shared_ptr model) { ov::pass::GraphRewrite rewr; rewr.add_matcher(); @@ -1049,6 +1056,8 @@ void StatelessLLMPipeline::setupAndCompileModels( m_kvcache_desc.v_tensors_transposed = true; prefill_model = cvt_value_tensors_layout(prefill_model); } + } else { + unroll_sdpa(kvcache_model); } // (7) Replace KV-cache tensors for the entire cache to tensors only for new token (before concat) kvcache_model = redirect_new_kv_to_output(kvcache_model);