Skip to content

Commit

Permalink
move patch model to init
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <[email protected]>
  • Loading branch information
sywangyi committed Oct 10, 2024
1 parent 8b574d0 commit 541a236
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 5 deletions.
1 change: 1 addition & 0 deletions optimum/exporters/ipex/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
) -> None:
super().__init__()
self.max_batch_size = max_batch_size
self.batch_size = max_batch_size
self.kv_cache = []

self._seen_tokens = max_batch_size * [
Expand Down
7 changes: 2 additions & 5 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def __init__(

self.input_names = set(inspect.signature(model.forward).parameters)

if self._is_ipex_exported:
model = _patch_model(model)
# Registers the IPEXModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating
# a pipeline https://github.com/huggingface/transformers/blob/cad61b68396a1a387287a8e2e2fef78a25b79383/src/transformers/pipelines/base.py#L863
AutoConfig.register(self.base_model_prefix, AutoConfig)
Expand Down Expand Up @@ -238,11 +240,6 @@ def _from_pretrained(
_commit_hash=commit_hash,
**model_kwargs,
)
if is_torch_xpu_available(check_device=True):
model.to("xpu:0")

if _is_patched_with_ipex(model, task):
model = _patch_model(model)
return cls(model, config=config, export=True, **kwargs)

def _save_pretrained(self, save_directory: Union[str, Path]):
Expand Down

0 comments on commit 541a236

Please sign in to comment.