From 682362d7ae937e9888cf5467fed00943fc13ec92 Mon Sep 17 00:00:00 2001 From: eaidova Date: Wed, 16 Oct 2024 14:08:31 +0400 Subject: [PATCH] nanollava support --- optimum/exporters/openvino/model_configs.py | 169 ++++++++++++++- optimum/exporters/openvino/model_patcher.py | 18 ++ optimum/exporters/openvino/utils.py | 2 +- optimum/intel/openvino/modeling_decoder.py | 2 +- .../openvino/modeling_visual_language.py | 196 +++++++++++++++++- 5 files changed, 373 insertions(+), 14 deletions(-) diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index a2bd9d4342..c238627eca 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from packaging import version -from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel +from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, TFPreTrainedModel from transformers.utils import is_tf_available from optimum.exporters.onnx.config import OnnxConfig, TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig @@ -69,6 +69,7 @@ JaisModelPatcher, LlamaModelPatcher, LlavaImageEmbeddingModelPatcher, + LlavaQwen2ImageEmbeddingsModelPatcher, MiniCPMVImageEmbeddingsModelPatcher, MiniCPMVResamplerModelPatcher, MistralModelPatcher, @@ -1218,8 +1219,8 @@ def patch_model_for_export( class LlavaConfigBehavior(str, enum.Enum): - LANGUAGE = "language" VISION_EMBEDDINGS = "vision_embeddings" + LANGUAGE = "language" TEXT_EMBEDDINGS = "text_embeddings" @@ -1380,6 +1381,166 @@ class LlavaNextOpenVINOConfig(LlavaOpenVINOConfig): MIN_TRANSFORMERS_VERSION = version.parse("4.40.0") +@register_in_tasks_manager( + "llava-qwen2", *["image-text-to-text", "text-generation", "text-generation-with-past"], library_name="transformers" +) +class LlavaQwen2OpenVINOConfig(OnnxConfig): + SUPPORTS_PAST = True + MIN_TRANSFORMERS_VERSION = version.parse("4.40.0") + SUPPORTED_BEHAVIORS = [model_type.value for model_type in LlavaConfigBehavior] + NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig + DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator,) + + def __init__( + self, + config: "PretrainedConfig", + task: str = "feature-extraction", + int_dtype: str = "int64", + float_dtype: str = "fp32", + behavior: LlavaConfigBehavior = LlavaConfigBehavior.VISION_EMBEDDINGS, + preprocessors: Optional[List[Any]] = None, + use_past: bool = False, + ): + self._behavior = behavior + self._orig_config = config + if self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS: + config = AutoConfig.from_pretrained(config.mm_vision_tower, trust_remote_code=True) + if hasattr(config, "vision_config"): + config = config.vision_config + super().__init__( + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + preprocessors=preprocessors, + ) + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + if not self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS: + return {} + return {"pixel_values": {0: "batch_size", 2: "height", 3: "width"}} + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + if not self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS: + return {} + return {"last_hidden_state": {0: "batch_size"}} + + def get_model_for_behavior(self, model, behavior: Union[str, LlavaConfigBehavior]): + if isinstance(behavior, str) and not isinstance(behavior, LlavaConfigBehavior): + behavior = LlavaConfigBehavior(behavior) + + if behavior == LlavaConfigBehavior.LANGUAGE: + model.forward = super(type(model), model).forward + return model + + if behavior == LlavaConfigBehavior.VISION_EMBEDDINGS: + return model + + if behavior == LlavaConfigBehavior.TEXT_EMBEDDINGS: + text_embedding = model.model.embed_tokens + text_embedding.config = model.model.config + return text_embedding + + def with_behavior( + self, + behavior: Union[str, LlavaConfigBehavior], + ): + """ + Creates a config for different behaviour. + + Args: + behavior ([`ConfigBehavior`]): + The behavior to use for the new instance. + """ + if isinstance(behavior, str) and not isinstance(behavior, LlavaConfigBehavior): + behavior = LlavaConfigBehavior(behavior) + + if behavior == LlavaConfigBehavior.TEXT_EMBEDDINGS: + model_type = self._orig_config.model_type.replace("llava-", "") + model_type = model_type.replace("_", "-") + if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: + raise ValueError( + f"Unsupported language model type provided `{model_type}`. Please define custom export config" + ) + + if "text-generation-with-past" not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"]: + raise ValueError( + f"Export config for text generation for `{model_type}` is not available. Please define custom export config" + ) + internal_export_config_class = TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"][ + "text-generation-with-past" + ] + internal_export_config = internal_export_config_class( + self._orig_config, + use_past=True, + use_past_in_inputs=True, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + ) + InputEmbedOpenvVINOConfig.NORMALIZED_CONFIG_CLASS = internal_export_config.NORMALIZED_CONFIG_CLASS + export_config = InputEmbedOpenvVINOConfig( + self._orig_config, + task="feature-extraction", + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + ) + return export_config + + if behavior == LlavaConfigBehavior.LANGUAGE: + model_type = self._orig_config.model_type.replace("llava-", "") + model_type = model_type.replace("_", "-") + + if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: + raise ValueError( + f"Unsupported language model type provided `{model_type}`. Please define custom export config" + ) + + if "text-generation-with-past" not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"]: + raise ValueError( + f"Export config for text generation for `{model_type}` is not available. Please define custom export config" + ) + internal_export_config_class = TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"][ + "text-generation-with-past" + ] + internal_export_config = internal_export_config_class( + self._orig_config, + use_past=True, + use_past_in_inputs=True, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + ) + export_config = LMInputEmbedsConfigHelper(internal_export_config) + export_config._normalized_config = internal_export_config._normalized_config + return export_config + + if behavior == LlavaConfigBehavior.VISION_EMBEDDINGS: + return self.__class__( + self._orig_config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + behavior=behavior, + preprocessors=self._preprocessors, + ) + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ): + model_kwargs = model_kwargs or {} + if self._behavior != LlavaConfigBehavior.VISION_EMBEDDINGS: + return super().patch_model_for_export(model, model_kwargs) + return LlavaQwen2ImageEmbeddingsModelPatcher(self, model, model_kwargs) + + def rename_ambiguous_inputs(self, inputs): + if self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS: + model_inputs = {} + model_inputs["images"] = inputs["pixel_values"] + return model_inputs + return super().rename_ambiguous_inputs(inputs) + + class InternVLChatConfigBehavior(str, enum.Enum): LANGUAGE = "language" VISION_EMBEDDINGS = "vision_embeddings" @@ -1508,8 +1669,8 @@ def with_behavior( preprocessors=self._preprocessors, ) - def get_model_for_behavior(self, model, behavior: Union[str, LlavaConfigBehavior]): - if isinstance(behavior, str) and not isinstance(behavior, LlavaConfigBehavior): + def get_model_for_behavior(self, model, behavior: Union[str, InternVLChatConfigBehavior]): + if isinstance(behavior, str) and not isinstance(behavior, InternVLChatConfigBehavior): behavior = InternVLChatConfigBehavior(behavior) if behavior == InternVLChatConfigBehavior.LANGUAGE: diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 2a3b98bd54..ff2a2cec62 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -2936,3 +2936,21 @@ def forward(self, input): def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward + + +class LlavaQwen2ImageEmbeddingsModelPatcher(ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Dict[str, Any], + ): + model.__orig_forward = model.forward + model.forward = model.encode_images + super().__init__(config, model, model_kwargs) + if not self._model.get_vision_tower().is_loaded: + self._model.get_vision_tower().load_model() + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + self._model.forward = self._model.__orig_forward diff --git a/optimum/exporters/openvino/utils.py b/optimum/exporters/openvino/utils.py index 2045d4302f..663aaae5c2 100644 --- a/optimum/exporters/openvino/utils.py +++ b/optimum/exporters/openvino/utils.py @@ -208,4 +208,4 @@ def get_submodels(model): return custom_export, fn_get_submodels -MULTI_MODAL_TEXT_GENERATION_MODELS = ["llava", "llava-next", "internvl-chat", "minicpmv"] +MULTI_MODAL_TEXT_GENERATION_MODELS = ["llava", "llava-next", "llava-qwen2", "internvl-chat", "minicpmv"] diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 733f5a4119..3ee49c9965 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -504,7 +504,7 @@ def prepare_inputs( else: position_ids = np.cumsum(attention_mask, axis=1) - 1 position_ids[attention_mask == 0] = 1 - if past_key_values: + if past_key_values is not None: position_ids = position_ids[:, -input_ids.shape[1] :] inputs["position_ids"] = position_ids diff --git a/optimum/intel/openvino/modeling_visual_language.py b/optimum/intel/openvino/modeling_visual_language.py index 49d5026bf3..588235c1ec 100644 --- a/optimum/intel/openvino/modeling_visual_language.py +++ b/optimum/intel/openvino/modeling_visual_language.py @@ -15,7 +15,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPooling from ...exporters.openvino import main_export -from ...exporters.openvino.stateful import ensure_stateful_is_available +from ...exporters.openvino.stateful import ensure_stateful_is_available, model_has_input_output_name from .configuration import OVConfig, OVWeightQuantizationConfig from .modeling_base import OVBaseModel, OVModelPart from .modeling_decoder import CausalLMOutputWithPast, OVModelForCausalLM @@ -140,8 +140,8 @@ def prepare_inputs( else: position_ids = np.cumsum(attention_mask, axis=1) - 1 position_ids[attention_mask == 0] = 1 - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + if past_len: + position_ids = position_ids[:, -inputs_embeds.shape[1] :] inputs["position_ids"] = position_ids @@ -191,13 +191,14 @@ def __init__(self, model: ov.Model, parent_model: OVBaseModel) -> None: self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)} self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)} self.hidden_states_output_names = [] + self._main_input = "images" if model_has_input_output_name(self.model, "images") else "pixel_values" if len(self.model.outputs) > 2: self.hidden_states_output_names = [ key.get_any_name() for key in self.model.outputs[2:] if "hidden_states" in key.get_any_name() ] def forward(self, pixel_values, **kwargs): - inputs = {"pixel_values": pixel_values} + inputs = {self._main_input: pixel_values} if len(self.input_names) > 1: for name in self.input_names: if name in kwargs: @@ -586,7 +587,7 @@ def half(self): def forward( self, input_ids, - pixel_values, + pixel_values=None, past_key_values=None, inputs_embeds=None, image_sizes=None, @@ -594,8 +595,11 @@ def forward( position_ids=None, image_bound=None, tgt_sizes=None, + images=None, **kwargs, ): + if pixel_values is None and images is not None: + pixel_values = images inputs_embeds, attention_mask, position_ids = self.get_multimodal_embeddings( input_ids, pixel_values, @@ -665,8 +669,9 @@ def prepare_inputs_for_generation( # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + if attention_mask is not None and past_length + 1 > input_ids.shape[1]: + input_discount = max(attention_mask.shape[1] - past_length, 1) + input_ids = input_ids[:, -input_discount:] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length.llava elif past_length < input_ids.shape[1]: @@ -679,7 +684,7 @@ def prepare_inputs_for_generation( if attention_mask is not None and position_ids is None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: + if past_key_values is not None: position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step @@ -698,6 +703,7 @@ def prepare_inputs_for_generation( "image_sizes": image_sizes, "image_bound": kwargs.get("image_bound"), "tgt_sizes": kwargs.get("tgt_sizes"), + "images": kwargs.get("images"), } ) return model_inputs @@ -1388,9 +1394,183 @@ def merge_vision_text_embeddings( return vllm_embedding, attention_mask, position_ids +class _OVNanoLlavaForCausalLM(OVModelForVisualCausalLM): + def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs): + if input_ids is not None and input_ids.shape[1] == 1: + return None + if isinstance(pixel_values, list) or pixel_values.ndim == 5: + concat_images = torch.cat([image for image in pixel_values], dim=0) + image_features = torch.from_numpy(self.vision_embeddings(concat_images).last_hidden_state) + split_sizes = [image.shape[0] for image in pixel_values] + image_features = torch.split(image_features, split_sizes, dim=0) + image_features = [x.flatten(0, 1).to(self.device) for x in image_features] + else: + image_features = self.vision_embeddings(pixel_values).last_hidden_state + + return image_features + + def get_multimodal_embeddings( + self, input_ids, pixel_values=None, attention_mask=None, position_ids=None, **kwargs + ): + vision_embeds = None + IGNORE_INDEX = -100 + IMAGE_TOKEN_INDEX = -200 + if pixel_values is not None: + vision_embeds = self.get_vision_embeddings(pixel_values, input_ids=input_ids, **kwargs) + if vision_embeds is None: + inputs_embeds = torch.from_numpy(self.get_text_embeddings(input_ids)) + if kwargs.get("past_key_values") is not None: + past_len = self.language_model._get_past_length(kwargs.get("past_key_values")) + if attention_mask is not None and attention_mask.shape[1] < past_len + input_ids.shape[1]: + attention_mask = torch.cat( + [ + attention_mask, + torch.ones(attention_mask.shape[0], past_len + input_ids.shape[1] - attention_mask.shape[1]), + ], + dim=1, + ) + position_ids = None + return inputs_embeds, attention_mask, position_ids + + vision_embeds = torch.from_numpy(vision_embeds) if isinstance(vision_embeds, np.ndarray) else vision_embeds + + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.long) + if position_ids is None: + position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + labels = torch.full_like(input_ids, IGNORE_INDEX) + + # remove the padding using attention_mask -- TODO: double check + input_ids = [ + cur_input_ids[cur_attention_mask] + for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask.bool()) + ] + labels = [ + cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask.bool()) + ] + + new_input_embeds = [] + new_labels = [] + cur_image_idx = 0 + for batch_idx, cur_input_ids in enumerate(input_ids): + num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + if num_images == 0: + cur_image_features = vision_embeds[cur_image_idx] + cur_input_embeds_1 = torch.from_numpy(self.get_text_embeddings(cur_input_ids.unsqueeze(0))[0]) + cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) + new_input_embeds.append(cur_input_embeds) + new_labels.append(labels[batch_idx]) + cur_image_idx += 1 + continue + + image_token_indices = ( + [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] + ) + cur_input_ids_noim = [] + cur_labels = labels[batch_idx] + cur_labels_noim = [] + for i in range(len(image_token_indices) - 1): + cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]) + cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]) + split_sizes = [x.shape[0] for x in cur_labels_noim] + cur_input_embeds = torch.from_numpy( + self.get_text_embeddings(torch.cat(cur_input_ids_noim).unsqueeze(0))[0] + ) + cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) + cur_new_input_embeds = [] + cur_new_labels = [] + + for i in range(num_images + 1): + cur_new_input_embeds.append(cur_input_embeds_no_im[i]) + cur_new_labels.append(cur_labels_noim[i]) + if i < num_images: + cur_image_features = vision_embeds[cur_image_idx] + cur_image_idx += 1 + cur_new_input_embeds.append(cur_image_features) + cur_new_labels.append( + torch.full( + (cur_image_features.shape[0],), + IGNORE_INDEX, + device=cur_labels.device, + dtype=cur_labels.dtype, + ) + ) + + cur_new_input_embeds = torch.cat(cur_new_input_embeds) + cur_new_labels = torch.cat(cur_new_labels) + + new_input_embeds.append(cur_new_input_embeds) + new_labels.append(cur_new_labels) + + # Truncate sequences to max length as image embeddings can make the sequence longer + tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None) + if tokenizer_model_max_length is not None: + new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] + new_labels = [x[:tokenizer_model_max_length] for x in new_labels] + + # Combine them + max_len = max(x.shape[0] for x in new_input_embeds) + batch_size = len(new_input_embeds) + + new_input_embeds_padded = [] + new_labels_padded = torch.full( + (batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device + ) + attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) + position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) + + for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): + cur_len = cur_new_embed.shape[0] + if getattr(self.config, "tokenizer_padding_side", "right") == "left": + new_input_embeds_padded.append( + torch.cat( + ( + torch.zeros( + (max_len - cur_len, cur_new_embed.shape[1]), + dtype=cur_new_embed.dtype, + device=cur_new_embed.device, + ), + cur_new_embed, + ), + dim=0, + ) + ) + if cur_len > 0: + new_labels_padded[i, -cur_len:] = cur_new_labels + attention_mask[i, -cur_len:] = True + position_ids[i, -cur_len:] = torch.arange( + 0, cur_len, dtype=position_ids.dtype, device=position_ids.device + ) + else: + new_input_embeds_padded.append( + torch.cat( + ( + cur_new_embed, + torch.zeros( + (max_len - cur_len, cur_new_embed.shape[1]), + dtype=cur_new_embed.dtype, + device=cur_new_embed.device, + ), + ), + dim=0, + ) + ) + if cur_len > 0: + new_labels_padded[i, :cur_len] = cur_new_labels + attention_mask[i, :cur_len] = True + position_ids[i, :cur_len] = torch.arange( + 0, cur_len, dtype=position_ids.dtype, device=position_ids.device + ) + + new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) + + return new_input_embeds, attention_mask, position_ids + + MODEL_TYPE_TO_CLS_MAPPING = { "llava": _OVLlavaForCausalLM, "llava_next": _OVLlavaNextForCausalLM, "internvl_chat": _OvInternVLForCausalLM, "minicpmv": _OVMiniCPMVForCausalLM, + "llava-qwen2": _OVNanoLlavaForCausalLM, }