diff --git a/CHANGELOG.md b/CHANGELOG.md index 9585916a..c83782fe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **Pattern**: Added the `depth` attribute into `Pattern`, which represents the depth of parallel execution. +- **Pattern**: Added pattern resource/throughput metrics (`active_volume`, `volume`, `idle_times`, `throughput`). + - **Scheduler Integration**: Enhanced qompile() to support temporal scheduling with TICK commands - Added `scheduler` parameter to qompile() for custom scheduling - Automatically inserts TICK commands between time slices @@ -33,6 +35,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - This optimization is equivalent to the operation of the same name in the measurement calculus, and makes the measurement pattern as parallel as possible. - The optimization is now self-contained within the feedforward module. +- **Feedforward Optimization**: Added `pauli_simplification()` to remove redundant Pauli corrections in correction maps when measuring in Pauli bases. + ### Changed - **Pattern**: Updated command sequence generation to support TICK commands diff --git a/graphqomb/feedforward.py b/graphqomb/feedforward.py index fdb0102e..ce4ccd65 100644 --- a/graphqomb/feedforward.py +++ b/graphqomb/feedforward.py @@ -19,7 +19,7 @@ import typing_extensions -from graphqomb.common import Plane +from graphqomb.common import Axis, Plane, determine_pauli_axis from graphqomb.graphstate import BaseGraphState, odd_neighbors if sys.version_info >= (3, 10): @@ -277,3 +277,61 @@ def propagate_correction_map( # noqa: C901, PLR0912 new_zflow[parent] ^= {child_z} return new_xflow, new_zflow + + +def pauli_simplification( # noqa: C901, PLR0912 + graph: BaseGraphState, + xflow: Mapping[int, AbstractSet[int]], + zflow: Mapping[int, AbstractSet[int]] | None = None, +) -> tuple[dict[int, set[int]], dict[int, set[int]]]: + r"""Simplify the correction maps by removing redundant Pauli corrections. + + Parameters + ---------- + graph : `BaseGraphState` + Underlying graph state. + xflow : `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\] + Correction map for X. + zflow : `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\] | `None` + Correction map for Z. If `None`, it is generated from xflow by odd neighbors. + + Returns + ------- + `tuple`\[`dict`\[`int`, `set`\[`int`\]\], `dict`\[`int`, `set`\[`int`\]\]] + Updated correction maps for X and Z after simplification. + """ + if zflow is None: + zflow = {node: odd_neighbors(xflow[node], graph) - {node} for node in xflow} + + new_xflow = {k: set(vs) for k, vs in xflow.items()} + new_zflow = {k: set(vs) for k, vs in zflow.items()} + + inv_xflow: dict[int, set[int]] = {} + inv_zflow: dict[int, set[int]] = {} + for k, vs in xflow.items(): + for v in vs: + inv_xflow.setdefault(v, set()).add(k) + for k, vs in zflow.items(): + for v in vs: + inv_zflow.setdefault(v, set()).add(k) + + for node in graph.physical_nodes - graph.output_node_indices.keys(): + meas_basis = graph.meas_bases.get(node) + if meas_basis is None: + continue + meas_axis = determine_pauli_axis(meas_basis) + if meas_axis is None: + continue + + if meas_axis == Axis.X: + for parent in inv_xflow.get(node, set()): + new_xflow[parent] -= {node} + elif meas_axis == Axis.Z: + for parent in inv_zflow.get(node, set()): + new_zflow[parent] -= {node} + elif meas_axis == Axis.Y: + for parent in inv_xflow.get(node, set()) & inv_zflow.get(node, set()): + new_xflow[parent] -= {node} + new_zflow[parent] -= {node} + + return new_xflow, new_zflow diff --git a/graphqomb/pattern.py b/graphqomb/pattern.py index b9718da4..cf98f766 100644 --- a/graphqomb/pattern.py +++ b/graphqomb/pattern.py @@ -101,6 +101,76 @@ def depth(self) -> int: """ return sum(1 for cmd in self.commands if isinstance(cmd, TICK)) + @property + def active_volume(self) -> int: + """Calculate tha active volume, summation of space for each timeslice. + + Returns + ------- + `int` + Active volume of the pattern + """ + return sum(self.space) + + @property + def volume(self) -> int: + """Calculate the volume, defined as max_space * depth. + + Returns + ------- + `int` + Volume of the pattern + """ + return self.max_space * self.depth + + @property + def idle_times(self) -> dict[int, int]: + r"""Calculate the idle times for each qubit in the pattern. + + Returns + ------- + `dict`\[`int`, `int`\] + A dictionary mapping each qubit index to its idle time. + """ + idle_times: dict[int, int] = {} + prepared_time: dict[int, int] = dict.fromkeys(self.input_node_indices, 0) + + current_time = 0 + for cmd in self.commands: + if isinstance(cmd, TICK): + current_time += 1 + elif isinstance(cmd, N): + prepared_time[cmd.node] = current_time + elif isinstance(cmd, M): + idle_times[cmd.node] = current_time - prepared_time[cmd.node] + + for output_node in self.output_node_indices: + if output_node in prepared_time: + idle_times[output_node] = current_time - prepared_time[output_node] + + return idle_times + + @property + def throughput(self) -> float: + """Calculate the number of measurements per TICK in the pattern. + + Returns + ------- + `float` + Number of measurements per TICK + + Raises + ------ + ValueError + If the pattern has zero depth (no TICK commands) + """ + num_measurements = sum(1 for cmd in self.commands if isinstance(cmd, M)) + num_ticks = self.depth + if num_ticks == 0: + msg = "Cannot calculate throughput for a pattern with zero depth (no TICK commands)." + raise ValueError(msg) + return num_measurements / num_ticks + def is_runnable(pattern: Pattern) -> None: """Check if the pattern is runnable. diff --git a/tests/test_feedforward.py b/tests/test_feedforward.py index 8c68e783..b49beb64 100644 --- a/tests/test_feedforward.py +++ b/tests/test_feedforward.py @@ -4,13 +4,14 @@ import pytest from graphqomb.circuit import MBQCCircuit, circuit2graph -from graphqomb.common import Plane, PlannerMeasBasis +from graphqomb.common import Axis, AxisMeasBasis, Plane, PlannerMeasBasis, Sign from graphqomb.feedforward import ( _is_flow, _is_gflow, check_dag, check_flow, dag_from_flow, + pauli_simplification, propagate_correction_map, signal_shifting, ) @@ -336,3 +337,217 @@ def test_signal_shifting_circuit_integration() -> None: # Verify that the results match (inner product should be close to 1) assert np.isclose(np.abs(inner_product), 1.0) + + +# Tests for pauli_simplification + + +def test_pauli_simplification_x_axis_removes_x_correction() -> None: + """Test that X-axis measurement removes X corrections from the flow.""" + # Create a 3-node graph: parent -> target -> output + graphstate = GraphState() + parent = graphstate.add_physical_node() + target = graphstate.add_physical_node() + output = graphstate.add_physical_node() + graphstate.add_physical_edge(parent, target) + graphstate.add_physical_edge(target, output) + + # Set X-axis measurement basis for target + graphstate.assign_meas_basis(target, AxisMeasBasis(Axis.X, Sign.PLUS)) + graphstate.assign_meas_basis(parent, AxisMeasBasis(Axis.X, Sign.PLUS)) + + graphstate.register_output(output, 0) + + # Define flows where parent's X correction depends on target + xflow: dict[int, set[int]] = {parent: {target}, target: {output}} + zflow: dict[int, set[int]] = {parent: {target}, target: {output}} + + new_xflow, new_zflow = pauli_simplification(graphstate, xflow, zflow) + + # X-axis measurement should remove target from parent's X corrections + assert target not in new_xflow[parent] + # Z corrections should remain unchanged + assert target in new_zflow[parent] + + +def test_pauli_simplification_z_axis_removes_z_correction() -> None: + """Test that Z-axis measurement removes Z corrections from the flow.""" + # Create a 3-node graph: parent -> target -> output + graphstate = GraphState() + parent = graphstate.add_physical_node() + target = graphstate.add_physical_node() + output = graphstate.add_physical_node() + graphstate.add_physical_edge(parent, target) + graphstate.add_physical_edge(target, output) + + # Set Z-axis measurement basis for target + graphstate.assign_meas_basis(target, AxisMeasBasis(Axis.Z, Sign.PLUS)) + graphstate.assign_meas_basis(parent, AxisMeasBasis(Axis.Z, Sign.PLUS)) + + graphstate.register_output(output, 0) + + # Define flows where parent's Z correction depends on target + xflow: dict[int, set[int]] = {parent: {target}, target: {output}} + zflow: dict[int, set[int]] = {parent: {target}, target: {output}} + + new_xflow, new_zflow = pauli_simplification(graphstate, xflow, zflow) + + # Z-axis measurement should remove target from parent's Z corrections + assert target not in new_zflow[parent] + # X corrections should remain unchanged + assert target in new_xflow[parent] + + +def test_pauli_simplification_y_axis_removes_both_corrections() -> None: + """Test that Y-axis measurement removes both X and Z corrections from the flow.""" + # Create a 3-node graph: parent -> target -> output + graphstate = GraphState() + parent = graphstate.add_physical_node() + target = graphstate.add_physical_node() + output = graphstate.add_physical_node() + graphstate.add_physical_edge(parent, target) + graphstate.add_physical_edge(target, output) + + # Set Y-axis measurement basis for target + graphstate.assign_meas_basis(target, AxisMeasBasis(Axis.Y, Sign.PLUS)) + graphstate.assign_meas_basis(parent, AxisMeasBasis(Axis.X, Sign.PLUS)) + + graphstate.register_output(output, 0) + + # Define flows where parent's corrections depend on target + xflow: dict[int, set[int]] = {parent: {target}, target: {output}} + zflow: dict[int, set[int]] = {parent: {target}, target: {output}} + + new_xflow, new_zflow = pauli_simplification(graphstate, xflow, zflow) + + # Y-axis measurement should remove target from both X and Z corrections + assert target not in new_xflow[parent] + assert target not in new_zflow[parent] + + +def test_pauli_simplification_y_axis_skip() -> None: + """Test that Y-axis measurement skips if not both corrections are present.""" + # Create a 3-node graph: parent -> target -> output + graphstate = GraphState() + parent = graphstate.add_physical_node() + target = graphstate.add_physical_node() + output = graphstate.add_physical_node() + graphstate.add_physical_edge(parent, target) + graphstate.add_physical_edge(target, output) + + # Set Y-axis measurement basis for target + graphstate.assign_meas_basis(target, AxisMeasBasis(Axis.Y, Sign.PLUS)) + graphstate.assign_meas_basis(parent, AxisMeasBasis(Axis.X, Sign.PLUS)) + + graphstate.register_output(output, 0) + + # Define flows where parent's corrections depend on target + xflow: dict[int, set[int]] = {parent: {target}, target: {output}} + zflow: dict[int, set[int]] = {parent: {output}, target: {output}} + + # Skip removing X correction + new_xflow, _ = pauli_simplification(graphstate, xflow, zflow) + + assert target in new_xflow[parent] # X correction remains + + xflow = {parent: {output}, target: {output}} + zflow = {parent: {target}, target: {output}} + # Skip removing Z correction + _, new_zflow = pauli_simplification(graphstate, xflow, zflow) + + assert target in new_zflow[parent] # Z correction remains + + +def test_pauli_simplification_non_pauli_leaves_unchanged() -> None: + """Test that non-Pauli measurement angles leave corrections unchanged.""" + # Create a 3-node graph: parent -> target -> output + graphstate = GraphState() + parent = graphstate.add_physical_node() + target = graphstate.add_physical_node() + output = graphstate.add_physical_node() + graphstate.add_physical_edge(parent, target) + graphstate.add_physical_edge(target, output) + + # Set non-Pauli measurement basis for target (XY plane, angle=pi/4) + graphstate.assign_meas_basis(target, PlannerMeasBasis(Plane.XY, 0.25 * np.pi)) + graphstate.assign_meas_basis(parent, PlannerMeasBasis(Plane.XY, 0.0)) + + graphstate.register_output(output, 0) + + # Define flows + xflow: dict[int, set[int]] = {parent: {target}, target: {output}} + zflow: dict[int, set[int]] = {parent: {target}, target: {output}} + + new_xflow, new_zflow = pauli_simplification(graphstate, xflow, zflow) + + # Non-Pauli angle should leave corrections unchanged + assert target in new_xflow[parent] + assert target in new_zflow[parent] + + +def test_pauli_simplification_preserves_original_flows() -> None: + """Test that original xflow and zflow are not modified.""" + # Create a 3-node graph: parent -> target -> output + graphstate = GraphState() + parent = graphstate.add_physical_node() + target = graphstate.add_physical_node() + output = graphstate.add_physical_node() + graphstate.add_physical_edge(parent, target) + graphstate.add_physical_edge(target, output) + + # Set X-axis measurement basis for target + graphstate.assign_meas_basis(target, AxisMeasBasis(Axis.X, Sign.PLUS)) + graphstate.assign_meas_basis(parent, AxisMeasBasis(Axis.X, Sign.PLUS)) + + graphstate.register_output(output, 0) + + # Define flows + xflow: dict[int, set[int]] = {parent: {target}, target: {output}} + zflow: dict[int, set[int]] = {parent: {target}, target: {output}} + + # Store original values + original_xflow_parent = set(xflow[parent]) + original_zflow_parent = set(zflow[parent]) + + pauli_simplification(graphstate, xflow, zflow) + + # Original flows should be unchanged + assert xflow[parent] == original_xflow_parent + assert zflow[parent] == original_zflow_parent + + +def test_pauli_simplification_circuit_integration() -> None: + """Test pauli_simplification integration with circuit compilation and simulation.""" + # Create a quantum circuit (using j for rotations, cz for entanglement) + circuit = MBQCCircuit(2) + circuit.j(0, 0.5 * np.pi) # Rotation on qubit 0 + circuit.cz(0, 1) + circuit.j(1, 0.25 * np.pi) # Rotation on qubit 1 + + # Convert circuit to graph and gflow + graphstate, gflow = circuit2graph(circuit) + + # Apply pauli simplification + xflow, zflow = pauli_simplification(graphstate, gflow) + + # Compile to pattern + pattern = qompile(graphstate, xflow, zflow) + + # Verify pattern is runnable + assert pattern is not None + assert pattern.max_space >= 0 + + # Simulate the pattern + simulator = PatternSimulator(pattern, SimulatorBackend.StateVector) + simulator.simulate() + state = simulator.state + statevec = state.state() + + # Compare with circuit simulator + circ_simulator = CircuitSimulator(circuit, SimulatorBackend.StateVector) + circ_simulator.simulate() + circ_state = circ_simulator.state.state() + inner_product = np.vdot(statevec, circ_state) + + # Verify that the results match (inner product should be close to 1) + assert np.isclose(np.abs(inner_product), 1.0) diff --git a/tests/test_pattern.py b/tests/test_pattern.py index b24dcef6..c3f844c2 100644 --- a/tests/test_pattern.py +++ b/tests/test_pattern.py @@ -80,3 +80,254 @@ def test_pattern_depth_is_zero_without_ticks( ) assert pattern.depth == 0 + + +# Tests for active_volume + + +def test_active_volume_sums_space_list( + pattern_components: tuple[dict[int, int], dict[int, int], PauliFrame, list[int]], +) -> None: + """Test that active_volume equals sum of space list.""" + input_nodes, output_nodes, pauli_frame, nodes = pattern_components + meas_basis = PlannerMeasBasis(Plane.XY, 0.0) + commands = ( + N(node=nodes[1]), + TICK(), + E(nodes=(nodes[0], nodes[1])), + M(node=nodes[1], meas_basis=meas_basis), + TICK(), + ) + + pattern = Pattern( + input_node_indices=input_nodes, + output_node_indices=output_nodes, + commands=commands, + pauli_frame=pauli_frame, + ) + + assert pattern.active_volume == sum(pattern.space) + + +def test_active_volume_with_multiple_ticks( + pattern_components: tuple[dict[int, int], dict[int, int], PauliFrame, list[int]], +) -> None: + """Test active_volume correctly sums space across multiple time slices.""" + input_nodes, output_nodes, pauli_frame, nodes = pattern_components + meas_basis = PlannerMeasBasis(Plane.XY, 0.0) + # Pattern: 1 input, add 2 nodes, measure 1, add 1, measure 2 + # space should be [1, 3, 2, 3, 1] + commands = ( + N(node=nodes[1]), + N(node=nodes[2]), + TICK(), + M(node=nodes[1], meas_basis=meas_basis), + TICK(), + N(node=3), # New node + TICK(), + M(node=nodes[2], meas_basis=meas_basis), + M(node=3, meas_basis=meas_basis), + TICK(), + ) + + pattern = Pattern( + input_node_indices=input_nodes, + output_node_indices=output_nodes, + commands=commands, + pauli_frame=pauli_frame, + ) + + # space = [1, 3, 2, 3, 1] -> active_volume = 10 + assert pattern.active_volume == sum(pattern.space) + + +# Tests for volume + + +def test_volume_equals_max_space_times_depth( + pattern_components: tuple[dict[int, int], dict[int, int], PauliFrame, list[int]], +) -> None: + """Test that volume equals max_space times depth.""" + input_nodes, output_nodes, pauli_frame, nodes = pattern_components + meas_basis = PlannerMeasBasis(Plane.XY, 0.0) + commands = ( + N(node=nodes[1]), + TICK(), + M(node=nodes[1], meas_basis=meas_basis), + TICK(), + ) + + pattern = Pattern( + input_node_indices=input_nodes, + output_node_indices=output_nodes, + commands=commands, + pauli_frame=pauli_frame, + ) + + assert pattern.volume == pattern.max_space * pattern.depth + + +def test_volume_is_zero_without_ticks( + pattern_components: tuple[dict[int, int], dict[int, int], PauliFrame, list[int]], +) -> None: + """Test that volume is zero when there are no TICK commands.""" + input_nodes, output_nodes, pauli_frame, nodes = pattern_components + meas_basis = PlannerMeasBasis(Plane.XY, 0.0) + commands = ( + N(node=nodes[1]), + M(node=nodes[1], meas_basis=meas_basis), + ) + + pattern = Pattern( + input_node_indices=input_nodes, + output_node_indices=output_nodes, + commands=commands, + pauli_frame=pauli_frame, + ) + + assert pattern.depth == 0 + assert pattern.volume == 0 + + +# Tests for idle_times + + +def test_idle_times_returns_dict_for_measured_qubits( + pattern_components: tuple[dict[int, int], dict[int, int], PauliFrame, list[int]], +) -> None: + """Test that idle_times returns a dict with entries for measured qubits.""" + input_nodes, output_nodes, pauli_frame, nodes = pattern_components + meas_basis = PlannerMeasBasis(Plane.XY, 0.0) + commands = ( + N(node=nodes[1]), + TICK(), + M(node=nodes[1], meas_basis=meas_basis), + TICK(), + ) + + pattern = Pattern( + input_node_indices=input_nodes, + output_node_indices=output_nodes, + commands=commands, + pauli_frame=pauli_frame, + ) + + idle_times = pattern.idle_times + # Should include measured node + assert nodes[1] in idle_times + # idle_time for nodes[1]: prepared at time 0, measured at time 1 -> idle = 1 + assert idle_times[nodes[1]] == 1 + + +def test_idle_times_input_nodes_use_zero_baseline( + pattern_components: tuple[dict[int, int], dict[int, int], PauliFrame, list[int]], +) -> None: + """Test that input nodes have idle time starting from time 0.""" + input_nodes, output_nodes, pauli_frame, nodes = pattern_components + meas_basis = PlannerMeasBasis(Plane.XY, 0.0) + # Input node measured at time 1 (after 1 TICK) + # prepared_time = 0, current_time = 1 -> idle_time = 1 + commands = ( + TICK(), + M(node=nodes[0], meas_basis=meas_basis), + TICK(), + ) + + pattern = Pattern( + input_node_indices=input_nodes, + output_node_indices=output_nodes, + commands=commands, + pauli_frame=pauli_frame, + ) + + idle_times = pattern.idle_times + # Input node prepared at 0, measured after 1 TICK (time=1) + # idle_time = 1 - 0 = 1 + assert idle_times[nodes[0]] == 1 + + +def test_idle_times_output_nodes_included_when_prepared() -> None: + """Test that output nodes are included in idle_times when they are prepared.""" + # Create a graph where output node is also an input node + graph = GraphState() + input_node = graph.add_physical_node() + output_node = graph.add_physical_node() + + graph.register_input(input_node, 0) + graph.register_input(output_node, 1) # Output node is also an input + graph.register_output(output_node, 0) + + pauli_frame = PauliFrame(graph, xflow={}, zflow={}) + + meas_basis = PlannerMeasBasis(Plane.XY, 0.0) + commands = ( + TICK(), + M(node=input_node, meas_basis=meas_basis), + TICK(), + TICK(), + ) + + pattern = Pattern( + input_node_indices=graph.input_node_indices, + output_node_indices=graph.output_node_indices, + commands=commands, + pauli_frame=pauli_frame, + ) + + idle_times = pattern.idle_times + # Output node prepared at 0, final time is 3 -> idle = 3 + assert output_node in idle_times + assert idle_times[output_node] == 3 + + +# Tests for throughput + + +def test_throughput_calculates_measurements_per_tick( + pattern_components: tuple[dict[int, int], dict[int, int], PauliFrame, list[int]], +) -> None: + """Test that throughput correctly calculates measurements per tick.""" + input_nodes, output_nodes, pauli_frame, nodes = pattern_components + meas_basis = PlannerMeasBasis(Plane.XY, 0.0) + # 2 measurements, 4 TICKs -> throughput = 0.5 + commands = ( + N(node=nodes[1]), + TICK(), + M(node=nodes[1], meas_basis=meas_basis), + TICK(), + N(node=3), + TICK(), + M(node=3, meas_basis=meas_basis), + TICK(), + ) + + pattern = Pattern( + input_node_indices=input_nodes, + output_node_indices=output_nodes, + commands=commands, + pauli_frame=pauli_frame, + ) + + assert pattern.throughput == 2 / 4 + + +def test_throughput_raises_for_zero_depth( + pattern_components: tuple[dict[int, int], dict[int, int], PauliFrame, list[int]], +) -> None: + """Test that throughput raises ValueError when pattern has no TICKs.""" + input_nodes, output_nodes, pauli_frame, nodes = pattern_components + meas_basis = PlannerMeasBasis(Plane.XY, 0.0) + commands = ( + N(node=nodes[1]), + M(node=nodes[1], meas_basis=meas_basis), + ) + + pattern = Pattern( + input_node_indices=input_nodes, + output_node_indices=output_nodes, + commands=commands, + pauli_frame=pauli_frame, + ) + + with pytest.raises(ValueError, match="Cannot calculate throughput"): + _ = pattern.throughput