diff --git a/docs/source/graphstate.rst b/docs/source/graphstate.rst new file mode 100644 index 000000000..24bf72e6c --- /dev/null +++ b/docs/source/graphstate.rst @@ -0,0 +1,26 @@ +GraphState +========== + +:mod:`graphix_zx.graphstate` module ++++++++++++++++++++++++++++++++++++ + +.. automodule:: graphix_zx.graphstate + +Graph State Classes +------------------- + +.. autoclass:: graphix_zx.graphstate.BaseGraphState + :members: + :member-order: bysource + +.. autoclass:: graphix_zx.graphstate.GraphState + :members: + :show-inheritance: + :member-order: bysource + +Functions +--------- + +.. autofunction:: graphix_zx.graphstate.compose_sequentially +.. autofunction:: graphix_zx.graphstate.compose_in_parallel +.. autofunction:: graphix_zx.graphstate.bipartite_edges diff --git a/docs/source/references.rst b/docs/source/references.rst index 9a6dcb462..22d9dbe22 100644 --- a/docs/source/references.rst +++ b/docs/source/references.rst @@ -7,3 +7,4 @@ Module reference common euler matrix + graphstate diff --git a/graphix_zx/graphstate.py b/graphix_zx/graphstate.py new file mode 100644 index 000000000..9024d0473 --- /dev/null +++ b/graphix_zx/graphstate.py @@ -0,0 +1,802 @@ +"""Graph State classes for Measurement-based Quantum Computing. + +This module provides: + +- `BaseGraphState`: Abstract base class for Graph State. +- `GraphState`: Minimal implementation of Graph State. +- `compose_sequentially`: Function to compose two graph states sequentially. +- `compose_in_parallel`: Function to compose two graph states in parallel. +- `bipartite_edges`: Function to create a complete bipartite graph between two sets of nodes. +""" + +from __future__ import annotations + +import operator +import weakref +from abc import ABC, abstractmethod +from collections.abc import Iterator, MutableMapping +from itertools import product +from types import MappingProxyType +from typing import TYPE_CHECKING + +import typing_extensions + +from graphix_zx.common import MeasBasis, Plane, PlannerMeasBasis +from graphix_zx.euler import update_lc_basis, update_lc_lc + +if TYPE_CHECKING: + from collections.abc import Set as AbstractSet + + from graphix_zx.euler import LocalClifford + + +class _MeasBasesDict(MutableMapping): + _owner: BaseGraphState + _field_name: str + _store: dict[int, MeasBasis] + + def __init__(self, owner: BaseGraphState, field_name: str) -> None: + self._owner = weakref.proxy(owner) + self._field_name = field_name + self._store = {} + + def __getitem__(self, node: int) -> MeasBasis: + return self._store[node] + + def __iter__(self) -> Iterator[int]: + return iter(self._store) + + def __len__(self) -> int: + return len(self._store) + + def __delitem__(self, node: int) -> None: + del self._store[node] + + def __setitem__(self, node: int, meas_basis: MeasBasis) -> None: + if node not in self._owner.physical_nodes: + msg = f"Node {node} does not exist in the graph state." + raise ValueError(msg) + if node in self._owner.output_node_indices: + msg = "Cannot set measurement basis for output node." + raise ValueError(msg) + self._store[node] = meas_basis + + +class MeasBasesField: + def __set_name__(self, owner, name: str) -> None: + self._private_name = f"_{name}" + + def __get__(self, obj, owner) -> _MeasBasesDict | MeasBasesField: + if obj is None: + return self + mapping = getattr(obj, self._private_name, None) + if mapping is None: + mapping = _MeasBasesDict(obj, self._private_name) + setattr(obj, self._private_name, mapping) + return mapping + + def __set__(self, obj, value: _MeasBasesDict) -> None: + msg = f"Cannot set {self._private_name} directly." + raise AttributeError(msg) + + +class BaseGraphState(ABC): + """Abstract base class for Graph State.""" + + meas_bases = MeasBasesField() + + @property + @abstractmethod + def input_node_indices(self) -> MappingProxyType[int, int]: + r"""Return map of input nodes to logical qubit indices. + + Returns + ------- + `types.MappingProxyType`\[`int`, `int`\] + qubit indices map of input nodes. + """ + + @property + @abstractmethod + def output_node_indices(self) -> MappingProxyType[int, int]: + r"""Return map of output nodes to logical qubit indices. + + Returns + ------- + `types.MappingProxyType`\[`int`, `int`\] + qubit indices map of output nodes. + """ + + @property + @abstractmethod + def physical_nodes(self) -> frozenset[int]: + r"""Return set of physical nodes. + + Returns + ------- + `frozenset`\[`int`\] + set of physical nodes. + """ + + @property + @abstractmethod + def physical_edges(self) -> frozenset[tuple[int, int]]: + r"""Return set of physical edges. + + Returns + ------- + `frozenset`\[`tuple`\[`int`, `int`\]` + set of physical edges. + """ + + @abstractmethod + def add_physical_node( + self, + ) -> int: + """Add a physical node to the graph state. + + Returns + ------- + `int` + The node index intenally generated + """ + + @abstractmethod + def add_physical_edge(self, node1: int, node2: int) -> None: + """Add a physical edge to the graph state. + + Parameters + ---------- + node1 : `int` + node index + node2 : `int` + node index + """ + + @abstractmethod + def register_input(self, node: int) -> int: + """Mark the node as an input node. + + Parameters + ---------- + node : `int` + node index + + Returns + ------- + `int` + logical qubit index + """ + + @abstractmethod + def register_output(self, node: int, q_index: int) -> None: + """Mark the node as an output node. + + Parameters + ---------- + node : `int` + node index + q_index : `int` + logical qubit index + """ + + @abstractmethod + def neighbors(self, node: int) -> frozenset[int]: + r"""Return the neighbors of the node. + + Parameters + ---------- + node : `int` + node index + + Returns + ------- + `frozenset`\[`int`\] + set of neighboring nodes + """ + + @abstractmethod + def is_canonical_form(self) -> bool: + r"""Check if the graph state is in canonical form. + + Returns + ------- + `bool` + `True` if the graph state is in canonical form, `False` otherwise. + """ + + +class GraphState(BaseGraphState): + """Minimal implementation of GraphState.""" + + __input_node_indices: dict[int, int] + __output_node_indices: dict[int, int] + __physical_nodes: set[int] + __physical_edges: dict[int, set[int]] + __meas_bases: dict[int, MeasBasis] + __local_cliffords: dict[int, LocalClifford] + + __node_counter: int + + def __init__(self) -> None: + self.__input_node_indices = {} + self.__output_node_indices = {} + self.__physical_nodes = set() + self.__physical_edges = {} + self.__meas_bases = {} + self.__local_cliffords = {} + + self.__node_counter = 0 + + @property + @typing_extensions.override + def input_node_indices(self) -> MappingProxyType[int, int]: + r"""Return map of input nodes to logical qubit indices. + + Returns + ------- + `types.MappingProxyType`\[`int`, `int`\] + qubit indices map of input nodes. + """ + return MappingProxyType(self.__input_node_indices) + + @property + @typing_extensions.override + def output_node_indices(self) -> MappingProxyType[int, int]: + r"""Return map of output nodes to logical qubit indices. + + Returns + ------- + `types.MappingProxyType`\[`int`, `int`\] + qubit indices map of output nodes. + """ + return MappingProxyType(self.__output_node_indices) + + @property + @typing_extensions.override + def physical_nodes(self) -> frozenset[int]: + r"""Return set of physical nodes. + + Returns + ------- + `frozenset`\[`int`\] + set of physical nodes. + """ + return frozenset(self.__physical_nodes) + + @property + @typing_extensions.override + def physical_edges(self) -> frozenset[tuple[int, int]]: + r"""Return set of physical edges. + + Returns + ------- + `frozenset`\[`tuple`\[`int`, `int`\] + set of physical edges. + """ + edges = set() + for node1 in self.__physical_edges: + for node2 in self.__physical_edges[node1]: + if node1 < node2: + edges |= {(node1, node2)} + return frozenset(edges) + + @property + def local_cliffords(self) -> MappingProxyType[int, LocalClifford]: + r"""Return local clifford nodes. + + Returns + ------- + `types.MappingProxyType`\[`int`, `LocalClifford`\] + local clifford nodes. + """ + return MappingProxyType(self.__local_cliffords) + + def _check_meas_basis(self) -> None: + """Check if the measurement basis is set for all physical nodes except output nodes. + + Raises + ------ + ValueError + If the measurement basis is not set for a node or the measurement plane is invalid. + """ + for v in self.physical_nodes - set(self.output_node_indices): + if self.meas_bases.get(v) is None: + msg = f"Measurement basis not set for node {v}" + raise ValueError(msg) + + def _ensure_node_exists(self, node: int) -> None: + """Ensure that the node exists in the graph state. + + Raises + ------ + ValueError + If the node does not exist in the graph state. + """ + if node not in self.__physical_nodes: + msg = f"Node does not exist {node=}" + raise ValueError(msg) + + @typing_extensions.override + def add_physical_node( + self, + ) -> int: + """Add a physical node to the graph state. + + Returns + ------- + `int` + The node index internally generated. + """ + node = self.__node_counter + self.__physical_nodes |= {node} + self.__physical_edges[node] = set() + self.__node_counter += 1 + + return node + + @typing_extensions.override + def add_physical_edge(self, node1: int, node2: int) -> None: + """Add a physical edge to the graph state. + + Parameters + ---------- + node1 : `int` + node index + node2 : `int` + node index + + Raises + ------ + ValueError + 1. If the node does not exist. + 2. If the edge already exists. + 3. If the edge is a self-loop. + """ + self._ensure_node_exists(node1) + self._ensure_node_exists(node2) + if node1 in self.__physical_edges[node2] or node2 in self.__physical_edges[node1]: + msg = f"Edge already exists {node1=}, {node2=}" + raise ValueError(msg) + if node1 == node2: + msg = "Self-loops are not allowed" + raise ValueError(msg) + self.__physical_edges[node1] |= {node2} + self.__physical_edges[node2] |= {node1} + + def remove_physical_node(self, node: int) -> None: + """Remove a physical node from the graph state. + + Parameters + ---------- + node : `int` + node index to be removed + + Raises + ------ + ValueError + If the input node is specified + """ + self._ensure_node_exists(node) + if node in self.input_node_indices: + msg = "The input node cannot be removed" + raise ValueError(msg) + self.__physical_nodes -= {node} + for neighbor in self.__physical_edges[node]: + self.__physical_edges[neighbor] -= {node} + del self.__physical_edges[node] + + if node in self.output_node_indices: + del self.__output_node_indices[node] + self.__meas_bases.pop(node, None) + self.__local_cliffords.pop(node, None) + + def remove_physical_edge(self, node1: int, node2: int) -> None: + """Remove a physical edge from the graph state. + + Parameters + ---------- + node1 : `int` + node index + node2 : `int` + node index + + Raises + ------ + ValueError + If the edge does not exist. + """ + self._ensure_node_exists(node1) + self._ensure_node_exists(node2) + if node1 not in self.__physical_edges[node2] or node2 not in self.__physical_edges[node1]: + msg = "Edge does not exist" + raise ValueError(msg) + self.__physical_edges[node1] -= {node2} + self.__physical_edges[node2] -= {node1} + + @typing_extensions.override + def register_input(self, node: int) -> int: + """Mark the node as an input node. + + Parameters + ---------- + node : `int` + node index + + Returns + ------- + `int` + logical qubit index + + Raises + ------ + ValueError + If the node is already registered as an input node. + """ + self._ensure_node_exists(node) + if node in self.__input_node_indices: + msg = "The node is already registered as an input node." + raise ValueError(msg) + q_index = len(self.__input_node_indices) + self.__input_node_indices[node] = q_index + return q_index + + @typing_extensions.override + def register_output(self, node: int, q_index: int) -> None: + """Mark the node as an output node. + + Parameters + ---------- + node : `int` + node index + q_index : `int` + logical qubit index + + Raises + ------ + ValueError + 1. If the node is already registered as an output node. + 2. If the node has a measurement basis. + 3. If the invalid q_index specified. + 4. If the q_index already exists in output qubit indices. + """ + self._ensure_node_exists(node) + if node in self.__output_node_indices: + msg = "The node is already registered as an output node." + raise ValueError(msg) + if self.meas_bases.get(node) is not None: + msg = "Cannot set output node with measurement basis." + raise ValueError(msg) + if q_index >= len(self.input_node_indices): + msg = "The q_index does not exist in input qubit indices" + raise ValueError(msg) + if q_index in self.output_node_indices.values(): + msg = "The q_index already exists in output qubit indices" + raise ValueError(msg) + self.__output_node_indices[node] = q_index + + def apply_local_clifford(self, node: int, lc: LocalClifford) -> None: + """Apply a local clifford to the node. + + Parameters + ---------- + node : `int` + node index + lc : `LocalClifford` + local clifford operator + """ + self._ensure_node_exists(node) + if node in self.input_node_indices or node in self.output_node_indices: + original_lc = self._pop_local_clifford(node) + if original_lc is not None: + new_lc = update_lc_lc(lc, original_lc) + self.__local_cliffords[node] = new_lc + else: + self.__local_cliffords[node] = lc + else: + self._check_meas_basis() + new_meas_basis = update_lc_basis(lc.conjugate(), self.meas_bases[node]) + self.assign_meas_basis(node, new_meas_basis) + + @typing_extensions.override + def neighbors(self, node: int) -> frozenset[int]: + r"""Return the neighbors of the node. + + Parameters + ---------- + node : `int` + node index + + Returns + ------- + `frozenset`\[`int`\] + set of neighboring nodes + """ + self._ensure_node_exists(node) + return frozenset(self.__physical_edges[node]) + + @typing_extensions.override + def is_canonical_form(self) -> bool: + r"""Check if the graph state is in canonical form. + + The definition of canonical form is: + 1. No Clifford operators applied. + 2. All non-output nodes have measurement basis. + + Returns + ------- + `bool` + `True` if the graph state is in canonical form, `False` otherwise. + """ + if self.__local_cliffords: + return False + for node in self.physical_nodes - set(self.output_node_indices): + if self.meas_bases.get(node) is None: + return False + return True + + def expand_local_cliffords(self) -> tuple[dict[int, tuple[int, int, int]], dict[int, tuple[int, int, int]]]: + r"""Expand local Clifford operators applied on the input and output nodes. + + Returns + ------- + `tuple`\[`dict`\[`int`, `tuple`\[`int`, `int`, `int`\]\], `dict`\[`int`, `tuple`\[`int`, `int`, `int`\]\]\] + A tuple of dictionaries mapping input and output node indices to the new node indices created. + """ + input_node_map = self._expand_input_local_cliffords() + output_node_map = self._expand_output_local_cliffords() + return input_node_map, output_node_map + + def _pop_local_clifford(self, node: int) -> LocalClifford | None: + """Pop local clifford of the node. + + Parameters + ---------- + node : `int` + node index to remove local clifford. + + Returns + ------- + `LocalClifford` | `None` + removed local clifford + """ + return self.__local_cliffords.pop(node, None) + + def _expand_input_local_cliffords(self) -> dict[int, tuple[int, int, int]]: + r"""Expand local Clifford operators applied on the input nodes. + + Returns + ------- + `dict`\[`int`, `tuple`\[`int`, `int`, `int`\]\] + A dictionary mapping input node indices to the new node indices created. + """ + node_index_addition_map = {} + new_input_indices = [] + for input_node, _ in sorted(self.input_node_indices.items(), key=operator.itemgetter(1)): + lc = self._pop_local_clifford(input_node) + if lc is None: + continue + + new_node_index0 = self.add_physical_node() + new_input_indices.append(new_node_index0) + new_node_index1 = self.add_physical_node() + new_node_index2 = self.add_physical_node() + + self.add_physical_edge(new_node_index0, new_node_index1) + self.add_physical_edge(new_node_index1, new_node_index2) + self.add_physical_edge(new_node_index2, input_node) + + self.assign_meas_basis(new_node_index0, PlannerMeasBasis(Plane.XY, lc.alpha)) + self.assign_meas_basis(new_node_index1, PlannerMeasBasis(Plane.XY, lc.beta)) + self.assign_meas_basis(new_node_index2, PlannerMeasBasis(Plane.XY, lc.gamma)) + + node_index_addition_map[input_node] = (new_node_index0, new_node_index1, new_node_index2) + + self.__input_node_indices = {} + for new_input_index in new_input_indices: + self.register_input(new_input_index) + + return node_index_addition_map + + def _expand_output_local_cliffords(self) -> dict[int, tuple[int, int, int]]: + r"""Expand local Clifford operators applied on the output nodes. + + Returns + ------- + `dict`\[`int`, `tuple`\[`int`, `int`, `int`\]\] + A dictionary mapping output node indices to the new node indices created. + """ + node_index_addition_map = {} + new_output_indices = [] + for output_node, _ in sorted(self.output_node_indices.items(), key=operator.itemgetter(1)): + lc = self._pop_local_clifford(output_node) + if lc is None: + continue + + new_node_index0 = self.add_physical_node() + new_node_index1 = self.add_physical_node() + new_node_index2 = self.add_physical_node() + new_output_indices.append(new_node_index2) + + self.add_physical_edge(output_node, new_node_index0) + self.add_physical_edge(new_node_index0, new_node_index1) + self.add_physical_edge(new_node_index1, new_node_index2) + + self.assign_meas_basis(output_node, PlannerMeasBasis(Plane.XY, lc.alpha)) + self.assign_meas_basis(new_node_index0, PlannerMeasBasis(Plane.XY, lc.beta)) + self.assign_meas_basis(new_node_index1, PlannerMeasBasis(Plane.XY, lc.gamma)) + + node_index_addition_map[output_node] = (new_node_index0, new_node_index1, new_node_index2) + + self.__output_node_indices = {} + for new_output_index in new_output_indices: + self.register_output(new_output_index, len(self.__output_node_indices)) + + return node_index_addition_map + + +def compose_sequentially( # noqa: C901 + graph1: BaseGraphState, graph2: BaseGraphState +) -> tuple[BaseGraphState, dict[int, int], dict[int, int]]: + r"""Compose two graph states sequentially. + + Parameters + ---------- + graph1 : `BaseGraphState` + first graph state + graph2 : `BaseGraphState` + second graph state + + Returns + ------- + `tuple`\[`BaseGraphState`, `dict`\[`int`, `int`\], `dict`\[`int`, `int`\]\] + composed graph state, node map for graph1, node map for graph2 + + Raises + ------ + ValueError + 1. If graph1 or graph2 is not in canonical form. + 2. If the logical qubit indices of output nodes in graph1 do not match input nodes in graph2. + """ + if not graph1.is_canonical_form(): + msg = "graph1 must be in canonical form." + raise ValueError(msg) + if not graph2.is_canonical_form(): + msg = "graph2 must be in canonical form." + raise ValueError(msg) + if set(graph1.output_node_indices.values()) != set(graph2.input_node_indices.values()): + msg = "Logical qubit indices of output nodes in graph1 must match input nodes in graph2." + raise ValueError(msg) + node_map1 = {} + node_map2 = {} + composed_graph = GraphState() + + for node in graph1.physical_nodes - graph1.output_node_indices.keys(): + node_index = composed_graph.add_physical_node() + meas_basis = graph1.meas_bases.get(node, None) + if meas_basis is not None: + composed_graph.assign_meas_basis(node_index, meas_basis) + node_map1[node] = node_index + + for node in graph2.physical_nodes: + node_index = composed_graph.add_physical_node() + meas_basis = graph2.meas_bases.get(node, None) + if meas_basis is not None: + composed_graph.assign_meas_basis(node_index, meas_basis) + node_map2[node] = node_index + + for input_node, _ in sorted(graph1.input_node_indices.items(), key=operator.itemgetter(1)): + composed_graph.register_input(node_map1[input_node]) + + for output_node, q_index in graph2.output_node_indices.items(): + composed_graph.register_output(node_map2[output_node], q_index) + + # overlapping node process + q_index2output_node_index1 = { + q_index: output_node_index1 for output_node_index1, q_index in graph1.output_node_indices.items() + } + for input_node_index2, q_index in graph2.input_node_indices.items(): + node_map1[q_index2output_node_index1[q_index]] = node_map2[input_node_index2] + + for u, v in graph1.physical_edges: + composed_graph.add_physical_edge(node_map1[u], node_map1[v]) + for u, v in graph2.physical_edges: + composed_graph.add_physical_edge(node_map2[u], node_map2[v]) + + return composed_graph, node_map1, node_map2 + + +def compose_in_parallel( # noqa: C901 + graph1: BaseGraphState, graph2: BaseGraphState +) -> tuple[BaseGraphState, dict[int, int], dict[int, int]]: + r"""Compose two graph states in parallel. + + Parameters + ---------- + graph1 : `BaseGraphState` + first graph state + graph2 : `BaseGraphState` + second graph state + + Returns + ------- + `tuple`\[`BaseGraphState`, `dict`\[`int`, `int`\], `dict`\[`int`, `int`\]\] + composed graph state, node map for graph1, node map for graph2 + + Raises + ------ + ValueError + If graph1 or graph2 is not in canonical form. + """ + if not graph1.is_canonical_form(): + msg = "graph1 must be in canonical form." + raise ValueError(msg) + if not graph2.is_canonical_form(): + msg = "graph2 must be in canonical form." + raise ValueError(msg) + node_map1 = {} + node_map2 = {} + composed_graph = GraphState() + + for node in graph1.physical_nodes: + node_index = composed_graph.add_physical_node() + meas_basis = graph1.meas_bases.get(node, None) + if meas_basis is not None: + composed_graph.assign_meas_basis(node_index, meas_basis) + node_map1[node] = node_index + + for node in graph2.physical_nodes: + node_index = composed_graph.add_physical_node() + meas_basis = graph2.meas_bases.get(node, None) + if meas_basis is not None: + composed_graph.assign_meas_basis(node_index, meas_basis) + node_map2[node] = node_index + + q_index_map1 = {} + q_index_map2 = {} + for input_node, old_q_index in sorted(graph1.input_node_indices.items(), key=operator.itemgetter(1)): + new_q_index = composed_graph.register_input(node_map1[input_node]) + q_index_map1[old_q_index] = new_q_index + + for input_node, old_q_index in sorted(graph2.input_node_indices.items(), key=operator.itemgetter(1)): + new_q_index = composed_graph.register_input(node_map2[input_node]) + q_index_map2[old_q_index] = new_q_index + + for output_node, q_index in graph1.output_node_indices.items(): + composed_graph.register_output(node_map1[output_node], q_index_map1[q_index]) + + for output_node, q_index in graph2.output_node_indices.items(): + composed_graph.register_output(node_map2[output_node], q_index_map2[q_index]) + + for u, v in graph1.physical_edges: + composed_graph.add_physical_edge(node_map1[u], node_map1[v]) + for u, v in graph2.physical_edges: + composed_graph.add_physical_edge(node_map2[u], node_map2[v]) + + return composed_graph, node_map1, node_map2 + + +def bipartite_edges(node_set1: AbstractSet[int], node_set2: AbstractSet[int]) -> set[tuple[int, int]]: + r"""Return a set of edges for the complete bipartite graph between two sets of nodes. + + Parameters + ---------- + node_set1 : `collections.abc.Set`\[`int`\] + set of nodes + node_set2 : `collections.abc.Set`\[`int`\] + set of nodes + + Returns + ------- + `set`\[`tuple`\[`int`, `int`\] + set of edges for the complete bipartite graph + + Raises + ------ + ValueError + If the two sets of nodes are not disjoint. + """ + if not node_set1.isdisjoint(node_set2): + msg = "The two sets of nodes must be disjoint." + raise ValueError(msg) + return {(min(a, b), max(a, b)) for a, b in product(node_set1, node_set2)} diff --git a/graphix_zx/matrix.py b/graphix_zx/matrix.py index 12812e3be..8ff6f1416 100644 --- a/graphix_zx/matrix.py +++ b/graphix_zx/matrix.py @@ -7,28 +7,22 @@ from __future__ import annotations -import sys -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, TypeVar import numpy as np if TYPE_CHECKING: from numpy.typing import NDArray -if sys.version_info >= (3, 10): - Numeric = np.number -else: - from typing import Union +T = TypeVar("T", bound=np.number[Any]) # can be removed >= 3.10 - Numeric = Union[np.int64, np.float64, np.complex128] - -def is_unitary(mat: NDArray[Numeric]) -> bool: +def is_unitary(mat: NDArray[T]) -> bool: r"""Check if a matrix is unitary. Parameters ---------- - mat : `numpy.typing.NDArray`\[`numpy.number`\] + mat : `numpy.typing.NDArray`\[T\] matrix to check Returns diff --git a/graphix_zx/random_objects.py b/graphix_zx/random_objects.py new file mode 100644 index 000000000..f0f2b9e2d --- /dev/null +++ b/graphix_zx/random_objects.py @@ -0,0 +1,84 @@ +"""Random object generator. + +This module provides: + +- `get_random_flow_graph`: Generate a random flow graph. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from graphix_zx.common import default_meas_basis +from graphix_zx.graphstate import GraphState + +if TYPE_CHECKING: + from numpy.random import Generator + + +def get_random_flow_graph( + width: int, + depth: int, + edge_p: float = 0.5, + rng: Generator | None = None, +) -> tuple[GraphState, dict[int, set[int]]]: + r"""Generate a random flow graph. + + Parameters + ---------- + width : `int` + The width of the graph. + depth : `int` + The depth of the graph. + edge_p : `float`, optional + The probability of adding an edge between two adjacent nodes. + Default is 0.5. + rng : `numpy.random.Generator`, optional + The random number generator. + Default is `None`. + + Returns + ------- + `GraphState` + The generated graph. + `dict`\[`int`, `set`\[`int`\]\] + The flow of the graph. + """ + graph = GraphState() + flow: dict[int, set[int]] = {} + q_indices = [] + + if rng is None: + rng = np.random.default_rng() + + # input nodes + for _ in range(width): + node_index = graph.add_physical_node() + q_index = graph.register_input(node_index) + graph.assign_meas_basis(node_index, default_meas_basis()) + q_indices.append(q_index) + + # internal nodes + for _ in range(depth - 2): + node_indices_layer = [] + for _ in range(width): + node_index = graph.add_physical_node() + graph.assign_meas_basis(node_index, default_meas_basis()) + graph.add_physical_edge(node_index - width, node_index) + flow[node_index - width] = {node_index} + node_indices_layer.append(node_index) + + for w in range(width - 1): + if rng.random() < edge_p: + graph.add_physical_edge(node_indices_layer[w], node_indices_layer[w + 1]) + + # output nodes + for qi in q_indices: + node_index = graph.add_physical_node() + graph.register_output(node_index, qi) + graph.add_physical_edge(node_index - width, node_index) + flow[node_index - width] = {node_index} + + return graph, flow diff --git a/pyproject.toml b/pyproject.toml index 86fa371b6..7d89a5f45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,7 @@ docstring-code-format = true [tool.ruff.lint.per-file-ignores] "tests/*.py" = [ "S101", # `assert` detected + "SLF001", # private method "PLR2004", # magic value in test(should be removed) "D100", "D103", diff --git a/tests/test_graphstate.py b/tests/test_graphstate.py new file mode 100644 index 000000000..be82870b8 --- /dev/null +++ b/tests/test_graphstate.py @@ -0,0 +1,216 @@ +"""Tests for the GraphState class.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from graphix_zx.common import Plane, PlannerMeasBasis +from graphix_zx.graphstate import GraphState, bipartite_edges + + +@pytest.fixture +def graph() -> GraphState: + """Generate an empty GraphState object. + + Returns + ------- + GraphState: An empty GraphState object. + """ + return GraphState() + + +def test_add_physical_node(graph: GraphState) -> None: + """Test adding a physical node to the graph.""" + node_index = graph.add_physical_node() + assert node_index in graph.physical_nodes + assert len(graph.physical_nodes) == 1 + + +def test_add_physical_node_input_output(graph: GraphState) -> None: + """Test adding a physical node as input and output.""" + node_index = graph.add_physical_node() + q_index = graph.register_input(node_index) + graph.register_output(node_index, q_index) + assert node_index in graph.input_node_indices + assert node_index in graph.output_node_indices + assert graph.input_node_indices[node_index] == q_index + assert graph.output_node_indices[node_index] == q_index + + +def test_ensure_node_exists_raises(graph: GraphState) -> None: + """Test ensuring a node exists in the graph.""" + with pytest.raises(ValueError, match="Node does not exist node=1"): + graph._ensure_node_exists(1) + + +def test_ensure_node_exists(graph: GraphState) -> None: + """Test ensuring a node exists in the graph.""" + node_index = graph.add_physical_node() + graph._ensure_node_exists(node_index) + + +def test_neighbors(graph: GraphState) -> None: + """Test getting the neighbors of a node in the graph.""" + node_index1 = graph.add_physical_node() + node_index2 = graph.add_physical_node() + node_index3 = graph.add_physical_node() + graph.add_physical_edge(node_index1, node_index2) + graph.add_physical_edge(node_index2, node_index3) + assert graph.neighbors(node_index1) == {node_index2} + assert graph.neighbors(node_index2) == {node_index1, node_index3} + assert graph.neighbors(node_index3) == {node_index2} + + +def test_add_physical_edge(graph: GraphState) -> None: + """Test adding a physical edge to the graph.""" + node_index1 = graph.add_physical_node() + node_index2 = graph.add_physical_node() + graph.add_physical_edge(node_index1, node_index2) + assert (node_index1, node_index2) in graph.physical_edges or (node_index2, node_index1) in graph.physical_edges + assert len(graph.physical_edges) == 1 + + +def test_add_duplicate_physical_edge(graph: GraphState) -> None: + """Test adding a duplicate physical edge to the graph.""" + node_index1 = graph.add_physical_node() + node_index2 = graph.add_physical_node() + graph.add_physical_edge(node_index1, node_index2) + with pytest.raises(ValueError, match=f"Edge already exists node1={node_index1}, node2={node_index2}"): + graph.add_physical_edge(node_index1, node_index2) + + +def test_add_edge_with_nonexistent_node(graph: GraphState) -> None: + """Test adding an edge with a nonexistent node to the graph.""" + node_index1 = graph.add_physical_node() + with pytest.raises(ValueError, match="Node does not exist node=2"): + graph.add_physical_edge(node_index1, 2) + + +def test_remove_physical_node_with_nonexistent_node(graph: GraphState) -> None: + """Test removing a nonexistent physical node from the graph.""" + with pytest.raises(ValueError, match="Node does not exist node=1"): + graph.remove_physical_node(1) + + +def test_remove_physical_node_with_input_removal(graph: GraphState) -> None: + """Test removing an input node from the graph""" + node_index = graph.add_physical_node() + graph.register_input(node_index) + with pytest.raises(ValueError, match="The input node cannot be removed"): + graph.remove_physical_node(node_index) + + +def test_remove_physical_node(graph: GraphState) -> None: + """Test removing a physical node from the graph.""" + node_index = graph.add_physical_node() + graph.remove_physical_node(node_index) + assert node_index not in graph.physical_nodes + assert len(graph.physical_nodes) == 0 + + +def test_remove_physical_node_from_minimal_graph(graph: GraphState) -> None: + """Test removing a physical node from the graph with edges.""" + node_index1 = graph.add_physical_node() + node_index2 = graph.add_physical_node() + graph.add_physical_edge(node_index1, node_index2) + graph.remove_physical_node(node_index1) + assert node_index1 not in graph.physical_nodes + assert node_index2 in graph.physical_nodes + assert len(graph.physical_nodes) == 1 + assert len(graph.physical_edges) == 0 + + +def test_remove_physical_node_from_3_nodes_graph(graph: GraphState) -> None: + """Test removing a physical node from the graph with 3 nodes and edges.""" + node_index1 = graph.add_physical_node() + node_index2 = graph.add_physical_node() + node_index3 = graph.add_physical_node() + graph.add_physical_edge(node_index1, node_index2) + graph.add_physical_edge(node_index2, node_index3) + q_index = graph.register_input(node_index1) + graph.register_output(node_index3, q_index) + graph.remove_physical_node(node_index2) + assert graph.physical_nodes == {node_index1, node_index3} + assert len(graph.physical_nodes) == 2 + assert len(graph.physical_edges) == 0 + assert graph.input_node_indices == {node_index1: q_index} + assert graph.output_node_indices == {node_index3: q_index} + + +def test_remove_physical_edge_with_nonexistent_nodes(graph: GraphState) -> None: + """Test removing an edge with nonexistent nodes from the graph.""" + with pytest.raises(ValueError, match="Node does not exist"): + graph.remove_physical_edge(1, 2) + + +def test_remove_physical_edge_with_nonexistent_edge(graph: GraphState) -> None: + """Test removing a nonexistent edge from the graph.""" + node_index1 = graph.add_physical_node() + node_index2 = graph.add_physical_node() + with pytest.raises(ValueError, match="Edge does not exist"): + graph.remove_physical_edge(node_index1, node_index2) + + +def test_remove_physical_edge(graph: GraphState) -> None: + """Test removing a physical edge from the graph.""" + node_index1 = graph.add_physical_node() + node_index2 = graph.add_physical_node() + graph.add_physical_edge(node_index1, node_index2) + graph.remove_physical_edge(node_index1, node_index2) + assert (node_index1, node_index2) not in graph.physical_edges + assert (node_index2, node_index1) not in graph.physical_edges + assert len(graph.physical_edges) == 0 + + +def test_register_output_raises_1(graph: GraphState) -> None: + with pytest.raises(ValueError, match="Node does not exist node=1"): + graph.register_output(1, 0) + + +def test_register_output_raises_2(graph: GraphState) -> None: + node_index = graph.add_physical_node() + graph.meas_bases[node_index] = PlannerMeasBasis(Plane.XY, 0.5 * np.pi) + with pytest.raises(ValueError, match=r"Cannot set output node with measurement basis."): + graph.register_output(node_index, 0) + + +def test_assign_meas_basis(graph: GraphState) -> None: + """Test setting the measurement basis of a physical node.""" + node_index = graph.add_physical_node() + meas_basis = PlannerMeasBasis(Plane.XZ, 0.5 * np.pi) + graph.meas_bases[node_index] = meas_basis + assert graph.meas_bases[node_index].plane == Plane.XZ + assert graph.meas_bases[node_index].angle == 0.5 * np.pi + + +def test_check_meas_raises_value_error(graph: GraphState) -> None: + """Test if measurement planes and angles are set improperly.""" + node_index = graph.add_physical_node() + with pytest.raises(ValueError, match=f"Measurement basis not set for node {node_index}"): + graph._check_meas_basis() + + +def test_check_meas_basis_success(graph: GraphState) -> None: + """Test if measurement planes and angles are set properly.""" + graph._check_meas_basis() + node_index1 = graph.add_physical_node() + q_index = graph.register_input(node_index1) + meas_basis = PlannerMeasBasis(Plane.XY, 0.5 * np.pi) + graph.meas_bases[node_index1] = meas_basis + graph._check_meas_basis() + + node_index2 = graph.add_physical_node() + graph.add_physical_edge(node_index1, node_index2) + graph.register_output(node_index2, q_index) + graph._check_meas_basis() + + +def test_bipartite_edges() -> None: + """Test the function that generate complete bipartite edges""" + assert bipartite_edges(set(), set()) == set() + assert bipartite_edges({1, 2}, {3, 4}) == {(1, 3), (1, 4), (2, 3), (2, 4)} + + +if __name__ == "__main__": + pytest.main()