Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
b0327a2
feat: spras_revision
tristan-f-r Jul 9, 2025
8cec738
style: fmt
tristan-f-r Jul 9, 2025
5683392
test: summary
tristan-f-r Jul 10, 2025
af90ce0
docs(test_summary): mention preprocessing motivation
tristan-f-r Jul 10, 2025
6141874
test(analysis/summary): use input from /input instead
tristan-f-r Jul 10, 2025
440a2d4
docs(test/analysis): mention dual integration testing
tristan-f-r Jul 10, 2025
d9e852b
test(analysis/summary): use test/analysis provided gold standard
tristan-f-r Jul 10, 2025
abb0eb9
style: fmt
tristan-f-r Jul 10, 2025
60185fc
chore: don't repeat docs inside analysis configs
tristan-f-r Jul 10, 2025
e6bd6a0
feat: get working with cytoscape
tristan-f-r Jul 11, 2025
f9a3081
style: fmt
tristan-f-r Jul 11, 2025
77fc3b4
test: remove nondet from analysis
tristan-f-r Jul 11, 2025
0592850
fix: get input pathways at runtime
tristan-f-r Jul 11, 2025
0b6413d
Merge branch 'umain' into hash
tristan-f-r Aug 4, 2025
1817157
fix: rm run
tristan-f-r Aug 4, 2025
c077d91
Merge branch 'main' into hash
tristan-f-r Aug 14, 2025
50f2195
fix: correct for pydantic
tristan-f-r Aug 14, 2025
d3a088b
fix: attach spras revision inside gs_values
tristan-f-r Aug 14, 2025
8e3b898
chore: drop re import
tristan-f-r Aug 14, 2025
1ada504
Merge branch 'main' into hash
tristan-f-r Aug 27, 2025
34a40ad
fix: correct tests
tristan-f-r Aug 27, 2025
5d2c6d0
Merge branch 'main' into hash
tristan-f-r Sep 9, 2025
ef15781
Merge branch 'main' into hash
tristan-f-r Sep 24, 2025
8d5019b
fix: correct Snakefile
tristan-f-r Sep 24, 2025
9949572
fix: use correct gs variable
tristan-f-r Sep 25, 2025
6ec4f62
refactor: separate statistic computation
tristan-f-r Oct 10, 2025
9987189
fix: correct tuple assumption
tristan-f-r Oct 10, 2025
25eef5e
fix: stably use graph statistic values
tristan-f-r Oct 10, 2025
3cd25e8
Merge branch 'main' into hash
tristan-f-r Oct 24, 2025
0965a68
test: correct config
tristan-f-r Oct 25, 2025
a169505
fix: correct name again
tristan-f-r Oct 25, 2025
cb373c1
style: fmt
tristan-f-r Oct 30, 2025
47a9e26
Merge branch 'main' into lazy-stats
tristan-f-r Oct 30, 2025
898d568
style: specify zip strict
tristan-f-r Oct 30, 2025
c675ece
fix: make undirected for determining number of connected components
tristan-f-r Nov 6, 2025
eec09f2
Merge branch 'main' into hash
tristan-f-r Jan 10, 2026
a8d71bd
test: fix files
tristan-f-r Jan 10, 2026
3c81d05
Merge branch 'main' into lazy-stats
tristan-f-r Jan 13, 2026
1ca730e
feat: snakemake-based summary generation
tristan-f-r Jan 13, 2026
d67186d
fix(Snakefile): use parse_output for edgelist parsing
tristan-f-r Jan 13, 2026
fd483c3
fix: parse edgelist with rank, embed header skip inside from_edgelist
tristan-f-r Jan 13, 2026
fd5046f
style: fmt
tristan-f-r Jan 13, 2026
79cf748
chore: mention statistics_files param
tristan-f-r Jan 13, 2026
e12fc75
apply suggestions
tristan-f-r Jan 17, 2026
977bf5a
clean, fix: strip project_directory
tristan-f-r Jan 17, 2026
8500bcb
fix: correct equality on not SPRAS pyproject.toml
tristan-f-r Jan 17, 2026
112db39
chore: grammar
tristan-f-r Jan 17, 2026
c7262ed
chore: move attach_spras_revision out of Snakefile
tristan-f-r Jan 18, 2026
f69a0f3
Merge branch 'main' into hash
tristan-f-r Jan 31, 2026
72e30bf
fix: properly resolve merge conflict
tristan-f-r Jan 31, 2026
c71b652
fix: undo mistaken merge conflict
tristan-f-r Jan 31, 2026
6b941e0
chore: drop unnecessary self.datasets initialization
tristan-f-r Jan 31, 2026
339d915
Merge branch 'hash' into lazy-stats
tristan-f-r Jan 31, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ import os
from spras import runner
import shutil
import yaml
from spras.dataset import Dataset
from spras.evaluation import Evaluation
from spras.analysis import ml, summary, cytoscape
import spras.config.config as _config
from spras.dataset import Dataset
from spras.evaluation import Evaluation
from spras.statistics import from_output_pathway, statistics_computation, statistics_options

# Snakemake updated the behavior in the 6.5.0 release https://github.com/snakemake/snakemake/pull/1037
# and using the wrong separator prevents Snakemake from matching filenames to the rules that can produce them
Expand Down Expand Up @@ -34,7 +35,6 @@ def get_dataset(_datasets, label):
algorithms = list(algorithm_params)
algorithms_with_params = [f'{algorithm}-params-{params_hash}' for algorithm, param_combos in algorithm_params.items() for params_hash in param_combos.keys()]
dataset_labels = list(_config.config.datasets.keys())

dataset_gold_standard_node_pairs = [f"{dataset}-{gs['label']}" for gs in _config.config.gold_standards.values() if gs['node_files'] for dataset in gs['dataset_labels']]
dataset_gold_standard_edge_pairs = [f"{dataset}-{gs['label']}" for gs in _config.config.gold_standards.values() if gs['edge_files'] for dataset in gs['dataset_labels']]

Expand Down Expand Up @@ -282,7 +282,7 @@ rule reconstruct:
# Original pathway reconstruction output to universal output
# Use PRRunner as a wrapper to call the algorithm-specific parse_output
rule parse_output:
input:
input:
raw_file = SEP.join([out_dir, '{dataset}-{algorithm}-{params}', 'raw-pathway.txt']),
dataset_file = SEP.join([out_dir, 'dataset-{dataset}-merged.pickle'])
output: standardized_file = SEP.join([out_dir, '{dataset}-{algorithm}-{params}', 'pathway.txt'])
Expand Down Expand Up @@ -310,18 +310,33 @@ rule viz_cytoscape:
run:
cytoscape.run_cytoscape(input.pathways, output.session, container_settings)

for keys, values in statistics_computation.items():
pythonic_name = 'generate_' + '_and_'.join([key.lower().replace(' ', '_') for key in keys])
rule:
name: pythonic_name
input: pathway_file = rules.parse_output.output.standardized_file
output: [SEP.join([out_dir, '{dataset}-{algorithm}-{params}', 'statistics', f'{key}.txt']) for key in keys]
run:
(Path(input.pathway_file).parent / 'statistics').mkdir(exist_ok=True)
graph = from_output_pathway(input.pathway_file)
for computed, output in zip(values(graph), output):
Path(output).write_text(str(computed))

# Write a single summary table for all pathways for each dataset
rule summary_table:
input:
# Collect all pathways generated for the dataset
pathways = expand('{out_dir}{sep}{{dataset}}-{algorithm_params}{sep}pathway.txt', out_dir=out_dir, sep=SEP, algorithm_params=algorithms_with_params),
dataset_file = SEP.join([out_dir, 'dataset-{dataset}-merged.pickle'])
dataset_file = SEP.join([out_dir, 'dataset-{dataset}-merged.pickle']),
# Collect all possible options
statistics = expand(
'{out_dir}{sep}{{dataset}}-{algorithm_params}{sep}statistics{sep}{statistic}.txt',
out_dir=out_dir, sep=SEP, algorithm_params=algorithms_with_params, statistic=statistics_options)
output: summary_table = SEP.join([out_dir, '{dataset}-pathway-summary.txt'])
run:
# Load the node table from the pickled dataset file
node_table = Dataset.from_file(input.dataset_file).node_table
summary_df = summary.summarize_networks(input.pathways, node_table, algorithm_params, algorithms_with_params)
summary_df = summary.summarize_networks(input.pathways, node_table, algorithm_params, algorithms_with_params, input.statistics)
summary_df.to_csv(output.summary_table, sep='\t', index=False)

# Cluster the output pathways for each dataset
Expand Down
57 changes: 13 additions & 44 deletions spras/analysis/summary.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import ast
from pathlib import Path
from statistics import median
from typing import Iterable

import networkx as nx
import pandas as pd

from spras.statistics import from_output_pathway


def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, algo_params: dict[str, dict],
algo_with_params: list) -> pd.DataFrame:
algo_with_params: list[str], statistics_files: list) -> pd.DataFrame:
"""
Generate a table that aggregates summary information about networks in file_paths, including which nodes are present
in node_table columns. Network directionality is ignored and all edges are treated as undirected. The order of the
Expand All @@ -17,6 +18,7 @@ def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, alg
@param algo_params: a nested dict mapping algorithm names to dicts that map parameter hashes to parameter
combinations.
@param algo_with_params: a list of <algorithm>-params-<params_hash> combinations
@param statistics_files: a list of statistic files with the computed statistics.
@return: pandas DataFrame with summary information
"""
# Ensure that NODEID is the first column
Expand All @@ -39,52 +41,17 @@ def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, alg

# Iterate through each network file path
for index, file_path in enumerate(sorted(file_paths)):
with open(file_path, 'r') as f:
lines = f.readlines()[1:] # skip the header line

# directed or mixed graphs are parsed and summarized as an undirected graph
nw = nx.read_edgelist(lines, data=(('weight', float), ('Direction', str)))
nw = from_output_pathway(file_path)

# Save the network name, number of nodes, number edges, and number of connected components
nw_name = str(file_path)
number_nodes = nw.number_of_nodes()
number_edges = nw.number_of_edges()
ncc = nx.number_connected_components(nw)

# Save the max/median degree, average clustering coefficient, and density
if number_nodes == 0:
max_degree = 0
median_degree = 0.0
density = 0.0
else:
degrees = [deg for _, deg in nw.degree()]
max_degree = max(degrees)
median_degree = median(degrees)
density = nx.density(nw)

cc = list(nx.connected_components(nw))
# Save the max diameter
# Use diameter only for components with ≥2 nodes (singleton components have diameter 0)
diameters = [
nx.diameter(nw.subgraph(c).copy()) if len(c) > 1 else 0
for c in cc
]
max_diameter = max(diameters, default=0)

# Save the average path lengths
# Compute average shortest path length only for components with ≥2 nodes (undefined for singletons, set to 0.0)
avg_path_lengths = [
nx.average_shortest_path_length(nw.subgraph(c).copy()) if len(c) > 1 else 0.0
for c in cc
]

if len(avg_path_lengths) != 0:
avg_path_len = sum(avg_path_lengths) / len(avg_path_lengths)
else:
avg_path_len = 0.0

# We use literal_eval here to easily coerce to either ints or floats, depending.
graph_statistics = [ast.literal_eval(Path(file).read_text()) for file in statistics_files]

# Initialize list to store current network information
cur_nw_info = [nw_name, number_nodes, number_edges, ncc, density, max_degree, median_degree, max_diameter, avg_path_len]
cur_nw_info = [nw_name, *graph_statistics]

# Iterate through each node property and save the intersection with the current network
for node_list in nodes_by_col:
Expand All @@ -105,8 +72,10 @@ def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, alg
# Save the current network information to the network summary list
nw_info.append(cur_nw_info)

# Get the list of statistic names by their file names
statistics_options = [Path(file).stem for file in statistics_files]
# Prepare column names
col_names = ['Name', 'Number of nodes', 'Number of edges', 'Number of connected components', 'Density', 'Max degree', 'Median degree', 'Max diameter', 'Average path length']
col_names = ['Name', *statistics_options]
col_names.extend(nodes_by_col_labs)
col_names.append('Parameter combination')

Expand Down
74 changes: 72 additions & 2 deletions spras/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
"""

import copy as copy
import functools
import hashlib
import importlib.metadata
import itertools as it
import subprocess
import tomllib
import warnings
from pathlib import Path
from typing import Any
Expand All @@ -27,6 +32,59 @@

config = None

@functools.cache
def spras_revision() -> str:
"""
Gets the revision of the current SPRAS repository. This function is meant to be user-friendly to warn for bad SPRAS installs.
1. If this file is inside the correct `.git` repository, we use the revision hash. This is for development in SPRAS as well as SPRAS installs via a cloned git repository.
2. If SPRAS was installed via a PyPA-compliant package manager, we use the hash of the RECORD file (https://packaging.python.org/en/latest/specifications/recording-installed-packages/#the-record-file).
which contains the hashes of all installed files to the package.
"""
clone_tip = "Make sure SPRAS is installed through the installation instructions: https://spras.readthedocs.io/en/latest/install.html."

# Check if we're inside the right git repository
try:
project_directory = subprocess.check_output(
["git", "rev-parse", "--show-toplevel"],
encoding='utf-8',
# In case the CWD is not inside the actual SPRAS directory
cwd=Path(__file__).parent.resolve()
).strip()

# We check the pyproject.toml name attribute to confirm that this is the SPRAS project. This is susceptible
# to false negatives, but we use this as a preliminary check against bad SPRAS installs.
pyproject_path = Path(project_directory, 'pyproject.toml')
try:
pyproject_toml = tomllib.loads(pyproject_path.read_text())
if "project" not in pyproject_toml or "name" not in pyproject_toml["project"]:
raise RuntimeError(f"The git top-level `{pyproject_path}` does not have the expected attributes. {clone_tip}")
if pyproject_toml["project"]["name"] != "spras":
raise RuntimeError(f"The git top-level `{pyproject_path}` is not the SPRAS pyproject.toml. {clone_tip}")

return subprocess.check_output(
["git", "rev-parse", "--short", "HEAD"],
encoding='utf-8',
cwd=project_directory
).strip()
except FileNotFoundError as err:
# pyproject.toml wasn't found during the `read_text` call
raise RuntimeError(f"The git top-level {pyproject_path} wasn't found. {clone_tip}") from err
except tomllib.TOMLDecodeError as err:
raise RuntimeError(f"The git top-level {pyproject_path} is malformed. {clone_tip}") from err
except subprocess.CalledProcessError:
try:
# `git` failed: use the truncated hash of the RECORD file in .dist-info instead.
record_path = str(importlib.metadata.distribution('spras').locate_file(f"spras-{importlib.metadata.version('spras')}.dist-info/RECORD"))
with open(record_path, 'rb', buffering=0) as f:
# Truncated to the magic value 8, the length of the short git revision.
return hashlib.file_digest(f, 'sha256').hexdigest()[:8]
except importlib.metadata.PackageNotFoundError as err:
# The metadata.distribution call failed.
raise RuntimeError(f"The spras package wasn't found: {clone_tip}") from err

def attach_spras_revision(label: str) -> str:
return f"{label}_{spras_revision()}"

# This will get called in the Snakefile, instantiating the singleton with the raw config
def init_global(config_dict):
global config
Expand Down Expand Up @@ -117,6 +175,12 @@ def process_datasets(self, raw_config: RawConfig):
# Currently assumes all datasets have a label and the labels are unique
# When Snakemake parses the config file it loads the datasets as OrderedDicts not dicts
# Convert to dicts to simplify the yaml logging

for dataset in raw_config.datasets:
dataset.label = attach_spras_revision(dataset.label)
for gold_standard in raw_config.gold_standards:
gold_standard.label = attach_spras_revision(gold_standard.label)

for dataset in raw_config.datasets:
label = dataset.label
if label.lower() in [key.lower() for key in self.datasets.keys()]:
Expand All @@ -130,8 +194,11 @@ def process_datasets(self, raw_config: RawConfig):
dataset_labels = set(self.datasets.keys())
gold_standard_dataset_labels = {dataset_label for value in self.gold_standards.values() for dataset_label in value['dataset_labels']}
for label in gold_standard_dataset_labels:
if label not in dataset_labels:
if attach_spras_revision(label) not in dataset_labels:
raise ValueError(f"Dataset label '{label}' provided in gold standards does not exist in the existing dataset labels.")
# We attach the SPRAS revision to the individual dataset labels afterwards for a cleaner error message above.
for key, gold_standard in self.gold_standards.items():
self.gold_standards[key]["dataset_labels"] = map(attach_spras_revision, gold_standard["dataset_labels"])

# Code snipped from Snakefile that may be useful for assigning default labels
# dataset_labels = [dataset.get('label', f'dataset{index}') for index, dataset in enumerate(datasets)]
Expand Down Expand Up @@ -187,7 +254,10 @@ def process_algorithms(self, raw_config: RawConfig):
run_dict[param] = float(value)
if isinstance(value, np.ndarray):
run_dict[param] = value.tolist()
params_hash = hash_params_sha1_base32(run_dict, self.hash_length, cls=NpHashEncoder)
# Incorporates the `spras_revision` into the hash
hash_run_dict = copy.deepcopy(run_dict)
hash_run_dict["_spras_rev"] = spras_revision()
params_hash = hash_params_sha1_base32(hash_run_dict, self.hash_length, cls=NpHashEncoder)
if params_hash in prior_params_hashes:
raise ValueError(f'Parameter hash collision detected. Increase the hash_length in the config file '
f'(current length {self.hash_length}).')
Expand Down
70 changes: 70 additions & 0 deletions spras/statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""
Graph statistics, used to power summary.py.

We allow for arbitrary computation of any specific statistic on some graph,
computing more than necessary if we have dependencies. See the top level
`statistics_computation` dictionary for usage.
"""

import itertools
from statistics import median
from typing import Callable

import networkx as nx


def compute_degree(graph: nx.DiGraph) -> tuple[int, float]:
"""
Computes the (max, median) degree of a `graph`.
"""
# number_of_nodes is a cheap call
if graph.number_of_nodes() == 0:
return (0, 0.0)
else:
degrees = [deg for _, deg in graph.degree()]
return max(degrees), median(degrees)

def compute_on_cc(directed_graph: nx.DiGraph) -> tuple[int, float]:
graph: nx.Graph = directed_graph.to_undirected()
cc = list(nx.connected_components(graph))
# Save the max diameter
# Use diameter only for components with ≥2 nodes (singleton components have diameter 0)
diameters = [
nx.diameter(graph.subgraph(c).copy()) if len(c) > 1 else 0
for c in cc
]
max_diameter = max(diameters, default=0)

# Save the average path lengths
# Compute average shortest path length only for components with ≥2 nodes (undefined for singletons, set to 0.0)
avg_path_lengths = [
nx.average_shortest_path_length(graph.subgraph(c).copy()) if len(c) > 1 else 0.0
for c in cc
]

if len(avg_path_lengths) != 0:
avg_path_len = sum(avg_path_lengths) / len(avg_path_lengths)
else:
avg_path_len = 0.0

return max_diameter, avg_path_len

# The type signature on here is quite bad. I would like to say that an n-tuple has n-outputs.
statistics_computation: dict[tuple[str, ...], Callable[[nx.DiGraph], tuple[float | int, ...]]] = {
('Number of nodes',): lambda graph : (graph.number_of_nodes(),),
('Number of edges',): lambda graph : (graph.number_of_edges(),),
('Number of connected components',): lambda graph : (nx.number_connected_components(graph.to_undirected()),),
('Density',): lambda graph : (nx.density(graph),),

('Max degree', 'Median degree'): compute_degree,
('Max diameter', 'Average path length'): compute_on_cc,
}

# All of the keys inside statistics_computation, flattened.
statistics_options: list[str] = list(itertools.chain(*(list(key) for key in statistics_computation.keys())))

def from_output_pathway(lines) -> nx.Graph:
with open(lines, 'r') as f:
lines = f.readlines()[1:]

return nx.read_edgelist(lines, data=(('Rank', int), ('Direction', str)))
Loading
Loading