From bf078369a8f05a2ff1afc9c9998fa7d5a8267096 Mon Sep 17 00:00:00 2001 From: Mengkun Liu Date: Sun, 18 Jan 2026 13:17:14 +0800 Subject: [PATCH 1/8] Add PyTorch BP docs, examples, and tests --- BP_PYTORCH_IMPLEMENTATION_PLAN.md | 66 ++++ README.md | 359 ++---------------- docs/api_reference.md | 45 +++ docs/mathematical_description.md | 41 ++ docs/usage_guide.md | 38 ++ examples/evidence_example.py | 24 ++ examples/simple_example.py | 16 + examples/simple_model.evid | 1 + examples/simple_model.uai | 10 + pytorch_bp/__init__.py | 42 ++ .../__pycache__/__init__.cpython-313.pyc | Bin 0 -> 758 bytes .../belief_propagation.cpython-313.pyc | Bin 0 -> 13723 bytes .../__pycache__/uai_parser.cpython-313.pyc | Bin 0 -> 6326 bytes pytorch_bp/belief_propagation.py | 357 +++++++++++++++++ pytorch_bp/uai_parser.py | 145 +++++++ pytorch_bp/utils.py | 19 + requirements.txt | 1 + tests/__init__.py | 1 + tests/__pycache__/__init__.cpython-313.pyc | Bin 0 -> 188 bytes .../__pycache__/test_bp_basic.cpython-313.pyc | Bin 0 -> 4999 bytes .../test_integration.cpython-313.pyc | Bin 0 -> 2123 bytes .../test_uai_parser.cpython-313.pyc | Bin 0 -> 3031 bytes tests/test_bp_basic.py | 88 +++++ tests/test_integration.py | 33 ++ tests/test_uai_parser.py | 56 +++ 25 files changed, 1021 insertions(+), 321 deletions(-) create mode 100644 BP_PYTORCH_IMPLEMENTATION_PLAN.md create mode 100644 docs/api_reference.md create mode 100644 docs/mathematical_description.md create mode 100644 docs/usage_guide.md create mode 100644 examples/evidence_example.py create mode 100644 examples/simple_example.py create mode 100644 examples/simple_model.evid create mode 100644 examples/simple_model.uai create mode 100644 pytorch_bp/__init__.py create mode 100644 pytorch_bp/__pycache__/__init__.cpython-313.pyc create mode 100644 pytorch_bp/__pycache__/belief_propagation.cpython-313.pyc create mode 100644 pytorch_bp/__pycache__/uai_parser.cpython-313.pyc create mode 100644 pytorch_bp/belief_propagation.py create mode 100644 pytorch_bp/uai_parser.py create mode 100644 pytorch_bp/utils.py create mode 100644 requirements.txt create mode 100644 tests/__init__.py create mode 100644 tests/__pycache__/__init__.cpython-313.pyc create mode 100644 tests/__pycache__/test_bp_basic.cpython-313.pyc create mode 100644 tests/__pycache__/test_integration.cpython-313.pyc create mode 100644 tests/__pycache__/test_uai_parser.cpython-313.pyc create mode 100644 tests/test_bp_basic.py create mode 100644 tests/test_integration.py create mode 100644 tests/test_uai_parser.py diff --git a/BP_PYTORCH_IMPLEMENTATION_PLAN.md b/BP_PYTORCH_IMPLEMENTATION_PLAN.md new file mode 100644 index 0000000..d1daaba --- /dev/null +++ b/BP_PYTORCH_IMPLEMENTATION_PLAN.md @@ -0,0 +1,66 @@ +# PyTorch Belief Propagation Implementation Plan + +## 1. Core Objectives + +1. Implement a generic Belief Propagation (BP) algorithm using PyTorch +2. Provide comprehensive mathematical description and code documentation + +## 2. Technical Analysis + +### 2.1 Algorithm Flow + +**Algorithm Flow**: Initialize → Collect Messages → Process Messages → Damping Update → Convergence Check → Compute Marginals + +### 2.2 Function Interface Summary + +| Function Name | Functionality | Input | Output | +|---------------|---------------|-------|--------| +| `read_model_file` | Parse UAI file | filepath | UAIModel | +| `read_evidence_file` | Parse evidence file | filepath | Dict[int, int] | +| `BeliefPropagation.__init__` | Construct BP object | UAIModel | BP object | +| `initial_state` | Initialize messages | BP object | BPState | +| `collect_message` | Factor→Variable message | BP, BPState | None | +| `process_message` | Variable→Factor message | BP, BPState, damping | None | +| `belief_propagate` | BP main loop | BP, parameters | (BPState, BPInfo) | +| `compute_marginals` | Compute marginal probabilities | BPState, BP | Dict[int, Tensor] | +| `apply_evidence` | Apply evidence | BP, evidence | BeliefPropagation | + +## 3. Project Structure and Testing + +### 3.1 Project Structure + +The Belief Propagation framework will be integrated as a submodule within the TensorInference.jl project: + +pytorch_bp_inference/ +├── README.md +├── requirements.txt +├── setup.py (optional) +├── src/ +│ ├── __init__.py +│ ├── uai_parser.py # UAI file parsing +│ ├── belief_propagation.py # BP core implementation +│ └── utils.py # Utility functions +├── tests/ +│ ├── __init__.py +│ ├── test_uai_parser.py +│ ├── test_bp.py +│ └── test_integration.py +├── examples/ +│ ├── asia_network/ +│ │ ├── main.py +│ │ └── model.uai +│ └── simple_example.py +└── docs/ + ├── mathematical_description.md + ├── api_reference.md + └── usage_guide.md + +### 3.2 Testing + +- [ ] Test parsing `examples/asia-network/model.uai` +- [ ] Test BP initialization and state creation +- [ ] Test message collection and processing +- [ ] Test convergence checking +- [ ] Test marginal computation +- [ ] Test evidence application +- [ ] Compare results with provided reference results (from test cases in TensorInference.jl) diff --git a/README.md b/README.md index dd6eb90..10e3803 100644 --- a/README.md +++ b/README.md @@ -1,353 +1,70 @@ -# BPDecoderPlus: Quantum Error Correction with Belief Propagation +# PyTorch Belief Propagation -A winter school project on circuit-level decoding of surface codes using belief propagation and integer programming decoders, with extensions for atom loss in neutral atom quantum computers. +This module provides a PyTorch implementation of belief propagation (BP) +for discrete factor graphs defined in the UAI format. The design follows +the tensor-contraction style used in TensorInference.jl. +See https://github.com/TensorBFS/TensorInference.jl for reference. -## Project Goals - -| Level | Task | Description | -|-------|------|-------------| -| **Basic** | MLE Decoder | Reproduce the integer programming (MLE) decoder as the baseline | -| **Challenge** | Atom Loss | Handle atom loss errors in neutral atom systems | -| **Extension** | QEC Visualization | https://github.com/nzy1997/qec-thrust | - -Note: we also want to explore the boundary of vibe coding, which may lead to a scipost paper. - -## Learning Objectives - -After completing this project, students will: -- Understand surface code structure and syndrome extraction -- Implement and compare different decoding algorithms -- Analyze decoder performance through threshold plots -- Learn about practical QEC challenges (atom loss, circuit-level noise) - -## Prerequisites - -- **Programming**: Julia basics, familiarity with Python for plotting -- **Mathematics**: Linear algebra, probability theory -- **QEC Background**: Stabilizer formalism, surface codes (helpful but not required) - -## Key Concepts - -### Detection Events - -In circuit-level quantum error correction, we don't use raw syndrome measurements directly. Instead, we use **detection events** — the XOR (difference) between consecutive syndrome measurements. - -**Why detection events instead of raw syndromes?** - -In code-capacity noise (simplified model), syndromes directly indicate errors. But in circuit-level noise: -- Measurement errors exist and can randomly flip syndrome values -- A syndrome value of 1 could mean "real data error" or "measurement error" -- Detection events localize changes in space-time - -``` -Round 1 syndrome: [0, 0, 1, 0] -Round 2 syndrome: [0, 1, 1, 0] - ─────────── -Detection event: [0, 1, 0, 0] ← Only the CHANGE matters -``` - -A detection event = 1 means "something happened in this space-time region" (data qubit error or measurement error). The decoder's job is to figure out which. - -### Observable Flip - -An **observable flip** indicates whether the logical qubit's value changed from initialization to final measurement. - -For a surface code doing Z-memory: -- The logical observable Z̄ is a product of Z operators along a path -- Initialize in |0⟩_L (eigenstate of Z̄ with eigenvalue +1) -- If final measurement gives Z̄ = -1, that's an observable flip → logical error - -**The decoding problem:** - -``` -Physical errors occur during circuit execution - ↓ -Input: Detection events (what we observe) - ↓ - Decoder - ↓ -Output: Predicted observable flip (0 or 1) - ↓ - Compare with actual observable flip - ↓ - Match → Success - Mismatch → Logical error -``` - -In the Detector Error Model (DEM), errors are annotated with which detectors and observables they affect: - -``` -error(0.001) D0 D1 # Triggers detectors 0,1 but NOT the observable -error(0.001) D2 D3 L0 # Triggers detectors 2,3 AND flips logical observable L0 -``` - -Errors that include `L0` form logical error chains — these are what the decoder must identify. - -## Must-Read Papers - -Before starting, please read these foundational papers: - -### 1. BP+OSD Decoder (Foundational) -**"Decoding across the quantum LDPC code landscape"** - Roffe et al. (2020) -- Introduces BP+OSD, the key algorithm for this project -- [arXiv:2005.07016](https://arxiv.org/abs/2005.07016) - -### 2. Improved BP for Surface Codes -**"Improved Belief Propagation Decoding Algorithms for Surface Codes"** - Chen et al. (2024) -- State-of-the-art BP improvements (Momentum-BP, AdaGrad-BP) -- [arXiv:2407.11523](https://arxiv.org/abs/2407.11523) - -### 3. Circuit-Level Noise Decoding -**"Exact Decoding of Repetition Code under Circuit Level Noise"** - (2025) -- Explains circuit-level noise models and exact MLE decoding -- [arXiv:2501.03582](https://arxiv.org/abs/2501.03582) - -### 4. Atom Loss Error Correction -**"Quantum Error Correction resilient against Atom Loss"** - (2024) -- Core paper for the atom loss extension -- [arXiv:2412.07841](https://arxiv.org/abs/2412.07841) - -### 5. Decoder Review (Optional) -**"Decoding algorithms for surface codes"** - Quantum Journal (2024) -- Comprehensive review of all surface code decoders -- [Quantum 8:1498](https://quantum-journal.org/papers/q-2024-10-10-1498/) - -## Getting Started - -### 1. Clone the Repository +## Quick Start ```bash -git clone -cd BPDecoderPlus +pip install -r requirements.txt ``` -### 2. Install Dependencies +## Environment Setup -```bash -# Install Julia dependencies -make setup-julia +Run commands from the repo root so imports resolve correctly. -# Install Python dependencies (for visualization) -make setup-python +Windows (PowerShell): -# Or install both at once -make setup +```powershell +$env:PYTHONPATH=(Get-Location).Path ``` -### 3. Run Tests +macOS/Linux: ```bash -make test +export PYTHONPATH="$(pwd)" ``` -### 4. Quick Demo +Alternative (works everywhere): install the package in editable mode. ```bash -# Run a quick benchmark to verify everything works -make quick -``` - -Or interactively in Julia: -```julia -using BPDecoderPlus - -# Quick benchmark with IP decoder -result = quick_benchmark(distance=5, p=0.05, n_trials=100) -println("Logical error rate: ", result["logical_error_rate"]) - -# Compare decoders -compare_decoders([3, 5], [0.02, 0.05, 0.08], 100) -``` - -## Project Structure - -``` -BPDecoderPlus/ -├── README.md # This file -├── Makefile # Build automation -├── Project.toml # Julia dependencies -├── src/ -│ └── BPDecoderPlus.jl # Main module (uses TensorQEC.jl) -├── benchmark/ -│ ├── generate_data.jl # Generate benchmark data -│ ├── run_benchmarks.jl # Run decoder timing tests -│ └── data/ # Output data (JSON) -├── python/ -│ ├── requirements.txt # Python dependencies -│ └── visualize.py # Plotting scripts -├── results/ -│ └── plots/ # Generated plots -├── test/ -│ └── runtests.jl # Unit tests -└── note/ - └── belief_propagation_qec_plan.tex -``` - -## Available Decoders - -| Decoder | Symbol | Description | -|---------|--------|-------------| -| IP (MLE) | `:IP` | Integer programming decoder - finds minimum weight error | -| BP | `:BP` | Belief propagation without post-processing | -| BP+OSD | `:BPOSD` | BP with Ordered Statistics Decoding post-processing | -| Matching | `:Matching` | Minimum weight perfect matching (via TensorQEC) | - -## Tasks - -### Basic Task: Reproduce MLE Decoder - -1. Understand how surface codes work using TensorQEC -2. Run the IP decoder on different code distances -3. Generate threshold plots (logical vs physical error rate) -4. Analyze how performance scales with code distance - -```julia -using BPDecoderPlus - -# Create surface code -code = SurfaceCode(5, 5) -tanner = CSSTannerGraph(code) - -# Create error model -em = iid_error(0.03, 0.03, 0.03, 25) - -# Decode with IP -decoder = IPDecoder() -compiled = compile(decoder, tanner) - -ep = random_error_pattern(em) -syn = syndrome_extraction(ep, tanner) -result = decode(compiled, syn) - -# Check for logical error -x_err, z_err, y_err = check_logical_error(tanner, ep, result.error_pattern) +pip install -e . ``` -### Challenge Task: Implement and Compare BP Decoder +```python +from pytorch_bp import read_model_file, BeliefPropagation, belief_propagate, compute_marginals -1. Run the BP decoder and compare with IP -2. Understand when BP fails and when OSD helps -3. Compare threshold performance -4. Analyze timing differences - -```julia -# Compare decoders -results = compare_decoders( - [3, 5, 7], # distances - 0.01:0.02:0.15, # error rates - 1000; # trials - decoders=[:IP, :BP, :BPOSD] -) +model = read_model_file("examples/simple_model.uai") +bp = BeliefPropagation(model) +state, info = belief_propagate(bp) +print(info) +print(compute_marginals(state, bp)) ``` -### Extension: Handle Atom Loss - -1. Understand the atom loss model -2. Compare naive vs loss-aware decoding -3. Analyze threshold degradation with loss +## Examples -```julia -# Simulate atom loss -loss_model = AtomLossModel(0.02) # 2% loss rate -tanner, lost_qubits = apply_atom_loss(tanner, loss_model) +Run from the repo root so `pytorch_bp` is on the import path: -# Decode with loss information -result = decode_with_atom_loss(tanner, syn, decoder, lost_qubits) -``` - -## Running Benchmarks - -### Full Benchmark (takes longer, more accurate) ```bash -make benchmark -make visualize +python examples/simple_example.py +python examples/evidence_example.py ``` -### Quick Benchmark (for testing) -```bash -make quick -``` +## Tests -### Individual Decoder Test ```bash -make benchmark-ip -make benchmark-bp -make benchmark-bposd +python -m unittest discover -s tests ``` -## Makefile Targets - -| Target | Description | -|--------|-------------| -| `make setup` | Install all dependencies | -| `make test` | Run unit tests | -| `make benchmark` | Generate full benchmark data | -| `make quick` | Quick benchmark (fewer trials) | -| `make visualize` | Generate plots from data | -| `make all` | Full pipeline: test -> benchmark -> visualize | -| `make clean` | Remove generated files | -| `make help` | Show available targets | - -## Output Plots - -After running benchmarks and visualization, you'll find: - -- `results/plots/threshold_ip.png` - IP decoder threshold curve -- `results/plots/threshold_bp.png` - BP decoder threshold curve -- `results/plots/threshold_bposd.png` - BP+OSD decoder threshold curve -- `results/plots/decoder_comparison.png` - Side-by-side comparison -- `results/plots/timing_comparison.png` - Decoding speed comparison -- `results/plots/atom_loss.png` - Effect of atom loss -- `results/plots/scalability.png` - Scalability analysis - -## Evaluation Criteria - -Your submission will be evaluated on: - -1. **Correctness** (40%): Do your decoders produce valid corrections? -2. **Analysis** (30%): Quality of threshold plots and performance analysis -3. **Code Quality** (20%): Clean, documented, well-tested code -4. **Extension** (10%): Atom loss handling or other improvements - -## Resources - -### Core Library -- [TensorQEC.jl](https://github.com/nzy1997/TensorQEC.jl) - QEC library we build on - -### Reference Implementations -- [bp_osd](https://github.com/quantumgizmos/bp_osd) - Python BP+OSD implementation -- [ldpc](https://github.com/quantumgizmos/ldpc) - LDPC decoder library - -### Documentation -- [TensorQEC Documentation](https://nzy1997.github.io/TensorQEC.jl/dev/) -- [Error Correction Zoo](https://errorcorrectionzoo.org/) - -## Troubleshooting - -### TensorQEC fails to precompile -This is a known issue with YaoSym. It still works at runtime: -```julia -# Ignore precompilation warnings and proceed -using TensorQEC -``` - -### Out of memory on large codes -Reduce code distance or number of trials: -```julia -# Use smaller codes for testing -results = quick_benchmark(distance=3, n_trials=100) -``` - -### Visualization fails -Ensure Python dependencies are installed: -```bash -pip install -r python/requirements.txt -``` - -## License - -MIT License - See LICENSE file for details. - -## Acknowledgments +What each unit test covers: -This project is built on [TensorQEC.jl](https://github.com/nzy1997/TensorQEC.jl) by nzy1997. +| Test file | Test case | Functions under test | +| --- | --- | --- | +| `tests/test_uai_parser.py` | `test_read_model_from_string` | `read_model_from_string` | +| `tests/test_uai_parser.py` | `test_read_evidence_file` | `read_evidence_file` | +| `tests/test_bp_basic.py` | `test_bp_matches_exact_tree` | `belief_propagate`, `compute_marginals` | +| `tests/test_bp_basic.py` | `test_apply_evidence` | `apply_evidence`, `belief_propagate`, `compute_marginals` | +| `tests/test_integration.py` | `test_example_file_runs` | `read_model_file`, `BeliefPropagation`, `belief_propagate`, `compute_marginals` | +| `tests/test_integration.py` | `test_example_with_evidence` | `read_evidence_file`, `apply_evidence` + BP pipeline | diff --git a/docs/api_reference.md b/docs/api_reference.md new file mode 100644 index 0000000..7210cdc --- /dev/null +++ b/docs/api_reference.md @@ -0,0 +1,45 @@ +## PyTorch BP API Reference + +This reference documents the public API exported from `pytorch_bp`. + +### UAI Parsing + +- `read_model_file(path, factor_eltype=torch.float64) -> UAIModel` + Parse a UAI `.uai` model file. + +- `read_model_from_string(content, factor_eltype=torch.float64) -> UAIModel` + Parse a UAI model from an in-memory string. + +- `read_evidence_file(path) -> Dict[int, int]` + Parse a UAI `.evid` file and return evidence as 1-based indices. + +### Data Structures + +- `Factor(vars: List[int], values: torch.Tensor)` + Container for a factor scope and its tensor. + +- `UAIModel(nvars: int, cards: List[int], factors: List[Factor])` + Holds all model metadata for BP. + +### Belief Propagation + +- `BeliefPropagation(uai_model: UAIModel)` + Builds factor graph adjacency for BP. + +- `initial_state(bp: BeliefPropagation) -> BPState` + Initialize messages to uniform vectors. + +- `collect_message(bp, state, normalize=True)` + Update factor-to-variable messages in place. + +- `process_message(bp, state, normalize=True, damping=0.2)` + Update variable-to-factor messages in place. + +- `belief_propagate(bp, max_iter=100, tol=1e-6, damping=0.2, normalize=True)` + Run the full BP loop and return `(BPState, BPInfo)`. + +- `compute_marginals(state, bp) -> Dict[int, torch.Tensor]` + Compute marginal distributions after convergence. + +- `apply_evidence(bp, evidence: Dict[int, int]) -> BeliefPropagation` + Return a new BP object with evidence applied to factor tensors. diff --git a/docs/mathematical_description.md b/docs/mathematical_description.md new file mode 100644 index 0000000..91d6c1c --- /dev/null +++ b/docs/mathematical_description.md @@ -0,0 +1,41 @@ +## Belief Propagation (BP) Overview + +This document summarizes the BP message-passing rules implemented in +`pytorch_bp/belief_propagation.py` for discrete factor graphs. The approach +mirrors the tensor-contraction perspective used in TensorInference.jl. +See https://github.com/TensorBFS/TensorInference.jl for the Julia reference. + +### Factor Graph Notation + +- Variables are indexed by x_i with domain size d_i. +- Factors are indexed by f and connect a subset of variables. +- Each factor has a tensor (potential) phi_f defined over its variables. + +### Messages + +Factor to variable message: + +mu_{f->x}(x) = sum_{all y in ne(f), y != x} phi_f(x, y, ...) * product_{y != x} mu_{y->f}(y) + +Variable to factor message: + +mu_{x->f}(x) = product_{g in ne(x), g != f} mu_{g->x}(x) + +### Damping + +To improve stability on loopy graphs, a damping update is applied: + +mu_new = damping * mu_old + (1 - damping) * mu_candidate + +### Convergence + +We use an L1 difference threshold between consecutive factor->variable +messages to determine convergence. + +### Marginals + +After convergence, variable marginals are computed as: + +b(x) = (1 / Z) * product_{f in ne(x)} mu_{f->x}(x) + +The normalization constant Z is obtained by summing the unnormalized vector. diff --git a/docs/usage_guide.md b/docs/usage_guide.md new file mode 100644 index 0000000..4ce3cfd --- /dev/null +++ b/docs/usage_guide.md @@ -0,0 +1,38 @@ +## PyTorch Belief Propagation Usage + +This guide shows how to parse a UAI file, run BP, and apply evidence. +The implementation follows the tensor-contraction viewpoint in +TensorInference.jl: https://github.com/TensorBFS/TensorInference.jl + +### Quick Start + +```python +from pytorch_bp import read_model_file, BeliefPropagation, belief_propagate, compute_marginals + +model = read_model_file("examples/simple_model.uai") +bp = BeliefPropagation(model) +state, info = belief_propagate(bp, max_iter=50, tol=1e-8, damping=0.1) +print(info) + +marginals = compute_marginals(state, bp) +print(marginals[1]) +``` + +### Evidence + +```python +from pytorch_bp import read_model_file, read_evidence_file, apply_evidence +from pytorch_bp import BeliefPropagation, belief_propagate, compute_marginals + +model = read_model_file("examples/simple_model.uai") +evidence = read_evidence_file("examples/simple_model.evid") +bp = apply_evidence(BeliefPropagation(model), evidence) +state, info = belief_propagate(bp) +marginals = compute_marginals(state, bp) +``` + +### Tips + +- For loopy graphs, use damping between 0.1 and 0.5. +- Normalize messages to avoid numerical underflow. +- Use float64 for consistent comparisons in tests. diff --git a/examples/evidence_example.py b/examples/evidence_example.py new file mode 100644 index 0000000..9276e93 --- /dev/null +++ b/examples/evidence_example.py @@ -0,0 +1,24 @@ +from pytorch_bp import ( + read_model_file, + read_evidence_file, + BeliefPropagation, + belief_propagate, + compute_marginals, + apply_evidence, +) + + +def main(): + model = read_model_file("examples/simple_model.uai") + evidence = read_evidence_file("examples/simple_model.evid") + bp = apply_evidence(BeliefPropagation(model), evidence) + state, info = belief_propagate(bp, max_iter=50, tol=1e-8, damping=0.1) + print(info) + + marginals = compute_marginals(state, bp) + for var_idx, marginal in marginals.items(): + print(f"Variable {var_idx} marginal: {marginal}") + + +if __name__ == "__main__": + main() diff --git a/examples/simple_example.py b/examples/simple_example.py new file mode 100644 index 0000000..1858f52 --- /dev/null +++ b/examples/simple_example.py @@ -0,0 +1,16 @@ +from pytorch_bp import read_model_file, BeliefPropagation, belief_propagate, compute_marginals + + +def main(): + model = read_model_file("examples/simple_model.uai") + bp = BeliefPropagation(model) + state, info = belief_propagate(bp, max_iter=50, tol=1e-8, damping=0.1) + print(info) + + marginals = compute_marginals(state, bp) + for var_idx, marginal in marginals.items(): + print(f"Variable {var_idx} marginal: {marginal}") + + +if __name__ == "__main__": + main() diff --git a/examples/simple_model.evid b/examples/simple_model.evid new file mode 100644 index 0000000..722ff38 --- /dev/null +++ b/examples/simple_model.evid @@ -0,0 +1 @@ +1 0 1 diff --git a/examples/simple_model.uai b/examples/simple_model.uai new file mode 100644 index 0000000..694d568 --- /dev/null +++ b/examples/simple_model.uai @@ -0,0 +1,10 @@ +MARKOV +2 +2 2 +2 +1 0 +2 0 1 +2 +0.6 0.4 +4 +0.9 0.1 0.2 0.8 diff --git a/pytorch_bp/__init__.py b/pytorch_bp/__init__.py new file mode 100644 index 0000000..3b88f48 --- /dev/null +++ b/pytorch_bp/__init__.py @@ -0,0 +1,42 @@ +""" +PyTorch Belief Propagation (BP) submodule for approximate inference. +""" + +from .uai_parser import ( + read_model_file, + read_model_from_string, + read_evidence_file, + UAIModel, + Factor +) + +from .belief_propagation import ( + BeliefPropagation, + BPState, + BPInfo, + initial_state, + collect_message, + process_message, + belief_propagate, + compute_marginals, + apply_evidence +) + +__all__ = [ + # UAI parsing + 'read_model_file', + 'read_model_from_string', + 'read_evidence_file', + 'UAIModel', + 'Factor', + # Belief Propagation + 'BeliefPropagation', + 'BPState', + 'BPInfo', + 'initial_state', + 'collect_message', + 'process_message', + 'belief_propagate', + 'compute_marginals', + 'apply_evidence', +] diff --git a/pytorch_bp/__pycache__/__init__.cpython-313.pyc b/pytorch_bp/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc67cce76350d6a07fe5641498eb692c9f16f754 GIT binary patch literal 758 zcmb7?zmC&D5XRR|>^M%Gf5quamzd(HXb?iW2u`{^fs6tYMWoecvq@H3d#$kpgpP-x z;1PJGZK>#h=yWZ!35QNd6m0R+`15$?+ZpeZBqn%13$20{Az$6Y?eTvH`!^VTA~OPH z=FR*okNV)#7PM#^+BATGh7hLiUF)htBZz1hx-^EE_Mk`m(5DF`bN~Z7gdrWlC?yZb zto_o1@jA%D4;OL1du2du(y7!+uF@P##n-%1rcCcm^DM2mi^43oTBa)lDX%Ip?^VGY znW}OnL6$->MYZTD3!1#|a zS{&}mP(R#N+OaT2wQVFTIILC4b!`WDRlQ@k7PA51jtI0J0*BC{;}AJ?9o#~p=g@aZ z90m?UhY{jMHnAPXmPRmU<1JUL;!sQYXLURI%Pnr^zdE{hTx-U@kUbh8_TjTf^LMIj zB&4<27^COY{HYW;dC2v)o>x1Zz1XltHD}Dtma(&HXQRi*|KJInVimXW`aa@&%k#V+ fzVEefqR{iN3F2lHBM9Q!;pgeN=YQ50TW9dQjBEYrRa$dN{V^SaRaGEnBv1S2S%VQp%$wN~S}R$}Z*D zO5CV8v}M}lD!Fk*rMa3$4rduHV5PvJ?xDaXhvM4)(`u_!Vz;?FqczaW0e7HcyFu@- z`@Y!+DVcWCD{w<_c6MGnJNwP=`+eWc{rr3jfe;RcLxb&v{2p)AP^cvC*XjtlMm&Tk z9>$|P&h+YdU9X`l=FZ2`*8v8lg&AP}_ zJgCcEVt{&{lEGr=CuhqJv&T;RF7edS=amlTd&&mwo&t{^%F40q0>hWql3L=aUst1I zu#DEScpOTta;Vebsf4jSVKj6^&d*V&3i}_l!ziki{wr`4bR>;*BsK6}SPMr&$2`am z<_{Kh=7)>fjB%$g zpNaCr=i4kFXz;FzD!=BW@DvA4pkhn$$3D1Bi>pTpv_vm=mqaQSM=sYZa6{E*6 zXo8;^r;SdX&SRt{c~D}8dLY0sB^Kq*Oz*6oJl-~Daa*Ngc@DCs;A0!~G}V%0ZlXgM zg2ORxG|F?Mbky7!ADB2Vv`KlBfslV3UualaDfAVL5bNQ8-wf}4jd+=Un!P^qlCYKm zHL}|)`A`GrjRvTpPLG-eSE!-2^O9VB*vVO)plc_B&Q6Xx2#M$hOlk*SUB6jPXRe*f zGf(UKoP<~j2W=xdcoNPs4e+FVHJ!P3DoLoRJz9t6S^%N1Z{e|xzhn&2##dxSx&mmk|A;tvL%!JSp>;29Ng5goV00u1?mFD=#;Gxpq!`*`~ha#~c z&k4ilqtWo-fzyYA!|<8;)8R>BaN;sb?s@;v#NZHJN&bnoPEQRs^C^DTw;X_ZfvQtS~WjyB%NJ9Jn@4Q%l7?0 zII+<5)}ciC^{V-*mD0xdPKfsXP~DueSIq65-MegWh_i2+|MZETlpbIN6vq#AoA&Qz zB_khf- z@ilsMt)MOO@%?!l9tyW!z4~ORze^r(-$D86)ir>;Q<-HzX7y8b2T*c2kfl&yC>(5! zMOzhd;l=|K6M&0tJgz|=my+8c8TdeCEXWtZos<`hOpXWn07z09-W#DoFf!sc@~{$# z#Q6CNyjtWA{h^UdQXUGepUSTulQ?Z4lMr}Vjf5}92Gu#FN)zx}s3A;3G)?YRammt~ z&Ku4Tt6G)|T7Om5G86a)FIKh0^-BemuTA@@BM8BH=|G4OJ{^% z+g*u~>lfxPtlC|o-IXe>NCf6hvnN(dw}_=%7Az~JJMJ5ZWAmDxG*%725ih%2LgvvzJx81zHLC^s|^UKwk>K48T0e1@rV3svda-Ri1=4H$#PmcS9%HW6sAxY1O1nxks&4*ow_jg-w&Y zh4ybwy*{;MuHCG&y{ua{kl< z0=MZBouMY-c|hQjHtP(|gzzb(opR6E*D(CGphD373y_tYUkn<6XrWTr5KyY|?_r_F zXVd}eYoKhy>@o6rV3Ay@7Z`x%D67FZ>9AA8R&fgoW!Vb(-Eg6w4A<#cGTlhao5nwQW!^7}Q z{s=_sH`Fq6e8WO{pxP${{jQ=@skt)}7!_SIr zQ()*uz^L^y1+Y53IY}^-7{sfWIZb}ZvH4d z$gPw0wGkKXq#)q(Q}i*SCm)>P0TMH>EN9^K2ng!aB*it{EVx-PTeWy9ZcG(B66FbV zGJj#@cJx+ssddktz*6D9CG$R6Xz+0!qrGXv$*ThpAk$d=LvTLU9q07oaRR zrdb`JYC{^3vT^Ysf@kAH0MGr5{CsR6>4qlw8c3}BI!Np3*ZB!(zH6kNQa&&!P!H0O zd-heb1;p+)5Ia4zW8rUr+Uj~xJAkWKk3s--1*mI;Ix1>EQYT^>K+Z)&KhV96p^kX7 zYoI#V5IjtmE@Jkw{i*{ctK4huM}@f|;H+a%--#n$!(Y<64~QDlWian@je*U2jxJ{n zL#9UGdvad`c6p4143w!UFQeA;>QrY9tOwoH8s&G&iTraOV`w&ZVlYE9RE93Q}7Cz)5WO zw{|6Zm+e~?s+O&tOKc|v04ZNFM`XS%6^9h2pv{_85DS3Li6J}1IbEk!GzcG}Adv+Q zng~YpUcHCu_5p+!!Gb6K2r3(by3iV+Ji=yKlA7(N+KIk?+7p0sAAo?b3?E`gR<#pzkK`;^s){8HFcbB9b=@y3b}5X> zYXG|#NF}w4%>aC6O@)CS2=HUw@bYjwpm|ulW_9U$^j%DZ^|G3UY5)tBo&0;4XmG5I zPXMu29aLHLOzF2~an8M0T1f zPo~RXI3Jh@a;}l5TDXzzEgbX=iHMsr%}3;|n?pY%8h&b#`c6U%A`|t@1p~w9HOdMV zm6@PDt`}G2z9Yd)*q>aGY3AD2IurmmlRMMF7Qj8l_UF(n-vDaBgkj#*@hA`}L!A1K ze)%aaz)?|OQEve_NA)A2akq&_QQ*&Fgoq~@Xe)`m7z(~D>4nL0V6Ny`1w?xRD}50o zKdl*o4g~b*2$Ny>mRK-=(GZrJp{XA$KMN37Smu+>^hwqwfGt6I4r3$N83+u$fy5+)#C36EfrwXd3k9<^EI^+GDSL6DW zwP4xG{l-?BD%d0zG(bji$&7v`IBSg?Qe}?0=VqTvyfix)H@t316<1t$z3oaKUUD{z z#m!$guH}=W>IWr+YfO6Id3Cw)@Y}B@^otR({m`6#{?+NDONEF3%~S}@=A5(6zcVDy zzgsxtTq@oM)XG{s-S?+YOr+=tBcQ(jU}xTe7WSX@Mu=f*(dLr!BFK_a->t(fjocUk zNj|@a0#N8bRNoaPHEb?Y1GRc6odH}z&lF%%Zen9X^m?+-3J8lK7h(W~$U$+uv_~&> zldCuaUaw;cRRHLa1E5e2GN!hG@xP?J<+r`sI}B))1zH1=Zwy|8H?LE{_K3x+(@+mA zEw80tqbZqoD(Xk9UaLpf{T65l7?tmPt$>iw!(T9Gz8n3endhh}b#_n(^a0^Sgzo!K z`3id_?BwsE0I)S2#tnc#WQ3cX7(v5bB@mA{nscJ+1{{4Hem;>@Aq}CD48%+JDsV1v z+k{3#;2sCun+OMngEEq8lHSdoJ0GNY4vIm`)!rftm`(!q;^3uaZxK$J?4vY<5Hl&; zqJ})h2+4Rc5S|PQ{0P=VBj0V4iPr_pco8EcQ+y>x6uKHAB^3aoB6090YLCYd1e1E$estxz#~k@;BRu>6fwj>5;_Ip2i8@DL(^uyV4=wV=Cg zxn)_Y-2qG}Zn;}s3m&qRy>f2%Tf2c^)o{tD=Fi1X-mR>O&Qc;x*lv9W)}<^keWy6HgyscA}%-hA=Ki!0UJ zUq2D=(Yn1?)s-qLpUa=kPZe!W6;<3fvDF3f6MrlqMU^RM(<*?Uvw7LsGGj{Fory1e zYj>)?U99h#IRVfLR5AC6cK58>szqCM;)|lKW#Rc{+phcdq-Mt_Wn@!_Oy*1@fpK(F z4`;7#4eLt&T&(+VSclncEcc3J5yi5w zr;O}%-&Zuo1}N%#DJP!MamZLIGt0g@87nqmo9T(zX7O6Gux*NBn+ec67e2{PG*GR8 zcF=~9!Asw691*xd^r&CFCcJBf+E4D6;-0ABURuFD)@_)q2E|~3&Cy|SOgV5!ng z5h2Upr{LZt$_f=W350R%cky^62={|RFb00dNr4-c`Qw<%JF^+&|2qOAmMBZ5n>!@0 z8I?@q#;0c7PtnGwkvg*{JIRUtW`h^a&ha6+f}k=1Ja zD^TqY{Du1v0dUq4*hJFZG~Y0Pw>@6)QB(V(?uW+rjDUUHj;u8GfJGYby;oi}cj4-V zWXVm(4aY*8xVh^mWk0R?phhe|74NyXbMO1QpBg_fzW*h0-`SO&=ftAzOGRADUjJ+R zQ>og^af(yjc7E)t)s2dD_SP(L&L2^YyLsTfftqfyUZKv9^7t z2cWTvTXk*~om&@9iq744dPQf?sL&zC(jZfMhx3B4l;McNQ183aY&m#7IQ$|P8dZf{;LZmw1~dNAbq*gzocu5`-Q2-Ro`*vs z8Kx-f6BfzJmohN7_UH4p3h$}q3cVU|A|GQe*Qp6v}L*BO2ioZX|D@wI@3 z85V*ZT^28ggrMcf891Gyp9)Rz`(Sb;ePBrDC~hkdCP zmhJ7U_8p>q$FjXMekfIFe{)KzsJ@>^Ht%C# zyLHPU2E04AB3Vq$<5yeHh?!;5(u+w+|XmdCk zomeO6pQ(b>x$oxs0+&MLljB)C)S8UTn`cjuZ4{)(C?2|oG`9X(9DcpCm7c~z7Y7{d z@}S+@0@ea?~=XuSGL}@GE&}{b`V?fTTRK% zMW0x*XUV!}iQPk2Uhd;w4iL9%uyeR>PbBBjE_DULii^OC(+{xlYBc0pZ{UbzaycmF zy^^f9pcys>N{n$|ESGaY-D`pG8!`9}Q&P+J!J0-bk$Yc6(=J9>j~NsF2&J*C?*O=V z^|(s6z`qa33+)LG)qpW=!PTfbD#yWv9D;L#91xJ9z)&a*`vI`WKsTXP7#+bK32N)Y z${do(r0c2n=T$G{s7fkjCaHNavzeq~DyxbGyQJ%ZF0Lo8@%0M|4>{t&M*gyT?kD3H zq_#2VdP;R3f{~r$^)7QG;6J1XAfTg-6cScZhC%Lw$$Kns1g9Y`E4t{UJoQ0>Oc%lJ zthfP@0SdTq6;&l`s$$jimuZJ)awU8k;VY1U%?l?fwof0s=iHQ`efN|diT>-pd0#Rj z)^z^Z`%f?Y=moLlc-(Z??pU?gi}w2D$q(&2z>+SnnH~Rn4=AhB4mcZRcg}5{-TJLo z*!{4U%zWV+yP;ex-xBY+>uR5AOFku*G%TAN*D7(}qMF#sme^87xdC4o!JX}Oi_Hv! zq{{6AwOzsI24@qIjc*LFw1C3mUPHe|sWNr?vlXV2t%)|9bw*O3YvE-r--Jrsd;w(k zRCS>RS&k0?1*s`RkGB8O&Sa^>C-zSlq}g}LYry3+pxM2-tWy99ZM!g%m)Qh_Wt_{h zq&8*Nd$ho4ExwzY6%0@UTcA(2Ilm{5%cJw=4WNsM2Cr4^l|n$93<17l3b?$jLPCF) z+9Gouz&&Rg;H%-6-Vopom2as&1N?OlQEod!2w4&T0UYqnX!$5TOdhvm&z_hPZP6+rkB_kaZoFyI)2ro)ZNV3Wqe!QoTl1$hiJYT?d zsi3TKk-@eAr{l6;ogMyfpcI)=g;S_I<|_A zt;>$qcs}rzjsC>Kk~d#V)iln(Fk?Z8X1}vFQS*KK!r9vcw+5EoU1DX|j4@T$GV}5* zav(>|+-p}~OI}&7=$v6cs;o_1S>3c<+_Zi1X>rr;mC9#k_NOWw*Ir9CwV==P-SLDq z)xIs+BRX1AJ312VJKp5(<(l?o2kt97p9RZfbDPL@T+LtG1a8K39my+|H?OrJuzR`S>FJ((*6PH$WozRS+epFQ z&75)T&ivu+c?>__>t_D&=70S|^3#7=+~a0H9>EWE9i#6;bjDpp&vi*gI4J<0O&+xY z&tim-%cDKc@52aeUqX)q@YgVfQplsw(36JpVS$5qOL^{oOrZv#o^AdVrjYDP271yU zi}SvLMW{6C_uO!ROfo+spP}5tzYA61I}>3A;T{g&ANl~n$L zxIQ&*W}IsTqEC-9R%Xi@f#}ofJZ3+$Mleo0oXpOp@-1ls*M*)3cu99RG9|!V=bo8; zCas6`-HNKj;p->nPly%nv;i`}Ia20Gn=oZ2M(fuGzdZOL5AN1ZGfm74-0W!po% zrniwsS88)(3gV{bwLFg5yHwenCU8yFZcgjrdbhGRZGbBbr?x(A!jzeC4QZ?kscoH$ z=kAOy`o;Fc4{6pBriLj?oJ$k9EH%RsKDfTmzVG>2<1br&-m-M~thn#o1I$eCV_KM| er3zP?z;*G^9s8Z(JC65G(%YpT-$Tr%vira7+U%bI literal 0 HcmV?d00001 diff --git a/pytorch_bp/__pycache__/uai_parser.cpython-313.pyc b/pytorch_bp/__pycache__/uai_parser.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f6196828d055b86e235fe375a85e854cfbea562 GIT binary patch literal 6326 zcmahtTWlLwc6T^4d<@?bMJcjqj~rPR6PcD0C$`l(8#{g|erO#rLYxfLPi}kc-+hHC*j#=6X9$)$&0hpJIhCXizAY9_*Ho(lhx#z9GzDS za$co|M)6HPqGdDMjC_{n^6IQw$mH~xbQkf9@}_V)V-!r`XeL!K#n%_|S#2WKhB^`I zFZWNv;3JYClmlkj;vD7UHY&vVgqW}+LUAGCcr5WkTx5v4;=n8_#U0c=>x{Zg`*Af@ z$kDql*u1#kdXZCE)i7kL<*8<9dLg6F%BnoW7Rni2F3bUa*4r#q^K&s%xCnHb_KRwE zK{HZ(9b0_h82|T2p!9Jz|!1*Y%K9^MlXQe zjnrH&mrWcwdsIv1(i%ORT`&^)%fRr|Tyi>}SWq)bOGRS&%ce7#%;=dyGHLoPLBtvW z4;*cTpevFG9@5oY8Mt<;B)spch5J5m-v>BJHCYHRs0obeXJG(##VKSGMq8VY#BDIs zWO!?SvY~J-5P{YF^Wp%B3KNSS3%?;?m~m*)IczcVkjX`(e2XKd$oPAcqh2`2Kn!K- zgB~eG?b<}C1?a++gU}U8Ef}e!Kh7*?DpRY@TcOXleYWk}(Ie}lN7hDO`FeECe{9Wp zj0Rz8)W$@@K0Fci;OV+Lufge3GC7}1FJy7-NhaS}P_y+Je=<3fp++H_(KS5>(~e{^ zol5~;$T*V+`GL$)PNv;bAqw7$&@j$+C6h)0OgNQP3k8~)UMOfrGI^6QY#4>qfs3(2 zjbQ}+4fzk@@P^1&rr_Uf2*7s#W#}I7 z0KoL^36VJga0Uqn^G|j(ckoZnxC6}I6?al;))kdZCpfAzppn_TPr-qUhtY2Vr?hO{ z_ER~%pn@{C8HF({Z!(KvS=G~ZqcUQqsG~BNBCat^hh=O=s$G##(U@8eF(S?fvhcJT z*nLxh&yQ^J0T-&6xEULhAs+@9><%Z2GR(IZ|OEu6EoQ0Z_|YPV3E-YU6E|~NO;@(Qmv~Oy6){C zc$m;PSujRM#;OSzJlTS$a2;!4%kAv0LmO_k-G2QI^oAHe4=uWfWxY=KA&OYmX*c<- zrZd_pv&@zaTb9sHr)9v%Cj%LT1}v|HKJERacg3@M>eu`iN54A#<>~5ySJ(ER{B8f5 z|MZ&kv}M11i^Yt$6Hzzqfg>m?C+)><8+OP;8o>@}wrs3vl~hJpKX&L8EHiu>N7&&> z#vbU<3{PzcE^$XCwL##qI91e>GW!6b55t^zrf!Y6-*jSF&#MLS2UII8 zP(9V|H=FYMD_FZJ%z^i9WF|JE+Z5X*;=ZN5P=D7Uoa3*tLnc#b#;m=;E06~rv<;Ac zGEC5fu%$sFGd7v*ATsQCBPO+(RD-ZF`7Dqp z7jkk8;u1?^+6*MAWeb<{+L(M)EvPaMnLSxb0~UqRRvt=gGwMRNFb1c9>x|80b86wm zeNjo8WIU?}@`m=wx#_pHR6#O1IuYeeuXR}cOBi8uc}+J3s;Ox+j6HgDEt{O7xq0xw z6xM*28!ti!Q8i2?Uumww-*)w_aJ7!k z8-q)O*M>@tM?8@t-?{r&c7G9Bcb`}jPCOC;@;g`Ge;H2L`ay7nza_oopm@$El|aXj zqP{;3M)?r|;{;J{ZJ-==XeT*mGi-Z_Ve2I0Q9^P%pQI>F@q-ec0o{JGTAj6>0b1Rl zAUY2!2ug-pr3)?P0vM051~$3~5CI{N!@8lafU~`cGX!?s-gctBaTa_X(sA~6$EFp0 zoQ12#I^DUK=&q?os3VHAF#{aA0H;fUlU%2c zAVdX?RQ!|P#%$|({-qzIwsDS36dzyXv5Cm zICbpTfE{2xg|3#hcvCVO*T!eR64<0e9ZH~%O<<~7kWB)uTiZ^cTnWIh139%tE|y+l z^kGeD%B7`Oh7Bab#a#z&dM7&tIS`oYZtn4dq(e9pv(eGn@r`P`Yqd zs6n~jrF2eWQZ*%!|Uk{E_wjAXqmtS|1V^(y)dbJLbRinOM!i;89K^;@Yosrk_CoGwfRAZUEl|LUrCg><*Rt9? zMl)96nA}B*MWBI22#V+bHY@Ad2lEuWKf~2Bup5%jFxegZbpDh1TSspX9;|v!elvKm zlK$$=FW>ws@m`|z`iE~`dGl&weehuM*qZ0$hMk0-ssxt~6i?sr_{(BhU2?v6rg*Ft z*|{>heE#FNmfyM^*;DnzzKQIqaPP;%iBYSXf?2a#3-c{k2hTfYmo`8dVf#UHy z-j4GAt8bOKTA-_}R{EE;CC|5ko$G;})xbbWywlZF>8iZ5+;iXjtGpnPk&wX;C z8rb{G3oG39&hl&5!j=B(y`NktiEDwqHK*@G?-g%Z`@OU8o{RYUAfa}9O0!pN61UG<{X3}_(e)RP z|FLg-@x&j3-6gT+4qlH}!q=v%Zuw3}-&*8EwPSqEGkz!3U9nd#F8ivX-6a=%q}+eS zU9#T^hHv!!ZC`E2u9cb9bGI_Wo z58q3~yZ49h10*=|!}kNkdjvop(!I~v37T@fXS_$a6@u+Krf zE&V&+F+MCkzDD^47&Rt1?McZ98EC)k*_;J)3iqVsMrS#2dlvkLZR^>J%}Q7%(R*)c zjr2t46@IdrPP7&^<{>Bu&7go`Mwn^_bwy}6>sFZ2MO{!tD=pIPBSd$!h0A7Gf?z7c zSx#6fvfvLGVpu~&<06i97tBFr4v9hAoGzXJRB>VXbOwZhp-6E&G1q_3%Ju7^LVyt9 zSo~Ez1ZWpCX-!W-j*iZKD2DI#G9bDs5z1m66%CEea9;#{yDejIMVKlpHm1+Wi%EgM;R?B~wF82gO@^4$Kgj;ISH7P2GBa Nb?)Z&hxp2{{9kDnVFUmG literal 0 HcmV?d00001 diff --git a/pytorch_bp/belief_propagation.py b/pytorch_bp/belief_propagation.py new file mode 100644 index 0000000..dcef8a5 --- /dev/null +++ b/pytorch_bp/belief_propagation.py @@ -0,0 +1,357 @@ +""" +Belief Propagation (BP) algorithm implementation using PyTorch. +""" + +from typing import List, Dict, Tuple, Optional +import torch +from copy import deepcopy + +from .uai_parser import UAIModel, Factor + + +class BeliefPropagation: + """Belief Propagation object for factor graphs.""" + + def __init__(self, uai_model: UAIModel): + """ + Construct BP object from UAI model. + + Args: + uai_model: Parsed UAI model + """ + self.nvars = uai_model.nvars + self.factors = uai_model.factors + self.cards = uai_model.cards + + # Build mapping: t2v (factor -> variables), v2t (variable -> factors) + self.t2v = [list(factor.vars) for factor in self.factors] + self.v2t = self._build_v2t() + + def _build_v2t(self) -> List[List[int]]: + """Build variable-to-factors mapping.""" + v2t = [[] for _ in range(self.nvars)] + for factor_idx, vars_list in enumerate(self.t2v): + for var in vars_list: + if 0 < var <= self.nvars: # Ensure valid index + v2t[var - 1].append(factor_idx) # Convert to 0-based + return v2t + + def num_tensors(self) -> int: + """Return number of factors (tensors).""" + return len(self.t2v) + + def num_variables(self) -> int: + """Return number of variables.""" + return self.nvars + + +class BPState: + """BP state storing messages.""" + + def __init__(self, message_in: List[List[torch.Tensor]], + message_out: List[List[torch.Tensor]]): + """ + Args: + message_in: Incoming messages from factors to variables + message_out: Outgoing messages from variables to factors + """ + self.message_in = message_in + self.message_out = message_out + + +class BPInfo: + """BP convergence information.""" + + def __init__(self, converged: bool, iterations: int): + self.converged = converged + self.iterations = iterations + + def __repr__(self): + status = "converged" if self.converged else "not converged" + return f"BPInfo({status}, iterations={self.iterations})" + + +def initial_state(bp: BeliefPropagation) -> BPState: + """ + Initialize BP message state with all ones vectors. + + Args: + bp: BeliefPropagation object + + Returns: + BPState with initialized messages + """ + message_in = [] + message_out = [] + + for var_idx in range(bp.nvars): + var_messages_in = [] + var_messages_out = [] + + for factor_idx in bp.v2t[var_idx]: + card = bp.cards[var_idx] + msg = torch.ones(card, dtype=torch.float64) + var_messages_in.append(msg.clone()) + var_messages_out.append(msg.clone()) + + message_in.append(var_messages_in) + message_out.append(var_messages_out) + + return BPState(deepcopy(message_in), message_out) + + +def _compute_factor_to_var_message( + factor_tensor: torch.Tensor, + incoming_messages: List[torch.Tensor], + target_var_idx: int +) -> torch.Tensor: + """ + Compute factor to variable message using tensor contraction. + + μ_{f→x}(x) = Σ_{other vars} [φ_f(...) * Π_{y≠x} μ_{y→f}] + + Args: + factor_tensor: Factor tensor with shape (d1, d2, ..., dn) + incoming_messages: List of incoming messages, one for each variable in factor + target_var_idx: Index of target variable (0-based) in factor's variable list + + Returns: + Output message vector with shape (d_target,) + """ + ndims = len(incoming_messages) + + if ndims == 1: + return factor_tensor.clone() + + # Multiply factor tensor by incoming messages (excluding target) and sum out dims. + result = factor_tensor + for dim in range(ndims): + if dim == target_var_idx: + continue + msg = incoming_messages[dim] + shape = [1] * ndims + shape[dim] = msg.shape[0] + result = result * msg.view(*shape) + + # Sum over all dimensions except target + sum_dims = [dim for dim in range(ndims) if dim != target_var_idx] + if sum_dims: + result = result.sum(dim=tuple(sum_dims)) + return result + + +def collect_message(bp: BeliefPropagation, state: BPState, normalize: bool = True) -> None: + """ + Collect and update messages from factors to variables. + + μ_{f→x}(x) = Σ[φ_f(...) * Π μ_{y→f}] + + Args: + bp: BeliefPropagation object + state: BPState (modified in place) + normalize: Whether to normalize messages + """ + for factor_idx, factor in enumerate(bp.factors): + # Get incoming messages from variables to this factor + incoming_messages = [] + for var in factor.vars: + var_idx_0based = var - 1 + # Find position of this factor in v2t[var_idx_0based] + factor_pos = bp.v2t[var_idx_0based].index(factor_idx) + incoming_messages.append(state.message_out[var_idx_0based][factor_pos]) + + # Compute outgoing message to each variable + for var_pos, var in enumerate(factor.vars): + var_idx_0based = var - 1 + # Compute message from factor to variable + outgoing_msg = _compute_factor_to_var_message( + factor.values, + incoming_messages, + var_pos + ) + + # Normalize + if normalize: + msg_sum = outgoing_msg.sum() + if msg_sum > 0: + outgoing_msg = outgoing_msg / msg_sum + + # Update message_in + factor_pos = bp.v2t[var_idx_0based].index(factor_idx) + state.message_in[var_idx_0based][factor_pos] = outgoing_msg + + +def process_message( + bp: BeliefPropagation, + state: BPState, + normalize: bool = True, + damping: float = 0.2, +) -> None: + r""" + Process and update messages from variables to factors. + + μ_{x→f}(x) = Π_{g∈ne(x)\setminus f} μ_{g→x}(x) + + Args: + bp: BeliefPropagation object + state: BPState (modified in place) + normalize: Whether to normalize messages + damping: Damping factor for message update + """ + for var_idx_0based in range(bp.nvars): + for factor_pos, factor_idx in enumerate(bp.v2t[var_idx_0based]): + # Compute product of all incoming messages except from current factor + product = torch.ones(bp.cards[var_idx_0based], dtype=torch.float64) + + for other_factor_pos, other_factor_idx in enumerate(bp.v2t[var_idx_0based]): + if other_factor_pos != factor_pos: + product = product * state.message_in[var_idx_0based][other_factor_pos] + + # Normalize + if normalize: + msg_sum = product.sum() + if msg_sum > 0: + product = product / msg_sum + + # Damping update + old_message = state.message_out[var_idx_0based][factor_pos].clone() + state.message_out[var_idx_0based][factor_pos] = ( + damping * old_message + (1 - damping) * product + ) + + +def _check_convergence(message_new: List[List[torch.Tensor]], + message_old: List[List[torch.Tensor]], + tol: float = 1e-6) -> bool: + """ + Check if messages have converged. + + Args: + message_new: Current iteration messages + message_old: Previous iteration messages + tol: Convergence tolerance + + Returns: + True if converged, False otherwise + """ + for var_msgs_new, var_msgs_old in zip(message_new, message_old): + for msg_new, msg_old in zip(var_msgs_new, var_msgs_old): + # Compute L1 distance + diff = torch.abs(msg_new - msg_old).sum() + if diff > tol: + return False + return True + + +def belief_propagate(bp: BeliefPropagation, + max_iter: int = 100, + tol: float = 1e-6, + damping: float = 0.2, + normalize: bool = True) -> Tuple[BPState, BPInfo]: + """ + Run Belief Propagation algorithm main loop. + + Args: + bp: BeliefPropagation object + max_iter: Maximum number of iterations + tol: Convergence tolerance + damping: Damping factor + normalize: Whether to normalize messages + + Returns: + Tuple of (BPState, BPInfo) + """ + state = initial_state(bp) + + for iteration in range(max_iter): + # Save previous messages for convergence check + prev_message_in = deepcopy(state.message_in) + + # Update messages + collect_message(bp, state, normalize=normalize) + process_message(bp, state, normalize=normalize, damping=damping) + + # Check convergence + if _check_convergence(state.message_in, prev_message_in, tol=tol): + return state, BPInfo(converged=True, iterations=iteration + 1) + + return state, BPInfo(converged=False, iterations=max_iter) + + +def compute_marginals(state: BPState, bp: BeliefPropagation) -> Dict[int, torch.Tensor]: + """ + Compute marginal probabilities from converged BP state. + + b(x) = (1/Z) * Π_{f∈ne(x)} μ_{f→x}(x) + + Args: + state: Converged BPState + bp: BeliefPropagation object + + Returns: + Dictionary mapping variable index (1-based) to marginal probability distribution + """ + marginals = {} + + for var_idx_0based in range(bp.nvars): + # Product of all incoming messages + product = torch.ones(bp.cards[var_idx_0based], dtype=torch.float64) + + for msg in state.message_in[var_idx_0based]: + product = product * msg + + # Normalize to get probability distribution + msg_sum = product.sum() + if msg_sum > 0: + product = product / msg_sum + + marginals[var_idx_0based + 1] = product # Convert to 1-based indexing + + return marginals + + +def apply_evidence(bp: BeliefPropagation, evidence: Dict[int, int]) -> BeliefPropagation: + """ + Apply evidence constraints to BP object. + + Modifies factor tensors to zero out non-evidence assignments. + + Args: + bp: Original BeliefPropagation object + evidence: Dictionary mapping variable index (1-based) to value (0-based) + + Returns: + New BeliefPropagation object with evidence applied + """ + # Create new factors with evidence constraints + new_factors = [] + + for factor in bp.factors: + # Create mask for evidence constraints + factor_tensor = factor.values.clone() + + # Apply evidence constraints + for var_pos, var in enumerate(factor.vars): + if var in evidence: + evid_value = evidence[var] + # Create slice that zeros out non-evidence values + slices = [slice(None)] * len(factor.vars) + slices[var_pos] = evid_value + + # Zero out all non-evidence assignments + mask = torch.ones_like(factor_tensor) + for i in range(factor_tensor.shape[var_pos]): + if i != evid_value: + slices_mask = slices.copy() + slices_mask[var_pos] = i + mask[tuple(slices_mask)] = 0 + + factor_tensor = factor_tensor * mask + + new_factors.append(Factor(factor.vars, factor_tensor)) + + # Create new UAIModel with modified factors + from .uai_parser import UAIModel + new_uai = UAIModel(bp.nvars, bp.cards, new_factors) + + return BeliefPropagation(new_uai) diff --git a/pytorch_bp/uai_parser.py b/pytorch_bp/uai_parser.py new file mode 100644 index 0000000..5077d8e --- /dev/null +++ b/pytorch_bp/uai_parser.py @@ -0,0 +1,145 @@ +""" +UAI file format parser for Belief Propagation. +""" + +from typing import List, Dict, Tuple +import torch + + +class Factor: + """Factor class representing a factor in the factor graph.""" + + def __init__(self, vars: List[int], values: torch.Tensor): + """ + Args: + vars: List of variable indices (1-based) + values: Tensor of factor values with shape matching variable cardinalities + """ + self.vars = tuple(vars) + self.values = values + + def __repr__(self): + return f"Factor(vars={self.vars}, shape={self.values.shape})" + + +class UAIModel: + """UAI model class containing variables, cardinalities, and factors.""" + + def __init__(self, nvars: int, cards: List[int], factors: List[Factor]): + """ + Args: + nvars: Number of variables + cards: List of cardinalities for each variable + factors: List of factors + """ + self.nvars = nvars + self.cards = cards + self.factors = factors + + def __repr__(self): + return f"UAIModel(nvars={self.nvars}, nfactors={len(self.factors)})" + + +def read_model_file(filepath: str, factor_eltype=torch.float64) -> UAIModel: + """ + Parse UAI format model file. + + Args: + filepath: Path to .uai file + factor_eltype: Data type for factor values (default: torch.float64) + + Returns: + UAIModel object + """ + with open(filepath, 'r') as f: + content = f.read() + return read_model_from_string(content, factor_eltype=factor_eltype) + + +def read_model_from_string(content: str, factor_eltype=torch.float64) -> UAIModel: + """ + Parse UAI model from string. + + Args: + content: UAI file content as string + factor_eltype: Data type for factor values + + Returns: + UAIModel object + """ + lines = [line.strip() for line in content.split('\n') if line.strip()] + + # Parse header + network_type = lines[0] # MARKOV or BAYES + nvars = int(lines[1]) + cards = [int(x) for x in lines[2].split()] + ntables = int(lines[3]) + + # Parse factor scopes + scopes = [] + for i in range(ntables): + parts = lines[4 + i].split() + scope_size = int(parts[0]) + scope = [int(x) + 1 for x in parts[1:]] # Convert to 1-based + scopes.append(scope) + + # Parse factor tables + idx = 4 + ntables + tokens: List[str] = [] + while idx < len(lines): + tokens.extend(lines[idx].split()) + idx += 1 + cursor = 0 + + factors: List[Factor] = [] + for scope in scopes: + if cursor >= len(tokens): + raise ValueError("Unexpected end of UAI factor table data.") + nelements = int(tokens[cursor]) + cursor += 1 + values = torch.tensor( + [float(x) for x in tokens[cursor:cursor + nelements]], + dtype=factor_eltype + ) + cursor += nelements + + # Reshape according to cardinalities in original scope order + shape = tuple([cards[v - 1] for v in scope]) + values = values.reshape(shape) + factors.append(Factor(scope, values)) + + return UAIModel(nvars, cards, factors) + + +def read_evidence_file(filepath: str) -> Dict[int, int]: + """ + Parse evidence file (.evid format). + + Args: + filepath: Path to .evid file + + Returns: + Dictionary mapping variable index (1-based) to observed value (0-based) + """ + if not filepath: + return {} + + with open(filepath, 'r') as f: + lines = f.readlines() + + if not lines: + return {} + + # Parse last line + last_line = lines[-1].strip() + parts = [int(x) for x in last_line.split()] + + nobsvars = parts[0] + evidence = {} + + for i in range(nobsvars): + var_idx = parts[1 + 2*i] + 1 # Convert to 1-based + var_value = parts[2 + 2*i] + evidence[var_idx] = var_value + + return evidence diff --git a/pytorch_bp/utils.py b/pytorch_bp/utils.py new file mode 100644 index 0000000..eccbf92 --- /dev/null +++ b/pytorch_bp/utils.py @@ -0,0 +1,19 @@ +""" +Utility functions for Belief Propagation. +""" + +from typing import List +import torch + + +def deep_copy_messages(messages: List[List[torch.Tensor]]) -> List[List[torch.Tensor]]: + """ + Deep copy message structure. + + Args: + messages: Nested list of message tensors + + Returns: + Deep copy of messages + """ + return [[msg.clone() for msg in var_msgs] for var_msgs in messages] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..12c6d5d --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +torch diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..8869551 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test package for pytorch_bp.""" diff --git a/tests/__pycache__/__init__.cpython-313.pyc b/tests/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62929ae21c8ce455c0d49920b90f7d45abefb24f GIT binary patch literal 188 zcmey&%ge<81T&L!GF5=|V-N=h7@>^M96-iYhG2#whIB?vrYf0`)Z!9_g2d$P#Pn2! zwEQB4g36NoqU4PDqyjxZO~za7@$o77$?@?k89sxIxMkpK6;qy>SCU$!P@J5RpPv)s z6yTDYoS%|f6p&L|98&@`uQ(TZlX-=wL5gX6|kVA?=j`+aL O$jEq$L8*uZ$N>OG=`wEs literal 0 HcmV?d00001 diff --git a/tests/__pycache__/test_bp_basic.cpython-313.pyc b/tests/__pycache__/test_bp_basic.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd95ade5d54a57c1824cb4a3de0bd0f0b91f023d GIT binary patch literal 4999 zcmb7ITWlN06`kGX@+H0`>Sg&w$(C%!5hW{<5fH6E+Q*vDt$JlfX5hml$kmL@h`#G$DZ>LnF2N66I$PS7_s^ zF`HoAJ&KT6gAm9zm1Xo+3+gF0Jt)dV8LxSXz6ZQKsmGEy9!@03rFeJ(Clg^s#W6Wb zstGOo>rZjU{E`I-zxiQ^h`sR7aADR9cn72@ywQvKUv0LrkUOlVRy}Y+RBf zQiRoYfG71F)q15L>Yq3-p&+x5*YldP($f^<4B|4Ep zuS}6GjOh4#n7g*U5 z+N{qz+TByc56n?lm?Pno`Qy%@mD(1x)2N}i^H(D@`j2+lQ`zAMzfe>JZENjQ&@Aw~ zFG8#V9!LM#4%j%#JhXR%MY}z^79%ZstNaAy+CpWCt>qEDB^Y_Sj3l_B$mBV zNtVu}@P5)Foe?8ynNECr2z@S(3+zX#$*_#2}-lQ*ntHU~A!JBAyggVubBYOA6tXbb?qUIh~NOsKP{N zV=4F`i6bY8QI%vRiD4r8NpH8;MR@H9A!eQU2@@ez6(Av|m!(+rqzYOT8pK%8f^Za6 zFtG}-vXl-I48)Cyf!c!GXp`6UB8CJ^T5v7(FW|1SJgNDaI9I_4?*>sip)wAj{&ouG zoeeXu%yvy3UaYCVIC-~W`=<@tZv;OIeGtk$a^!>14d!<}3)R1W;$|rO$dOwOL%Eva zso}iCHTBfJ`sR#vvC47VA$+|3uf9L~ZXX=UIfT1aLbgiC+uSpO>A>t`w{4q0x7FsW zYcs}t{rWc(7ZaJm#hQlO4Sfr}@9%kUPp+XaSJRg{l(*Gflx7F7`rq=;Z^>@(Rhftwz^w}+VAeSqa8zxLfK`ZZg8V50*z9kltRl2 z-O@{tt#<&PyVO!U22ra`wt zV;Pb#W<*~U-u2P}DcCcNetZN`yd)Ea2o4a&>a`N1|G>!b(Ps(Mb(ZV$bm0T=^;vGG zr}He|<>~b7EOeMUeZ8JeUpHa8&pJDOyPypsK9KM|;bn>8rDRMdrbrTKQC7VSF(^`e zLZRz<3Si|GNe!lm=ZB5-X%dgXCkj2Z0lp3E-<-TKd7rJaJ61S*qcvlA!?prL54(b9 zJ(?@pV7W`EoL50+Q9K#xH+s?C0T~$GJ*@;wPJ%4TRO*Et6Ig*047&{4vC<7QmU1jz zO3QU9F|I25Ivy0FRH>EGF8!>Q`A{i8s#x6^~bci1RggpVotEiNO4KaBliF;`;uW_74lZqKSbQ&yxs)?Sp z^ps7AYUHG(gbR$QVo4%B{{?A3{28TCWfH0>wCHM^Q{O&!?Oe{)o#7Yln`YBFduxVW ztZ$pc3kPoueDYYf{^)d7hR>UAGtOz}tl_HTEyr#1w#C}kIVo4$nXxRoHqY@nS6hbv zqQR4SDsOYm9GyOTW#DSya$tT}cE`RO>6@?IIzF1+@cbRy3->wH*o+Uu>JP;_1(yc! zL4fm*0XV>y>0|)XbaS*)0qzpy;j&N+6Pc?>JAtYw^aZ{M{s91F>5`z&2jJ&NXhN{6 zltEX9`RJk2I>r!96twdHhHH?e$)Xk(YiJQF1;bib14u-T0z;ipAy;2ipyENVqJpV| zt&m#d${?6`l{{9q6TDNp1eB+&NdJw3g=#QTO5Dp*`jtwb4X(#lp=uWg>G_&Iuh5Ui z?W&?#tF09@71k7Ng1wRmfc+f-Eycn^@G%~y6`ct`4V9O{KokhWT>@ch{}P4A5)EP? zc4I0oMkECvqodEzK34@5Q0^@A0T$@%aaui00XRtecF+pA4v}iec7Qhy#1lzH9XOR1 zONXAL61qbT6FD+53JRNBrHj`L?HbL4E(+yXSJA{!h+iU85yVS+rHp z1YQs1UF&9>FW9a~SL2uC^M@AP^SeI`e)Qta7qgp&S2))0c)%Gf)@3WQ*zcM*WzCze zr01TWKXq;FZfj4rwI|p5cy99(H@18%-lC(2i-dKBzv{T`n44UvzIJXQcw^_gqaXKY z8;^Wu9@0n%fI#^@;58F-I4p|^DI6x&F#Q`P9j9$uI867dIN}V4Ct_Gp<1ty1li@HP zgpctHv=X5rbz%6|1^#1+ghdtpeK`UDs3;&s*+W~daCz8EUui{C9);c_eva<54&Hj- zW8p*0*Y!?*m?_9pDXLjQjLG4!mkD?s#GIC6D&2HUqp@(_SW}Z)LJtQ;0ji|~s{l(4 zw3PUXnIhwm2w2qAh|oSr#bfd+Es1$=!3q0u0%T~KqtHlN<`{;#hn)A2>mIUwV=yx8 z*9fWyHH}Q&gWes?Ff-S9x4A3Z-1Pv_x8*Sg8El!eIc{CvVEcvtRsU;8mpNo;(k%QB D0Xzu# literal 0 HcmV?d00001 diff --git a/tests/__pycache__/test_integration.cpython-313.pyc b/tests/__pycache__/test_integration.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f17763de94ba6578d84b074da4bd4234fa3f9e42 GIT binary patch literal 2123 zcmb`I&2Jk;6u@Ucy-w`fY0{9mY3mjx#Xv)WMuZR*HE9mGYKXkC6y&I*jWbDBtk>r4 zn$|rb1SORaaEX*iNbRw`<;b7W$VIJ&oZx_Ri<+WJoS3)v+KmZtVI;qKGxPS%^LxL2 zdk~Lz5iz!kjytO80_1vRUImeoPeMj(s($FEg*kx)H3xrt52@$saZJSU#N5-0JnygBzQbymy^=wAx z31HSTEtQvzY51XTR%b!JLqP2H6W&9o7cZ|};5ut+2?o3g2x z>ibGnCZuvHPjV6Au4ClYoYAo*mZA(_5fa{&CWY?MX5pPN;SGnY36B_h$0#Jijl7oY z3HA0I7BKW~>Sp9^IrJ*ZN;1ampM&g{RjN45d&{$NZWMO3YT0|Ig|FWyzPiQmDEfAL(Fc^k}qh<5h5F#=cF*>eZT)BCy#sO6 zKM*`5D)HY0n{r5l`x3#07O!o>^DRPG#5ODfD=immuR_h# zJ6Lb=yD%M6>FUDc5iA#aLoazfcOul=^WRvHm_`Td(OJNk=yvL8;IZQXmJ={6!1fss z4tO*436be;5H7kk0eqL+tOJGw*hT}o;T6;dPz(x@i9ibVlR|JKY;7=vd#6y~iw0+e z@URdKrqhrXI!_(q?xyXpwR$6e{a+b9WQ(DBH6adj+B^JcWUAgfy>9HrPd*;H@PO5a zE^V*X;~)O-=)$Yw_J`sF65$4=u2o``@-B+)rYgD)if&4A%YY&p8Cy&%p+kM(nQs5FCuc7{yeC{o1)Nc}Y`KMbCqK&OKlE z-t+D`=cXMF8v<|!BjIZngdWg|2||`gUj*U`l90soA)LVsW?;_tvBx;f9pf?YLUpJa zNnAaW_#xyoyqb=lctPcY**-6?3#1&YVo&d8FgE@(< zN0_I3$AZLTOGt2Ah~t#1$NjzUoKUc?;uR}c1xz2*6+Iz6QTi7cAv%n2y_9MIcp*k7 znVzgfV+OM^K**oxvVP<}tY0p}~ruW3?};dyn1CMKrC3Qc%7 zgXt}Hvj)PB49h^QF*AlxmjMqBg}$Gq@aG{nLxhp|`Is8|<97=Gn4;?@@taAk9CAAe ze@;=}`=xB2gv2x7#fcURNVQVhuz!kj05$ zlxB0Oc9ZDPX~w%KkXMfGA37HXb?njuqtR$&=0 zz@FH7(m*BhCJn${P_R8uQKjWkAA)~$2`yLFOuMJt57|;jS(^7^-5EwvG6K^ub^tu z*1ZAVXV-VI!4gv}<-0Nmy{U+$pObirZ)fr=gwQs$2kp$d7=cp-Sr=+TF8BuCILI}j zAmbJWChFBslrhLR-K&Qw&b(nf^ulmaym_|QLs;S##%sJK<>S>RHR z?w^_En%e2$RPejeX>Ceds%f2k>oJGQdhX|Ma&kvX-1S#g%dOVKOI1fE#oX7<%Gb+KU&L_Cl=-Rt-O=+KnV_rb8rF${ATmET47zftd+rH$GB&D(dj mv@C9ESwk>c{g^?#eR5)n+q%r#zv%z8|Fh$(9OCPZP5%!gomCtF literal 0 HcmV?d00001 diff --git a/tests/test_bp_basic.py b/tests/test_bp_basic.py new file mode 100644 index 0000000..40a32a3 --- /dev/null +++ b/tests/test_bp_basic.py @@ -0,0 +1,88 @@ +import unittest +import itertools +import torch + +from pytorch_bp import ( + read_model_from_string, + BeliefPropagation, + belief_propagate, + compute_marginals, + apply_evidence, +) + + +def exact_marginals(model, evidence=None): + evidence = evidence or {} + assignments = list(itertools.product(*[range(c) for c in model.cards])) + weights = [] + for assignment in assignments: + if any(assignment[var_idx - 1] != val for var_idx, val in evidence.items()): + weights.append(0.0) + continue + weight = 1.0 + for factor in model.factors: + idx = tuple(assignment[v - 1] for v in factor.vars) + weight *= float(factor.values[idx]) + weights.append(weight) + total = sum(weights) + marginals = {} + for var_idx, card in enumerate(model.cards): + values = [] + for value in range(card): + mass = 0.0 + for assignment, weight in zip(assignments, weights): + if assignment[var_idx] == value: + mass += weight + values.append(mass / total if total > 0 else 0.0) + marginals[var_idx + 1] = torch.tensor(values, dtype=torch.float64) + return marginals + + +class TestBeliefPropagationBasic(unittest.TestCase): + def setUp(self): + self.content = "\n".join( + [ + "MARKOV", + "2", + "2 2", + "2", + "1 0", + "2 0 1", + "2", + "0.6 0.4", + "4", + "0.9 0.1 0.2 0.8", + ] + ) + + def test_bp_matches_exact_tree(self): + model = read_model_from_string(self.content) + bp = BeliefPropagation(model) + state, info = belief_propagate(bp, max_iter=50, tol=1e-10, damping=0.0) + self.assertTrue(info.converged) + marginals = compute_marginals(state, bp) + exact = exact_marginals(model) + + for var_idx in marginals: + self.assertTrue(torch.allclose(marginals[var_idx], exact[var_idx], atol=1e-6)) + + def test_apply_evidence(self): + model = read_model_from_string(self.content) + evidence = {1: 1} + bp = apply_evidence(BeliefPropagation(model), evidence) + state, info = belief_propagate(bp, max_iter=50, tol=1e-10, damping=0.0) + self.assertTrue(info.converged) + marginals = compute_marginals(state, bp) + + exact = exact_marginals(read_model_from_string(self.content), evidence=evidence) + self.assertTrue( + torch.allclose( + marginals[1], torch.tensor([0.0, 1.0], dtype=torch.float64) + ) + ) + self.assertAlmostEqual(float(marginals[2].sum()), 1.0, places=6) + self.assertTrue(torch.allclose(marginals[2], exact[2], atol=1e-6)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 0000000..73707a9 --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,33 @@ +import unittest + +from pytorch_bp import ( + read_model_file, + read_evidence_file, + BeliefPropagation, + belief_propagate, + compute_marginals, + apply_evidence, +) + + +class TestIntegration(unittest.TestCase): + def test_example_file_runs(self): + model = read_model_file("examples/simple_model.uai") + bp = BeliefPropagation(model) + state, info = belief_propagate(bp, max_iter=30, tol=1e-8) + self.assertTrue(info.iterations > 0) + marginals = compute_marginals(state, bp) + self.assertEqual(set(marginals.keys()), {1, 2}) + + def test_example_with_evidence(self): + model = read_model_file("examples/simple_model.uai") + evidence = read_evidence_file("examples/simple_model.evid") + bp = apply_evidence(BeliefPropagation(model), evidence) + state, info = belief_propagate(bp, max_iter=30, tol=1e-8) + self.assertTrue(info.iterations > 0) + marginals = compute_marginals(state, bp) + self.assertEqual(set(marginals.keys()), {1, 2}) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_uai_parser.py b/tests/test_uai_parser.py new file mode 100644 index 0000000..1f5907c --- /dev/null +++ b/tests/test_uai_parser.py @@ -0,0 +1,56 @@ +import unittest +import torch + +from pytorch_bp import read_model_from_string, read_evidence_file + + +class TestUAIParser(unittest.TestCase): + def test_read_model_from_string(self): + content = "\n".join( + [ + "MARKOV", + "2", + "2 2", + "2", + "1 0", + "2 0 1", + "2", + "0.6 0.4", + "4", + "0.9 0.1 0.2 0.8", + ] + ) + model = read_model_from_string(content) + self.assertEqual(model.nvars, 2) + self.assertEqual(model.cards, [2, 2]) + self.assertEqual(len(model.factors), 2) + + factor0 = model.factors[0] + factor1 = model.factors[1] + self.assertEqual(factor0.vars, (1,)) + self.assertEqual(factor1.vars, (1, 2)) + self.assertEqual(tuple(factor0.values.shape), (2,)) + self.assertEqual(tuple(factor1.values.shape), (2, 2)) + self.assertTrue( + torch.allclose( + factor0.values, torch.tensor([0.6, 0.4], dtype=torch.float64) + ) + ) + self.assertTrue( + torch.allclose( + factor1.values, + torch.tensor([[0.9, 0.1], [0.2, 0.8]], dtype=torch.float64), + ) + ) + + def test_read_evidence_file(self): + with open("examples/simple_model.evid", "r") as f: + content = f.read().strip() + self.assertEqual(content, "1 0 1") + + evidence = read_evidence_file("examples/simple_model.evid") + self.assertEqual(evidence, {1: 1}) + + +if __name__ == "__main__": + unittest.main() From 457cb1dbded05518edba2d82b4a76c845353bc9d Mon Sep 17 00:00:00 2001 From: Mengkun Liu Date: Sun, 18 Jan 2026 13:20:43 +0800 Subject: [PATCH 2/8] Ignore Python cache files --- .gitignore | 2 ++ pytorch_bp/__pycache__/__init__.cpython-313.pyc | Bin 758 -> 0 bytes .../belief_propagation.cpython-313.pyc | Bin 13723 -> 0 bytes .../__pycache__/uai_parser.cpython-313.pyc | Bin 6326 -> 0 bytes tests/__pycache__/__init__.cpython-313.pyc | Bin 188 -> 0 bytes tests/__pycache__/test_bp_basic.cpython-313.pyc | Bin 4999 -> 0 bytes .../test_integration.cpython-313.pyc | Bin 2123 -> 0 bytes .../__pycache__/test_uai_parser.cpython-313.pyc | Bin 3031 -> 0 bytes 8 files changed, 2 insertions(+) delete mode 100644 pytorch_bp/__pycache__/__init__.cpython-313.pyc delete mode 100644 pytorch_bp/__pycache__/belief_propagation.cpython-313.pyc delete mode 100644 pytorch_bp/__pycache__/uai_parser.cpython-313.pyc delete mode 100644 tests/__pycache__/__init__.cpython-313.pyc delete mode 100644 tests/__pycache__/test_bp_basic.cpython-313.pyc delete mode 100644 tests/__pycache__/test_integration.cpython-313.pyc delete mode 100644 tests/__pycache__/test_uai_parser.cpython-313.pyc diff --git a/.gitignore b/.gitignore index d33eb84..6a685ea 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +__pycache__/ +*.pyc .DS_Store .ipynb_checkpoints/ Manifest.toml diff --git a/pytorch_bp/__pycache__/__init__.cpython-313.pyc b/pytorch_bp/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index bc67cce76350d6a07fe5641498eb692c9f16f754..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 758 zcmb7?zmC&D5XRR|>^M%Gf5quamzd(HXb?iW2u`{^fs6tYMWoecvq@H3d#$kpgpP-x z;1PJGZK>#h=yWZ!35QNd6m0R+`15$?+ZpeZBqn%13$20{Az$6Y?eTvH`!^VTA~OPH z=FR*okNV)#7PM#^+BATGh7hLiUF)htBZz1hx-^EE_Mk`m(5DF`bN~Z7gdrWlC?yZb zto_o1@jA%D4;OL1du2du(y7!+uF@P##n-%1rcCcm^DM2mi^43oTBa)lDX%Ip?^VGY znW}OnL6$->MYZTD3!1#|a zS{&}mP(R#N+OaT2wQVFTIILC4b!`WDRlQ@k7PA51jtI0J0*BC{;}AJ?9o#~p=g@aZ z90m?UhY{jMHnAPXmPRmU<1JUL;!sQYXLURI%Pnr^zdE{hTx-U@kUbh8_TjTf^LMIj zB&4<27^COY{HYW;dC2v)o>x1Zz1XltHD}Dtma(&HXQRi*|KJInVimXW`aa@&%k#V+ fzVEefqR{iN3F2lHBM9Q!;pgeN=YQ50TW9dQjBEYrRa$dN{V^SaRaGEnBv1S2S%VQp%$wN~S}R$}Z*D zO5CV8v}M}lD!Fk*rMa3$4rduHV5PvJ?xDaXhvM4)(`u_!Vz;?FqczaW0e7HcyFu@- z`@Y!+DVcWCD{w<_c6MGnJNwP=`+eWc{rr3jfe;RcLxb&v{2p)AP^cvC*XjtlMm&Tk z9>$|P&h+YdU9X`l=FZ2`*8v8lg&AP}_ zJgCcEVt{&{lEGr=CuhqJv&T;RF7edS=amlTd&&mwo&t{^%F40q0>hWql3L=aUst1I zu#DEScpOTta;Vebsf4jSVKj6^&d*V&3i}_l!ziki{wr`4bR>;*BsK6}SPMr&$2`am z<_{Kh=7)>fjB%$g zpNaCr=i4kFXz;FzD!=BW@DvA4pkhn$$3D1Bi>pTpv_vm=mqaQSM=sYZa6{E*6 zXo8;^r;SdX&SRt{c~D}8dLY0sB^Kq*Oz*6oJl-~Daa*Ngc@DCs;A0!~G}V%0ZlXgM zg2ORxG|F?Mbky7!ADB2Vv`KlBfslV3UualaDfAVL5bNQ8-wf}4jd+=Un!P^qlCYKm zHL}|)`A`GrjRvTpPLG-eSE!-2^O9VB*vVO)plc_B&Q6Xx2#M$hOlk*SUB6jPXRe*f zGf(UKoP<~j2W=xdcoNPs4e+FVHJ!P3DoLoRJz9t6S^%N1Z{e|xzhn&2##dxSx&mmk|A;tvL%!JSp>;29Ng5goV00u1?mFD=#;Gxpq!`*`~ha#~c z&k4ilqtWo-fzyYA!|<8;)8R>BaN;sb?s@;v#NZHJN&bnoPEQRs^C^DTw;X_ZfvQtS~WjyB%NJ9Jn@4Q%l7?0 zII+<5)}ciC^{V-*mD0xdPKfsXP~DueSIq65-MegWh_i2+|MZETlpbIN6vq#AoA&Qz zB_khf- z@ilsMt)MOO@%?!l9tyW!z4~ORze^r(-$D86)ir>;Q<-HzX7y8b2T*c2kfl&yC>(5! zMOzhd;l=|K6M&0tJgz|=my+8c8TdeCEXWtZos<`hOpXWn07z09-W#DoFf!sc@~{$# z#Q6CNyjtWA{h^UdQXUGepUSTulQ?Z4lMr}Vjf5}92Gu#FN)zx}s3A;3G)?YRammt~ z&Ku4Tt6G)|T7Om5G86a)FIKh0^-BemuTA@@BM8BH=|G4OJ{^% z+g*u~>lfxPtlC|o-IXe>NCf6hvnN(dw}_=%7Az~JJMJ5ZWAmDxG*%725ih%2LgvvzJx81zHLC^s|^UKwk>K48T0e1@rV3svda-Ri1=4H$#PmcS9%HW6sAxY1O1nxks&4*ow_jg-w&Y zh4ybwy*{;MuHCG&y{ua{kl< z0=MZBouMY-c|hQjHtP(|gzzb(opR6E*D(CGphD373y_tYUkn<6XrWTr5KyY|?_r_F zXVd}eYoKhy>@o6rV3Ay@7Z`x%D67FZ>9AA8R&fgoW!Vb(-Eg6w4A<#cGTlhao5nwQW!^7}Q z{s=_sH`Fq6e8WO{pxP${{jQ=@skt)}7!_SIr zQ()*uz^L^y1+Y53IY}^-7{sfWIZb}ZvH4d z$gPw0wGkKXq#)q(Q}i*SCm)>P0TMH>EN9^K2ng!aB*it{EVx-PTeWy9ZcG(B66FbV zGJj#@cJx+ssddktz*6D9CG$R6Xz+0!qrGXv$*ThpAk$d=LvTLU9q07oaRR zrdb`JYC{^3vT^Ysf@kAH0MGr5{CsR6>4qlw8c3}BI!Np3*ZB!(zH6kNQa&&!P!H0O zd-heb1;p+)5Ia4zW8rUr+Uj~xJAkWKk3s--1*mI;Ix1>EQYT^>K+Z)&KhV96p^kX7 zYoI#V5IjtmE@Jkw{i*{ctK4huM}@f|;H+a%--#n$!(Y<64~QDlWian@je*U2jxJ{n zL#9UGdvad`c6p4143w!UFQeA;>QrY9tOwoH8s&G&iTraOV`w&ZVlYE9RE93Q}7Cz)5WO zw{|6Zm+e~?s+O&tOKc|v04ZNFM`XS%6^9h2pv{_85DS3Li6J}1IbEk!GzcG}Adv+Q zng~YpUcHCu_5p+!!Gb6K2r3(by3iV+Ji=yKlA7(N+KIk?+7p0sAAo?b3?E`gR<#pzkK`;^s){8HFcbB9b=@y3b}5X> zYXG|#NF}w4%>aC6O@)CS2=HUw@bYjwpm|ulW_9U$^j%DZ^|G3UY5)tBo&0;4XmG5I zPXMu29aLHLOzF2~an8M0T1f zPo~RXI3Jh@a;}l5TDXzzEgbX=iHMsr%}3;|n?pY%8h&b#`c6U%A`|t@1p~w9HOdMV zm6@PDt`}G2z9Yd)*q>aGY3AD2IurmmlRMMF7Qj8l_UF(n-vDaBgkj#*@hA`}L!A1K ze)%aaz)?|OQEve_NA)A2akq&_QQ*&Fgoq~@Xe)`m7z(~D>4nL0V6Ny`1w?xRD}50o zKdl*o4g~b*2$Ny>mRK-=(GZrJp{XA$KMN37Smu+>^hwqwfGt6I4r3$N83+u$fy5+)#C36EfrwXd3k9<^EI^+GDSL6DW zwP4xG{l-?BD%d0zG(bji$&7v`IBSg?Qe}?0=VqTvyfix)H@t316<1t$z3oaKUUD{z z#m!$guH}=W>IWr+YfO6Id3Cw)@Y}B@^otR({m`6#{?+NDONEF3%~S}@=A5(6zcVDy zzgsxtTq@oM)XG{s-S?+YOr+=tBcQ(jU}xTe7WSX@Mu=f*(dLr!BFK_a->t(fjocUk zNj|@a0#N8bRNoaPHEb?Y1GRc6odH}z&lF%%Zen9X^m?+-3J8lK7h(W~$U$+uv_~&> zldCuaUaw;cRRHLa1E5e2GN!hG@xP?J<+r`sI}B))1zH1=Zwy|8H?LE{_K3x+(@+mA zEw80tqbZqoD(Xk9UaLpf{T65l7?tmPt$>iw!(T9Gz8n3endhh}b#_n(^a0^Sgzo!K z`3id_?BwsE0I)S2#tnc#WQ3cX7(v5bB@mA{nscJ+1{{4Hem;>@Aq}CD48%+JDsV1v z+k{3#;2sCun+OMngEEq8lHSdoJ0GNY4vIm`)!rftm`(!q;^3uaZxK$J?4vY<5Hl&; zqJ})h2+4Rc5S|PQ{0P=VBj0V4iPr_pco8EcQ+y>x6uKHAB^3aoB6090YLCYd1e1E$estxz#~k@;BRu>6fwj>5;_Ip2i8@DL(^uyV4=wV=Cg zxn)_Y-2qG}Zn;}s3m&qRy>f2%Tf2c^)o{tD=Fi1X-mR>O&Qc;x*lv9W)}<^keWy6HgyscA}%-hA=Ki!0UJ zUq2D=(Yn1?)s-qLpUa=kPZe!W6;<3fvDF3f6MrlqMU^RM(<*?Uvw7LsGGj{Fory1e zYj>)?U99h#IRVfLR5AC6cK58>szqCM;)|lKW#Rc{+phcdq-Mt_Wn@!_Oy*1@fpK(F z4`;7#4eLt&T&(+VSclncEcc3J5yi5w zr;O}%-&Zuo1}N%#DJP!MamZLIGt0g@87nqmo9T(zX7O6Gux*NBn+ec67e2{PG*GR8 zcF=~9!Asw691*xd^r&CFCcJBf+E4D6;-0ABURuFD)@_)q2E|~3&Cy|SOgV5!ng z5h2Upr{LZt$_f=W350R%cky^62={|RFb00dNr4-c`Qw<%JF^+&|2qOAmMBZ5n>!@0 z8I?@q#;0c7PtnGwkvg*{JIRUtW`h^a&ha6+f}k=1Ja zD^TqY{Du1v0dUq4*hJFZG~Y0Pw>@6)QB(V(?uW+rjDUUHj;u8GfJGYby;oi}cj4-V zWXVm(4aY*8xVh^mWk0R?phhe|74NyXbMO1QpBg_fzW*h0-`SO&=ftAzOGRADUjJ+R zQ>og^af(yjc7E)t)s2dD_SP(L&L2^YyLsTfftqfyUZKv9^7t z2cWTvTXk*~om&@9iq744dPQf?sL&zC(jZfMhx3B4l;McNQ183aY&m#7IQ$|P8dZf{;LZmw1~dNAbq*gzocu5`-Q2-Ro`*vs z8Kx-f6BfzJmohN7_UH4p3h$}q3cVU|A|GQe*Qp6v}L*BO2ioZX|D@wI@3 z85V*ZT^28ggrMcf891Gyp9)Rz`(Sb;ePBrDC~hkdCP zmhJ7U_8p>q$FjXMekfIFe{)KzsJ@>^Ht%C# zyLHPU2E04AB3Vq$<5yeHh?!;5(u+w+|XmdCk zomeO6pQ(b>x$oxs0+&MLljB)C)S8UTn`cjuZ4{)(C?2|oG`9X(9DcpCm7c~z7Y7{d z@}S+@0@ea?~=XuSGL}@GE&}{b`V?fTTRK% zMW0x*XUV!}iQPk2Uhd;w4iL9%uyeR>PbBBjE_DULii^OC(+{xlYBc0pZ{UbzaycmF zy^^f9pcys>N{n$|ESGaY-D`pG8!`9}Q&P+J!J0-bk$Yc6(=J9>j~NsF2&J*C?*O=V z^|(s6z`qa33+)LG)qpW=!PTfbD#yWv9D;L#91xJ9z)&a*`vI`WKsTXP7#+bK32N)Y z${do(r0c2n=T$G{s7fkjCaHNavzeq~DyxbGyQJ%ZF0Lo8@%0M|4>{t&M*gyT?kD3H zq_#2VdP;R3f{~r$^)7QG;6J1XAfTg-6cScZhC%Lw$$Kns1g9Y`E4t{UJoQ0>Oc%lJ zthfP@0SdTq6;&l`s$$jimuZJ)awU8k;VY1U%?l?fwof0s=iHQ`efN|diT>-pd0#Rj z)^z^Z`%f?Y=moLlc-(Z??pU?gi}w2D$q(&2z>+SnnH~Rn4=AhB4mcZRcg}5{-TJLo z*!{4U%zWV+yP;ex-xBY+>uR5AOFku*G%TAN*D7(}qMF#sme^87xdC4o!JX}Oi_Hv! zq{{6AwOzsI24@qIjc*LFw1C3mUPHe|sWNr?vlXV2t%)|9bw*O3YvE-r--Jrsd;w(k zRCS>RS&k0?1*s`RkGB8O&Sa^>C-zSlq}g}LYry3+pxM2-tWy99ZM!g%m)Qh_Wt_{h zq&8*Nd$ho4ExwzY6%0@UTcA(2Ilm{5%cJw=4WNsM2Cr4^l|n$93<17l3b?$jLPCF) z+9Gouz&&Rg;H%-6-Vopom2as&1N?OlQEod!2w4&T0UYqnX!$5TOdhvm&z_hPZP6+rkB_kaZoFyI)2ro)ZNV3Wqe!QoTl1$hiJYT?d zsi3TKk-@eAr{l6;ogMyfpcI)=g;S_I<|_A zt;>$qcs}rzjsC>Kk~d#V)iln(Fk?Z8X1}vFQS*KK!r9vcw+5EoU1DX|j4@T$GV}5* zav(>|+-p}~OI}&7=$v6cs;o_1S>3c<+_Zi1X>rr;mC9#k_NOWw*Ir9CwV==P-SLDq z)xIs+BRX1AJ312VJKp5(<(l?o2kt97p9RZfbDPL@T+LtG1a8K39my+|H?OrJuzR`S>FJ((*6PH$WozRS+epFQ z&75)T&ivu+c?>__>t_D&=70S|^3#7=+~a0H9>EWE9i#6;bjDpp&vi*gI4J<0O&+xY z&tim-%cDKc@52aeUqX)q@YgVfQplsw(36JpVS$5qOL^{oOrZv#o^AdVrjYDP271yU zi}SvLMW{6C_uO!ROfo+spP}5tzYA61I}>3A;T{g&ANl~n$L zxIQ&*W}IsTqEC-9R%Xi@f#}ofJZ3+$Mleo0oXpOp@-1ls*M*)3cu99RG9|!V=bo8; zCas6`-HNKj;p->nPly%nv;i`}Ia20Gn=oZ2M(fuGzdZOL5AN1ZGfm74-0W!po% zrniwsS88)(3gV{bwLFg5yHwenCU8yFZcgjrdbhGRZGbBbr?x(A!jzeC4QZ?kscoH$ z=kAOy`o;Fc4{6pBriLj?oJ$k9EH%RsKDfTmzVG>2<1br&-m-M~thn#o1I$eCV_KM| er3zP?z;*G^9s8Z(JC65G(%YpT-$Tr%vira7+U%bI diff --git a/pytorch_bp/__pycache__/uai_parser.cpython-313.pyc b/pytorch_bp/__pycache__/uai_parser.cpython-313.pyc deleted file mode 100644 index 8f6196828d055b86e235fe375a85e854cfbea562..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6326 zcmahtTWlLwc6T^4d<@?bMJcjqj~rPR6PcD0C$`l(8#{g|erO#rLYxfLPi}kc-+hHC*j#=6X9$)$&0hpJIhCXizAY9_*Ho(lhx#z9GzDS za$co|M)6HPqGdDMjC_{n^6IQw$mH~xbQkf9@}_V)V-!r`XeL!K#n%_|S#2WKhB^`I zFZWNv;3JYClmlkj;vD7UHY&vVgqW}+LUAGCcr5WkTx5v4;=n8_#U0c=>x{Zg`*Af@ z$kDql*u1#kdXZCE)i7kL<*8<9dLg6F%BnoW7Rni2F3bUa*4r#q^K&s%xCnHb_KRwE zK{HZ(9b0_h82|T2p!9Jz|!1*Y%K9^MlXQe zjnrH&mrWcwdsIv1(i%ORT`&^)%fRr|Tyi>}SWq)bOGRS&%ce7#%;=dyGHLoPLBtvW z4;*cTpevFG9@5oY8Mt<;B)spch5J5m-v>BJHCYHRs0obeXJG(##VKSGMq8VY#BDIs zWO!?SvY~J-5P{YF^Wp%B3KNSS3%?;?m~m*)IczcVkjX`(e2XKd$oPAcqh2`2Kn!K- zgB~eG?b<}C1?a++gU}U8Ef}e!Kh7*?DpRY@TcOXleYWk}(Ie}lN7hDO`FeECe{9Wp zj0Rz8)W$@@K0Fci;OV+Lufge3GC7}1FJy7-NhaS}P_y+Je=<3fp++H_(KS5>(~e{^ zol5~;$T*V+`GL$)PNv;bAqw7$&@j$+C6h)0OgNQP3k8~)UMOfrGI^6QY#4>qfs3(2 zjbQ}+4fzk@@P^1&rr_Uf2*7s#W#}I7 z0KoL^36VJga0Uqn^G|j(ckoZnxC6}I6?al;))kdZCpfAzppn_TPr-qUhtY2Vr?hO{ z_ER~%pn@{C8HF({Z!(KvS=G~ZqcUQqsG~BNBCat^hh=O=s$G##(U@8eF(S?fvhcJT z*nLxh&yQ^J0T-&6xEULhAs+@9><%Z2GR(IZ|OEu6EoQ0Z_|YPV3E-YU6E|~NO;@(Qmv~Oy6){C zc$m;PSujRM#;OSzJlTS$a2;!4%kAv0LmO_k-G2QI^oAHe4=uWfWxY=KA&OYmX*c<- zrZd_pv&@zaTb9sHr)9v%Cj%LT1}v|HKJERacg3@M>eu`iN54A#<>~5ySJ(ER{B8f5 z|MZ&kv}M11i^Yt$6Hzzqfg>m?C+)><8+OP;8o>@}wrs3vl~hJpKX&L8EHiu>N7&&> z#vbU<3{PzcE^$XCwL##qI91e>GW!6b55t^zrf!Y6-*jSF&#MLS2UII8 zP(9V|H=FYMD_FZJ%z^i9WF|JE+Z5X*;=ZN5P=D7Uoa3*tLnc#b#;m=;E06~rv<;Ac zGEC5fu%$sFGd7v*ATsQCBPO+(RD-ZF`7Dqp z7jkk8;u1?^+6*MAWeb<{+L(M)EvPaMnLSxb0~UqRRvt=gGwMRNFb1c9>x|80b86wm zeNjo8WIU?}@`m=wx#_pHR6#O1IuYeeuXR}cOBi8uc}+J3s;Ox+j6HgDEt{O7xq0xw z6xM*28!ti!Q8i2?Uumww-*)w_aJ7!k z8-q)O*M>@tM?8@t-?{r&c7G9Bcb`}jPCOC;@;g`Ge;H2L`ay7nza_oopm@$El|aXj zqP{;3M)?r|;{;J{ZJ-==XeT*mGi-Z_Ve2I0Q9^P%pQI>F@q-ec0o{JGTAj6>0b1Rl zAUY2!2ug-pr3)?P0vM051~$3~5CI{N!@8lafU~`cGX!?s-gctBaTa_X(sA~6$EFp0 zoQ12#I^DUK=&q?os3VHAF#{aA0H;fUlU%2c zAVdX?RQ!|P#%$|({-qzIwsDS36dzyXv5Cm zICbpTfE{2xg|3#hcvCVO*T!eR64<0e9ZH~%O<<~7kWB)uTiZ^cTnWIh139%tE|y+l z^kGeD%B7`Oh7Bab#a#z&dM7&tIS`oYZtn4dq(e9pv(eGn@r`P`Yqd zs6n~jrF2eWQZ*%!|Uk{E_wjAXqmtS|1V^(y)dbJLbRinOM!i;89K^;@Yosrk_CoGwfRAZUEl|LUrCg><*Rt9? zMl)96nA}B*MWBI22#V+bHY@Ad2lEuWKf~2Bup5%jFxegZbpDh1TSspX9;|v!elvKm zlK$$=FW>ws@m`|z`iE~`dGl&weehuM*qZ0$hMk0-ssxt~6i?sr_{(BhU2?v6rg*Ft z*|{>heE#FNmfyM^*;DnzzKQIqaPP;%iBYSXf?2a#3-c{k2hTfYmo`8dVf#UHy z-j4GAt8bOKTA-_}R{EE;CC|5ko$G;})xbbWywlZF>8iZ5+;iXjtGpnPk&wX;C z8rb{G3oG39&hl&5!j=B(y`NktiEDwqHK*@G?-g%Z`@OU8o{RYUAfa}9O0!pN61UG<{X3}_(e)RP z|FLg-@x&j3-6gT+4qlH}!q=v%Zuw3}-&*8EwPSqEGkz!3U9nd#F8ivX-6a=%q}+eS zU9#T^hHv!!ZC`E2u9cb9bGI_Wo z58q3~yZ49h10*=|!}kNkdjvop(!I~v37T@fXS_$a6@u+Krf zE&V&+F+MCkzDD^47&Rt1?McZ98EC)k*_;J)3iqVsMrS#2dlvkLZR^>J%}Q7%(R*)c zjr2t46@IdrPP7&^<{>Bu&7go`Mwn^_bwy}6>sFZ2MO{!tD=pIPBSd$!h0A7Gf?z7c zSx#6fvfvLGVpu~&<06i97tBFr4v9hAoGzXJRB>VXbOwZhp-6E&G1q_3%Ju7^LVyt9 zSo~Ez1ZWpCX-!W-j*iZKD2DI#G9bDs5z1m66%CEea9;#{yDejIMVKlpHm1+Wi%EgM;R?B~wF82gO@^4$Kgj;ISH7P2GBa Nb?)Z&hxp2{{9kDnVFUmG diff --git a/tests/__pycache__/__init__.cpython-313.pyc b/tests/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index 62929ae21c8ce455c0d49920b90f7d45abefb24f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 188 zcmey&%ge<81T&L!GF5=|V-N=h7@>^M96-iYhG2#whIB?vrYf0`)Z!9_g2d$P#Pn2! zwEQB4g36NoqU4PDqyjxZO~za7@$o77$?@?k89sxIxMkpK6;qy>SCU$!P@J5RpPv)s z6yTDYoS%|f6p&L|98&@`uQ(TZlX-=wL5gX6|kVA?=j`+aL O$jEq$L8*uZ$N>OG=`wEs diff --git a/tests/__pycache__/test_bp_basic.cpython-313.pyc b/tests/__pycache__/test_bp_basic.cpython-313.pyc deleted file mode 100644 index bd95ade5d54a57c1824cb4a3de0bd0f0b91f023d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4999 zcmb7ITWlN06`kGX@+H0`>Sg&w$(C%!5hW{<5fH6E+Q*vDt$JlfX5hml$kmL@h`#G$DZ>LnF2N66I$PS7_s^ zF`HoAJ&KT6gAm9zm1Xo+3+gF0Jt)dV8LxSXz6ZQKsmGEy9!@03rFeJ(Clg^s#W6Wb zstGOo>rZjU{E`I-zxiQ^h`sR7aADR9cn72@ywQvKUv0LrkUOlVRy}Y+RBf zQiRoYfG71F)q15L>Yq3-p&+x5*YldP($f^<4B|4Ep zuS}6GjOh4#n7g*U5 z+N{qz+TByc56n?lm?Pno`Qy%@mD(1x)2N}i^H(D@`j2+lQ`zAMzfe>JZENjQ&@Aw~ zFG8#V9!LM#4%j%#JhXR%MY}z^79%ZstNaAy+CpWCt>qEDB^Y_Sj3l_B$mBV zNtVu}@P5)Foe?8ynNECr2z@S(3+zX#$*_#2}-lQ*ntHU~A!JBAyggVubBYOA6tXbb?qUIh~NOsKP{N zV=4F`i6bY8QI%vRiD4r8NpH8;MR@H9A!eQU2@@ez6(Av|m!(+rqzYOT8pK%8f^Za6 zFtG}-vXl-I48)Cyf!c!GXp`6UB8CJ^T5v7(FW|1SJgNDaI9I_4?*>sip)wAj{&ouG zoeeXu%yvy3UaYCVIC-~W`=<@tZv;OIeGtk$a^!>14d!<}3)R1W;$|rO$dOwOL%Eva zso}iCHTBfJ`sR#vvC47VA$+|3uf9L~ZXX=UIfT1aLbgiC+uSpO>A>t`w{4q0x7FsW zYcs}t{rWc(7ZaJm#hQlO4Sfr}@9%kUPp+XaSJRg{l(*Gflx7F7`rq=;Z^>@(Rhftwz^w}+VAeSqa8zxLfK`ZZg8V50*z9kltRl2 z-O@{tt#<&PyVO!U22ra`wt zV;Pb#W<*~U-u2P}DcCcNetZN`yd)Ea2o4a&>a`N1|G>!b(Ps(Mb(ZV$bm0T=^;vGG zr}He|<>~b7EOeMUeZ8JeUpHa8&pJDOyPypsK9KM|;bn>8rDRMdrbrTKQC7VSF(^`e zLZRz<3Si|GNe!lm=ZB5-X%dgXCkj2Z0lp3E-<-TKd7rJaJ61S*qcvlA!?prL54(b9 zJ(?@pV7W`EoL50+Q9K#xH+s?C0T~$GJ*@;wPJ%4TRO*Et6Ig*047&{4vC<7QmU1jz zO3QU9F|I25Ivy0FRH>EGF8!>Q`A{i8s#x6^~bci1RggpVotEiNO4KaBliF;`;uW_74lZqKSbQ&yxs)?Sp z^ps7AYUHG(gbR$QVo4%B{{?A3{28TCWfH0>wCHM^Q{O&!?Oe{)o#7Yln`YBFduxVW ztZ$pc3kPoueDYYf{^)d7hR>UAGtOz}tl_HTEyr#1w#C}kIVo4$nXxRoHqY@nS6hbv zqQR4SDsOYm9GyOTW#DSya$tT}cE`RO>6@?IIzF1+@cbRy3->wH*o+Uu>JP;_1(yc! zL4fm*0XV>y>0|)XbaS*)0qzpy;j&N+6Pc?>JAtYw^aZ{M{s91F>5`z&2jJ&NXhN{6 zltEX9`RJk2I>r!96twdHhHH?e$)Xk(YiJQF1;bib14u-T0z;ipAy;2ipyENVqJpV| zt&m#d${?6`l{{9q6TDNp1eB+&NdJw3g=#QTO5Dp*`jtwb4X(#lp=uWg>G_&Iuh5Ui z?W&?#tF09@71k7Ng1wRmfc+f-Eycn^@G%~y6`ct`4V9O{KokhWT>@ch{}P4A5)EP? zc4I0oMkECvqodEzK34@5Q0^@A0T$@%aaui00XRtecF+pA4v}iec7Qhy#1lzH9XOR1 zONXAL61qbT6FD+53JRNBrHj`L?HbL4E(+yXSJA{!h+iU85yVS+rHp z1YQs1UF&9>FW9a~SL2uC^M@AP^SeI`e)Qta7qgp&S2))0c)%Gf)@3WQ*zcM*WzCze zr01TWKXq;FZfj4rwI|p5cy99(H@18%-lC(2i-dKBzv{T`n44UvzIJXQcw^_gqaXKY z8;^Wu9@0n%fI#^@;58F-I4p|^DI6x&F#Q`P9j9$uI867dIN}V4Ct_Gp<1ty1li@HP zgpctHv=X5rbz%6|1^#1+ghdtpeK`UDs3;&s*+W~daCz8EUui{C9);c_eva<54&Hj- zW8p*0*Y!?*m?_9pDXLjQjLG4!mkD?s#GIC6D&2HUqp@(_SW}Z)LJtQ;0ji|~s{l(4 zw3PUXnIhwm2w2qAh|oSr#bfd+Es1$=!3q0u0%T~KqtHlN<`{;#hn)A2>mIUwV=yx8 z*9fWyHH}Q&gWes?Ff-S9x4A3Z-1Pv_x8*Sg8El!eIc{CvVEcvtRsU;8mpNo;(k%QB D0Xzu# diff --git a/tests/__pycache__/test_integration.cpython-313.pyc b/tests/__pycache__/test_integration.cpython-313.pyc deleted file mode 100644 index f17763de94ba6578d84b074da4bd4234fa3f9e42..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2123 zcmb`I&2Jk;6u@Ucy-w`fY0{9mY3mjx#Xv)WMuZR*HE9mGYKXkC6y&I*jWbDBtk>r4 zn$|rb1SORaaEX*iNbRw`<;b7W$VIJ&oZx_Ri<+WJoS3)v+KmZtVI;qKGxPS%^LxL2 zdk~Lz5iz!kjytO80_1vRUImeoPeMj(s($FEg*kx)H3xrt52@$saZJSU#N5-0JnygBzQbymy^=wAx z31HSTEtQvzY51XTR%b!JLqP2H6W&9o7cZ|};5ut+2?o3g2x z>ibGnCZuvHPjV6Au4ClYoYAo*mZA(_5fa{&CWY?MX5pPN;SGnY36B_h$0#Jijl7oY z3HA0I7BKW~>Sp9^IrJ*ZN;1ampM&g{RjN45d&{$NZWMO3YT0|Ig|FWyzPiQmDEfAL(Fc^k}qh<5h5F#=cF*>eZT)BCy#sO6 zKM*`5D)HY0n{r5l`x3#07O!o>^DRPG#5ODfD=immuR_h# zJ6Lb=yD%M6>FUDc5iA#aLoazfcOul=^WRvHm_`Td(OJNk=yvL8;IZQXmJ={6!1fss z4tO*436be;5H7kk0eqL+tOJGw*hT}o;T6;dPz(x@i9ibVlR|JKY;7=vd#6y~iw0+e z@URdKrqhrXI!_(q?xyXpwR$6e{a+b9WQ(DBH6adj+B^JcWUAgfy>9HrPd*;H@PO5a zE^V*X;~)O-=)$Yw_J`sF65$4=u2o``@-B+)rYgD)if&4A%YY&p8Cy&%p+kM(nQs5FCuc7{yeC{o1)Nc}Y`KMbCqK&OKlE z-t+D`=cXMF8v<|!BjIZngdWg|2||`gUj*U`l90soA)LVsW?;_tvBx;f9pf?YLUpJa zNnAaW_#xyoyqb=lctPcY**-6?3#1&YVo&d8FgE@(< zN0_I3$AZLTOGt2Ah~t#1$NjzUoKUc?;uR}c1xz2*6+Iz6QTi7cAv%n2y_9MIcp*k7 znVzgfV+OM^K**oxvVP<}tY0p}~ruW3?};dyn1CMKrC3Qc%7 zgXt}Hvj)PB49h^QF*AlxmjMqBg}$Gq@aG{nLxhp|`Is8|<97=Gn4;?@@taAk9CAAe ze@;=}`=xB2gv2x7#fcURNVQVhuz!kj05$ zlxB0Oc9ZDPX~w%KkXMfGA37HXb?njuqtR$&=0 zz@FH7(m*BhCJn${P_R8uQKjWkAA)~$2`yLFOuMJt57|;jS(^7^-5EwvG6K^ub^tu z*1ZAVXV-VI!4gv}<-0Nmy{U+$pObirZ)fr=gwQs$2kp$d7=cp-Sr=+TF8BuCILI}j zAmbJWChFBslrhLR-K&Qw&b(nf^ulmaym_|QLs;S##%sJK<>S>RHR z?w^_En%e2$RPejeX>Ceds%f2k>oJGQdhX|Ma&kvX-1S#g%dOVKOI1fE#oX7<%Gb+KU&L_Cl=-Rt-O=+KnV_rb8rF${ATmET47zftd+rH$GB&D(dj mv@C9ESwk>c{g^?#eR5)n+q%r#zv%z8|Fh$(9OCPZP5%!gomCtF From 8dee52e53a98aed8c420b4cb62f4c1d88a681655 Mon Sep 17 00:00:00 2001 From: Mengkun Liu Date: Sun, 18 Jan 2026 13:41:15 +0800 Subject: [PATCH 3/8] Add extra BP tests and update README --- BP_PYTORCH_IMPLEMENTATION_PLAN.md | 66 ----------------- README.md | 3 + tests/testcase.py | 114 ++++++++++++++++++++++++++++++ 3 files changed, 117 insertions(+), 66 deletions(-) delete mode 100644 BP_PYTORCH_IMPLEMENTATION_PLAN.md create mode 100644 tests/testcase.py diff --git a/BP_PYTORCH_IMPLEMENTATION_PLAN.md b/BP_PYTORCH_IMPLEMENTATION_PLAN.md deleted file mode 100644 index d1daaba..0000000 --- a/BP_PYTORCH_IMPLEMENTATION_PLAN.md +++ /dev/null @@ -1,66 +0,0 @@ -# PyTorch Belief Propagation Implementation Plan - -## 1. Core Objectives - -1. Implement a generic Belief Propagation (BP) algorithm using PyTorch -2. Provide comprehensive mathematical description and code documentation - -## 2. Technical Analysis - -### 2.1 Algorithm Flow - -**Algorithm Flow**: Initialize → Collect Messages → Process Messages → Damping Update → Convergence Check → Compute Marginals - -### 2.2 Function Interface Summary - -| Function Name | Functionality | Input | Output | -|---------------|---------------|-------|--------| -| `read_model_file` | Parse UAI file | filepath | UAIModel | -| `read_evidence_file` | Parse evidence file | filepath | Dict[int, int] | -| `BeliefPropagation.__init__` | Construct BP object | UAIModel | BP object | -| `initial_state` | Initialize messages | BP object | BPState | -| `collect_message` | Factor→Variable message | BP, BPState | None | -| `process_message` | Variable→Factor message | BP, BPState, damping | None | -| `belief_propagate` | BP main loop | BP, parameters | (BPState, BPInfo) | -| `compute_marginals` | Compute marginal probabilities | BPState, BP | Dict[int, Tensor] | -| `apply_evidence` | Apply evidence | BP, evidence | BeliefPropagation | - -## 3. Project Structure and Testing - -### 3.1 Project Structure - -The Belief Propagation framework will be integrated as a submodule within the TensorInference.jl project: - -pytorch_bp_inference/ -├── README.md -├── requirements.txt -├── setup.py (optional) -├── src/ -│ ├── __init__.py -│ ├── uai_parser.py # UAI file parsing -│ ├── belief_propagation.py # BP core implementation -│ └── utils.py # Utility functions -├── tests/ -│ ├── __init__.py -│ ├── test_uai_parser.py -│ ├── test_bp.py -│ └── test_integration.py -├── examples/ -│ ├── asia_network/ -│ │ ├── main.py -│ │ └── model.uai -│ └── simple_example.py -└── docs/ - ├── mathematical_description.md - ├── api_reference.md - └── usage_guide.md - -### 3.2 Testing - -- [ ] Test parsing `examples/asia-network/model.uai` -- [ ] Test BP initialization and state creation -- [ ] Test message collection and processing -- [ ] Test convergence checking -- [ ] Test marginal computation -- [ ] Test evidence application -- [ ] Compare results with provided reference results (from test cases in TensorInference.jl) diff --git a/README.md b/README.md index 10e3803..1712d46 100644 --- a/README.md +++ b/README.md @@ -68,3 +68,6 @@ What each unit test covers: | `tests/test_bp_basic.py` | `test_apply_evidence` | `apply_evidence`, `belief_propagate`, `compute_marginals` | | `tests/test_integration.py` | `test_example_file_runs` | `read_model_file`, `BeliefPropagation`, `belief_propagate`, `compute_marginals` | | `tests/test_integration.py` | `test_example_with_evidence` | `read_evidence_file`, `apply_evidence` + BP pipeline | +| `tests/testcase.py` | `test_unary_factor_marginal` | `belief_propagate`, `compute_marginals` (unary factor) | +| `tests/testcase.py` | `test_chain_three_vars_exact` | `belief_propagate`, `compute_marginals` (tree exactness) | +| `tests/testcase.py` | `test_message_normalization` | `collect_message`, `process_message` (normalization) | diff --git a/tests/testcase.py b/tests/testcase.py new file mode 100644 index 0000000..a94d2a3 --- /dev/null +++ b/tests/testcase.py @@ -0,0 +1,114 @@ +import unittest +import itertools +import torch + +from pytorch_bp import ( + read_model_from_string, + BeliefPropagation, + belief_propagate, + compute_marginals, + initial_state, + collect_message, + process_message, +) + + +def exact_marginals(model, evidence=None): + evidence = evidence or {} + assignments = list(itertools.product(*[range(c) for c in model.cards])) + weights = [] + for assignment in assignments: + if any(assignment[var_idx - 1] != val for var_idx, val in evidence.items()): + weights.append(0.0) + continue + weight = 1.0 + for factor in model.factors: + idx = tuple(assignment[v - 1] for v in factor.vars) + weight *= float(factor.values[idx]) + weights.append(weight) + total = sum(weights) + marginals = {} + for var_idx, card in enumerate(model.cards): + values = [] + for value in range(card): + mass = 0.0 + for assignment, weight in zip(assignments, weights): + if assignment[var_idx] == value: + mass += weight + values.append(mass / total if total > 0 else 0.0) + marginals[var_idx + 1] = torch.tensor(values, dtype=torch.float64) + return marginals + + +class TestBPAdditionalCases(unittest.TestCase): + def test_unary_factor_marginal(self): + content = "\n".join( + [ + "MARKOV", + "1", + "3", + "1", + "1 0", + "3", + "0.2 0.3 0.5", + ] + ) + model = read_model_from_string(content) + bp = BeliefPropagation(model) + state, info = belief_propagate(bp, max_iter=20, tol=1e-10, damping=0.0) + self.assertTrue(info.converged) + marginals = compute_marginals(state, bp) + expected = torch.tensor([0.2, 0.3, 0.5], dtype=torch.float64) + self.assertTrue(torch.allclose(marginals[1], expected, atol=1e-6)) + + def test_chain_three_vars_exact(self): + content = "\n".join( + [ + "MARKOV", + "3", + "2 2 2", + "2", + "2 0 1", + "2 1 2", + "4", + "0.9 0.1 0.2 0.8", + "4", + "0.3 0.7 0.6 0.4", + ] + ) + model = read_model_from_string(content) + bp = BeliefPropagation(model) + state, info = belief_propagate(bp, max_iter=50, tol=1e-10, damping=0.0) + self.assertTrue(info.converged) + marginals = compute_marginals(state, bp) + exact = exact_marginals(model) + for var_idx in marginals: + self.assertTrue(torch.allclose(marginals[var_idx], exact[var_idx], atol=1e-6)) + + def test_message_normalization(self): + content = "\n".join( + [ + "MARKOV", + "2", + "2 2", + "1", + "2 0 1", + "4", + "0.9 0.1 0.2 0.8", + ] + ) + model = read_model_from_string(content) + bp = BeliefPropagation(model) + state = initial_state(bp) + collect_message(bp, state, normalize=True) + process_message(bp, state, normalize=True, damping=0.0) + for var_msgs in state.message_in: + for msg in var_msgs: + self.assertAlmostEqual(float(msg.sum()), 1.0, places=6) + for var_msgs in state.message_out: + for msg in var_msgs: + self.assertAlmostEqual(float(msg.sum()), 1.0, places=6) + + +if __name__ == "__main__": + unittest.main() From fc89e7f4e09c800b4cf3c3a5b0d4795b53229e6e Mon Sep 17 00:00:00 2001 From: Mengkun Liu Date: Mon, 19 Jan 2026 16:51:05 +0800 Subject: [PATCH 4/8] Apply BP fixes and update tests --- .../pytorch_bp/belief_propagation.py | 36 +++++++++---------- src/bpdecoderplus/pytorch_bp/uai_parser.py | 11 +++++- src/bpdecoderplus/pytorch_bp/utils.py | 19 ---------- tests/_path.py | 6 ++-- tests/test_bp_basic.py | 35 +++++------------- tests/test_integration.py | 11 ++++-- tests/test_uai_parser.py | 7 ++++ tests/testcase.py | 34 +++++------------- 8 files changed, 64 insertions(+), 95 deletions(-) delete mode 100644 src/bpdecoderplus/pytorch_bp/utils.py diff --git a/src/bpdecoderplus/pytorch_bp/belief_propagation.py b/src/bpdecoderplus/pytorch_bp/belief_propagation.py index dcef8a5..b774e9e 100644 --- a/src/bpdecoderplus/pytorch_bp/belief_propagation.py +++ b/src/bpdecoderplus/pytorch_bp/belief_propagation.py @@ -2,7 +2,7 @@ Belief Propagation (BP) algorithm implementation using PyTorch. """ -from typing import List, Dict, Tuple, Optional +from typing import List, Dict, Tuple import torch from copy import deepcopy @@ -88,7 +88,7 @@ def initial_state(bp: BeliefPropagation) -> BPState: var_messages_in = [] var_messages_out = [] - for factor_idx in bp.v2t[var_idx]: + for _ in bp.v2t[var_idx]: card = bp.cards[var_idx] msg = torch.ones(card, dtype=torch.float64) var_messages_in.append(msg.clone()) @@ -97,7 +97,7 @@ def initial_state(bp: BeliefPropagation) -> BPState: message_in.append(var_messages_in) message_out.append(var_messages_out) - return BPState(deepcopy(message_in), message_out) + return BPState(message_in, message_out) def _compute_factor_to_var_message( @@ -124,7 +124,7 @@ def _compute_factor_to_var_message( return factor_tensor.clone() # Multiply factor tensor by incoming messages (excluding target) and sum out dims. - result = factor_tensor + result = factor_tensor.clone() for dim in range(ndims): if dim == target_var_idx: continue @@ -154,11 +154,13 @@ def collect_message(bp: BeliefPropagation, state: BPState, normalize: bool = Tru for factor_idx, factor in enumerate(bp.factors): # Get incoming messages from variables to this factor incoming_messages = [] + var_factor_positions = [] for var in factor.vars: var_idx_0based = var - 1 # Find position of this factor in v2t[var_idx_0based] factor_pos = bp.v2t[var_idx_0based].index(factor_idx) incoming_messages.append(state.message_out[var_idx_0based][factor_pos]) + var_factor_positions.append(factor_pos) # Compute outgoing message to each variable for var_pos, var in enumerate(factor.vars): @@ -177,7 +179,7 @@ def collect_message(bp: BeliefPropagation, state: BPState, normalize: bool = Tru outgoing_msg = outgoing_msg / msg_sum # Update message_in - factor_pos = bp.v2t[var_idx_0based].index(factor_idx) + factor_pos = var_factor_positions[var_pos] state.message_in[var_idx_0based][factor_pos] = outgoing_msg @@ -334,19 +336,17 @@ def apply_evidence(bp: BeliefPropagation, evidence: Dict[int, int]) -> BeliefPro for var_pos, var in enumerate(factor.vars): if var in evidence: evid_value = evidence[var] - # Create slice that zeros out non-evidence values - slices = [slice(None)] * len(factor.vars) - slices[var_pos] = evid_value - - # Zero out all non-evidence assignments - mask = torch.ones_like(factor_tensor) - for i in range(factor_tensor.shape[var_pos]): - if i != evid_value: - slices_mask = slices.copy() - slices_mask[var_pos] = i - mask[tuple(slices_mask)] = 0 - - factor_tensor = factor_tensor * mask + dim_size = factor_tensor.shape[var_pos] + if 0 <= evid_value < dim_size: + all_indices = torch.arange(dim_size, device=factor_tensor.device) + zero_indices = all_indices[all_indices != evid_value] + if zero_indices.numel() > 0: + factor_tensor = factor_tensor.index_fill( + var_pos, zero_indices, 0 + ) + else: + factor_tensor = torch.zeros_like(factor_tensor) + break new_factors.append(Factor(factor.vars, factor_tensor)) diff --git a/src/bpdecoderplus/pytorch_bp/uai_parser.py b/src/bpdecoderplus/pytorch_bp/uai_parser.py index 5077d8e..6318978 100644 --- a/src/bpdecoderplus/pytorch_bp/uai_parser.py +++ b/src/bpdecoderplus/pytorch_bp/uai_parser.py @@ -2,7 +2,7 @@ UAI file format parser for Belief Propagation. """ -from typing import List, Dict, Tuple +from typing import List, Dict import torch @@ -71,6 +71,10 @@ def read_model_from_string(content: str, factor_eltype=torch.float64) -> UAIMode # Parse header network_type = lines[0] # MARKOV or BAYES + if network_type not in ("MARKOV", "BAYES"): + raise ValueError( + f"Unsupported UAI network type: {network_type!r}. Expected 'MARKOV' or 'BAYES'." + ) nvars = int(lines[1]) cards = [int(x) for x in lines[2].split()] ntables = int(lines[3]) @@ -80,6 +84,11 @@ def read_model_from_string(content: str, factor_eltype=torch.float64) -> UAIMode for i in range(ntables): parts = lines[4 + i].split() scope_size = int(parts[0]) + if len(parts) - 1 != scope_size: + raise ValueError( + f"Scope size mismatch on line {4 + i}: " + f"declared {scope_size}, found {len(parts) - 1} variables." + ) scope = [int(x) + 1 for x in parts[1:]] # Convert to 1-based scopes.append(scope) diff --git a/src/bpdecoderplus/pytorch_bp/utils.py b/src/bpdecoderplus/pytorch_bp/utils.py deleted file mode 100644 index eccbf92..0000000 --- a/src/bpdecoderplus/pytorch_bp/utils.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -Utility functions for Belief Propagation. -""" - -from typing import List -import torch - - -def deep_copy_messages(messages: List[List[torch.Tensor]]) -> List[List[torch.Tensor]]: - """ - Deep copy message structure. - - Args: - messages: Nested list of message tensors - - Returns: - Deep copy of messages - """ - return [[msg.clone() for msg in var_msgs] for var_msgs in messages] diff --git a/tests/_path.py b/tests/_path.py index beddd36..2194164 100644 --- a/tests/_path.py +++ b/tests/_path.py @@ -4,5 +4,7 @@ def add_project_root_to_path(): project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) - if project_root not in sys.path: - sys.path.insert(0, project_root) + src_root = os.path.join(project_root, "src") + for path in (src_root, project_root): + if path not in sys.path: + sys.path.insert(0, path) diff --git a/tests/test_bp_basic.py b/tests/test_bp_basic.py index 9fde27f..be18a00 100644 --- a/tests/test_bp_basic.py +++ b/tests/test_bp_basic.py @@ -1,7 +1,13 @@ import unittest -import itertools import torch +try: + from ._path import add_project_root_to_path +except ImportError: + from _path import add_project_root_to_path + +add_project_root_to_path() + from bpdecoderplus.pytorch_bp import ( read_model_from_string, BeliefPropagation, @@ -10,32 +16,7 @@ apply_evidence, ) - -def exact_marginals(model, evidence=None): - evidence = evidence or {} - assignments = list(itertools.product(*[range(c) for c in model.cards])) - weights = [] - for assignment in assignments: - if any(assignment[var_idx - 1] != val for var_idx, val in evidence.items()): - weights.append(0.0) - continue - weight = 1.0 - for factor in model.factors: - idx = tuple(assignment[v - 1] for v in factor.vars) - weight *= float(factor.values[idx]) - weights.append(weight) - total = sum(weights) - marginals = {} - for var_idx, card in enumerate(model.cards): - values = [] - for value in range(card): - mass = 0.0 - for assignment, weight in zip(assignments, weights): - if assignment[var_idx] == value: - mass += weight - values.append(mass / total if total > 0 else 0.0) - marginals[var_idx + 1] = torch.tensor(values, dtype=torch.float64) - return marginals +from tests.test_utils import exact_marginals class TestBeliefPropagationBasic(unittest.TestCase): diff --git a/tests/test_integration.py b/tests/test_integration.py index 8110ed7..80ff83e 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,5 +1,12 @@ import unittest +try: + from ._path import add_project_root_to_path +except ImportError: + from _path import add_project_root_to_path + +add_project_root_to_path() + from bpdecoderplus.pytorch_bp import ( read_model_file, read_evidence_file, @@ -15,7 +22,7 @@ def test_example_file_runs(self): model = read_model_file("examples/simple_model.uai") bp = BeliefPropagation(model) state, info = belief_propagate(bp, max_iter=30, tol=1e-8) - self.assertTrue(info.iterations > 0) + self.assertGreater(info.iterations, 0) marginals = compute_marginals(state, bp) self.assertEqual(set(marginals.keys()), {1, 2}) @@ -24,7 +31,7 @@ def test_example_with_evidence(self): evidence = read_evidence_file("examples/simple_model.evid") bp = apply_evidence(BeliefPropagation(model), evidence) state, info = belief_propagate(bp, max_iter=30, tol=1e-8) - self.assertTrue(info.iterations > 0) + self.assertGreater(info.iterations, 0) marginals = compute_marginals(state, bp) self.assertEqual(set(marginals.keys()), {1, 2}) diff --git a/tests/test_uai_parser.py b/tests/test_uai_parser.py index 99ac091..316953c 100644 --- a/tests/test_uai_parser.py +++ b/tests/test_uai_parser.py @@ -1,6 +1,13 @@ import unittest import torch +try: + from ._path import add_project_root_to_path +except ImportError: + from _path import add_project_root_to_path + +add_project_root_to_path() + from bpdecoderplus.pytorch_bp import read_model_from_string, read_evidence_file diff --git a/tests/testcase.py b/tests/testcase.py index 238d9fb..7840f0e 100644 --- a/tests/testcase.py +++ b/tests/testcase.py @@ -1,7 +1,13 @@ import unittest -import itertools import torch +try: + from ._path import add_project_root_to_path +except ImportError: + from _path import add_project_root_to_path + +add_project_root_to_path() + from bpdecoderplus.pytorch_bp import ( read_model_from_string, BeliefPropagation, @@ -13,31 +19,7 @@ ) -def exact_marginals(model, evidence=None): - evidence = evidence or {} - assignments = list(itertools.product(*[range(c) for c in model.cards])) - weights = [] - for assignment in assignments: - if any(assignment[var_idx - 1] != val for var_idx, val in evidence.items()): - weights.append(0.0) - continue - weight = 1.0 - for factor in model.factors: - idx = tuple(assignment[v - 1] for v in factor.vars) - weight *= float(factor.values[idx]) - weights.append(weight) - total = sum(weights) - marginals = {} - for var_idx, card in enumerate(model.cards): - values = [] - for value in range(card): - mass = 0.0 - for assignment, weight in zip(assignments, weights): - if assignment[var_idx] == value: - mass += weight - values.append(mass / total if total > 0 else 0.0) - marginals[var_idx + 1] = torch.tensor(values, dtype=torch.float64) - return marginals +from tests.test_utils import exact_marginals class TestBPAdditionalCases(unittest.TestCase): From 6b3f4e2de319c0c7fc388c0daf032ab53c08a851 Mon Sep 17 00:00:00 2001 From: Mengkun Liu Date: Mon, 19 Jan 2026 17:08:54 +0800 Subject: [PATCH 5/8] Add tests for BP edge cases and UAI errors --- tests/test_uai_parser.py | 40 ++++++++++++++++++++++++++++++++++++++++ tests/testcase.py | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+) diff --git a/tests/test_uai_parser.py b/tests/test_uai_parser.py index 316953c..044a3da 100644 --- a/tests/test_uai_parser.py +++ b/tests/test_uai_parser.py @@ -58,6 +58,46 @@ def test_read_evidence_file(self): evidence = read_evidence_file("examples/simple_model.evid") self.assertEqual(evidence, {1: 1}) + def test_invalid_network_type(self): + content = "\n".join( + [ + "INVALID", + "1", + "2", + "0", + ] + ) + with self.assertRaises(ValueError): + read_model_from_string(content) + + def test_scope_size_mismatch(self): + content = "\n".join( + [ + "MARKOV", + "2", + "2 2", + "1", + "2 0", + "2", + "0.5 0.5", + ] + ) + with self.assertRaises(ValueError): + read_model_from_string(content) + + def test_missing_table_entries(self): + content = "\n".join( + [ + "MARKOV", + "1", + "2", + "1", + "1 0", + ] + ) + with self.assertRaises(ValueError): + read_model_from_string(content) + if __name__ == "__main__": unittest.main() diff --git a/tests/testcase.py b/tests/testcase.py index 7840f0e..7b3f99b 100644 --- a/tests/testcase.py +++ b/tests/testcase.py @@ -16,6 +16,7 @@ initial_state, collect_message, process_message, + apply_evidence, ) @@ -91,6 +92,45 @@ def test_message_normalization(self): for msg in var_msgs: self.assertAlmostEqual(float(msg.sum()), 1.0, places=6) + def test_zero_message_handling(self): + content = "\n".join( + [ + "MARKOV", + "1", + "2", + "2", + "1 0", + "1 0", + "2", + "0.0 0.0", + "2", + "0.7 0.3", + ] + ) + model = read_model_from_string(content) + bp = BeliefPropagation(model) + state = initial_state(bp) + collect_message(bp, state, normalize=True) + process_message(bp, state, normalize=True, damping=0.0) + self.assertAlmostEqual(float(state.message_in[0][0].sum()), 0.0, places=6) + self.assertAlmostEqual(float(state.message_out[0][1].sum()), 0.0, places=6) + + def test_evidence_out_of_range_zeros_factor(self): + content = "\n".join( + [ + "MARKOV", + "1", + "2", + "1", + "1 0", + "2", + "0.4 0.6", + ] + ) + model = read_model_from_string(content) + bp = apply_evidence(BeliefPropagation(model), {1: 5}) + self.assertAlmostEqual(float(bp.factors[0].values.sum()), 0.0, places=6) + if __name__ == "__main__": unittest.main() From ab9eff0d92cc25c5c3888cc73b6e81916f7a817d Mon Sep 17 00:00:00 2001 From: Mengkun Liu Date: Fri, 23 Jan 2026 16:56:45 +0800 Subject: [PATCH 6/8] Add independent tropical tensor network package Includes standalone src, tests, docs, and example assets. --- tropical_in_new/README.md | 38 +++++ tropical_in_new/docs/api_reference.md | 52 ++++++ .../docs/mathematical_description.md | 32 ++++ tropical_in_new/docs/usage_guide.md | 30 ++++ tropical_in_new/examples/asia_network/main.py | 21 +++ .../examples/asia_network/model.uai | 19 +++ tropical_in_new/requirements.txt | 2 + tropical_in_new/src/__init__.py | 25 +++ tropical_in_new/src/contraction.py | 161 ++++++++++++++++++ tropical_in_new/src/mpe.py | 91 ++++++++++ tropical_in_new/src/network.py | 26 +++ tropical_in_new/src/primitives.py | 150 ++++++++++++++++ tropical_in_new/src/utils.py | 137 +++++++++++++++ tropical_in_new/tests/conftest.py | 5 + tropical_in_new/tests/test_contraction.py | 25 +++ tropical_in_new/tests/test_mpe.py | 69 ++++++++ tropical_in_new/tests/test_primitives.py | 26 +++ 17 files changed, 909 insertions(+) create mode 100644 tropical_in_new/README.md create mode 100644 tropical_in_new/docs/api_reference.md create mode 100644 tropical_in_new/docs/mathematical_description.md create mode 100644 tropical_in_new/docs/usage_guide.md create mode 100644 tropical_in_new/examples/asia_network/main.py create mode 100644 tropical_in_new/examples/asia_network/model.uai create mode 100644 tropical_in_new/requirements.txt create mode 100644 tropical_in_new/src/__init__.py create mode 100644 tropical_in_new/src/contraction.py create mode 100644 tropical_in_new/src/mpe.py create mode 100644 tropical_in_new/src/network.py create mode 100644 tropical_in_new/src/primitives.py create mode 100644 tropical_in_new/src/utils.py create mode 100644 tropical_in_new/tests/conftest.py create mode 100644 tropical_in_new/tests/test_contraction.py create mode 100644 tropical_in_new/tests/test_mpe.py create mode 100644 tropical_in_new/tests/test_primitives.py diff --git a/tropical_in_new/README.md b/tropical_in_new/README.md new file mode 100644 index 0000000..7898009 --- /dev/null +++ b/tropical_in_new/README.md @@ -0,0 +1,38 @@ +# Tropical Tensor Network for MPE + +This folder contains an independent implementation of tropical tensor network +contraction for Most Probable Explanation (MPE). It does not depend on the +`bpdecoderplus` package; all code lives under `tropical_in_new/src`. + +## Structure + +``` +tropical_in_new/ +├── README.md +├── requirements.txt +├── src/ +│ ├── __init__.py +│ ├── primitives.py +│ ├── network.py +│ ├── contraction.py +│ ├── mpe.py +│ └── utils.py +├── tests/ +│ ├── test_primitives.py +│ ├── test_contraction.py +│ └── test_mpe.py +├── examples/ +│ └── asia_network/ +│ ├── main.py +│ └── model.uai +└── docs/ + ├── mathematical_description.md + ├── api_reference.md + └── usage_guide.md +``` + +## Quick Start + +```bash +python tropical_in_new/examples/asia_network/main.py +``` diff --git a/tropical_in_new/docs/api_reference.md b/tropical_in_new/docs/api_reference.md new file mode 100644 index 0000000..f3358eb --- /dev/null +++ b/tropical_in_new/docs/api_reference.md @@ -0,0 +1,52 @@ +## Tropical Tensor Network API Reference + +Public APIs exported from `tropical_in_new/src`. + +### UAI Parsing + +- `read_model_file(path, factor_eltype=torch.float64) -> UAIModel` + Parse a UAI `.uai` model file. + +- `read_model_from_string(content, factor_eltype=torch.float64) -> UAIModel` + Parse a UAI model from an in-memory string. + +### Data Structures + +- `Factor(vars: Tuple[int, ...], values: torch.Tensor)` + Container for a factor scope and its tensor. + +- `UAIModel(nvars: int, cards: List[int], factors: List[Factor])` + Holds all model metadata for MPE. + +### Primitives + +- `safe_log(tensor: torch.Tensor) -> torch.Tensor` + Convert potentials to log-domain; maps zeros to `-inf`. + +- `tropical_einsum(a, b, index_map, track_argmax=True)` + Binary contraction in max-plus semiring; returns `(values, backpointer)`. + +- `argmax_trace(backpointer, assignment) -> Dict[int, int]` + Decode assignments for eliminated variables from backpointer metadata. + +### Network + Contraction + +- `build_network(factors: Iterable[Factor]) -> list[TensorNode]` + Convert factors into log-domain tensors with scopes. + +- `choose_order(nodes, heuristic="min_fill") -> list[int]` + Select variable elimination order (min-fill / min-degree). + +- `build_contraction_tree(order, nodes) -> ContractionTree` + Construct a contraction plan from order and nodes. + +- `contract_tree(tree, einsum_fn, track_argmax=True) -> TreeNode` + Execute contractions and return the root node with backpointers. + +### MPE + +- `mpe_tropical(model, evidence=None, order=None)` + Return `(assignment_dict, score, info)` where `score` is log-domain. + +- `recover_mpe_assignment(root) -> Dict[int, int]` + Recover assignments from a contraction tree root. diff --git a/tropical_in_new/docs/mathematical_description.md b/tropical_in_new/docs/mathematical_description.md new file mode 100644 index 0000000..ae18385 --- /dev/null +++ b/tropical_in_new/docs/mathematical_description.md @@ -0,0 +1,32 @@ +## Tropical Tensor Network for MPE + +We compute the Most Probable Explanation (MPE) by treating the factor +graph as a tensor network in the tropical semiring (max-plus). + +### Tropical Semiring + +For log-potentials, multiplication becomes addition and summation becomes +maximization: + +- `a ⊗ b = a + b` +- `a ⊕ b = max(a, b)` + +### MPE Objective + +Let `x` be the set of variables and `phi_f(x_f)` the factor potentials. +We maximize the log-score: + +`x* = argmax_x sum_f log(phi_f(x_f))` + +### Contraction + +Each factor becomes a tensor over its variable scope. Contraction combines +two tensors by adding their log-values and reducing (max) over eliminated +variables. A greedy elimination order (min-fill or min-degree) controls +the intermediate tensor sizes. + +### Backpointers + +During each reduction, we store the argmax index for eliminated variables. +Traversing the contraction tree with these backpointers recovers the MPE +assignment. diff --git a/tropical_in_new/docs/usage_guide.md b/tropical_in_new/docs/usage_guide.md new file mode 100644 index 0000000..280b091 --- /dev/null +++ b/tropical_in_new/docs/usage_guide.md @@ -0,0 +1,30 @@ +## Tropical Tensor Network Usage + +This guide shows how to parse a UAI model and compute MPE using the +tropical tensor network implementation. + +### Quick Start + +```python +from src import mpe_tropical, read_model_file + +model = read_model_file("tropical_in_new/examples/asia_network/model.uai") +assignment, score, info = mpe_tropical(model) +print(assignment, score, info) +``` + +### Evidence + +```python +from src import mpe_tropical, read_model_file + +model = read_model_file("tropical_in_new/examples/asia_network/model.uai") +evidence = {1: 0} # variable index is 1-based +assignment, score, info = mpe_tropical(model, evidence=evidence) +``` + +### Running the Example + +```bash +python tropical_in_new/examples/asia_network/main.py +``` diff --git a/tropical_in_new/examples/asia_network/main.py b/tropical_in_new/examples/asia_network/main.py new file mode 100644 index 0000000..4bfcef0 --- /dev/null +++ b/tropical_in_new/examples/asia_network/main.py @@ -0,0 +1,21 @@ +"""Run tropical MPE on a small UAI model.""" + +from pathlib import Path +import sys + +ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(ROOT)) + +from src import mpe_tropical, read_model_file # noqa: E402 + + +def main() -> None: + model = read_model_file("tropical_in_new/examples/asia_network/model.uai") + assignment, score, info = mpe_tropical(model) + print("MPE assignment:", assignment) + print("MPE log-score:", score) + print("Info:", info) + + +if __name__ == "__main__": + main() diff --git a/tropical_in_new/examples/asia_network/model.uai b/tropical_in_new/examples/asia_network/model.uai new file mode 100644 index 0000000..e60d689 --- /dev/null +++ b/tropical_in_new/examples/asia_network/model.uai @@ -0,0 +1,19 @@ +MARKOV +3 +2 2 2 +5 +1 0 +1 1 +1 2 +2 0 1 +2 1 2 +2 +0.6 0.4 +2 +0.5 0.5 +2 +0.7 0.3 +4 +1.2 0.2 0.2 1.2 +4 +1.1 0.3 0.3 1.1 diff --git a/tropical_in_new/requirements.txt b/tropical_in_new/requirements.txt new file mode 100644 index 0000000..5259e95 --- /dev/null +++ b/tropical_in_new/requirements.txt @@ -0,0 +1,2 @@ +torch>=2.0.0 +numpy>=1.24.0 diff --git a/tropical_in_new/src/__init__.py b/tropical_in_new/src/__init__.py new file mode 100644 index 0000000..3689064 --- /dev/null +++ b/tropical_in_new/src/__init__.py @@ -0,0 +1,25 @@ +"""Tropical tensor network tools for MPE (independent package).""" + +from .contraction import build_contraction_tree, choose_order, contract_tree +from .mpe import mpe_tropical, recover_mpe_assignment +from .network import TensorNode, build_network +from .primitives import argmax_trace, safe_log, tropical_einsum +from .utils import Factor, UAIModel, build_tropical_factors, read_model_file, read_model_from_string + +__all__ = [ + "Factor", + "TensorNode", + "UAIModel", + "argmax_trace", + "build_contraction_tree", + "build_network", + "build_tropical_factors", + "choose_order", + "contract_tree", + "mpe_tropical", + "read_model_file", + "read_model_from_string", + "recover_mpe_assignment", + "safe_log", + "tropical_einsum", +] diff --git a/tropical_in_new/src/contraction.py b/tropical_in_new/src/contraction.py new file mode 100644 index 0000000..b1b0a2e --- /dev/null +++ b/tropical_in_new/src/contraction.py @@ -0,0 +1,161 @@ +"""Contraction ordering and binary contraction tree execution.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Iterable, Tuple + +import torch + +try: # Optional heuristic provider + import omeco # type: ignore +except Exception: # pragma: no cover - optional dependency + omeco = None + +from .network import TensorNode +from .primitives import Backpointer, tropical_reduce_max +from .utils import build_index_map + + +@dataclass +class ContractNode: + vars: Tuple[int, ...] + values: torch.Tensor + left: "TreeNode" + right: "TreeNode" + elim_vars: Tuple[int, ...] + backpointer: Backpointer | None + + +@dataclass +class ReduceNode: + vars: Tuple[int, ...] + values: torch.Tensor + child: "TreeNode" + elim_vars: Tuple[int, ...] + backpointer: Backpointer | None + + +TreeNode = TensorNode | ContractNode | ReduceNode + + +@dataclass(frozen=True) +class ContractionTree: + order: Tuple[int, ...] + nodes: Tuple[TensorNode, ...] + + +def _build_var_graph(nodes: Iterable[TensorNode]) -> dict[int, set[int]]: + graph: dict[int, set[int]] = {} + for node in nodes: + vars = list(node.vars) + for var in vars: + graph.setdefault(var, set()).update(v for v in vars if v != var) + return graph + + +def _min_fill_order(graph: dict[int, set[int]]) -> list[int]: + order: list[int] = [] + graph = {k: set(v) for k, v in graph.items()} + while graph: + best_var = None + best_fill = None + best_degree = None + for var, neighbors in graph.items(): + fill = 0 + neighbor_list = list(neighbors) + for i in range(len(neighbor_list)): + for j in range(i + 1, len(neighbor_list)): + if neighbor_list[j] not in graph[neighbor_list[i]]: + fill += 1 + degree = len(neighbors) + if best_fill is None or (fill, degree) < (best_fill, best_degree): + best_var = var + best_fill = fill + best_degree = degree + if best_var is None: + break + neighbors = list(graph[best_var]) + for i in range(len(neighbors)): + for j in range(i + 1, len(neighbors)): + graph[neighbors[i]].add(neighbors[j]) + graph[neighbors[j]].add(neighbors[i]) + for neighbor in neighbors: + graph[neighbor].discard(best_var) + graph.pop(best_var, None) + order.append(best_var) + return order + + +def choose_order(nodes: list[TensorNode], heuristic: str = "min_fill") -> list[int]: + """Select elimination order over variable indices.""" + if heuristic == "omeco" and omeco is not None: + if hasattr(omeco, "min_fill_order"): + return list(omeco.min_fill_order([node.vars for node in nodes])) + graph = _build_var_graph(nodes) + if heuristic in ("min_fill", "omeco"): + return _min_fill_order(graph) + if heuristic == "min_degree": + return sorted(graph, key=lambda v: len(graph[v])) + raise ValueError(f"Unknown heuristic: {heuristic!r}") + + +def build_contraction_tree(order: Iterable[int], nodes: list[TensorNode]) -> ContractionTree: + """Prepare a contraction plan from order and leaf nodes.""" + return ContractionTree(order=tuple(order), nodes=tuple(nodes)) + + +def contract_tree( + tree: ContractionTree, + einsum_fn, + track_argmax: bool = True, +) -> TreeNode: + """Contract along the tree using the tropical einsum.""" + active_nodes: list[TreeNode] = list(tree.nodes) + for var in tree.order: + bucket = [node for node in active_nodes if var in node.vars] + if not bucket: + continue + bucket_ids = {id(node) for node in bucket} + active_nodes = [node for node in active_nodes if id(node) not in bucket_ids] + combined: TreeNode = bucket[0] + for other in bucket[1:]: + index_map = build_index_map(combined.vars, other.vars, elim_vars=()) + values, _ = einsum_fn(combined.values, other.values, index_map, track_argmax=False) + combined = ContractNode( + vars=index_map.out_vars, + values=values, + left=combined, + right=other, + elim_vars=(), + backpointer=None, + ) + if var in combined.vars: + values, backpointer = tropical_reduce_max( + combined.values, combined.vars, (var,), track_argmax=track_argmax + ) + combined = ReduceNode( + vars=tuple(v for v in combined.vars if v != var), + values=values, + child=combined, + elim_vars=(var,), + backpointer=backpointer, + ) + active_nodes.append(combined) + while len(active_nodes) > 1: + left = active_nodes.pop(0) + right = active_nodes.pop(0) + index_map = build_index_map(left.vars, right.vars, elim_vars=()) + values, _ = einsum_fn(left.values, right.values, index_map, track_argmax=False) + combined = ContractNode( + vars=index_map.out_vars, + values=values, + left=left, + right=right, + elim_vars=(), + backpointer=None, + ) + active_nodes.append(combined) + if not active_nodes: + raise ValueError("Contraction produced no nodes.") + return active_nodes[0] diff --git a/tropical_in_new/src/mpe.py b/tropical_in_new/src/mpe.py new file mode 100644 index 0000000..0ddc739 --- /dev/null +++ b/tropical_in_new/src/mpe.py @@ -0,0 +1,91 @@ +"""Top-level MPE API for tropical tensor networks.""" + +from __future__ import annotations + +from typing import Dict, Iterable + +from .contraction import ContractNode, ReduceNode, build_contraction_tree, choose_order +from .contraction import contract_tree as _contract_tree +from .network import TensorNode, build_network +from .primitives import argmax_trace, tropical_einsum, tropical_reduce_max +from .utils import UAIModel, build_tropical_factors + + +def _unravel_argmax(values, vars: Iterable[int]) -> Dict[int, int]: + vars = tuple(vars) + if not vars: + return {} + flat = int(values.reshape(-1).argmax().item()) + shape = list(values.shape) + assignments = [] + for size in reversed(shape): + assignments.append(flat % size) + flat //= size + assignments = list(reversed(assignments)) + return {var: int(val) for var, val in zip(vars, assignments)} + + +def recover_mpe_assignment(root) -> Dict[int, int]: + """Recover MPE assignment from a contraction tree with backpointers.""" + assignment: Dict[int, int] = {} + + def traverse(node, out_assignment: Dict[int, int]) -> None: + assignment.update(out_assignment) + if isinstance(node, TensorNode): + return + if isinstance(node, ReduceNode): + elim_assignment = ( + argmax_trace(node.backpointer, out_assignment) if node.backpointer else {} + ) + combined = {**out_assignment, **elim_assignment} + child_assignment = {v: combined[v] for v in node.child.vars} + traverse(node.child, child_assignment) + return + if isinstance(node, ContractNode): + elim_assignment = ( + argmax_trace(node.backpointer, out_assignment) if node.backpointer else {} + ) + combined = {**out_assignment, **elim_assignment} + left_assignment = {v: combined[v] for v in node.left.vars} + right_assignment = {v: combined[v] for v in node.right.vars} + traverse(node.left, left_assignment) + traverse(node.right, right_assignment) + + initial = _unravel_argmax(root.values, root.vars) + traverse(root, initial) + return assignment + + +def mpe_tropical( + model: UAIModel, + evidence: Dict[int, int] | None = None, + order: Iterable[int] | None = None, +) -> tuple[Dict[int, int], float, Dict[str, int | tuple[int, ...]]]: + """Return MPE assignment, score, and contraction metadata.""" + evidence = evidence or {} + factors = build_tropical_factors(model, evidence) + nodes = build_network(factors) + if order is None: + order = choose_order(nodes, heuristic="min_fill") + tree = build_contraction_tree(order, nodes) + root = _contract_tree(tree, einsum_fn=tropical_einsum) + if root.vars: + values, backpointer = tropical_reduce_max( + root.values, root.vars, tuple(root.vars), track_argmax=True + ) + root = ReduceNode( + vars=(), + values=values, + child=root, + elim_vars=tuple(root.vars), + backpointer=backpointer, + ) + assignment = recover_mpe_assignment(root) + assignment.update({int(k): int(v) for k, v in evidence.items()}) + score = float(root.values.item()) + info = { + "order": tuple(order), + "num_nodes": len(nodes), + "num_elims": len(tuple(order)), + } + return assignment, score, info diff --git a/tropical_in_new/src/network.py b/tropical_in_new/src/network.py new file mode 100644 index 0000000..aa98b7a --- /dev/null +++ b/tropical_in_new/src/network.py @@ -0,0 +1,26 @@ +"""Tensor network construction for tropical MPE.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Iterable, Tuple + +import torch + +from .primitives import safe_log + + +@dataclass(frozen=True) +class TensorNode: + """Node representing a tensor factor in log-domain.""" + + vars: Tuple[int, ...] + values: torch.Tensor + + +def build_network(factors: Iterable) -> list[TensorNode]: + """Convert factors into tensor nodes (log-domain).""" + nodes: list[TensorNode] = [] + for factor in factors: + nodes.append(TensorNode(vars=tuple(factor.vars), values=safe_log(factor.values))) + return nodes diff --git a/tropical_in_new/src/primitives.py b/tropical_in_new/src/primitives.py new file mode 100644 index 0000000..719e9b9 --- /dev/null +++ b/tropical_in_new/src/primitives.py @@ -0,0 +1,150 @@ +"""Tropical semiring primitives and backpointer helpers.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, Iterable, Tuple + +import torch + +try: # Optional accelerator; falls back to pure torch. + import tropical_gemm # type: ignore +except Exception: # pragma: no cover - optional dependency + tropical_gemm = None + + +@dataclass(frozen=True) +class IndexMap: + """Index mapping for binary tropical contractions.""" + + a_vars: Tuple[int, ...] + b_vars: Tuple[int, ...] + out_vars: Tuple[int, ...] + elim_vars: Tuple[int, ...] + + +@dataclass +class Backpointer: + """Stores argmax metadata for eliminated variables.""" + + elim_vars: Tuple[int, ...] + elim_shape: Tuple[int, ...] + out_vars: Tuple[int, ...] + argmax_flat: torch.Tensor + + +@dataclass(frozen=True) +class TropicalTensor: + """Lightweight wrapper for tropical (max-plus) tensors.""" + + vars: Tuple[int, ...] + values: torch.Tensor + + def __add__(self, other: "TropicalTensor") -> "TropicalTensor": + if self.vars != other.vars: + raise ValueError("TropicalTensor.__add__ requires identical variable order.") + return TropicalTensor(self.vars, torch.maximum(self.values, other.values)) + + +def safe_log(tensor: torch.Tensor) -> torch.Tensor: + """Convert potentials to log domain; zeros map to -inf.""" + neg_inf = torch.tensor(float("-inf"), dtype=tensor.dtype, device=tensor.device) + return torch.where(tensor > 0, torch.log(tensor), neg_inf) + + +def _align_tensor( + tensor: torch.Tensor, tensor_vars: Tuple[int, ...], target_vars: Tuple[int, ...] +) -> torch.Tensor: + if not target_vars: + return tensor.reshape(()) + if not tensor_vars: + return tensor.reshape((1,) * len(target_vars)) + present = [v for v in target_vars if v in tensor_vars] + perm = [tensor_vars.index(v) for v in present] + aligned = tensor if perm == list(range(len(tensor_vars))) else tensor.permute(perm) + shape = [] + p = 0 + for var in target_vars: + if var in tensor_vars: + shape.append(aligned.shape[p]) + p += 1 + else: + shape.append(1) + return aligned.reshape(tuple(shape)) + + +def tropical_reduce_max( + tensor: torch.Tensor, + vars: Tuple[int, ...], + elim_vars: Iterable[int], + track_argmax: bool = True, +) -> tuple[torch.Tensor, Backpointer | None]: + elim_vars = tuple(elim_vars) + if not elim_vars: + return tensor, None + target_vars = tuple(vars) + elim_axes = [target_vars.index(v) for v in elim_vars] + keep_axes = [i for i in range(len(target_vars)) if i not in elim_axes] + perm = keep_axes + elim_axes + permuted = tensor if perm == list(range(len(target_vars))) else tensor.permute(perm) + out_shape = [tensor.shape[i] for i in keep_axes] + elim_shape = [tensor.shape[i] for i in elim_axes] + flat = permuted.reshape(*out_shape, -1) + values, argmax_flat = torch.max(flat, dim=-1) + if not track_argmax: + return values, None + backpointer = Backpointer( + elim_vars=elim_vars, + elim_shape=tuple(elim_shape), + out_vars=tuple(target_vars[i] for i in keep_axes), + argmax_flat=argmax_flat, + ) + return values, backpointer + + +def tropical_einsum( + a: torch.Tensor, + b: torch.Tensor, + index_map: IndexMap, + track_argmax: bool = True, +) -> tuple[torch.Tensor, Backpointer | None]: + """Binary tropical contraction: add in log-space, max over elim_vars.""" + if tropical_gemm is not None and hasattr(tropical_gemm, "einsum"): + try: # pragma: no cover - optional dependency + return tropical_gemm.einsum(a, b, index_map, track_argmax=track_argmax) + except Exception: + pass + + target_vars = tuple(dict.fromkeys(index_map.a_vars + index_map.b_vars)) + expected_out = tuple(v for v in target_vars if v not in index_map.elim_vars) + if index_map.out_vars and index_map.out_vars != expected_out: + raise ValueError("index_map.out_vars does not match contraction result ordering.") + + aligned_a = _align_tensor(a, index_map.a_vars, target_vars) + aligned_b = _align_tensor(b, index_map.b_vars, target_vars) + combined = aligned_a + aligned_b + + if not index_map.elim_vars: + return combined, None + + values, backpointer = tropical_reduce_max( + combined, target_vars, index_map.elim_vars, track_argmax=track_argmax + ) + return values, backpointer + + +def argmax_trace(backpointer: Backpointer, assignment: Dict[int, int]) -> Dict[int, int]: + """Decode eliminated variable assignments from a backpointer.""" + if not backpointer.elim_vars: + return {} + if backpointer.out_vars: + idx = tuple(assignment[v] for v in backpointer.out_vars) + flat = int(backpointer.argmax_flat[idx].item()) + else: + flat = int(backpointer.argmax_flat.item()) + values = [] + for size in reversed(backpointer.elim_shape): + values.append(flat % size) + flat //= size + values = list(reversed(values)) + return {var: int(val) for var, val in zip(backpointer.elim_vars, values)} diff --git a/tropical_in_new/src/utils.py b/tropical_in_new/src/utils.py new file mode 100644 index 0000000..1f234ac --- /dev/null +++ b/tropical_in_new/src/utils.py @@ -0,0 +1,137 @@ +"""Utility helpers for tropical tensor networks.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, Iterable, List, Tuple + +import torch + +from .primitives import IndexMap + + +@dataclass +class Factor: + """Factor class representing a factor in the factor graph.""" + + vars: Tuple[int, ...] + values: torch.Tensor + + +@dataclass +class UAIModel: + """UAI model containing variables, cardinalities, and factors.""" + + nvars: int + cards: List[int] + factors: List[Factor] + + +def read_model_file(filepath: str, factor_eltype=torch.float64) -> UAIModel: + """Parse UAI format model file.""" + with open(filepath, "r", encoding="utf-8") as f: + content = f.read() + return read_model_from_string(content, factor_eltype=factor_eltype) + + +def read_model_from_string(content: str, factor_eltype=torch.float64) -> UAIModel: + """Parse UAI model from string.""" + lines = [line.strip() for line in content.split("\n") if line.strip()] + network_type = lines[0] + if network_type not in ("MARKOV", "BAYES"): + raise ValueError( + f"Unsupported UAI network type: {network_type!r}. Expected 'MARKOV' or 'BAYES'." + ) + nvars = int(lines[1]) + cards = [int(x) for x in lines[2].split()] + ntables = int(lines[3]) + + scopes: list[list[int]] = [] + for i in range(ntables): + parts = lines[4 + i].split() + scope_size = int(parts[0]) + if len(parts) - 1 != scope_size: + raise ValueError( + f"Scope size mismatch on line {4 + i}: " + f"declared {scope_size}, found {len(parts) - 1} variables." + ) + scope = [int(x) + 1 for x in parts[1:]] + scopes.append(scope) + + idx = 4 + ntables + tokens: List[str] = [] + while idx < len(lines): + tokens.extend(lines[idx].split()) + idx += 1 + cursor = 0 + + factors: List[Factor] = [] + for scope in scopes: + if cursor >= len(tokens): + raise ValueError("Unexpected end of UAI factor table data.") + nelements = int(tokens[cursor]) + cursor += 1 + values = torch.tensor( + [float(x) for x in tokens[cursor : cursor + nelements]], dtype=factor_eltype + ) + cursor += nelements + shape = tuple([cards[v - 1] for v in scope]) + values = values.reshape(shape) + factors.append(Factor(tuple(scope), values)) + + return UAIModel(nvars, cards, factors) + + +def read_evidence_file(filepath: str) -> Dict[int, int]: + """Parse evidence file (.evid format).""" + if not filepath: + return {} + with open(filepath, "r", encoding="utf-8") as f: + lines = f.readlines() + if not lines: + return {} + last_line = lines[-1].strip() + parts = [int(x) for x in last_line.split()] + nobsvars = parts[0] + evidence: Dict[int, int] = {} + for i in range(nobsvars): + var_idx = parts[1 + 2 * i] + 1 + var_value = parts[2 + 2 * i] + evidence[var_idx] = var_value + return evidence + + +def apply_evidence_to_factor(factor: Factor, evidence: Dict[int, int]) -> Factor: + """Clamp a factor with evidence and drop observed variables.""" + if not evidence: + return factor + index = [] + new_vars = [] + for var in factor.vars: + if var in evidence: + index.append(int(evidence[var])) + else: + index.append(slice(None)) + new_vars.append(var) + values = factor.values[tuple(index)] + if values.ndim == 0: + values = values.reshape(()) + return Factor(tuple(new_vars), values) + + +def build_tropical_factors(model: UAIModel, evidence: Dict[int, int] | None = None) -> list[Factor]: + """Apply evidence and return factors in original domain (log later).""" + evidence = evidence or {} + factors = [apply_evidence_to_factor(factor, evidence) for factor in model.factors] + return factors + + +def build_index_map( + a_vars: Iterable[int], b_vars: Iterable[int], elim_vars: Iterable[int] +) -> IndexMap: + a_vars = tuple(a_vars) + b_vars = tuple(b_vars) + elim_vars = tuple(elim_vars) + target_vars = tuple(dict.fromkeys(a_vars + b_vars)) + out_vars = tuple(v for v in target_vars if v not in elim_vars) + return IndexMap(a_vars=a_vars, b_vars=b_vars, out_vars=out_vars, elim_vars=elim_vars) diff --git a/tropical_in_new/tests/conftest.py b/tropical_in_new/tests/conftest.py new file mode 100644 index 0000000..4ffab65 --- /dev/null +++ b/tropical_in_new/tests/conftest.py @@ -0,0 +1,5 @@ +from pathlib import Path +import sys + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT)) diff --git a/tropical_in_new/tests/test_contraction.py b/tropical_in_new/tests/test_contraction.py new file mode 100644 index 0000000..2597bdd --- /dev/null +++ b/tropical_in_new/tests/test_contraction.py @@ -0,0 +1,25 @@ +import torch + +from src.contraction import build_contraction_tree, choose_order, contract_tree +from src.network import TensorNode +from src.primitives import tropical_einsum + + +def test_choose_order_returns_all_vars(): + nodes = [ + TensorNode(vars=(1, 2), values=torch.zeros((2, 2))), + TensorNode(vars=(2, 3), values=torch.zeros((2, 2))), + ] + order = choose_order(nodes, heuristic="min_fill") + assert set(order) == {1, 2, 3} + + +def test_contract_tree_reduces_to_scalar(): + nodes = [ + TensorNode(vars=(1,), values=torch.tensor([0.1, 0.0])), + TensorNode(vars=(1, 2), values=torch.tensor([[0.2, 0.3], [0.4, 0.1]])), + ] + order = [1, 2] + tree = build_contraction_tree(order, nodes) + root = contract_tree(tree, tropical_einsum) + assert root.values.numel() == 1 diff --git a/tropical_in_new/tests/test_mpe.py b/tropical_in_new/tests/test_mpe.py new file mode 100644 index 0000000..3dfb085 --- /dev/null +++ b/tropical_in_new/tests/test_mpe.py @@ -0,0 +1,69 @@ +import torch + +from src.mpe import mpe_tropical +from src.utils import read_model_from_string + + +def _brute_force_mpe(cards, factors): + best_score = float("-inf") + best_assignment = None + for x0 in range(cards[0]): + for x1 in range(cards[1]): + score = 0.0 + for vars, values in factors: + idx = [] + for v in vars: + idx.append(x0 if v == 1 else x1) + score += torch.log(values[tuple(idx)]).item() + if score > best_score: + best_score = score + best_assignment = {1: x0, 2: x1} + return best_assignment, best_score + + +def test_mpe_matches_bruteforce(): + uai = "\n".join( + [ + "MARKOV", + "2", + "2 2", + "3", + "1 0", + "1 1", + "2 0 1", + "2", + "0.6 0.4", + "2", + "0.3 0.7", + "4", + "1.2 0.2 0.2 1.2", + ] + ) + model = read_model_from_string(uai, factor_eltype=torch.float64) + assignment, score, _ = mpe_tropical(model) + brute_assignment, brute_score = _brute_force_mpe( + model.cards, [(f.vars, f.values) for f in model.factors] + ) + assert assignment == brute_assignment + assert abs(score - brute_score) < 1e-8 + + +def test_mpe_with_evidence(): + uai = "\n".join( + [ + "MARKOV", + "2", + "2 2", + "2", + "1 0", + "2 0 1", + "2", + "0.8 0.2", + "4", + "1.0 0.1 0.1 1.0", + ] + ) + model = read_model_from_string(uai, factor_eltype=torch.float64) + assignment, score, _ = mpe_tropical(model, evidence={1: 1}) + assert assignment[1] == 1 + assert isinstance(score, float) diff --git a/tropical_in_new/tests/test_primitives.py b/tropical_in_new/tests/test_primitives.py new file mode 100644 index 0000000..1fefb62 --- /dev/null +++ b/tropical_in_new/tests/test_primitives.py @@ -0,0 +1,26 @@ +import torch + +from src.primitives import IndexMap, argmax_trace, safe_log, tropical_einsum + + +def test_safe_log_zero_to_neg_inf(): + values = torch.tensor([1.0, 0.0], dtype=torch.float64) + logged = safe_log(values) + assert logged[0].item() == 0.0 + assert torch.isneginf(logged[1]) + + +def test_tropical_einsum_binary_reduce(): + # a(x) with x in {0,1}, b(x,y) with y in {0,1} + a = torch.tensor([0.2, 0.8]) + b = torch.tensor([[0.1, 0.5], [0.7, 0.3]]) + index_map = IndexMap(a_vars=(1,), b_vars=(1, 2), out_vars=(2,), elim_vars=(1,)) + values, backpointer = tropical_einsum(a, b, index_map, track_argmax=True) + + expected = torch.tensor([ + max(0.2 + 0.1, 0.8 + 0.7), + max(0.2 + 0.5, 0.8 + 0.3), + ]) + assert torch.allclose(values, expected) + recovered = argmax_trace(backpointer, {2: 0}) + assert recovered[1] in (0, 1) From d3de0a05e2ef9755153a7dc9ce2ec26afa273417 Mon Sep 17 00:00:00 2001 From: Mengkun Liu Date: Sat, 24 Jan 2026 11:29:34 +0800 Subject: [PATCH 7/8] Use omeco ordering in tropical MPE. Update docs/imports and improve validation coverage. --- tropical_in_new/README.md | 19 ++- tropical_in_new/docs/api_reference.md | 4 +- tropical_in_new/docs/usage_guide.md | 13 +- tropical_in_new/examples/asia_network/main.py | 8 +- tropical_in_new/requirements.txt | 2 +- tropical_in_new/src/__init__.py | 10 +- tropical_in_new/src/contraction.py | 134 +++++++++++------- tropical_in_new/src/mpe.py | 13 +- tropical_in_new/src/primitives.py | 25 ++-- tropical_in_new/src/utils.py | 22 +++ tropical_in_new/tests/conftest.py | 4 +- tropical_in_new/tests/test_contraction.py | 8 +- tropical_in_new/tests/test_mpe.py | 5 +- tropical_in_new/tests/test_primitives.py | 2 +- tropical_in_new/tests/test_utils.py | 10 ++ 15 files changed, 185 insertions(+), 94 deletions(-) create mode 100644 tropical_in_new/tests/test_utils.py diff --git a/tropical_in_new/README.md b/tropical_in_new/README.md index 7898009..087aca1 100644 --- a/tropical_in_new/README.md +++ b/tropical_in_new/README.md @@ -1,8 +1,12 @@ # Tropical Tensor Network for MPE This folder contains an independent implementation of tropical tensor network -contraction for Most Probable Explanation (MPE). It does not depend on the -`bpdecoderplus` package; all code lives under `tropical_in_new/src`. +contraction for Most Probable Explanation (MPE). It uses `omeco` for contraction +order optimization and does not depend on the `bpdecoderplus` package; all code +lives under `tropical_in_new/src`. + +`omeco` provides high-quality contraction order heuristics (greedy and +simulated annealing). Install it alongside Torch to run the examples and tests. ## Structure @@ -20,7 +24,8 @@ tropical_in_new/ ├── tests/ │ ├── test_primitives.py │ ├── test_contraction.py -│ └── test_mpe.py +│ ├── test_mpe.py +│ └── test_utils.py ├── examples/ │ └── asia_network/ │ ├── main.py @@ -34,5 +39,13 @@ tropical_in_new/ ## Quick Start ```bash +pip install -r tropical_in_new/requirements.txt python tropical_in_new/examples/asia_network/main.py ``` + +## Notes on omeco + +`omeco` is a Rust-backed Python package. If a prebuilt wheel is not available +for your Python version, you will need a Rust toolchain with `cargo` on PATH to +build it from source. See the omeco repository for details: +https://github.com/GiggleLiu/omeco diff --git a/tropical_in_new/docs/api_reference.md b/tropical_in_new/docs/api_reference.md index f3358eb..a137122 100644 --- a/tropical_in_new/docs/api_reference.md +++ b/tropical_in_new/docs/api_reference.md @@ -34,8 +34,8 @@ Public APIs exported from `tropical_in_new/src`. - `build_network(factors: Iterable[Factor]) -> list[TensorNode]` Convert factors into log-domain tensors with scopes. -- `choose_order(nodes, heuristic="min_fill") -> list[int]` - Select variable elimination order (min-fill / min-degree). +- `choose_order(nodes, heuristic="omeco") -> list[int]` + Select variable elimination order using `omeco`. - `build_contraction_tree(order, nodes) -> ContractionTree` Construct a contraction plan from order and nodes. diff --git a/tropical_in_new/docs/usage_guide.md b/tropical_in_new/docs/usage_guide.md index 280b091..c34130c 100644 --- a/tropical_in_new/docs/usage_guide.md +++ b/tropical_in_new/docs/usage_guide.md @@ -3,10 +3,19 @@ This guide shows how to parse a UAI model and compute MPE using the tropical tensor network implementation. +### Install Dependencies + +```bash +pip install -r tropical_in_new/requirements.txt +``` + +`omeco` provides contraction order optimization. If a prebuilt wheel is not +available for your Python version, you may need a Rust toolchain installed. + ### Quick Start ```python -from src import mpe_tropical, read_model_file +from tropical_in_new.src import mpe_tropical, read_model_file model = read_model_file("tropical_in_new/examples/asia_network/model.uai") assignment, score, info = mpe_tropical(model) @@ -16,7 +25,7 @@ print(assignment, score, info) ### Evidence ```python -from src import mpe_tropical, read_model_file +from tropical_in_new.src import mpe_tropical, read_model_file model = read_model_file("tropical_in_new/examples/asia_network/model.uai") evidence = {1: 0} # variable index is 1-based diff --git a/tropical_in_new/examples/asia_network/main.py b/tropical_in_new/examples/asia_network/main.py index 4bfcef0..08fe0eb 100644 --- a/tropical_in_new/examples/asia_network/main.py +++ b/tropical_in_new/examples/asia_network/main.py @@ -1,12 +1,6 @@ """Run tropical MPE on a small UAI model.""" -from pathlib import Path -import sys - -ROOT = Path(__file__).resolve().parents[2] -sys.path.insert(0, str(ROOT)) - -from src import mpe_tropical, read_model_file # noqa: E402 +from tropical_in_new.src import mpe_tropical, read_model_file def main() -> None: diff --git a/tropical_in_new/requirements.txt b/tropical_in_new/requirements.txt index 5259e95..89b7694 100644 --- a/tropical_in_new/requirements.txt +++ b/tropical_in_new/requirements.txt @@ -1,2 +1,2 @@ torch>=2.0.0 -numpy>=1.24.0 +omeco diff --git a/tropical_in_new/src/__init__.py b/tropical_in_new/src/__init__.py index 3689064..2522ab8 100644 --- a/tropical_in_new/src/__init__.py +++ b/tropical_in_new/src/__init__.py @@ -4,7 +4,14 @@ from .mpe import mpe_tropical, recover_mpe_assignment from .network import TensorNode, build_network from .primitives import argmax_trace, safe_log, tropical_einsum -from .utils import Factor, UAIModel, build_tropical_factors, read_model_file, read_model_from_string +from .utils import ( + Factor, + UAIModel, + build_tropical_factors, + read_evidence_file, + read_model_file, + read_model_from_string, +) __all__ = [ "Factor", @@ -17,6 +24,7 @@ "choose_order", "contract_tree", "mpe_tropical", + "read_evidence_file", "read_model_file", "read_model_from_string", "recover_mpe_assignment", diff --git a/tropical_in_new/src/contraction.py b/tropical_in_new/src/contraction.py index b1b0a2e..4f32a2f 100644 --- a/tropical_in_new/src/contraction.py +++ b/tropical_in_new/src/contraction.py @@ -7,10 +7,7 @@ import torch -try: # Optional heuristic provider - import omeco # type: ignore -except Exception: # pragma: no cover - optional dependency - omeco = None +import omeco from .network import TensorNode from .primitives import Backpointer, tropical_reduce_max @@ -45,59 +42,86 @@ class ContractionTree: nodes: Tuple[TensorNode, ...] -def _build_var_graph(nodes: Iterable[TensorNode]) -> dict[int, set[int]]: - graph: dict[int, set[int]] = {} +def _infer_var_sizes(nodes: Iterable[TensorNode]) -> dict[int, int]: + sizes: dict[int, int] = {} for node in nodes: - vars = list(node.vars) + for var, dim in zip(node.vars, node.values.shape): + if var in sizes and sizes[var] != dim: + raise ValueError( + f"Variable {var} has inconsistent sizes: {sizes[var]} vs {dim}." + ) + sizes[var] = int(dim) + return sizes + + +def _extract_leaf_index(node_dict: dict) -> int | None: + for key in ("leaf", "leaf_index", "index", "tensor"): + if key in node_dict: + value = node_dict[key] + if isinstance(value, int): + return value + return None + + +def _elim_order_from_tree_dict(tree_dict: dict, ixs: list[list[int]]) -> list[int]: + total_counts: dict[int, int] = {} + for vars in ixs: for var in vars: - graph.setdefault(var, set()).update(v for v in vars if v != var) - return graph - - -def _min_fill_order(graph: dict[int, set[int]]) -> list[int]: - order: list[int] = [] - graph = {k: set(v) for k, v in graph.items()} - while graph: - best_var = None - best_fill = None - best_degree = None - for var, neighbors in graph.items(): - fill = 0 - neighbor_list = list(neighbors) - for i in range(len(neighbor_list)): - for j in range(i + 1, len(neighbor_list)): - if neighbor_list[j] not in graph[neighbor_list[i]]: - fill += 1 - degree = len(neighbors) - if best_fill is None or (fill, degree) < (best_fill, best_degree): - best_var = var - best_fill = fill - best_degree = degree - if best_var is None: - break - neighbors = list(graph[best_var]) - for i in range(len(neighbors)): - for j in range(i + 1, len(neighbors)): - graph[neighbors[i]].add(neighbors[j]) - graph[neighbors[j]].add(neighbors[i]) - for neighbor in neighbors: - graph[neighbor].discard(best_var) - graph.pop(best_var, None) - order.append(best_var) - return order - - -def choose_order(nodes: list[TensorNode], heuristic: str = "min_fill") -> list[int]: - """Select elimination order over variable indices.""" - if heuristic == "omeco" and omeco is not None: - if hasattr(omeco, "min_fill_order"): - return list(omeco.min_fill_order([node.vars for node in nodes])) - graph = _build_var_graph(nodes) - if heuristic in ("min_fill", "omeco"): - return _min_fill_order(graph) - if heuristic == "min_degree": - return sorted(graph, key=lambda v: len(graph[v])) - raise ValueError(f"Unknown heuristic: {heuristic!r}") + total_counts[var] = total_counts.get(var, 0) + 1 + + eliminated: set[int] = set() + + def visit(node: dict) -> tuple[dict[int, int], list[int]]: + leaf_index = _extract_leaf_index(node) + if leaf_index is not None: + counts: dict[int, int] = {} + for var in ixs[leaf_index]: + counts[var] = counts.get(var, 0) + 1 + return counts, [] + + children = node.get("children", []) + if not isinstance(children, list) or not children: + return {}, [] + + counts: dict[int, int] = {} + order: list[int] = [] + for child in children: + child_counts, child_order = visit(child) + order.extend(child_order) + for var, count in child_counts.items(): + counts[var] = counts.get(var, 0) + count + + newly_eliminated = [ + var + for var, count in counts.items() + if count == total_counts.get(var, 0) and var not in eliminated + ] + for var in sorted(newly_eliminated): + eliminated.add(var) + order.append(var) + return counts, order + + _, order = visit(tree_dict) + remaining = sorted([var for var in total_counts if var not in eliminated]) + return order + remaining + + +def choose_order(nodes: list[TensorNode], heuristic: str = "omeco") -> list[int]: + """Select elimination order over variable indices using omeco.""" + if heuristic != "omeco": + raise ValueError("Only the 'omeco' heuristic is supported.") + ixs = [list(node.vars) for node in nodes] + sizes = _infer_var_sizes(nodes) + method = omeco.GreedyMethod() if hasattr(omeco, "GreedyMethod") else None + tree = ( + omeco.optimize_code(ixs, [], sizes, method) + if method is not None + else omeco.optimize_code(ixs, [], sizes) + ) + tree_dict = tree.to_dict() if hasattr(tree, "to_dict") else tree + if not isinstance(tree_dict, dict): + raise ValueError("omeco.optimize_code did not return a usable tree.") + return _elim_order_from_tree_dict(tree_dict, ixs) def build_contraction_tree(order: Iterable[int], nodes: list[TensorNode]) -> ContractionTree: diff --git a/tropical_in_new/src/mpe.py b/tropical_in_new/src/mpe.py index 0ddc739..8f64aeb 100644 --- a/tropical_in_new/src/mpe.py +++ b/tropical_in_new/src/mpe.py @@ -29,6 +29,14 @@ def recover_mpe_assignment(root) -> Dict[int, int]: """Recover MPE assignment from a contraction tree with backpointers.""" assignment: Dict[int, int] = {} + def require_vars(required: Iterable[int], available: Dict[int, int]) -> None: + missing = [v for v in required if v not in available] + if missing: + raise KeyError( + "Missing assignment values for variables: " + f"{missing}. Provided assignment keys: {sorted(available.keys())}" + ) + def traverse(node, out_assignment: Dict[int, int]) -> None: assignment.update(out_assignment) if isinstance(node, TensorNode): @@ -38,6 +46,7 @@ def traverse(node, out_assignment: Dict[int, int]) -> None: argmax_trace(node.backpointer, out_assignment) if node.backpointer else {} ) combined = {**out_assignment, **elim_assignment} + require_vars(node.child.vars, combined) child_assignment = {v: combined[v] for v in node.child.vars} traverse(node.child, child_assignment) return @@ -46,7 +55,9 @@ def traverse(node, out_assignment: Dict[int, int]) -> None: argmax_trace(node.backpointer, out_assignment) if node.backpointer else {} ) combined = {**out_assignment, **elim_assignment} + require_vars(node.left.vars, combined) left_assignment = {v: combined[v] for v in node.left.vars} + require_vars(node.right.vars, combined) right_assignment = {v: combined[v] for v in node.right.vars} traverse(node.left, left_assignment) traverse(node.right, right_assignment) @@ -66,7 +77,7 @@ def mpe_tropical( factors = build_tropical_factors(model, evidence) nodes = build_network(factors) if order is None: - order = choose_order(nodes, heuristic="min_fill") + order = choose_order(nodes, heuristic="omeco") tree = build_contraction_tree(order, nodes) root = _contract_tree(tree, einsum_fn=tropical_einsum) if root.vars: diff --git a/tropical_in_new/src/primitives.py b/tropical_in_new/src/primitives.py index 719e9b9..46fb332 100644 --- a/tropical_in_new/src/primitives.py +++ b/tropical_in_new/src/primitives.py @@ -33,19 +33,6 @@ class Backpointer: argmax_flat: torch.Tensor -@dataclass(frozen=True) -class TropicalTensor: - """Lightweight wrapper for tropical (max-plus) tensors.""" - - vars: Tuple[int, ...] - values: torch.Tensor - - def __add__(self, other: "TropicalTensor") -> "TropicalTensor": - if self.vars != other.vars: - raise ValueError("TropicalTensor.__add__ requires identical variable order.") - return TropicalTensor(self.vars, torch.maximum(self.values, other.values)) - - def safe_log(tensor: torch.Tensor) -> torch.Tensor: """Convert potentials to log domain; zeros map to -inf.""" neg_inf = torch.tensor(float("-inf"), dtype=tensor.dtype, device=tensor.device) @@ -83,6 +70,12 @@ def tropical_reduce_max( if not elim_vars: return tensor, None target_vars = tuple(vars) + missing_elim_vars = [v for v in elim_vars if v not in target_vars] + if missing_elim_vars: + raise ValueError( + "tropical_reduce_max: elim_vars " + f"{missing_elim_vars} are not present in vars {target_vars}." + ) elim_axes = [target_vars.index(v) for v in elim_vars] keep_axes = [i for i in range(len(target_vars)) if i not in elim_axes] perm = keep_axes + elim_axes @@ -138,6 +131,12 @@ def argmax_trace(backpointer: Backpointer, assignment: Dict[int, int]) -> Dict[i if not backpointer.elim_vars: return {} if backpointer.out_vars: + missing = [v for v in backpointer.out_vars if v not in assignment] + if missing: + raise KeyError( + "Missing assignment values for output variables: " + f"{missing}. Provided assignment keys: {sorted(assignment.keys())}" + ) idx = tuple(assignment[v] for v in backpointer.out_vars) flat = int(backpointer.argmax_flat[idx].item()) else: diff --git a/tropical_in_new/src/utils.py b/tropical_in_new/src/utils.py index 1f234ac..1438965 100644 --- a/tropical_in_new/src/utils.py +++ b/tropical_in_new/src/utils.py @@ -37,6 +37,8 @@ def read_model_file(filepath: str, factor_eltype=torch.float64) -> UAIModel: def read_model_from_string(content: str, factor_eltype=torch.float64) -> UAIModel: """Parse UAI model from string.""" lines = [line.strip() for line in content.split("\n") if line.strip()] + if len(lines) < 4: + raise ValueError("Malformed UAI model: expected at least 4 header lines.") network_type = lines[0] if network_type not in ("MARKOV", "BAYES"): raise ValueError( @@ -44,7 +46,13 @@ def read_model_from_string(content: str, factor_eltype=torch.float64) -> UAIMode ) nvars = int(lines[1]) cards = [int(x) for x in lines[2].split()] + if len(cards) != nvars: + raise ValueError(f"Expected {nvars} cardinalities, got {len(cards)}.") ntables = int(lines[3]) + if len(lines) < 4 + ntables: + raise ValueError( + f"Malformed UAI model: expected {ntables} scope lines, got {len(lines) - 4}." + ) scopes: list[list[int]] = [] for i in range(ntables): @@ -71,6 +79,16 @@ def read_model_from_string(content: str, factor_eltype=torch.float64) -> UAIMode raise ValueError("Unexpected end of UAI factor table data.") nelements = int(tokens[cursor]) cursor += 1 + expected_size = 1 + for card in (cards[v - 1] for v in scope): + expected_size *= card + if nelements != expected_size: + raise ValueError( + f"Factor table size mismatch for scope {scope}: " + f"expected {expected_size}, got {nelements}." + ) + if cursor + nelements > len(tokens): + raise ValueError("Unexpected end of UAI factor table data.") values = torch.tensor( [float(x) for x in tokens[cursor : cursor + nelements]], dtype=factor_eltype ) @@ -93,6 +111,10 @@ def read_evidence_file(filepath: str) -> Dict[int, int]: last_line = lines[-1].strip() parts = [int(x) for x in last_line.split()] nobsvars = parts[0] + if len(parts) < 1 + 2 * nobsvars: + raise ValueError( + f"Malformed evidence line: expected {1 + 2 * nobsvars} entries, got {len(parts)}." + ) evidence: Dict[int, int] = {} for i in range(nobsvars): var_idx = parts[1 + 2 * i] + 1 diff --git a/tropical_in_new/tests/conftest.py b/tropical_in_new/tests/conftest.py index 4ffab65..fbc1e9e 100644 --- a/tropical_in_new/tests/conftest.py +++ b/tropical_in_new/tests/conftest.py @@ -1,5 +1,5 @@ from pathlib import Path import sys -ROOT = Path(__file__).resolve().parents[1] -sys.path.insert(0, str(ROOT)) +REPO_ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(REPO_ROOT)) diff --git a/tropical_in_new/tests/test_contraction.py b/tropical_in_new/tests/test_contraction.py index 2597bdd..9c58dbc 100644 --- a/tropical_in_new/tests/test_contraction.py +++ b/tropical_in_new/tests/test_contraction.py @@ -1,8 +1,8 @@ import torch -from src.contraction import build_contraction_tree, choose_order, contract_tree -from src.network import TensorNode -from src.primitives import tropical_einsum +from tropical_in_new.src.contraction import build_contraction_tree, choose_order, contract_tree +from tropical_in_new.src.network import TensorNode +from tropical_in_new.src.primitives import tropical_einsum def test_choose_order_returns_all_vars(): @@ -10,7 +10,7 @@ def test_choose_order_returns_all_vars(): TensorNode(vars=(1, 2), values=torch.zeros((2, 2))), TensorNode(vars=(2, 3), values=torch.zeros((2, 2))), ] - order = choose_order(nodes, heuristic="min_fill") + order = choose_order(nodes, heuristic="omeco") assert set(order) == {1, 2, 3} diff --git a/tropical_in_new/tests/test_mpe.py b/tropical_in_new/tests/test_mpe.py index 3dfb085..638d5d9 100644 --- a/tropical_in_new/tests/test_mpe.py +++ b/tropical_in_new/tests/test_mpe.py @@ -1,10 +1,11 @@ import torch -from src.mpe import mpe_tropical -from src.utils import read_model_from_string +from tropical_in_new.src.mpe import mpe_tropical +from tropical_in_new.src.utils import read_model_from_string def _brute_force_mpe(cards, factors): + # Helper for this two-variable test case. best_score = float("-inf") best_assignment = None for x0 in range(cards[0]): diff --git a/tropical_in_new/tests/test_primitives.py b/tropical_in_new/tests/test_primitives.py index 1fefb62..d459780 100644 --- a/tropical_in_new/tests/test_primitives.py +++ b/tropical_in_new/tests/test_primitives.py @@ -1,6 +1,6 @@ import torch -from src.primitives import IndexMap, argmax_trace, safe_log, tropical_einsum +from tropical_in_new.src.primitives import IndexMap, argmax_trace, safe_log, tropical_einsum def test_safe_log_zero_to_neg_inf(): diff --git a/tropical_in_new/tests/test_utils.py b/tropical_in_new/tests/test_utils.py new file mode 100644 index 0000000..57f9c57 --- /dev/null +++ b/tropical_in_new/tests/test_utils.py @@ -0,0 +1,10 @@ +from tropical_in_new.src.utils import read_evidence_file + + +def test_read_evidence_file(tmp_path): + content = "2 0 1 3 0\n" + filepath = tmp_path / "example.evid" + filepath.write_text(content, encoding="utf-8") + + evidence = read_evidence_file(str(filepath)) + assert evidence == {1: 1, 4: 0} From 24bd0257c2e2a15993238bf224055e6cb8bc7711 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Sat, 24 Jan 2026 12:10:54 +0800 Subject: [PATCH 8/8] Resolve PR #51 review comments - Use einsum with elimination directly in contract_tree instead of separate ReduceNode for multi-node buckets - Generalize _brute_force_mpe with itertools.product for any number of variables - Add assert statements in utils.py for internal consistency checks - Add comprehensive tests for error paths, edge cases, and read_evidence_file/read_model_file (coverage now 98%) Co-Authored-By: Claude Opus 4.5 --- tropical_in_new/src/contraction.py | 16 +++- tropical_in_new/src/utils.py | 8 ++ tropical_in_new/tests/test_contraction.py | 80 +++++++++++++++- tropical_in_new/tests/test_mpe.py | 107 +++++++++++++++++++--- tropical_in_new/tests/test_primitives.py | 74 ++++++++++++++- tropical_in_new/tests/test_utils.py | 55 ++++++++++- 6 files changed, 319 insertions(+), 21 deletions(-) diff --git a/tropical_in_new/src/contraction.py b/tropical_in_new/src/contraction.py index 4f32a2f..042db6e 100644 --- a/tropical_in_new/src/contraction.py +++ b/tropical_in_new/src/contraction.py @@ -143,18 +143,24 @@ def contract_tree( bucket_ids = {id(node) for node in bucket} active_nodes = [node for node in active_nodes if id(node) not in bucket_ids] combined: TreeNode = bucket[0] - for other in bucket[1:]: - index_map = build_index_map(combined.vars, other.vars, elim_vars=()) - values, _ = einsum_fn(combined.values, other.values, index_map, track_argmax=False) + for i, other in enumerate(bucket[1:]): + is_last = i == len(bucket) - 2 + elim_vars = (var,) if is_last else () + index_map = build_index_map(combined.vars, other.vars, elim_vars=elim_vars) + values, backpointer = einsum_fn( + combined.values, other.values, index_map, + track_argmax=track_argmax if is_last else False, + ) combined = ContractNode( vars=index_map.out_vars, values=values, left=combined, right=other, - elim_vars=(), - backpointer=None, + elim_vars=elim_vars, + backpointer=backpointer, ) if var in combined.vars: + # Single-node bucket: eliminate via reduce values, backpointer = tropical_reduce_max( combined.values, combined.vars, (var,), track_argmax=track_argmax ) diff --git a/tropical_in_new/src/utils.py b/tropical_in_new/src/utils.py index 1438965..fd7519d 100644 --- a/tropical_in_new/src/utils.py +++ b/tropical_in_new/src/utils.py @@ -64,7 +64,11 @@ def read_model_from_string(content: str, factor_eltype=torch.float64) -> UAIMode f"declared {scope_size}, found {len(parts) - 1} variables." ) scope = [int(x) + 1 for x in parts[1:]] + assert all(1 <= v <= nvars for v in scope), ( + f"Scope variables out of range [1, {nvars}]: {scope}" + ) scopes.append(scope) + assert len(scopes) == ntables idx = 4 + ntables tokens: List[str] = [] @@ -95,8 +99,10 @@ def read_model_from_string(content: str, factor_eltype=torch.float64) -> UAIMode cursor += nelements shape = tuple([cards[v - 1] for v in scope]) values = values.reshape(shape) + assert values.shape == shape factors.append(Factor(tuple(scope), values)) + assert len(factors) == ntables return UAIModel(nvars, cards, factors) @@ -119,7 +125,9 @@ def read_evidence_file(filepath: str) -> Dict[int, int]: for i in range(nobsvars): var_idx = parts[1 + 2 * i] + 1 var_value = parts[2 + 2 * i] + assert var_idx >= 1, f"Invalid variable index: {var_idx - 1}" evidence[var_idx] = var_value + assert len(evidence) == nobsvars return evidence diff --git a/tropical_in_new/tests/test_contraction.py b/tropical_in_new/tests/test_contraction.py index 9c58dbc..5ff4279 100644 --- a/tropical_in_new/tests/test_contraction.py +++ b/tropical_in_new/tests/test_contraction.py @@ -1,6 +1,14 @@ +import pytest import torch -from tropical_in_new.src.contraction import build_contraction_tree, choose_order, contract_tree +from tropical_in_new.src.contraction import ( + _elim_order_from_tree_dict, + _extract_leaf_index, + _infer_var_sizes, + build_contraction_tree, + choose_order, + contract_tree, +) from tropical_in_new.src.network import TensorNode from tropical_in_new.src.primitives import tropical_einsum @@ -14,6 +22,12 @@ def test_choose_order_returns_all_vars(): assert set(order) == {1, 2, 3} +def test_choose_order_invalid_heuristic(): + nodes = [TensorNode(vars=(1,), values=torch.zeros((2,)))] + with pytest.raises(ValueError, match="Only the 'omeco' heuristic"): + choose_order(nodes, heuristic="invalid") + + def test_contract_tree_reduces_to_scalar(): nodes = [ TensorNode(vars=(1,), values=torch.tensor([0.1, 0.0])), @@ -23,3 +37,67 @@ def test_contract_tree_reduces_to_scalar(): tree = build_contraction_tree(order, nodes) root = contract_tree(tree, tropical_einsum) assert root.values.numel() == 1 + + +def test_contract_tree_three_nodes_shared_var(): + """Test contraction with 3 nodes sharing a variable (uses einsum with elimination).""" + nodes = [ + TensorNode(vars=(1, 2), values=torch.tensor([[0.1, 0.2], [0.3, 0.4]])), + TensorNode(vars=(2, 3), values=torch.tensor([[0.5, 0.6], [0.7, 0.8]])), + TensorNode(vars=(2,), values=torch.tensor([0.1, 0.9])), + ] + order = [2, 1, 3] + tree = build_contraction_tree(order, nodes) + root = contract_tree(tree, tropical_einsum) + assert root.values.numel() == 1 + + +def test_contract_tree_partial_order(): + """Test contraction where order doesn't cover all vars (remaining merged).""" + nodes = [ + TensorNode(vars=(1, 2), values=torch.tensor([[0.1, 0.2], [0.3, 0.4]])), + TensorNode(vars=(3,), values=torch.tensor([0.5, 0.6])), + ] + order = [1] + tree = build_contraction_tree(order, nodes) + root = contract_tree(tree, tropical_einsum) + # var 2 and 3 remain + assert 2 in root.vars or 3 in root.vars + + +def test_infer_var_sizes_inconsistent(): + nodes = [ + TensorNode(vars=(1,), values=torch.zeros((2,))), + TensorNode(vars=(1,), values=torch.zeros((3,))), + ] + with pytest.raises(ValueError, match="inconsistent sizes"): + _infer_var_sizes(nodes) + + +def test_extract_leaf_index(): + assert _extract_leaf_index({"leaf": 0}) == 0 + assert _extract_leaf_index({"leaf_index": 2}) == 2 + assert _extract_leaf_index({"index": 1}) == 1 + assert _extract_leaf_index({"tensor": 3}) == 3 + assert _extract_leaf_index({"other": "abc"}) is None + assert _extract_leaf_index({"leaf": "not_int"}) is None + + +def test_elim_order_from_tree_dict(): + tree_dict = { + "children": [ + {"leaf": 0}, + {"leaf": 1}, + ] + } + ixs = [[1, 2], [2, 3]] + order = _elim_order_from_tree_dict(tree_dict, ixs) + assert set(order) == {1, 2, 3} + + +def test_elim_order_from_tree_dict_no_children(): + tree_dict = {"other_key": "value"} + ixs = [[1, 2]] + order = _elim_order_from_tree_dict(tree_dict, ixs) + # No children → remaining vars appended + assert set(order) == {1, 2} diff --git a/tropical_in_new/tests/test_mpe.py b/tropical_in_new/tests/test_mpe.py index 638d5d9..c28117b 100644 --- a/tropical_in_new/tests/test_mpe.py +++ b/tropical_in_new/tests/test_mpe.py @@ -1,24 +1,25 @@ +import itertools + +import pytest import torch -from tropical_in_new.src.mpe import mpe_tropical +from tropical_in_new.src.mpe import mpe_tropical, recover_mpe_assignment +from tropical_in_new.src.network import TensorNode from tropical_in_new.src.utils import read_model_from_string def _brute_force_mpe(cards, factors): - # Helper for this two-variable test case. + """General brute-force MPE over any number of variables.""" best_score = float("-inf") best_assignment = None - for x0 in range(cards[0]): - for x1 in range(cards[1]): - score = 0.0 - for vars, values in factors: - idx = [] - for v in vars: - idx.append(x0 if v == 1 else x1) - score += torch.log(values[tuple(idx)]).item() - if score > best_score: - best_score = score - best_assignment = {1: x0, 2: x1} + for combo in itertools.product(*(range(c) for c in cards)): + score = 0.0 + for vars, values in factors: + idx = tuple(combo[v - 1] for v in vars) + score += torch.log(values[idx]).item() + if score > best_score: + best_score = score + best_assignment = {i + 1: combo[i] for i in range(len(cards))} return best_assignment, best_score @@ -49,6 +50,86 @@ def test_mpe_matches_bruteforce(): assert abs(score - brute_score) < 1e-8 +def test_mpe_three_variables(): + """Test MPE on a 3-variable model to exercise general brute-force.""" + uai = "\n".join( + [ + "MARKOV", + "3", + "2 2 2", + "3", + "2 0 1", + "2 1 2", + "1 2", + "4", + "1.0 0.2 0.3 0.9", + "4", + "0.8 0.1 0.2 0.7", + "2", + "0.4 0.6", + ] + ) + model = read_model_from_string(uai, factor_eltype=torch.float64) + assignment, score, _ = mpe_tropical(model) + brute_assignment, brute_score = _brute_force_mpe( + model.cards, [(f.vars, f.values) for f in model.factors] + ) + assert assignment == brute_assignment + assert abs(score - brute_score) < 1e-8 + + +def test_mpe_partial_order(): + """Test MPE with a partial elimination order (remaining vars reduced at end).""" + uai = "\n".join( + [ + "MARKOV", + "2", + "2 2", + "2", + "1 0", + "2 0 1", + "2", + "0.6 0.4", + "4", + "1.2 0.2 0.2 1.2", + ] + ) + model = read_model_from_string(uai, factor_eltype=torch.float64) + # Only eliminate var 1, leaving var 2 for final reduce + assignment, score, _ = mpe_tropical(model, order=[1]) + brute_assignment, brute_score = _brute_force_mpe( + model.cards, [(f.vars, f.values) for f in model.factors] + ) + assert assignment == brute_assignment + assert abs(score - brute_score) < 1e-8 + + +def test_recover_mpe_assignment_tensor_node(): + """Test recover_mpe_assignment directly on a TensorNode with vars.""" + node = TensorNode( + vars=(1, 2), + values=torch.tensor([[0.1, 0.9], [0.3, 0.2]]), + ) + assignment = recover_mpe_assignment(node) + # The max is at (0, 1) → var 1=0, var 2=1 + assert assignment == {1: 0, 2: 1} + + +def test_recover_mpe_assignment_bad_node(): + """Test recover_mpe_assignment raises on missing variables.""" + from tropical_in_new.src.contraction import ReduceNode + from tropical_in_new.src.primitives import Backpointer + + child = TensorNode(vars=(1, 2), values=torch.tensor([[0.1, 0.9], [0.3, 0.2]])) + bp = Backpointer( + elim_vars=(2,), elim_shape=(2,), out_vars=(99,), # bad out_vars + argmax_flat=torch.tensor([1, 0]) + ) + root = ReduceNode(vars=(), values=torch.tensor(0.9), child=child, elim_vars=(2,), backpointer=bp) + with pytest.raises(KeyError, match="Missing assignment"): + recover_mpe_assignment(root) + + def test_mpe_with_evidence(): uai = "\n".join( [ diff --git a/tropical_in_new/tests/test_primitives.py b/tropical_in_new/tests/test_primitives.py index d459780..e44d7ac 100644 --- a/tropical_in_new/tests/test_primitives.py +++ b/tropical_in_new/tests/test_primitives.py @@ -1,6 +1,13 @@ +import pytest import torch -from tropical_in_new.src.primitives import IndexMap, argmax_trace, safe_log, tropical_einsum +from tropical_in_new.src.primitives import ( + IndexMap, + argmax_trace, + safe_log, + tropical_einsum, + tropical_reduce_max, +) def test_safe_log_zero_to_neg_inf(): @@ -24,3 +31,68 @@ def test_tropical_einsum_binary_reduce(): assert torch.allclose(values, expected) recovered = argmax_trace(backpointer, {2: 0}) assert recovered[1] in (0, 1) + + +def test_tropical_einsum_no_elim(): + a = torch.tensor([0.1, 0.2]) + b = torch.tensor([0.3, 0.4]) + index_map = IndexMap(a_vars=(1,), b_vars=(2,), out_vars=(1, 2), elim_vars=()) + values, backpointer = tropical_einsum(a, b, index_map) + assert values.shape == (2, 2) + assert backpointer is None + + +def test_tropical_einsum_out_vars_mismatch(): + a = torch.tensor([0.1, 0.2]) + b = torch.tensor([0.3, 0.4]) + index_map = IndexMap(a_vars=(1,), b_vars=(2,), out_vars=(2, 1), elim_vars=()) + with pytest.raises(ValueError, match="out_vars does not match"): + tropical_einsum(a, b, index_map) + + +def test_tropical_reduce_max_no_elim(): + t = torch.tensor([1.0, 2.0]) + values, bp = tropical_reduce_max(t, (1,), []) + assert torch.equal(values, t) + assert bp is None + + +def test_tropical_reduce_max_missing_var(): + t = torch.tensor([1.0, 2.0]) + with pytest.raises(ValueError, match="not present in vars"): + tropical_reduce_max(t, (1,), (99,)) + + +def test_tropical_reduce_max_no_track(): + t = torch.tensor([[1.0, 3.0], [2.0, 4.0]]) + values, bp = tropical_reduce_max(t, (1, 2), (2,), track_argmax=False) + assert values.shape == (2,) + assert bp is None + + +def test_argmax_trace_no_elim_vars(): + from tropical_in_new.src.primitives import Backpointer + bp = Backpointer(elim_vars=(), elim_shape=(), out_vars=(), argmax_flat=torch.tensor(0)) + result = argmax_trace(bp, {}) + assert result == {} + + +def test_argmax_trace_missing_key(): + from tropical_in_new.src.primitives import Backpointer + bp = Backpointer( + elim_vars=(2,), elim_shape=(3,), out_vars=(1,), + argmax_flat=torch.tensor([0, 1, 2]) + ) + with pytest.raises(KeyError, match="Missing assignment"): + argmax_trace(bp, {}) + + +def test_argmax_trace_scalar_backpointer(): + """Test argmax_trace when out_vars is empty (scalar result).""" + from tropical_in_new.src.primitives import Backpointer + bp = Backpointer( + elim_vars=(1,), elim_shape=(3,), out_vars=(), + argmax_flat=torch.tensor(2) + ) + result = argmax_trace(bp, {}) + assert result == {1: 2} diff --git a/tropical_in_new/tests/test_utils.py b/tropical_in_new/tests/test_utils.py index 57f9c57..d37061c 100644 --- a/tropical_in_new/tests/test_utils.py +++ b/tropical_in_new/tests/test_utils.py @@ -1,4 +1,6 @@ -from tropical_in_new.src.utils import read_evidence_file +import pytest + +from tropical_in_new.src.utils import read_evidence_file, read_model_file, read_model_from_string def test_read_evidence_file(tmp_path): @@ -8,3 +10,54 @@ def test_read_evidence_file(tmp_path): evidence = read_evidence_file(str(filepath)) assert evidence == {1: 1, 4: 0} + + +def test_read_evidence_file_empty_path(): + assert read_evidence_file("") == {} + + +def test_read_evidence_file_empty_content(tmp_path): + filepath = tmp_path / "empty.evid" + filepath.write_text("", encoding="utf-8") + assert read_evidence_file(str(filepath)) == {} + + +def test_read_evidence_file_malformed(tmp_path): + filepath = tmp_path / "bad.evid" + filepath.write_text("2 0 1\n", encoding="utf-8") # declares 2 obs vars but only 1 pair + with pytest.raises(ValueError, match="Malformed evidence line"): + read_evidence_file(str(filepath)) + + +def test_read_model_from_string_malformed_header(): + with pytest.raises(ValueError, match="at least 4 header lines"): + read_model_from_string("MARKOV\n2\n") + + +def test_read_model_from_string_bad_network_type(): + with pytest.raises(ValueError, match="Unsupported UAI network type"): + read_model_from_string("UNKNOWN\n2\n2 2\n1\n1 0\n2\n0.5 0.5\n") + + +def test_read_model_from_string_card_mismatch(): + with pytest.raises(ValueError, match="Expected 2 cardinalities"): + read_model_from_string("MARKOV\n2\n2 2 2\n1\n1 0\n2\n0.5 0.5\n") + + +def test_read_model_from_string_scope_size_mismatch(): + with pytest.raises(ValueError, match="Scope size mismatch"): + read_model_from_string("MARKOV\n2\n2 2\n1\n2 0\n2\n0.5 0.5\n") + + +def test_read_model_from_string_table_size_mismatch(): + with pytest.raises(ValueError, match="Factor table size mismatch"): + read_model_from_string("MARKOV\n2\n2 2\n1\n2 0 1\n3\n0.5 0.5 0.5\n") + + +def test_read_model_file(tmp_path): + content = "MARKOV\n2\n2 2\n1\n1 0\n2\n0.6 0.4\n" + filepath = tmp_path / "test.uai" + filepath.write_text(content, encoding="utf-8") + model = read_model_file(str(filepath)) + assert model.nvars == 2 + assert len(model.factors) == 1