Skip to content

Commit

Permalink
🔥 Removal of Pre-Trained Model Support (#300)
Browse files Browse the repository at this point in the history
This PR stops the support of pre-trained models for the MQT Predictor
framework.

Up until mqt.predictor v2.0.0, pre-trained models were provided.
However, this is not feasible anymore due to the increasing number of
devices and figures of merits.

Instead, we now provide a detailed documentation on how to train and
setup the MQT Predictor framework.

---------

Signed-off-by: Nils Quetschlich <[email protected]>
Co-authored-by: Lukas Burgholzer <[email protected]>
  • Loading branch information
nquetschlich and burgholzer authored Oct 21, 2024
1 parent 003df9e commit 90f1547
Show file tree
Hide file tree
Showing 12 changed files with 104 additions and 284 deletions.
62 changes: 0 additions & 62 deletions .github/workflows/pretrained_model.yml

This file was deleted.

10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ from mqt.predictor import qcompile
from mqt.bench import get_benchmark

# get a benchmark circuit on algorithmic level representing the GHZ state with 5 qubits from [MQT Bench](https://github.com/cda-tum/mqt-bench)
qc_uncompiled = get_benchmark(benchmark_name="dj", level="alg", circuit_size=5)
qc_uncompiled = get_benchmark(benchmark_name="ghz", level="alg", circuit_size=5)

# compile it using the MQT Predictor
qc_compiled, compilation_information, quantum_device = qcompile(qc_uncompiled)
Expand All @@ -72,7 +72,13 @@ print(quantum_device, compilation_information)
print(qc_compiled.draw())
```

**Detailed documentation and examples are available at [ReadTheDocs](https://mqt.readthedocs.io/projects/predictor).**
> [!NOTE]
> To execute the code, respective machine learning models must be trained before.
> Up until mqt.predictor v2.0.0, pre-trained models were provided. However, this is not feasible anymore due to the
> increasing number of devices and figures of merits. Instead, we now provide a detailed documentation on how to train
> and setup the MQT Predictor framework.\*\*
**Further documentation and examples are available at [ReadTheDocs](https://mqt.readthedocs.io/projects/predictor).**

## References

Expand Down
2 changes: 1 addition & 1 deletion docs/Compilation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ We trained one RL model for each currently :ref:`supported quantum device <suppo
Training Data
-------------
To train the model, sufficient training data must be provided as qasm files in the `respective directory <https://github.com/cda-tum/mqt-predictor/tree/main/src/mqt/predictor/rl/training_data/training_circuits>`_.
We provide the training data used for the pre-trained models which are stored `here <https://github.com/cda-tum/mqt-predictor/tree/main/src/mqt/predictor/rl/training_data/trained_model>`_.
We provide the training data used for the initial performance evaluation of this framework which are stored `here <https://github.com/cda-tum/mqt-predictor/tree/main/src/mqt/predictor/rl/training_data/trained_model>`_.
2 changes: 1 addition & 1 deletion docs/DeviceSelection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Training Data
-------------

To train the model, sufficient training data must be provided as qasm files in the `respective directory <https://github.com/cda-tum/mqt-predictor/tree/main/src/mqt/predictor/ml/training_data/training_circuits>`_.
We provide the training data used for the pre-trained model.
We provide the training data used in the initial performance evaluation of this framework.

After the adjustment is finished, the following methods need to be called to generate the training data:

Expand Down
71 changes: 71 additions & 0 deletions docs/Usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,74 @@ For that, the repository must be cloned and installed:
pip install .
Afterwards, the package can be used as described :ref:`above <pip_usage>`.

MQT Predictor Framework Setup
=============================
To run ``qcompile``, the MQT Predictor framework must be set up. How this is properly done is described next.

First, the to-be-considered quantum devices must be included in the framework.
Currently, all devices supported by `MQT Bench <https://github.com/cda-tum/mqt-bench>`_ are natively supported.
In case another device shall be considered, it can be added by using a similar format as in MQT Bench but it is not
necessary to add it in the repository since it can be directly added to the MQT Predictor framework as follows:

- Modify in `mqt/predictor/rl/predictorenv.py <https://github.com/cda-tum/mqt-predictor/tree/main/src/mqt/predictor/rl/predictorenv.py>`_. the line where ``mqt.bench.devices.get_device_by_name`` is used.
- Modify in `mqt/predictor/ml/predictor.py <https://github.com/cda-tum/mqt-predictor/tree/main/src/mqt/predictor/ml/predictor.py>`_. the lines where ``mqt.bench.devices.*`` are used.
- Follow the same data format as defined in `mqt.bench.devices.device.py <https://github.com/cda-tum/mqt-bench/tree/main/src/mqt/bench/devices/device.py>`_

Second, for each supported device, a respective reinforcement learning model must be trained. This is done by running
the following command based on the training data in the form of quantum circuits provided as qasm files in
`mqt/predictor/rl/training_data/training_circuits <https://github.com/cda-tum/mqt-predictor/tree/main/src/mqt/predictor/rl/training_data/training_circuits>`_:

.. code-block:: python
import mqt.predictor
rl_pred = mqt.predictor.rl.Predictor(
figure_of_merit="expected_fidelity", device_name="ibm_washington"
)
rl_pred.train_model(timesteps=100000, model_name="sample_model_rl")
This will train a reinforcement learning model for the ``ibm_washington`` device with the expected fidelity as figure of merit.
Additionally to the expected fidelity, also critical depth is provided as another figure of merit.
Further figures of merit can be added in `mqt.predictor.reward.py <https://github.com/cda-tum/mqt-predictor/tree/main/src/mqt/predictor/reward.py>`_.

Third, after the reinforcement learning models that are used for the respective compilations are trained, the
supervised machine learning model to predict the device selection must be trained.
This is done by first creating the necessary training data (based on the training data in the form of quantum circuits provided as qasm files in
`mqt/predictor/ml/training_data/training_circuits <https://github.com/cda-tum/mqt-predictor/tree/main/src/mqt/predictor/ml/training_data/training_circuits>`_) and then running the following command:

.. code-block:: python
ml_pred = mqt.predictor.ml.Predictor()
ml_pred.generate_compiled_circuits(timeout=600) # timeout in seconds
training_data, name_list, scores_list = ml_pred.generate_trainingdata_from_qasm_files(
figure_of_merit="expected_fidelity"
)
mqt.predictor.ml.helper.save_training_data(
training_data, name_list, scores_list, figure_of_merit="expected_fidelity"
)
This will compile all provided uncompiled training circuits for all available devices and figures of merit.
Afterwards, the training data is generated individually for a figure of merit.
This training data can then be saved and used to train the supervised machine learning model:

.. code-block:: python
ml_pred.train_random_forest_classifier(figure_of_merit="expected_fidelity")
Finally, the MQT Predictor framework is fully set up and can be used to predict the most
suitable device for a given quantum circuit using supervised machine learning and compile
the circuit for the predicted device using reinforcement learning by running:

.. code-block:: python
from mqt.predictor import qcompile
from mqt.bench import get_benchmark
qc_uncompiled = get_benchmark(benchmark_name="ghz", level="alg", circuit_size=5)
compiled_qc, compilation_information, device = qcompile(
uncompiled_qc, figure_of_merit="expected_fidelity"
)
This returns the compiled quantum circuit for the predicted device together with additional information of the compilation procedure.
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The MQT Predictor framework is based on two main components:
- A :doc:`Device-Specific Circuit Compilation <Compilation>` component that compiles a given quantum circuit for a given device.

Combining these two components, the framework can be used to automatically compile a given quantum circuit for the most suitable device optimizing a :doc:`customizable figure of merit<FigureOfMerit>`.
How to use the framework is described in the :doc:`Usage <Usage>` section.

If you are interested in the theory behind MQT Predictor, have a look at the publications in the :doc:`references list <References>`.

Expand Down
76 changes: 2 additions & 74 deletions src/mqt/predictor/ml/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,79 +40,6 @@ def set_classifier(self, clf: RandomForestClassifier) -> None:
"""Sets the classifier to the given classifier."""
self.clf = clf

def compile_all_circuits_circuitwise(
self,
figure_of_merit: reward.figure_of_merit,
timeout: int,
source_path: Path | None = None,
target_path: Path | None = None,
logger_level: int = logging.INFO,
) -> None:
"""Compiles all circuits in the given directory with the given timeout and saves them in the given directory.
Arguments:
figure_of_merit: The figure of merit to be used for compilation.
timeout: The timeout in seconds for the compilation of a single circuit.
source_path: The path to the directory containing the circuits to be compiled. Defaults to None.
target_path: The path to the directory where the compiled circuits should be saved. Defaults to None.
logger_level: The level of the logger. Defaults to logging.INFO.
"""
logger.setLevel(logger_level)

if source_path is None:
source_path = ml.helper.get_path_training_circuits()

if target_path is None:
target_path = ml.helper.get_path_training_circuits_compiled()

Parallel(n_jobs=-1, verbose=100)(
delayed(self.generate_compiled_circuits_for_single_training_circuit)(
filename, timeout, source_path, target_path, figure_of_merit
)
for filename in source_path.iterdir()
)

def generate_compiled_circuits_for_single_training_circuit(
self,
filename: Path,
timeout: int,
source_path: Path,
target_path: Path,
figure_of_merit: reward.figure_of_merit,
) -> None:
"""Compiles a single circuit with the given timeout and saves it in the given directory.
Arguments:
filename: The path to the circuit to be compiled.
timeout: The timeout in seconds for the compilation of the circuit.
source_path: The path to the directory containing the circuit to be compiled.
target_path: The path to the directory where the compiled circuit should be saved.
figure_of_merit: The figure of merit to be used for compilation.
"""
try:
qc = QuantumCircuit.from_qasm_file(Path(source_path) / filename)
if filename.suffix != ".qasm":
return

for i, dev in enumerate(self.devices):
target_filename = str(filename).split("/")[-1].split(".qasm")[0] + "_" + figure_of_merit + "_" + str(i)
if (Path(target_path) / (target_filename + ".qasm")).exists() or qc.num_qubits > dev.num_qubits:
continue
try:
res = utils.timeout_watcher(rl.qcompile, [qc, figure_of_merit, dev.name], timeout)
if isinstance(res, tuple):
compiled_qc = res[0]
with Path(target_path / (target_filename + ".qasm")).open("w", encoding="utf-8") as f:
dump(compiled_qc, f)

except Exception as e:
print(e, filename, "inner")

except Exception as e:
print(e, filename, "outer")

def compile_all_circuits_devicewise(
self,
device_name: str,
Expand Down Expand Up @@ -570,7 +497,8 @@ def predict_probs(self, qc: Path | QuantumCircuit, figure_of_merit: reward.figur
self.clf = load(path)

if self.clf is None:
error_msg = "Classifier is neither trained nor saved."
error_msg = "The ML model is not trained yet. Please train the model before using it."
logger.error(error_msg)
raise FileNotFoundError(error_msg)

feature_dict = ml.helper.create_feature_dict(qc) # type: ignore[unreachable]
Expand Down
69 changes: 5 additions & 64 deletions src/mqt/predictor/rl/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@
from __future__ import annotations

import logging
import os
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Any

import numpy as np
import requests
from bqskit import MachineModel
from packaging import version
from pytket.architecture import Architecture
from pytket.circuit import Circuit, Node, Qubit
from pytket.passes import (
Expand Down Expand Up @@ -67,14 +64,9 @@
from mqt.bench.devices import Device


if TYPE_CHECKING or sys.version_info >= (3, 10, 0):
from importlib import metadata, resources
else:
import importlib_metadata as metadata
import importlib_resources as resources

import operator
import zipfile
from importlib import resources

from bqskit import compile as bqskit_compile
from bqskit.ir import gates
Expand Down Expand Up @@ -429,63 +421,12 @@ def load_model(model_name: str) -> MaskablePPO:
The loaded model.
"""
path = get_path_trained_model()

if Path(path / (model_name + ".zip")).exists():
if Path(path / (model_name + ".zip")).is_file():
return MaskablePPO.load(path / (model_name + ".zip"))
logger.info("Model does not exist. Try to retrieve suitable Model from GitHub...")
try:
mqtpredictor_module_version = metadata.version("mqt.predictor")
except ModuleNotFoundError:
error_msg = (
"Could not retrieve version of mqt.predictor. Please run 'pip install . or pip install mqt.predictor'."
)
raise RuntimeError(error_msg) from None

headers = None
if "GITHUB_TOKEN" in os.environ:
headers = {"Authorization": f"token {os.environ['GITHUB_TOKEN']}"}

version_found = False
response = requests.get("https://api.github.com/repos/cda-tum/mqt-predictor/tags", headers=headers)

if not response:
error_msg = "Querying the GitHub API failed. One reasons could be that the limit of 60 API calls per hour and IP address is exceeded."
raise RuntimeError(error_msg)

available_versions = [elem["name"] for elem in response.json()]

for possible_version in available_versions:
if version.parse(mqtpredictor_module_version) >= version.parse(possible_version):
url = "https://api.github.com/repos/cda-tum/mqt-predictor/releases/tags/" + possible_version
response = requests.get(url, headers=headers)
if not response:
error_msg = "Suitable trained models cannot be downloaded since the GitHub API failed. One reasons could be that the limit of 60 API calls per hour and IP address is exceeded."
raise RuntimeError(error_msg)

response_json = response.json()
if "assets" in response_json:
assets = response_json["assets"]
elif "asset" in response_json:
assets = [response_json["asset"]]
else:
assets = []

for asset in assets:
if model_name in asset["name"]:
version_found = True
download_url = asset["browser_download_url"]
logger.info("Downloading model from: " + download_url)
handle_downloading_model(download_url, model_name)
break

if version_found:
break

if not version_found:
error_msg = "No suitable model found on GitHub. Please update your mqt.predictor package using 'pip install -U mqt.predictor'."
raise RuntimeError(error_msg) from None

return MaskablePPO.load(path / model_name)
error_msg = "The RL model is not trained yet. Please train the model before using it."
logger.error(error_msg)
raise FileNotFoundError(error_msg)


def handle_downloading_model(download_url: str, model_name: str) -> None:
Expand Down
Loading

0 comments on commit 90f1547

Please sign in to comment.