Skip to content

Commit

Permalink
Revise LogitsProcessor
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu committed Nov 5, 2023
1 parent 83b57fd commit 01ee16c
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 297 deletions.
2 changes: 1 addition & 1 deletion src/fairseq2/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from fairseq2.generation.beam_search import BeamSearch as BeamSearch
from fairseq2.generation.beam_search import StandardBeamSearch as StandardBeamSearch
from fairseq2.generation.logits_processor import (
BannedSequenceLogitsProcessor as BannedSequenceLogitsProcessor,
BannedSequenceProcessor as BannedSequenceProcessor,
)
from fairseq2.generation.logits_processor import LogitsProcessor as LogitsProcessor
from fairseq2.generation.sequence_generator import Hypothesis as Hypothesis
Expand Down
219 changes: 87 additions & 132 deletions src/fairseq2/generation/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,172 +4,127 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import sys
from abc import ABC, abstractmethod
from typing import List, Optional, Sequence, Tuple
from typing import List, Optional, Sequence, final

import torch
from torch import Tensor
from torch.nn.functional import pad

from fairseq2.data.text.text_tokenizer import TextTokenEncoder
from fairseq2.data.typing import StringLike
from fairseq2.typing import Device
from fairseq2.typing import finaloverride


class LogitsProcessor(ABC):
"""Abstracte base class for updating scores in place"""
"""Processes next-step probabilities during sequence generation."""

@abstractmethod
def __call__(self, seqs: Tensor, lprobs: Tensor) -> None:
"""Update next-step log probabilities inplace based on given token sequence.
def __call__(self, seqs: Tensor, probs: Tensor, lprob: bool = False) -> None:
"""
:param seqs:
The sequence of tokens generated in current beam search step.
:math:`(N,B,S)`, where :math:`N` is the batch size, :math:`B` is
the number of beams, and :math:`S` is the size of the sequence.
:param lprobs:
The next-step log probability of each vocabulary entry. *Shape:*
:math:`(N,B,V)`, where :math:`N` is the batch size, :math:`B` is
the number of beams, and :math:`V` is the size of the vocabulary.
:returns:
None
The sequences that are in process of being generated. *Shape:*
:math:`(N,S)`, where :math`N` is the batch size and :math:`S` is the
sequence length generated so far.
:param probs:
The next-step probabilities of ``seqs``. *Shape:* :math:`(N,V)`,
where :math:`N` is the batch size and :math:`V` is the size of the
target vocabulary.
:param lprob:
If ``True``, ``probs`` contains log probabilities.
"""


class BannedSequenceLogitsProcessor(LogitsProcessor):
"""Processor used to penalize scores of multiple banned sequences of words."""
@final
class BannedSequenceProcessor(LogitsProcessor):
"""Prevents a provided list of banned sequences from being generated."""

banned_tokens: Tensor
"""Vector of shape (nb_banned_sequences, 1) containing last token of each sequence to ban."""
_banned_seqs: Optional[Tensor]
_banned_mask: Optional[Tensor]

banned_prefix: Tensor
"""Matrix of shape (nb_banned_sequences, max_banned_tokens_len - 1) padded with 0s on the left."""
def __init__(self, banned_seqs: Sequence[Tensor]) -> None:
"""
:param banned_seqs:
The list of banned sequences.
"""
batch_size = len(banned_seqs)

banned_prefix_mask: Tensor
"""mask of 0s and 1s based on each banned token sequence and max prefix len."""
if batch_size == 0:
self._banned_seqs = None
self._banned_mask = None

max_prefix_len: int
"""length of biggest banned sequence - 1."""
return

pad_idx: int
"""Padding index used for encoding banned sequences."""
max_seq_len = 0
min_seq_len = sys.maxsize

device: Device
"""device used for all inner tensors."""
seq_lens: List[int] = []

def __init__(self, banned_seqs: List[Tensor], pad_idx: int, device: Device) -> None:
"""
:param banned_seqs:
list of token sequences to ban.
:param pad_idx:
padding index used for encoding banned sequences.
:param device:
device
"""
if len(banned_seqs) == 0:
raise ValueError("`banned_seqs` should contain at least one element.")
if any([t.ndim != 1 for t in banned_seqs]):
raise ValueError(
"`banned_seqs` should contain only one dimensional tensors."
)
for idx, seq in enumerate(banned_seqs):
seq_len = len(seq)
if seq_len == 0:
raise ValueError(f"`banned_seqs[{idx}]` must not be empty.")

self.pad_idx = pad_idx
self.device = device
seq_lens.append(seq_len)

self.max_prefix_len = max([len(t) - 1 for t in banned_seqs])
self.banned_prefix = self._create_pad_tensor(
size=(len(banned_seqs), self.max_prefix_len)
)
self.banned_tokens = torch.empty(
size=(len(banned_seqs), 1), dtype=torch.int64, device=self.device
)
for i, seq in enumerate(banned_seqs):
if (len(seq)) > 1:
self.banned_prefix[i, -len(seq) + 1 :] = seq[:-1]
self.banned_tokens[i] = seq[-1]
max_seq_len = max(seq_len, max_seq_len)
min_seq_len = min(seq_len, min_seq_len)

self.banned_prefix_mask = torch.where(
self.banned_prefix == self.pad_idx, 0, 1
).to(device=self.device)
device = banned_seqs[0].device

def __call__(self, seqs: Tensor, lprobs: Tensor) -> None:
"""Apply score penalty of banend tokens inplace"""
seqs = self._pad_left_short_sequence(seqs)
# (N, S)
self._banned_seqs = torch.zeros(
(batch_size, max_seq_len), device=device, dtype=torch.int64
)

if self.max_prefix_len == 0:
lprobs[:, :, self.banned_tokens] = -torch.inf
else:
prefix_diff = (
seqs[:, :, -self.max_prefix_len :].unsqueeze(2)
* self.banned_prefix_mask
- self.banned_prefix
if max_seq_len != min_seq_len:
# (N, S)
self._banned_mask = torch.full(
(batch_size, max_seq_len), True, device=device
)
batch_idx, beam_idx, match_idx = (prefix_diff.sum(dim=-1) == 0).nonzero(
as_tuple=True
)
if len(batch_idx) > 0:
lprobs[batch_idx, beam_idx, self.banned_tokens[match_idx]] = -torch.inf

def _pad_left_short_sequence(self, tokens: Tensor) -> Tensor:
batch_size, beam_size, seq_len = tokens.shape
if seq_len < self.max_prefix_len:
tmp = self._create_pad_tensor(
size=(batch_size, beam_size, self.max_prefix_len)
)
tmp[:, :, -seq_len:] = tokens
tokens = tmp
else:
self._banned_mask = None

return tokens
for row, seq in enumerate(banned_seqs):
if self._banned_mask is None:
self._banned_seqs[row] = seq
else:
self._banned_seqs[row, -seq_lens[row] :] = seq
self._banned_mask[row, -seq_lens[row] :] = False

def _create_pad_tensor(self, size: Tuple[int, ...]) -> Tensor:
return torch.full(
size=size,
fill_value=self.pad_idx,
dtype=torch.int64,
device=self.device,
)
@finaloverride
def __call__(self, seqs: Tensor, probs: Tensor, lprob: bool = False) -> None:
if self._banned_seqs is None:
return

# This is not the best place but the whole file needs a refactoring
# We need target decoder to create this tensor
@staticmethod
def compute_banned_words_seqs(
banned_strings: Sequence[StringLike],
token_encoder: TextTokenEncoder,
) -> List[Tensor]:
"""Compute sequences of tokens to ban from encoder and banned strings
:params banned_strings:
The list of strings to ban in sequence generation.
:params token_encoder:
Encoder to use for tokenizing input strings.
:returns:
List of token sequences to ban.
"""
if not banned_strings:
return []
ban_value = -torch.inf if lprob else 0

control_tokens = BannedSequenceLogitsProcessor._concat_optional_tensors(
[token_encoder.prefix_indices, token_encoder.suffix_indices]
)
banned_prefix_len = self._banned_seqs.size(1) - 1
if banned_prefix_len == 0:
probs[:, self._banned_seqs[:, 0]] = ban_value

return

if (len_delta := banned_prefix_len - seqs.size(1)) > 0:
# (N, S) -> (N, S_pre)
seqs = pad(seqs, (len_delta, -1))
elif len_delta < 0:
# (N, S) -> (N, S_pre)
seqs = seqs[:, -banned_prefix_len:]

def encode(s: StringLike) -> torch.Tensor:
seq = token_encoder(s)
if control_tokens is None:
return seq
# (N, S_pre) -> (N, 1, S_pre)
seqs = seqs.unsqueeze(1)

mask = torch.isin(seq, control_tokens, invert=True)
return seq[mask]
# (N, 1, S_pre) - (B, S_pre) -> (N, B, S_pre)
seqs = seqs - self._banned_seqs[:, :-1]

return [encode(x) for x in banned_strings]
if self._banned_mask is not None:
seqs.masked_fill_(self._banned_mask[:, :-1], 0)

@staticmethod
def _concat_optional_tensors(tensors: List[Optional[Tensor]]) -> Optional[Tensor]:
not_none_tensors = [t for t in tensors if t is not None]
# (N, B, S_pre) -> (N, B)
banned_prefix_matches = seqs.sum(dim=-1)

result: Optional[Tensor] = None
if len(not_none_tensors) > 0:
result = torch.cat(not_none_tensors).unique()
# (N, B) -> (N), (B)
batch_indices, banned_indices = torch.where(banned_prefix_matches == 0)

return result
if len(batch_indices) > 0:
probs[batch_indices, self._banned_seqs[:, -1][banned_indices]] = ban_value
5 changes: 2 additions & 3 deletions src/fairseq2/generation/sequence_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,10 @@ def __call__(
if self.unk_idx is not None:
lprobs[:, :, self.unk_idx] -= self.opts.unk_penalty

# update scores in place using logits_processor
# Update `lprobs` in-place if requested.
if self.logits_processor is not None:
self.logits_processor(
seqs.view(num_searches, beam_size, -1)[:, :, : step_nr + 1],
lprobs.view(num_searches, beam_size, -1),
seqs[:, : step_nr + 1], lprobs.squeeze(1), lprob=True
)

# Determine candidates for the next step.
Expand Down
18 changes: 9 additions & 9 deletions src/fairseq2/nn/utils/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Tuple, cast
from typing import Optional, Tuple

import torch
from torch import Tensor
Expand Down Expand Up @@ -48,7 +48,8 @@ def compute_row_mask(
might be smaller. The implementation also guarantees that there is
always at least one unmasked element in each row.
:param row_lens:
The length of each row.
The length of each row. *Shape:* :math:`(R)`, where :math:`R` is the
number of rows.
:param min_num_spans:
The minimum number of mask spans per row.
:param device:
Expand All @@ -63,7 +64,7 @@ def compute_row_mask(
# We only mask rows that are longer than the mask span length.
if span_len >= max_row_len:
raise ValueError(
f"The size of the second dimension of `shape` must be greater than {span_len}, but is {max_row_len} instead."
f"The size of the second dimension of `shape` must be greater than `span_len` ({span_len}), but is {max_row_len} instead."
)

row_lens = torch.full(
Expand All @@ -75,11 +76,10 @@ def compute_row_mask(
# We only mask rows that are longer than the mask span length.
if (span_len >= row_lens).any():
raise ValueError(
f"All lengths in `row_lens` must be greater than {span_len}, but at least one length is smaller. row_lens: {row_lens}"
f"All lengths in `row_lens` must be greater than `span_len` ({span_len}), but at least one length is smaller. row_lens: {row_lens}"
)

indices = _compute_mask_spans(row_lens, span_len, max_mask_prob, min_num_spans)

if indices is None:
return row_lens.new_empty((0, 0))

Expand All @@ -92,7 +92,7 @@ def _compute_mask_spans(
"""Compute random mask spans of the specified shape."""
device, dtype = row_lens.device, row_lens.dtype

num_rows = row_lens.size(0)
num_rows = len(row_lens)
if num_rows == 0:
return None

Expand All @@ -101,7 +101,7 @@ def _compute_mask_spans(
num_spans_per_row = (max_mask_prob / span_len) * (row_lens - 1)

# Require the same number of mask spans for all rows.
num_spans = cast(int, num_spans_per_row.type(dtype).min().item())
num_spans = int(num_spans_per_row.to(dtype).min())

if min_num_spans > num_spans:
raise ValueError(
Expand Down Expand Up @@ -129,7 +129,7 @@ def _compute_mask_spans(
# The following ops convert the mask span offsets (i.e. start indices) to
# mask spans (i.e. index ranges).
# (R x N) -> (R, N)
span_offsets = span_offsets.type(dtype).view(num_rows, -1)
span_offsets = span_offsets.to(dtype).view(num_rows, -1)

# (R, N) -> (R, N x L)
span_offsets = repeat_interleave(span_offsets, dim=-1, repeat=span_len)
Expand All @@ -153,7 +153,7 @@ def _generate_mask(indices: Tensor, max_row_len: int) -> Tensor:
# Since mask spans may overlap, rows might have varying number of masked
# elements; therefore, we have to randomly unmask some of the elements to
# ensure that all rows have the same amount of masking.
min_num_masked = cast(int, torch.count_nonzero(float_mask, dim=-1).min().item())
min_num_masked = int(torch.count_nonzero(float_mask, dim=-1).min())

# We randomly pick `min_num_masked` masked elements from each row, which
# effectively unmasks the remaining elements.
Expand Down
Empty file.
Loading

0 comments on commit 01ee16c

Please sign in to comment.