Skip to content

Commit

Permalink
fix config saving when check on misplaced args broken
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Oct 24, 2024
1 parent 86598a6 commit f10cf64
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 3 deletions.
7 changes: 6 additions & 1 deletion optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
_get_open_clip_submodels_fn_and_export_configs,
clear_class_registry,
remove_none_from_dummy_inputs,
save_config,
)


Expand Down Expand Up @@ -658,7 +659,11 @@ def export_from_model(
files_subpaths = ["openvino_" + model_name + ".xml" for model_name in models_and_export_configs.keys()]
elif library_name != "diffusers":
if is_transformers_version(">=", "4.44.99"):
misplaced_generation_parameters = model.config._get_non_default_generation_parameters()
# some model configs may have issues with loading without parameters initialization
try:
misplaced_generation_parameters = model.config._get_non_default_generation_parameters()
except Exception:
misplaced_generation_parameters = {}
if isinstance(model, GenerationMixin) and len(misplaced_generation_parameters) > 0:
logger.warning(
"Moving the following attributes in the config to the generation config: "
Expand Down
11 changes: 11 additions & 0 deletions optimum/exporters/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import inspect
from collections import namedtuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from pathlib import Path

from transformers.utils import is_torch_available

Expand Down Expand Up @@ -209,3 +210,13 @@ def get_submodels(model):


MULTI_MODAL_TEXT_GENERATION_MODELS = ["llava", "llava-next", "internvl-chat"]


def save_config(config, save_dir):
try:
config.save_pretrained(save_dir)
except Exception:
save_dir = Path(save_dir)
save_dir.mkdir(exist_ok=True)
output_config_file = Path(save_dir / "config.json")
config.to_json_file(output_config_file, use_diff=True)
6 changes: 5 additions & 1 deletion optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,11 @@ def __init__(
self.generation_config = generation_config or GenerationConfig.from_model_config(config)

if is_transformers_version(">=", "4.44.99"):
misplaced_generation_parameters = self.config._get_non_default_generation_parameters()
# some model configs may have issues with loading without parameters initialization
try:
misplaced_generation_parameters = self.config._get_non_default_generation_parameters()
except Exception:
misplaced_generation_parameters = {}
if len(misplaced_generation_parameters) > 0:
logger.warning(
"Moving the following attributes in the config to the generation config: "
Expand Down
6 changes: 5 additions & 1 deletion optimum/intel/openvino/modeling_base_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@ def __init__(
self.generation_config = generation_config or GenerationConfig.from_model_config(config)

if is_transformers_version(">=", "4.44.99"):
misplaced_generation_parameters = self.config._get_non_default_generation_parameters()
# some model configs may have issues with loading without parameters initialization
try:
misplaced_generation_parameters = self.config._get_non_default_generation_parameters()
except Exception:
misplaced_generation_parameters = {}
if len(misplaced_generation_parameters) > 0:
logger.warning(
"Moving the following attributes in the config to the generation config: "
Expand Down
8 changes: 8 additions & 0 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from transformers.modeling_outputs import BaseModelOutputWithPooling

from ...exporters.openvino import main_export
from ...exporters.openvino.utils import save_config
from ...exporters.openvino.stateful import ensure_stateful_is_available
from .configuration import OVConfig, OVWeightQuantizationConfig
from .modeling_base import OVBaseModel, OVModelPart
Expand Down Expand Up @@ -272,6 +273,13 @@ def compile(self):
if part_model is not None:
part_model._compile()

def _save_config(self, save_directory):
"""
Saves a model configuration into a directory, so that it can be re-loaded using the
[`from_pretrained`] class method.
"""
save_config(self.config, save_directory)

def _save_pretrained(self, save_directory: Union[str, Path]):
"""
Saves the model to the OpenVINO IR format so that it can be re-loaded using the
Expand Down

0 comments on commit f10cf64

Please sign in to comment.