Skip to content

Commit

Permalink
Utf8 turn on (#326)
Browse files Browse the repository at this point in the history
* turn on UTF8Validate.REPLACE by default

* allow disabling from cli

* add correct comparison

* add default_factory
  • Loading branch information
pavel-esir authored Nov 25, 2024
1 parent 5397aa3 commit 7ce02af
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 18 deletions.
7 changes: 5 additions & 2 deletions python/openvino_tokenizers/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,14 @@ def get_parser() -> ArgumentParser:
"--utf8_replace_mode",
choices=list(UTF8ReplaceMode),
type=UTF8ReplaceMode, # enum with 'ignore', 'replace' values.
default=None,
default=UTF8ReplaceMode.REPLACE,
required=False,
help=(
"If specified then resulting strings during decoding are checked if sequence of bytes is a valid UTF-8 sequence. "
f"If mode is '{UTF8ReplaceMode.REPLACE}' then invalid characters are replaced with �, if mode is '{UTF8ReplaceMode.IGNORE}' then invalid character are skipped."
f"If mode is '{UTF8ReplaceMode.DISABLE}' then UTF8 validation is not performed at all. "
f"Two other regimes are identical to python decode method error handling parameter. "
f"If mode is '{UTF8ReplaceMode.REPLACE}' then invalid characters are replaced with �. "
f"if mode is '{UTF8ReplaceMode.IGNORE}' then invalid character are skipped and instead of them empty substring is added."
),
)
return parser
Expand Down
10 changes: 10 additions & 0 deletions python/openvino_tokenizers/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@
class UTF8ReplaceMode(Enum):
IGNORE: str = "ignore"
REPLACE: str = "replace"
DISABLE: str = "disable"

def __str__(self):
return self.value

def __eq__(self, other):
if isinstance(other, (UTF8ReplaceMode)):
# UTF8ReplaceMode is a singleton, so we can compare them by reference
return self is other
elif isinstance(other, str):
return self.value == other
else:
return False
2 changes: 1 addition & 1 deletion python/openvino_tokenizers/convert_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def convert_tokenizer(
use_max_padding: bool = False,
handle_special_tokens_with_re: Optional[bool] = None,
use_sentencepiece_backend: bool = False,
utf8_replace_mode: Optional[UTF8ReplaceMode] = None,
utf8_replace_mode: Optional[UTF8ReplaceMode] = UTF8ReplaceMode.REPLACE,
) -> Union[Model, Tuple[Model, Model]]:
"""
Converts a given tokenizer object into an OpenVINO-compatible model.
Expand Down
14 changes: 7 additions & 7 deletions python/openvino_tokenizers/hf_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,8 @@ def decoding(self) -> None:
self.pipeline.add_steps(CharsToBytesStep())
else:
self.pipeline.add_steps(FuseStep())

if self.utf8_replace_mode is not None:
if self.utf8_replace_mode is not None and (self.utf8_replace_mode != UTF8ReplaceMode.DISABLE):
self.pipeline.add_steps(UTF8ValidateStep(mode=self.utf8_replace_mode))

if self.clean_up_tokenization_spaces is None:
Expand Down Expand Up @@ -981,12 +981,12 @@ def get_sp_detokenizer(

if params.clean_up_tokenization_spaces:
detokenizer = RegexDecodingStep.clean_up_tokenization_spaces().get_ov_subgraph(detokenizer)

last_sinks = detokenizer
if params.utf8_replace_mode is not None and params.utf8_replace_mode != UTF8ReplaceMode.DISABLE:
last_sinks = UTF8ValidateStep(params.utf8_replace_mode).get_ov_subgraph(detokenizer)

if params.utf8_replace_mode is not None:
replace_mode = True if params.utf8_replace_mode is UTF8ReplaceMode.REPLACE else False
UTF8ValidateStep(mode=replace_mode).get_ov_subgraph(detokenizer)

string_output = _get_factory().create("StringTensorPack", detokenizer).outputs()
string_output = _get_factory().create("StringTensorPack", last_sinks).outputs()
string_output[0].tensor.add_names({STRING_OUTPUT_NAME})
tokenizer_detokenizer = Model(string_output, [model_input], DETOKENIZER_NAME)
tokenizer_detokenizer.validate_nodes_and_infer_types()
Expand Down
4 changes: 2 additions & 2 deletions python/openvino_tokenizers/tokenizer_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,10 +1043,10 @@ def get_ov_subgraph(self, input_nodes: List[Output]) -> List[Output]:

@dataclass
class UTF8ValidateStep(DecodingStep):
mode: UTF8ReplaceMode = UTF8ReplaceMode.IGNORE
mode: UTF8ReplaceMode = field(default_factory=lambda: UTF8ReplaceMode.IGNORE)

def get_ov_subgraph(self, input_nodes: List[Output]) -> List[Output]:
replace_mode = True if self.mode is UTF8ReplaceMode.REPLACE else False
replace_mode = True if self.mode == UTF8ReplaceMode.REPLACE else False
return _get_factory().create("UTF8Validate", input_nodes, {"replace_mode": replace_mode}).outputs()


Expand Down
6 changes: 3 additions & 3 deletions python/openvino_tokenizers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import logging
import re
from dataclasses import dataclass, fields
from dataclasses import dataclass, fields, field
from functools import lru_cache
from typing import Any, Dict, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -57,7 +57,7 @@ class TokenzierConversionParams:
utf8_replace_mode : Optional[UTF8ReplaceMode]
Specifies the UTF-8 replacement mode during tokenization.
Allowed values are UTF8ReplaceMode.IGNORE and UTF8ReplaceMode.REPLACE. Default is None.
Allowed values are UTF8ReplaceMode.DISABLE, UTF8ReplaceMode.IGNORE and UTF8ReplaceMode.REPLACE. Default is UTF8ReplaceMode.REPLACE.
"""

with_detokenizer: bool = False
Expand All @@ -70,7 +70,7 @@ class TokenzierConversionParams:
use_max_padding: bool = False
handle_special_tokens_with_re: Optional[bool] = None
use_sentencepiece_backend: bool = False
utf8_replace_mode: Optional[UTF8ReplaceMode] = None
utf8_replace_mode: Optional[UTF8ReplaceMode] = field(default_factory=lambda: UTF8ReplaceMode.REPLACE)
add_attention_mask: bool = True
add_prefix_space: Optional[bool] = None
number_of_inputs: int = 1
Expand Down
4 changes: 1 addition & 3 deletions tests/layer_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ def create_normalization_model(layer: Union[NormalizationStep, DecodingStep]) ->
@pytest.mark.parametrize("test_string", utf8_validate_strings)
@pytest.mark.parametrize("replace_mode", ["ignore", "replace"])
def test_utf8_validate(test_string, replace_mode):
utf_validation_node = UTF8ValidateStep(
UTF8ReplaceMode.REPLACE if replace_mode == "replace" else UTF8ReplaceMode.IGNORE
)
utf_validation_node = UTF8ValidateStep(UTF8ReplaceMode(replace_mode))
compiled_model = create_normalization_model(utf_validation_node)
res_ov = compiled_model([test_string])[0]
res_py = test_string.decode(errors=replace_mode)
Expand Down

0 comments on commit 7ce02af

Please sign in to comment.