diff --git a/tropical_in_new/README.md b/tropical_in_new/README.md new file mode 100644 index 0000000..087aca1 --- /dev/null +++ b/tropical_in_new/README.md @@ -0,0 +1,51 @@ +# Tropical Tensor Network for MPE + +This folder contains an independent implementation of tropical tensor network +contraction for Most Probable Explanation (MPE). It uses `omeco` for contraction +order optimization and does not depend on the `bpdecoderplus` package; all code +lives under `tropical_in_new/src`. + +`omeco` provides high-quality contraction order heuristics (greedy and +simulated annealing). Install it alongside Torch to run the examples and tests. + +## Structure + +``` +tropical_in_new/ +├── README.md +├── requirements.txt +├── src/ +│ ├── __init__.py +│ ├── primitives.py +│ ├── network.py +│ ├── contraction.py +│ ├── mpe.py +│ └── utils.py +├── tests/ +│ ├── test_primitives.py +│ ├── test_contraction.py +│ ├── test_mpe.py +│ └── test_utils.py +├── examples/ +│ └── asia_network/ +│ ├── main.py +│ └── model.uai +└── docs/ + ├── mathematical_description.md + ├── api_reference.md + └── usage_guide.md +``` + +## Quick Start + +```bash +pip install -r tropical_in_new/requirements.txt +python tropical_in_new/examples/asia_network/main.py +``` + +## Notes on omeco + +`omeco` is a Rust-backed Python package. If a prebuilt wheel is not available +for your Python version, you will need a Rust toolchain with `cargo` on PATH to +build it from source. See the omeco repository for details: +https://github.com/GiggleLiu/omeco diff --git a/tropical_in_new/docs/api_reference.md b/tropical_in_new/docs/api_reference.md new file mode 100644 index 0000000..a137122 --- /dev/null +++ b/tropical_in_new/docs/api_reference.md @@ -0,0 +1,52 @@ +## Tropical Tensor Network API Reference + +Public APIs exported from `tropical_in_new/src`. + +### UAI Parsing + +- `read_model_file(path, factor_eltype=torch.float64) -> UAIModel` + Parse a UAI `.uai` model file. + +- `read_model_from_string(content, factor_eltype=torch.float64) -> UAIModel` + Parse a UAI model from an in-memory string. + +### Data Structures + +- `Factor(vars: Tuple[int, ...], values: torch.Tensor)` + Container for a factor scope and its tensor. + +- `UAIModel(nvars: int, cards: List[int], factors: List[Factor])` + Holds all model metadata for MPE. + +### Primitives + +- `safe_log(tensor: torch.Tensor) -> torch.Tensor` + Convert potentials to log-domain; maps zeros to `-inf`. + +- `tropical_einsum(a, b, index_map, track_argmax=True)` + Binary contraction in max-plus semiring; returns `(values, backpointer)`. + +- `argmax_trace(backpointer, assignment) -> Dict[int, int]` + Decode assignments for eliminated variables from backpointer metadata. + +### Network + Contraction + +- `build_network(factors: Iterable[Factor]) -> list[TensorNode]` + Convert factors into log-domain tensors with scopes. + +- `choose_order(nodes, heuristic="omeco") -> list[int]` + Select variable elimination order using `omeco`. + +- `build_contraction_tree(order, nodes) -> ContractionTree` + Construct a contraction plan from order and nodes. + +- `contract_tree(tree, einsum_fn, track_argmax=True) -> TreeNode` + Execute contractions and return the root node with backpointers. + +### MPE + +- `mpe_tropical(model, evidence=None, order=None)` + Return `(assignment_dict, score, info)` where `score` is log-domain. + +- `recover_mpe_assignment(root) -> Dict[int, int]` + Recover assignments from a contraction tree root. diff --git a/tropical_in_new/docs/mathematical_description.md b/tropical_in_new/docs/mathematical_description.md new file mode 100644 index 0000000..ae18385 --- /dev/null +++ b/tropical_in_new/docs/mathematical_description.md @@ -0,0 +1,32 @@ +## Tropical Tensor Network for MPE + +We compute the Most Probable Explanation (MPE) by treating the factor +graph as a tensor network in the tropical semiring (max-plus). + +### Tropical Semiring + +For log-potentials, multiplication becomes addition and summation becomes +maximization: + +- `a ⊗ b = a + b` +- `a ⊕ b = max(a, b)` + +### MPE Objective + +Let `x` be the set of variables and `phi_f(x_f)` the factor potentials. +We maximize the log-score: + +`x* = argmax_x sum_f log(phi_f(x_f))` + +### Contraction + +Each factor becomes a tensor over its variable scope. Contraction combines +two tensors by adding their log-values and reducing (max) over eliminated +variables. A greedy elimination order (min-fill or min-degree) controls +the intermediate tensor sizes. + +### Backpointers + +During each reduction, we store the argmax index for eliminated variables. +Traversing the contraction tree with these backpointers recovers the MPE +assignment. diff --git a/tropical_in_new/docs/usage_guide.md b/tropical_in_new/docs/usage_guide.md new file mode 100644 index 0000000..c34130c --- /dev/null +++ b/tropical_in_new/docs/usage_guide.md @@ -0,0 +1,39 @@ +## Tropical Tensor Network Usage + +This guide shows how to parse a UAI model and compute MPE using the +tropical tensor network implementation. + +### Install Dependencies + +```bash +pip install -r tropical_in_new/requirements.txt +``` + +`omeco` provides contraction order optimization. If a prebuilt wheel is not +available for your Python version, you may need a Rust toolchain installed. + +### Quick Start + +```python +from tropical_in_new.src import mpe_tropical, read_model_file + +model = read_model_file("tropical_in_new/examples/asia_network/model.uai") +assignment, score, info = mpe_tropical(model) +print(assignment, score, info) +``` + +### Evidence + +```python +from tropical_in_new.src import mpe_tropical, read_model_file + +model = read_model_file("tropical_in_new/examples/asia_network/model.uai") +evidence = {1: 0} # variable index is 1-based +assignment, score, info = mpe_tropical(model, evidence=evidence) +``` + +### Running the Example + +```bash +python tropical_in_new/examples/asia_network/main.py +``` diff --git a/tropical_in_new/examples/asia_network/main.py b/tropical_in_new/examples/asia_network/main.py new file mode 100644 index 0000000..08fe0eb --- /dev/null +++ b/tropical_in_new/examples/asia_network/main.py @@ -0,0 +1,15 @@ +"""Run tropical MPE on a small UAI model.""" + +from tropical_in_new.src import mpe_tropical, read_model_file + + +def main() -> None: + model = read_model_file("tropical_in_new/examples/asia_network/model.uai") + assignment, score, info = mpe_tropical(model) + print("MPE assignment:", assignment) + print("MPE log-score:", score) + print("Info:", info) + + +if __name__ == "__main__": + main() diff --git a/tropical_in_new/examples/asia_network/model.uai b/tropical_in_new/examples/asia_network/model.uai new file mode 100644 index 0000000..e60d689 --- /dev/null +++ b/tropical_in_new/examples/asia_network/model.uai @@ -0,0 +1,19 @@ +MARKOV +3 +2 2 2 +5 +1 0 +1 1 +1 2 +2 0 1 +2 1 2 +2 +0.6 0.4 +2 +0.5 0.5 +2 +0.7 0.3 +4 +1.2 0.2 0.2 1.2 +4 +1.1 0.3 0.3 1.1 diff --git a/tropical_in_new/requirements.txt b/tropical_in_new/requirements.txt new file mode 100644 index 0000000..89b7694 --- /dev/null +++ b/tropical_in_new/requirements.txt @@ -0,0 +1,2 @@ +torch>=2.0.0 +omeco diff --git a/tropical_in_new/src/__init__.py b/tropical_in_new/src/__init__.py new file mode 100644 index 0000000..2522ab8 --- /dev/null +++ b/tropical_in_new/src/__init__.py @@ -0,0 +1,33 @@ +"""Tropical tensor network tools for MPE (independent package).""" + +from .contraction import build_contraction_tree, choose_order, contract_tree +from .mpe import mpe_tropical, recover_mpe_assignment +from .network import TensorNode, build_network +from .primitives import argmax_trace, safe_log, tropical_einsum +from .utils import ( + Factor, + UAIModel, + build_tropical_factors, + read_evidence_file, + read_model_file, + read_model_from_string, +) + +__all__ = [ + "Factor", + "TensorNode", + "UAIModel", + "argmax_trace", + "build_contraction_tree", + "build_network", + "build_tropical_factors", + "choose_order", + "contract_tree", + "mpe_tropical", + "read_evidence_file", + "read_model_file", + "read_model_from_string", + "recover_mpe_assignment", + "safe_log", + "tropical_einsum", +] diff --git a/tropical_in_new/src/contraction.py b/tropical_in_new/src/contraction.py new file mode 100644 index 0000000..042db6e --- /dev/null +++ b/tropical_in_new/src/contraction.py @@ -0,0 +1,191 @@ +"""Contraction ordering and binary contraction tree execution.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Iterable, Tuple + +import torch + +import omeco + +from .network import TensorNode +from .primitives import Backpointer, tropical_reduce_max +from .utils import build_index_map + + +@dataclass +class ContractNode: + vars: Tuple[int, ...] + values: torch.Tensor + left: "TreeNode" + right: "TreeNode" + elim_vars: Tuple[int, ...] + backpointer: Backpointer | None + + +@dataclass +class ReduceNode: + vars: Tuple[int, ...] + values: torch.Tensor + child: "TreeNode" + elim_vars: Tuple[int, ...] + backpointer: Backpointer | None + + +TreeNode = TensorNode | ContractNode | ReduceNode + + +@dataclass(frozen=True) +class ContractionTree: + order: Tuple[int, ...] + nodes: Tuple[TensorNode, ...] + + +def _infer_var_sizes(nodes: Iterable[TensorNode]) -> dict[int, int]: + sizes: dict[int, int] = {} + for node in nodes: + for var, dim in zip(node.vars, node.values.shape): + if var in sizes and sizes[var] != dim: + raise ValueError( + f"Variable {var} has inconsistent sizes: {sizes[var]} vs {dim}." + ) + sizes[var] = int(dim) + return sizes + + +def _extract_leaf_index(node_dict: dict) -> int | None: + for key in ("leaf", "leaf_index", "index", "tensor"): + if key in node_dict: + value = node_dict[key] + if isinstance(value, int): + return value + return None + + +def _elim_order_from_tree_dict(tree_dict: dict, ixs: list[list[int]]) -> list[int]: + total_counts: dict[int, int] = {} + for vars in ixs: + for var in vars: + total_counts[var] = total_counts.get(var, 0) + 1 + + eliminated: set[int] = set() + + def visit(node: dict) -> tuple[dict[int, int], list[int]]: + leaf_index = _extract_leaf_index(node) + if leaf_index is not None: + counts: dict[int, int] = {} + for var in ixs[leaf_index]: + counts[var] = counts.get(var, 0) + 1 + return counts, [] + + children = node.get("children", []) + if not isinstance(children, list) or not children: + return {}, [] + + counts: dict[int, int] = {} + order: list[int] = [] + for child in children: + child_counts, child_order = visit(child) + order.extend(child_order) + for var, count in child_counts.items(): + counts[var] = counts.get(var, 0) + count + + newly_eliminated = [ + var + for var, count in counts.items() + if count == total_counts.get(var, 0) and var not in eliminated + ] + for var in sorted(newly_eliminated): + eliminated.add(var) + order.append(var) + return counts, order + + _, order = visit(tree_dict) + remaining = sorted([var for var in total_counts if var not in eliminated]) + return order + remaining + + +def choose_order(nodes: list[TensorNode], heuristic: str = "omeco") -> list[int]: + """Select elimination order over variable indices using omeco.""" + if heuristic != "omeco": + raise ValueError("Only the 'omeco' heuristic is supported.") + ixs = [list(node.vars) for node in nodes] + sizes = _infer_var_sizes(nodes) + method = omeco.GreedyMethod() if hasattr(omeco, "GreedyMethod") else None + tree = ( + omeco.optimize_code(ixs, [], sizes, method) + if method is not None + else omeco.optimize_code(ixs, [], sizes) + ) + tree_dict = tree.to_dict() if hasattr(tree, "to_dict") else tree + if not isinstance(tree_dict, dict): + raise ValueError("omeco.optimize_code did not return a usable tree.") + return _elim_order_from_tree_dict(tree_dict, ixs) + + +def build_contraction_tree(order: Iterable[int], nodes: list[TensorNode]) -> ContractionTree: + """Prepare a contraction plan from order and leaf nodes.""" + return ContractionTree(order=tuple(order), nodes=tuple(nodes)) + + +def contract_tree( + tree: ContractionTree, + einsum_fn, + track_argmax: bool = True, +) -> TreeNode: + """Contract along the tree using the tropical einsum.""" + active_nodes: list[TreeNode] = list(tree.nodes) + for var in tree.order: + bucket = [node for node in active_nodes if var in node.vars] + if not bucket: + continue + bucket_ids = {id(node) for node in bucket} + active_nodes = [node for node in active_nodes if id(node) not in bucket_ids] + combined: TreeNode = bucket[0] + for i, other in enumerate(bucket[1:]): + is_last = i == len(bucket) - 2 + elim_vars = (var,) if is_last else () + index_map = build_index_map(combined.vars, other.vars, elim_vars=elim_vars) + values, backpointer = einsum_fn( + combined.values, other.values, index_map, + track_argmax=track_argmax if is_last else False, + ) + combined = ContractNode( + vars=index_map.out_vars, + values=values, + left=combined, + right=other, + elim_vars=elim_vars, + backpointer=backpointer, + ) + if var in combined.vars: + # Single-node bucket: eliminate via reduce + values, backpointer = tropical_reduce_max( + combined.values, combined.vars, (var,), track_argmax=track_argmax + ) + combined = ReduceNode( + vars=tuple(v for v in combined.vars if v != var), + values=values, + child=combined, + elim_vars=(var,), + backpointer=backpointer, + ) + active_nodes.append(combined) + while len(active_nodes) > 1: + left = active_nodes.pop(0) + right = active_nodes.pop(0) + index_map = build_index_map(left.vars, right.vars, elim_vars=()) + values, _ = einsum_fn(left.values, right.values, index_map, track_argmax=False) + combined = ContractNode( + vars=index_map.out_vars, + values=values, + left=left, + right=right, + elim_vars=(), + backpointer=None, + ) + active_nodes.append(combined) + if not active_nodes: + raise ValueError("Contraction produced no nodes.") + return active_nodes[0] diff --git a/tropical_in_new/src/mpe.py b/tropical_in_new/src/mpe.py new file mode 100644 index 0000000..8f64aeb --- /dev/null +++ b/tropical_in_new/src/mpe.py @@ -0,0 +1,102 @@ +"""Top-level MPE API for tropical tensor networks.""" + +from __future__ import annotations + +from typing import Dict, Iterable + +from .contraction import ContractNode, ReduceNode, build_contraction_tree, choose_order +from .contraction import contract_tree as _contract_tree +from .network import TensorNode, build_network +from .primitives import argmax_trace, tropical_einsum, tropical_reduce_max +from .utils import UAIModel, build_tropical_factors + + +def _unravel_argmax(values, vars: Iterable[int]) -> Dict[int, int]: + vars = tuple(vars) + if not vars: + return {} + flat = int(values.reshape(-1).argmax().item()) + shape = list(values.shape) + assignments = [] + for size in reversed(shape): + assignments.append(flat % size) + flat //= size + assignments = list(reversed(assignments)) + return {var: int(val) for var, val in zip(vars, assignments)} + + +def recover_mpe_assignment(root) -> Dict[int, int]: + """Recover MPE assignment from a contraction tree with backpointers.""" + assignment: Dict[int, int] = {} + + def require_vars(required: Iterable[int], available: Dict[int, int]) -> None: + missing = [v for v in required if v not in available] + if missing: + raise KeyError( + "Missing assignment values for variables: " + f"{missing}. Provided assignment keys: {sorted(available.keys())}" + ) + + def traverse(node, out_assignment: Dict[int, int]) -> None: + assignment.update(out_assignment) + if isinstance(node, TensorNode): + return + if isinstance(node, ReduceNode): + elim_assignment = ( + argmax_trace(node.backpointer, out_assignment) if node.backpointer else {} + ) + combined = {**out_assignment, **elim_assignment} + require_vars(node.child.vars, combined) + child_assignment = {v: combined[v] for v in node.child.vars} + traverse(node.child, child_assignment) + return + if isinstance(node, ContractNode): + elim_assignment = ( + argmax_trace(node.backpointer, out_assignment) if node.backpointer else {} + ) + combined = {**out_assignment, **elim_assignment} + require_vars(node.left.vars, combined) + left_assignment = {v: combined[v] for v in node.left.vars} + require_vars(node.right.vars, combined) + right_assignment = {v: combined[v] for v in node.right.vars} + traverse(node.left, left_assignment) + traverse(node.right, right_assignment) + + initial = _unravel_argmax(root.values, root.vars) + traverse(root, initial) + return assignment + + +def mpe_tropical( + model: UAIModel, + evidence: Dict[int, int] | None = None, + order: Iterable[int] | None = None, +) -> tuple[Dict[int, int], float, Dict[str, int | tuple[int, ...]]]: + """Return MPE assignment, score, and contraction metadata.""" + evidence = evidence or {} + factors = build_tropical_factors(model, evidence) + nodes = build_network(factors) + if order is None: + order = choose_order(nodes, heuristic="omeco") + tree = build_contraction_tree(order, nodes) + root = _contract_tree(tree, einsum_fn=tropical_einsum) + if root.vars: + values, backpointer = tropical_reduce_max( + root.values, root.vars, tuple(root.vars), track_argmax=True + ) + root = ReduceNode( + vars=(), + values=values, + child=root, + elim_vars=tuple(root.vars), + backpointer=backpointer, + ) + assignment = recover_mpe_assignment(root) + assignment.update({int(k): int(v) for k, v in evidence.items()}) + score = float(root.values.item()) + info = { + "order": tuple(order), + "num_nodes": len(nodes), + "num_elims": len(tuple(order)), + } + return assignment, score, info diff --git a/tropical_in_new/src/network.py b/tropical_in_new/src/network.py new file mode 100644 index 0000000..aa98b7a --- /dev/null +++ b/tropical_in_new/src/network.py @@ -0,0 +1,26 @@ +"""Tensor network construction for tropical MPE.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Iterable, Tuple + +import torch + +from .primitives import safe_log + + +@dataclass(frozen=True) +class TensorNode: + """Node representing a tensor factor in log-domain.""" + + vars: Tuple[int, ...] + values: torch.Tensor + + +def build_network(factors: Iterable) -> list[TensorNode]: + """Convert factors into tensor nodes (log-domain).""" + nodes: list[TensorNode] = [] + for factor in factors: + nodes.append(TensorNode(vars=tuple(factor.vars), values=safe_log(factor.values))) + return nodes diff --git a/tropical_in_new/src/primitives.py b/tropical_in_new/src/primitives.py new file mode 100644 index 0000000..46fb332 --- /dev/null +++ b/tropical_in_new/src/primitives.py @@ -0,0 +1,149 @@ +"""Tropical semiring primitives and backpointer helpers.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, Iterable, Tuple + +import torch + +try: # Optional accelerator; falls back to pure torch. + import tropical_gemm # type: ignore +except Exception: # pragma: no cover - optional dependency + tropical_gemm = None + + +@dataclass(frozen=True) +class IndexMap: + """Index mapping for binary tropical contractions.""" + + a_vars: Tuple[int, ...] + b_vars: Tuple[int, ...] + out_vars: Tuple[int, ...] + elim_vars: Tuple[int, ...] + + +@dataclass +class Backpointer: + """Stores argmax metadata for eliminated variables.""" + + elim_vars: Tuple[int, ...] + elim_shape: Tuple[int, ...] + out_vars: Tuple[int, ...] + argmax_flat: torch.Tensor + + +def safe_log(tensor: torch.Tensor) -> torch.Tensor: + """Convert potentials to log domain; zeros map to -inf.""" + neg_inf = torch.tensor(float("-inf"), dtype=tensor.dtype, device=tensor.device) + return torch.where(tensor > 0, torch.log(tensor), neg_inf) + + +def _align_tensor( + tensor: torch.Tensor, tensor_vars: Tuple[int, ...], target_vars: Tuple[int, ...] +) -> torch.Tensor: + if not target_vars: + return tensor.reshape(()) + if not tensor_vars: + return tensor.reshape((1,) * len(target_vars)) + present = [v for v in target_vars if v in tensor_vars] + perm = [tensor_vars.index(v) for v in present] + aligned = tensor if perm == list(range(len(tensor_vars))) else tensor.permute(perm) + shape = [] + p = 0 + for var in target_vars: + if var in tensor_vars: + shape.append(aligned.shape[p]) + p += 1 + else: + shape.append(1) + return aligned.reshape(tuple(shape)) + + +def tropical_reduce_max( + tensor: torch.Tensor, + vars: Tuple[int, ...], + elim_vars: Iterable[int], + track_argmax: bool = True, +) -> tuple[torch.Tensor, Backpointer | None]: + elim_vars = tuple(elim_vars) + if not elim_vars: + return tensor, None + target_vars = tuple(vars) + missing_elim_vars = [v for v in elim_vars if v not in target_vars] + if missing_elim_vars: + raise ValueError( + "tropical_reduce_max: elim_vars " + f"{missing_elim_vars} are not present in vars {target_vars}." + ) + elim_axes = [target_vars.index(v) for v in elim_vars] + keep_axes = [i for i in range(len(target_vars)) if i not in elim_axes] + perm = keep_axes + elim_axes + permuted = tensor if perm == list(range(len(target_vars))) else tensor.permute(perm) + out_shape = [tensor.shape[i] for i in keep_axes] + elim_shape = [tensor.shape[i] for i in elim_axes] + flat = permuted.reshape(*out_shape, -1) + values, argmax_flat = torch.max(flat, dim=-1) + if not track_argmax: + return values, None + backpointer = Backpointer( + elim_vars=elim_vars, + elim_shape=tuple(elim_shape), + out_vars=tuple(target_vars[i] for i in keep_axes), + argmax_flat=argmax_flat, + ) + return values, backpointer + + +def tropical_einsum( + a: torch.Tensor, + b: torch.Tensor, + index_map: IndexMap, + track_argmax: bool = True, +) -> tuple[torch.Tensor, Backpointer | None]: + """Binary tropical contraction: add in log-space, max over elim_vars.""" + if tropical_gemm is not None and hasattr(tropical_gemm, "einsum"): + try: # pragma: no cover - optional dependency + return tropical_gemm.einsum(a, b, index_map, track_argmax=track_argmax) + except Exception: + pass + + target_vars = tuple(dict.fromkeys(index_map.a_vars + index_map.b_vars)) + expected_out = tuple(v for v in target_vars if v not in index_map.elim_vars) + if index_map.out_vars and index_map.out_vars != expected_out: + raise ValueError("index_map.out_vars does not match contraction result ordering.") + + aligned_a = _align_tensor(a, index_map.a_vars, target_vars) + aligned_b = _align_tensor(b, index_map.b_vars, target_vars) + combined = aligned_a + aligned_b + + if not index_map.elim_vars: + return combined, None + + values, backpointer = tropical_reduce_max( + combined, target_vars, index_map.elim_vars, track_argmax=track_argmax + ) + return values, backpointer + + +def argmax_trace(backpointer: Backpointer, assignment: Dict[int, int]) -> Dict[int, int]: + """Decode eliminated variable assignments from a backpointer.""" + if not backpointer.elim_vars: + return {} + if backpointer.out_vars: + missing = [v for v in backpointer.out_vars if v not in assignment] + if missing: + raise KeyError( + "Missing assignment values for output variables: " + f"{missing}. Provided assignment keys: {sorted(assignment.keys())}" + ) + idx = tuple(assignment[v] for v in backpointer.out_vars) + flat = int(backpointer.argmax_flat[idx].item()) + else: + flat = int(backpointer.argmax_flat.item()) + values = [] + for size in reversed(backpointer.elim_shape): + values.append(flat % size) + flat //= size + values = list(reversed(values)) + return {var: int(val) for var, val in zip(backpointer.elim_vars, values)} diff --git a/tropical_in_new/src/utils.py b/tropical_in_new/src/utils.py new file mode 100644 index 0000000..fd7519d --- /dev/null +++ b/tropical_in_new/src/utils.py @@ -0,0 +1,167 @@ +"""Utility helpers for tropical tensor networks.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, Iterable, List, Tuple + +import torch + +from .primitives import IndexMap + + +@dataclass +class Factor: + """Factor class representing a factor in the factor graph.""" + + vars: Tuple[int, ...] + values: torch.Tensor + + +@dataclass +class UAIModel: + """UAI model containing variables, cardinalities, and factors.""" + + nvars: int + cards: List[int] + factors: List[Factor] + + +def read_model_file(filepath: str, factor_eltype=torch.float64) -> UAIModel: + """Parse UAI format model file.""" + with open(filepath, "r", encoding="utf-8") as f: + content = f.read() + return read_model_from_string(content, factor_eltype=factor_eltype) + + +def read_model_from_string(content: str, factor_eltype=torch.float64) -> UAIModel: + """Parse UAI model from string.""" + lines = [line.strip() for line in content.split("\n") if line.strip()] + if len(lines) < 4: + raise ValueError("Malformed UAI model: expected at least 4 header lines.") + network_type = lines[0] + if network_type not in ("MARKOV", "BAYES"): + raise ValueError( + f"Unsupported UAI network type: {network_type!r}. Expected 'MARKOV' or 'BAYES'." + ) + nvars = int(lines[1]) + cards = [int(x) for x in lines[2].split()] + if len(cards) != nvars: + raise ValueError(f"Expected {nvars} cardinalities, got {len(cards)}.") + ntables = int(lines[3]) + if len(lines) < 4 + ntables: + raise ValueError( + f"Malformed UAI model: expected {ntables} scope lines, got {len(lines) - 4}." + ) + + scopes: list[list[int]] = [] + for i in range(ntables): + parts = lines[4 + i].split() + scope_size = int(parts[0]) + if len(parts) - 1 != scope_size: + raise ValueError( + f"Scope size mismatch on line {4 + i}: " + f"declared {scope_size}, found {len(parts) - 1} variables." + ) + scope = [int(x) + 1 for x in parts[1:]] + assert all(1 <= v <= nvars for v in scope), ( + f"Scope variables out of range [1, {nvars}]: {scope}" + ) + scopes.append(scope) + assert len(scopes) == ntables + + idx = 4 + ntables + tokens: List[str] = [] + while idx < len(lines): + tokens.extend(lines[idx].split()) + idx += 1 + cursor = 0 + + factors: List[Factor] = [] + for scope in scopes: + if cursor >= len(tokens): + raise ValueError("Unexpected end of UAI factor table data.") + nelements = int(tokens[cursor]) + cursor += 1 + expected_size = 1 + for card in (cards[v - 1] for v in scope): + expected_size *= card + if nelements != expected_size: + raise ValueError( + f"Factor table size mismatch for scope {scope}: " + f"expected {expected_size}, got {nelements}." + ) + if cursor + nelements > len(tokens): + raise ValueError("Unexpected end of UAI factor table data.") + values = torch.tensor( + [float(x) for x in tokens[cursor : cursor + nelements]], dtype=factor_eltype + ) + cursor += nelements + shape = tuple([cards[v - 1] for v in scope]) + values = values.reshape(shape) + assert values.shape == shape + factors.append(Factor(tuple(scope), values)) + + assert len(factors) == ntables + return UAIModel(nvars, cards, factors) + + +def read_evidence_file(filepath: str) -> Dict[int, int]: + """Parse evidence file (.evid format).""" + if not filepath: + return {} + with open(filepath, "r", encoding="utf-8") as f: + lines = f.readlines() + if not lines: + return {} + last_line = lines[-1].strip() + parts = [int(x) for x in last_line.split()] + nobsvars = parts[0] + if len(parts) < 1 + 2 * nobsvars: + raise ValueError( + f"Malformed evidence line: expected {1 + 2 * nobsvars} entries, got {len(parts)}." + ) + evidence: Dict[int, int] = {} + for i in range(nobsvars): + var_idx = parts[1 + 2 * i] + 1 + var_value = parts[2 + 2 * i] + assert var_idx >= 1, f"Invalid variable index: {var_idx - 1}" + evidence[var_idx] = var_value + assert len(evidence) == nobsvars + return evidence + + +def apply_evidence_to_factor(factor: Factor, evidence: Dict[int, int]) -> Factor: + """Clamp a factor with evidence and drop observed variables.""" + if not evidence: + return factor + index = [] + new_vars = [] + for var in factor.vars: + if var in evidence: + index.append(int(evidence[var])) + else: + index.append(slice(None)) + new_vars.append(var) + values = factor.values[tuple(index)] + if values.ndim == 0: + values = values.reshape(()) + return Factor(tuple(new_vars), values) + + +def build_tropical_factors(model: UAIModel, evidence: Dict[int, int] | None = None) -> list[Factor]: + """Apply evidence and return factors in original domain (log later).""" + evidence = evidence or {} + factors = [apply_evidence_to_factor(factor, evidence) for factor in model.factors] + return factors + + +def build_index_map( + a_vars: Iterable[int], b_vars: Iterable[int], elim_vars: Iterable[int] +) -> IndexMap: + a_vars = tuple(a_vars) + b_vars = tuple(b_vars) + elim_vars = tuple(elim_vars) + target_vars = tuple(dict.fromkeys(a_vars + b_vars)) + out_vars = tuple(v for v in target_vars if v not in elim_vars) + return IndexMap(a_vars=a_vars, b_vars=b_vars, out_vars=out_vars, elim_vars=elim_vars) diff --git a/tropical_in_new/tests/conftest.py b/tropical_in_new/tests/conftest.py new file mode 100644 index 0000000..fbc1e9e --- /dev/null +++ b/tropical_in_new/tests/conftest.py @@ -0,0 +1,5 @@ +from pathlib import Path +import sys + +REPO_ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(REPO_ROOT)) diff --git a/tropical_in_new/tests/test_contraction.py b/tropical_in_new/tests/test_contraction.py new file mode 100644 index 0000000..5ff4279 --- /dev/null +++ b/tropical_in_new/tests/test_contraction.py @@ -0,0 +1,103 @@ +import pytest +import torch + +from tropical_in_new.src.contraction import ( + _elim_order_from_tree_dict, + _extract_leaf_index, + _infer_var_sizes, + build_contraction_tree, + choose_order, + contract_tree, +) +from tropical_in_new.src.network import TensorNode +from tropical_in_new.src.primitives import tropical_einsum + + +def test_choose_order_returns_all_vars(): + nodes = [ + TensorNode(vars=(1, 2), values=torch.zeros((2, 2))), + TensorNode(vars=(2, 3), values=torch.zeros((2, 2))), + ] + order = choose_order(nodes, heuristic="omeco") + assert set(order) == {1, 2, 3} + + +def test_choose_order_invalid_heuristic(): + nodes = [TensorNode(vars=(1,), values=torch.zeros((2,)))] + with pytest.raises(ValueError, match="Only the 'omeco' heuristic"): + choose_order(nodes, heuristic="invalid") + + +def test_contract_tree_reduces_to_scalar(): + nodes = [ + TensorNode(vars=(1,), values=torch.tensor([0.1, 0.0])), + TensorNode(vars=(1, 2), values=torch.tensor([[0.2, 0.3], [0.4, 0.1]])), + ] + order = [1, 2] + tree = build_contraction_tree(order, nodes) + root = contract_tree(tree, tropical_einsum) + assert root.values.numel() == 1 + + +def test_contract_tree_three_nodes_shared_var(): + """Test contraction with 3 nodes sharing a variable (uses einsum with elimination).""" + nodes = [ + TensorNode(vars=(1, 2), values=torch.tensor([[0.1, 0.2], [0.3, 0.4]])), + TensorNode(vars=(2, 3), values=torch.tensor([[0.5, 0.6], [0.7, 0.8]])), + TensorNode(vars=(2,), values=torch.tensor([0.1, 0.9])), + ] + order = [2, 1, 3] + tree = build_contraction_tree(order, nodes) + root = contract_tree(tree, tropical_einsum) + assert root.values.numel() == 1 + + +def test_contract_tree_partial_order(): + """Test contraction where order doesn't cover all vars (remaining merged).""" + nodes = [ + TensorNode(vars=(1, 2), values=torch.tensor([[0.1, 0.2], [0.3, 0.4]])), + TensorNode(vars=(3,), values=torch.tensor([0.5, 0.6])), + ] + order = [1] + tree = build_contraction_tree(order, nodes) + root = contract_tree(tree, tropical_einsum) + # var 2 and 3 remain + assert 2 in root.vars or 3 in root.vars + + +def test_infer_var_sizes_inconsistent(): + nodes = [ + TensorNode(vars=(1,), values=torch.zeros((2,))), + TensorNode(vars=(1,), values=torch.zeros((3,))), + ] + with pytest.raises(ValueError, match="inconsistent sizes"): + _infer_var_sizes(nodes) + + +def test_extract_leaf_index(): + assert _extract_leaf_index({"leaf": 0}) == 0 + assert _extract_leaf_index({"leaf_index": 2}) == 2 + assert _extract_leaf_index({"index": 1}) == 1 + assert _extract_leaf_index({"tensor": 3}) == 3 + assert _extract_leaf_index({"other": "abc"}) is None + assert _extract_leaf_index({"leaf": "not_int"}) is None + + +def test_elim_order_from_tree_dict(): + tree_dict = { + "children": [ + {"leaf": 0}, + {"leaf": 1}, + ] + } + ixs = [[1, 2], [2, 3]] + order = _elim_order_from_tree_dict(tree_dict, ixs) + assert set(order) == {1, 2, 3} + + +def test_elim_order_from_tree_dict_no_children(): + tree_dict = {"other_key": "value"} + ixs = [[1, 2]] + order = _elim_order_from_tree_dict(tree_dict, ixs) + # No children → remaining vars appended + assert set(order) == {1, 2} diff --git a/tropical_in_new/tests/test_mpe.py b/tropical_in_new/tests/test_mpe.py new file mode 100644 index 0000000..c28117b --- /dev/null +++ b/tropical_in_new/tests/test_mpe.py @@ -0,0 +1,151 @@ +import itertools + +import pytest +import torch + +from tropical_in_new.src.mpe import mpe_tropical, recover_mpe_assignment +from tropical_in_new.src.network import TensorNode +from tropical_in_new.src.utils import read_model_from_string + + +def _brute_force_mpe(cards, factors): + """General brute-force MPE over any number of variables.""" + best_score = float("-inf") + best_assignment = None + for combo in itertools.product(*(range(c) for c in cards)): + score = 0.0 + for vars, values in factors: + idx = tuple(combo[v - 1] for v in vars) + score += torch.log(values[idx]).item() + if score > best_score: + best_score = score + best_assignment = {i + 1: combo[i] for i in range(len(cards))} + return best_assignment, best_score + + +def test_mpe_matches_bruteforce(): + uai = "\n".join( + [ + "MARKOV", + "2", + "2 2", + "3", + "1 0", + "1 1", + "2 0 1", + "2", + "0.6 0.4", + "2", + "0.3 0.7", + "4", + "1.2 0.2 0.2 1.2", + ] + ) + model = read_model_from_string(uai, factor_eltype=torch.float64) + assignment, score, _ = mpe_tropical(model) + brute_assignment, brute_score = _brute_force_mpe( + model.cards, [(f.vars, f.values) for f in model.factors] + ) + assert assignment == brute_assignment + assert abs(score - brute_score) < 1e-8 + + +def test_mpe_three_variables(): + """Test MPE on a 3-variable model to exercise general brute-force.""" + uai = "\n".join( + [ + "MARKOV", + "3", + "2 2 2", + "3", + "2 0 1", + "2 1 2", + "1 2", + "4", + "1.0 0.2 0.3 0.9", + "4", + "0.8 0.1 0.2 0.7", + "2", + "0.4 0.6", + ] + ) + model = read_model_from_string(uai, factor_eltype=torch.float64) + assignment, score, _ = mpe_tropical(model) + brute_assignment, brute_score = _brute_force_mpe( + model.cards, [(f.vars, f.values) for f in model.factors] + ) + assert assignment == brute_assignment + assert abs(score - brute_score) < 1e-8 + + +def test_mpe_partial_order(): + """Test MPE with a partial elimination order (remaining vars reduced at end).""" + uai = "\n".join( + [ + "MARKOV", + "2", + "2 2", + "2", + "1 0", + "2 0 1", + "2", + "0.6 0.4", + "4", + "1.2 0.2 0.2 1.2", + ] + ) + model = read_model_from_string(uai, factor_eltype=torch.float64) + # Only eliminate var 1, leaving var 2 for final reduce + assignment, score, _ = mpe_tropical(model, order=[1]) + brute_assignment, brute_score = _brute_force_mpe( + model.cards, [(f.vars, f.values) for f in model.factors] + ) + assert assignment == brute_assignment + assert abs(score - brute_score) < 1e-8 + + +def test_recover_mpe_assignment_tensor_node(): + """Test recover_mpe_assignment directly on a TensorNode with vars.""" + node = TensorNode( + vars=(1, 2), + values=torch.tensor([[0.1, 0.9], [0.3, 0.2]]), + ) + assignment = recover_mpe_assignment(node) + # The max is at (0, 1) → var 1=0, var 2=1 + assert assignment == {1: 0, 2: 1} + + +def test_recover_mpe_assignment_bad_node(): + """Test recover_mpe_assignment raises on missing variables.""" + from tropical_in_new.src.contraction import ReduceNode + from tropical_in_new.src.primitives import Backpointer + + child = TensorNode(vars=(1, 2), values=torch.tensor([[0.1, 0.9], [0.3, 0.2]])) + bp = Backpointer( + elim_vars=(2,), elim_shape=(2,), out_vars=(99,), # bad out_vars + argmax_flat=torch.tensor([1, 0]) + ) + root = ReduceNode(vars=(), values=torch.tensor(0.9), child=child, elim_vars=(2,), backpointer=bp) + with pytest.raises(KeyError, match="Missing assignment"): + recover_mpe_assignment(root) + + +def test_mpe_with_evidence(): + uai = "\n".join( + [ + "MARKOV", + "2", + "2 2", + "2", + "1 0", + "2 0 1", + "2", + "0.8 0.2", + "4", + "1.0 0.1 0.1 1.0", + ] + ) + model = read_model_from_string(uai, factor_eltype=torch.float64) + assignment, score, _ = mpe_tropical(model, evidence={1: 1}) + assert assignment[1] == 1 + assert isinstance(score, float) diff --git a/tropical_in_new/tests/test_primitives.py b/tropical_in_new/tests/test_primitives.py new file mode 100644 index 0000000..e44d7ac --- /dev/null +++ b/tropical_in_new/tests/test_primitives.py @@ -0,0 +1,98 @@ +import pytest +import torch + +from tropical_in_new.src.primitives import ( + IndexMap, + argmax_trace, + safe_log, + tropical_einsum, + tropical_reduce_max, +) + + +def test_safe_log_zero_to_neg_inf(): + values = torch.tensor([1.0, 0.0], dtype=torch.float64) + logged = safe_log(values) + assert logged[0].item() == 0.0 + assert torch.isneginf(logged[1]) + + +def test_tropical_einsum_binary_reduce(): + # a(x) with x in {0,1}, b(x,y) with y in {0,1} + a = torch.tensor([0.2, 0.8]) + b = torch.tensor([[0.1, 0.5], [0.7, 0.3]]) + index_map = IndexMap(a_vars=(1,), b_vars=(1, 2), out_vars=(2,), elim_vars=(1,)) + values, backpointer = tropical_einsum(a, b, index_map, track_argmax=True) + + expected = torch.tensor([ + max(0.2 + 0.1, 0.8 + 0.7), + max(0.2 + 0.5, 0.8 + 0.3), + ]) + assert torch.allclose(values, expected) + recovered = argmax_trace(backpointer, {2: 0}) + assert recovered[1] in (0, 1) + + +def test_tropical_einsum_no_elim(): + a = torch.tensor([0.1, 0.2]) + b = torch.tensor([0.3, 0.4]) + index_map = IndexMap(a_vars=(1,), b_vars=(2,), out_vars=(1, 2), elim_vars=()) + values, backpointer = tropical_einsum(a, b, index_map) + assert values.shape == (2, 2) + assert backpointer is None + + +def test_tropical_einsum_out_vars_mismatch(): + a = torch.tensor([0.1, 0.2]) + b = torch.tensor([0.3, 0.4]) + index_map = IndexMap(a_vars=(1,), b_vars=(2,), out_vars=(2, 1), elim_vars=()) + with pytest.raises(ValueError, match="out_vars does not match"): + tropical_einsum(a, b, index_map) + + +def test_tropical_reduce_max_no_elim(): + t = torch.tensor([1.0, 2.0]) + values, bp = tropical_reduce_max(t, (1,), []) + assert torch.equal(values, t) + assert bp is None + + +def test_tropical_reduce_max_missing_var(): + t = torch.tensor([1.0, 2.0]) + with pytest.raises(ValueError, match="not present in vars"): + tropical_reduce_max(t, (1,), (99,)) + + +def test_tropical_reduce_max_no_track(): + t = torch.tensor([[1.0, 3.0], [2.0, 4.0]]) + values, bp = tropical_reduce_max(t, (1, 2), (2,), track_argmax=False) + assert values.shape == (2,) + assert bp is None + + +def test_argmax_trace_no_elim_vars(): + from tropical_in_new.src.primitives import Backpointer + bp = Backpointer(elim_vars=(), elim_shape=(), out_vars=(), argmax_flat=torch.tensor(0)) + result = argmax_trace(bp, {}) + assert result == {} + + +def test_argmax_trace_missing_key(): + from tropical_in_new.src.primitives import Backpointer + bp = Backpointer( + elim_vars=(2,), elim_shape=(3,), out_vars=(1,), + argmax_flat=torch.tensor([0, 1, 2]) + ) + with pytest.raises(KeyError, match="Missing assignment"): + argmax_trace(bp, {}) + + +def test_argmax_trace_scalar_backpointer(): + """Test argmax_trace when out_vars is empty (scalar result).""" + from tropical_in_new.src.primitives import Backpointer + bp = Backpointer( + elim_vars=(1,), elim_shape=(3,), out_vars=(), + argmax_flat=torch.tensor(2) + ) + result = argmax_trace(bp, {}) + assert result == {1: 2} diff --git a/tropical_in_new/tests/test_utils.py b/tropical_in_new/tests/test_utils.py new file mode 100644 index 0000000..d37061c --- /dev/null +++ b/tropical_in_new/tests/test_utils.py @@ -0,0 +1,63 @@ +import pytest + +from tropical_in_new.src.utils import read_evidence_file, read_model_file, read_model_from_string + + +def test_read_evidence_file(tmp_path): + content = "2 0 1 3 0\n" + filepath = tmp_path / "example.evid" + filepath.write_text(content, encoding="utf-8") + + evidence = read_evidence_file(str(filepath)) + assert evidence == {1: 1, 4: 0} + + +def test_read_evidence_file_empty_path(): + assert read_evidence_file("") == {} + + +def test_read_evidence_file_empty_content(tmp_path): + filepath = tmp_path / "empty.evid" + filepath.write_text("", encoding="utf-8") + assert read_evidence_file(str(filepath)) == {} + + +def test_read_evidence_file_malformed(tmp_path): + filepath = tmp_path / "bad.evid" + filepath.write_text("2 0 1\n", encoding="utf-8") # declares 2 obs vars but only 1 pair + with pytest.raises(ValueError, match="Malformed evidence line"): + read_evidence_file(str(filepath)) + + +def test_read_model_from_string_malformed_header(): + with pytest.raises(ValueError, match="at least 4 header lines"): + read_model_from_string("MARKOV\n2\n") + + +def test_read_model_from_string_bad_network_type(): + with pytest.raises(ValueError, match="Unsupported UAI network type"): + read_model_from_string("UNKNOWN\n2\n2 2\n1\n1 0\n2\n0.5 0.5\n") + + +def test_read_model_from_string_card_mismatch(): + with pytest.raises(ValueError, match="Expected 2 cardinalities"): + read_model_from_string("MARKOV\n2\n2 2 2\n1\n1 0\n2\n0.5 0.5\n") + + +def test_read_model_from_string_scope_size_mismatch(): + with pytest.raises(ValueError, match="Scope size mismatch"): + read_model_from_string("MARKOV\n2\n2 2\n1\n2 0\n2\n0.5 0.5\n") + + +def test_read_model_from_string_table_size_mismatch(): + with pytest.raises(ValueError, match="Factor table size mismatch"): + read_model_from_string("MARKOV\n2\n2 2\n1\n2 0 1\n3\n0.5 0.5 0.5\n") + + +def test_read_model_file(tmp_path): + content = "MARKOV\n2\n2 2\n1\n1 0\n2\n0.6 0.4\n" + filepath = tmp_path / "test.uai" + filepath.write_text(content, encoding="utf-8") + model = read_model_file(str(filepath)) + assert model.nvars == 2 + assert len(model.factors) == 1