Skip to content

Commit

Permalink
Disable graph check while tracing (#1103)
Browse files Browse the repository at this point in the history
* Disable graph check while tracing

* Fix style

* Apply suggestions from code review

* Move decoder input closer to the usage. For cases when torch is not available in openvino

---------

Co-authored-by: Ilyas Moutawwakil <[email protected]>
  • Loading branch information
mvafin and IlyasMoutawwakil authored Jan 14, 2025
1 parent f28aabc commit fe55db5
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
compare_versions,
is_diffusers_version,
is_openvino_tokenizers_version,
is_openvino_version,
is_tokenizers_version,
is_transformers_version,
)
Expand Down Expand Up @@ -366,6 +367,7 @@ def export_pytorch(
import torch
from torch.utils._pytree import tree_map

from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
from optimum.exporters.utils import check_dummy_inputs_are_allowed

logger.info(f"Using framework PyTorch: {torch.__version__}")
Expand Down Expand Up @@ -428,15 +430,20 @@ def ts_patched_forward(*args, **kwargs):

patcher.patched_forward = ts_patched_forward

ts_decoder_kwargs = {}
if library_name == "diffusers" and is_openvino_version(">=", "2025.0"):
ts_decoder_kwargs["trace_kwargs"] = {"check_trace": False}

with patcher:
if patch_16bit_model:
from openvino.frontend.pytorch.patch_model import __make_16bit_traceable

__make_16bit_traceable(model)
check_dummy_inputs_are_allowed(model, dummy_inputs)
input_info = _get_input_info(model, config, dummy_inputs)
ts_decoder = TorchScriptPythonDecoder(model, example_input=dummy_inputs, **ts_decoder_kwargs)
ov_model = convert_model(
model,
ts_decoder,
example_input=dummy_inputs,
input=[(item.shape, item.type) for item in input_info],
)
Expand Down

0 comments on commit fe55db5

Please sign in to comment.