Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- **Pattern**: Added the `depth` attribute into `Pattern`, which represents the depth of parallel execution.

- **Pattern**: Added pattern resource/throughput metrics (`active_volume`, `volume`, `idle_times`, `throughput`).

- **Scheduler Integration**: Enhanced qompile() to support temporal scheduling with TICK commands
- Added `scheduler` parameter to qompile() for custom scheduling
- Automatically inserts TICK commands between time slices
Expand All @@ -33,6 +35,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- This optimization is equivalent to the operation of the same name in the measurement calculus, and makes the measurement pattern as parallel as possible.
- The optimization is now self-contained within the feedforward module.

- **Feedforward Optimization**: Added `pauli_simplification()` to remove redundant Pauli corrections in correction maps when measuring in Pauli bases.

### Changed

- **Pattern**: Updated command sequence generation to support TICK commands
Expand Down
60 changes: 59 additions & 1 deletion graphqomb/feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import typing_extensions

from graphqomb.common import Plane
from graphqomb.common import Axis, Plane, determine_pauli_axis
from graphqomb.graphstate import BaseGraphState, odd_neighbors

if sys.version_info >= (3, 10):
Expand Down Expand Up @@ -277,3 +277,61 @@ def propagate_correction_map( # noqa: C901, PLR0912
new_zflow[parent] ^= {child_z}

return new_xflow, new_zflow


def pauli_simplification( # noqa: C901, PLR0912
graph: BaseGraphState,
xflow: Mapping[int, AbstractSet[int]],
zflow: Mapping[int, AbstractSet[int]] | None = None,
) -> tuple[dict[int, set[int]], dict[int, set[int]]]:
r"""Simplify the correction maps by removing redundant Pauli corrections.

Parameters
----------
graph : `BaseGraphState`
Underlying graph state.
xflow : `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\]
Correction map for X.
zflow : `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\] | `None`
Correction map for Z. If `None`, it is generated from xflow by odd neighbors.

Returns
-------
`tuple`\[`dict`\[`int`, `set`\[`int`\]\], `dict`\[`int`, `set`\[`int`\]\]]
Updated correction maps for X and Z after simplification.
"""
if zflow is None:
zflow = {node: odd_neighbors(xflow[node], graph) - {node} for node in xflow}

new_xflow = {k: set(vs) for k, vs in xflow.items()}
new_zflow = {k: set(vs) for k, vs in zflow.items()}

inv_xflow: dict[int, set[int]] = {}
inv_zflow: dict[int, set[int]] = {}
for k, vs in xflow.items():
for v in vs:
inv_xflow.setdefault(v, set()).add(k)
for k, vs in zflow.items():
for v in vs:
inv_zflow.setdefault(v, set()).add(k)

for node in graph.physical_nodes - graph.output_node_indices.keys():
meas_basis = graph.meas_bases.get(node)
if meas_basis is None:
continue
meas_axis = determine_pauli_axis(meas_basis)
if meas_axis is None:
continue

if meas_axis == Axis.X:
for parent in inv_xflow.get(node, set()):
new_xflow[parent] -= {node}
elif meas_axis == Axis.Z:
for parent in inv_zflow.get(node, set()):
new_zflow[parent] -= {node}
elif meas_axis == Axis.Y:
for parent in inv_xflow.get(node, set()) & inv_zflow.get(node, set()):
new_xflow[parent] -= {node}
new_zflow[parent] -= {node}

return new_xflow, new_zflow
70 changes: 70 additions & 0 deletions graphqomb/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,76 @@ def depth(self) -> int:
"""
return sum(1 for cmd in self.commands if isinstance(cmd, TICK))

@property
def active_volume(self) -> int:
"""Calculate tha active volume, summation of space for each timeslice.

Returns
-------
`int`
Active volume of the pattern
"""
return sum(self.space)

@property
def volume(self) -> int:
"""Calculate the volume, defined as max_space * depth.

Returns
-------
`int`
Volume of the pattern
"""
return self.max_space * self.depth

@property
def idle_times(self) -> dict[int, int]:
r"""Calculate the idle times for each qubit in the pattern.

Returns
-------
`dict`\[`int`, `int`\]
A dictionary mapping each qubit index to its idle time.
"""
idle_times: dict[int, int] = {}
prepared_time: dict[int, int] = dict.fromkeys(self.input_node_indices, 0)

current_time = 0
for cmd in self.commands:
if isinstance(cmd, TICK):
current_time += 1
elif isinstance(cmd, N):
prepared_time[cmd.node] = current_time
elif isinstance(cmd, M):
idle_times[cmd.node] = current_time - prepared_time[cmd.node]

for output_node in self.output_node_indices:
if output_node in prepared_time:
idle_times[output_node] = current_time - prepared_time[output_node]

return idle_times

@property
def throughput(self) -> float:
"""Calculate the number of measurements per TICK in the pattern.

Returns
-------
`float`
Number of measurements per TICK

Raises
------
ValueError
If the pattern has zero depth (no TICK commands)
"""
num_measurements = sum(1 for cmd in self.commands if isinstance(cmd, M))
num_ticks = self.depth
if num_ticks == 0:
msg = "Cannot calculate throughput for a pattern with zero depth (no TICK commands)."
raise ValueError(msg)
return num_measurements / num_ticks


def is_runnable(pattern: Pattern) -> None:
"""Check if the pattern is runnable.
Expand Down
217 changes: 216 additions & 1 deletion tests/test_feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import pytest

from graphqomb.circuit import MBQCCircuit, circuit2graph
from graphqomb.common import Plane, PlannerMeasBasis
from graphqomb.common import Axis, AxisMeasBasis, Plane, PlannerMeasBasis, Sign
from graphqomb.feedforward import (
_is_flow,
_is_gflow,
check_dag,
check_flow,
dag_from_flow,
pauli_simplification,
propagate_correction_map,
signal_shifting,
)
Expand Down Expand Up @@ -336,3 +337,217 @@ def test_signal_shifting_circuit_integration() -> None:

# Verify that the results match (inner product should be close to 1)
assert np.isclose(np.abs(inner_product), 1.0)


# Tests for pauli_simplification


def test_pauli_simplification_x_axis_removes_x_correction() -> None:
"""Test that X-axis measurement removes X corrections from the flow."""
# Create a 3-node graph: parent -> target -> output
graphstate = GraphState()
parent = graphstate.add_physical_node()
target = graphstate.add_physical_node()
output = graphstate.add_physical_node()
graphstate.add_physical_edge(parent, target)
graphstate.add_physical_edge(target, output)

# Set X-axis measurement basis for target
graphstate.assign_meas_basis(target, AxisMeasBasis(Axis.X, Sign.PLUS))
graphstate.assign_meas_basis(parent, AxisMeasBasis(Axis.X, Sign.PLUS))

graphstate.register_output(output, 0)

# Define flows where parent's X correction depends on target
xflow: dict[int, set[int]] = {parent: {target}, target: {output}}
zflow: dict[int, set[int]] = {parent: {target}, target: {output}}

new_xflow, new_zflow = pauli_simplification(graphstate, xflow, zflow)

# X-axis measurement should remove target from parent's X corrections
assert target not in new_xflow[parent]
# Z corrections should remain unchanged
assert target in new_zflow[parent]


def test_pauli_simplification_z_axis_removes_z_correction() -> None:
"""Test that Z-axis measurement removes Z corrections from the flow."""
# Create a 3-node graph: parent -> target -> output
graphstate = GraphState()
parent = graphstate.add_physical_node()
target = graphstate.add_physical_node()
output = graphstate.add_physical_node()
graphstate.add_physical_edge(parent, target)
graphstate.add_physical_edge(target, output)

# Set Z-axis measurement basis for target
graphstate.assign_meas_basis(target, AxisMeasBasis(Axis.Z, Sign.PLUS))
graphstate.assign_meas_basis(parent, AxisMeasBasis(Axis.Z, Sign.PLUS))

graphstate.register_output(output, 0)

# Define flows where parent's Z correction depends on target
xflow: dict[int, set[int]] = {parent: {target}, target: {output}}
zflow: dict[int, set[int]] = {parent: {target}, target: {output}}

new_xflow, new_zflow = pauli_simplification(graphstate, xflow, zflow)

# Z-axis measurement should remove target from parent's Z corrections
assert target not in new_zflow[parent]
# X corrections should remain unchanged
assert target in new_xflow[parent]


def test_pauli_simplification_y_axis_removes_both_corrections() -> None:
"""Test that Y-axis measurement removes both X and Z corrections from the flow."""
# Create a 3-node graph: parent -> target -> output
graphstate = GraphState()
parent = graphstate.add_physical_node()
target = graphstate.add_physical_node()
output = graphstate.add_physical_node()
graphstate.add_physical_edge(parent, target)
graphstate.add_physical_edge(target, output)

# Set Y-axis measurement basis for target
graphstate.assign_meas_basis(target, AxisMeasBasis(Axis.Y, Sign.PLUS))
graphstate.assign_meas_basis(parent, AxisMeasBasis(Axis.X, Sign.PLUS))

graphstate.register_output(output, 0)

# Define flows where parent's corrections depend on target
xflow: dict[int, set[int]] = {parent: {target}, target: {output}}
zflow: dict[int, set[int]] = {parent: {target}, target: {output}}

new_xflow, new_zflow = pauli_simplification(graphstate, xflow, zflow)

# Y-axis measurement should remove target from both X and Z corrections
assert target not in new_xflow[parent]
assert target not in new_zflow[parent]


def test_pauli_simplification_y_axis_skip() -> None:
"""Test that Y-axis measurement skips if not both corrections are present."""
# Create a 3-node graph: parent -> target -> output
graphstate = GraphState()
parent = graphstate.add_physical_node()
target = graphstate.add_physical_node()
output = graphstate.add_physical_node()
graphstate.add_physical_edge(parent, target)
graphstate.add_physical_edge(target, output)

# Set Y-axis measurement basis for target
graphstate.assign_meas_basis(target, AxisMeasBasis(Axis.Y, Sign.PLUS))
graphstate.assign_meas_basis(parent, AxisMeasBasis(Axis.X, Sign.PLUS))

graphstate.register_output(output, 0)

# Define flows where parent's corrections depend on target
xflow: dict[int, set[int]] = {parent: {target}, target: {output}}
zflow: dict[int, set[int]] = {parent: {output}, target: {output}}

# Skip removing X correction
new_xflow, _ = pauli_simplification(graphstate, xflow, zflow)

assert target in new_xflow[parent] # X correction remains

xflow = {parent: {output}, target: {output}}
zflow = {parent: {target}, target: {output}}
# Skip removing Z correction
_, new_zflow = pauli_simplification(graphstate, xflow, zflow)

assert target in new_zflow[parent] # Z correction remains


def test_pauli_simplification_non_pauli_leaves_unchanged() -> None:
"""Test that non-Pauli measurement angles leave corrections unchanged."""
# Create a 3-node graph: parent -> target -> output
graphstate = GraphState()
parent = graphstate.add_physical_node()
target = graphstate.add_physical_node()
output = graphstate.add_physical_node()
graphstate.add_physical_edge(parent, target)
graphstate.add_physical_edge(target, output)

# Set non-Pauli measurement basis for target (XY plane, angle=pi/4)
graphstate.assign_meas_basis(target, PlannerMeasBasis(Plane.XY, 0.25 * np.pi))
graphstate.assign_meas_basis(parent, PlannerMeasBasis(Plane.XY, 0.0))

graphstate.register_output(output, 0)

# Define flows
xflow: dict[int, set[int]] = {parent: {target}, target: {output}}
zflow: dict[int, set[int]] = {parent: {target}, target: {output}}

new_xflow, new_zflow = pauli_simplification(graphstate, xflow, zflow)

# Non-Pauli angle should leave corrections unchanged
assert target in new_xflow[parent]
assert target in new_zflow[parent]


def test_pauli_simplification_preserves_original_flows() -> None:
"""Test that original xflow and zflow are not modified."""
# Create a 3-node graph: parent -> target -> output
graphstate = GraphState()
parent = graphstate.add_physical_node()
target = graphstate.add_physical_node()
output = graphstate.add_physical_node()
graphstate.add_physical_edge(parent, target)
graphstate.add_physical_edge(target, output)

# Set X-axis measurement basis for target
graphstate.assign_meas_basis(target, AxisMeasBasis(Axis.X, Sign.PLUS))
graphstate.assign_meas_basis(parent, AxisMeasBasis(Axis.X, Sign.PLUS))

graphstate.register_output(output, 0)

# Define flows
xflow: dict[int, set[int]] = {parent: {target}, target: {output}}
zflow: dict[int, set[int]] = {parent: {target}, target: {output}}

# Store original values
original_xflow_parent = set(xflow[parent])
original_zflow_parent = set(zflow[parent])

pauli_simplification(graphstate, xflow, zflow)

# Original flows should be unchanged
assert xflow[parent] == original_xflow_parent
assert zflow[parent] == original_zflow_parent


def test_pauli_simplification_circuit_integration() -> None:
"""Test pauli_simplification integration with circuit compilation and simulation."""
# Create a quantum circuit (using j for rotations, cz for entanglement)
circuit = MBQCCircuit(2)
circuit.j(0, 0.5 * np.pi) # Rotation on qubit 0
circuit.cz(0, 1)
circuit.j(1, 0.25 * np.pi) # Rotation on qubit 1

# Convert circuit to graph and gflow
graphstate, gflow = circuit2graph(circuit)

# Apply pauli simplification
xflow, zflow = pauli_simplification(graphstate, gflow)

# Compile to pattern
pattern = qompile(graphstate, xflow, zflow)

# Verify pattern is runnable
assert pattern is not None
assert pattern.max_space >= 0

# Simulate the pattern
simulator = PatternSimulator(pattern, SimulatorBackend.StateVector)
simulator.simulate()
state = simulator.state
statevec = state.state()

# Compare with circuit simulator
circ_simulator = CircuitSimulator(circuit, SimulatorBackend.StateVector)
circ_simulator.simulate()
circ_state = circ_simulator.state.state()
inner_product = np.vdot(statevec, circ_state)

# Verify that the results match (inner product should be close to 1)
assert np.isclose(np.abs(inner_product), 1.0)
Loading