Skip to content

Commit

Permalink
Remove pad_with_end_token argument.
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Jan 18, 2025
1 parent 2ac95a9 commit 8be7564
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 21 deletions.
4 changes: 2 additions & 2 deletions keras_hub/src/models/clip/clip_preprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_preprocessor_basics(self):
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output={
"token_ids": [[5, 1, 2, 1, 3, 4, 0, 0]],
"token_ids": [[5, 1, 2, 1, 3, 4, 4, 4]],
"padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]],
},
)
Expand All @@ -39,7 +39,7 @@ def test_no_start_end_token(self):
add_end_token=False,
)
x = preprocessor(input_data)
self.assertAllEqual(x["token_ids"], [[1, 2, 1, 3, 0, 0, 0, 0]] * 4)
self.assertAllEqual(x["token_ids"], [[1, 2, 1, 3, 4, 4, 4, 4]] * 4)
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4)

def test_sequence_length_override(self):
Expand Down
22 changes: 10 additions & 12 deletions keras_hub/src/models/clip/clip_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class CLIPTokenizer(BytePairTokenizer):
it should be the file path to merge rules. The merge rule file
should have one merge rule per line. Every merge rule contains
merge entities separated by a space.
pad_with_end_token: bool. Whether to pad the output with `end_token`.
Examples:
Expand All @@ -64,13 +63,17 @@ def __init__(
self,
vocabulary=None,
merges=None,
pad_with_end_token=False,
**kwargs,
):
self._add_special_token("<|startoftext|>", "start_token")
self._add_special_token("<|endoftext|>", "end_token")
self.pad_token_id = 0
self.pad_with_end_token = pad_with_end_token
self._add_special_token("<|endoftext|>", "pad_token")

# TODO: Remove this in the future.
# To maintain backward compatibility, we need to remove
# `"pad_with_end_token"` arg.
if "pad_with_end_token" in kwargs:
kwargs.pop("pad_with_end_token")

super().__init__(
vocabulary=vocabulary,
Expand All @@ -81,8 +84,6 @@ def __init__(

def set_vocabulary_and_merges(self, vocabulary, merges):
super().set_vocabulary_and_merges(vocabulary, merges)
if self.pad_with_end_token:
self.pad_token_id = self.end_token_id

def _bpe_merge_and_update_cache(self, tokens):
"""Process unseen tokens and add to cache."""
Expand Down Expand Up @@ -161,7 +162,9 @@ def process_unseen_tokens():
if self.sequence_length:
output_shape = tokens.shape.as_list()
output_shape[-1] = self.sequence_length
tokens = tokens.to_tensor(shape=output_shape)
tokens = tokens.to_tensor(
default_value=self.pad_token_id, shape=output_shape
)

# Convert to a dense output if input in scalar
if unbatched:
Expand Down Expand Up @@ -194,11 +197,6 @@ def detokenize(self, inputs):

def get_config(self):
config = super().get_config()
config.update(
{
"pad_with_end_token": self.pad_with_end_token,
}
)
# In the constructor, we pass the list of special tokens to the
# `unsplittable_tokens` arg of the superclass' constructor. Hence, we
# delete it from the config here.
Expand Down
6 changes: 0 additions & 6 deletions keras_hub/src/models/clip/clip_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,6 @@ def test_tokenizer_basics(self):
expected_detokenize_output=["airplane", "airport"],
)

def test_pad_with_end_token(self):
init_kwargs = self.init_kwargs.copy()
init_kwargs["pad_with_end_token"] = True
tokenizer = CLIPTokenizer(**init_kwargs)
self.assertEqual(tokenizer.pad_token_id, tokenizer.end_token_id)

def test_errors_missing_special_tokens(self):
with self.assertRaises(ValueError):
CLIPTokenizer(vocabulary={"foo": 0, "bar": 1}, merges=["fo o"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,4 @@ def test_generate_preprocess(self):
self.assertIn("clip_l", x)
self.assertIn("clip_g", x)
self.assertAllEqual(x["clip_l"][0], [4, 0, 1, 3, 3, 3, 3, 3])
self.assertAllEqual(x["clip_g"][0], [4, 0, 1, 3, 0, 0, 0, 0])
self.assertAllEqual(x["clip_g"][0], [4, 0, 1, 3, 3, 3, 3, 3])

0 comments on commit 8be7564

Please sign in to comment.