diff --git a/.github/workflows/llm_bench-python.yml b/.github/workflows/llm_bench-python.yml index 0ac47d1aa0..77f26d33a0 100644 --- a/.github/workflows/llm_bench-python.yml +++ b/.github/workflows/llm_bench-python.yml @@ -66,28 +66,28 @@ jobs: python ./tools/llm_bench/benchmark.py -m tiny-random-qwen -d cpu -n 1 -f pt env: GIT_LFS_SKIP_SMUDGE: 0 - - name: Test tiny-random-baichuan2 on Linux + - name: Test tiny-random-baichuan2 on Linux Optimum Intel run: | optimum-cli export openvino --model katuni4ka/tiny-random-baichuan2 --trust-remote-code --weight-format fp16 ./ov_models/tiny-random-baichuan2/pytorch/dldt/FP16 - python ./tools/llm_bench/benchmark.py -m ./ov_models/tiny-random-baichuan2/pytorch/dldt/FP16/ -d cpu -n 1 - - name: Test tiny-stable-diffusion on Linux + python ./tools/llm_bench/benchmark.py -m ./ov_models/tiny-random-baichuan2/pytorch/dldt/FP16/ -d cpu -n 1 --optimum + - name: Test tiny-stable-diffusion on Linux Optimum Intel run: | optimum-cli export openvino --model segmind/tiny-sd --trust-remote-code --weight-format fp16 ./ov_models/tiny-sd/pytorch/dldt/FP16/ - python ./tools/llm_bench/benchmark.py -m ./ov_models/tiny-sd/pytorch/dldt/FP16/ -pf ./tools/llm_bench/prompts/stable-diffusion.jsonl -d cpu -n 1 + python ./tools/llm_bench/benchmark.py -m ./ov_models/tiny-sd/pytorch/dldt/FP16/ -pf ./tools/llm_bench/prompts/stable-diffusion.jsonl -d cpu -n 1 --optimum - name: Test dreamlike-anime on Linux with GenAI run: | optimum-cli export openvino --model dreamlike-art/dreamlike-anime-1.0 --task stable-diffusion --weight-format fp16 ov_models/dreamlike-art-dreamlike-anime-1.0/FP16 - python ./tools/llm_bench/benchmark.py -m ./ov_models/dreamlike-art-dreamlike-anime-1.0/FP16/ -pf ./tools/llm_bench/prompts/stable-diffusion.jsonl -d cpu -n 1 --genai + python ./tools/llm_bench/benchmark.py -m ./ov_models/dreamlike-art-dreamlike-anime-1.0/FP16/ -pf ./tools/llm_bench/prompts/stable-diffusion.jsonl -d cpu -n 1 - name: Test dreamlike-anime on Linux with GenAI and LoRA run: | wget -O ./ov_models/soulcard.safetensors https://civitai.com/api/download/models/72591 - python ./tools/llm_bench/benchmark.py -m ./ov_models/dreamlike-art-dreamlike-anime-1.0/FP16/ -pf ./tools/llm_bench/prompts/stable-diffusion.jsonl -d cpu -n 1 --genai --lora ./ov_models/soulcard.safetensors --lora_alphas 0.7 + python ./tools/llm_bench/benchmark.py -m ./ov_models/dreamlike-art-dreamlike-anime-1.0/FP16/ -pf ./tools/llm_bench/prompts/stable-diffusion.jsonl -d cpu -n 1 --lora ./ov_models/soulcard.safetensors --lora_alphas 0.7 - name: Test TinyLlama-1.1B-Chat-v1.0 in Speculative Deconding mode on Linux run: | optimum-cli export openvino --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --trust-remote-code --weight-format fp16 ov_models/TinyLlama-1.1B-Chat-v1.0/FP16 optimum-cli export openvino --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --trust-remote-code --weight-format int8 ov_models/TinyLlama-1.1B-Chat-v1.0/INT8 - python ./tools/llm_bench/benchmark.py -m ./ov_models/TinyLlama-1.1B-Chat-v1.0/FP16/ --draft_model ./ov_models/TinyLlama-1.1B-Chat-v1.0/INT8/ -p "Why is the Sun yellow?" -d cpu --draft_device cpu -n 1 --genai --assistant_confidence_threshold 0.4 - python ./tools/llm_bench/benchmark.py -m ./ov_models/TinyLlama-1.1B-Chat-v1.0/FP16/ --draft_model ./ov_models/TinyLlama-1.1B-Chat-v1.0/INT8/ -p "Why is the Sun yellow?" -d cpu --draft_device cpu -n 1 --genai --num_assistant_tokens 5 + python ./tools/llm_bench/benchmark.py -m ./ov_models/TinyLlama-1.1B-Chat-v1.0/FP16/ --draft_model ./ov_models/TinyLlama-1.1B-Chat-v1.0/INT8/ -p "Why is the Sun yellow?" -d cpu --draft_device cpu -n 1 --assistant_confidence_threshold 0.4 + python ./tools/llm_bench/benchmark.py -m ./ov_models/TinyLlama-1.1B-Chat-v1.0/FP16/ --draft_model ./ov_models/TinyLlama-1.1B-Chat-v1.0/INT8/ -p "Why is the Sun yellow?" -d cpu --draft_device cpu -n 1 --num_assistant_tokens 5 - name: Test whisper-tiny on Linux run: | GIT_LFS_SKIP_SMUDGE=1 git clone --depth 1 --branch main --single-branch https://huggingface.co/datasets/facebook/multilingual_librispeech @@ -97,8 +97,8 @@ jobs: tar zxvf data/mls_polish/train/audio/3283_1447_000.tar.gz -C data/mls_polish/train/audio/3283_1447_000/ cd .. optimum-cli export openvino --trust-remote-code --model openai/whisper-tiny ./ov_models/whisper-tiny + python ./tools/llm_bench/benchmark.py -m ./ov_models/whisper-tiny --media multilingual_librispeech/data/mls_polish/train/audio/3283_1447_000/3283_1447_000000.flac -d cpu -n 1 --optimum python ./tools/llm_bench/benchmark.py -m ./ov_models/whisper-tiny --media multilingual_librispeech/data/mls_polish/train/audio/3283_1447_000/3283_1447_000000.flac -d cpu -n 1 - python ./tools/llm_bench/benchmark.py -m ./ov_models/whisper-tiny --media multilingual_librispeech/data/mls_polish/train/audio/3283_1447_000/3283_1447_000000.flac -d cpu -n 1 --genai - name: WWB Tests run: | GIT_CLONE_PROTECTION_ACTIVE=false pip install -r ${{ env.WWB_PATH }}/requirements.txt diff --git a/tools/llm_bench/README.md b/tools/llm_bench/README.md index d3f643b58f..87f6e91271 100755 --- a/tools/llm_bench/README.md +++ b/tools/llm_bench/README.md @@ -161,11 +161,10 @@ For example, `--load_config config.json` as following will result in streams.num ## 6. Execution on CPU device -OpenVINO is by default built with [oneTBB](https://github.com/oneapi-src/oneTBB/) threading library, while Torch uses [OpenMP](https://www.openmp.org/). Both threading libraries have ['busy-wait spin'](https://gcc.gnu.org/onlinedocs/libgomp/GOMP_005fSPINCOUNT.html) by default. When running LLM pipeline on CPU device, there is threading overhead in the switching between inference on CPU with OpenVINO (oneTBB) and postprocessing (For example: greedy search or beam search) with Torch (OpenMP). +OpenVINO is by default built with [oneTBB](https://github.com/oneapi-src/oneTBB/) threading library, while Torch uses [OpenMP](https://www.openmp.org/). Both threading libraries have ['busy-wait spin'](https://gcc.gnu.org/onlinedocs/libgomp/GOMP_005fSPINCOUNT.html) by default. When running LLM pipeline on CPU device, there is threading overhead in the switching between inference on CPU with OpenVINO (oneTBB) and postprocessing (For example: greedy search or beam search) with Torch (OpenMP). The default benchmarking scenarion uses OpenVINO GenAI that implements own postprocessing api without additional dependencies. **Alternative solutions** -1. Use --genai option which uses OpenVINO genai API instead of optimum-intel API. In this case postprocessing is executed with OpenVINO genai API. -2. Without --genai option which uses optimum-intel API, set environment variable [OMP_WAIT_POLICY](https://gcc.gnu.org/onlinedocs/libgomp/OMP_005fWAIT_005fPOLICY.html) to PASSIVE which will disable OpenMP 'busy-wait', and benchmark.py will limit the Torch thread number by default to avoid using CPU cores which is in 'busy-wait' by OpenVINO inference. Users can also set the number with --set_torch_thread option. +1. With --optimum option which uses optimum-intel API, set environment variable [OMP_WAIT_POLICY](https://gcc.gnu.org/onlinedocs/libgomp/OMP_005fWAIT_005fPOLICY.html) to PASSIVE which will disable OpenMP 'busy-wait', and benchmark.py will limit the Torch thread number by default to avoid using CPU cores which is in 'busy-wait' by OpenVINO inference. Users can also set the number with --set_torch_thread option. ## 7. Additional Resources diff --git a/tools/llm_bench/benchmark.py b/tools/llm_bench/benchmark.py index d652c8b48f..fe5068b009 100644 --- a/tools/llm_bench/benchmark.py +++ b/tools/llm_bench/benchmark.py @@ -130,7 +130,8 @@ def get_argprser(): ) parser.add_argument('-od', '--output_dir', help='Save the input text and generated text, images to files') llm_bench_utils.model_utils.add_stateful_model_arguments(parser) - parser.add_argument("--genai", action="store_true", help="Use OpenVINO GenAI optimized pipelines for benchmarking") + parser.add_argument("--genai", action="store_true", help="[DEPRECATED] Use OpenVINO GenAI optimized pipelines for benchmarking. Enabled by default") + parser.add_argument("--optimum", action="store_true", help="Use Optimum Intel pipelines for benchmarking") parser.add_argument( "--lora", nargs='*', diff --git a/tools/llm_bench/llm_bench_utils/config_class.py b/tools/llm_bench/llm_bench_utils/config_class.py index 2f6cd95664..12385d2879 100644 --- a/tools/llm_bench/llm_bench_utils/config_class.py +++ b/tools/llm_bench/llm_bench_utils/config_class.py @@ -7,9 +7,7 @@ from optimum.intel.openvino import ( OVModelForCausalLM, OVModelForSeq2SeqLM, - OVStableDiffusionPipeline, - OVLatentConsistencyModelPipeline, - OVStableDiffusionXLPipeline, + OVDiffusionPipeline, OVModelForSpeechSeq2Seq ) from llm_bench_utils.ov_model_classes import OVMPTModel, OVLDMSuperResolutionPipeline, OVChatGLMModel @@ -22,19 +20,14 @@ 'falcon': AutoTokenizer, } +IMAGE_GEN_CLS = OVDiffusionPipeline + OV_MODEL_CLASSES_MAPPING = { 'decoder': OVModelForCausalLM, 't5': OVModelForSeq2SeqLM, 'blenderbot': OVModelForSeq2SeqLM, 'falcon': OVModelForCausalLM, 'mpt': OVMPTModel, - 'stable-diffusion-xl': OVStableDiffusionXLPipeline, - 'sdxl': OVStableDiffusionXLPipeline, - 'lcm-sdxl': OVStableDiffusionXLPipeline, - 'ssd-': OVStableDiffusionXLPipeline, - 'lcm-ssd-': OVStableDiffusionXLPipeline, - 'stable_diffusion': OVStableDiffusionPipeline, - 'lcm': OVLatentConsistencyModelPipeline, 'replit': OVMPTModel, 'codet5': OVModelForSeq2SeqLM, 'codegen2': OVModelForCausalLM, @@ -57,7 +50,7 @@ } USE_CASES = { - 'image_gen': ['stable-diffusion-', 'ssd-', 'deepfloyd-if', 'tiny-sd', 'small-sd', 'lcm-', 'sdxl', 'dreamlike'], + 'image_gen': ['stable-diffusion-', 'ssd-', 'tiny-sd', 'small-sd', 'lcm-', 'sdxl', 'dreamlike', "flux"], 'speech2text': ['whisper'], 'image_cls': ['vit'], 'code_gen': ['replit', 'codegen2', 'codegen', 'codet5', "stable-code"], diff --git a/tools/llm_bench/llm_bench_utils/metrics_print.py b/tools/llm_bench/llm_bench_utils/metrics_print.py index de9d0126f8..73e83dc672 100644 --- a/tools/llm_bench/llm_bench_utils/metrics_print.py +++ b/tools/llm_bench/llm_bench_utils/metrics_print.py @@ -97,12 +97,17 @@ def print_stable_diffusion_infer_latency(iter_str, iter_data, stable_diffusion, prefix = f'[{iter_str}][P{prompt_idx}]' log.info(f"{prefix} First step of unet latency: {iter_data['first_token_latency']:.2f} ms/step, " f"other steps of unet latency: {iter_data['other_tokens_avg_latency']:.2f} ms/step",) - log.info(f"{prefix} Text encoder latency: {stable_diffusion.get_text_encoder_latency():.2f} ms/step, " - f"unet latency: {stable_diffusion.get_unet_latency():.2f} ms/step, " - f"vae decoder latency: {stable_diffusion.get_vae_decoder_latency():.2f} ms/step, " - f"text encoder step count: {stable_diffusion.get_text_encoder_step_count()}, " - f"unet step count: {stable_diffusion.get_unet_step_count()}, " - f"vae decoder step count: {stable_diffusion.get_vae_decoder_step_count()}",) + has_text_encoder_time = stable_diffusion.get_text_encoder_step_count() != -1 + log_str = ( + f"{prefix} Text encoder latency: {stable_diffusion.get_text_encoder_latency():.2f}" if has_text_encoder_time else f"{prefix} Text encoder latency: N/A " + f"unet latency: {stable_diffusion.get_unet_latency():.2f} ms/step, " + f"vae decoder latency: {stable_diffusion.get_vae_decoder_latency():.2f} ms/step, ") + if has_text_encoder_time: + log_str += f"text encoder step count: {stable_diffusion.get_text_encoder_step_count()}, " + log_str += ( + f"unet step count: {stable_diffusion.get_unet_step_count()}, " + f"vae decoder step count: {stable_diffusion.get_vae_decoder_step_count()}") + log.info(log_str) def print_ldm_unet_vqvae_infer_latency(iter_num, iter_data, tms=None, warm_up=False, prompt_idx=-1): diff --git a/tools/llm_bench/llm_bench_utils/model_utils.py b/tools/llm_bench/llm_bench_utils/model_utils.py index 6539bef232..f72557b6c5 100644 --- a/tools/llm_bench/llm_bench_utils/model_utils.py +++ b/tools/llm_bench/llm_bench_utils/model_utils.py @@ -95,6 +95,13 @@ def analyze_args(args): model_args['torch_compile_input_module'] = args.torch_compile_input_module model_args['media'] = args.media + optimum = args.optimum + + if optimum and args.genai: + raise RuntimeError("`--genai` and `--optimum` can not be selected in the same time") + model_args["optimum"] = optimum + model_args["genai"] = not optimum + has_torch_compile_options = any([args.torch_compile_options is not None, args.torch_compile_options is not None, args.torch_compile_dynamic]) if model_args["torch_compile_backend"] is None and has_torch_compile_options: log.warning("torch.compile configuration options provided, but backend is not selected, openvino backend will be used") @@ -102,7 +109,6 @@ def analyze_args(args): model_args['convert_tokenizer'] = args.convert_tokenizer model_args['subsequent'] = args.subsequent model_args['output_dir'] = args.output_dir - model_args['genai'] = args.genai model_args['lora'] = args.lora model_args['lora_alphas'] = args.lora_alphas model_args["use_cb"] = args.use_cb @@ -135,7 +141,7 @@ def analyze_args(args): model_args['model_type'] = get_model_type(model_name, use_case, model_framework) model_args['model_name'] = model_name - if (args.use_cb or args.draft_model) and not args.genai: + if (args.use_cb or args.draft_model) and optimum: raise RuntimeError("Continuous batching mode supported only via OpenVINO GenAI") cb_config = None if args.cb_config: @@ -169,6 +175,11 @@ def get_use_case(model_name_or_path): config = json.loads(config_file.read_text()) except Exception: config = None + if (Path(model_name_or_path) / "model_index.json").exists(): + diffusers_config = json.loads((Path(model_name_or_path) / "model_index.json").read_text()) + pipe_type = diffusers_config.get("_class_name") + if pipe_type in ["StableDiffusionPipeline", "StableDiffusionXLPipeline", "StableDiffusion3Pipeline", "FluxPipeline", "LatentConsistencyModelPipeline"]: + return "image_gen", pipe_type.replace("Pipeline", "") if config is not None: for case, model_ids in USE_CASES.items(): diff --git a/tools/llm_bench/llm_bench_utils/ov_utils.py b/tools/llm_bench/llm_bench_utils/ov_utils.py index cf0d0d831c..9ebd1363e3 100644 --- a/tools/llm_bench/llm_bench_utils/ov_utils.py +++ b/tools/llm_bench/llm_bench_utils/ov_utils.py @@ -11,7 +11,7 @@ import json import types from llm_bench_utils.hook_common import get_bench_hook -from llm_bench_utils.config_class import OV_MODEL_CLASSES_MAPPING, TOKENIZE_CLASSES_MAPPING, DEFAULT_MODEL_CLASSES +from llm_bench_utils.config_class import OV_MODEL_CLASSES_MAPPING, TOKENIZE_CLASSES_MAPPING, DEFAULT_MODEL_CLASSES, IMAGE_GEN_CLS import openvino.runtime.opset13 as opset from transformers import pipeline @@ -171,11 +171,13 @@ def create_text_gen_model(model_path, device, **kwargs): if not model_path_existed: raise RuntimeError(f'==Failure ==: model path:{model_path} does not exist') else: - if kwargs.get("genai", False) and is_genai_available(log_msg=True): + if kwargs.get("genai", True) and is_genai_available(log_msg=True): if model_class not in [OV_MODEL_CLASSES_MAPPING[default_model_type], OV_MODEL_CLASSES_MAPPING["mpt"], OV_MODEL_CLASSES_MAPPING["chatglm"]]: log.warning("OpenVINO GenAI based benchmarking is not available for {model_type}. Will be switched to default benchmarking") else: + log.info("Selected OpenVINO GenAI for benchmarking") return create_genai_text_gen_model(model_path, device, ov_config, **kwargs) + log.info("Selected Optimum Intel for benchmarking") remote_code = False try: model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=False) @@ -295,23 +297,23 @@ def convert_ov_tokenizer(tokenizer_path): def create_image_gen_model(model_path, device, **kwargs): - default_model_type = DEFAULT_MODEL_CLASSES[kwargs['use_case']] - model_type = kwargs.get('model_type', default_model_type) - model_class = OV_MODEL_CLASSES_MAPPING[model_type] + model_class = IMAGE_GEN_CLS model_path = Path(model_path) ov_config = kwargs['config'] if not Path(model_path).exists(): raise RuntimeError(f'==Failure ==: model path:{model_path} does not exist') else: - if kwargs.get("genai", False) and is_genai_available(log_msg=True): + if kwargs.get("genai", True) and is_genai_available(log_msg=True): + log.info("Selected OpenVINO GenAI for benchmarking") return create_genai_image_gen_model(model_path, device, ov_config, **kwargs) + log.info("Selected Optimum Intel for benchmarking") start = time.perf_counter() ov_model = model_class.from_pretrained(model_path, device=device, ov_config=ov_config) end = time.perf_counter() from_pretrained_time = end - start log.info(f'From pretrained time: {from_pretrained_time:.2f}s') - return ov_model, from_pretrained_time, False + return ov_model, from_pretrained_time, False, None def get_genai_clip_text_encoder(model_index_data, model_path, device, ov_config): @@ -350,6 +352,51 @@ def get_genai_unet_model(model_index_data, model_path, device, ov_config): def create_genai_image_gen_model(model_path, device, ov_config, **kwargs): import openvino_genai + class PerfCollector: + def __init__(self) -> types.NoneType: + self.iteration_time = [] + self.start_time = time.perf_counter() + self.duration = -1 + + def __call__(self, step, latents): + self.iteration_time.append(time.perf_counter() - self.start_time) + self.start_time = time.perf_counter() + return False + + def reset(self): + self.iteration_time = [] + self.start_time = time.perf_counter() + self.duration = -1 + + def get_1st_unet_latency(self): + return self.iteration_time[0] * 1000 if len(self.iteration_time) > 0 else 0 + + def get_2nd_unet_latency(self): + return sum(self.iteration_time[1:]) / (len(self.iteration_time) - 1) * 1000 if len(self.iteration_time) > 1 else 0 + + def get_unet_latency(self): + return (sum(self.iteration_time) / len(self.iteration_time)) * 1000 if len(self.iteration_time) > 0 else 0 + + def get_vae_decoder_latency(self): + if self.duration != -1: + vae_time = self.duration - sum(self.iteration_time) + return vae_time * 1000 + return 0 + + def get_text_encoder_latency(self): + return -1 + + def get_text_encoder_step_count(self): + return -1 + + def get_unet_step_count(self): + return len(self.iteration_time) + + def get_vae_decoder_step_count(self): + return 1 + + callback = PerfCollector() + adapter_config = get_lora_config(kwargs.get("lora", None), kwargs.get("lora_alphas", [])) if adapter_config: ov_config['adapters'] = adapter_config @@ -393,7 +440,7 @@ def create_genai_image_gen_model(model_path, device, ov_config, **kwargs): end = time.perf_counter() log.info(f'Pipeline initialization time: {end - start:.2f}s') - return t2i_pipe, end - start, True + return t2i_pipe, end - start, True, callback def create_ldm_super_resolution_model(model_path, device, **kwargs): @@ -414,7 +461,7 @@ def create_ldm_super_resolution_model(model_path, device, **kwargs): def create_genai_speech_2_txt_model(model_path, device, **kwargs): import openvino_genai as ov_genai - if kwargs.get("genai", False) is False: + if kwargs.get("genai", True) is False: raise RuntimeError('==Failure the command line does not set --genai ==') if is_genai_available(log_msg=True) is False: raise RuntimeError('==Failure genai is not enable ==') @@ -442,11 +489,13 @@ def create_speech_2txt_model(model_path, device, **kwargs): if not model_path_existed: raise RuntimeError(f'==Failure ==: model path:{model_path} does not exist') else: - if kwargs.get("genai", False) and is_genai_available(log_msg=True): + if kwargs.get("genai", True) and is_genai_available(log_msg=True): if model_class not in [OV_MODEL_CLASSES_MAPPING[default_model_type]]: log.warning("OpenVINO GenAI based benchmarking is not available for {model_type}. Will be switched to default bencmarking") else: + log.info("Selected OpenVINO GenAI for benchmarking") return create_genai_speech_2_txt_model(model_path, device, **kwargs) + log.info("Selected Optimum Intel for benchmarking") start = time.perf_counter() ov_model = model_class.from_pretrained( model_path, diff --git a/tools/llm_bench/task/image_generation.py b/tools/llm_bench/task/image_generation.py index b6260568bf..f227898ef6 100644 --- a/tools/llm_bench/task/image_generation.py +++ b/tools/llm_bench/task/image_generation.py @@ -41,7 +41,7 @@ def collects_input_args(image_param, model_type, model_name): return input_args -def run_image_generation(image_param, num, image_id, pipe, args, iter_data_list, proc_id, mem_consumption): +def run_image_generation(image_param, num, image_id, pipe, args, iter_data_list, proc_id, mem_consumption, callback=None): set_seed(args['seed']) input_text = image_param['prompt'] input_args = collects_input_args(image_param, args['model_type'], args['model_name']) @@ -104,7 +104,7 @@ def run_image_generation(image_param, num, image_id, pipe, args, iter_data_list, stable_diffusion_hook.clear_statistics() -def run_image_generation_genai(image_param, num, image_id, pipe, args, iter_data_list, proc_id, mem_consumption): +def run_image_generation_genai(image_param, num, image_id, pipe, args, iter_data_list, proc_id, mem_consumption, callback=None): set_seed(args['seed']) input_text = image_param['prompt'] input_args = collects_input_args(image_param, args['model_type'], args['model_name']) @@ -125,9 +125,11 @@ def run_image_generation_genai(image_param, num, image_id, pipe, args, iter_data if num == 0 and args["output_dir"] is not None: for bs_idx, in_text in enumerate(input_text_list): llm_bench_utils.output_file.output_image_input_text(in_text, args, image_id, bs_idx, proc_id) + callback.reset() start = time.perf_counter() - res = pipe.generate(input_text, **input_args).data + res = pipe.generate(input_text, **input_args, callback=callback).data end = time.perf_counter() + callback.duration = end - start if (args['mem_consumption'] == 1 and num == 0) or args['mem_consumption'] == 2: mem_consumption.end_collect_momory_consumption() max_rss_mem_consumption, max_shared_mem_consumption, max_uss_mem_consumption = mem_consumption.get_max_memory_consumption() @@ -155,7 +157,7 @@ def run_image_generation_genai(image_param, num, image_id, pipe, args, iter_data max_rss_mem=max_rss_mem_consumption, max_shared_mem=max_shared_mem_consumption, max_uss_mem=max_uss_mem_consumption, - stable_diffusion=None, + stable_diffusion=callback, prompt_idx=image_id ) metrics_print.print_generated(num, warm_up=(num == 0), generated=rslt_img_fn, prompt_idx=image_id) @@ -163,7 +165,7 @@ def run_image_generation_genai(image_param, num, image_id, pipe, args, iter_data def run_image_generation_benchmark(model_path, framework, device, args, num_iters, mem_consumption): - pipe, pretrain_time, use_genai = FW_UTILS[framework].create_image_gen_model(model_path, device, **args) + pipe, pretrain_time, use_genai, callback = FW_UTILS[framework].create_image_gen_model(model_path, device, **args) iter_data_list = [] input_image_list = get_image_prompt(args) if framework == "ov" and not use_genai: @@ -198,7 +200,7 @@ def run_image_generation_benchmark(model_path, framework, device, args, num_iter for image_id, image_param in enumerate(image_list): p_idx = prompt_idx_list[image_id] iter_timestamp[num][p_idx]['start'] = datetime.datetime.now().isoformat() - image_gen_fn(image_param, num, prompt_idx_list[image_id], pipe, args, iter_data_list, proc_id, mem_consumption) + image_gen_fn(image_param, num, prompt_idx_list[image_id], pipe, args, iter_data_list, proc_id, mem_consumption, callback) iter_timestamp[num][p_idx]['end'] = datetime.datetime.now().isoformat() prefix = '[warm-up]' if num == 0 else '[{}]'.format(num) log.info(f"{prefix}[P{p_idx}] start: {iter_timestamp[num][p_idx]['start']}, end: {iter_timestamp[num][p_idx]['end']}")