Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Jan 15, 2025
1 parent e4cc078 commit b57a1c8
Showing 1 changed file with 32 additions and 2 deletions.
34 changes: 32 additions & 2 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,6 +1658,21 @@ def test_compare_to_transformers(self, model_arch):
transformers_outputs = transformers_model(**tokens, **decoder_inputs)
# Compare tensor outputs
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4))
gen_config = GenerationConfig(
max_new_tokens=10,
min_new_tokens=10,
num_beams=2,
do_sample=False,
eos_token_id=None,
)

set_seed(SEED)
generated_tokens = transformers_model.generate(**tokens, generation_config=gen_config)
set_seed(SEED)
ov_generated_tokens = ov_model.generate(**tokens, generation_config=gen_config)

self.assertTrue(torch.equal(generated_tokens, ov_generated_tokens))

del transformers_model
del ov_model

Expand Down Expand Up @@ -2355,12 +2370,12 @@ def test_compare_to_transformers(self, model_arch):

processor = get_preprocessor(model_id)
data = self._generate_random_audio_data()
features = processor.feature_extractor(data, return_tensors="pt")
pt_features = processor.feature_extractor(data, return_tensors="pt")
decoder_start_token_id = transformers_model.config.decoder_start_token_id
decoder_inputs = {"decoder_input_ids": torch.ones((1, 1), dtype=torch.long) * decoder_start_token_id}

with torch.no_grad():
transformers_outputs = transformers_model(**features, **decoder_inputs)
transformers_outputs = transformers_model(**pt_features, **decoder_inputs)

for input_type in ["pt", "np"]:
features = processor.feature_extractor(data, return_tensors=input_type)
Expand All @@ -2373,6 +2388,21 @@ def test_compare_to_transformers(self, model_arch):
# Compare tensor outputs
self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-3))

gen_config = GenerationConfig(
max_new_tokens=10,
min_new_tokens=10,
num_beams=2,
do_sample=False,
eos_token_id=None,
)

set_seed(SEED)
generated_tokens = transformers_model.generate(**pt_features, generation_config=gen_config)
set_seed(SEED)
ov_generated_tokens = ov_model.generate(**pt_features, generation_config=gen_config)

self.assertTrue(torch.equal(generated_tokens, ov_generated_tokens))

del transformers_model
del ov_model
gc.collect()
Expand Down

0 comments on commit b57a1c8

Please sign in to comment.