diff --git a/src/cpp/src/llm_pipeline_static.cpp b/src/cpp/src/llm_pipeline_static.cpp index 3f50d30ec9..942f8f2076 100644 --- a/src/cpp/src/llm_pipeline_static.cpp +++ b/src/cpp/src/llm_pipeline_static.cpp @@ -5,6 +5,16 @@ #include "openvino/opsets/opset13.hpp" + +#include "openvino/pass/validate.hpp" +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/op/util/op_types.hpp" +#include "openvino/pass/pattern/op/label.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" + #include "text_callback_streamer.hpp" #include "utils.hpp" @@ -117,6 +127,72 @@ ov::AnyMap extract_config_or_default(const ov::AnyMap& config, const std::string return stage_cfg; } +inline int8_t hi4(int8_t x) { + return ((x & (1 << 7)) >> 4) | ((x & (1 << 6)) >> 4) | ((x & (1 << 5)) >> 4) | ((x & (1 << 4)) >> 4); +} + +inline int8_t lo4(int8_t x) { + return (x & (1 << 3)) | (x & (1 << 2)) | (x & (1 << 1)) | (x & (1 << 0)); +} + +inline int8_t upc(int8_t h) { + return h | (-((h & (1 << 3)) >> 3) & (-8)); +} + +void cvt(const ov::Tensor &src, ov::Tensor &dst) { + + int8_t const* pSrc = static_cast(src.data()); + int8_t *pDst = static_cast(dst.data()); + for (int i = 0; i < src.get_size() / 2; i++) { + uint8_t a0 = upc(lo4(*pSrc)) + 8; + uint8_t a1 = upc(hi4(*pSrc)) + 8; + *pDst = a1 << 4 | a0; + pSrc++; + pDst++; + } +} + +struct DQMM1: public ov::pass::MatcherPass { + DQMM1() { + namespace opp = ov::pass::pattern; + + auto w = opp::wrap_type(); + auto s = opp::wrap_type(); + auto cvtw = opp::wrap_type({w}); + auto mply = opp::wrap_type({cvtw, s}); + + auto cb = [=](ov::pass::pattern::Matcher& m) { + auto& node_to_output = m.get_pattern_value_map(); + auto mw_const = std::static_pointer_cast(node_to_output.at(w).get_node_shared_ptr()); + auto mupscale = node_to_output.at(mply).get_node_shared_ptr(); + if (ov::element::i4 == mw_const->get_element_type()) { + + ov::Tensor src(mw_const->get_element_type(), mw_const->get_shape(), const_cast(mw_const->get_data_ptr())); + ov::Tensor dst(ov::element::u4, mw_const->get_shape()); + cvt(src, dst); + + auto new_w = std::make_shared(dst); + + ov::Tensor zp(ov::element::u4, ov::Shape{1}); + *static_cast(zp.data()) = 8; + + auto new_z = std::make_shared(zp); + + auto mply_type = mupscale->input(1).get_element_type(); + + auto new_wcvt = std::make_shared(new_w, mply_type); + auto new_zcvt = std::make_shared(new_z, mply_type); + auto new_sub = std::make_shared(new_wcvt, new_zcvt); + + mupscale->input(0).replace_source_output(new_sub); + } + + return false; + }; + register_matcher(std::make_shared(mply, "DQMM1"), cb); + } +}; + } // anonymous namespace namespace ov { @@ -144,26 +220,33 @@ StaticLLMPipeline::StaticLLMPipeline( */ ov::Core core; // (1) Read the template model - this will be kvcache model - auto kvcache_model = core.read_model(path / "openvino_model.xml"); + m_kvcache_model = core.read_model(path / "openvino_model.xml"); + + // (1.5): Some rewrites + ov::pass::GraphRewrite rewr; + rewr.add_matcher(); + rewr.run_on_model(m_kvcache_model); + ov::pass::Validate().run_on_model(m_kvcache_model); + // (2) Expose KV-cache input and output layers from kvcache model - ov::pass::StatefulToStateless().run_on_model(kvcache_model); + ov::pass::StatefulToStateless().run_on_model(m_kvcache_model); // (3) Clone the model - this will be prefill - auto prefill_model = kvcache_model->clone(); - prefill_model->set_friendly_name(kvcache_model->get_friendly_name() + "_prefill"); + m_prefill_model = m_kvcache_model->clone(); + m_prefill_model->set_friendly_name(m_kvcache_model->get_friendly_name() + "_prefill"); // (4) Reshape both models to static shape m_kvcache_desc = KVCacheDesc { 1024u, 0u }; const uint32_t max_prompt_size = m_kvcache_desc.total_size; const uint32_t max_kvcache_size = m_kvcache_desc.total_size; - reshape_to_static(prefill_model, max_prompt_size, max_kvcache_size); - reshape_to_static(kvcache_model, 1u, max_kvcache_size); + reshape_to_static(m_prefill_model, max_prompt_size, max_kvcache_size); + reshape_to_static(m_kvcache_model, 1u, max_kvcache_size); // (5) Add slices to kvcache model - kvcache_model = add_slices_to_kvcache_inputs(kvcache_model); + m_kvcache_model = add_slices_to_kvcache_inputs(m_kvcache_model); // (6) Compile both model m_prefill_request = core.compile_model( - prefill_model, device, extract_config_or_default(config, "PREFILL_CONFIG") + m_prefill_model, device, extract_config_or_default(config, "PREFILL_CONFIG") ).create_infer_request(); m_kvcache_request = core.compile_model( - kvcache_model, device, extract_config_or_default(config, "GENERATE_CONFIG") + m_kvcache_model, device, extract_config_or_default(config, "GENERATE_CONFIG") ).create_infer_request(); // (7) Initialize tensors prepare_for_new_conversation(); diff --git a/src/cpp/src/llm_pipeline_static.hpp b/src/cpp/src/llm_pipeline_static.hpp index 85488e1880..7560b7e336 100644 --- a/src/cpp/src/llm_pipeline_static.hpp +++ b/src/cpp/src/llm_pipeline_static.hpp @@ -46,6 +46,10 @@ class StaticLLMPipeline final : public LLMPipelineImplBase { uint32_t num_stored_tokens; }; + // FIXME: Ideally, we don't need to keep those + std::shared_ptr m_kvcache_model; + std::shared_ptr m_prefill_model; + KVCacheDesc m_kvcache_desc; ov::InferRequest m_kvcache_request; ov::InferRequest m_prefill_request;