Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept only non-batched tensor in TextTokenDecoder #118

Merged
merged 2 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions fairseq2n/python/src/fairseq2n/bindings/data/string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ def_string(py::module_ &data_module)
return py::bytes(static_cast<std::string_view>(self));
})

.def(
"strip",
[](const immutable_string &self)
{
return rtrim(ltrim(self));
})
.def(
"lstrip",
[](const immutable_string &self)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ waveform_to_fbank_converter::find_waveform(data_dict &dict)

if (waveform.dim() != 2)
throw_<std::invalid_argument>(
"The input waveform must be two dimensional, but has {} dimensions instead.", waveform.dim());
"The input waveform must be two dimensional, but has {} dimension(s) instead.", waveform.dim());

return waveform;
}
Expand Down
47 changes: 17 additions & 30 deletions fairseq2n/src/fairseq2n/data/text/sentencepiece/sp_decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class sp_decoder_op {
explicit
sp_decoder_op(const sp_decoder *decoder, const sp_processor *processor, at::Tensor &&tensor);

data_list &&
immutable_string &&
run() &&;

private:
Expand All @@ -50,26 +50,22 @@ class sp_decoder_op {
const sp_decoder *decoder_;
const sp_processor *processor_;
at::Tensor tensor_;
data_list sentences_{};
immutable_string sentence_{};
};

sp_decoder_op::sp_decoder_op(
const sp_decoder *decoder, const sp_processor *processor, at::Tensor &&tensor)
: decoder_{decoder}, processor_{processor}, tensor_{std::move(tensor)}
{
auto batch_size = static_cast<std::size_t>(tensor_.size(0));

sentences_.reserve(batch_size);
}
{}

data_list &&
immutable_string &&
sp_decoder_op::run() &&
{
tensor_ = tensor_.to(at::kCPU);

decode();

return std::move(sentences_);
return std::move(sentence_);
}

void
Expand Down Expand Up @@ -98,31 +94,25 @@ template <typename T>
void
sp_decoder_op::decode()
{
std::int64_t seq_len = tensor_.size(1);
std::int64_t seq_len = tensor_.size(0);

std::vector<std::string_view> tokens{};

tokens.reserve(static_cast<std::size_t>(seq_len));

auto tensor_data = tensor_.accessor<T, 2>();

for (std::int64_t i = 0; i < tensor_.size(0); ++i) {
tokens.clear();
auto tensor_data = tensor_.accessor<T, 1>();

for (std::int64_t j = 0; j < seq_len; j++) {
T token_idx = tensor_data[i][decoder_->reverse_ ? seq_len - 1 - j : j];
for (std::int64_t j = 0; j < seq_len; j++) {
T token_idx = tensor_data[decoder_->reverse_ ? seq_len - 1 - j : j];

auto token_idx_32bit = conditional_cast<std::int32_t>(token_idx);
auto token_idx_32bit = conditional_cast<std::int32_t>(token_idx);

std::string_view token = processor_->index_to_token(token_idx_32bit);
std::string_view token = processor_->index_to_token(token_idx_32bit);

tokens.push_back(token);
}

std::string sentence = processor_->decode(tokens);

sentences_.emplace_back(std::move(sentence));
tokens.push_back(token);
}

sentence_ = processor_->decode(tokens);
}

} // namespace detail
Expand All @@ -140,17 +130,14 @@ sp_decoder::operator()(data &&d) const

at::Tensor tensor = d.as_tensor();

if (tensor.dim() == 0 || tensor.dim() > 2)
if (tensor.dim() != 1)
throw_<std::invalid_argument>(
"The input tensor must be one or two dimensional, but has {} dimensions instead.", tensor.dim());

if (tensor.dim() == 1)
tensor = tensor.unsqueeze(0);
"The input tensor must be one dimensional, but has {} dimension(s) instead.", tensor.dim());

return decode(std::move(tensor));
}

data_list
immutable_string
sp_decoder::decode(at::Tensor &&tensor) const
{
return sp_decoder_op{this, model_->processor_.get(), std::move(tensor)}.run();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "fairseq2n/api.h"
#include "fairseq2n/data/data.h"
#include "fairseq2n/data/immutable_string.h"

namespace fairseq2n {
namespace detail {
Expand All @@ -34,7 +35,7 @@ class FAIRSEQ2_API sp_decoder final {
operator()(data &&d) const;

private:
data_list
immutable_string
decode(at::Tensor &&tensor) const;

private:
Expand Down
3 changes: 3 additions & 0 deletions src/fairseq2/data/cstring.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def __hash__(self) -> int:
def bytes(self) -> bytes:
"""Return a copy of this string as :class:`bytes`."""

def strip(self) -> "CString":
"""Return a copy of this string with no whitespace at the beginning and end."""

def lstrip(self) -> "CString":
"""Return a copy of this string with no whitespace at the beginning."""

Expand Down
4 changes: 2 additions & 2 deletions src/fairseq2/data/text/sentencepiece.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import TYPE_CHECKING, List, Optional, Sequence, final
from typing import TYPE_CHECKING, Optional, Sequence, final

from torch import Tensor

Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(self, model: SentencePieceModel, reverse: bool = False) -> None:
...

@finaloverride
def __call__(self, token_indices: Tensor) -> List[StringLike]:
def __call__(self, token_indices: Tensor) -> StringLike:
...

else:
Expand Down
4 changes: 2 additions & 2 deletions src/fairseq2/data/text/text_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from typing import List, Optional
from typing import Optional

from torch import Tensor

Expand Down Expand Up @@ -92,7 +92,7 @@ class TextTokenDecoder(ABC):
"""Decodes sentences from token indices."""

@abstractmethod
def __call__(self, token_indices: Tensor) -> List[StringLike]:
def __call__(self, token_indices: Tensor) -> StringLike:
"""
:param token_indices:
The token indices to decode from.
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/generation/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _do_generate(
encoder_output, encoder_padding_mask, source_seq_len=source_seqs.size(1)
)

sentences = [self.token_decoder(b[0].seq)[0] for b in gen_output.results]
sentences = [self.token_decoder(b[0].seq) for b in gen_output.results]

return SequenceToTextOutput(
sentences, gen_output, encoder_output, encoder_padding_mask
Expand Down
8 changes: 6 additions & 2 deletions tests/unit/data/audio/test_waveform_to_fbank_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@
WaveformToFbankConverter,
)
from fairseq2.memory import MemoryBlock
from tests.common import assert_equal, device
from tests.common import assert_equal, device, python_devel_only

TEST_OGG_PATH: Final = Path(__file__).parent.joinpath("test.ogg")


@pytest.mark.skipif(
python_devel_only(),
reason="New fairseq2n API in Python-only installation. Skipping till v0.2.",
)
class TestWaveformToFbankConverter:
def test_call_works(self) -> None:
audio = self.get_audio()
Expand Down Expand Up @@ -185,7 +189,7 @@ def test_call_raises_error_when_waveform_is_not_two_dimensional(

with pytest.raises(
ValueError,
match=rf"^The input waveform must be two dimensional, but has {len(shape)} dimensions instead\.$",
match=rf"^The input waveform must be two dimensional, but has {len(shape)} dimension\(s\) instead\.$",
):
converter({"waveform": waveform, "sample_rate": 16000.0})

Expand Down
Loading
Loading