Skip to content

Commit

Permalink
Support multiple images for LLaVA & LLaVA Next models (#1080)
Browse files Browse the repository at this point in the history
Ticket: CVS-155384
  • Loading branch information
yatarkan authored Oct 29, 2024
1 parent fa324cf commit 418aece
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 99 deletions.
241 changes: 148 additions & 93 deletions src/cpp/src/visual_language/input_embedder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,46 @@ class InputsEmbedder::IInputsEmbedder {
}
return encoded_input_ids;
}

/**
* @brief Unpads an image tensor of a padded and resized image.
* Used for packing image features of llava_next models.
*
* @param tensor An image tensor with a shape (embed_dim, height, width)
* @param original_size A size of original image
* @return An unpadded image tensor with a shape (embed_dim, new_height, new_width)
*/

/**
* @brief Converts a vector of batched images ([NHWC]) into a vector of individual image tensors ([1HWC]).
*
* @param images A vector of tensors representing the images. Each tensor can have a shape of either [NHWC] or [HWC].
* @return A vector of tensors where each tensor represents a single image with a shape of [1, H, W, C].
*/
std::vector<ov::Tensor> to_single_image_tensors(const std::vector<ov::Tensor>& images) {
std::vector<ov::Tensor> single_image_tensors;
for (const auto& image : images) {
ov::Tensor reshaped_image = image;
ov::Shape image_shape = image.get_shape();
switch (image_shape.size()) {
case 3:
reshaped_image.set_shape({1, image_shape.at(0), image_shape.at(1), image_shape.at(2)});
break;
case 4: break;
default: OPENVINO_THROW("Input image must have [NHWC] or [HWC] layout");
}
ov::Shape reshaped_image_shape = reshaped_image.get_shape();
for (size_t batch_idx = 0; batch_idx < reshaped_image_shape.at(0); ++batch_idx) {
ov::Tensor single_image{
ov::element::u8,
{1, reshaped_image_shape.at(1), reshaped_image_shape.at(2), reshaped_image_shape.at(3)},
reshaped_image.data<uint8_t>() + batch_idx * reshaped_image_shape.at(1) * reshaped_image_shape.at(2) * reshaped_image_shape.at(3)
};
single_image_tensors.push_back(std::move(single_image));
}
}
return single_image_tensors;
}
};

class InputsEmbedderMiniCPM : public InputsEmbedder::IInputsEmbedder {
Expand Down Expand Up @@ -161,49 +201,35 @@ class InputsEmbedderMiniCPM : public InputsEmbedder::IInputsEmbedder {
virtual ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector<ov::Tensor>& images) override {
std::string images_prompt;
std::vector<EncodedImage> embeds;
for (const ov::Tensor& rgb : images) {
ov::Tensor reshaped = rgb;
ov::Shape rgb_shape = rgb.get_shape();
switch (rgb_shape.size()) {
case 3:
reshaped.set_shape({1, rgb_shape.at(0), rgb_shape.at(1), rgb_shape.at(2)});
break;
case 4: break;
default: OPENVINO_THROW("Input image must have [NHWC] or [HWC] layout");

std::vector<ov::Tensor> single_images = to_single_image_tensors(images);

for (const ov::Tensor& image : single_images) {
EncodedImage encoded_image = m_vision_encoder.encode(image);
if (m_vlm_config.use_image_id) {
images_prompt += m_vlm_config.im_id_start + std::to_string(m_image_id) + m_vlm_config.im_id_end;
++m_image_id;
}
ov::Shape reshaped_shape = reshaped.get_shape();
for (size_t batch_idx = 0; batch_idx < reshaped_shape.at(0); ++batch_idx) {
ov::Tensor single_image{
ov::element::u8,
{1, reshaped_shape.at(1), reshaped_shape.at(2), reshaped_shape.at(3)},
reshaped.data<uint8_t>() + batch_idx * reshaped_shape.at(1) * reshaped_shape.at(1) * reshaped_shape.at(1)
};
EncodedImage encoded_image = m_vision_encoder.encode(single_image);
if (m_vlm_config.use_image_id) {
images_prompt += m_vlm_config.im_id_start + std::to_string(m_image_id) + m_vlm_config.im_id_end;
++m_image_id;
}
std::string unk64;
for (size_t idx = 0; idx < m_vlm_config.query_num; ++idx) {
unk64 += m_vlm_config.unk;
}
images_prompt += m_vlm_config.im_start + unk64 + m_vlm_config.im_end;
if (encoded_image.slices) {
ov::Shape slices_shape = encoded_image.slices.get_shape();
for (size_t row_idx = 0; row_idx < slices_shape.at(0); ++row_idx) {
for (size_t col_idx = 0; col_idx < slices_shape.at(1); ++col_idx) {
images_prompt += m_vlm_config.slice_start + unk64 + m_vlm_config.slice_end;
}
images_prompt += '\n';
std::string unk64;
for (size_t idx = 0; idx < m_vlm_config.query_num; ++idx) {
unk64 += m_vlm_config.unk;
}
images_prompt += m_vlm_config.im_start + unk64 + m_vlm_config.im_end;
if (encoded_image.slices) {
ov::Shape slices_shape = encoded_image.slices.get_shape();
for (size_t row_idx = 0; row_idx < slices_shape.at(0); ++row_idx) {
for (size_t col_idx = 0; col_idx < slices_shape.at(1); ++col_idx) {
images_prompt += m_vlm_config.slice_start + unk64 + m_vlm_config.slice_end;
}
}
if ('\n' != *(images_prompt.end() - 1)) {
// Image wasn't sliced, add \n to the end of image anyway.
// Strangely, \n isn't placed between </image><slice>.
images_prompt += '\n';
}
embeds.push_back(std::move(encoded_image));
}
if ('\n' != *(images_prompt.end() - 1)) {
// Image wasn't sliced, add \n to the end of image anyway.
// Strangely, \n isn't placed between </image><slice>.
images_prompt += '\n';
}
embeds.push_back(std::move(encoded_image));
}
images_prompt += prompt;

Expand Down Expand Up @@ -461,69 +487,86 @@ class InputsEmbedderLLaVA : public InputsEmbedder::IInputsEmbedder {

virtual ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector<ov::Tensor>& images) override {
std::string image_token = m_vlm_config.im_start;
std::string formatted_prompt = images.empty() ? prompt : image_token + "\n" + prompt;

// std::string chat_template_fallback = m_templated_chat_history + " USER: " + formatted_prompt + " ASSISTANT: ";
// chat_template_fallback = chat_template_fallback.erase(0, chat_template_fallback.find_first_not_of(' '));

// Adapted from llava-1.5-7b-hf chat_template.json
std::string chat_template_fallback = "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'USER: ' + message['content'] + ' ' }}{% else %}{{ 'ASSISTANT: ' + message['content'] + ' ' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}";
ov::Tensor input_ids = get_encoded_input_ids(formatted_prompt, chat_template_fallback);

std::vector<ov::Tensor> single_images = to_single_image_tensors(images);

if (images.empty()) {
return m_embedding.infer(input_ids);
} else {
OPENVINO_ASSERT(1 == images.size(), "Only a single image allowed");
EncodedImage encoded_image = m_vision_encoder.encode(images.at(0));
ov::Tensor image_embeds = encoded_image.resized_source;
std::string formatted_prompt;
std::vector<ov::Tensor> image_embeds;
image_embeds.reserve(single_images.size());

ov::Tensor text_embeds = m_embedding.infer(input_ids);
for (const auto& image : single_images) {
EncodedImage encoded_image = m_vision_encoder.encode(image);
image_embeds.push_back(std::move(encoded_image.resized_source));
formatted_prompt += image_token + "\n";
}
formatted_prompt += prompt;

ov::Tensor encoded_image_token = m_tokenizer.encode(image_token, ov::genai::add_special_tokens(false)).input_ids;
int64_t image_token_id = encoded_image_token.data<int64_t>()[encoded_image_token.get_size() - 1];
ov::Tensor input_ids = get_encoded_input_ids(formatted_prompt, chat_template_fallback);
ov::Tensor text_embeds = m_embedding.infer(input_ids);

return merge_text_and_image_embeddings_llava(input_ids, text_embeds, image_embeds, image_token_id);
if (images.empty()) {
return text_embeds;
}

ov::Tensor encoded_image_token = m_tokenizer.encode(m_vlm_config.im_start, ov::genai::add_special_tokens(false)).input_ids;
int64_t image_token_id = encoded_image_token.data<int64_t>()[encoded_image_token.get_size() - 1];

return merge_text_and_image_embeddings_llava(input_ids, text_embeds, image_embeds, image_token_id);
}

protected:
ov::Tensor merge_text_and_image_embeddings_llava(
const ov::Tensor& input_ids,
const ov::Tensor& text_embeds,
const ov::Tensor& image_embeds,
const std::vector<ov::Tensor>& image_embeds,
int64_t image_token_id
) {
auto text_embeds_shape = text_embeds.get_shape();
auto image_embeds_shape = image_embeds.get_shape();
size_t text_embeds_seq_length = text_embeds_shape[1];
size_t hidden_size = text_embeds_shape[2];

const int64_t* input_ids_data = input_ids.data<const int64_t>();
const float* text_embeds_data = text_embeds.data<const float>();

size_t num_image_tokens = 0;
for (size_t s = 0; s < text_embeds_seq_length; ++s) {
if (input_ids_data[s] == image_token_id) {
num_image_tokens++;
}
}
auto num_images = image_embeds.size();
OPENVINO_ASSERT(
text_embeds_shape[2] == image_embeds_shape[2],
"Incompatible shapes between text_embeds and image_embeds"
num_image_tokens == num_images,
"Number of image tokens in input_ids different from num_images."
);

size_t text_embeds_seq_length = text_embeds_shape[1];
size_t hidden_size = text_embeds_shape[2];
size_t image_embeds_seq_length = image_embeds_shape[1];

size_t merged_seq_length = text_embeds_seq_length + (image_embeds_seq_length - 1);
size_t total_image_seq_length = 0;
for (const auto& single_image_embeds : image_embeds) {
OPENVINO_ASSERT(
text_embeds_shape[2] == single_image_embeds.get_shape().at(2),
"Incompatible shapes between text_embeds and image_embeds"
);
total_image_seq_length += single_image_embeds.get_shape().at(1);
}
size_t merged_seq_length = text_embeds_seq_length + total_image_seq_length - num_image_tokens;

ov::Tensor merged_embeds(text_embeds.get_element_type(), {BATCH_SIZE, merged_seq_length, hidden_size});

const int64_t* input_ids_data = input_ids.data<const int64_t>();
const float* text_embeds_data = text_embeds.data<const float>();
const float* image_embeds_data = image_embeds.data<const float>();
float* merged_data = merged_embeds.data<float>();


size_t merged_idx = 0;
size_t image_idx = 0;
for (size_t s = 0; s < text_embeds_seq_length; ++s) {
if (input_ids_data[s] == image_token_id) {
for (size_t i = 0; i < image_embeds_seq_length; ++i) {
std::copy_n(image_embeds_data + i * hidden_size,
hidden_size,
merged_data + merged_idx * hidden_size);
merged_idx++;
}
const float* image_embeds_data = image_embeds[image_idx].data<const float>();
size_t image_seq_length = image_embeds[image_idx].get_shape()[1];

std::copy_n(image_embeds_data,
image_seq_length * hidden_size,
merged_data + merged_idx * hidden_size);
merged_idx += image_seq_length;
image_idx++;
} else {
std::copy_n(text_embeds_data + s * hidden_size,
hidden_size,
Expand All @@ -547,35 +590,47 @@ class InputsEmbedderLLaVANext : public InputsEmbedderLLaVA {

virtual ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector<ov::Tensor>& images) override {
std::string image_token = m_vlm_config.im_start;
std::string formatted_prompt = images.empty() ? prompt : image_token + "\n" + prompt;

// Adapted from llava-1.5-7b-hf chat_template.json
std::string chat_template_fallback = "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'USER: ' + message['content'] + ' ' }}{% else %}{{ 'ASSISTANT: ' + message['content'] + ' ' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}";
ov::Tensor input_ids = get_encoded_input_ids(formatted_prompt, chat_template_fallback);

if (images.empty()) {
return m_embedding.infer(input_ids);
} else {
OPENVINO_ASSERT(1 == images.size(), "Only a single image allowed");
EncodedImage encoded_image = m_vision_encoder.encode(images.at(0));
std::vector<ov::Tensor> single_images = to_single_image_tensors(images);

// Create image_newline tensor with data from config
size_t embed_dim = encoded_image.resized_source.get_shape().at(2);
ov::Tensor image_newline(encoded_image.resized_source.get_element_type(), {embed_dim});
float* image_newline_data = image_newline.data<float>();
std::copy(m_vlm_config.image_newline.begin(), m_vlm_config.image_newline.end(), image_newline_data);
std::string formatted_prompt;
std::vector<ov::Tensor> image_embeds;
image_embeds.reserve(single_images.size());

ov::Tensor image_newline;

ImageSize original_image_size{images.at(0).get_shape().at(1), images.at(0).get_shape().at(2)}; // [height, width]
for (const auto& image : single_images) {
EncodedImage encoded_image = m_vision_encoder.encode(image);

ov::Tensor image_features = pack_image_features_llava_next(encoded_image, original_image_size, image_newline);
if (!image_newline) {
size_t embed_dim = encoded_image.resized_source.get_shape().at(2);
image_newline = ov::Tensor(encoded_image.resized_source.get_element_type(), {embed_dim});
float* image_newline_data = image_newline.data<float>();
std::copy(m_vlm_config.image_newline.begin(), m_vlm_config.image_newline.end(), image_newline_data);
}

ov::Tensor text_embeds = m_embedding.infer(input_ids);
ImageSize original_image_size{image.get_shape().at(1), image.get_shape().at(2)}; // [height, width]

ov::Tensor encoded_image_token = m_tokenizer.encode(image_token, ov::genai::add_special_tokens(false)).input_ids;
int64_t image_token_id = encoded_image_token.data<int64_t>()[encoded_image_token.get_size() - 1];
ov::Tensor packed_features = pack_image_features_llava_next(encoded_image, original_image_size, image_newline);

return merge_text_and_image_embeddings_llava(input_ids, text_embeds, image_features, image_token_id);
image_embeds.push_back(std::move(packed_features));
formatted_prompt += image_token + "\n";
}
formatted_prompt += prompt;

ov::Tensor input_ids = get_encoded_input_ids(formatted_prompt, chat_template_fallback);
ov::Tensor text_embeds = m_embedding.infer(input_ids);

if (images.empty()) {
return text_embeds;
}

ov::Tensor encoded_image_token = m_tokenizer.encode(m_vlm_config.im_start, ov::genai::add_special_tokens(false)).input_ids;
int64_t image_token_id = encoded_image_token.data<int64_t>()[encoded_image_token.get_size() - 1];

return merge_text_and_image_embeddings_llava(input_ids, text_embeds, image_embeds, image_token_id);
}

private:
Expand Down
21 changes: 15 additions & 6 deletions src/cpp/src/visual_language/vision_encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -657,10 +657,13 @@ EncodedImage VisionEncoder::encode_llava(const ov::Tensor& image, const Processo
m_vision_encoder.set_tensor("pixel_values", pixel_values);
m_vision_encoder.infer();

ov::Tensor image_features = m_vision_encoder.get_output_tensor();
const ov::Tensor& infer_output = m_vision_encoder.get_output_tensor();
ov::Tensor image_features(infer_output.get_element_type(), infer_output.get_shape());
std::memcpy(image_features.data(), infer_output.data(), infer_output.get_byte_size());

ImageSize resized_source_size{config.crop_size_height / config.patch_size, config.crop_size_width / config.patch_size};

return {image_features, resized_source_size};
return {std::move(image_features), resized_source_size};
}

EncodedImage VisionEncoder::encode_llava_next(const ov::Tensor& image, const ProcessorConfig& config) {
Expand All @@ -669,7 +672,10 @@ EncodedImage VisionEncoder::encode_llava_next(const ov::Tensor& image, const Pro
m_vision_encoder.set_tensor("pixel_values", pixel_values);
m_vision_encoder.infer();

ov::Tensor image_features = m_vision_encoder.get_output_tensor();
const ov::Tensor& infer_output = m_vision_encoder.get_output_tensor();
ov::Tensor image_features(infer_output.get_element_type(), infer_output.get_shape());
std::memcpy(image_features.data(), infer_output.data(), infer_output.get_byte_size());

ImageSize resized_source_size{config.crop_size_height / config.patch_size, config.crop_size_width / config.patch_size};

// Gen number of patches
Expand All @@ -679,7 +685,7 @@ EncodedImage VisionEncoder::encode_llava_next(const ov::Tensor& image, const Pro
int num_patches_h = best_resolution.second / config.size_shortest_edge;

EncodedImage encoded_image;
encoded_image.resized_source = image_features;
encoded_image.resized_source = std::move(image_features);
encoded_image.resized_source_size = resized_source_size;
encoded_image.patches_grid = {num_patches_h, num_patches_w};
return encoded_image;
Expand All @@ -691,8 +697,11 @@ EncodedImage VisionEncoder::encode_internvl(const ov::Tensor& image, const Proce
m_vision_encoder.set_tensor("pixel_values", pixel_values);
m_vision_encoder.infer();

ov::Tensor image_features = m_vision_encoder.get_output_tensor();
const ov::Tensor& infer_output = m_vision_encoder.get_output_tensor();
ov::Tensor image_features(infer_output.get_element_type(), infer_output.get_shape());
std::memcpy(image_features.data(), infer_output.data(), infer_output.get_byte_size());

ImageSize resized_source_size{config.crop_size_height / config.patch_size, config.crop_size_width / config.patch_size};

return {image_features, resized_source_size};
return {std::move(image_features), resized_source_size};
}

0 comments on commit 418aece

Please sign in to comment.