Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion spras/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def __init__(self, raw_config: dict[str, Any]):
self.out_dir = parsed_raw_config.reconstruction_settings.locations.reconstruction_dir
# A dictionary to store configured datasets against which SPRAS will be run
self.datasets: dict[str, DatasetSchema] = {}
# A dictionary to store dataset categories with their associated dataset labels
self.dataset_categories: dict[str, list[str]] = {}
# A dictionary to store configured gold standard data against output of SPRAS runs
self.gold_standards = None
# The hash length SPRAS will use to identify parameter combinations.
Expand Down Expand Up @@ -119,10 +121,19 @@ def process_datasets(self, raw_config: RawConfig):
# Convert to dicts to simplify the yaml logging
for dataset in raw_config.datasets:
label = dataset.label
if label.lower() in [key.lower() for key in self.datasets.keys()]:
if label.casefold() in [key.casefold() for key in self.datasets.keys()]:
raise ValueError(f"Datasets must have unique case-insensitive labels, but the label {label} appears at least twice.")
self.datasets[label] = dataset

# Extra check for conflicting categories which we don't store, yet.
category = dataset.category
if category:
if category.casefold() in [key.casefold() for key in self.datasets.keys()]:
raise ValueError(f"Dataset categories can not appear as (case-insensitive) labels, yet category {category} appears as a label.")

category_dataset_labels = self.dataset_categories.setdefault(category, [])
category_dataset_labels.append(dataset.label)

# parse gold standard information
self.gold_standards = {gold_standard.label: dict(gold_standard) for gold_standard in raw_config.gold_standards}

Expand Down
4 changes: 3 additions & 1 deletion spras/config/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Annotated
from typing import Annotated, Optional

from pydantic import AfterValidator, BaseModel, ConfigDict

Expand All @@ -19,5 +19,7 @@ class DatasetSchema(BaseModel):
edge_files: list[LoosePathLike]
other_files: list[LoosePathLike]
data_dir: LoosePathLike
category: Optional[str] = None
"The dataset category, for working with dataset collections in the configuration."

model_config = ConfigDict(extra='forbid')
3 changes: 2 additions & 1 deletion spras/dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import pickle as pkl
import warnings
from typing import Union
from typing import NotRequired, Union

import pandas as pd

Expand Down Expand Up @@ -93,6 +93,7 @@ def __init__(self, dataset_params: DatasetSchema):
"""

self.label = dataset_params.label
self.category = dataset_params.category

# Get file paths from config
# TODO support multiple edge files
Expand Down
24 changes: 24 additions & 0 deletions test/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ def get_test_config():
},
"datasets": [{
"label": "alg1",
"category": "category1",
"data_dir": "fake",
"edge_files": [],
"other_files": [],
"node_files": []
}, {
"label": "alg2",
"category": "category2",
"data_dir": "faux",
"edge_files": [],
"other_files": [],
Expand Down Expand Up @@ -238,6 +240,28 @@ def test_correct_dataset_label(self):
test_config["datasets"] = [test_dict]
config.init_global(test_config) # no error should be raised

def test_correct_dataset_category(self):
test_config = get_test_config()
config.init_global(test_config)
assert config.config.dataset_categories
assert len(config.config.dataset_categories["category1"]) == 1
assert len(config.config.dataset_categories["category2"]) == 1

def test_multiple_dataset_category(self):
test_config = get_test_config()
for dataset in test_config["datasets"]:
dataset["category"] = "category1"
config.init_global(test_config)
assert config.config.dataset_categories
assert len(config.config.dataset_categories["category1"]) == 2

def test_bad_dataset_category(self):
test_config = get_test_config()
for dataset in test_config["datasets"]:
dataset["category"] = "alg2"
with pytest.raises(ValueError): # categories can not match dataset labels
config.init_global(test_config)

def test_error_gs_label(self):
test_config = get_test_config()
error_labels = ["test$", "@test'"]
Expand Down
Loading