diff --git a/models/turbine_models/custom_models/torchbench/README.md b/models/turbine_models/custom_models/torchbench/README.md new file mode 100644 index 000000000..890f90704 --- /dev/null +++ b/models/turbine_models/custom_models/torchbench/README.md @@ -0,0 +1,70 @@ +# SHARK torchbench exports and benchmarks + +## Overview + +This directory serves as a place for scripts and utilities to run a suite of benchmarked inference tasks, showing functionality and performance parity between SHARK/IREE and native torch.compile workflows. It is currently under development and benchmark numbers should not be treated as the best possible result with the current state of IREE compiler optimizations. + +Eventually, we want this process to be a plug-in to the upstream torchbench process, and this will be accomplished by exposing the IREE methodology shown here as a compile/runtime backend for the torch benchmark classes. For now, it is set up for developers as a way to get preliminary results and achieve blanket functionality for the models listed in export.py. + +The setup instructions provided here, in a few cases, use "gfx942" as the IREE/LLVM hip target. This is for MI300x accelerators -- you can find a mapping of AMD targets to their LLVM target architecture [here](https://llvm.org/docs/AMDGPUUsage.html#amdgpu-architecture-table), and replace "gfx942" in the following documentation with your desired target. + +## Setup (docker) + +Use the dockerfile provided with the following build/run commands to execute in docker. +These commands assume a few things about your machine/distro, so please read them and make sure they do what you want. + +```shell +docker build --platform linux/amd64 --tag shark_torchbench --file shark_torchbench.dockerfile . +``` +```shell +docker run -it --network=host --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v ./shark_torchbench_outputs:/SHARK-Turbine/models/turbine_models/custom_models/torchbench/outputs -w /SHARK-Turbine/models/turbine_models/custom_models/torchbench shark_torchbench:latest +``` +```shell +python3 ./export.py --target=gfx942 --device=rocm --compile_to=vmfb --performance --inference --precision=fp16 --float16 --external_weights=safetensors --external_weights_dir=./torchbench_weights/ --output_csv=./outputs/torchbench_results_SHARK.csv +``` + + +## Setup (source) + +### Setup source code and prerequisites + + - pip install torch+rocm packages: +```shell +pip install torch==2.5.0.dev20240801+rocm6.1 torchvision==0.20.0.dev20240801+rocm6.1 torchaudio==2.4.0.dev20240801+rocm6.1 --index-url https://download.pytorch.org/whl/nightly/rocm6.1 + +``` + - Workaround amdsmi error in pre-release pytorch+rocm: +```shell +sudo apt install amd-smi-lib +sudo chown -R $USER:$USER /opt/rocm/share/amd_smi +python3 -m pip install /opt/rocm/share/amd_smi +``` + - Clone torch and expose benchmarking code as a relative module: +```shell +git clone https://github.com/pytorch/pytorch +cd pytorch/benchmarks +touch __init__.py +cd ../.. +``` + - Clone and install pytorch benchmark modules: +```shell +git clone https://github.com/pytorch/benchmark +cd benchmark +python3 install.py --models BERT_pytorch Background_Matting LearningToPaint alexnet dcgan densenet121 hf_Albert hf_Bart hf_Bert hf_GPT2 hf_T5 mnasnet1_0 mobilenet_v2 mobilenet_v3_large nvidia_deeprecommender pytorch_unet resnet18 resnet50 resnet50_32x4d shufflenet_v2_x1_0 squeezenet1_1 timm_nfnet timm_efficientnet timm_regnet timm_resnest timm_vision_transformer timm_vovnet vgg16 +pip install -e . +cd .. +``` + +### Export and compile + +```shell +python ./export.py --target=gfx942 --device=rocm --compile_to=vmfb --performance --inference --precision=fp16 --float16 --external_weights=safetensors --external_weights_dir=./torchbench_weights/ +``` + +### Example of manual benchmark using export and IREE runtime CLI (mobilenet_v3_large) + +```shell + python ./export.py --target=gfx942 --device=rocm --compile_to=vmfb --performance --inference --precision=fp16 --float16 --external_weights=safetensors --external_weights_dir=./torchbench_weights/ --model_id=mobilenet_v3_large + +iree-benchmark-module --module=generated/mobilenet_v3_large_256_fp16_gfx942.vmfb --input=@generated/mobilenet_v3_large_input0.npy --parameters=model=./torchbench_weights/mobilenet_v3_large_fp16.irpa --device=hip://0 --device_allocator=caching --function=main --benchmark_repetitions=10 +``` \ No newline at end of file diff --git a/models/turbine_models/custom_models/torchbench/__init__.py b/models/turbine_models/custom_models/torchbench/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/models/turbine_models/custom_models/torchbench/cmd_opts.py b/models/turbine_models/custom_models/torchbench/cmd_opts.py new file mode 100644 index 000000000..7166293d8 --- /dev/null +++ b/models/turbine_models/custom_models/torchbench/cmd_opts.py @@ -0,0 +1,159 @@ +import argparse +import os +from pathlib import Path + + +def path_expand(s): + return Path(s).expanduser().resolve() + + +def is_valid_file(arg): + if not os.path.exists(arg): + return None + else: + return arg + + +# Note: this is where command-line options for the scripts in this directory +# are defined along with their defaults. Thus, they should not be referenced +# within modelling or inference code, only at the entry point to the script. + +# We should consider separating out the options that are "model configs" from +# the options that control the compiler, runtime, and script behavior, +# when applicable, as the former would best be kept in a separate +# config or imported from huggingface. + +p = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter +) + +############################################################################## +# general options +############################################################################## + +p.add_argument( + "--hf_auth_token", + type=str, + help="The Hugging Face auth token, if required", + default=None, +) +p.add_argument( + "--model_id", + type=str, + help="model ID as it appears in the torchbench models text file lists, or 'all' for batch export", + default="all", +) +p.add_argument( + "--model_lists", + type=Path, + nargs="*", + help="path to a JSON list of models to benchmark. One or more paths.", + default=["torchbench_models.json", "timm_models.json", "torchvision_models.json"], +) +p.add_argument( + "--external_weights_dir", + type=str, + default="", + help="Path to external weights file, for jobs with one weights filepath. When importing, this is used to specify where to save the model weights, and at runtime, this is used to specify where to load the model weights from.", +) +p.add_argument( + "--vmfbs_dir", type=str, default="", help="path to vmfb containing compiled module" +) +p.add_argument( + "--benchmark", + type=str, + default=None, + help="A comma-separated list of submodel IDs for which to report benchmarks for, or 'all' for all components.", +) +p.add_argument( + "--save_outputs", + type=str, + default=None, + help="A comma-separated list of submodel IDs for which to save output .npys for, or 'all' for all components.", +) +p.add_argument("--compile_to", type=str, default="mlir", help="torch, linalg, vmfb") +p.add_argument( + "--external_weights", + type=str, + default="irpa", + choices=["safetensors", "irpa", "gguf", None], + help="Externalizes model weights from the torch dialect IR and its successors", +) +p.add_argument( + "--run_benchmark", + type=bool, + default=True, +) +p.add_argument( + "--num_iters", + type=int, + default=10, +) +p.add_argument( + "--output_csv", + type=str, + default="./benchmark_results.csv", +) + +############################################################################## +# Modeling and Export Options +# These options are used to control model defining parameters. +# These are MLIR - changing variables! If you change them, you will need +# to import/download and recompile the model. +############################################################################## + +p.add_argument("--batch_size", type=int, default=1, help="Batch size for inference") +p.add_argument( + "--precision", + type=str, + default="fp16", + help="Precision of Stable Diffusion weights and graph.", +) +p.add_argument( + "--decomp_attn", + default=False, + action="store_true", + help="Decompose attention at fx graph level", +) + +# See --external_weight_path and --external_weight_dir to specify where to save the model weights. + +p.add_argument( + "--compare_vs_torch", + action="store_true", + help="Runs both turbine vmfb and a torch model to compare results", +) +p.add_argument( + "--input_mlir", + type=str, + default=None, + help="Path to input mlir file to compile. Comma-separate paths to provide more than one input to pipelines.", +) + + +############################################################################## +# IREE Compiler Options +############################################################################## + +p.add_argument( + "--device", + type=str, + default="local-task", + help="local-task, local-sync, vulkan://0, rocm://0, cuda://0, etc.", +) +p.add_argument( + "--target", + type=str, + default="gfx942", + help="Usually a rocm chip arch or llvmcpu target triple, e.g. gfx942 or x86_64-linux-gnu.", +) +p.add_argument("--ireec_flags", type=str, default="", help="extra iree-compile options") +p.add_argument( + "--attn_spec", + type=str, + default=None, + help="extra iree-compile options for models with sdpa ops.", +) + + +args, unknown = p.parse_known_args() diff --git a/models/turbine_models/custom_models/torchbench/export.py b/models/turbine_models/custom_models/torchbench/export.py new file mode 100644 index 000000000..b311a865d --- /dev/null +++ b/models/turbine_models/custom_models/torchbench/export.py @@ -0,0 +1,457 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import sys +import gc +import time + +from iree.compiler.ir import Context +from iree import runtime as ireert +import numpy as np +from shark_turbine.aot import * +from shark_turbine.dynamo.passes import ( + DEFAULT_DECOMPOSITIONS, +) +from turbine_models.custom_models.torchbench import utils +import torch +import torch._dynamo as dynamo +from huggingface_hub import hf_hub_download +from safetensors import safe_open +import argparse +from turbine_models.turbine_tank import turbine_tank +from turbine_models.model_runner import vmfbRunner + +from pytorch.benchmarks.dynamo.common import parse_args +from pytorch.benchmarks.dynamo.torchbench import ( + TorchBenchmarkRunner, + setup_torchbench_cwd, +) + +import csv + +torchbench_models_all = { + # "BERT_pytorch": { + # "dim": 128, + # }, # Dynamo Export Issue + # "Background_Matting": { + # "dim": 16, + # }, # Transpose Bubbling Pattern Failed + "LearningToPaint": { + "dim": 1024, + }, + "alexnet": { + "dim": 1024, + }, + "densenet121": { + "dim": 64, + }, + # "hf_Albert": {"dim": 32, "buffer_prefix": "albert"}, + # "hf_Bart": { + # "dim": 16, + # }, + # "hf_Bert": { + # "dim": 16, + # "buffer_prefix": "bert" + # }, + # "hf_GPT2": { + # "dim": 16, + # "buffer_prefix": "gpt2" + # }, + # "hf_T5": { + # "dim": 4, + # "buffer_prefix": "t5" + # }, + "mnasnet1_0": { + "dim": 256, + }, + "mobilenet_v2": { + "dim": 128, + }, + "mobilenet_v3_large": { + "dim": 256, + }, + # "nvidia_deeprecommender": { + # "dim": 1024, + # }, + "pytorch_unet": { + "dim": 8, + }, + "resnet18": { + "dim": 512, + }, + "resnet50": { + "dim": 128, + }, + "resnext50_32x4d": { + "dim": 128, + }, + "shufflenet_v2_x1_0": { + "dim": 512, + }, + "squeezenet1_1": { + "dim": 512, + }, + # "timm_nfnet": { + # "dim": 256, + # }, + "timm_efficientnet": { + "dim": 128, + }, + "timm_regnet": { + "dim": 128, + }, + "timm_resnest": { + "dim": 256, + }, + # "timm_vision_transformer": { + # "dim": 256, + # "decomp_attn": True, + # }, + "timm_vovnet": { + "dim": 128, + }, + # "vgg16": { + # "dim": 128, + # }, +} + + +# Adapted from pytorch.benchmarks.dynamo.common.main() +def get_runner(tb_dir, tb_args): + if tb_dir: + os.chdir(tb_dir) + runner = TorchBenchmarkRunner() + runner.args = parse_args(tb_args) + runner.setup_amp() + runner.model_iter_fn = runner.forward_pass + return runner + + +def get_model_and_inputs(model_id, batch_size, tb_dir, tb_args, get_baseline=False): + runner = get_runner(tb_dir, tb_args) + _, model_name, model, forward_args, _ = runner.load_model( + "cuda:0", + model_id, + batch_size=batch_size, + ) + match get_baseline: + case True: + start_t = time.time() + res = runner.forward_pass(model, forward_args, collect_outputs=True) + baseline = time.time() - start_t + return model_name, model, forward_args, res, baseline + case False: + return model_name, model, forward_args + + +""" +Imports models from torchbench model tooling, exports them with turbine AOT, and does simple benchmarking. +""" + + +@torch.no_grad() +def benchmark_torchbench_model( + model_id, + tb_dir, + tb_args, + precision, + batch_size=1, + compile_to="vmfb", + external_weights=None, + external_weights_dir=None, + device=None, + target=None, + ireec_flags=None, + decomp_attn=False, + exit_on_vmfb=False, + attn_spec=None, + input_mlir=None, + weights_only=False, + upload_ir=False, + compare_vs_eager=False, +): + static_dim = torchbench_models_dict[model_id]["dim"] + dtype = torch.float16 if precision == "fp16" else torch.float32 + np_dtype = "float16" if precision == "fp16" else "float32" + safe_name = utils.create_safe_name( + model_id, + f"_{static_dim}_{precision}", + ) + safe_name = os.path.join("generated", safe_name) + if decomp_attn: + safe_name += "_decomp_attn" + + if not os.path.exists("generated"): + os.mkdir("generated") + + if input_mlir: + vmfb_path = utils.compile_to_vmfb( + input_mlir, + device, + target, + ireec_flags, + safe_name, + mlir_source="file", + return_path=not exit_on_vmfb, + attn_spec=attn_spec, + ) + return vmfb_path + + if compare_vs_eager: + model_name, model, forward_args, golden, baseline = get_model_and_inputs( + model_id, batch_size, tb_dir, tb_args, get_baseline=True + ) + else: + model_name, model, forward_args = get_model_and_inputs( + model_id, batch_size, tb_dir, tb_args + ) + golden = None + baseline = None + + if dtype == torch.float16: + model = model.half() + model.to("cuda:0") + + if not isinstance(forward_args, dict): + forward_args = [i.type(dtype) for i in forward_args] + for idx, i in enumerate(forward_args): + np.save( + os.path.join("generated", f"{model_id}_input{idx}"), + i.clone().detach().cpu(), + ) + else: + for idx, i in enumerate(forward_args.values()): + np.save(f"{model_id}_input{idx}", i.clone().detach().cpu()) + + mapper = {} + if external_weights_dir is not None: + if not os.path.exists(external_weights_dir): + os.mkdir(external_weights_dir) + external_weight_path = os.path.join( + external_weights_dir, f"{model_id}_{precision}.irpa" + ) + else: + external_weight_path = None + + decomp_list = [torch.ops.aten.reflection_pad2d] + if decomp_attn == True or torchbench_models_dict[model_id].get("decomp_attn"): + print("decomposing attention for: " + model_id) + decomp_list.extend( + [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention, + torch.ops.aten.scaled_dot_product_attention, + ] + ) + with decompositions.extend_aot_decompositions( + from_current=True, + add_ops=decomp_list, + ): + if "hf" in model_id: + + class HF_M(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.mod = model + + def forward(self, inp): + return self.mod(**inp) + + if "Bart" not in model_id: + # In some transformers models, the position ids buffer is registered as non-persistent, + # which makes it fail to globalize in the FX import. + # Add them manually to the state dict here. + + prefix = torchbench_models_dict[model_id]["buffer_prefix"] + getattr(model, prefix).embeddings.register_buffer( + "position_ids", + getattr(model, prefix).embeddings.position_ids, + persistent=True, + ) + fxb = FxProgramsBuilder(HF_M(model)) + + @fxb.export_program(args=(forward_args,)) + def _forward(module: HF_M(model), inputs): + return module(inputs) + + else: + fxb = FxProgramsBuilder(model) + + @fxb.export_program(args=(forward_args,)) + def _forward(module, inputs): + return module(*inputs) + + class CompiledTorchbenchModel(CompiledModule): + main = _forward + + if external_weights: + externalize_module_parameters(model) + save_module_parameters(external_weight_path, model) + + inst = CompiledTorchbenchModel(context=Context(), import_to="IMPORT") + + module = CompiledModule.get_mlir_module(inst) + model.to("cpu") + del model + if compile_to != "vmfb": + return str(module) + else: + vmfb_path = utils.compile_to_vmfb( + str(module), + device, + target, + ireec_flags, + safe_name, + return_path=not exit_on_vmfb, + attn_spec=attn_spec, + ) + return vmfb_path, external_weight_path, forward_args, golden, baseline + + +def _run_iter(runner, inputs): + start = time.time() + res = runner.ctx.modules.compiled_torchbench_model["main"](*inputs) + return res, time.time() - start + + +def do_compare(shark_results, shark_latency, golden_results, golden_latency): + numerics_pass_fail = np.allclose( + shark_results.to_host(), + golden_results.clone().cpu().numpy(), + rtol=1e-4, + atol=1e-4, + ) + speedup = golden_latency / shark_latency + return speedup, numerics_pass_fail + + +def run_benchmark( + device, + vmfb_path, + weights_path, + example_args, + model_id, + csv_path, + iters, + golden=None, + baseline=None, +): + if "rocm" in device: + device = "hip" + device.split("rocm")[-1] + mod_runner = vmfbRunner(device, vmfb_path, weights_path) + inputs = torch_to_iree(mod_runner, example_args) + iter_latencies = [] + for i in range(iters): + results, iter_latency = _run_iter(mod_runner, inputs) + iter_latencies.append(iter_latency) + avg_latency = sum(iter_latencies) / len(iter_latencies) + it_per_sec = 1 / avg_latency + + if golden is not None and baseline is not None: + speedup, numerics_pass_fail = do_compare(results, avg_latency, golden, baseline) + else: + speedup, numerics_pass_fail = ("N/A", "N/A") + + needs_header = True + if os.path.exists(csv_path): + needs_header = False + with open(csv_path, "a") as csvfile: + fieldnames = [ + "model", + "avg_latency", + "avg_iter_per_sec", + "speedup_over_eager", + "numerics", + ] + data = [ + { + "model": model_id, + "avg_latency": avg_latency, + "avg_iter_per_sec": it_per_sec, + "speedup_over_eager": speedup, + "numerics": numerics_pass_fail, + } + ] + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + if needs_header: + writer.writeheader() + writer.writerows(data) + print(data) + + +def torch_to_iree(iree_runner, example_args): + if isinstance(example_args, dict): + iree_args = [ + ireert.asdevicearray(iree_runner.config.device, i.clone().detach().cpu()) + for i in example_args.values() + ] + else: + iree_args = [ + ireert.asdevicearray(iree_runner.config.device, i.clone().detach().cpu()) + for i in example_args + ] + return iree_args + + +def run_main(model_id, args, tb_dir, tb_args): + print(f"exporting {model_id}") + mod_str, weights_path, example_args, golden, baseline = benchmark_torchbench_model( + model_id, + tb_dir, + tb_args, + precision=args.precision, + batch_size=args.batch_size, + compile_to=args.compile_to, + external_weights=args.external_weights, + external_weights_dir=args.external_weights_dir, + device=args.device, + target=args.target, + ireec_flags=args.ireec_flags, + decomp_attn=args.decomp_attn, + attn_spec=args.attn_spec, + input_mlir=args.input_mlir, + compare_vs_eager=args.compare_vs_torch, + ) + if args.compile_to in ["torch", "mlir"]: + safe_name = utils.create_safe_name( + model_id, + f"_{static_dim}_{args.precision}", + ) + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") + elif args.run_benchmark: + run_benchmark( + args.device, + mod_str, + weights_path, + example_args, + model_id, + args.output_csv, + args.num_iters, + golden, + baseline, + ) + + gc.collect() + + +if __name__ == "__main__": + from turbine_models.custom_models.torchbench.cmd_opts import args, unknown + import json + + for list in args.model_lists: + with open(list, "r") as f: + torchbench_models_dict = json.load(f) + + tb_dir = setup_torchbench_cwd() + if args.model_id.lower() == "all": + for name in torchbench_models_dict.keys(): + run_main(name, args, tb_dir, unknown) + else: + run_main(args.model_id, args, tb_dir, unknown) diff --git a/models/turbine_models/custom_models/torchbench/shark_torchbench.dockerfile b/models/turbine_models/custom_models/torchbench/shark_torchbench.dockerfile new file mode 100644 index 000000000..d93eb7009 --- /dev/null +++ b/models/turbine_models/custom_models/torchbench/shark_torchbench.dockerfile @@ -0,0 +1,53 @@ +FROM rocm/dev-ubuntu-22.04:6.1.2 + +# ###################################################### +# # Install MLPerf+Shark reference implementation +# ###################################################### +ENV DEBIAN_FRONTEND=noninteractive + +# apt dependencies +RUN apt-get update && apt-get install -y \ +ffmpeg libsm6 libxext6 git wget unzip \ + software-properties-common git \ + build-essential curl cmake ninja-build clang lld vim nano python3.10-dev python3.10-venv && \ + apt-get clean && rm -rf /var/lib/apt/lists/* +RUN pip install --upgrade pip setuptools wheel && \ + pip install pybind11 'nanobind<2' numpy==1.* pandas && \ + pip install hip-python hip-python-as-cuda -i https://test.pypi.org/simple + +# Rust requirements +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y +ENV PATH="/root/.cargo/bin:${PATH}" + +SHELL ["/bin/bash", "-c"] + +# Disable apt-key parse waring +ARG APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=1 + +###################################################### +# Install SHARK-Turbine +###################################################### +RUN pip3 install torch==2.4.0+rocm6.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1 +RUN pip3 install --pre iree-compiler==20240920.1022 iree-runtime==20240920.1022 -f https://iree.dev/pip-release-links.html + +RUN apt install amd-smi-lib && sudo chown -R $USER:$USER /opt/rocm/share/amd_smi && python3 -m pip install /opt/rocm/share/amd_smi +# Install turbine-models, where the export is implemented. + +ENV TB_SHARK_DIR=/SHARK-Turbine/models/turbine_models/custom_models/torchbench + +RUN git clone https://github.com/nod-ai/SHARK-Turbine -b torchbench \ + && cd SHARK-Turbine \ + && pip install --pre --upgrade -e models -r models/requirements.txt \ + && cd $TB_SHARK_DIR \ + && git clone https://github.com/pytorch/pytorch \ + && cd pytorch/benchmarks \ + && touch __init__.py && cd ../.. \ + && git clone https://github.com/pytorch/benchmark && cd benchmark \ + && python3 install.py --models BERT_pytorch Background_Matting LearningToPaint alexnet dcgan densenet121 hf_Albert hf_Bart hf_Bert hf_GPT2 hf_T5 mnasnet1_0 mobilenet_v2 mobilenet_v3_large nvidia_deeprecommender pytorch_unet resnet18 resnet50 resnet50_32x4d shufflenet_v2_x1_0 squeezenet1_1 timm_nfnet timm_efficientnet timm_regnet timm_resnest timm_vision_transformer timm_vovnet vgg16 \ + && pip install -e . + +ENV HF_HOME=/models/huggingface/ + +# initialization settings for CPX mode +ENV HSA_USE_SVM=0 +ENV HSA_XNACK=0 \ No newline at end of file diff --git a/models/turbine_models/custom_models/torchbench/timm_models.json b/models/turbine_models/custom_models/torchbench/timm_models.json new file mode 100644 index 000000000..e69de29bb diff --git a/models/turbine_models/custom_models/torchbench/torchbench_models.json b/models/turbine_models/custom_models/torchbench/torchbench_models.json new file mode 100644 index 000000000..46d4f06cf --- /dev/null +++ b/models/turbine_models/custom_models/torchbench/torchbench_models.json @@ -0,0 +1,50 @@ +{ + "LearningToPaint": { + "dim": 1024 + }, + "alexnet": { + "dim": 1024 + }, + "densenet121": { + "dim": 64 + }, + "mnasnet1_0": { + "dim": 256 + }, + "mobilenet_v2": { + "dim": 128 + }, + "mobilenet_v3_large": { + "dim": 256 + }, + "pytorch_unet": { + "dim": 8 + }, + "resnet18": { + "dim": 512 + }, + "resnet50": { + "dim": 128 + }, + "resnext50_32x4d": { + "dim": 128 + }, + "shufflenet_v2_x1_0": { + "dim": 512 + }, + "squeezenet1_1": { + "dim": 512 + }, + "timm_efficientnet": { + "dim": 128 + }, + "timm_regnet": { + "dim": 128 + }, + "timm_resnest": { + "dim": 256 + }, + "timm_vovnet": { + "dim": 128 + } +} \ No newline at end of file diff --git a/models/turbine_models/custom_models/torchbench/utils.py b/models/turbine_models/custom_models/torchbench/utils.py new file mode 100644 index 000000000..325fc0229 --- /dev/null +++ b/models/turbine_models/custom_models/torchbench/utils.py @@ -0,0 +1,506 @@ +from urllib.request import urlopen +import iree.compiler as ireec +import numpy as np +import os +import safetensors +import safetensors.numpy as safe_numpy +import safetensors.torch as safe_torch +import re +import glob + +# If flags are verified to work on a specific model and improve performance without regressing numerics, add them to this dictionary. If you are working with bleeding edge flags, please add them manually with the --ireec_flags argument. +MI_flags = { + "all": [ + "--iree-global-opt-propagate-transposes=true", + "--iree-dispatch-creation-enable-fuse-horizontal-contractions=true", + "--iree-dispatch-creation-enable-aggressive-fusion=true", + "--iree-opt-aggressively-propagate-transposes=true", + "--iree-opt-outer-dim-concat=true", + "--iree-vm-target-truncate-unsupported-floats", + "--iree-llvmgpu-enable-prefetch=true", + "--iree-opt-data-tiling=false", + "--iree-codegen-gpu-native-math-precision=true", + "--iree-codegen-llvmgpu-use-vector-distribution", + "--iree-hip-waves-per-eu=2", + "--iree-execution-model=async-external", + ], + "preprocess_default": [], +} +GFX11_flags = { + "all": [ + "--iree-global-opt-propagate-transposes=true", + "--iree-opt-outer-dim-concat=true", + "--iree-vm-target-truncate-unsupported-floats", + "--iree-llvmgpu-enable-prefetch=true", + "--iree-opt-data-tiling=false", + "--iree-opt-const-eval=false", + "--iree-opt-aggressively-propagate-transposes=true", + "--iree-dispatch-creation-enable-aggressive-fusion", + "--iree-codegen-gpu-native-math-precision=true", + "--iree-codegen-llvmgpu-use-vector-distribution=true", + ], + "preprocess_default": [ + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, iree-preprocessing-pad-to-intrinsics)", + "--iree-codegen-llvmgpu-enable-transform-dialect-jit=false", + ], +} +znver4_flags = { + "all": [ + "--iree-llvmcpu-target-cpu=znver4", + "--iree-opt-const-eval=false", + "--iree-llvmcpu-enable-ukernels=mmt4d,pack,unpack", + "--iree-dispatch-creation-collapse-reduction-dims", + "--iree-opt-const-expr-max-size-increase-threshold=1000000000000000", + "--iree-dispatch-creation-enable-fuse-padding-into-linalg-consumer-ops", + ], + "bf16": [ + "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-demote-contraction-inputs-to-bf16))", + ], + "winograd": [ + "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-linalg-ext-convert-conv2d-to-winograd{replace-all-convs=true},iree-global-opt-demote-contraction-inputs-to-bf16))" + ], +} + +_IREE_DRIVER_MAP = { + "cpu": "local-task", + "cpu-task": "local-task", + "cpu-sync": "local-sync", + "cuda": "cuda", + "vulkan": "vulkan", + "metal": "metal", + "rocm": "hip", + "rocm-legacy": "rocm", + "hip": "hip", + "intel-gpu": "level_zero", +} + +_IREE_BACKEND_MAP = { + "cpu": "llvm-cpu", + "local-task": "llvm-cpu", + "local-sync": "llvm-cpu", + "rocm": "rocm", + "rocm-legacy": "rocm", + "hip": "rocm", + "cuda": "cuda", + "vulkan": "vulkan-spirv", + "metal": "metal", +} + + +def iree_device_map(device): + uri_parts = device.split("://", 2) + iree_driver = ( + _IREE_DRIVER_MAP[uri_parts[0]] + if uri_parts[0] in _IREE_DRIVER_MAP + else uri_parts[0] + ) + if len(uri_parts) == 1: + return iree_driver + else: + return f"{iree_driver}://{uri_parts[1]}" + + +def iree_backend_map(device): + uri_parts = device.split("://", 2) + iree_device = ( + _IREE_BACKEND_MAP[uri_parts[0]] + if uri_parts[0] in _IREE_BACKEND_MAP + else uri_parts[0] + ) + return iree_device + + +def replace_with_tk_kernels(tk_kernels_dir, flow_dialect_ir, batch_size): + kernels = glob.glob(tk_kernels_dir + "/bs" + str(batch_size) + "/*") + + # Replace all calls to old kernel with new kernel + print("Inserting kernels and updating calls to kernels...") + kernel_name = {} + for kernel in kernels: + kernel_name[kernel] = kernel.split("/")[-1].split(".")[0] + kernel_map = {} + prefix_map = {} + + base = flow_dialect_ir.split("\n") + new_base = [] + for line in base: + for kernel in kernels: + suffix = kernel.split(".")[0].split("_")[-1] + if "bias" in suffix: + suffix = kernel.split(".")[0].split("_")[-2] + B, M, N, K = suffix.split("x") + old_kernel = f"matmul_like_{B}x{M}x{N}x{K}" + if not old_kernel in line: + continue + if old_kernel in line and "func.func" in line: + num_args = line.count("arg") + with open(kernel, "r") as f: + data = f.readlines() + idx_with_kernel_args = [ + idx for idx, s in enumerate(data) if "func.func" in s + ][0] + kernel_args = data[idx_with_kernel_args].count("arg") + if num_args != kernel_args: + continue + kernel_map[kernel] = line.strip().split(" ")[1][1:-7] + prefix_map[kernel] = kernel_map[kernel].split(old_kernel)[0][:-1] + if ( + old_kernel in line + and "flow.dispatch" in line + and not "func.func" in line + ): + line = line.replace(kernel_map[kernel], kernel_name[kernel]) + line = line.replace(prefix_map[kernel], kernel_name[kernel]) + new_base.append(line) + # Insert kernels in appropriate locations + final_ir = [] + for line in new_base: + for kernel in kernels: + if ( + prefix_map[kernel] + " {" in line + and "flow.executable" in line + and "private" in line + ): + with open(kernel, "r") as f: + data = f.readlines() + translation_info = data[0].split("#translation = ")[1].strip() + extract = "".join(data[2:-2]) + extract = extract.replace("#translation", translation_info) + final_ir += extract + final_ir.append(line) + + print("tk kernels added") + return final_ir + + +def compile_to_vmfb( + module_str, + device, + target_triple, + ireec_flags=[""], + safe_name="model", + return_path=False, + const_expr_hoisting=True, + mlir_source="str", + max_alloc="4294967296", + save_mlir=True, + attn_spec=None, + winograd=False, + flagset_keywords=[], + debug=False, + add_tk_kernels=False, + tk_kernels_dir=None, + batch_size=1, +): + if ireec_flags is not None and "masked_attention" in ireec_flags: + flagset_keywords = ["masked_attention"] + ireec_flags = "".join(ireec_flags.split("masked_attention")) + masked_attention = True + else: + masked_attention = False + if ireec_flags is not None and "winograd" in ireec_flags: + winograd = True + ireec_flags = "".join(ireec_flags.split("winograd")) + if batch_size != 1 and batch_size != 8: + add_tk_kernels = False + flags = [] + if mlir_source == "file" and not isinstance(module_str, str): + module_str = str(module_str) + if target_triple in ["", None]: + if device == "cpu": + target_triple = "x86_64-linux-gnu" + else: + raise ValueError( + "target_triple must be set. Usually this can be fixed by setting --iree_target_triple in the CLI." + ) + if device in ["cpu", "llvm-cpu"]: + if target_triple == "znver4": + flags.extend(znver4_flags["all"]) + if winograd: + flags.extend(znver4_flags["winograd"]) + else: + flags.extend( + [ + "--iree-llvmcpu-target-triple=" + target_triple, + "--iree-llvmcpu-target-cpu-features=host", + "--iree-llvmcpu-fail-on-out-of-bounds-stack-allocation=false", + "--iree-llvmcpu-distribution-size=32", + "--iree-opt-const-eval=false", + "--iree-llvmcpu-enable-ukernels=all", + "--iree-global-opt-enable-quantized-matmul-reassociation", + ] + ) + device = "llvm-cpu" + elif device in ["vulkan", "vulkan-spirv"]: + flags.extend( + [ + "--iree-hal-target-backends=vulkan-spirv", + "--iree-vulkan-target-triple=" + target_triple, + "--iree-stream-resource-max-allocation-size=" + max_alloc, + "--iree-stream-resource-index-bits=64", + "--iree-vm-target-index-bits=64", + "--iree-dispatch-creation-inline-constants-max-byte-length=1", + ] + ) + device = "vulkan-spirv" + elif device in ["rocm", "hip"]: + flags.extend( + [ + "--iree-hal-target-backends=rocm", + "--iree-hip-target=" + target_triple, + "--iree-vm-bytecode-module-output-format=flatbuffer-binary", + ] + ) + elif device == "cuda": + flags.extend( + [ + "--iree-hal-target-backends=cuda", + "--iree-hal-cuda-llvm-target-arch=" + target_triple, + "--iree-vm-target-truncate-unsupported-floats", + ] + ) + else: + print("incorrect device: ", device) + if isinstance(ireec_flags, str): + if ireec_flags != "": + ireec_flags = ireec_flags.split(",") + elif ireec_flags == None: + ireec_flags = [] + + if debug: + flags.extend( + ["--iree-hal-dump-executable-files-to=" + safe_name + "_dispatches"] + ) + + if target_triple in ["gfx940", "gfx941", "gfx942", "gfx90a"]: + flags.extend(MI_flags["all"]) + flags.extend(MI_flags["preprocess_default"]) + + if "gfx11" in target_triple: + flags.extend(GFX11_flags["all"]) + flags.extend(GFX11_flags["preprocess_default"]) + + # Currently, we need a transform dialect script to be applied to the compilation through IREE in certain cases. + # This 'attn_spec' handles a linalg_ext.attention op lowering to mfma instructions for capable targets. + # This is a temporary solution, and should be removed or largely disabled once the functionality of + # the TD spec is implemented in C++. + + if attn_spec in ["default", "mfma", "punet"]: + # if any(x in safe_name for x in ["clip", "prompt_encoder"]) == False: + use_punet = True if attn_spec in ["punet", "i8"] else False + attn_spec = get_mfma_spec_path( + target_triple, + os.path.dirname(safe_name), + use_punet=use_punet, + masked_attention=masked_attention, + ) + flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) + + elif attn_spec in ["wmma"] or ("gfx11" in target_triple and not attn_spec): + attn_spec = get_wmma_spec_path( + target_triple, os.path.dirname(safe_name), masked_attention=masked_attention + ) + if attn_spec: + flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) + elif attn_spec and attn_spec != "None": + if any(x in safe_name for x in ["clip", "prompt_encoder"]) == False: + flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) + + for i, flag in enumerate(ireec_flags): + k = flag.strip().split("=")[0] + for idx, default in enumerate(flags): + if default == None: + flags.pop(idx) + continue + elif k == default.split("=")[0]: + flags[idx] = flag if flag.split("=")[-1] not in ["None", ""] else None + flag = None + if flags[idx] == None: + flags.pop(idx) + continue + if flag not in [None, "", " "] and flag.split("=")[-1] not in ["None", ""]: + flags.append(flag) + + for idx, flag in enumerate(flags): + if flag is None: + flags.pop(idx) + input_ir_type = "torch" + if add_tk_kernels: + print("Adding tk kernels") + flags.extend(["--compile-to=flow"]) + if mlir_source == "file": + flatbuffer_blob = ireec.compile_file( + module_str, + target_backends=[device], + input_type=input_ir_type, + extra_args=flags, + ) + elif mlir_source == "str": + flatbuffer_blob = ireec.compile_str( + module_str, + target_backends=[device], + input_type=input_ir_type, + extra_args=flags, + ) + + flow_ir = flatbuffer_blob.decode("utf-8") + + flow_ir_tk = replace_with_tk_kernels(tk_kernels_dir, flow_ir, batch_size) + module_str = "\n".join(flow_ir_tk) + flags.pop() + flags.extend(["--compile-from=flow"]) + mlir_source = "str" + input_ir_type = "auto" + + print("Compiling to", device, "with flags:", flags) + + # Forces a standard for naming files: + # If safe_name has target triple in it, get rid of target triple in mlir name + # + if target_triple not in safe_name: + safe_vmfb_name = safe_name + "_" + target_triple + safe_mlir_name = safe_name + else: + safe_vmfb_name = safe_name + safe_mlir_name = "".join(safe_name.split(target_triple)) + + if mlir_source == "file": + flatbuffer_blob = ireec.compile_file( + module_str, + target_backends=[device], + input_type=input_ir_type, + extra_args=flags, + ) + elif mlir_source == "str": + if save_mlir: + with open(f"{safe_mlir_name}.mlir", "w+") as f: + f.write(module_str) + print("Saved to", safe_mlir_name + ".mlir") + flatbuffer_blob = ireec.compile_str( + module_str, + target_backends=[device], + input_type=input_ir_type, + extra_args=flags, + ) + else: + raise ValueError("mlir_source must be either 'file' or 'str'") + with open(f"{safe_vmfb_name}.vmfb", "wb+") as f: + f.write(flatbuffer_blob) + print(f"Saved to {safe_vmfb_name}.vmfb") + if return_path == True: + return safe_vmfb_name + ".vmfb" + + +def create_safe_name(hf_model_name, model_name_str=""): + if not model_name_str: + model_name_str = "" + if model_name_str != "" and (not model_name_str.startswith("_")): + model_name_str = "_" + model_name_str + + safe_name = hf_model_name.split("/")[-1].strip() + model_name_str + safe_name = re.sub("-", "_", safe_name) + safe_name = re.sub("\.", "_", safe_name) + return safe_name + + +def get_mfma_spec_path(target_chip, save_dir, masked_attention=False, use_punet=False): + if use_punet: + suffix = "_punet" + url = "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/specs/attention_and_matmul_spec.mlir" + elif not masked_attention: + suffix = "" + url = "https://raw.githubusercontent.com/iree-org/iree/refs/heads/main/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir" + else: + suffix = "_pad" + url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx942.mlir" + attn_spec = urlopen(url).read().decode("utf-8") + spec_path = os.path.join(save_dir, f"attention_and_matmul_spec_mfma{suffix}.mlir") + with open(spec_path, "w") as f: + f.write(attn_spec) + return spec_path + + +def get_wmma_spec_path(target_chip, save_dir, masked_attention=False): + if not masked_attention: + url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/no_pad/attention_and_matmul_spec_wmma.mlir" + elif target_chip == "gfx1100": + url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx1100.mlir" + elif target_chip in ["gfx1103", "gfx1150"]: + url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx1150.mlir" + else: + return None + attn_spec = urlopen(url).read().decode("utf-8") + suffix = "masked" if masked_attention else "" + spec_path = os.path.join(save_dir, f"attention_and_matmul_spec_wmma{suffix}.mlir") + with open(spec_path, "w") as f: + f.write(attn_spec) + return spec_path + + +def save_external_weights( + mapper, + model, + external_weights=None, + external_weight_file=None, + force_format=False, + vae_harness=False, +): + if external_weights is not None: + if external_weights in ["safetensors", "irpa"]: + mod_params = dict(model.named_parameters()) + mod_buffers = dict(model.named_buffers()) + mod_params.update(mod_buffers) + vae_params = {} + for name in mod_params: + if vae_harness: + vae_params[name.replace("vae.", "")] = mod_params[name] + mapper["params." + name] = name + if vae_harness: + mod_params = vae_params + if external_weight_file and not os.path.isfile(external_weight_file): + if not force_format: + safe_torch.save_file(mod_params, external_weight_file) + else: + for x in mod_params.keys(): + mod_params[x] = mod_params[x].numpy() + safe_numpy.save_file(mod_params, external_weight_file) + print("Saved params to", external_weight_file) + + +def largest_error(array1, array2): + absolute_diff = np.abs(array1 - array2) + max_error = np.max(absolute_diff) + print("Max error:", max_error) + return max_error + + +def get_schedulers(model_id): + # TODO: Robust scheduler setup on pipeline creation -- if we don't + # set batch_size here, the SHARK schedulers will + # compile with batch size = 1 regardless of whether the model + # outputs latents of a larger batch size, e.g. SDXL. + # However, obviously, searching for whether the base model ID + # contains "xl" is not very robust. + + batch_size = 2 if "xl" in model_id.lower() else 1 + + schedulers = dict() + schedulers["PNDM"] = PNDMScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["EulerAncestralDiscrete"] = ( + EulerAncestralDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + ) + # schedulers["DPMSolverSDE"] = DPMSolverSDEScheduler.from_pretrained( + # model_id, + # subfolder="scheduler", + # ) + return schedulers