From 115c349dd4bd752d5c0a56a4b83e8b0f06580457 Mon Sep 17 00:00:00 2001 From: Kaushik Ram Sadagopan Date: Tue, 17 Oct 2023 13:05:51 -0700 Subject: [PATCH 1/2] Make padding_mask Optional in SequenceBatch. --- src/fairseq2/models/sequence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fairseq2/models/sequence.py b/src/fairseq2/models/sequence.py index cbafd9257..02729267c 100644 --- a/src/fairseq2/models/sequence.py +++ b/src/fairseq2/models/sequence.py @@ -37,7 +37,7 @@ class SequenceBatch: size, :math:`S` is the sequence length, and :math:`*` is any number of sequence-specific dimensions including none.""" - padding_mask: PaddingMask + padding_mask: Optional[PaddingMask] """The padding mask of ``seqs``. *Shape:* :math:`(N,S)`, where :math:`N` is the batch size and :math:`S` is the sequence length.""" From 5ca38aeb5b22f821be2860a33e6d8968aca13473 Mon Sep 17 00:00:00 2001 From: Kaushik Ram Sadagopan Date: Tue, 17 Oct 2023 13:12:17 -0700 Subject: [PATCH 2/2] Make compute_num_tokens() in SequenceBatch to account for the corner case. --- src/fairseq2/models/sequence.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/fairseq2/models/sequence.py b/src/fairseq2/models/sequence.py index 02729267c..e67a59451 100644 --- a/src/fairseq2/models/sequence.py +++ b/src/fairseq2/models/sequence.py @@ -51,6 +51,9 @@ def batch_size(self) -> int: def compute_num_tokens(self) -> Tensor: """Compute the number of tokens in this batch.""" + if self.padding_mask is None: + return torch.full((), self.seqs.numel(), device=self.seqs.device) + return self.padding_mask.seq_lens.sum()