diff --git a/fairseq2n/python/src/fairseq2n/bindings/data/string.cc b/fairseq2n/python/src/fairseq2n/bindings/data/string.cc index 15303b898..9a9c43c25 100644 --- a/fairseq2n/python/src/fairseq2n/bindings/data/string.cc +++ b/fairseq2n/python/src/fairseq2n/bindings/data/string.cc @@ -75,6 +75,12 @@ def_string(py::module_ &data_module) return py::bytes(static_cast(self)); }) + .def( + "strip", + [](const immutable_string &self) + { + return rtrim(ltrim(self)); + }) .def( "lstrip", [](const immutable_string &self) diff --git a/fairseq2n/src/fairseq2n/data/audio/waveform_to_fbank_converter.cc b/fairseq2n/src/fairseq2n/data/audio/waveform_to_fbank_converter.cc index 860c409e3..ad3b3a52c 100644 --- a/fairseq2n/src/fairseq2n/data/audio/waveform_to_fbank_converter.cc +++ b/fairseq2n/src/fairseq2n/data/audio/waveform_to_fbank_converter.cc @@ -111,7 +111,7 @@ waveform_to_fbank_converter::find_waveform(data_dict &dict) if (waveform.dim() != 2) throw_( - "The input waveform must be two dimensional, but has {} dimensions instead.", waveform.dim()); + "The input waveform must be two dimensional, but has {} dimension(s) instead.", waveform.dim()); return waveform; } diff --git a/fairseq2n/src/fairseq2n/data/text/sentencepiece/sp_decoder.cc b/fairseq2n/src/fairseq2n/data/text/sentencepiece/sp_decoder.cc index cdbad02ba..b597ba254 100644 --- a/fairseq2n/src/fairseq2n/data/text/sentencepiece/sp_decoder.cc +++ b/fairseq2n/src/fairseq2n/data/text/sentencepiece/sp_decoder.cc @@ -35,7 +35,7 @@ class sp_decoder_op { explicit sp_decoder_op(const sp_decoder *decoder, const sp_processor *processor, at::Tensor &&tensor); - data_list && + immutable_string && run() &&; private: @@ -50,26 +50,22 @@ class sp_decoder_op { const sp_decoder *decoder_; const sp_processor *processor_; at::Tensor tensor_; - data_list sentences_{}; + immutable_string sentence_{}; }; sp_decoder_op::sp_decoder_op( const sp_decoder *decoder, const sp_processor *processor, at::Tensor &&tensor) : decoder_{decoder}, processor_{processor}, tensor_{std::move(tensor)} -{ - auto batch_size = static_cast(tensor_.size(0)); - - sentences_.reserve(batch_size); -} +{} -data_list && +immutable_string && sp_decoder_op::run() && { tensor_ = tensor_.to(at::kCPU); decode(); - return std::move(sentences_); + return std::move(sentence_); } void @@ -98,31 +94,25 @@ template void sp_decoder_op::decode() { - std::int64_t seq_len = tensor_.size(1); + std::int64_t seq_len = tensor_.size(0); std::vector tokens{}; tokens.reserve(static_cast(seq_len)); - auto tensor_data = tensor_.accessor(); - - for (std::int64_t i = 0; i < tensor_.size(0); ++i) { - tokens.clear(); + auto tensor_data = tensor_.accessor(); - for (std::int64_t j = 0; j < seq_len; j++) { - T token_idx = tensor_data[i][decoder_->reverse_ ? seq_len - 1 - j : j]; + for (std::int64_t j = 0; j < seq_len; j++) { + T token_idx = tensor_data[decoder_->reverse_ ? seq_len - 1 - j : j]; - auto token_idx_32bit = conditional_cast(token_idx); + auto token_idx_32bit = conditional_cast(token_idx); - std::string_view token = processor_->index_to_token(token_idx_32bit); + std::string_view token = processor_->index_to_token(token_idx_32bit); - tokens.push_back(token); - } - - std::string sentence = processor_->decode(tokens); - - sentences_.emplace_back(std::move(sentence)); + tokens.push_back(token); } + + sentence_ = processor_->decode(tokens); } } // namespace detail @@ -140,17 +130,14 @@ sp_decoder::operator()(data &&d) const at::Tensor tensor = d.as_tensor(); - if (tensor.dim() == 0 || tensor.dim() > 2) + if (tensor.dim() != 1) throw_( - "The input tensor must be one or two dimensional, but has {} dimensions instead.", tensor.dim()); - - if (tensor.dim() == 1) - tensor = tensor.unsqueeze(0); + "The input tensor must be one dimensional, but has {} dimension(s) instead.", tensor.dim()); return decode(std::move(tensor)); } -data_list +immutable_string sp_decoder::decode(at::Tensor &&tensor) const { return sp_decoder_op{this, model_->processor_.get(), std::move(tensor)}.run(); diff --git a/fairseq2n/src/fairseq2n/data/text/sentencepiece/sp_decoder.h b/fairseq2n/src/fairseq2n/data/text/sentencepiece/sp_decoder.h index dde193cf3..a72cdd7b6 100644 --- a/fairseq2n/src/fairseq2n/data/text/sentencepiece/sp_decoder.h +++ b/fairseq2n/src/fairseq2n/data/text/sentencepiece/sp_decoder.h @@ -13,6 +13,7 @@ #include "fairseq2n/api.h" #include "fairseq2n/data/data.h" +#include "fairseq2n/data/immutable_string.h" namespace fairseq2n { namespace detail { @@ -34,7 +35,7 @@ class FAIRSEQ2_API sp_decoder final { operator()(data &&d) const; private: - data_list + immutable_string decode(at::Tensor &&tensor) const; private: diff --git a/src/fairseq2/data/cstring.py b/src/fairseq2/data/cstring.py index f292537ab..ff0988085 100644 --- a/src/fairseq2/data/cstring.py +++ b/src/fairseq2/data/cstring.py @@ -45,6 +45,9 @@ def __hash__(self) -> int: def bytes(self) -> bytes: """Return a copy of this string as :class:`bytes`.""" + def strip(self) -> "CString": + """Return a copy of this string with no whitespace at the beginning and end.""" + def lstrip(self) -> "CString": """Return a copy of this string with no whitespace at the beginning.""" diff --git a/src/fairseq2/data/text/sentencepiece.py b/src/fairseq2/data/text/sentencepiece.py index 327748a27..e3086ad82 100644 --- a/src/fairseq2/data/text/sentencepiece.py +++ b/src/fairseq2/data/text/sentencepiece.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import TYPE_CHECKING, List, Optional, Sequence, final +from typing import TYPE_CHECKING, Optional, Sequence, final from torch import Tensor @@ -87,7 +87,7 @@ def __init__(self, model: SentencePieceModel, reverse: bool = False) -> None: ... @finaloverride - def __call__(self, token_indices: Tensor) -> List[StringLike]: + def __call__(self, token_indices: Tensor) -> StringLike: ... else: diff --git a/src/fairseq2/data/text/text_tokenizer.py b/src/fairseq2/data/text/text_tokenizer.py index 87f00170b..1adc1204c 100644 --- a/src/fairseq2/data/text/text_tokenizer.py +++ b/src/fairseq2/data/text/text_tokenizer.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod -from typing import List, Optional +from typing import Optional from torch import Tensor @@ -92,7 +92,7 @@ class TextTokenDecoder(ABC): """Decodes sentences from token indices.""" @abstractmethod - def __call__(self, token_indices: Tensor) -> List[StringLike]: + def __call__(self, token_indices: Tensor) -> StringLike: """ :param token_indices: The token indices to decode from. diff --git a/src/fairseq2/generation/text.py b/src/fairseq2/generation/text.py index 9d477c637..5bebe73f2 100644 --- a/src/fairseq2/generation/text.py +++ b/src/fairseq2/generation/text.py @@ -77,7 +77,7 @@ def _do_generate( encoder_output, encoder_padding_mask, source_seq_len=source_seqs.size(1) ) - sentences = [self.token_decoder(b[0].seq)[0] for b in gen_output.results] + sentences = [self.token_decoder(b[0].seq) for b in gen_output.results] return SequenceToTextOutput( sentences, gen_output, encoder_output, encoder_padding_mask diff --git a/tests/unit/data/audio/test_waveform_to_fbank_converter.py b/tests/unit/data/audio/test_waveform_to_fbank_converter.py index b2a5bdaa5..89c074999 100644 --- a/tests/unit/data/audio/test_waveform_to_fbank_converter.py +++ b/tests/unit/data/audio/test_waveform_to_fbank_converter.py @@ -16,11 +16,15 @@ WaveformToFbankConverter, ) from fairseq2.memory import MemoryBlock -from tests.common import assert_equal, device +from tests.common import assert_equal, device, python_devel_only TEST_OGG_PATH: Final = Path(__file__).parent.joinpath("test.ogg") +@pytest.mark.skipif( + python_devel_only(), + reason="New fairseq2n API in Python-only installation. Skipping till v0.2.", +) class TestWaveformToFbankConverter: def test_call_works(self) -> None: audio = self.get_audio() @@ -185,7 +189,7 @@ def test_call_raises_error_when_waveform_is_not_two_dimensional( with pytest.raises( ValueError, - match=rf"^The input waveform must be two dimensional, but has {len(shape)} dimensions instead\.$", + match=rf"^The input waveform must be two dimensional, but has {len(shape)} dimension\(s\) instead\.$", ): converter({"waveform": waveform, "sample_rate": 16000.0}) diff --git a/tests/unit/data/text/test_sentencepiece.py b/tests/unit/data/text/test_sentencepiece.py index 3db6e5d77..11a3f1e4b 100644 --- a/tests/unit/data/text/test_sentencepiece.py +++ b/tests/unit/data/text/test_sentencepiece.py @@ -11,32 +11,33 @@ import pytest import torch +from fairseq2.data import CString from fairseq2.data.text import ( SentencePieceDecoder, SentencePieceEncoder, SentencePieceModel, ) from fairseq2.typing import DataType -from tests.common import assert_equal, device +from tests.common import assert_equal, device, python_devel_only TEST_SPM_PATH: Final = Path(__file__).parent.joinpath("test.spm") +@pytest.mark.skipif( + python_devel_only(), + reason="New fairseq2n API in Python-only installation. Skipping till v0.2.", +) class TestSentencePieceModel: - sentences: ClassVar[List[str]] - token_indices: ClassVar[List[List[int]]] + sentence: ClassVar[str] + token_indices: ClassVar[List[int]] @classmethod def setup_class(cls) -> None: - cls.sentences = [ - "Hello world! How are you?", - "What's up? Hope you are doing well today.", - ] + cls.sentence = "What's up? Hope you are doing well today." # fmt: off cls.token_indices = [ - [132, 30, 131, 114, 52, 418, 68, 166, 106, 40, 11], - [169, 87, 5, 227, 11, 424, 294, 40, 106, 120, 26, 597, 19, 303, 4] + 169, 87, 5, 227, 11, 424, 294, 40, 106, 120, 26, 597, 19, 303, 4 ] # fmt: on @@ -78,19 +79,17 @@ def test_encode_decode_work(self) -> None: encoder = SentencePieceEncoder(model, device=device) decoder = SentencePieceDecoder(model) - indices = encoder(self.sentences[0]) + indices = encoder(self.sentence) # Assert encoder. - assert_equal(indices, self.token_indices[0]) + assert_equal(indices, self.token_indices) - sentences = decoder(indices) + sentence = decoder(indices) # Assert decoder - assert isinstance(sentences, list) - - assert len(sentences) == 1 + assert isinstance(sentence, CString) - assert sentences[0] == self.sentences[0] + assert sentence == self.sentence def test_encode_decode_work_when_reverse_is_true(self) -> None: model = self.build_model() @@ -98,19 +97,17 @@ def test_encode_decode_work_when_reverse_is_true(self) -> None: encoder = SentencePieceEncoder(model, device=device, reverse=True) decoder = SentencePieceDecoder(model, reverse=True) - indices = encoder(self.sentences[0]) + indices = encoder(self.sentence) # Assert encoder. - assert_equal(indices, self.token_indices[0][::-1]) + assert_equal(indices, self.token_indices[::-1]) - sentences = decoder(indices) + sentence = decoder(indices) # Assert decoder. - assert isinstance(sentences, list) + assert isinstance(sentence, CString) - assert len(sentences) == 1 - - assert sentences[0] == self.sentences[0] + assert sentence == self.sentence def test_decode_works_when_control_symbols_are_specified(self) -> None: model = self.build_model(control_symbols=[""]) @@ -118,10 +115,10 @@ def test_decode_works_when_control_symbols_are_specified(self) -> None: encoder = SentencePieceEncoder(model, device=device) decoder = SentencePieceDecoder(model) - indices = encoder(self.sentences[0]) + indices = encoder(self.sentence) # Assert encoder. - assert_equal(indices, self.token_indices[0]) + assert_equal(indices, self.token_indices) # We inject a dummy token to the returned tokens. foo_idx = model.token_to_index("") @@ -131,14 +128,12 @@ def test_decode_works_when_control_symbols_are_specified(self) -> None: indices = torch.cat([indices[:2], foo, indices[2:]]) # We expect the decoder to ignore the tokens. - sentences = decoder(indices) + sentence = decoder(indices) # Assert decoder. - assert isinstance(sentences, list) - - assert len(sentences) == 1 + assert isinstance(sentence, CString) - assert sentences[0] == self.sentences[0] + assert sentence == self.sentence def test_encode_works_when_prefix_and_suffix_tokens_are_specified(self) -> None: model = self.build_model(control_symbols=["", "", ""]) @@ -151,45 +146,25 @@ def test_encode_works_when_prefix_and_suffix_tokens_are_specified(self) -> None: ) decoder = SentencePieceDecoder(model) - indices = encoder(self.sentences[0]) + indices = encoder(self.sentence) # Assert encoder. foo1_idx = model.token_to_index("") foo2_idx = model.token_to_index("") foo3_idx = model.token_to_index("") - e = ( - [foo1_idx, model.bos_idx] - + self.token_indices[0] - + [foo2_idx, model.eos_idx, foo3_idx] - ) + bos_idx = model.bos_idx + eos_idx = model.eos_idx + + e = [foo1_idx, bos_idx] + self.token_indices + [foo2_idx, eos_idx, foo3_idx] assert_equal(indices, e) # We expect the decoder to ignore the prefix and suffix tokens. - sentences = decoder(indices) + sentence = decoder(indices) # Assert decoder. - assert sentences[0] == self.sentences[0] - - @pytest.mark.parametrize("dtype", [torch.int16, torch.int32, torch.int64]) - def test_decode_works_when_input_is_batched(self, dtype: DataType) -> None: - model = self.build_model() - - decoder = SentencePieceDecoder(model) - - indices1 = torch.tensor(self.token_indices[0], device=device, dtype=dtype) - indices2 = torch.tensor(self.token_indices[1], device=device, dtype=dtype) - - batch = torch.nn.utils.rnn.pad_sequence([indices1, indices2], batch_first=True) - - sentences = decoder(batch) - - assert isinstance(sentences, list) - - assert len(sentences) == 2 - - assert sentences == self.sentences + assert sentence == self.sentence @pytest.mark.parametrize("dtype", [torch.float32, torch.int8]) def test_decode_raises_error_when_data_type_is_not_supported( @@ -208,7 +183,7 @@ def test_decode_raises_error_when_data_type_is_not_supported( decoder(indices) @pytest.mark.parametrize("shape", [(), (4, 4, 4)]) - def test_decode_raises_error_when_input_has_more_than_2_dimensions( + def test_decode_raises_error_when_input_has_more_than_1_dimension( self, shape: Sequence[int] ) -> None: model = self.build_model() @@ -219,7 +194,7 @@ def test_decode_raises_error_when_input_has_more_than_2_dimensions( with pytest.raises( ValueError, - match=rf"^The input tensor must be one or two dimensional, but has {len(shape)} dimensions instead\.$", + match=rf"^The input tensor must be one dimensional, but has {len(shape)} dimension\(s\) instead\.$", ): decoder(indices)