From 52492d45ce7beea6f37deb62e274fb67fec09ee9 Mon Sep 17 00:00:00 2001 From: MarcSerraPeralta <43704266+MarcSerraPeralta@users.noreply.github.com> Date: Fri, 4 Oct 2024 23:20:18 +0200 Subject: [PATCH] Add debugging utils (#40) --- qec_util/__init__.py | 7 +- qec_util/circuits/__init__.py | 8 ++ qec_util/circuits/circuits.py | 95 +++++++++++++++ qec_util/dems/__init__.py | 16 +++ qec_util/dems/dems.py | 210 ++++++++++++++++++++++++++++++++ qec_util/dems/util.py | 26 ++++ tests/circuits/test_circuits.py | 164 +++++++++++++++++++++++++ tests/dems/test_dems.py | 193 +++++++++++++++++++++++++++++ tests/test_tests.py | 19 +-- 9 files changed, 726 insertions(+), 12 deletions(-) create mode 100644 qec_util/circuits/__init__.py create mode 100644 qec_util/circuits/circuits.py create mode 100644 qec_util/dems/__init__.py create mode 100644 qec_util/dems/dems.py create mode 100644 qec_util/dems/util.py create mode 100644 tests/circuits/test_circuits.py create mode 100644 tests/dems/test_dems.py diff --git a/qec_util/__init__.py b/qec_util/__init__.py index 58a89ae..ddb7cec 100644 --- a/qec_util/__init__.py +++ b/qec_util/__init__.py @@ -1,7 +1,4 @@ -"""Main surface-sim module.""" +from . import performance, util, distance, dems, circuits __version__ = "0.1.0" - -from . import performance, util, distance - -__all__ = ["performance", "util", "distance"] +__all__ = ["performance", "util", "distance", "dems", "circuits"] diff --git a/qec_util/circuits/__init__.py b/qec_util/circuits/__init__.py new file mode 100644 index 0000000..d165e97 --- /dev/null +++ b/qec_util/circuits/__init__.py @@ -0,0 +1,8 @@ +from .circuits import ( + remove_gauge_detectors, + remove_detectors_except, + logicals_to_detectors, +) + + +__all__ = ["remove_gauge_detectors", "remove_detectors_except", "logicals_to_detectors"] diff --git a/qec_util/circuits/circuits.py b/qec_util/circuits/circuits.py new file mode 100644 index 0000000..2bb90e8 --- /dev/null +++ b/qec_util/circuits/circuits.py @@ -0,0 +1,95 @@ +from collections.abc import Sequence + +import stim + + +def remove_gauge_detectors(circuit: stim.Circuit) -> stim.Circuit: + """Removes the gauge detectors from the given circuit.""" + if not isinstance(circuit, stim.Circuit): + raise TypeError(f"'circuit' is not a stim.Circuit, but a {type(circuit)}.") + + dem = circuit.detector_error_model(allow_gauge_detectors=True) + gauge_dets = [] + for dem_instr in dem.flattened(): + if dem_instr.type == "error" and dem_instr.args_copy() == [0.5]: + if len(dem_instr.targets_copy()) != 1: + raise ValueError("There exist 'composed' gauge detector: {dem_instr}.") + gauge_dets.append(dem_instr.targets_copy()[0].val) + + if len(gauge_dets) == 0: + return circuit + + current_det = -1 + new_circuit = stim.Circuit() + for instr in circuit.flattened(): + if instr.name == "DETECTOR": + current_det += 1 + if current_det in gauge_dets: + continue + + new_circuit.append(instr) + + return new_circuit + + +def remove_detectors_except( + circuit: stim.Circuit, det_ids_exception: Sequence[int] = [] +) -> stim.Circuit: + """Removes all detectors from a circuit except the specified ones. + Useful for plotting individual detectors with ``stim.Circuit.diagram``. + + Parameters + ---------- + circuit + Stim circuit. + det_ids_exception + Index of the detectors to not be removed. + + Returns + ------- + new_circuit + Stim circuit without detectors except the ones in ``det_ids_exception``. + """ + if not isinstance(circuit, stim.Circuit): + raise TypeError(f"'circuit' is not a stim.Circuit, but a {type(circuit)}.") + if not isinstance(det_ids_exception, Sequence): + raise TypeError( + f"'det_ids_exception' is not a Sequence, but a {type(det_ids_exception)}." + ) + if any([not isinstance(i, int) for i in det_ids_exception]): + raise TypeError( + "'det_ids_exception' is not a sequence of ints, " + f"{det_ids_exception} was given." + ) + + new_circuit = stim.Circuit() + current_det_id = -1 + for instr in circuit.flattened(): + if instr.name != "DETECTOR": + new_circuit.append(instr) + continue + + current_det_id += 1 + if current_det_id in det_ids_exception: + new_circuit.append(instr) + + return new_circuit + + +def logicals_to_detectors(circuit: stim.Circuit) -> stim.Circuit: + """Converts the logical observables of a circuit to detectors.""" + if not isinstance(circuit, stim.Circuit): + raise TypeError(f"'circuit' is not a stim.Circuit, but a {type(circuit)}.") + + new_circuit = stim.Circuit() + for instr in circuit.flattened(): + if instr.name != "OBSERVABLE_INCLUDE": + new_circuit.append(instr) + continue + + targets = instr.targets_copy() + args = instr.gate_args_copy() + new_instr = stim.CircuitInstruction("DETECTOR", gate_args=args, targets=targets) + new_circuit.append(new_instr) + + return new_circuit diff --git a/qec_util/dems/__init__.py b/qec_util/dems/__init__.py new file mode 100644 index 0000000..be3d8d9 --- /dev/null +++ b/qec_util/dems/__init__.py @@ -0,0 +1,16 @@ +from .dems import ( + remove_gauge_detectors, + dem_difference, + is_instr_in_dem, + get_max_weight_hyperedge, + disjoint_graphs, +) + + +__all__ = [ + "remove_gauge_detectors", + "dem_difference", + "is_instr_in_dem", + "get_max_weight_hyperedge", + "disjoint_graphs", +] diff --git a/qec_util/dems/dems.py b/qec_util/dems/dems.py new file mode 100644 index 0000000..2c8b9cd --- /dev/null +++ b/qec_util/dems/dems.py @@ -0,0 +1,210 @@ +import stim +import networkx as nx + +from .util import sorting_index + + +def remove_gauge_detectors(dem: stim.DetectorErrorModel) -> stim.DetectorErrorModel: + """Remove the gauge detectors from a DEM.""" + if not isinstance(dem, stim.DetectorErrorModel): + raise TypeError(f"'dem' is not a stim.DetectorErrorModel, but a {type(dem)}.") + + new_dem = stim.DetectorErrorModel() + gauge_dets = set() + + for dem_instr in dem.flattened(): + if dem_instr.type != "error": + new_dem.append(dem_instr) + + if dem_instr.args_copy() == [0.5]: + det = dem_instr.targets_copy() + if len(det) != 1: + raise ValueError("There exist 'composed' gauge detector: {dem_instr}.") + gauge_dets.add(det[0]) + continue + + if dem_instr.args_copy() != [0.5]: + if len([i for i in dem_instr.targets_copy() if i in gauge_dets]) != 0: + raise ValueError( + "A gauge detector is present in the following error:\n" + f"{dem_instr}\nGauge detectors = {gauge_dets}" + ) + new_dem.append(dem_instr) + + return new_dem + + +def dem_difference( + dem_1: stim.DetectorErrorModel, dem_2: stim.DetectorErrorModel +) -> tuple[stim.DetectorErrorModel, stim.DetectorErrorModel]: + """Returns the the DEM error instructions in the first DEM that are not present + in the second DEM and vice versa. Note that this does not take into account + the decomposition of errors. + + Parameters + ---------- + dem_1 + First detector error model. + dem_2 + Second detector error model. + + Returns + ------- + diff_1 + DEM instructions present in ``dem_1`` that are not present in ``dem_2``. + diff_2 + DEM instructions present in ``dem_2`` that are not present in ``dem_1``. + """ + if not isinstance(dem_1, stim.DetectorErrorModel): + raise TypeError( + f"'dem_1' is not a stim.DetectorErrorModel, but a {type(dem_1)}." + ) + if not isinstance(dem_2, stim.DetectorErrorModel): + raise TypeError( + f"'dem_2' is not a stim.DetectorErrorModel, but a {type(dem_2)}." + ) + + dem_1_ordered = stim.DetectorErrorModel() + num_dets = dem_1.num_detectors + for dem_instr in dem_1.flattened(): + if dem_instr.type != "error": + continue + + # remove separators + targets = [t for t in dem_instr.targets_copy() if not t.is_separator()] + + targets = sorted(targets, key=lambda x: sorting_index(x, num_dets)) + prob = dem_instr.args_copy()[0] + dem_1_ordered.append("error", prob, targets) + + dem_2_ordered = stim.DetectorErrorModel() + num_dets = dem_2.num_detectors + for dem_instr in dem_2.flattened(): + if dem_instr.type != "error": + continue + + # remove separators + targets = [t for t in dem_instr.targets_copy() if not t.is_separator()] + + targets = sorted(targets, key=lambda x: sorting_index(x, num_dets)) + prob = dem_instr.args_copy()[0] + dem_2_ordered.append("error", prob, targets) + + diff_1 = stim.DetectorErrorModel() + for dem_instr in dem_1_ordered: + if dem_instr not in dem_2_ordered: + diff_1.append(dem_instr) + + diff_2 = stim.DetectorErrorModel() + for dem_instr in dem_2_ordered: + if dem_instr not in dem_1_ordered: + diff_2.append(dem_instr) + + return diff_1, diff_2 + + +def is_instr_in_dem( + dem_instr: stim.DemInstruction, dem: stim.DetectorErrorModel +) -> bool: + """Checks if the DEM error instruction and its undecomposed form are present + in the given DEM. + """ + if not isinstance(dem_instr, stim.DemInstruction): + raise TypeError( + f"'dem_instr' must be a stim.DemInstruction, but {type(dem_instr)} was given." + ) + if dem_instr.type != "error": + raise TypeError(f"'dem_instr' is not an error, but a {dem_instr.type}.") + if not isinstance(dem, stim.DetectorErrorModel): + raise TypeError( + f"'dem' must be a stim.DetectorErrorModel, but {type(dem)} was given." + ) + + num_dets = dem.num_detectors + prob = dem_instr.args_copy()[0] + targets = [t for t in dem_instr.targets_copy() if not t.is_separator()] + targets = sorted(targets, key=lambda x: sorting_index(x, num_dets)) + + for instr in dem.flattened(): + if instr.type != "error": + continue + if instr.args_copy()[0] != prob: + continue + + other_targets = [t for t in instr.targets_copy() if not t.is_separator()] + other_targets = sorted(other_targets, key=lambda x: sorting_index(x, num_dets)) + if other_targets == targets: + return True + + return False + + +def get_max_weight_hyperedge( + dem: stim.DetectorErrorModel, +) -> tuple[int, stim.DemInstruction]: + """Return the weight and hyperedges corresponding to the max-weight hyperedge. + + Parameters + ---------- + dem + Stim detector error model. + + Returns + ------- + weight + Weight of the max-weight hyperedge in ``dem``. + hyperedge + Hyperedge with the max-weight in ``dem``. + """ + if not isinstance(dem, stim.DetectorErrorModel): + raise TypeError( + f"'dem' must be a stim.DetectorErrorModel, but {type(dem)} was given." + ) + + max_weight = 0 + hyperedge = stim.DemInstruction(type="error", args=[0], targets=[]) + for dem_instr in dem.flattened(): + if dem_instr.type != "error": + continue + + targets = dem_instr.targets_copy() + targets = [t for t in targets if t.is_relative_detector_id()] + if len(targets) > max_weight: + max_weight = len(targets) + hyperedge = dem_instr + + return max_weight, hyperedge + + +def disjoint_graphs(dem: stim.DetectorErrorModel) -> list[list[int]]: + """Return the nodes in the disjoint subgraphs that the DEM (or decoding + graph) can be split into.""" + if not isinstance(dem, stim.DetectorErrorModel): + raise TypeError( + f"'dem' must be a stim.DetectorErrorModel, but {type(dem)} was given." + ) + + # convert stim.DetectorErrorModel to nx.Graph + # to use the functionality 'nx.connected_components(G)' + g = nx.Graph() + for dem_instr in dem.flattened(): + if dem_instr.type != "error": + continue + + targets = dem_instr.targets_copy() + targets = [t.val for t in targets if t.is_relative_detector_id()] + + if len(targets) == 1: + g.add_node(targets[0]) + + # hyperedges cannot be added to nx.Graph, but if we are just checking + # the number of disjoint graphs, we can transform the hyperedge to a + # sequence of edges which keeps the same connectiviy between nodes. + # For example, hyperedge (0,2,5,6) can be expressed as edges (0,2), + # (2,5) and (5,6). + for start, end in zip(targets[:-1], targets[1:]): + g.add_edge(start, end) + + subgraphs = [list(c) for c in nx.connected_components(g)] + + return subgraphs diff --git a/qec_util/dems/util.py b/qec_util/dems/util.py new file mode 100644 index 0000000..5d7c0ac --- /dev/null +++ b/qec_util/dems/util.py @@ -0,0 +1,26 @@ +def sorting_index(t, num_dets) -> int: + """Function to sort the logical and detector targets inside a DEM instruction. + + Parameters + ---------- + t + stim.DemTarget. + num_dets + Number of detectors in the DEM. + + Returns + ------- + Sorting index associated with ``t``. + + Notes + ----- + ``dem_instr1 == dem_instr2`` is only true if the argument is the same + and the targets are sorted also in the same way. For example + ``"error(0.1) D0 D1"`` is different than ``"error(0.1) D1 D0"``. + """ + if t.is_logical_observable_id(): + return num_dets + t.val + if t.is_relative_detector_id(): + return t.val + else: + raise NotImplemented(f"{t} is not a logical or a detector.") diff --git a/tests/circuits/test_circuits.py b/tests/circuits/test_circuits.py new file mode 100644 index 0000000..2206e37 --- /dev/null +++ b/tests/circuits/test_circuits.py @@ -0,0 +1,164 @@ +import stim +import pytest + +from qec_util.circuits import ( + remove_gauge_detectors, + remove_detectors_except, + logicals_to_detectors, +) + + +def test_remove_gauge_detectors(): + circuit = stim.Circuit( + """ + R 0 1 2 3 + X_ERROR(0.1) 0 1 2 3 + MX 0 + MZ 1 2 3 + DETECTOR(0) rec[-4] + DETECTOR(3) rec[-3] rec[-1] + X 0 + CNOT 1 0 + """ + ) + + new_circuit = remove_gauge_detectors(circuit) + + expected_circuit = stim.Circuit( + """ + R 0 1 2 3 + X_ERROR(0.1) 0 1 2 3 + MX 0 + MZ 1 2 3 + DETECTOR(3) rec[-3] rec[-1] + X 0 + CNOT 1 0 + """ + ) + + assert new_circuit == expected_circuit + + circuit = stim.Circuit( + """ + R 0 1 2 3 + X_ERROR(0.1) 0 1 2 3 + MX 0 + MZ 1 2 3 + DETECTOR(0) rec[-4] + DETECTOR(3) rec[-3] rec[-1] + DETECTOR(9) rec[-4] rec[-2] + X 0 + CNOT 1 0 + """ + ) + + # the DEM looks like "error(0.5) D0 D2" + with pytest.raises(ValueError): + _ = remove_gauge_detectors(circuit) + + return + + +def test_remove_detectors_except(): + circuit = stim.Circuit( + """ + R 0 1 2 3 + X_ERROR(0.1) 0 1 2 3 + MX 0 + MZ 1 2 3 + DETECTOR(0) rec[-4] + DETECTOR(3) rec[-3] rec[-1] + X 0 + CNOT 1 0 + DETECTOR(9) rec[-4] rec[-2] + """ + ) + + new_circuit = remove_detectors_except(circuit) + + expected_circuit = stim.Circuit( + """ + R 0 1 2 3 + X_ERROR(0.1) 0 1 2 3 + MX 0 + MZ 1 2 3 + X 0 + CNOT 1 0 + """ + ) + + assert new_circuit == expected_circuit + + circuit = stim.Circuit( + """ + R 0 1 2 3 + X_ERROR(0.1) 0 1 2 3 + MX 0 + MZ 1 2 3 + DETECTOR(0) rec[-4] + DETECTOR(3) rec[-3] rec[-1] + X 0 + CNOT 1 0 + DETECTOR(9) rec[-4] rec[-2] + """ + ) + + new_circuit = remove_detectors_except(circuit, [0, 2, 1000]) + + expected_circuit = stim.Circuit( + """ + R 0 1 2 3 + X_ERROR(0.1) 0 1 2 3 + MX 0 + MZ 1 2 3 + DETECTOR(0) rec[-4] + X 0 + CNOT 1 0 + DETECTOR(9) rec[-4] rec[-2] + """ + ) + + assert new_circuit == expected_circuit + + with pytest.raises(TypeError): + _ = remove_detectors_except(circuit, [1.2]) + + return + + +def test_logicals_to_detectors(): + circuit = stim.Circuit( + """ + R 0 1 2 3 + X_ERROR(0.1) 0 1 2 3 + MX 0 + MZ 1 2 3 + DETECTOR(0) rec[-4] + DETECTOR(3) rec[-3] rec[-1] + X 0 + CNOT 1 0 + DETECTOR(9) rec[-4] rec[-2] + OBSERVABLE_INCLUDE(0) rec[-1] + """ + ) + + new_circuit = logicals_to_detectors(circuit) + + expected_circuit = stim.Circuit( + """ + R 0 1 2 3 + X_ERROR(0.1) 0 1 2 3 + MX 0 + MZ 1 2 3 + DETECTOR(0) rec[-4] + DETECTOR(3) rec[-3] rec[-1] + X 0 + CNOT 1 0 + DETECTOR(9) rec[-4] rec[-2] + DETECTOR(0) rec[-1] + """ + ) + + assert new_circuit == expected_circuit + + return diff --git a/tests/dems/test_dems.py b/tests/dems/test_dems.py new file mode 100644 index 0000000..a0eec7c --- /dev/null +++ b/tests/dems/test_dems.py @@ -0,0 +1,193 @@ +import pytest +import stim + +from qec_util.dems import ( + remove_gauge_detectors, + dem_difference, + is_instr_in_dem, + get_max_weight_hyperedge, + disjoint_graphs, +) + + +def test_remove_gauge_detectors(): + dem = stim.DetectorErrorModel( + """ + error(0.1) D0 + error(0.5) D4 + error(0.2) D1 D2 + """ + ) + + new_dem = remove_gauge_detectors(dem) + + expected_dem = stim.DetectorErrorModel( + """ + error(0.1) D0 + error(0.2) D1 D2 + """ + ) + + assert new_dem == expected_dem + + dem = stim.DetectorErrorModel( + """ + error(0.1) D0 + error(0.5) D1 D2 D3 + error(0.2) D1 D2 + """ + ) + with pytest.raises(ValueError): + _ = remove_gauge_detectors(dem) + + dem = stim.DetectorErrorModel( + """ + error(0.1) D0 + error(0.5) D1 + error(0.2) D1 D2 + """ + ) + with pytest.raises(ValueError): + _ = remove_gauge_detectors(dem) + + return + + +def test_dem_difference(): + dem_1 = stim.DetectorErrorModel( + """ + error(0.1) L0 D0 + error(0.2) D1 ^ D2 + error(0.3) D3 D4 D1 + """ + ) + dem_2 = stim.DetectorErrorModel( + """ + error(0.1) D0 L0 + error(0.2) D1 D2 + error(0.3) D1 D3 D4 + """ + ) + + diff_1, diff_2 = dem_difference(dem_1, dem_2) + + assert len(diff_1) == 0 + assert len(diff_2) == 0 + + dem_2 = stim.DetectorErrorModel( + """ + error(0.2) D1 D2 + error(0.3) D1 D3 D4 + error(0.5) D0 + """ + ) + + diff_1, diff_2 = dem_difference(dem_1, dem_2) + + assert diff_1 == stim.DetectorErrorModel("error(0.1) D0 L0") + assert diff_2 == stim.DetectorErrorModel("error(0.5) D0") + + return + + +def test_is_instr_in_dem(): + dem = stim.DetectorErrorModel( + """ + error(0.1) L0 D0 + error(0.2) D1 ^ D2 + error(0.3) D3 D4 D1 + error(0.5) D1 L1 + """ + ) + dem_instr = stim.DemInstruction( + "error", + [0.1], + [stim.target_relative_detector_id(0), stim.target_logical_observable_id(0)], + ) + assert is_instr_in_dem(dem_instr, dem) + + dem_instr = stim.DemInstruction( + "error", + [0.2], + [stim.target_relative_detector_id(0), stim.target_logical_observable_id(0)], + ) + assert not is_instr_in_dem(dem_instr, dem) + + dem_instr = stim.DemInstruction( + "error", + [0.5], + [stim.target_relative_detector_id(1), stim.target_logical_observable_id(1)], + ) + assert is_instr_in_dem(dem_instr, dem) + + dem_instr = stim.DemInstruction( + "detector", + [0], + [stim.target_relative_detector_id(1)], + ) + with pytest.raises(TypeError): + is_instr_in_dem(dem_instr, dem) + + return + + +def test_get_max_weight_hyperedge(): + dem = stim.DetectorErrorModel( + """ + error(0.1) L0 D0 + error(0.2) D1 ^ D2 + error(0.3) D3 D4 D1 + error(0.5) D1 L1 + """ + ) + + max_weight, hyperedge = get_max_weight_hyperedge(dem) + + expected_hyperedge = stim.DemInstruction( + "error", + args=[0.3], + targets=[ + stim.target_relative_detector_id(3), + stim.target_relative_detector_id(4), + stim.target_relative_detector_id(1), + ], + ) + + assert max_weight == 3 + assert hyperedge == expected_hyperedge + + dem = stim.DetectorErrorModel() + + max_weight, hyperedge = get_max_weight_hyperedge(dem) + + expected_hyperedge = stim.DemInstruction( + "error", + args=[0.0], + targets=[], + ) + + assert max_weight == 0 + assert hyperedge == expected_hyperedge + + return + + +def test_disjoint_graphs(): + dem = stim.DetectorErrorModel( + """ + error(0.1) L0 D0 + error(0.2) D1 ^ D2 + error(0.3) D3 D4 D1 + error(0.5) D5 D6 + detector(0) D0 + """ + ) + + subgraphs = disjoint_graphs(dem) + subgraphs = set(tuple(sorted(s)) for s in subgraphs) + + expected_subgraphs = set([(0,), (1, 2, 3, 4), (5, 6)]) + + assert subgraphs == expected_subgraphs + + return diff --git a/tests/test_tests.py b/tests/test_tests.py index 4e2844f..2c89fc1 100644 --- a/tests/test_tests.py +++ b/tests/test_tests.py @@ -3,7 +3,7 @@ DIR_EXCEPTIONS = ["__pycache__"] -FILE_EXCEPTIONS = ["__init__.py"] +FILE_EXCEPTIONS = ["__init__.py", "qec_util/dems/util.py"] def test_tests(): @@ -16,14 +16,19 @@ def test_tests(): for file in files: if file[-3:] != ".py" or file[0] == "_": continue + if os.path.basename(os.path.normpath(path)) in DIR_EXCEPTIONS: + continue + if (file in FILE_EXCEPTIONS) or ( + os.path.join(path, file) in FILE_EXCEPTIONS + ): + continue # change root dir from "qec_util" to test_dir relpath = os.path.relpath(path, mod_dir) testpath = os.path.join(test_dir, relpath) - if file not in FILE_EXCEPTIONS: - if not os.path.exists(os.path.join(testpath, "test_" + file)): - raise ValueError( - f"test file for {os.path.join(mod_dir, relpath, file)}" - " does not exist" - ) + if not os.path.exists(os.path.join(testpath, "test_" + file)): + raise ValueError( + f"test file for {os.path.join(mod_dir, relpath, file)}" + " does not exist" + ) return