Skip to content

Commit

Permalink
Support unfixed kv heads number
Browse files Browse the repository at this point in the history
  • Loading branch information
mangguo321 committed Dec 26, 2024
1 parent 812163a commit db866ee
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 55 deletions.
41 changes: 20 additions & 21 deletions src/cpp/src/cache_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ class CacheManager {
}
OPENVINO_ASSERT(m_key_cache.size() == m_value_cache.size());
m_num_allocated_kv_blocks = num_kv_blocks;
ov::Shape value_cache_shape = set_first_dim_and_make_static(m_device_config.get_value_cache_shape(), num_kv_blocks);
ov::Shape key_cache_shape = set_first_dim_and_make_static(m_device_config.get_key_cache_shape(), num_kv_blocks);

const std::string device_name = m_device_config.get_device();

Expand All @@ -56,6 +54,8 @@ class CacheManager {

if (device_name.find("GPU") == std::string::npos) {// Allocate KV caches
for (size_t decoder_layer_id = 0; decoder_layer_id < m_device_config.get_num_layers(); ++decoder_layer_id) {
ov::Shape value_cache_shape = set_first_dim_and_make_static(m_device_config.get_value_cache_shape(decoder_layer_id), num_kv_blocks);
ov::Shape key_cache_shape = set_first_dim_and_make_static(m_device_config.get_key_cache_shape(decoder_layer_id), num_kv_blocks);
ov::Tensor key_cache(m_device_config.get_cache_precision(), key_cache_shape);
ov::Tensor value_cache(m_device_config.get_cache_precision(), value_cache_shape);

Expand Down Expand Up @@ -104,6 +104,8 @@ class CacheManager {
} else {
auto remote_context = m_core.get_default_context(device_name);
for (size_t decoder_layer_id = 0; decoder_layer_id < m_device_config.get_num_layers(); ++decoder_layer_id) {
ov::Shape value_cache_shape = set_first_dim_and_make_static(m_device_config.get_value_cache_shape(decoder_layer_id), num_kv_blocks);
ov::Shape key_cache_shape = set_first_dim_and_make_static(m_device_config.get_key_cache_shape(decoder_layer_id), num_kv_blocks);
ov::Tensor key_cache = remote_context.create_tensor(m_device_config.get_cache_precision(),
key_cache_shape);
ov::Tensor value_cache = remote_context.create_tensor(m_device_config.get_cache_precision(),
Expand Down Expand Up @@ -142,30 +144,27 @@ class CacheManager {
}

void copy_blocks(const std::map<size_t, std::list<size_t>>& block_copy_map) {
ov::Shape key_shape = set_first_dim_and_make_static(m_device_config.get_key_cache_shape(), m_num_allocated_kv_blocks);
ov::Shape value_shape = set_first_dim_and_make_static(m_device_config.get_value_cache_shape(), m_num_allocated_kv_blocks);

ov::Coordinate key_src_start_roi(key_shape.size(), 0);
ov::Coordinate key_src_end_roi = key_shape;
ov::Coordinate key_dst_start_roi(key_shape.size(), 0);
ov::Coordinate key_dst_end_roi = key_shape;

ov::Coordinate value_src_start_roi(value_shape.size(), 0);
ov::Coordinate value_src_end_roi = value_shape;
ov::Coordinate value_dst_start_roi(value_shape.size(), 0);
ov::Coordinate value_dst_end_roi = value_shape;

for (const auto & blocks_pair : block_copy_map) {
size_t src_block_id = blocks_pair.first;
key_src_end_roi[0] = (key_src_start_roi[0] = src_block_id) + 1;
value_src_end_roi[0] = (value_src_start_roi[0] = src_block_id) + 1;

const std::list<size_t>& dst_block_ids = blocks_pair.second;
for (size_t dst_block_id : dst_block_ids) {
key_dst_end_roi[0] = (key_dst_start_roi[0] = dst_block_id) + 1;
value_dst_end_roi[0] = (value_dst_start_roi[0] = dst_block_id) + 1;

for (size_t decoder_layer_id = 0; decoder_layer_id < m_device_config.get_num_layers(); ++decoder_layer_id) {
ov::Shape key_shape = set_first_dim_and_make_static(m_device_config.get_key_cache_shape(decoder_layer_id), m_num_allocated_kv_blocks);
ov::Shape value_shape = set_first_dim_and_make_static(m_device_config.get_value_cache_shape(decoder_layer_id), m_num_allocated_kv_blocks);
ov::Coordinate key_src_start_roi(key_shape.size(), 0);
ov::Coordinate key_src_end_roi = key_shape;
ov::Coordinate key_dst_start_roi(key_shape.size(), 0);
ov::Coordinate key_dst_end_roi = key_shape;

ov::Coordinate value_src_start_roi(value_shape.size(), 0);
ov::Coordinate value_src_end_roi = value_shape;
ov::Coordinate value_dst_start_roi(value_shape.size(), 0);
ov::Coordinate value_dst_end_roi = value_shape;
key_src_end_roi[0] = (key_src_start_roi[0] = src_block_id) + 1;
value_src_end_roi[0] = (value_src_start_roi[0] = src_block_id) + 1;
key_dst_end_roi[0] = (key_dst_start_roi[0] = dst_block_id) + 1;
value_dst_end_roi[0] = (value_dst_start_roi[0] = dst_block_id) + 1;

ov::Tensor key_src_cache_roi(m_key_cache[decoder_layer_id], key_src_start_roi, key_src_end_roi);
ov::Tensor key_dst_cache_roi(m_key_cache[decoder_layer_id], key_dst_start_roi, key_dst_end_roi);

Expand Down
61 changes: 40 additions & 21 deletions src/cpp/src/device_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
namespace ov::genai {
class DeviceConfig {
ov::element::Type m_kv_cache_type;
ov::PartialShape m_key_cache_shape, m_value_cache_shape;
ov::Shape::value_type m_num_kv_heads, m_head_size, m_num_decoder_layers;
std::vector<ov::PartialShape> m_key_cache_shape, m_value_cache_shape;
std::vector<ov::Shape::value_type> m_num_kv_heads;
ov::Shape::value_type m_head_size, m_num_decoder_layers;
size_t m_num_kv_blocks = 0;
size_t m_block_size = 0;
size_t m_cache_size = 0;
Expand Down Expand Up @@ -88,11 +89,14 @@ class DeviceConfig {
}
}

void set_model_params(size_t num_kv_heads, size_t head_size, size_t num_decoder_layers) {
m_num_kv_heads = num_kv_heads;
void set_model_params(std::vector<size_t> num_kv_heads, size_t head_size, size_t num_decoder_layers) {
m_head_size = head_size;
m_num_decoder_layers = num_decoder_layers;

m_num_kv_heads.assign(num_kv_heads.begin(), num_kv_heads.end());
m_key_cache_shape.reserve(m_num_decoder_layers);
m_value_cache_shape.reserve(m_num_decoder_layers);

if (m_device == "CPU") {
// Scale, zero point and quantized data will be stored together.
// The layout for per token per head:
Expand All @@ -104,21 +108,32 @@ class DeviceConfig {
}

if (m_num_kv_blocks == 0 && m_cache_size > 0) {
size_t block_size = 0;
size_t size_in_bytes = m_cache_size * 1024 * 1024 * 1024;
m_num_kv_blocks = size_in_bytes / (m_num_decoder_layers * 2 * m_num_kv_heads * m_block_size * m_head_size * m_kv_cache_type.size());
for (size_t layer_id = 0; layer_id < m_num_decoder_layers; layer_id++) {
block_size += 2 * m_num_kv_heads[layer_id] * m_block_size * m_head_size * m_kv_cache_type.size();
}
m_num_kv_blocks = size_in_bytes / block_size;
}

m_key_cache_shape = m_value_cache_shape = ov::PartialShape{ov::Dimension::dynamic(),
ov::Dimension(m_num_kv_heads),
ov::Dimension(m_block_size),
ov::Dimension(m_head_size)};

if (m_device.find("GPU") != std::string::npos) {
// Update key shape, as the key's shape is different from the value's shape
m_key_cache_shape = ov::PartialShape{ov::Dimension::dynamic(),
ov::Dimension(m_num_kv_heads),
ov::Dimension(m_head_size),
ov::Dimension(m_block_size)};
for (size_t layer_id = 0; layer_id < m_num_decoder_layers; layer_id++) {
m_key_cache_shape.push_back(ov::PartialShape{ov::Dimension::dynamic(),
ov::Dimension(m_num_kv_heads[layer_id]),
ov::Dimension(m_block_size),
ov::Dimension(m_head_size)});

m_value_cache_shape.push_back(ov::PartialShape{ov::Dimension::dynamic(),
ov::Dimension(m_num_kv_heads[layer_id]),
ov::Dimension(m_block_size),
ov::Dimension(m_head_size)});

if (m_device.find("GPU") != std::string::npos) {
// Update key shape, as the key's shape is different from the value's shape
m_key_cache_shape.push_back(ov::PartialShape{ov::Dimension::dynamic(),
ov::Dimension(m_num_kv_heads[layer_id]),
ov::Dimension(m_head_size),
ov::Dimension(m_block_size)});
}
}
}

Expand All @@ -134,14 +149,14 @@ class DeviceConfig {
return m_num_decoder_layers;
}

ov::PartialShape get_key_cache_shape() const {
ov::PartialShape get_key_cache_shape(size_t id) const {
OPENVINO_ASSERT(m_key_cache_shape.size());
return m_key_cache_shape;
return m_key_cache_shape[id];
}

ov::PartialShape get_value_cache_shape() const {
ov::PartialShape get_value_cache_shape(size_t id) const {
OPENVINO_ASSERT(m_value_cache_shape.size());
return m_value_cache_shape;
return m_value_cache_shape[id];
}

size_t get_num_kv_blocks() const {
Expand All @@ -153,7 +168,11 @@ class DeviceConfig {
}

size_t get_block_size_in_bytes() const {
return m_num_decoder_layers * 2 * m_num_kv_heads * m_block_size * m_head_size * get_cache_precision().size();
size_t block_size = 0;
for (size_t layer_id = 0; layer_id < m_num_decoder_layers; layer_id++) {
block_size += 2 * m_num_kv_heads[layer_id] * m_block_size * m_head_size * get_cache_precision().size();
}
return block_size;
}
};
}
20 changes: 13 additions & 7 deletions src/cpp/src/utils/paged_attention_transformations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,21 @@ void set_kv_cache_type_and_shape(std::shared_ptr<ov::Model> model, DeviceConfig&
OPENVINO_ASSERT(key_cache_params.count(key_cache_param_name) != 0, "key_cache.0 tensor not found among model parameters");
ov::PartialShape k_shape = key_cache_params[key_cache_param_name]->get_partial_shape();
OPENVINO_ASSERT(k_shape.rank().get_length() == 3, "KV cache shape is expected to have rank 3, while shape is ", k_shape);
size_t num_kv_heads = k_shape[1].get_length(), head_size = k_shape[2].get_length();

size_t head_size = k_shape[2].get_length();
std::vector<size_t> num_kv_heads(num_layers);
for (size_t idx = 0; idx < num_layers; idx++) {
size_t num_heads = key_cache_params[std::string("key_cache.") + std::to_string(idx)]->get_partial_shape()[1].get_length();
num_kv_heads[idx] = num_heads;
}
device_config.set_model_params(num_kv_heads, head_size, num_layers);

for (auto it_k = key_cache_params.begin(), it_v = value_cache_params.begin(); it_k != key_cache_params.end();++it_k, ++it_v) {
it_k->second->set_element_type(device_config.get_cache_precision());
it_v->second->set_element_type(device_config.get_cache_precision());
it_k->second->set_partial_shape(device_config.get_key_cache_shape());
it_v->second->set_partial_shape(device_config.get_value_cache_shape());
for (size_t idx = 0; idx < num_layers; idx++) {
auto k = key_cache_params[std::string("key_cache.") + std::to_string(idx)];
auto v = value_cache_params[std::string("value_cache.") + std::to_string(idx)];
k->set_element_type(device_config.get_cache_precision());
v->set_element_type(device_config.get_cache_precision());
k->set_partial_shape(device_config.get_key_cache_shape(idx));
v->set_partial_shape(device_config.get_value_cache_shape(idx));
}

model->validate_nodes_and_infer_types();
Expand Down
13 changes: 9 additions & 4 deletions tests/cpp/cache_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ TEST(TestCacheManager, test_cache_size_param) {
const std::string device = "CPU";
ov::genai::DeviceConfig device_config(core, scheduler_config, "CPU");
size_t num_decoder_layers = 12;
device_config.set_model_params(12, 64, num_decoder_layers);
std::vector<size_t> num_kv_heads(12, 12);
device_config.set_model_params(num_kv_heads, 64, num_decoder_layers);

ov::InferRequest request = core.compile_model(get_dummy_model(num_decoder_layers)).create_infer_request();
auto cache_manager = std::make_shared<ov::genai::CacheManager>(device_config, request, core);
Expand All @@ -76,7 +77,8 @@ TEST(TestCacheManager, test_kv_blocks_param) {
const std::string device = "CPU";
ov::genai::DeviceConfig device_config(core, scheduler_config, "CPU");
size_t num_decoder_layers = 12;
device_config.set_model_params(12, 64, num_decoder_layers);
std::vector<size_t> num_kv_heads(12, 12);
device_config.set_model_params(num_kv_heads, 64, num_decoder_layers);

ov::InferRequest request = core.compile_model(get_dummy_model(num_decoder_layers)).create_infer_request();
auto cache_manager = std::make_shared<ov::genai::CacheManager>(device_config, request, core);
Expand All @@ -97,9 +99,12 @@ TEST(TestCacheManager, test_dynamic_cache_increase) {
ov::genai::DeviceConfig device_config(core, scheduler_config, "CPU");
size_t num_decoder_layers = 12;
size_t head_size = 64;
size_t num_kv_heads = 12;
std::vector<size_t> num_kv_heads(12, 12);
device_config.set_model_params(num_kv_heads, head_size, num_decoder_layers);
size_t block_size_in_bytes = num_decoder_layers * 2 * num_kv_heads * device_config.get_block_size() * head_size * device_config.get_cache_precision().size();
size_t block_size_in_bytes = 0;
for (size_t layer_id = 0; layer_id < num_decoder_layers; layer_id++) {
block_size_in_bytes += 2 * num_kv_heads[layer_id] * device_config.get_block_size() * head_size * device_config.get_cache_precision().size();
}


ov::InferRequest request = core.compile_model(get_dummy_model(num_decoder_layers)).create_infer_request();
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/device_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ TEST(TestDeviceConfig, kv_cache_precision_u8) {
const std::string device = "CPU";
size_t num_decoder_layers = 12;
size_t head_size = 64, head_size_u8 = head_size + 8;
size_t num_kv_heads = 12;
std::vector<size_t> num_kv_heads(12, 12);

ov::genai::DeviceConfig device_config_default(core, scheduler_config, "CPU");
device_config_default.set_model_params(num_kv_heads, head_size_u8, num_decoder_layers);
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ std::shared_ptr<CacheManager> init_cache_manager(SchedulerConfig scheduler_confi
size_t num_decoder_layers = 12;
ov::InferRequest request = core.compile_model(get_model(num_decoder_layers)).create_infer_request();
size_t head_size = 64, head_size_u8 = head_size + 8;
size_t num_kv_heads = 12;
std::vector<size_t> num_kv_heads(12, 12);
ov::genai::DeviceConfig device_config(core, scheduler_config, "CPU");
device_config.set_model_params(num_kv_heads, head_size_u8, num_decoder_layers);
return std::make_shared<CacheManager>(device_config, request, core);
Expand Down

0 comments on commit db866ee

Please sign in to comment.