Skip to content

Commit

Permalink
Revert "Revert "Use opset15 version of Str Pack/Unpack (#351)" (#374)" (
Browse files Browse the repository at this point in the history
#381)

This reverts commit e8da805.
  • Loading branch information
pavel-esir authored Jan 17, 2025
1 parent 2a59579 commit 2e59c96
Show file tree
Hide file tree
Showing 13 changed files with 8,421 additions and 8,437 deletions.
19 changes: 17 additions & 2 deletions python/openvino_tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def new_fe_init(self, *args, **kwargs):


openvino.runtime.Core.__init__ = new_core_init
openvino.runtime.utils.node_factory.NodeFactory.__init__ = new_factory_init
openvino.frontend.frontend.FrontEnd.__init__ = new_fe_init


Expand All @@ -76,6 +75,22 @@ def _get_factory_callable() -> Callable[[], NodeFactory]:
def inner(opset_version: Optional[str] = None) -> NodeFactory:
nonlocal factory
if opset_version not in factory:
openvino.runtime.utils.node_factory.NodeFactory.__init__ = new_factory_init
factory[opset_version] = NodeFactory() if opset_version is None else NodeFactory(opset_version)

return factory[opset_version]

return inner


def _get_opset_factory_callable() -> Callable[[], NodeFactory]:
# factory without extensions
factory = {}

def inner(opset_version: Optional[str] = None) -> NodeFactory:
nonlocal factory
if opset_version not in factory:
openvino.runtime.utils.node_factory.NodeFactory.__init__ = old_factory_init
factory[opset_version] = NodeFactory() if opset_version is None else NodeFactory(opset_version)

return factory[opset_version]
Expand All @@ -84,10 +99,10 @@ def inner(opset_version: Optional[str] = None) -> NodeFactory:


_get_factory = _get_factory_callable()
_get_opset_factory = _get_opset_factory_callable()

# some files uses _get_factory function
from .__version__ import __version__ # noqa
from .build_tokenizer import build_rwkv_tokenizer # noqa
from .convert_tokenizer import convert_tokenizer # noqa
from .str_pack import pack_strings, unpack_strings # noqa
from .utils import add_greedy_decoding, connect_models # noqa
8 changes: 4 additions & 4 deletions python/openvino_tokenizers/build_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ def build_rwkv_tokenizer(
tokenizer_output_type: Type = Type.i64,
detokenizer_input_type: Type = Type.i64,
) -> Tuple[Model, Model]:
from openvino_tokenizers import _get_factory
from openvino_tokenizers import _get_factory, _get_opset_factory

input_node = op.Parameter(Type.string, PartialShape(["?"]))
input_node.set_friendly_name("string_input")

output = _get_factory().create("StringTensorUnpack", input_node.outputs()).outputs()
output = _get_opset_factory("opset15").create("StringTensorUnpack", input_node.outputs()).outputs()
trie_node = TrieTokenizerStep.from_rwkv_vocab(rwkv_vocab)
output = trie_node.get_ov_subgraph(TokenizerPipeline.add_ragged_dimension(output))

Expand Down Expand Up @@ -56,7 +56,7 @@ def build_rwkv_tokenizer(
_get_factory()
.create(
"VocabDecoder",
[*detokenizer_input.outputs(), *BasePipelineStep.create_string_constant_node(trie_node.vocab).outputs()],
[*detokenizer_input.outputs(), *BasePipelineStep.create_string_constant_node(trie_node.vocab)],
)
.outputs()
)
Expand All @@ -65,7 +65,7 @@ def build_rwkv_tokenizer(
if clean_up_tokenization_spaces:
RegexDecodingStep.clean_up_tokenization_spaces().get_ov_subgraph(detokenizer_output)

detokenizer_output = _get_factory().create("StringTensorPack", detokenizer_output).outputs()
detokenizer_output = _get_opset_factory("opset15").create("StringTensorPack", detokenizer_output).outputs()
detokenizer_output[0].tensor.add_names({STRING_OUTPUT_NAME})

detokenizer = Model(detokenizer_output, [detokenizer_input], DETOKENIZER_NAME)
Expand Down
6 changes: 3 additions & 3 deletions python/openvino_tokenizers/hf_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from transformers import PreTrainedTokenizerBase, PreTrainedTokenizerFast
from transformers.convert_slow_tokenizer import import_protobuf

from . import _get_factory
from . import _get_factory, _get_opset_factory
from .constants import (
ATTENTION_MASK_INPUT_NAME,
DETOKENIZER_NAME,
Expand Down Expand Up @@ -810,7 +810,7 @@ def convert_sentencepiece_model_tokenizer(
if params.handle_special_tokens_with_re:
tokens, ids = zip(*sorted(((token, id) for id, token in add_tokens.items()), reverse=True))
added_inputs = [
*BasePipelineStep.create_string_constant_node(tokens).outputs(),
*BasePipelineStep.create_string_constant_node(tokens),
make_constant_node(np.array(ids, dtype=np.int32), Type.i32).output(0),
]
else:
Expand Down Expand Up @@ -1013,7 +1013,7 @@ def get_sp_detokenizer(
if params.utf8_replace_mode is not None and params.utf8_replace_mode != UTF8ReplaceMode.DISABLE:
last_sinks = UTF8ValidateStep(params.utf8_replace_mode).get_ov_subgraph(detokenizer)

string_output = _get_factory().create("StringTensorPack", last_sinks).outputs()
string_output = _get_opset_factory("opset15").create("StringTensorPack", last_sinks).outputs()
string_output[0].tensor.add_names({STRING_OUTPUT_NAME})
tokenizer_detokenizer = Model(string_output, [model_input], DETOKENIZER_NAME)
tokenizer_detokenizer.validate_nodes_and_infer_types()
Expand Down
62 changes: 0 additions & 62 deletions python/openvino_tokenizers/str_pack.py

This file was deleted.

69 changes: 40 additions & 29 deletions python/openvino_tokenizers/tokenizer_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from openvino.runtime.exceptions import OVTypeError, UserInputError
from openvino.runtime.utils.types import as_node, make_constant_node

from . import _get_factory
from . import _get_factory, _get_opset_factory
from .constants import (
ATTENTION_MASK_INPUT_NAME,
DETOKENIZER_NAME,
Expand All @@ -31,8 +31,13 @@
VOCAB_SIZE_CACHE_PROPORTION,
UTF8ReplaceMode,
)
from .str_pack import pack_string, pack_strings
from .utils import apply_unicode_to_bytes, generate_tokens_with_space_symbols, has_incompatible_re2_op, quote_meta
from .utils import (
apply_unicode_to_bytes,
create_unpacked_string,
generate_tokens_with_space_symbols,
has_incompatible_re2_op,
quote_meta,
)


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -66,15 +71,15 @@ def get_ov_subgraph(self, *input_nodes: List[Output]) -> List[Output]:
raise NotImplementedError

@staticmethod
def create_string_constant_node(value: Union[str, Iterable[str]]) -> op.Constant:
def create_string_constant_node(value: Union[str, Iterable[str]]) -> List[Output]:
if isinstance(value, str):
# string scalar
ps = pack_string(value)
return op.Constant(ps)
else:
return op.Constant(np.frombuffer(bytes(value, "utf-8"), dtype=np.uint8)).outputs()
elif isinstance(value, Iterable):
# support only 1D strings for now
ps = pack_strings(value)
return _get_factory().create("StringTensorUnpack", op.Constant(ps).outputs())
return create_unpacked_string(value)
else:
raise ValueError(f"Unsupported value type {type(value)}")

def finalize(self) -> None:
"""Called after the entire pipeline has been built"""
Expand Down Expand Up @@ -144,7 +149,7 @@ def get_ov_subgraph(self, input_nodes: List[Output]) -> List[Output]:
return list(input_nodes)

split_pattern = "|".join(token.regex_repr() for token in self.special_tokens)
input_nodes.extend(self.create_string_constant_node(split_pattern).outputs())
input_nodes.extend(self.create_string_constant_node(split_pattern))

return _get_factory().create("SpecialTokensSplit", input_nodes).outputs()

Expand Down Expand Up @@ -233,10 +238,10 @@ def del_control_chars_regex(cls) -> "RegexNormalizationStep":

def get_ov_subgraph(self, input_nodes: List[Output]) -> List[Output]:
input_nodes.extend(
(
self.create_string_constant_node(self.regex_search_pattern),
self.create_string_constant_node(self.replace_term),
)
[
*self.create_string_constant_node(self.regex_search_pattern),
*self.create_string_constant_node(self.replace_term),
]
)
return (
_get_factory().create("RegexNormalization", input_nodes, {"global_replace": self.global_replace}).outputs()
Expand Down Expand Up @@ -357,7 +362,7 @@ def punctuation_splitter(cls, behaviour="isolate") -> "RegexSplitStep":
)

def get_ov_subgraph(self, input_nodes: List[Output]) -> List[Output]:
input_nodes.extend(self.create_string_constant_node(self.split_pattern).outputs())
input_nodes.extend(self.create_string_constant_node(self.split_pattern))
return (
_get_factory()
.create(
Expand Down Expand Up @@ -423,7 +428,7 @@ def get_vocab_node_outputs(self) -> Optional[List[Output]]:

def get_ov_subgraph(self, input_nodes: List[Output]) -> List[Output]:
pipeline = self.get_pipeline()
pipeline.vocab_node_outputs = self.create_string_constant_node(self.vocab).outputs()
pipeline.vocab_node_outputs = self.create_string_constant_node(self.vocab)

ragged_dims, other_dims = [], input_nodes
if len(input_nodes) > 4:
Expand Down Expand Up @@ -475,7 +480,7 @@ def from_rwkv_vocab(cls, vocab_file_strings: Iterable[str]) -> TrieTokenizerStep
def get_ov_subgraph(self, input_nodes: List[Output]) -> List[Output]:
input_nodes.extend(
(
*self.create_string_constant_node(self.vocab).outputs(),
*self.create_string_constant_node(self.vocab),
make_constant_node(np.array(self.indices, dtype=np.int32), Type.i32),
)
)
Expand Down Expand Up @@ -511,7 +516,7 @@ def from_hf_json(cls, tokenizer_json: Dict[str, Any]) -> "WordPieceTokenizationS
def get_ov_subgraph(self, input_nodes: List[Output]) -> List[Output]:
input_nodes.extend(
(
*self.create_string_constant_node(self.vocab).outputs(),
*self.create_string_constant_node(self.vocab),
*as_node(self.unk_token_id).outputs(),
)
)
Expand Down Expand Up @@ -643,10 +648,10 @@ def merges_are_pairs(self) -> bool:

def get_ov_subgraph(self, input_nodes: List[Output]) -> List[Output]:
pipeline = self.get_pipeline()
pipeline.vocab_node_outputs = self.create_string_constant_node(self.vocab).outputs()
pipeline.vocab_node_outputs = self.create_string_constant_node(self.vocab)

if self.added_tokens:
special_tokens_outputs = self.create_string_constant_node(self.added_tokens).outputs()
special_tokens_outputs = self.create_string_constant_node(self.added_tokens)
else:
special_tokens_outputs = []

Expand All @@ -659,12 +664,12 @@ def get_ov_subgraph(self, input_nodes: List[Output]) -> List[Output]:
left_merges, right_merges = zip(*self.merges)
input_nodes.extend(
(
*self.create_string_constant_node(left_merges).outputs(),
*self.create_string_constant_node(right_merges).outputs(),
*self.create_string_constant_node(left_merges),
*self.create_string_constant_node(right_merges),
)
)
else:
input_nodes.extend(self.create_string_constant_node(self.merges).outputs())
input_nodes.extend(self.create_string_constant_node(self.merges))

if special_tokens_outputs:
input_nodes.extend(
Expand Down Expand Up @@ -1035,7 +1040,13 @@ def finalize(self) -> None:
self.skip_tokens = pipeline.skip_tokens or []

@classmethod
def from_hf_json(cls, tokenizer_json: Dict[str, Any], pipeline_vocab: Optional[List[str]], skip_tokens: Optional[List[int]] = None, do_skip_tokens: bool = True) -> "VocabDecoderStep":
def from_hf_json(
cls,
tokenizer_json: Dict[str, Any],
pipeline_vocab: Optional[List[str]],
skip_tokens: Optional[List[int]] = None,
do_skip_tokens: bool = True,
) -> "VocabDecoderStep":
model_type = tokenizer_json["model"]["type"]

if pipeline_vocab is not None and model_type == "WordLevel":
Expand All @@ -1057,7 +1068,7 @@ def get_ov_subgraph(self, input_nodes: List[Output]) -> List[Output]:
if self.vocab is None:
vocab_outputs = self.get_vocab_node_outputs()
else:
vocab_outputs = self.create_string_constant_node(self.vocab).outputs()
vocab_outputs = self.create_string_constant_node(self.vocab)
input_nodes.extend(vocab_outputs)

# Put constant with skip tokens even if do_skip_tokens=False, so that it can be switched on/off at runtime.
Expand Down Expand Up @@ -1178,8 +1189,8 @@ def get_ov_subgraph(self, input_nodes: List[Output]) -> List[Output]:

input_nodes.extend(
(
*self.create_string_constant_node(self.regex_search_pattern).outputs(),
*self.create_string_constant_node(self.replace_term).outputs(),
*self.create_string_constant_node(self.regex_search_pattern),
*self.create_string_constant_node(self.replace_term),
)
)
return ragged_dims + _get_factory().create("RegexNormalization", input_nodes).outputs()
Expand Down Expand Up @@ -1234,7 +1245,7 @@ def get_tokenizer_ov_subgraph(self) -> Model:

processing_outputs = []
for input_node in string_inputs:
input_node = _get_factory().create("StringTensorUnpack", input_node.outputs()).outputs()
input_node = _get_opset_factory("opset15").create("StringTensorUnpack", input_node.outputs()).outputs()

ragged = []
if isinstance(self.steps[0], SpecialTokensSplit):
Expand Down Expand Up @@ -1307,7 +1318,7 @@ def create_decoding_pipeline(self, input_nodes: List[Output]) -> List[Output]:
pipeline_step = step.get_ov_subgraph(input_nodes)
input_nodes = pipeline_step

return _get_factory().create("StringTensorPack", input_nodes).outputs()
return _get_opset_factory("opset15").create("StringTensorPack", input_nodes).outputs()

def get_detokenizer_ov_subgraph(self) -> Model:
self.finalize()
Expand Down
Loading

0 comments on commit 2e59c96

Please sign in to comment.