From bae418c02617fbb5b2ae2a7fe860e814c946ae60 Mon Sep 17 00:00:00 2001 From: Mathieu Bernard Date: Fri, 28 Jun 2024 19:19:10 +0200 Subject: [PATCH] fixing issue #169 --- phonemizer/backend/espeak/espeak.py | 2 +- phonemizer/backend/espeak/words_mismatch.py | 19 +++++++++++++----- test/test_espeak_word_mismatch.py | 22 +++++++++++++++++++-- 3 files changed, 35 insertions(+), 8 deletions(-) diff --git a/phonemizer/backend/espeak/espeak.py b/phonemizer/backend/espeak/espeak.py index 34e3faf..766af5b 100644 --- a/phonemizer/backend/espeak/espeak.py +++ b/phonemizer/backend/espeak/espeak.py @@ -150,7 +150,7 @@ def _phonemize_postprocess(self, phonemized, punctuation_marks, separator: Separ text = phonemized[0] switches = phonemized[1] - self._words_mismatch.count_phonemized(text) + self._words_mismatch.count_phonemized(text, separator) self._lang_switch.warning(switches) phonemized = super()._phonemize_postprocess(text, punctuation_marks, separator, strip) diff --git a/phonemizer/backend/espeak/words_mismatch.py b/phonemizer/backend/espeak/words_mismatch.py index bc5be0c..30f206b 100644 --- a/phonemizer/backend/espeak/words_mismatch.py +++ b/phonemizer/backend/espeak/words_mismatch.py @@ -19,7 +19,10 @@ from logging import Logger from typing import List, Tuple -from typing_extensions import TypeAlias, Literal +from typing_extensions import TypeAlias, Literal, Union + +from phonemizer.separator import Separator + WordMismatch: TypeAlias = Literal["warn", "ignore"] @@ -58,10 +61,16 @@ def __init__(self, logger: Logger): self._count_phn = [] @classmethod - def _count_words(cls, text: List[str]) -> List[int]: + def _count_words( + cls, + text: List[str], + wordsep: Union[str, re.Pattern] = _RE_SPACES) -> List[int]: """Return the number of words contained in each line of `text`""" + if not isinstance(wordsep, re.Pattern): + wordsep = re.escape(wordsep) + return [ - len([w for w in cls._RE_SPACES.split(line.strip()) if w]) + len([w for w in re.split(wordsep, line.strip()) if w]) for line in text] def _mismatched_lines(self) -> List[Tuple[int, int, int]]: @@ -93,9 +102,9 @@ def count_text(self, text: List[str]): """Stores the number of words in each input line""" self._count_txt = self._count_words(text) - def count_phonemized(self, text: List[str]): + def count_phonemized(self, text: List[str], separator: Separator): """Stores the number of words in each output line""" - self._count_phn = self._count_words(text) + self._count_phn = self._count_words(text, separator.word) @abc.abstractmethod def process(self, text: List[str]) -> List[str]: diff --git a/test/test_espeak_word_mismatch.py b/test/test_espeak_word_mismatch.py index 5d77014..37ae0d3 100644 --- a/test/test_espeak_word_mismatch.py +++ b/test/test_espeak_word_mismatch.py @@ -5,8 +5,11 @@ import pytest -from phonemizer.backend.espeak.words_mismatch import Ignore +import re + from phonemizer import phonemize +from phonemizer.backend.espeak.words_mismatch import Ignore +from phonemizer.separator import Separator, default_separator @pytest.fixture @@ -16,7 +19,8 @@ def text(): def test_count_words(): # pylint: disable=protected-access - count_words = Ignore._count_words + count_words = lambda phn: Ignore._count_words( + phn, wordsep=default_separator.word) assert count_words(['']) == [0] assert count_words(['a']) == [1] assert count_words(['aaa']) == [1] @@ -59,3 +63,17 @@ def test_mismatch(caplog, text, mode): 'words count mismatch on line 3 (expected 4 words but get 3)' in messages) assert 'words count mismatch on 67.0% of the lines (2/3)' in messages + + +# from https://github.com/bootphon/phonemizer/issues/169 +def test_custom_separator(caplog): + phn = phonemize( + 'try', + backend='espeak', + language='en-us', + separator=Separator(word='|', phone=' '), + words_mismatch='warn') + + assert phn == 't ɹ aɪ |' + messages = [msg[2] for msg in caplog.record_tuples] + assert len(messages) == 0