Skip to content

Commit

Permalink
Add debugging utils (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcSerraPeralta authored Oct 4, 2024
1 parent fe4e512 commit 52492d4
Show file tree
Hide file tree
Showing 9 changed files with 726 additions and 12 deletions.
7 changes: 2 additions & 5 deletions qec_util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
"""Main surface-sim module."""
from . import performance, util, distance, dems, circuits

__version__ = "0.1.0"

from . import performance, util, distance

__all__ = ["performance", "util", "distance"]
__all__ = ["performance", "util", "distance", "dems", "circuits"]
8 changes: 8 additions & 0 deletions qec_util/circuits/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .circuits import (
remove_gauge_detectors,
remove_detectors_except,
logicals_to_detectors,
)


__all__ = ["remove_gauge_detectors", "remove_detectors_except", "logicals_to_detectors"]
95 changes: 95 additions & 0 deletions qec_util/circuits/circuits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from collections.abc import Sequence

import stim


def remove_gauge_detectors(circuit: stim.Circuit) -> stim.Circuit:
"""Removes the gauge detectors from the given circuit."""
if not isinstance(circuit, stim.Circuit):
raise TypeError(f"'circuit' is not a stim.Circuit, but a {type(circuit)}.")

dem = circuit.detector_error_model(allow_gauge_detectors=True)
gauge_dets = []
for dem_instr in dem.flattened():
if dem_instr.type == "error" and dem_instr.args_copy() == [0.5]:
if len(dem_instr.targets_copy()) != 1:
raise ValueError("There exist 'composed' gauge detector: {dem_instr}.")
gauge_dets.append(dem_instr.targets_copy()[0].val)

if len(gauge_dets) == 0:
return circuit

current_det = -1
new_circuit = stim.Circuit()
for instr in circuit.flattened():
if instr.name == "DETECTOR":
current_det += 1
if current_det in gauge_dets:
continue

new_circuit.append(instr)

return new_circuit


def remove_detectors_except(
circuit: stim.Circuit, det_ids_exception: Sequence[int] = []
) -> stim.Circuit:
"""Removes all detectors from a circuit except the specified ones.
Useful for plotting individual detectors with ``stim.Circuit.diagram``.
Parameters
----------
circuit
Stim circuit.
det_ids_exception
Index of the detectors to not be removed.
Returns
-------
new_circuit
Stim circuit without detectors except the ones in ``det_ids_exception``.
"""
if not isinstance(circuit, stim.Circuit):
raise TypeError(f"'circuit' is not a stim.Circuit, but a {type(circuit)}.")
if not isinstance(det_ids_exception, Sequence):
raise TypeError(
f"'det_ids_exception' is not a Sequence, but a {type(det_ids_exception)}."
)
if any([not isinstance(i, int) for i in det_ids_exception]):
raise TypeError(
"'det_ids_exception' is not a sequence of ints, "
f"{det_ids_exception} was given."
)

new_circuit = stim.Circuit()
current_det_id = -1
for instr in circuit.flattened():
if instr.name != "DETECTOR":
new_circuit.append(instr)
continue

current_det_id += 1
if current_det_id in det_ids_exception:
new_circuit.append(instr)

return new_circuit


def logicals_to_detectors(circuit: stim.Circuit) -> stim.Circuit:
"""Converts the logical observables of a circuit to detectors."""
if not isinstance(circuit, stim.Circuit):
raise TypeError(f"'circuit' is not a stim.Circuit, but a {type(circuit)}.")

new_circuit = stim.Circuit()
for instr in circuit.flattened():
if instr.name != "OBSERVABLE_INCLUDE":
new_circuit.append(instr)
continue

targets = instr.targets_copy()
args = instr.gate_args_copy()
new_instr = stim.CircuitInstruction("DETECTOR", gate_args=args, targets=targets)
new_circuit.append(new_instr)

return new_circuit
16 changes: 16 additions & 0 deletions qec_util/dems/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from .dems import (
remove_gauge_detectors,
dem_difference,
is_instr_in_dem,
get_max_weight_hyperedge,
disjoint_graphs,
)


__all__ = [
"remove_gauge_detectors",
"dem_difference",
"is_instr_in_dem",
"get_max_weight_hyperedge",
"disjoint_graphs",
]
210 changes: 210 additions & 0 deletions qec_util/dems/dems.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import stim
import networkx as nx

from .util import sorting_index


def remove_gauge_detectors(dem: stim.DetectorErrorModel) -> stim.DetectorErrorModel:
"""Remove the gauge detectors from a DEM."""
if not isinstance(dem, stim.DetectorErrorModel):
raise TypeError(f"'dem' is not a stim.DetectorErrorModel, but a {type(dem)}.")

new_dem = stim.DetectorErrorModel()
gauge_dets = set()

for dem_instr in dem.flattened():
if dem_instr.type != "error":
new_dem.append(dem_instr)

if dem_instr.args_copy() == [0.5]:
det = dem_instr.targets_copy()
if len(det) != 1:
raise ValueError("There exist 'composed' gauge detector: {dem_instr}.")
gauge_dets.add(det[0])
continue

if dem_instr.args_copy() != [0.5]:
if len([i for i in dem_instr.targets_copy() if i in gauge_dets]) != 0:
raise ValueError(
"A gauge detector is present in the following error:\n"
f"{dem_instr}\nGauge detectors = {gauge_dets}"
)
new_dem.append(dem_instr)

return new_dem


def dem_difference(
dem_1: stim.DetectorErrorModel, dem_2: stim.DetectorErrorModel
) -> tuple[stim.DetectorErrorModel, stim.DetectorErrorModel]:
"""Returns the the DEM error instructions in the first DEM that are not present
in the second DEM and vice versa. Note that this does not take into account
the decomposition of errors.
Parameters
----------
dem_1
First detector error model.
dem_2
Second detector error model.
Returns
-------
diff_1
DEM instructions present in ``dem_1`` that are not present in ``dem_2``.
diff_2
DEM instructions present in ``dem_2`` that are not present in ``dem_1``.
"""
if not isinstance(dem_1, stim.DetectorErrorModel):
raise TypeError(
f"'dem_1' is not a stim.DetectorErrorModel, but a {type(dem_1)}."
)
if not isinstance(dem_2, stim.DetectorErrorModel):
raise TypeError(
f"'dem_2' is not a stim.DetectorErrorModel, but a {type(dem_2)}."
)

dem_1_ordered = stim.DetectorErrorModel()
num_dets = dem_1.num_detectors
for dem_instr in dem_1.flattened():
if dem_instr.type != "error":
continue

# remove separators
targets = [t for t in dem_instr.targets_copy() if not t.is_separator()]

targets = sorted(targets, key=lambda x: sorting_index(x, num_dets))
prob = dem_instr.args_copy()[0]
dem_1_ordered.append("error", prob, targets)

dem_2_ordered = stim.DetectorErrorModel()
num_dets = dem_2.num_detectors
for dem_instr in dem_2.flattened():
if dem_instr.type != "error":
continue

# remove separators
targets = [t for t in dem_instr.targets_copy() if not t.is_separator()]

targets = sorted(targets, key=lambda x: sorting_index(x, num_dets))
prob = dem_instr.args_copy()[0]
dem_2_ordered.append("error", prob, targets)

diff_1 = stim.DetectorErrorModel()
for dem_instr in dem_1_ordered:
if dem_instr not in dem_2_ordered:
diff_1.append(dem_instr)

diff_2 = stim.DetectorErrorModel()
for dem_instr in dem_2_ordered:
if dem_instr not in dem_1_ordered:
diff_2.append(dem_instr)

return diff_1, diff_2


def is_instr_in_dem(
dem_instr: stim.DemInstruction, dem: stim.DetectorErrorModel
) -> bool:
"""Checks if the DEM error instruction and its undecomposed form are present
in the given DEM.
"""
if not isinstance(dem_instr, stim.DemInstruction):
raise TypeError(
f"'dem_instr' must be a stim.DemInstruction, but {type(dem_instr)} was given."
)
if dem_instr.type != "error":
raise TypeError(f"'dem_instr' is not an error, but a {dem_instr.type}.")
if not isinstance(dem, stim.DetectorErrorModel):
raise TypeError(
f"'dem' must be a stim.DetectorErrorModel, but {type(dem)} was given."
)

num_dets = dem.num_detectors
prob = dem_instr.args_copy()[0]
targets = [t for t in dem_instr.targets_copy() if not t.is_separator()]
targets = sorted(targets, key=lambda x: sorting_index(x, num_dets))

for instr in dem.flattened():
if instr.type != "error":
continue
if instr.args_copy()[0] != prob:
continue

other_targets = [t for t in instr.targets_copy() if not t.is_separator()]
other_targets = sorted(other_targets, key=lambda x: sorting_index(x, num_dets))
if other_targets == targets:
return True

return False


def get_max_weight_hyperedge(
dem: stim.DetectorErrorModel,
) -> tuple[int, stim.DemInstruction]:
"""Return the weight and hyperedges corresponding to the max-weight hyperedge.
Parameters
----------
dem
Stim detector error model.
Returns
-------
weight
Weight of the max-weight hyperedge in ``dem``.
hyperedge
Hyperedge with the max-weight in ``dem``.
"""
if not isinstance(dem, stim.DetectorErrorModel):
raise TypeError(
f"'dem' must be a stim.DetectorErrorModel, but {type(dem)} was given."
)

max_weight = 0
hyperedge = stim.DemInstruction(type="error", args=[0], targets=[])
for dem_instr in dem.flattened():
if dem_instr.type != "error":
continue

targets = dem_instr.targets_copy()
targets = [t for t in targets if t.is_relative_detector_id()]
if len(targets) > max_weight:
max_weight = len(targets)
hyperedge = dem_instr

return max_weight, hyperedge


def disjoint_graphs(dem: stim.DetectorErrorModel) -> list[list[int]]:
"""Return the nodes in the disjoint subgraphs that the DEM (or decoding
graph) can be split into."""
if not isinstance(dem, stim.DetectorErrorModel):
raise TypeError(
f"'dem' must be a stim.DetectorErrorModel, but {type(dem)} was given."
)

# convert stim.DetectorErrorModel to nx.Graph
# to use the functionality 'nx.connected_components(G)'
g = nx.Graph()
for dem_instr in dem.flattened():
if dem_instr.type != "error":
continue

targets = dem_instr.targets_copy()
targets = [t.val for t in targets if t.is_relative_detector_id()]

if len(targets) == 1:
g.add_node(targets[0])

# hyperedges cannot be added to nx.Graph, but if we are just checking
# the number of disjoint graphs, we can transform the hyperedge to a
# sequence of edges which keeps the same connectiviy between nodes.
# For example, hyperedge (0,2,5,6) can be expressed as edges (0,2),
# (2,5) and (5,6).
for start, end in zip(targets[:-1], targets[1:]):
g.add_edge(start, end)

subgraphs = [list(c) for c in nx.connected_components(g)]

return subgraphs
26 changes: 26 additions & 0 deletions qec_util/dems/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
def sorting_index(t, num_dets) -> int:
"""Function to sort the logical and detector targets inside a DEM instruction.
Parameters
----------
t
stim.DemTarget.
num_dets
Number of detectors in the DEM.
Returns
-------
Sorting index associated with ``t``.
Notes
-----
``dem_instr1 == dem_instr2`` is only true if the argument is the same
and the targets are sorted also in the same way. For example
``"error(0.1) D0 D1"`` is different than ``"error(0.1) D1 D0"``.
"""
if t.is_logical_observable_id():
return num_dets + t.val
if t.is_relative_detector_id():
return t.val
else:
raise NotImplemented(f"{t} is not a logical or a detector.")
Loading

0 comments on commit 52492d4

Please sign in to comment.