diff --git a/README.md b/README.md index 369a681..d84ccbf 100644 --- a/README.md +++ b/README.md @@ -186,6 +186,42 @@ BPDecoderPlus/ └── belief_propagation_qec_plan.tex ``` +## PyTorch BP Module (UAI) + +This repository also includes a PyTorch implementation of belief propagation for +UAI factor graphs under `src/bpdecoderplus/pytorch_bp`. + +### Python Setup + +```bash +pip install -e . +``` + +### Quick Example + +```python +from bpdecoderplus.pytorch_bp import ( + read_model_file, + BeliefPropagation, + belief_propagate, + compute_marginals, +) + +model = read_model_file("examples/simple_model.uai") +bp = BeliefPropagation(model) +state, info = belief_propagate(bp) +print(info) +print(compute_marginals(state, bp)) +``` + +### Examples and Tests + +```bash +python examples/simple_example.py +python examples/evidence_example.py +pytest tests/test_bp_basic.py tests/test_uai_parser.py tests/test_integration.py tests/testcase.py +``` + ## Available Decoders | Decoder | Symbol | Description | diff --git a/docs/api_reference.md b/docs/api_reference.md new file mode 100644 index 0000000..157431e --- /dev/null +++ b/docs/api_reference.md @@ -0,0 +1,45 @@ +## PyTorch BP API Reference + +This reference documents the public API exported from `bpdecoderplus.pytorch_bp`. + +### UAI Parsing + +- `read_model_file(path, factor_eltype=torch.float64) -> UAIModel` + Parse a UAI `.uai` model file. + +- `read_model_from_string(content, factor_eltype=torch.float64) -> UAIModel` + Parse a UAI model from an in-memory string. + +- `read_evidence_file(path) -> Dict[int, int]` + Parse a UAI `.evid` file and return evidence as 1-based indices. + +### Data Structures + +- `Factor(vars: List[int], values: torch.Tensor)` + Container for a factor scope and its tensor. + +- `UAIModel(nvars: int, cards: List[int], factors: List[Factor])` + Holds all model metadata for BP. + +### Belief Propagation + +- `BeliefPropagation(uai_model: UAIModel)` + Builds factor graph adjacency for BP. + +- `initial_state(bp: BeliefPropagation) -> BPState` + Initialize messages to uniform vectors. + +- `collect_message(bp, state, normalize=True)` + Update factor-to-variable messages in place. + +- `process_message(bp, state, normalize=True, damping=0.2)` + Update variable-to-factor messages in place. + +- `belief_propagate(bp, max_iter=100, tol=1e-6, damping=0.2, normalize=True)` + Run the full BP loop and return `(BPState, BPInfo)`. + +- `compute_marginals(state, bp) -> Dict[int, torch.Tensor]` + Compute marginal distributions after convergence. + +- `apply_evidence(bp, evidence: Dict[int, int]) -> BeliefPropagation` + Return a new BP object with evidence applied to factor tensors. diff --git a/docs/mathematical_description.md b/docs/mathematical_description.md new file mode 100644 index 0000000..d2b284b --- /dev/null +++ b/docs/mathematical_description.md @@ -0,0 +1,41 @@ +## Belief Propagation (BP) Overview + +This document summarizes the BP message-passing rules implemented in +`src/bpdecoderplus/pytorch_bp/belief_propagation.py` for discrete factor graphs. The approach +mirrors the tensor-contraction perspective used in TensorInference.jl. +See https://github.com/TensorBFS/TensorInference.jl for the Julia reference. + +### Factor Graph Notation + +- Variables are indexed by x_i with domain size d_i. +- Factors are indexed by f and connect a subset of variables. +- Each factor has a tensor (potential) phi_f defined over its variables. + +### Messages + +Factor to variable message: + +mu_{f->x}(x) = sum_{all y in ne(f), y != x} phi_f(x, y, ...) * product_{y != x} mu_{y->f}(y) + +Variable to factor message: + +mu_{x->f}(x) = product_{g in ne(x), g != f} mu_{g->x}(x) + +### Damping + +To improve stability on loopy graphs, a damping update is applied: + +mu_new = damping * mu_old + (1 - damping) * mu_candidate + +### Convergence + +We use an L1 difference threshold between consecutive factor->variable +messages to determine convergence. + +### Marginals + +After convergence, variable marginals are computed as: + +b(x) = (1 / Z) * product_{f in ne(x)} mu_{f->x}(x) + +The normalization constant Z is obtained by summing the unnormalized vector. diff --git a/docs/usage_guide.md b/docs/usage_guide.md new file mode 100644 index 0000000..eef5dee --- /dev/null +++ b/docs/usage_guide.md @@ -0,0 +1,43 @@ +## PyTorch Belief Propagation Usage + +This guide shows how to parse a UAI file, run BP, and apply evidence. +The implementation follows the tensor-contraction viewpoint in +TensorInference.jl: https://github.com/TensorBFS/TensorInference.jl + +### Quick Start + +```python +from bpdecoderplus.pytorch_bp import ( + read_model_file, + BeliefPropagation, + belief_propagate, + compute_marginals, +) + +model = read_model_file("examples/simple_model.uai") +bp = BeliefPropagation(model) +state, info = belief_propagate(bp, max_iter=50, tol=1e-8, damping=0.1) +print(info) + +marginals = compute_marginals(state, bp) +print(marginals[1]) +``` + +### Evidence + +```python +from bpdecoderplus.pytorch_bp import read_model_file, read_evidence_file, apply_evidence +from bpdecoderplus.pytorch_bp import BeliefPropagation, belief_propagate, compute_marginals + +model = read_model_file("examples/simple_model.uai") +evidence = read_evidence_file("examples/simple_model.evid") +bp = apply_evidence(BeliefPropagation(model), evidence) +state, info = belief_propagate(bp) +marginals = compute_marginals(state, bp) +``` + +### Tips + +- For loopy graphs, use damping between 0.1 and 0.5. +- Normalize messages to avoid numerical underflow. +- Use float64 for consistent comparisons in tests. diff --git a/examples/evidence_example.py b/examples/evidence_example.py new file mode 100644 index 0000000..aea099c --- /dev/null +++ b/examples/evidence_example.py @@ -0,0 +1,24 @@ +from bpdecoderplus.pytorch_bp import ( + read_model_file, + read_evidence_file, + BeliefPropagation, + belief_propagate, + compute_marginals, + apply_evidence, +) + + +def main(): + model = read_model_file("examples/simple_model.uai") + evidence = read_evidence_file("examples/simple_model.evid") + bp = apply_evidence(BeliefPropagation(model), evidence) + state, info = belief_propagate(bp, max_iter=50, tol=1e-8, damping=0.1) + print(info) + + marginals = compute_marginals(state, bp) + for var_idx, marginal in marginals.items(): + print(f"Variable {var_idx} marginal: {marginal}") + + +if __name__ == "__main__": + main() diff --git a/examples/simple_example.py b/examples/simple_example.py new file mode 100644 index 0000000..a9c68ba --- /dev/null +++ b/examples/simple_example.py @@ -0,0 +1,21 @@ +from bpdecoderplus.pytorch_bp import ( + read_model_file, + BeliefPropagation, + belief_propagate, + compute_marginals, +) + + +def main(): + model = read_model_file("examples/simple_model.uai") + bp = BeliefPropagation(model) + state, info = belief_propagate(bp, max_iter=50, tol=1e-8, damping=0.1) + print(info) + + marginals = compute_marginals(state, bp) + for var_idx, marginal in marginals.items(): + print(f"Variable {var_idx} marginal: {marginal}") + + +if __name__ == "__main__": + main() diff --git a/examples/simple_model.evid b/examples/simple_model.evid new file mode 100644 index 0000000..722ff38 --- /dev/null +++ b/examples/simple_model.evid @@ -0,0 +1 @@ +1 0 1 diff --git a/examples/simple_model.uai b/examples/simple_model.uai new file mode 100644 index 0000000..694d568 --- /dev/null +++ b/examples/simple_model.uai @@ -0,0 +1,10 @@ +MARKOV +2 +2 2 +2 +1 0 +2 0 1 +2 +0.6 0.4 +4 +0.9 0.1 0.2 0.8 diff --git a/pyproject.toml b/pyproject.toml index 59ff2bd..85c8cbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ classifiers = [ dependencies = [ "stim>=1.12.0", "numpy>=1.24.0", + "torch>=2.0.0", ] [project.optional-dependencies] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..12c6d5d --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +torch diff --git a/src/bpdecoderplus/pytorch_bp/__init__.py b/src/bpdecoderplus/pytorch_bp/__init__.py new file mode 100644 index 0000000..3b88f48 --- /dev/null +++ b/src/bpdecoderplus/pytorch_bp/__init__.py @@ -0,0 +1,42 @@ +""" +PyTorch Belief Propagation (BP) submodule for approximate inference. +""" + +from .uai_parser import ( + read_model_file, + read_model_from_string, + read_evidence_file, + UAIModel, + Factor +) + +from .belief_propagation import ( + BeliefPropagation, + BPState, + BPInfo, + initial_state, + collect_message, + process_message, + belief_propagate, + compute_marginals, + apply_evidence +) + +__all__ = [ + # UAI parsing + 'read_model_file', + 'read_model_from_string', + 'read_evidence_file', + 'UAIModel', + 'Factor', + # Belief Propagation + 'BeliefPropagation', + 'BPState', + 'BPInfo', + 'initial_state', + 'collect_message', + 'process_message', + 'belief_propagate', + 'compute_marginals', + 'apply_evidence', +] diff --git a/src/bpdecoderplus/pytorch_bp/belief_propagation.py b/src/bpdecoderplus/pytorch_bp/belief_propagation.py new file mode 100644 index 0000000..b774e9e --- /dev/null +++ b/src/bpdecoderplus/pytorch_bp/belief_propagation.py @@ -0,0 +1,357 @@ +""" +Belief Propagation (BP) algorithm implementation using PyTorch. +""" + +from typing import List, Dict, Tuple +import torch +from copy import deepcopy + +from .uai_parser import UAIModel, Factor + + +class BeliefPropagation: + """Belief Propagation object for factor graphs.""" + + def __init__(self, uai_model: UAIModel): + """ + Construct BP object from UAI model. + + Args: + uai_model: Parsed UAI model + """ + self.nvars = uai_model.nvars + self.factors = uai_model.factors + self.cards = uai_model.cards + + # Build mapping: t2v (factor -> variables), v2t (variable -> factors) + self.t2v = [list(factor.vars) for factor in self.factors] + self.v2t = self._build_v2t() + + def _build_v2t(self) -> List[List[int]]: + """Build variable-to-factors mapping.""" + v2t = [[] for _ in range(self.nvars)] + for factor_idx, vars_list in enumerate(self.t2v): + for var in vars_list: + if 0 < var <= self.nvars: # Ensure valid index + v2t[var - 1].append(factor_idx) # Convert to 0-based + return v2t + + def num_tensors(self) -> int: + """Return number of factors (tensors).""" + return len(self.t2v) + + def num_variables(self) -> int: + """Return number of variables.""" + return self.nvars + + +class BPState: + """BP state storing messages.""" + + def __init__(self, message_in: List[List[torch.Tensor]], + message_out: List[List[torch.Tensor]]): + """ + Args: + message_in: Incoming messages from factors to variables + message_out: Outgoing messages from variables to factors + """ + self.message_in = message_in + self.message_out = message_out + + +class BPInfo: + """BP convergence information.""" + + def __init__(self, converged: bool, iterations: int): + self.converged = converged + self.iterations = iterations + + def __repr__(self): + status = "converged" if self.converged else "not converged" + return f"BPInfo({status}, iterations={self.iterations})" + + +def initial_state(bp: BeliefPropagation) -> BPState: + """ + Initialize BP message state with all ones vectors. + + Args: + bp: BeliefPropagation object + + Returns: + BPState with initialized messages + """ + message_in = [] + message_out = [] + + for var_idx in range(bp.nvars): + var_messages_in = [] + var_messages_out = [] + + for _ in bp.v2t[var_idx]: + card = bp.cards[var_idx] + msg = torch.ones(card, dtype=torch.float64) + var_messages_in.append(msg.clone()) + var_messages_out.append(msg.clone()) + + message_in.append(var_messages_in) + message_out.append(var_messages_out) + + return BPState(message_in, message_out) + + +def _compute_factor_to_var_message( + factor_tensor: torch.Tensor, + incoming_messages: List[torch.Tensor], + target_var_idx: int +) -> torch.Tensor: + """ + Compute factor to variable message using tensor contraction. + + μ_{f→x}(x) = Σ_{other vars} [φ_f(...) * Π_{y≠x} μ_{y→f}] + + Args: + factor_tensor: Factor tensor with shape (d1, d2, ..., dn) + incoming_messages: List of incoming messages, one for each variable in factor + target_var_idx: Index of target variable (0-based) in factor's variable list + + Returns: + Output message vector with shape (d_target,) + """ + ndims = len(incoming_messages) + + if ndims == 1: + return factor_tensor.clone() + + # Multiply factor tensor by incoming messages (excluding target) and sum out dims. + result = factor_tensor.clone() + for dim in range(ndims): + if dim == target_var_idx: + continue + msg = incoming_messages[dim] + shape = [1] * ndims + shape[dim] = msg.shape[0] + result = result * msg.view(*shape) + + # Sum over all dimensions except target + sum_dims = [dim for dim in range(ndims) if dim != target_var_idx] + if sum_dims: + result = result.sum(dim=tuple(sum_dims)) + return result + + +def collect_message(bp: BeliefPropagation, state: BPState, normalize: bool = True) -> None: + """ + Collect and update messages from factors to variables. + + μ_{f→x}(x) = Σ[φ_f(...) * Π μ_{y→f}] + + Args: + bp: BeliefPropagation object + state: BPState (modified in place) + normalize: Whether to normalize messages + """ + for factor_idx, factor in enumerate(bp.factors): + # Get incoming messages from variables to this factor + incoming_messages = [] + var_factor_positions = [] + for var in factor.vars: + var_idx_0based = var - 1 + # Find position of this factor in v2t[var_idx_0based] + factor_pos = bp.v2t[var_idx_0based].index(factor_idx) + incoming_messages.append(state.message_out[var_idx_0based][factor_pos]) + var_factor_positions.append(factor_pos) + + # Compute outgoing message to each variable + for var_pos, var in enumerate(factor.vars): + var_idx_0based = var - 1 + # Compute message from factor to variable + outgoing_msg = _compute_factor_to_var_message( + factor.values, + incoming_messages, + var_pos + ) + + # Normalize + if normalize: + msg_sum = outgoing_msg.sum() + if msg_sum > 0: + outgoing_msg = outgoing_msg / msg_sum + + # Update message_in + factor_pos = var_factor_positions[var_pos] + state.message_in[var_idx_0based][factor_pos] = outgoing_msg + + +def process_message( + bp: BeliefPropagation, + state: BPState, + normalize: bool = True, + damping: float = 0.2, +) -> None: + r""" + Process and update messages from variables to factors. + + μ_{x→f}(x) = Π_{g∈ne(x)\setminus f} μ_{g→x}(x) + + Args: + bp: BeliefPropagation object + state: BPState (modified in place) + normalize: Whether to normalize messages + damping: Damping factor for message update + """ + for var_idx_0based in range(bp.nvars): + for factor_pos, factor_idx in enumerate(bp.v2t[var_idx_0based]): + # Compute product of all incoming messages except from current factor + product = torch.ones(bp.cards[var_idx_0based], dtype=torch.float64) + + for other_factor_pos, other_factor_idx in enumerate(bp.v2t[var_idx_0based]): + if other_factor_pos != factor_pos: + product = product * state.message_in[var_idx_0based][other_factor_pos] + + # Normalize + if normalize: + msg_sum = product.sum() + if msg_sum > 0: + product = product / msg_sum + + # Damping update + old_message = state.message_out[var_idx_0based][factor_pos].clone() + state.message_out[var_idx_0based][factor_pos] = ( + damping * old_message + (1 - damping) * product + ) + + +def _check_convergence(message_new: List[List[torch.Tensor]], + message_old: List[List[torch.Tensor]], + tol: float = 1e-6) -> bool: + """ + Check if messages have converged. + + Args: + message_new: Current iteration messages + message_old: Previous iteration messages + tol: Convergence tolerance + + Returns: + True if converged, False otherwise + """ + for var_msgs_new, var_msgs_old in zip(message_new, message_old): + for msg_new, msg_old in zip(var_msgs_new, var_msgs_old): + # Compute L1 distance + diff = torch.abs(msg_new - msg_old).sum() + if diff > tol: + return False + return True + + +def belief_propagate(bp: BeliefPropagation, + max_iter: int = 100, + tol: float = 1e-6, + damping: float = 0.2, + normalize: bool = True) -> Tuple[BPState, BPInfo]: + """ + Run Belief Propagation algorithm main loop. + + Args: + bp: BeliefPropagation object + max_iter: Maximum number of iterations + tol: Convergence tolerance + damping: Damping factor + normalize: Whether to normalize messages + + Returns: + Tuple of (BPState, BPInfo) + """ + state = initial_state(bp) + + for iteration in range(max_iter): + # Save previous messages for convergence check + prev_message_in = deepcopy(state.message_in) + + # Update messages + collect_message(bp, state, normalize=normalize) + process_message(bp, state, normalize=normalize, damping=damping) + + # Check convergence + if _check_convergence(state.message_in, prev_message_in, tol=tol): + return state, BPInfo(converged=True, iterations=iteration + 1) + + return state, BPInfo(converged=False, iterations=max_iter) + + +def compute_marginals(state: BPState, bp: BeliefPropagation) -> Dict[int, torch.Tensor]: + """ + Compute marginal probabilities from converged BP state. + + b(x) = (1/Z) * Π_{f∈ne(x)} μ_{f→x}(x) + + Args: + state: Converged BPState + bp: BeliefPropagation object + + Returns: + Dictionary mapping variable index (1-based) to marginal probability distribution + """ + marginals = {} + + for var_idx_0based in range(bp.nvars): + # Product of all incoming messages + product = torch.ones(bp.cards[var_idx_0based], dtype=torch.float64) + + for msg in state.message_in[var_idx_0based]: + product = product * msg + + # Normalize to get probability distribution + msg_sum = product.sum() + if msg_sum > 0: + product = product / msg_sum + + marginals[var_idx_0based + 1] = product # Convert to 1-based indexing + + return marginals + + +def apply_evidence(bp: BeliefPropagation, evidence: Dict[int, int]) -> BeliefPropagation: + """ + Apply evidence constraints to BP object. + + Modifies factor tensors to zero out non-evidence assignments. + + Args: + bp: Original BeliefPropagation object + evidence: Dictionary mapping variable index (1-based) to value (0-based) + + Returns: + New BeliefPropagation object with evidence applied + """ + # Create new factors with evidence constraints + new_factors = [] + + for factor in bp.factors: + # Create mask for evidence constraints + factor_tensor = factor.values.clone() + + # Apply evidence constraints + for var_pos, var in enumerate(factor.vars): + if var in evidence: + evid_value = evidence[var] + dim_size = factor_tensor.shape[var_pos] + if 0 <= evid_value < dim_size: + all_indices = torch.arange(dim_size, device=factor_tensor.device) + zero_indices = all_indices[all_indices != evid_value] + if zero_indices.numel() > 0: + factor_tensor = factor_tensor.index_fill( + var_pos, zero_indices, 0 + ) + else: + factor_tensor = torch.zeros_like(factor_tensor) + break + + new_factors.append(Factor(factor.vars, factor_tensor)) + + # Create new UAIModel with modified factors + from .uai_parser import UAIModel + new_uai = UAIModel(bp.nvars, bp.cards, new_factors) + + return BeliefPropagation(new_uai) diff --git a/src/bpdecoderplus/pytorch_bp/uai_parser.py b/src/bpdecoderplus/pytorch_bp/uai_parser.py new file mode 100644 index 0000000..6318978 --- /dev/null +++ b/src/bpdecoderplus/pytorch_bp/uai_parser.py @@ -0,0 +1,154 @@ +""" +UAI file format parser for Belief Propagation. +""" + +from typing import List, Dict +import torch + + +class Factor: + """Factor class representing a factor in the factor graph.""" + + def __init__(self, vars: List[int], values: torch.Tensor): + """ + Args: + vars: List of variable indices (1-based) + values: Tensor of factor values with shape matching variable cardinalities + """ + self.vars = tuple(vars) + self.values = values + + def __repr__(self): + return f"Factor(vars={self.vars}, shape={self.values.shape})" + + +class UAIModel: + """UAI model class containing variables, cardinalities, and factors.""" + + def __init__(self, nvars: int, cards: List[int], factors: List[Factor]): + """ + Args: + nvars: Number of variables + cards: List of cardinalities for each variable + factors: List of factors + """ + self.nvars = nvars + self.cards = cards + self.factors = factors + + def __repr__(self): + return f"UAIModel(nvars={self.nvars}, nfactors={len(self.factors)})" + + +def read_model_file(filepath: str, factor_eltype=torch.float64) -> UAIModel: + """ + Parse UAI format model file. + + Args: + filepath: Path to .uai file + factor_eltype: Data type for factor values (default: torch.float64) + + Returns: + UAIModel object + """ + with open(filepath, 'r') as f: + content = f.read() + return read_model_from_string(content, factor_eltype=factor_eltype) + + +def read_model_from_string(content: str, factor_eltype=torch.float64) -> UAIModel: + """ + Parse UAI model from string. + + Args: + content: UAI file content as string + factor_eltype: Data type for factor values + + Returns: + UAIModel object + """ + lines = [line.strip() for line in content.split('\n') if line.strip()] + + # Parse header + network_type = lines[0] # MARKOV or BAYES + if network_type not in ("MARKOV", "BAYES"): + raise ValueError( + f"Unsupported UAI network type: {network_type!r}. Expected 'MARKOV' or 'BAYES'." + ) + nvars = int(lines[1]) + cards = [int(x) for x in lines[2].split()] + ntables = int(lines[3]) + + # Parse factor scopes + scopes = [] + for i in range(ntables): + parts = lines[4 + i].split() + scope_size = int(parts[0]) + if len(parts) - 1 != scope_size: + raise ValueError( + f"Scope size mismatch on line {4 + i}: " + f"declared {scope_size}, found {len(parts) - 1} variables." + ) + scope = [int(x) + 1 for x in parts[1:]] # Convert to 1-based + scopes.append(scope) + + # Parse factor tables + idx = 4 + ntables + tokens: List[str] = [] + while idx < len(lines): + tokens.extend(lines[idx].split()) + idx += 1 + cursor = 0 + + factors: List[Factor] = [] + for scope in scopes: + if cursor >= len(tokens): + raise ValueError("Unexpected end of UAI factor table data.") + nelements = int(tokens[cursor]) + cursor += 1 + values = torch.tensor( + [float(x) for x in tokens[cursor:cursor + nelements]], + dtype=factor_eltype + ) + cursor += nelements + + # Reshape according to cardinalities in original scope order + shape = tuple([cards[v - 1] for v in scope]) + values = values.reshape(shape) + factors.append(Factor(scope, values)) + + return UAIModel(nvars, cards, factors) + + +def read_evidence_file(filepath: str) -> Dict[int, int]: + """ + Parse evidence file (.evid format). + + Args: + filepath: Path to .evid file + + Returns: + Dictionary mapping variable index (1-based) to observed value (0-based) + """ + if not filepath: + return {} + + with open(filepath, 'r') as f: + lines = f.readlines() + + if not lines: + return {} + + # Parse last line + last_line = lines[-1].strip() + parts = [int(x) for x in last_line.split()] + + nobsvars = parts[0] + evidence = {} + + for i in range(nobsvars): + var_idx = parts[1 + 2*i] + 1 # Convert to 1-based + var_value = parts[2 + 2*i] + evidence[var_idx] = var_value + + return evidence diff --git a/tests/__init__.py b/tests/__init__.py index 2e36e72..8b01d5c 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1 @@ -# Tests for bpdecoderplus package +"""Tests for bpdecoderplus and pytorch_bp modules.""" diff --git a/tests/_path.py b/tests/_path.py new file mode 100644 index 0000000..2194164 --- /dev/null +++ b/tests/_path.py @@ -0,0 +1,10 @@ +import os +import sys + + +def add_project_root_to_path(): + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) + src_root = os.path.join(project_root, "src") + for path in (src_root, project_root): + if path not in sys.path: + sys.path.insert(0, path) diff --git a/tests/test_bp_basic.py b/tests/test_bp_basic.py new file mode 100644 index 0000000..be18a00 --- /dev/null +++ b/tests/test_bp_basic.py @@ -0,0 +1,69 @@ +import unittest +import torch + +try: + from ._path import add_project_root_to_path +except ImportError: + from _path import add_project_root_to_path + +add_project_root_to_path() + +from bpdecoderplus.pytorch_bp import ( + read_model_from_string, + BeliefPropagation, + belief_propagate, + compute_marginals, + apply_evidence, +) + +from tests.test_utils import exact_marginals + + +class TestBeliefPropagationBasic(unittest.TestCase): + def setUp(self): + self.content = "\n".join( + [ + "MARKOV", + "2", + "2 2", + "2", + "1 0", + "2 0 1", + "2", + "0.6 0.4", + "4", + "0.9 0.1 0.2 0.8", + ] + ) + + def test_bp_matches_exact_tree(self): + model = read_model_from_string(self.content) + bp = BeliefPropagation(model) + state, info = belief_propagate(bp, max_iter=50, tol=1e-10, damping=0.0) + self.assertTrue(info.converged) + marginals = compute_marginals(state, bp) + exact = exact_marginals(model) + + for var_idx in marginals: + self.assertTrue(torch.allclose(marginals[var_idx], exact[var_idx], atol=1e-6)) + + def test_apply_evidence(self): + model = read_model_from_string(self.content) + evidence = {1: 1} + bp = apply_evidence(BeliefPropagation(model), evidence) + state, info = belief_propagate(bp, max_iter=50, tol=1e-10, damping=0.0) + self.assertTrue(info.converged) + marginals = compute_marginals(state, bp) + + exact = exact_marginals(read_model_from_string(self.content), evidence=evidence) + self.assertTrue( + torch.allclose( + marginals[1], torch.tensor([0.0, 1.0], dtype=torch.float64) + ) + ) + self.assertAlmostEqual(float(marginals[2].sum()), 1.0, places=6) + self.assertTrue(torch.allclose(marginals[2], exact[2], atol=1e-6)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 0000000..80ff83e --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,40 @@ +import unittest + +try: + from ._path import add_project_root_to_path +except ImportError: + from _path import add_project_root_to_path + +add_project_root_to_path() + +from bpdecoderplus.pytorch_bp import ( + read_model_file, + read_evidence_file, + BeliefPropagation, + belief_propagate, + compute_marginals, + apply_evidence, +) + + +class TestIntegration(unittest.TestCase): + def test_example_file_runs(self): + model = read_model_file("examples/simple_model.uai") + bp = BeliefPropagation(model) + state, info = belief_propagate(bp, max_iter=30, tol=1e-8) + self.assertGreater(info.iterations, 0) + marginals = compute_marginals(state, bp) + self.assertEqual(set(marginals.keys()), {1, 2}) + + def test_example_with_evidence(self): + model = read_model_file("examples/simple_model.uai") + evidence = read_evidence_file("examples/simple_model.evid") + bp = apply_evidence(BeliefPropagation(model), evidence) + state, info = belief_propagate(bp, max_iter=30, tol=1e-8) + self.assertGreater(info.iterations, 0) + marginals = compute_marginals(state, bp) + self.assertEqual(set(marginals.keys()), {1, 2}) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_uai_parser.py b/tests/test_uai_parser.py new file mode 100644 index 0000000..044a3da --- /dev/null +++ b/tests/test_uai_parser.py @@ -0,0 +1,103 @@ +import unittest +import torch + +try: + from ._path import add_project_root_to_path +except ImportError: + from _path import add_project_root_to_path + +add_project_root_to_path() + +from bpdecoderplus.pytorch_bp import read_model_from_string, read_evidence_file + + +class TestUAIParser(unittest.TestCase): + def test_read_model_from_string(self): + content = "\n".join( + [ + "MARKOV", + "2", + "2 2", + "2", + "1 0", + "2 0 1", + "2", + "0.6 0.4", + "4", + "0.9 0.1 0.2 0.8", + ] + ) + model = read_model_from_string(content) + self.assertEqual(model.nvars, 2) + self.assertEqual(model.cards, [2, 2]) + self.assertEqual(len(model.factors), 2) + + factor0 = model.factors[0] + factor1 = model.factors[1] + self.assertEqual(factor0.vars, (1,)) + self.assertEqual(factor1.vars, (1, 2)) + self.assertEqual(tuple(factor0.values.shape), (2,)) + self.assertEqual(tuple(factor1.values.shape), (2, 2)) + self.assertTrue( + torch.allclose( + factor0.values, torch.tensor([0.6, 0.4], dtype=torch.float64) + ) + ) + self.assertTrue( + torch.allclose( + factor1.values, + torch.tensor([[0.9, 0.1], [0.2, 0.8]], dtype=torch.float64), + ) + ) + + def test_read_evidence_file(self): + with open("examples/simple_model.evid", "r") as f: + content = f.read().strip() + self.assertEqual(content, "1 0 1") + + evidence = read_evidence_file("examples/simple_model.evid") + self.assertEqual(evidence, {1: 1}) + + def test_invalid_network_type(self): + content = "\n".join( + [ + "INVALID", + "1", + "2", + "0", + ] + ) + with self.assertRaises(ValueError): + read_model_from_string(content) + + def test_scope_size_mismatch(self): + content = "\n".join( + [ + "MARKOV", + "2", + "2 2", + "1", + "2 0", + "2", + "0.5 0.5", + ] + ) + with self.assertRaises(ValueError): + read_model_from_string(content) + + def test_missing_table_entries(self): + content = "\n".join( + [ + "MARKOV", + "1", + "2", + "1", + "1 0", + ] + ) + with self.assertRaises(ValueError): + read_model_from_string(content) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..e3cd4d0 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,29 @@ +import itertools +import torch + + +def exact_marginals(model, evidence=None): + evidence = evidence or {} + assignments = list(itertools.product(*[range(c) for c in model.cards])) + weights = [] + for assignment in assignments: + if any(assignment[var_idx - 1] != val for var_idx, val in evidence.items()): + weights.append(0.0) + continue + weight = 1.0 + for factor in model.factors: + idx = tuple(assignment[v - 1] for v in factor.vars) + weight *= float(factor.values[idx]) + weights.append(weight) + total = sum(weights) + marginals = {} + for var_idx, card in enumerate(model.cards): + values = [] + for value in range(card): + mass = 0.0 + for assignment, weight in zip(assignments, weights): + if assignment[var_idx] == value: + mass += weight + values.append(mass / total if total > 0 else 0.0) + marginals[var_idx + 1] = torch.tensor(values, dtype=torch.float64) + return marginals diff --git a/tests/testcase.py b/tests/testcase.py new file mode 100644 index 0000000..7b3f99b --- /dev/null +++ b/tests/testcase.py @@ -0,0 +1,136 @@ +import unittest +import torch + +try: + from ._path import add_project_root_to_path +except ImportError: + from _path import add_project_root_to_path + +add_project_root_to_path() + +from bpdecoderplus.pytorch_bp import ( + read_model_from_string, + BeliefPropagation, + belief_propagate, + compute_marginals, + initial_state, + collect_message, + process_message, + apply_evidence, +) + + +from tests.test_utils import exact_marginals + + +class TestBPAdditionalCases(unittest.TestCase): + def test_unary_factor_marginal(self): + content = "\n".join( + [ + "MARKOV", + "1", + "3", + "1", + "1 0", + "3", + "0.2 0.3 0.5", + ] + ) + model = read_model_from_string(content) + bp = BeliefPropagation(model) + state, info = belief_propagate(bp, max_iter=20, tol=1e-10, damping=0.0) + self.assertTrue(info.converged) + marginals = compute_marginals(state, bp) + expected = torch.tensor([0.2, 0.3, 0.5], dtype=torch.float64) + self.assertTrue(torch.allclose(marginals[1], expected, atol=1e-6)) + + def test_chain_three_vars_exact(self): + content = "\n".join( + [ + "MARKOV", + "3", + "2 2 2", + "2", + "2 0 1", + "2 1 2", + "4", + "0.9 0.1 0.2 0.8", + "4", + "0.3 0.7 0.6 0.4", + ] + ) + model = read_model_from_string(content) + bp = BeliefPropagation(model) + state, info = belief_propagate(bp, max_iter=50, tol=1e-10, damping=0.0) + self.assertTrue(info.converged) + marginals = compute_marginals(state, bp) + exact = exact_marginals(model) + for var_idx in marginals: + self.assertTrue(torch.allclose(marginals[var_idx], exact[var_idx], atol=1e-6)) + + def test_message_normalization(self): + content = "\n".join( + [ + "MARKOV", + "2", + "2 2", + "1", + "2 0 1", + "4", + "0.9 0.1 0.2 0.8", + ] + ) + model = read_model_from_string(content) + bp = BeliefPropagation(model) + state = initial_state(bp) + collect_message(bp, state, normalize=True) + process_message(bp, state, normalize=True, damping=0.0) + for var_msgs in state.message_in: + for msg in var_msgs: + self.assertAlmostEqual(float(msg.sum()), 1.0, places=6) + for var_msgs in state.message_out: + for msg in var_msgs: + self.assertAlmostEqual(float(msg.sum()), 1.0, places=6) + + def test_zero_message_handling(self): + content = "\n".join( + [ + "MARKOV", + "1", + "2", + "2", + "1 0", + "1 0", + "2", + "0.0 0.0", + "2", + "0.7 0.3", + ] + ) + model = read_model_from_string(content) + bp = BeliefPropagation(model) + state = initial_state(bp) + collect_message(bp, state, normalize=True) + process_message(bp, state, normalize=True, damping=0.0) + self.assertAlmostEqual(float(state.message_in[0][0].sum()), 0.0, places=6) + self.assertAlmostEqual(float(state.message_out[0][1].sum()), 0.0, places=6) + + def test_evidence_out_of_range_zeros_factor(self): + content = "\n".join( + [ + "MARKOV", + "1", + "2", + "1", + "1 0", + "2", + "0.4 0.6", + ] + ) + model = read_model_from_string(content) + bp = apply_evidence(BeliefPropagation(model), {1: 5}) + self.assertAlmostEqual(float(bp.factors[0].values.sum()), 0.0, places=6) + + +if __name__ == "__main__": + unittest.main()