Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WWB]: Add support for text-to-image GenAI pipelines #1052

Merged
merged 10 commits into from
Oct 24, 2024
64 changes: 63 additions & 1 deletion tools/who_what_benchmark/tests/test_cli_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import shutil
import pytest
import logging
import tempfile


logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -42,9 +43,10 @@ def test_image_model_types(model_id, model_type, backend):
]
if backend == "hf":
wwb_args.append("--hf")
elif backend == "genai":
wwb_args.append("--genai")

result = run_wwb(wwb_args)
print(f"WWB result: {result}, {result.stderr}")

try:
os.remove(GT_FILE)
Expand All @@ -58,6 +60,64 @@ def test_image_model_types(model_id, model_type, backend):
assert "## Reference text" not in result.stderr


@pytest.mark.parametrize(
("model_id", "model_type"),
[
("echarlaix/tiny-random-stable-diffusion-xl", "text-to-image"),
],
)
def test_image_model_genai(model_id, model_type):
GT_FILE = "test_sd.json"
MODEL_PATH = tempfile.TemporaryDirectory().name

result = subprocess.run(["optimum-cli", "export",
AlexKoff88 marked this conversation as resolved.
Show resolved Hide resolved
"openvino", "-m", model_id,
MODEL_PATH], capture_output=True, text=True)
assert result.returncode == 0

wwb_args = [
"--base-model",
MODEL_PATH,
"--num-samples",
"1",
"--gt-data",
GT_FILE,
"--device",
"CPU",
"--model-type",
model_type,
]
result = run_wwb(wwb_args)
assert result.returncode == 0

wwb_args = [
"--target-model",
MODEL_PATH,
"--num-samples",
"1",
"--gt-data",
GT_FILE,
"--device",
"CPU",
"--model-type",
model_type,
"--genai",
]
result = run_wwb(wwb_args)

try:
os.remove(GT_FILE)
except OSError:
pass
shutil.rmtree("reference", ignore_errors=True)
shutil.rmtree("target", ignore_errors=True)
shutil.rmtree(MODEL_PATH, ignore_errors=True)

assert result.returncode == 0
assert "Metrics for model" in result.stderr
assert "## Reference text" not in result.stderr


@pytest.mark.parametrize(
("model_id", "model_type", "backend"),
[
Expand All @@ -84,6 +144,8 @@ def test_image_custom_dataset(model_id, model_type, backend):
]
if backend == "hf":
wwb_args.append("--hf")
elif backend == "genai":
wwb_args.append("--genai")

result = run_wwb(wwb_args)

Expand Down
4 changes: 4 additions & 0 deletions tools/who_what_benchmark/whowhatbench/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,7 @@ def score(self, model, **kwargs):
@abstractmethod
def worst_examples(self, top_k: int = 5, metric="similarity"):
pass

@abstractmethod
def get_generation_fn(self):
raise NotImplementedError("generation_fn should be returned")
43 changes: 34 additions & 9 deletions tools/who_what_benchmark/whowhatbench/text2image_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tqdm import tqdm
from transformers import set_seed
import torch
import openvino_genai

from .registry import register_evaluator, BaseEvaluator

Expand All @@ -26,6 +27,17 @@
}


class Generator(openvino_genai.Generator):
def __init__(self, seed, rng, mu=0.0, sigma=1.0):
openvino_genai.Generator.__init__(self)
self.mu = mu
self.sigma = sigma
self.rng = rng

def next(self):
return torch.normal(torch.tensor(self.mu), self.sigma, generator=self.rng)
AlexKoff88 marked this conversation as resolved.
Show resolved Hide resolved


@register_evaluator("text-to-image")
class Text2ImageEvaluator(BaseEvaluator):
def __init__(
Expand All @@ -41,6 +53,7 @@ def __init__(
num_samples=None,
gen_image_fn=None,
seed=42,
is_genai=False,
) -> None:
assert (
base_model is not None or gt_data is not None
Expand All @@ -57,17 +70,25 @@ def __init__(
self.similarity = ImageSimilarity(similarity_model_id)
self.last_cmp = None
self.gt_dir = os.path.dirname(gt_data)
self.generation_fn = gen_image_fn
self.is_genai = is_genai

if base_model:
base_model.resolution = self.resolution
self.gt_data = self._generate_data(
base_model, gen_image_fn, os.path.join(self.gt_dir, "reference")
)
else:
self.gt_data = pd.read_csv(gt_data, keep_default_na=False)

def get_generation_fn(self):
return self.generation_fn

def dump_gt(self, csv_name: str):
self.gt_data.to_csv(csv_name)

def score(self, model, gen_image_fn=None):
model.resolution = self.resolution
predictions = self._generate_data(
model, gen_image_fn, os.path.join(self.gt_dir, "target")
)
Expand Down Expand Up @@ -100,12 +121,13 @@ def worst_examples(self, top_k: int = 5, metric="similarity"):

def _generate_data(self, model, gen_image_fn=None, image_dir="reference"):
if hasattr(model, "reshape") and self.resolution is not None:
model.reshape(
batch_size=1,
height=self.resolution[0],
width=self.resolution[1],
num_images_per_prompt=1,
)
if gen_image_fn is None:
model.reshape(
batch_size=1,
height=self.resolution[0],
width=self.resolution[1],
num_images_per_prompt=1,
)

def default_gen_image_fn(model, prompt, num_inference_steps, generator=None):
output = model(
Expand All @@ -118,7 +140,7 @@ def default_gen_image_fn(model, prompt, num_inference_steps, generator=None):
)
return output.images[0]

gen_image_fn = gen_image_fn or default_gen_image_fn
generation_fn = gen_image_fn or default_gen_image_fn

if self.test_data:
if isinstance(self.test_data, str):
Expand All @@ -144,13 +166,16 @@ def default_gen_image_fn(model, prompt, num_inference_steps, generator=None):

if not os.path.exists(image_dir):
os.makedirs(image_dir)

print(gen_image_fn)
AlexKoff88 marked this conversation as resolved.
Show resolved Hide resolved
for i, prompt in tqdm(enumerate(prompts), desc="Evaluate pipeline"):
set_seed(self.seed)
image = gen_image_fn(
rng = rng.manual_seed(self.seed)
image = generation_fn(
model,
prompt,
self.num_inference_steps,
generator=rng.manual_seed(self.seed),
generator=Generator(self.seed, rng) if self.is_genai else rng
)
image_path = os.path.join(image_dir, f"{i}.png")
image.save(image_path)
Expand Down
4 changes: 4 additions & 0 deletions tools/who_what_benchmark/whowhatbench/text_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(
self.generation_config = generation_config
self.generation_config_base = generation_config
self.seqs_per_request = seqs_per_request
self.generation_fn = gen_answer_fn
if self.generation_config is not None:
assert self.seqs_per_request is not None

Expand Down Expand Up @@ -151,6 +152,9 @@ def __init__(

self.last_cmp = None

def get_generation_fn(self):
return self.generation_fn

def dump_gt(self, csv_name: str):
self.gt_data.to_csv(csv_name)

Expand Down
Loading
Loading