Skip to content

Commit

Permalink
Fixes part 2
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov committed Dec 27, 2024
1 parent 7c9a6f5 commit 202b69e
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 1 deletion.
5 changes: 5 additions & 0 deletions tests/python_tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def get_beam_search() -> GenerationConfig:
generation_config = GenerationConfig()
generation_config.num_beam_groups = 3
generation_config.num_beams = 6
generation_config.diversity_penalty = 1
generation_config.max_new_tokens = 30
generation_config.num_return_sequences = 3
generation_config.num_return_sequences = generation_config.num_beams
Expand All @@ -82,6 +83,7 @@ def get_beam_search_min_and_max_tokens() -> GenerationConfig:
generation_config = GenerationConfig()
generation_config.num_beam_groups = 3
generation_config.num_beams = 6
generation_config.diversity_penalty = 1
generation_config.min_new_tokens = 15
generation_config.max_new_tokens = 30
generation_config.num_return_sequences = 3
Expand All @@ -92,6 +94,7 @@ def get_beam_search_with_single_stop_string() -> GenerationConfig:
generation_config = GenerationConfig()
generation_config.num_beam_groups = 3
generation_config.num_beams = 6
generation_config.diversity_penalty = 1
generation_config.max_new_tokens = 50
generation_config.num_return_sequences = generation_config.num_beams
generation_config.stop_strings = {"open sour"} # expected match on "open source"
Expand All @@ -102,6 +105,7 @@ def get_beam_search_with_multiple_stop_strings() -> GenerationConfig:
generation_config = GenerationConfig()
generation_config.num_beam_groups = 3
generation_config.num_beams = 6
generation_config.diversity_penalty = 1
generation_config.max_new_tokens = 50
generation_config.num_return_sequences = generation_config.num_beams
generation_config.stop_strings = {".", "software", "Intel"}
Expand All @@ -112,6 +116,7 @@ def get_beam_search_with_multiple_stop_strings_no_match() -> GenerationConfig:
generation_config = GenerationConfig()
generation_config.num_beam_groups = 3
generation_config.num_beams = 6
generation_config.diversity_penalty = 1
generation_config.max_new_tokens = 30
generation_config.num_return_sequences = generation_config.num_beams
generation_config.stop_strings = {"Einstein", "sunny", "geothermal"}
Expand Down
1 change: 1 addition & 0 deletions tests/python_tests/test_continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def get_beam_search_seq_len_300() -> GenerationConfig:
generation_config = GenerationConfig()
generation_config.num_beam_groups = 3
generation_config.num_beams = 6
generation_config.diversity_penalty = 1
generation_config.max_new_tokens = 300
generation_config.num_return_sequences = generation_config.num_beams
return generation_config
Expand Down
2 changes: 1 addition & 1 deletion tests/python_tests/test_kv_cache_eviction.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,14 @@ def test_cache_optimized_generation_is_similar_to_unoptimized(converted_model, t

def get_greedy_seq_len_300() -> GenerationConfig:
generation_config = GenerationConfig()
generation_config.num_return_sequences = 3
generation_config.max_new_tokens = 300
return generation_config

def get_beam_search_seq_len_300() -> GenerationConfig:
generation_config = GenerationConfig()
generation_config.num_beam_groups = 3
generation_config.num_beams = 6
generation_config.diversity_penalty = 1
generation_config.max_new_tokens = 300
generation_config.num_return_sequences = generation_config.num_beams
return generation_config
Expand Down

0 comments on commit 202b69e

Please sign in to comment.