Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Fix subword_text_tokenizer to make it invertible again. This breaks e…
Browse files Browse the repository at this point in the history
…xisting models and vocabularies. Change criteria for which characters are parts of words and which are separators - we now consider unicode letters and numbers to be parts of words.

PiperOrigin-RevId: 160690718
  • Loading branch information
nshazeer authored and lukaszkaiser committed Jul 1, 2017
1 parent f3e5859 commit 98be812
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 151 deletions.
207 changes: 127 additions & 80 deletions tensor2tensor/data_generators/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,20 @@
# Dependency imports

import six
from six import PY2
from six.moves import xrange # pylint: disable=redefined-builtin
from tensor2tensor.data_generators import tokenizer

import tensorflow as tf


# Conversion between Unicode and UTF-8, if required (on Python2)
_native_to_unicode = (lambda s: s.decode("utf-8")) if PY2 else (lambda s: s)


_unicode_to_native = (lambda s: s.encode("utf-8")) if PY2 else (lambda s: s)


# Reserved tokens for things like padding and EOS symbols.
PAD = "<pad>"
EOS = "<EOS>"
Expand Down Expand Up @@ -162,15 +171,36 @@ def _load_vocab_from_file(self, filename):


class SubwordTextEncoder(TextEncoder):
"""Class for breaking tokens into subtokens.
"""Class for invertibly encoding text using a limited vocabulary.
Invertibly encodes a string as a sequence of subtokens from a limited
Invertibly encodes a native string as a sequence of subtokens from a limited
vocabulary.
A SubwordTextEncoder is built from a corpus (so it is tailored to the text in
the corpus), and stored to a file. See text_encoder_build_subword.py.
It can then be loaded and used to encode/decode any text.
Encoding has four phases:
1. Tokenize into a list of tokens. Each token is a unicode string of either
all alphanumeric characters or all non-alphanumeric characters. We drop
tokens consisting of a single space that are between two alphanumeric
tokens.
2. Escape each token. This escapes away special and out-of-vocabulary
characters, and makes sure that each token ends with an underscore, and
has no other underscores.
3. Represent each escaped token as a the concatenation of a list of subtokens
from the limited vocabulary. Subtoken selection is done greedily from
beginning to end. That is, we construct the list in order, always picking
the longest subtoken in our vocabulary that matches a prefix of the
remaining portion of the encoded token.
4. Concatenate these lists. This concatenation is invertible due to the
fact that the trailing underscores indicate when one list is finished.
"""

def __init__(self, filename=None, num_reserved_ids=2):
Expand All @@ -182,24 +212,26 @@ def __init__(self, filename=None, num_reserved_ids=2):
super(SubwordTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids)

def encode(self, raw_text):
"""Converts a string to a list of subtoken ids.
"""Converts a native string to a list of subtoken ids.
Args:
raw_text: a string.
raw_text: a native string.
Returns:
a list of integers in the range [0, vocab_size)
"""
return self._tokens_to_subtokens(self._tokenizer.encode(raw_text))
return self._tokens_to_subtokens(self._tokenizer.encode(
_native_to_unicode(raw_text)))

def decode(self, subtokens):
"""Converts a sequence of subtoken ids to a string.
"""Converts a sequence of subtoken ids to a native string.
Args:
subtokens: a list of integers in the range [0, vocab_size)
Returns:
a string
a native string
"""
return self._tokenizer.decode(self._subtokens_to_tokens(subtokens))
return _unicode_to_native(self._tokenizer.decode(
self._subtokens_to_tokens(subtokens)))

@property
def vocab_size(self):
Expand Down Expand Up @@ -239,8 +271,8 @@ def subtoken_to_subtoken_string(self, subtoken):
if subtoken_string:
return subtoken_string
if 0 <= subtoken < self._num_reserved_ids:
return "%s_" % RESERVED_TOKENS[subtoken]
return "ID%d_" % subtoken
return u"%s_" % RESERVED_TOKENS[subtoken]
return u"ID%d_" % subtoken

def _escaped_token_to_subtokens(self, escaped_token):
"""Converts an escaped token string to a list of subtokens.
Expand All @@ -260,27 +292,11 @@ def _escaped_token_to_subtokens(self, escaped_token):
if subtoken != -1:
break
end -= 1
if end > pos:
ret.append(subtoken)
pos = end
else:
# No subtoken in the vocabulary matches escaped_token[pos].
# This can happen if the token contains a Unicode character
# that did not occur in the vocabulary training set.
# The id self.vocab_size - 1 is decoded as Unicode uFFFD,
# REPLACEMENT_CHARACTER.
ret.append(self.vocab_size - 1)
# Ensure that the outer loop continues
pos += 1
return ret
assert end > pos
ret.append(subtoken)
pos = end

@classmethod
def alphabet(cls, token_counts):
"""Return the set of Unicode characters that appear in the tokens."""
alphabet_set = set()
for token in six.iterkeys(token_counts):
alphabet_set |= set(token)
return alphabet_set
return ret

@classmethod
def build_to_target_size(cls,
Expand All @@ -304,17 +320,12 @@ def build_to_target_size(cls,
Returns:
a SubwordTextEncoder instance.
"""
# Calculate the alphabet, i.e. the set of all Unicode characters
# that appear in the tokens.
alphabet_set = cls.alphabet(token_counts)
tf.logging.info("Alphabet contains %d characters" % len(alphabet_set))

def bisect(min_val, max_val):
"""Bisection to find the right size."""
present_count = (max_val + min_val) // 2
tf.logging.info("Trying min_count %d" % present_count)
subtokenizer = cls()
subtokenizer.build_from_token_counts(token_counts, alphabet_set,
subtokenizer.build_from_token_counts(token_counts,
present_count, num_iterations)
if min_val >= max_val or subtokenizer.vocab_size == target_size:
return subtokenizer
Expand All @@ -333,17 +344,29 @@ def bisect(min_val, max_val):

def build_from_token_counts(self,
token_counts,
alphabet_set,
min_count,
num_iterations=4):
"""Train a SubwordTextEncoder based on a dictionary of word counts.
Args:
token_counts: a dictionary of Unicode strings to int.
alphabet_set: the set of Unicode characters that appear in the tokens.
min_count: an integer - discard subtokens with lower counts.
num_iterations: an integer. how many iterations of refinement.
"""
# first determine the alphabet to include all characters with count at
# least min_count in the dataset.
char_counts = defaultdict(int)
for token, count in six.iteritems(token_counts):
for c in token:
char_counts[c] += count
self._alphabet = set()
for c, count in six.iteritems(char_counts):
if count >= min_count:
self._alphabet.add(c)
# Make sure all characters needed for escaping are included
for c in u"\\_;0123456789":
self._alphabet.add(c)

# We build iteratively. On each iteration, we segment all the words,
# then count the resulting potential subtokens, keeping the ones
# with high enough counts for our new vocabulary.
Expand All @@ -367,43 +390,36 @@ def build_from_token_counts(self,
for end in xrange(start + 1, len(escaped_token) + 1):
subtoken_string = escaped_token[start:end]
counts[subtoken_string] += count
# Make sure all characters needed for escaping are included
for c in self._alphabet:
counts[c] += min_count
# Array of sets of candidate subtoken strings, by length
len_to_subtoken_strings = []
for subtoken_string, count in six.iteritems(counts):
lsub = len(subtoken_string)
# All subtoken strings of length 1 are automatically included
# later, so we don't need to consider them here
if count < min_count or lsub <= 1:
continue
# Add this subtoken string to its length set
while len(len_to_subtoken_strings) <= lsub:
len_to_subtoken_strings.append(set())
len_to_subtoken_strings[lsub].add(subtoken_string)
if count >= min_count:
# Add this subtoken string to its length set
while len(len_to_subtoken_strings) <= lsub:
len_to_subtoken_strings.append(set())
len_to_subtoken_strings[lsub].add(subtoken_string)
new_subtoken_strings = []
# consider the candidates longest to shortest, so that if we accept
# a longer subtoken string, we can decrement the counts of its prefixes.
for subtoken_strings in reversed(len_to_subtoken_strings[2:]):
for lsub in reversed(range(1, len(len_to_subtoken_strings))):
subtoken_strings = len_to_subtoken_strings[lsub]
for subtoken_string in subtoken_strings:
count = counts[subtoken_string]
if count < min_count:
continue
new_subtoken_strings.append((count, subtoken_string))
for l in xrange(1, len(subtoken_string)):
counts[subtoken_string[:l]] -= count
# Sort what we've got so far in decreasing order by count
if count >= min_count:
new_subtoken_strings.append((count, subtoken_string))
for l in xrange(1, lsub):
counts[subtoken_string[:l]] -= count
# Sort in decreasing order by count
new_subtoken_strings.sort(reverse=True)
# Add the alphabet set at the end of the vocabulary list
for char in alphabet_set:
new_subtoken_strings.append((0, char))
# Also include the Unicode REPLACEMENT CHARACTER to use
# when encountering previously unseen Unicode characters
# in the input (i.e. input external to the tokenizer training
# set, which may thus contain characters not in the alphabet_set).
# This must be the last entry in the subtoken vocabulary list.
new_subtoken_strings.append((0, u"\uFFFD"))
# Now we have a candidate vocabulary
old_alphabet = self._alphabet
self._init_from_list([u""] * self._num_reserved_ids +
[p[1] for p in new_subtoken_strings])
assert old_alphabet == self._alphabet
tf.logging.info("vocab_size = %d" % self.vocab_size)

original = "This sentence was encoded by the SubwordTextEncoder."
Expand All @@ -426,46 +442,77 @@ def _init_from_list(self, subtoken_strings):
self._all_subtoken_strings = subtoken_strings
self._subtoken_string_to_id = {
s: i for i, s in enumerate(subtoken_strings) if s}
self._alphabet = set([c for c in subtoken_strings if len(c) == 1])

def _load_from_file(self, filename):
"""Load from a file."""
subtoken_strings = []
with tf.gfile.Open(filename) as f:
for line in f:
if six.PY2:
subtoken_strings.append(line.strip()[1:-1].decode("utf-8"))
else:
subtoken_strings.append(line.strip()[1:-1])
subtoken_strings.append(_native_to_unicode(line.strip()[1:-1]))
self._init_from_list(subtoken_strings)

def store_to_file(self, filename):
with tf.gfile.Open(filename, "w") as f:
for subtoken_string in self._all_subtoken_strings:
if six.PY2:
f.write("'" + subtoken_string.encode("utf-8") + "'\n")
else:
f.write("'" + subtoken_string + "'\n")
f.write("'" + _unicode_to_native(subtoken_string) + "'\n")

def _escape_token(self, token):
r"""Translate '\'->'\\' and '_'->'\u', then append '_'.
r"""Escape away underscores and OOV characters and append '_'.
This allows the token to be experessed as the concatenation of a list
of subtokens from the vocabulary. The underscore acts as a sentinel
which allows us to invertibly concatenate multiple such lists.
Args:
token: a string
token: a unicode string
Returns:
escaped_token: a string
escaped_token: a unicode string
"""
return token.replace("\\", "\\\\").replace("_", "\\u") + "_"
token = token.replace("\\", "\\\\").replace("_", "\\u") + "_"
ret = u""
for c in token:
if c in self._alphabet:
ret += c
else:
ret += u"\\%d;" % ord(c)
return ret

def _unescape_token(self, escaped_token):
r"""Remove '_' from end, then translate '\\'->'\' and '\u'->'_'.
r"""Inverse of _escape_token().
Args:
escaped_token: a string
escaped_token: a unicode string
Returns:
token: a string
token: a unicode string
"""
assert escaped_token[-1] == "_"
return escaped_token[:-1].replace("\\u", "_").replace("\\\\", "\\")
ret = u""
escaped_token = escaped_token[:-1]
pos = 0
while pos < len(escaped_token):
c = escaped_token[pos]
if c == "\\":
pos += 1
c = escaped_token[pos]
if c == u"u":
ret += u"_"
pos += 1
elif c == "\\":
ret += u"_"
pos += 1
else:
semicolon_pos = escaped_token.find(u";", pos)
if semicolon_pos == -1:
continue
try:
ret += unichr(int(escaped_token[pos:semicolon_pos]))
pos = semicolon_pos + 1
except (ValueError, OverflowError) as _:
pass
else:
ret += c
pos += 1
return ret

@classmethod
def get_token_counts(cls, text_filepattern, corpus_max_lines):
Expand All @@ -477,7 +524,7 @@ def get_token_counts(cls, text_filepattern, corpus_max_lines):
with tf.gfile.Open(text_filename) as f:
for line in f:
# The tokenizer updates token_counts in encode()
tok.encode(line.strip())
tok.encode(_native_to_unicode(line.strip()))
lines_read += 1
if corpus_max_lines > 0 and lines_read > corpus_max_lines:
return tok.token_counts
Expand Down
3 changes: 1 addition & 2 deletions tensor2tensor/data_generators/text_encoder_build_subword.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def main(unused_argv):
raise ValueError('Must provide --corpus_filepattern')
token_counts = text_encoder.SubwordTextEncoder.get_token_counts(
FLAGS.corpus_filepattern, FLAGS.corpus_max_lines)
alphabet_set = text_encoder.SubwordTextEncoder.alphabet(token_counts)
gs.build_from_token_counts(token_counts, alphabet_set,
gs.build_from_token_counts(token_counts,
FLAGS.min_count,
FLAGS.num_iterations)
gs.store_to_file(FLAGS.output_fn)
Expand Down
Loading

1 comment on commit 98be812

@vthorsteinsson
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! 👍

Please sign in to comment.