diff --git a/spras/config/config.py b/spras/config/config.py index abfddd64..ff0a7180 100644 --- a/spras/config/config.py +++ b/spras/config/config.py @@ -53,6 +53,8 @@ def __init__(self, raw_config: dict[str, Any]): self.out_dir = parsed_raw_config.reconstruction_settings.locations.reconstruction_dir # A dictionary to store configured datasets against which SPRAS will be run self.datasets: dict[str, DatasetSchema] = {} + # A dictionary to store dataset categories with their associated dataset labels + self.dataset_categories: dict[str, list[str]] = {} # A dictionary to store configured gold standard data against output of SPRAS runs self.gold_standards = None # The hash length SPRAS will use to identify parameter combinations. @@ -119,10 +121,19 @@ def process_datasets(self, raw_config: RawConfig): # Convert to dicts to simplify the yaml logging for dataset in raw_config.datasets: label = dataset.label - if label.lower() in [key.lower() for key in self.datasets.keys()]: + if label.casefold() in [key.casefold() for key in self.datasets.keys()]: raise ValueError(f"Datasets must have unique case-insensitive labels, but the label {label} appears at least twice.") self.datasets[label] = dataset + # Extra check for conflicting categories which we don't store, yet. + category = dataset.category + if category: + if category.casefold() in [key.casefold() for key in self.datasets.keys()]: + raise ValueError(f"Dataset categories can not appear as (case-insensitive) labels, yet category {category} appears as a label.") + + category_dataset_labels = self.dataset_categories.setdefault(category, []) + category_dataset_labels.append(dataset.label) + # parse gold standard information self.gold_standards = {gold_standard.label: dict(gold_standard) for gold_standard in raw_config.gold_standards} diff --git a/spras/config/dataset.py b/spras/config/dataset.py index 9af41338..6fb03ff5 100644 --- a/spras/config/dataset.py +++ b/spras/config/dataset.py @@ -1,4 +1,4 @@ -from typing import Annotated +from typing import Annotated, Optional from pydantic import AfterValidator, BaseModel, ConfigDict @@ -19,5 +19,7 @@ class DatasetSchema(BaseModel): edge_files: list[LoosePathLike] other_files: list[LoosePathLike] data_dir: LoosePathLike + category: Optional[str] = None + "The dataset category, for working with dataset collections in the configuration." model_config = ConfigDict(extra='forbid') diff --git a/spras/dataset.py b/spras/dataset.py index 0e8f9de1..a2c74b18 100644 --- a/spras/dataset.py +++ b/spras/dataset.py @@ -1,7 +1,7 @@ import os import pickle as pkl import warnings -from typing import Union +from typing import NotRequired, Union import pandas as pd @@ -93,6 +93,7 @@ def __init__(self, dataset_params: DatasetSchema): """ self.label = dataset_params.label + self.category = dataset_params.category # Get file paths from config # TODO support multiple edge files diff --git a/test/test_config.py b/test/test_config.py index 41551c38..1015f6ee 100644 --- a/test/test_config.py +++ b/test/test_config.py @@ -39,12 +39,14 @@ def get_test_config(): }, "datasets": [{ "label": "alg1", + "category": "category1", "data_dir": "fake", "edge_files": [], "other_files": [], "node_files": [] }, { "label": "alg2", + "category": "category2", "data_dir": "faux", "edge_files": [], "other_files": [], @@ -238,6 +240,28 @@ def test_correct_dataset_label(self): test_config["datasets"] = [test_dict] config.init_global(test_config) # no error should be raised + def test_correct_dataset_category(self): + test_config = get_test_config() + config.init_global(test_config) + assert config.config.dataset_categories + assert len(config.config.dataset_categories["category1"]) == 1 + assert len(config.config.dataset_categories["category2"]) == 1 + + def test_multiple_dataset_category(self): + test_config = get_test_config() + for dataset in test_config["datasets"]: + dataset["category"] = "category1" + config.init_global(test_config) + assert config.config.dataset_categories + assert len(config.config.dataset_categories["category1"]) == 2 + + def test_bad_dataset_category(self): + test_config = get_test_config() + for dataset in test_config["datasets"]: + dataset["category"] = "alg2" + with pytest.raises(ValueError): # categories can not match dataset labels + config.init_global(test_config) + def test_error_gs_label(self): test_config = get_test_config() error_labels = ["test$", "@test'"]