diff --git a/src/anomalib/deploy/export.py b/src/anomalib/deploy/export.py index 87066c9ef9..aae359c035 100644 --- a/src/anomalib/deploy/export.py +++ b/src/anomalib/deploy/export.py @@ -36,6 +36,18 @@ class ExportType(str, Enum): class CompressionType(str, Enum): """Model compression type when exporting to OpenVINO. + Attributes: + FP16 (str): Weight compression (FP16). All weights are converted to FP16. + INT8 (str): Weight compression (INT8). All weights are quantized to INT8, + but are dequantized to floating point before inference. + INT8_PTQ (str): Full integer post-training quantization (INT8). + All weights and operations are quantized to INT8. Inference is done + in INT8 precision. + INT8_ACQ (str): Accuracy-control quantization (INT8). Weights and + operations are quantized to INT8, except those that would degrade + quality of the model more than is acceptable. Inference is done in + a mixed precision. + Examples: >>> from anomalib.deploy import CompressionType >>> CompressionType.INT8_PTQ @@ -43,20 +55,9 @@ class CompressionType(str, Enum): """ FP16 = "fp16" - """ - Weight compression (FP16) - All weights are converted to FP16. - """ INT8 = "int8" - """ - Weight compression (INT8) - All weights are quantized to INT8, but are dequantized to floating point before inference. - """ INT8_PTQ = "int8_ptq" - """ - Full integer post-training quantization (INT8) - All weights and operations are quantized to INT8. Inference is done in INT8 precision. - """ + INT8_ACQ = "int8_acq" class InferenceModel(nn.Module): diff --git a/src/anomalib/engine/engine.py b/src/anomalib/engine/engine.py index 60c5364e15..05b1d1d6af 100644 --- a/src/anomalib/engine/engine.py +++ b/src/anomalib/engine/engine.py @@ -14,6 +14,7 @@ from lightning.pytorch.trainer import Trainer from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils.data import DataLoader, Dataset +from torchmetrics import Metric from torchvision.transforms.v2 import Transform from anomalib import LearningType, TaskType @@ -871,6 +872,7 @@ def export( transform: Transform | None = None, compression_type: CompressionType | None = None, datamodule: AnomalibDataModule | None = None, + metric: Metric | str | None = None, ov_args: dict[str, Any] | None = None, ckpt_path: str | Path | None = None, ) -> Path | None: @@ -889,7 +891,12 @@ def export( compression_type (CompressionType | None, optional): Compression type for OpenVINO exporting only. Defaults to ``None``. datamodule (AnomalibDataModule | None, optional): Lightning datamodule. - Must be provided if CompressionType.INT8_PTQ is selected. + Must be provided if ``CompressionType.INT8_PTQ`` or `CompressionType.INT8_ACQ`` is selected + (OpenVINO export only). + Defaults to ``None``. + metric (Metric | str | None, optional): Metric to measure quality loss when quantizing. + Must be provided if ``CompressionType.INT8_ACQ`` is selected and must return higher value for better + performance of the model (OpenVINO export only). Defaults to ``None``. ov_args (dict[str, Any] | None, optional): This is optional and used only for OpenVINO's model optimizer. Defaults to None. @@ -915,12 +922,12 @@ def export( 3. To export as an OpenVINO ``.xml`` and ``.bin`` file you can run the following command. ```python anomalib export --model Padim --export_mode openvino --ckpt_path \ - --input_size "[256,256]" + --input_size "[256,256] --compression_type "fp16" ``` - 4. You can also override OpenVINO model optimizer by adding the ``--ov_args.`` arguments. + 4. You can also quantize OpenVINO model with the following. ```python anomalib export --model Padim --export_mode openvino --ckpt_path \ - --input_size "[256,256]" --ov_args.compress_to_fp16 False + --input_size "[256,256]" --compression_type "int8_ptq" --data MVTec ``` """ export_type = ExportType(export_type) @@ -954,6 +961,7 @@ def export( task=self.task, compression_type=compression_type, datamodule=datamodule, + metric=metric, ov_args=ov_args, ) else: diff --git a/src/anomalib/models/components/base/export_mixin.py b/src/anomalib/models/components/base/export_mixin.py index 9b0c2d41e2..3d6f5088da 100644 --- a/src/anomalib/models/components/base/export_mixin.py +++ b/src/anomalib/models/components/base/export_mixin.py @@ -6,7 +6,7 @@ import json import logging -from collections.abc import Callable +from collections.abc import Callable, Iterable from pathlib import Path from tempfile import TemporaryDirectory from typing import TYPE_CHECKING, Any @@ -14,16 +14,23 @@ import numpy as np import torch from torch import nn +from torchmetrics import Metric from torchvision.transforms.v2 import Transform from anomalib import TaskType from anomalib.data import AnomalibDataModule from anomalib.deploy.export import CompressionType, ExportType, InferenceModel +from anomalib.metrics import create_metric_collection from anomalib.utils.exceptions import try_import if TYPE_CHECKING: + from importlib.util import find_spec + from torch.types import Number + if find_spec("openvino") is not None: + from openvino import CompiledModel + logger = logging.getLogger(__name__) @@ -159,6 +166,7 @@ def to_openvino( transform: Transform | None = None, compression_type: CompressionType | None = None, datamodule: AnomalibDataModule | None = None, + metric: Metric | str | None = None, ov_args: dict[str, Any] | None = None, task: TaskType | None = None, ) -> Path: @@ -174,7 +182,11 @@ def to_openvino( compression_type (CompressionType, optional): Compression type for better inference performance. Defaults to ``None``. datamodule (AnomalibDataModule | None, optional): Lightning datamodule. - Must be provided if CompressionType.INT8_PTQ is selected. + Must be provided if ``CompressionType.INT8_PTQ`` or ``CompressionType.INT8_ACQ`` is selected. + Defaults to ``None``. + metric (Metric | str | None, optional): Metric to measure quality loss when quantizing. + Must be provided if ``CompressionType.INT8_ACQ`` is selected and must return higher value for better + performance of the model. Defaults to ``None``. ov_args (dict | None): Model optimizer arguments for OpenVINO model conversion. Defaults to ``None``. @@ -206,6 +218,20 @@ def to_openvino( ... task=datamodule.test_data.task ... ) + Export and Quantize the Model (OpenVINO IR): + This example demonstrates how to export and quantize the model to OpenVINO IR. + + >>> from anomalib.models import Patchcore + >>> from anomalib.data import Visa + >>> datamodule = Visa() + >>> model = Patchcore() + >>> model.to_openvino( + ... export_root="path/to/export", + ... compression_type=CompressionType.INT8_PTQ, + ... datamodule=datamodule, + ... task=datamodule.test_data.task + ... ) + Using Custom Transforms: This example shows how to use a custom ``Transform`` object for the ``transform`` argument. @@ -221,11 +247,7 @@ def to_openvino( if not try_import("openvino"): logger.exception("Could not find OpenVINO. Please check OpenVINO installation.") raise ModuleNotFoundError - if not try_import("nncf"): - logger.exception("Could not find NNCF. Please check NNCF installation.") - raise ModuleNotFoundError - import nncf import openvino as ov with TemporaryDirectory() as onnx_directory: @@ -235,20 +257,8 @@ def to_openvino( ov_args = {} if ov_args is None else ov_args model = ov.convert_model(model_path, **ov_args) - if compression_type == CompressionType.INT8: - model = nncf.compress_weights(model) - elif compression_type == CompressionType.INT8_PTQ: - if datamodule is None: - msg = "Datamodule must be provided for OpenVINO INT8_PTQ compression" - raise ValueError(msg) - - dataloader = datamodule.val_dataloader() - if len(dataloader.dataset) < 300: - logger.warning( - f">300 images recommended for INT8 quantization, found only {len(dataloader.dataset)} images", - ) - calibration_dataset = nncf.Dataset(dataloader, lambda x: x["image"]) - model = nncf.quantize(model, calibration_dataset) + if compression_type and compression_type != CompressionType.FP16: + model = self._compress_ov_model(model, compression_type, datamodule, metric, task) # fp16 compression is enabled by default compress_to_fp16 = compression_type == CompressionType.FP16 @@ -257,6 +267,145 @@ def to_openvino( return ov_model_path + def _compress_ov_model( + self, + model: "CompiledModel", + compression_type: CompressionType | None = None, + datamodule: AnomalibDataModule | None = None, + metric: Metric | str | None = None, + task: TaskType | None = None, + ) -> "CompiledModel": + """Compress OpenVINO model with NNCF. + + model (CompiledModel): Model already exported to OpenVINO format. + compression_type (CompressionType, optional): Compression type for better inference performance. + Defaults to ``None``. + datamodule (AnomalibDataModule | None, optional): Lightning datamodule. + Must be provided if ``CompressionType.INT8_PTQ`` or ``CompressionType.INT8_ACQ`` is selected. + Defaults to ``None``. + metric (Metric | str | None, optional): Metric to measure quality loss when quantizing. + Must be provided if ``CompressionType.INT8_ACQ`` is selected and must return higher value for better + performance of the model. + Defaults to ``None``. + task (TaskType | None): Task type. + Defaults to ``None``. + + Returns: + model (CompiledModel): Model in the OpenVINO format compressed with NNCF quantization. + """ + if not try_import("nncf"): + logger.exception("Could not find NCCF. Please check NNCF installation.") + raise ModuleNotFoundError + + import nncf + + if compression_type == CompressionType.INT8: + model = nncf.compress_weights(model) + elif compression_type == CompressionType.INT8_PTQ: + model = self._post_training_quantization_ov(model, datamodule) + elif compression_type == CompressionType.INT8_ACQ: + model = self._accuracy_control_quantization_ov(model, datamodule, metric, task) + else: + msg = f"Unrecognized compression type: {compression_type}" + raise ValueError(msg) + + return model + + def _post_training_quantization_ov( + self, + model: "CompiledModel", + datamodule: AnomalibDataModule | None = None, + ) -> "CompiledModel": + """Post-Training Quantization model with NNCF. + + model (CompiledModel): Model already exported to OpenVINO format. + datamodule (AnomalibDataModule | None, optional): Lightning datamodule. + Must be provided if ``CompressionType.INT8_PTQ`` or ``CompressionType.INT8_ACQ`` is selected. + Defaults to ``None``. + + Returns: + model (CompiledModel): Quantized model. + """ + import nncf + + if datamodule is None: + msg = "Datamodule must be provided for OpenVINO INT8_PTQ compression" + raise ValueError(msg) + + model_input = model.input(0) + + if model_input.partial_shape[0].is_static: + datamodule.train_batch_size = model_input.shape[0] + + dataloader = datamodule.train_dataloader() + if len(dataloader.dataset) < 300: + logger.warning( + f">300 images recommended for INT8 quantization, found only {len(dataloader.dataset)} images", + ) + + calibration_dataset = nncf.Dataset(dataloader, lambda x: x["image"]) + return nncf.quantize(model, calibration_dataset) + + def _accuracy_control_quantization_ov( + self, + model: "CompiledModel", + datamodule: AnomalibDataModule | None = None, + metric: Metric | str | None = None, + task: TaskType | None = None, + ) -> "CompiledModel": + """Accuracy-Control Quantization with NNCF. + + model (CompiledModel): Model already exported to OpenVINO format. + datamodule (AnomalibDataModule | None, optional): Lightning datamodule. + Must be provided if ``CompressionType.INT8_PTQ`` or ``CompressionType.INT8_ACQ`` is selected. + Defaults to ``None``. + metric (Metric | str | None, optional): Metric to measure quality loss when quantizing. + Must be provided if ``CompressionType.INT8_ACQ`` is selected and must return higher value for better + performance of the model. + Defaults to ``None``. + task (TaskType | None): Task type. + Defaults to ``None``. + + Returns: + model (CompiledModel): Quantized model. + """ + import nncf + + if datamodule is None: + msg = "Datamodule must be provided for OpenVINO INT8_PTQ compression" + raise ValueError(msg) + if metric is None: + msg = "Metric must be provided for OpenVINO INT8_ACQ compression" + raise ValueError(msg) + + model_input = model.input(0) + + if model_input.partial_shape[0].is_static: + datamodule.train_batch_size = model_input.shape[0] + datamodule.eval_batch_size = model_input.shape[0] + + dataloader = datamodule.train_dataloader() + if len(dataloader.dataset) < 300: + logger.warning( + f">300 images recommended for INT8 quantization, found only {len(dataloader.dataset)} images", + ) + + calibration_dataset = nncf.Dataset(dataloader, lambda x: x["image"]) + validation_dataset = nncf.Dataset(datamodule.val_dataloader()) + + if isinstance(metric, str): + metric = create_metric_collection([metric])[metric] + + # validation function to evaluate the quality loss after quantization + def val_fn(nncf_model: "CompiledModel", validation_data: Iterable) -> float: + for batch in validation_data: + preds = torch.from_numpy(nncf_model(batch["image"])[0]) + target = batch["label"] if task == TaskType.CLASSIFICATION else batch["mask"][:, None, :, :] + metric.update(preds, target) + return metric.compute() + + return nncf.quantize_with_accuracy_control(model, calibration_dataset, validation_dataset, val_fn) + def _get_metadata( self, task: TaskType | None = None,