diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index fbfaedf..41d3ac3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -29,10 +29,11 @@ jobs: # Below, we first run pytest in the `tests/` folder. Because we use a `src` # layout, this will fail if the package is not installed correctly. - - name: Test package is installable + - name: Patient-specific and installation tests run: pytest --cov=lydata --cov-config=pyproject.toml tests env: COVERAGE_FILE: .coverage.is_installable + GITHUB_TOKEN: ${{ secrets.LYCOSYSTEM_READALL }} # Now, we execute all doctests in the `src` tree. This will NOT run with # the installed code, but it doesn't matter, because we already know it is diff --git a/.gitignore b/.gitignore index 0437039..bc8e0b4 100644 --- a/.gitignore +++ b/.gitignore @@ -176,5 +176,5 @@ pyrightconfig.json # End of https://www.toptal.com/developers/gitignore/api/python **/_version.py -# VS Code +## VS Code .vscode/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b3947a..036d024 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,85 @@ All notable changes to this project will be documented in this file. +## [0.4.0] - 2025-09-04 + +### ๐Ÿš€ Features + +- Allow custom functions via `.pass_to()` of `C` objects +- Allow importing `LyDataFrame` type from root +- Add working sorting functions for `LyDataFrame` +- Add convenience `.ly.enhance()` method +- Add pydantic patient/tumor model +- Add schema for modalities +- Add working dtype casting function +- Add `.ly.cast()` to lydata accessor +- Add function to write JSON schema to file +- Add pre-/suffixes to T/N stages in schema +- Casting, validating, & enhancing during load +- Add a `.get_tnm()` helper method +- Fail more informatively when loading. Fixes [#10]. +- Add `.ly.location` to short column access + +### ๐Ÿ› Bug Fixes + +- [**breaking**] Combine mods & lvl info using probabilities over likelihoods +- Use spec/sens < 1 in `augment` +- Make `LyDataFrame` importable +- Ensure alignment of columns during combine/augment +- Change mid-level column from `info` to `core` +- Don't override superlevel when sublevels unknown +- Join using "outer" in `.ly.enhance()` +- Avoid `None`s due to index mismatch etc. +- Replace instead of udpdate augmented columns +- Augment during combine for max_llh/rank +- Use default subdivisions in `.ly.enhance()` +- Make casting safer and better +- Avoid pydantic's weird `TypeError` for `pd.NaT` +- Check central info in schema +- Call `logger.error` over `exception` +- Allow MX=-1 in schema +- Allow `None` in more patient fields +- Side may be `None` when central=`True` +- Make some fields robust to uppercase strings +- Allow loading from disk using custom paths +- Get github fetch working again + +### ๐Ÿ’ผ Other + +- Don't use `or` to check for `None` arg +- [**breaking**] remove old functions to infer/combine data +- Move `C` & `Q` to own module +- [**breaking**] Update schema for new 2nd lvl cols +- Improve final sorting of tables +- [**breaking**] Rewrite validation using new schema +- [**breaking**] Start using only pydantic schema for validation +- Update mid-level cols to new `core` +- Remove typer dependency + +### ๐Ÿ“š Documentation + +- Add more info to augment/combine +- Update some docstrings +- Add docstrings to JS code +- Update schema & validation docstrings +- Add new modules to sphinx + +### ๐Ÿงช Testing + +- Test new combine/augment with CLB patient 17 +- Add basic `.ly.combine()` test +- Add scripts to compare augment/combine +- Check one patient with specific issue +- Add util doctest (though unnecessary) +- Add some more patient-specific checks +- Ensure basic functionality of schemas +- Cover casting with minimal checks +- Update schema test to use `core`, too +- Add another 2025-USZ patient to test cases +- Fix small issues causing tests to fail +- Update to new, cast data +- Ensure .env is loaded during all tests + ## [0.3.3] - 2025-07-22 ### ๐Ÿš€ Features @@ -301,6 +380,9 @@ Initial implementation of the lyDATA library. +[0.4.0]: https://github.com/lycosystem/lydata-package/compare/0.3.3..0.4.0 +[0.3.3]: https://github.com/lycosystem/lydata-package/compare/0.3.2..0.3.3 +[0.3.2]: https://github.com/lycosystem/lydata-package/compare/0.3.1..0.3.2 [0.3.1]: https://github.com/lycosystem/lydata-package/compare/0.3.0..0.3.1 [0.3.0]: https://github.com/lycosystem/lydata-package/compare/8ae13..0.3.0 [0.2.5]: https://github.com/lycosystem/lydata/compare/0.2.4..0.2.5 @@ -321,3 +403,4 @@ Initial implementation of the lyDATA library. [#4]: https://github.com/lycosystem/lydata/issues/4 [#13]: https://github.com/lycosystem/lydata/issues/13 [#5]: https://github.com/lycosystem/lydata-package/issues/5 +[#10]: https://github.com/lycosystem/lydata-package/issues/10 diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..69dea87 --- /dev/null +++ b/conftest.py @@ -0,0 +1,5 @@ +"""Pytest configuration and fixtures for lydata tests.""" + +from dotenv import load_dotenv + +load_dotenv() diff --git a/docs/source/augmentor.rst b/docs/source/augmentor.rst new file mode 100644 index 0000000..5333954 --- /dev/null +++ b/docs/source/augmentor.rst @@ -0,0 +1,7 @@ +.. currentmodule:: lydata.augmentor + +Enhancing and Augmenting Datasets +================================= + +.. automodule:: lydata.augmentor + :members: diff --git a/docs/source/index.md b/docs/source/index.md index 4bd9338..157408a 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -9,7 +9,10 @@ :maxdepth: 2 accessor +augmentor loader +querier +schema utils validator ::: diff --git a/docs/source/querier.rst b/docs/source/querier.rst new file mode 100644 index 0000000..1ca22f0 --- /dev/null +++ b/docs/source/querier.rst @@ -0,0 +1,7 @@ +.. currentmodule:: lydata.querier + +Efficient and Reusable DataFrame Queries +======================================== + +.. automodule:: lydata.querier + :members: diff --git a/docs/source/schema.rst b/docs/source/schema.rst new file mode 100644 index 0000000..b09556b --- /dev/null +++ b/docs/source/schema.rst @@ -0,0 +1,7 @@ +.. currentmodule:: lydata.schema + +Formal Definition of a Patient Record +===================================== + +.. automodule:: lydata.schema + :members: diff --git a/docs/source/validator.rst b/docs/source/validator.rst index a19dbba..7cdfa0c 100644 --- a/docs/source/validator.rst +++ b/docs/source/validator.rst @@ -1,7 +1,7 @@ .. currentmodule:: lydata.validator -Pandera Schemas to Validate Datasets -==================================== +Type Casting and Validation +=========================== .. automodule:: lydata.validator :members: diff --git a/pyproject.toml b/pyproject.toml index 3411515..6a58096 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "pandera", "pydantic", "loguru", + "roman", ] [project.urls] @@ -44,6 +45,7 @@ docs = [ tests = [ "pytest", "pytest-cov", + "python-dotenv>=1.1.1", ] dev = [ "pre-commit", @@ -67,6 +69,9 @@ exclude = ["docs"] select = ["E", "F", "W", "B", "C", "R", "U", "D", "I", "S", "T", "A", "N"] ignore = ["B028", "N816", "E712"] +[tool.ruff.lint.per-file-ignores] +"tests/*" = ["S101"] + [tool.uv] package = true diff --git a/src/lydata/__init__.py b/src/lydata/__init__.py index 42b433e..3d8476a 100644 --- a/src/lydata/__init__.py +++ b/src/lydata/__init__.py @@ -3,13 +3,13 @@ from loguru import logger import lydata._version as _version -from lydata.accessor import C, Q +from lydata.accessor import LyDataFrame from lydata.loader import ( available_datasets, load_datasets, ) -from lydata.utils import infer_and_combine_levels -from lydata.validator import validate_datasets +from lydata.querier import C, Q +from lydata.validator import is_valid __author__ = "Roman Ludwig" __email__ = "roman.ludwig@usz.ch" @@ -17,13 +17,14 @@ __version__ = _version.__version__ __all__ = [ + "LyDataFrame", "accessor", "Q", "C", "available_datasets", "load_datasets", - "validate_datasets", - "infer_and_combine_levels", + "is_valid", ] logger.disable("lydata") +logger.remove() diff --git a/src/lydata/accessor.py b/src/lydata/accessor.py index 24fc94e..c297a94 100644 --- a/src/lydata/accessor.py +++ b/src/lydata/accessor.py @@ -1,4 +1,4 @@ -"""Module containing a custom accessor and helpers for querying lyDATA. +"""Module containing a custom accessor for interacting with lyDATA tables. Because of the special three-level header of the lyDATA tables, it is sometimes cumbersome and lengthy to access the columns. While this is certainly necessary to @@ -10,405 +10,50 @@ the above mentioned functionality. That way, accessing the age of all patients is now as easy as typing ``df.ly.age``. -Beyond that, the module implements a convenient wat to query the -:py:class:`~pandas.DataFrame`: The :py:class:`Q` object, that was inspired by Django's -``Q`` object. It allows for more readable and modular queries, which can be combined -with logical operators and reused across different DataFrames. - -The :py:class:`Q` objects can be passed to the :py:meth:`LyDataAccessor.query` and -:py:meth:`LyDataAccessor.portion` methods to filter the DataFrame or compute the -:py:class:`QueryPortion` of rows that satisfy the query. Alternatively, any of these -:py:class:`Q` objects have a method called :py:meth:`~Q.execute` that can be called with -a :py:class:`~pandas.DataFrame` to get a boolean mask of the rows satisfying the query. - -Further, we implement methods like :py:meth:`~LyDataAccessor.combine`, -:py:meth:`~LyDataAccessor.infer_sublevels`, and -:py:meth:`~LyDataAccessor.infer_superlevels` to compute additional columns from the -lyDATA tables. This is sometimes necessary, because not all data contains all the -possibly necessary columns. E.g., in some cohorts we do have detailed sublevel -information (i.e., IIa and IIb), while in others only the superlevel (II) is reported. -In such a case, one can now simply call ``df.ly.infer_sublevels()`` to get the -additional columns. +Beyond that, we implement methods like :py:meth:`~LyDataAccessor.query` for filtering +the DataFrame using reusable query objects (see the :py:mod:`lydata.querier` module +for more information), :py:meth:`~LyDataAccessor.stats` for computing common statistics +that we use in our `LyProX`_ web app, and :py:meth:`~LyDataAccessor.combine` for +combining diagnoses from different modalities into a single column. + +.. _LyProX: https://lyprox.org/ """ from __future__ import annotations import warnings -from collections.abc import Callable +from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass -from itertools import product from typing import Any, Literal -import numpy as np import pandas as pd import pandas.api.extensions as pd_ext +from lydata.augmentor import combine_and_augment_levels +from lydata.types import CanExecute from lydata.utils import ( ModalityConfig, + _get_all_true, + _sort_all, get_default_column_map_new, get_default_column_map_old, get_default_modalities, + replace, ) -from lydata.validator import construct_schema warnings.simplefilter(action="ignore", category=pd.errors.PerformanceWarning) -def _get_all_true(df: pd.DataFrame) -> pd.Series: - """Return a mask with all entries set to ``True``.""" - return pd.Series([True] * len(df)) - - -class CombineQMixin: - """Mixin class for combining queries. - - Four operators are defined for combining queries: - - 1. ``&`` for logical AND operations. - The returned object is an :py:class:`AndQ` instance and - when executed - - returns a boolean mask where both queries are satisfied. When the right-hand - side is ``None``, the left-hand side query object is returned unchanged. - 2. ``|`` for logical OR operations. - The returned object is an :py:class:`OrQ` instance and - when executed - - returns a boolean mask where either query is satisfied. When the right-hand - side is ``None``, the left-hand side query object is returned unchanged. - 3. ``~`` for inverting a query. - The returned object is a :py:class:`NotQ` instance and - when executed - - returns a boolean mask where the query is not satisfied. - 4. ``==`` for checking if two queries are equal. - Two queries are equal if their column names, operators, and values are equal. - Note that this does not check if the queries are semantically equal, i.e., if - they would return the same result when executed. - """ - - def __and__(self, other: QTypes | None) -> AndQ: - """Combine two queries with a logical AND.""" - return self if other is None else AndQ(self, other) - - def __or__(self, other: QTypes | None) -> OrQ: - """Combine two queries with a logical OR.""" - return self if other is None else OrQ(self, other) - - def __invert__(self) -> NotQ: - """Negate the query.""" - return NotQ(self) - - def __eq__(self, value): - """Check if two queries are equal.""" - return ( - isinstance(value, self.__class__) - and self.colname == value.colname - and self.operator == value.operator - and self.value == value.value - ) - - -class Q(CombineQMixin): - """Combinable query object for filtering a DataFrame. - - The syntax for this object is similar to Django's ``Q`` object. It can be used to - define queries in a more readable and modular way. - - .. caution:: - - The column names are not checked upon instantiation. This is only done when the - query is executed. In fact, the :py:class:`Q` object does not even know about - the :py:class:`~pandas.DataFrame` it will be applied to in the beginning. On the - flip side, this means a query may be reused for different DataFrames. - - The ``operator`` argument may be one of the following: - - - ``'=='``: Checks if ``column`` values are equal to the ``value``. - - ``'<'``: Checks if ``column`` values are less than the ``value``. - - ``'<='``: Checks if ``column`` values are less than or equal to ``value``. - - ``'>'``: Checks if ``column`` values are greater than the ``value``. - - ``'>='``: Checks if ``column`` values are greater than or equal to ``value``. - - ``'!='``: Checks if ``column`` values are not equal to the ``value``. This is - equivalent to ``~Q(column, '==', value)``. - - ``'in'``: Checks if ``column`` values are in the list of ``value``. For this, - pandas' :py:meth:`~pandas.Series.isin` method is used. - - ``'contains'``: Checks if ``column`` values contain the string ``value``. - Here, pandas' :py:meth:`~pandas.Series.str.contains` method is used. - - .. note:: - - During initialization, a private attribute ``_column_map`` is set to the - default column map returned by :py:func:`~lydata.utils.get_default_column_map`. - This is used to convert short column names to long ones. If one feels - adventurous, they may set this attribute to a custom column map containing - additional or other column short names. This could also be achieved by - subclassing the :py:class:`Q`. However, the attribute may change in the future, - and without notice. - """ - - _OPERATOR_MAP: dict[str, Callable[[pd.Series, Any], pd.Series]] = { - "==": lambda series, value: series == value, - "<": lambda series, value: series < value, - "<=": lambda series, value: series <= value, - ">": lambda series, value: series > value, - ">=": lambda series, value: series >= value, - "!=": lambda series, value: series != value, # same as ~Q("col", "==", value) - "in": lambda series, value: series.isin(value), # value is a list - "contains": lambda series, value: series.str.contains(value), # value is a str - } - - def __init__( - self, - column: str, - operator: Literal["==", "<", "<=", ">", ">=", "!=", "in", "contains"], - value: Any, - ) -> None: - """Create query object that can compare a ``column`` with a ``value``.""" - self.colname = column - self.operator = operator - self.value = value - - def __repr__(self) -> str: - """Return a string representation of the query.""" - return f"Q({self.colname!r}, {self.operator!r}, {self.value!r})" - - def execute(self, df: pd.DataFrame) -> pd.Series: - """Return a boolean mask where the query is satisfied for ``df``. - - >>> df = pd.DataFrame({'col1': [1, 2, 3], 'col2': ['foo', 'bar', 'baz']}) - >>> Q('col1', '<=', 2).execute(df) - 0 True - 1 True - 2 False - Name: col1, dtype: bool - >>> Q('col2', 'contains', 'ba').execute(df) - 0 False - 1 True - 2 True - Name: col2, dtype: bool - """ - column = df.ly[self.colname] - - if callable(self.value): - return self.value(column) - - return self._OPERATOR_MAP[self.operator](column, self.value) - - -class AndQ(CombineQMixin): - """Query object for combining two queries with a logical AND. - - >>> df = pd.DataFrame({'col1': [1, 2, 3], 'col2': ['foo', 'bar', 'baz']}) - >>> q1 = Q('col1', '!=', 3) - >>> q2 = Q('col2', 'contains', 'ba') - >>> and_q = q1 & q2 - >>> print(and_q) - (Q('col1', '!=', 3) & Q('col2', 'contains', 'ba')) - >>> isinstance(and_q, AndQ) - True - >>> and_q.execute(df) - 0 False - 1 True - 2 False - dtype: bool - >>> all((q1 & None).execute(df) == q1.execute(df)) - True - """ - - def __init__(self, q1: QTypes, q2: QTypes) -> None: - """Combine two queries with a logical AND.""" - self.q1 = q1 - self.q2 = q2 - - def __repr__(self) -> str: - """Return a string representation of the query.""" - return f"({self.q1!r} & {self.q2!r})" - - def execute(self, df: pd.DataFrame) -> pd.Series: - """Return a boolean mask where both queries are satisfied.""" - return self.q1.execute(df) & self.q2.execute(df) - - -class OrQ(CombineQMixin): - """Query object for combining two queries with a logical OR. - - >>> df = pd.DataFrame({'col1': [1, 2, 3]}) - >>> q1 = Q('col1', '==', 1) - >>> q2 = Q('col1', '==', 3) - >>> or_q = q1 | q2 - >>> print(or_q) - (Q('col1', '==', 1) | Q('col1', '==', 3)) - >>> isinstance(or_q, OrQ) - True - >>> or_q.execute(df) - 0 True - 1 False - 2 True - Name: col1, dtype: bool - >>> all((q1 | None).execute(df) == q1.execute(df)) - True - """ - - def __init__(self, q1: QTypes, q2: QTypes) -> None: - """Combine two queries with a logical OR.""" - self.q1 = q1 - self.q2 = q2 - - def __repr__(self) -> str: - """Return a string representation of the query.""" - return f"({self.q1!r} | {self.q2!r})" - - def execute(self, df: pd.DataFrame) -> pd.Series: - """Return a boolean mask where either query is satisfied.""" - return self.q1.execute(df) | self.q2.execute(df) - - -class NotQ(CombineQMixin): - """Query object for negating a query. - - >>> df = pd.DataFrame({'col1': [1, 2, 3]}) - >>> q = Q('col1', '==', 2) - >>> not_q = ~q - >>> print(not_q) - ~Q('col1', '==', 2) - >>> isinstance(not_q, NotQ) - True - >>> not_q.execute(df) - 0 True - 1 False - 2 True - Name: col1, dtype: bool - >>> print(~(Q('col1', '==', 2) & Q('col1', '!=', 3))) - ~(Q('col1', '==', 2) & Q('col1', '!=', 3)) - """ - - def __init__(self, q: QTypes) -> None: - """Negate the given query ``q``.""" - self.q = q - - def __repr__(self) -> str: - """Return a string representation of the query.""" - return f"~{self.q!r}" - - def execute(self, df: pd.DataFrame) -> pd.Series: - """Return a boolean mask where the query is not satisfied.""" - return ~self.q.execute(df) - - -class NoneQ(CombineQMixin): - """Query object that always returns the entire DataFrame. Useful as default.""" - - def __repr__(self) -> str: - """Return a string representation of the query.""" - return "NoneQ()" - - def execute(self, df: pd.DataFrame) -> pd.Series: - """Return a boolean mask with all entries set to ``True``.""" - return _get_all_true(df) - - -QTypes = Q | AndQ | OrQ | NotQ | None -"""Type for a query object or a combination of query objects.""" - - -class C: - """Wraps a column name and produces a :py:class:`Q` object upon comparison. +AggFuncType = dict[str | tuple[str, str, str], Callable[[pd.Series], pd.Series]] - This is basically a shorthand for creating a :py:class:`Q` object that avoids - writing the operator and value in quotes. Thus, it may be more readable and allows - IDEs to provide better autocompletion. - .. caution:: +@dataclass # we use a dataclass over pydantic, because it allows positional arguments +class QueryPortion: + """Dataclass for storing the portion of a query. - Just like for the :py:class:`Q` object, it is not checked upon instantiation - whether the column name is valid. This is only done when the query is executed. + An instance of this is returned by the :py:meth:`LyDataAccessor.portion` method. """ - def __init__(self, *column: str) -> None: - """Create a column object for comparison. - - For querying multi-level columns, both the syntax ``C('col1', 'col2')`` and - ``C(('col1', 'col2'))`` is valid. - - >>> (C('col1', 'col2') == 1) == (C(('col1', 'col2')) == 1) - True - """ - self.column = column[0] if len(column) == 1 else column - - def __repr__(self) -> str: - """Return a string representation of the column object. - - >>> repr(C('foo')) - "C('foo')" - >>> repr(C('foo', 'bar')) - "C(('foo', 'bar'))" - """ - return f"C({self.column!r})" - - def __eq__(self, value: Any) -> Q: - """Create a query object for comparing equality. - - >>> C('foo') == 'bar' - Q('foo', '==', 'bar') - """ - return Q(self.column, "==", value) - - def __lt__(self, value: Any) -> Q: - """Create a query object for comparing less than. - - >>> C('foo') < 42 - Q('foo', '<', 42) - """ - return Q(self.column, "<", value) - - def __le__(self, value: Any) -> Q: - """Create a query object for comparing less than or equal. - - >>> C('foo') <= 42 - Q('foo', '<=', 42) - """ - return Q(self.column, "<=", value) - - def __gt__(self, value: Any) -> Q: - """Create a query object for comparing greater than. - - >>> C('foo') > 42 - Q('foo', '>', 42) - """ - return Q(self.column, ">", value) - - def __ge__(self, value: Any) -> Q: - """Create a query object for comparing greater than or equal. - - >>> C('foo') >= 42 - Q('foo', '>=', 42) - """ - return Q(self.column, ">=", value) - - def __ne__(self, value: Any) -> Q: - """Create a query object for comparing inequality. - - >>> C('foo') != 'bar' - Q('foo', '!=', 'bar') - """ - return Q(self.column, "!=", value) - - def isin(self, value: list[Any]) -> Q: - """Create a query object for checking if the column values are in a list. - - >>> C('foo').isin([1, 2, 3]) - Q('foo', 'in', [1, 2, 3]) - """ - return Q(self.column, "in", value) - - def contains(self, value: str) -> Q: - """Create a query object for checking if the column values contain a string. - - >>> C('foo').contains('bar') - Q('foo', 'contains', 'bar') - """ - return Q(self.column, "contains", value) - - -@dataclass -class QueryPortion: - """Dataclass for storing the portion of a query.""" - match: int total: int @@ -452,7 +97,7 @@ def percent(self) -> float: >>> QueryPortion(2, 5).percent 40.0 """ - return self.ratio * 100 + return self.ratio * 100.0 def invert(self) -> QueryPortion: """Return the inverted portion. @@ -463,87 +108,6 @@ def invert(self) -> QueryPortion: return QueryPortion(match=self.fail, total=self.total) -def align_diagnoses( - dataset: pd.DataFrame, - modalities: list[str], -) -> list[pd.DataFrame]: - """Stack aligned diagnosis tables in ``dataset`` for each of ``modalities``.""" - diagnosis_stack = [] - for modality in modalities: - try: - this = dataset[modality].copy().drop(columns=["info"], errors="ignore") - except KeyError: - warnings.warn(f"Did not find modality {modality}, cannot align. Skipping.") # noqa - continue - - for i, other in enumerate(diagnosis_stack): - this, other = this.align(other, join="outer") - diagnosis_stack[i] = other - - diagnosis_stack.append(this) - - return diagnosis_stack - - -def _stack_to_float_matrix(diagnosis_stack: list[pd.DataFrame]) -> np.ndarray: - """Convert diagnosis stack to 3D array of floats with ``Nones`` as ``np.nan``.""" - diagnosis_matrix = np.array(diagnosis_stack) - diagnosis_matrix[pd.isna(diagnosis_matrix)] = np.nan - return np.astype(diagnosis_matrix, float) - - -def _evaluate_likelihood_ratios( - diagnosis_matrix: np.ndarray, - sensitivities: np.ndarray, - specificities: np.ndarray, - method: Literal["max_llh", "rank"], -) -> np.ndarray: - """Compare the likelihoods of true/false diagnoses using the given ``method``. - - The ``diagnosis_matrix`` is a 3D array of shape ``(n_modalities, n_patients, - n_levels)``. The ``sensitivities`` and ``specificities`` are 1D arrays of shape - ``(n_modalities,)``. When choosing the ``method="max_llh"``, the likelihood of each - diagnosis is combined into one likelihood for each patient and level. With - ``method="rank"``, the most trustworthy diagnosis is chosen for each patient and - level. - """ - true_pos = sensitivities[:, None, None] * diagnosis_matrix - false_neg = (1 - sensitivities[:, None, None]) * (1 - diagnosis_matrix) - true_neg = specificities[:, None, None] * (1 - diagnosis_matrix) - false_pos = (1 - specificities[:, None, None]) * diagnosis_matrix - - if method not in {"max_llh", "rank"}: - raise ValueError(f"Unknown method {method}") - - agg_func = np.nanprod if method == "max_llh" else np.nanmax - true_llh = agg_func(true_pos + false_neg, axis=0) - false_llh = agg_func(true_neg + false_pos, axis=0) - - return true_llh >= false_llh - - -def _expand_mapping( - short_map: dict[str, Any], - colname_map: dict[str | tuple[str, str, str], Any] | None = None, -) -> dict[tuple[str, str, str], Any]: - """Expand the column map to full column names. - - >>> _expand_mapping({'age': 'foo', 'hpv': 'bar'}) - {('patient', '#', 'age'): 'foo', ('patient', '#', 'hpv_status'): 'bar'} - """ - _colname_map = colname_map or get_default_column_map_old().from_short - expanded_map = {} - - for colname, func in short_map.items(): - expanded_colname = getattr(_colname_map.get(colname), "long", colname) - expanded_map[expanded_colname] = func - - return expanded_map - - -AggFuncType = dict[str | tuple[str, str, str], Callable[[pd.Series], pd.Series]] - - @pd_ext.register_dataframe_accessor("ly") class LyDataAccessor: """Custom accessor for handling lymphatic involvement data. @@ -558,6 +122,14 @@ def __init__(self, obj: pd.DataFrame) -> None: self._column_map_old = get_default_column_map_old() self._column_map_new = get_default_column_map_new() + def _get_safe_long_old(self, key: Any) -> tuple[str, str, str]: + """Get the old long column name or return the input.""" + return getattr(self._column_map_old.from_short.get(key), "long", key) + + def _get_safe_long_new(self, key: Any) -> tuple[str, str, str]: + """Get the new long column name or return the input.""" + return getattr(self._column_map_new.from_short.get(key), "long", key) + def __contains__(self, key: str) -> bool: """Check if a column is contained in the DataFrame. @@ -568,32 +140,35 @@ def __contains__(self, key: str) -> bool: False >>> ("patient", "#", "age") in df.ly True - >>> df = pd.DataFrame({("patient", "info", "age"): [61, 52, 73]}) + >>> df = pd.DataFrame({("patient", "core", "age"): [61, 52, 73]}) >>> "age" in df.ly True >>> "foo" in df.ly False - >>> ("patient", "info", "age") in df.ly + >>> ("patient", "core", "age") in df.ly True """ - _key_old = self._get_safe_long_old(key) - _key_new = self._get_safe_long_new(key) - return _key_new in self._obj or _key_old in self._obj + key_old = self._get_safe_long_old(key) + key_new = self._get_safe_long_new(key) + return key_new in self._obj or key_old in self._obj def __getitem__(self, key: str) -> pd.Series: - """Allow column access by short name, too.""" - _key_old = self._get_safe_long_old(key) - _key_new = self._get_safe_long_new(key) + """Allow column access by short name, too. - try: - return self._obj[_key_new] - except KeyError as err_from_new: - try: - return self._obj[_key_old] - except KeyError: - raise KeyError( - f"Neither '{_key_new}' nor '{_key_old}' found in DataFrame." - ) from err_from_new + >>> df = pd.DataFrame({("patient", "core", "nicotine_abuse"): [True, False]}) + >>> df.ly["smoke"] + 0 True + 1 False + Name: (patient, core, nicotine_abuse), dtype: bool + """ + key_old = self._get_safe_long_old(key) + key_new = self._get_safe_long_new(key) + + for key in (key_new, key_old): + if key in self: + return self._obj[key] + + raise KeyError(f"Neither '{key_new}' nor '{key_old}' found in DataFrame.") def __getattr__(self, name: str) -> Any: """Access columns also by short name. @@ -604,12 +179,12 @@ def __getattr__(self, name: str) -> Any: 1 52 2 73 Name: (patient, #, age), dtype: int64 - >>> df = pd.DataFrame({("patient", "info", "age"): [61, 52, 73]}) + >>> df = pd.DataFrame({("patient", "core", "age"): [61, 52, 73]}) >>> df.ly.age 0 61 1 52 2 73 - Name: (patient, info, age), dtype: int64 + Name: (patient, core, age), dtype: int64 >>> df.ly.foo Traceback (most recent call last): ... @@ -620,26 +195,11 @@ def __getattr__(self, name: str) -> Any: except KeyError as key_err: raise AttributeError(f"Attribute {name!r} not found.") from key_err - def _get_safe_long_old(self, key: Any) -> tuple[str, str, str]: - """Get the old long column name or return the input.""" - return getattr(self._column_map_old.from_short.get(key), "long", key) - - def _get_safe_long_new(self, key: Any) -> tuple[str, str, str]: - """Get the new long column name or return the input.""" - return getattr(self._column_map_new.from_short.get(key), "long", key) - def validate(self, modalities: list[str] | None = None) -> pd.DataFrame: - """Validate the DataFrame against the lydata schema. - - The schema is constructed by the :py:func:`construct_schema` function using - the ``modalities`` provided or it will :py:func:`get_default_modalities` if - ``None`` are provided. - """ - modalities = modalities or list(get_default_modalities().keys()) - lydata_schema = construct_schema(modalities=modalities) - return lydata_schema.validate(self._obj) + """Validate the DataFrame against the lydata schema.""" + raise NotImplementedError("Validation is not yet implemented.") - def get_modalities(self, _filter: list[str] | None = None) -> list[str]: + def get_modalities(self, ignore_cols: list[str] | None = None) -> list[str]: """Return the modalities present in this DataFrame. .. warning:: @@ -647,27 +207,74 @@ def get_modalities(self, _filter: list[str] | None = None) -> list[str]: This method assumes that all top-level columns are modalities, except for some predefined non-modality columns. For some custom dataset, this may not be correct. In that case, you should provide a list of columns to - ``_filter``, i.e., the columns that are *not* modalities. + ``ignore_cols``, i.e., the columns that are *not* modalities. """ top_level_cols = self._obj.columns.get_level_values(0) modalities = top_level_cols.unique().tolist() - for non_modality_col in _filter or [ - "patient", - "tumor", - "total_dissected", - "positive_dissected", - "enbloc_dissected", - "enbloc_positive", - ]: - try: - modalities.remove(non_modality_col) - except ValueError: - pass + if ignore_cols is None: + ignore_cols = [ + "patient", + "tumor", + "total_dissected", + "positive_dissected", + "enbloc_dissected", + "enbloc_positive", + ] + + for col in ignore_cols: + if col in modalities: + modalities.remove(col) return modalities - def query(self, query: QTypes = None) -> pd.DataFrame: + def get_tnm(self) -> pd.DataFrame: + """Return the T, N, and M stage with all pre- and suffixes. + + This info will be collected in three separate column `"T"`, `"N"`, and `"M"`. + + >>> df = pd.DataFrame({ + ... ('tumor', 'core', 't_stage_prefix'): ['c', 'p'], + ... ('tumor', 'core', 't_stage'): [2 , 3 ], + ... ('tumor', 'core', 't_stage_suffix'): ['a', 'b'], + ... ('patient', 'core', 'n_stage'): [1 , 2 ], + ... ('patient', 'core', 'n_stage_suffix'): ['a', 'b'], + ... ('patient', 'core', 'm_stage'): [0 , 1 ], + ... }) + >>> df.ly.get_tnm() # doctest: +NORMALIZE_WHITESPACE + T N M + 0 c2a 1a 0 + 1 p3b 2b 1 + """ + empty = pd.Series([""] * len(self._obj), index=self._obj.index) + result = pd.DataFrame(index=self._obj.index) + + for stage in ("t", "n", "m"): + tmp = pd.DataFrame(index=self._obj.index) + for part in ["prefix", "", "suffix"]: + name = "_".join([stage, "stage", part]).strip("_") + try: + col = self._obj.xs(name, axis="columns", level=2).iloc[:, 0] + except KeyError: + col = empty.copy() + + tmp = pd.concat([tmp, col], axis="columns") + + result[stage.upper()] = tmp.astype(str).agg("".join, axis="columns") + + return result + + def _get_mask(self, query: CanExecute | None = None) -> pd.Series: + """Safely get a boolean mask for the DataFrame based on the query.""" + if query is None: + return _get_all_true(self._obj) + + if isinstance(query, CanExecute): + return query.execute(self._obj) + + raise TypeError(f"Cannot query with {type(query).__name__}.") + + def query(self, query: CanExecute | None = None) -> pd.DataFrame: """Return a DataFrame with rows that satisfy the ``query``. A query is a :py:class:`Q` object that can be combined with logical operators. @@ -677,6 +284,7 @@ def query(self, query: QTypes = None) -> pd.DataFrame: :py:class:`C` object as in the example below, where we query all entries where ``x`` is greater than 1 and not less than 3: + >>> from lydata import C >>> df = pd.DataFrame({'x': [1, 2, 3]}) >>> df.ly.query((C('x') > 1) & ~(C('x') < 3)) x @@ -686,24 +294,29 @@ def query(self, query: QTypes = None) -> pd.DataFrame: 0 1 2 3 """ - mask = (query or NoneQ()).execute(self._obj) + mask = self._get_mask(query) return self._obj[mask] - def portion(self, query: QTypes = None, given: QTypes = None) -> QueryPortion: + def portion( + self, + query: CanExecute | None = None, + given: CanExecute | None = None, + ) -> QueryPortion: """Compute how many rows satisfy a ``query``, ``given`` some other conditions. This returns a :py:class:`QueryPortion` object that contains the number of rows satisfying the ``query`` and ``given`` :py:class:`Q` object divided by the number of rows satisfying only the ``given`` condition. + >>> from lydata import C >>> df = pd.DataFrame({'x': [1, 2, 3]}) >>> df.ly.portion(query=C('x') == 2, given=C('x') > 1) QueryPortion(match=np.int64(1), total=np.int64(2)) >>> df.ly.portion(query=C('x') == 2, given=C('x') > 3) QueryPortion(match=np.int64(0), total=np.int64(0)) """ - given_mask = (given or NoneQ()).execute(self._obj) - query_mask = (query or NoneQ()).execute(self._obj) + given_mask = self._get_mask(given) + query_mask = self._get_mask(query) return QueryPortion( match=query_mask[given_mask].sum(), @@ -721,7 +334,7 @@ def stats( The ``agg_funcs`` argument is a mapping of column names to functions that receive a :py:class:`pd.Series` and return a :py:class:`pd.Series`. The default is a useful selection of statistics for the most common columns. E.g., for the - column ``('patient', 'info', 'age')`` (or its short column name ``age``), the + column ``('patient', 'core', 'age')`` (or its short column name ``age``), the default function returns the value counts. The ``use_shortnames`` argument determines whether the output should use the @@ -740,9 +353,9 @@ def stats( 'hpv': {True: 2, False: 1, None: 1}, 't_stage': {2: 2, 3: 1, 1: 1}} >>> df = pd.DataFrame({ - ... ('patient', 'info', 'age'): [61, 52, 73, 61], - ... ('patient', 'info', 'hpv_status'): [True, False, None, True], - ... ('tumor', 'info', 't_stage'): [2, 3, 1, 2], + ... ('patient', 'core', 'age'): [61, 52, 73, 61], + ... ('patient', 'core', 'hpv_status'): [True, False, None, True], + ... ('tumor', 'core', 't_stage'): [2, 3, 1, 2], ... }) >>> df.ly.stats() # doctest: +NORMALIZE_WHITESPACE {'age': {61: 2, 52: 1, 73: 1}, @@ -765,12 +378,14 @@ def stats( return stats - def _filter_and_sort_modalities( + def _filter_modalities( self, modalities: dict[str, ModalityConfig] | None = None, ) -> dict[str, ModalityConfig]: - """Return only those ``modalities`` present in data and sorted as in data.""" - modalities = modalities or get_default_modalities() + """Keep only those ``modalities`` present in data.""" + if modalities is None: + modalities = get_default_modalities() + return { modality_name: modality_config for modality_name, modality_config in modalities.items() @@ -781,6 +396,7 @@ def combine( self, modalities: dict[str, ModalityConfig] | None = None, method: Literal["max_llh", "rank"] = "max_llh", + subdivisions: Mapping[str, Sequence[str]] | None = None, ) -> pd.DataFrame: """Combine diagnoses of ``modalities`` using ``method``. @@ -791,8 +407,27 @@ def combine( diagnosis is chosen for each patient and level based on the sensitivity and specificity of the given list of ``modalities``. - The result contains only the combined columns. The intended use is to - :py:meth:`~pandas.DataFrame.update` the original DataFrame with the result. + The result contains only the combined columns and no top-level header. This + means that if you want to add that to the original DataFrame, you could do so + like this: + + .. code-block:: python + + combined = data.ly.combine() + combined_full_header = pd.concat({"foo": combined}, axis="columns") + combined_full_header.index = data.index + data = pd.concat([data, combined_full_header], axis="columns") + + The method :py:func:`.enhance` is a shorthand for combining, augmenting, and + joining the results in a way similar to that example above. + + .. warning:: + + Here, the default value for ``subdivisions`` is set to an empty dictionary. + This is because on the one hand, we still want to retain the functionality + of combining and augmenting in one step (necessary in the + :py:meth:`.enhance` method), but if not explicitly chosen, we keep only + the originally provided levels. >>> df = pd.DataFrame({ ... ('CT' , 'ipsi', 'I'): [False, True , False, True, None], @@ -808,50 +443,36 @@ def combine( 3 False 4 None """ - modalities = self._filter_and_sort_modalities(modalities) - - diagnosis_stack = align_diagnoses(self._obj, list(modalities.keys())) - diagnosis_matrix = _stack_to_float_matrix(diagnosis_stack) - all_nan_mask = np.all(np.isnan(diagnosis_matrix), axis=0) - - result = _evaluate_likelihood_ratios( - diagnosis_matrix=diagnosis_matrix, - sensitivities=np.array([mod.sens for mod in modalities.values()]), - specificities=np.array([mod.spec for mod in modalities.values()]), + # We need the ability to pass the subdivisions for the `.enhance` method, + # but normally, we don't want to augment when combining. + if subdivisions is None: + subdivisions = {} + + modalities = self._filter_modalities(modalities) + obj_copy = self._obj.copy() + + return combine_and_augment_levels( + diagnoses=[obj_copy[mod] for mod in modalities.keys()], + specificities=[mod.spec for mod in modalities.values()], + sensitivities=[mod.sens for mod in modalities.values()], method=method, + subdivisions=subdivisions, ) - result = np.astype(result, object) - result[all_nan_mask] = None - return pd.DataFrame(result, columns=diagnosis_stack[0].columns) - def infer_sublevels( + def augment( self, - modalities: list[str] | None = None, - sides: list[Literal["ipsi", "contra"]] | None = None, + modality: str = "max_llh", subdivisions: dict[str, list[str]] | None = None, ) -> pd.DataFrame: - """Determine involvement status of an LNL's sublevels (e.g., IIa and IIb). - - Some LNLs have sublevels, e.g., IIa and IIb. The involvement of these sublevels - is not always reported, but only the superlevel's status. This function infers - the status of the sublevels from the superlevel. + """Complete the sub- and superlevel involvement columns. - The sublevel's status is computed for the specified ``modalities``. If and what - sublevels a superlevel has, is specified in ``subdivisions``. The default - ``subdivisions`` argument looks like this: + This is useful if the intention is not to combine multiple modalities, but + rather to fill in the missing super- and sub-level involvement columns for a + single modality. - .. code-block:: python - - { - "I": ["a", "b"], - "II": ["a", "b"], - "V": ["a", "b"], - } - - The resulting DataFrame will only contain the newly inferred sublevel columns - and only for those sublevels that were not already present in the DataFrame. - Thus, one can simply :py:meth:`~pandas.DataFrame.join` the original DataFrame - with the result. + Like the :py:meth:`~LyDataAccessor.combine` method, the returned DataFrame + only has a two-level header. So, for combining this with the original data, + one has to perform additional steps. Or use the :py:meth:`.enhance` method. >>> df = pd.DataFrame({ ... ('MRI', 'ipsi' , 'I' ): [True , False, False, None], @@ -860,100 +481,99 @@ def infer_sublevels( ... ('MRI', 'ipsi' , 'IV'): [False, False, True , None], ... ('CT' , 'ipsi' , 'I' ): [True , False, False, None], ... }) - >>> df.ly.infer_sublevels(modalities=["MRI"]) # doctest: +NORMALIZE_WHITESPACE - MRI - ipsi contra - Ia Ib IIa IIb Ia Ib - 0 None None False False False False - 1 False False False False None None - 2 False False None None False False - 3 None None None None None None + >>> df.ly.augment(modality="MRI") # doctest: +NORMALIZE_WHITESPACE + contra ipsi + I Ia Ib I Ia Ib II IIa IIb IV + 0 False False False True None None False False False False + 1 True None None False False False False False False False + 2 False False False False False False True None None True + 3 None None None None None None None None None None """ - modalities = modalities or list(get_default_modalities().keys()) - sides = sides or ["ipsi", "contra"] - subdivisions = subdivisions or { - "I": ["a", "b"], - "II": ["a", "b"], - "V": ["a", "b"], - } + if modality not in self.get_modalities(): + raise ValueError(f"Modality {modality!r} not found in DataFrame.") - result = self._obj.copy().drop(self._obj.columns, axis=1) + obj_copy = self._obj.copy() - loop_combinations = product(modalities, sides, subdivisions.items()) - for modality, side, (superlevel, subids) in loop_combinations: - try: - is_healthy = self._obj[modality, side, superlevel] == False # noqa - except KeyError: - continue + return combine_and_augment_levels( + diagnoses=[obj_copy[modality]], + specificities=[0.9], # Numbers here don't matter, as we only "combine" + sensitivities=[0.9], # a single modality's involvement info. + subdivisions=subdivisions, + ) - for subid in subids: - sublevel = superlevel + subid - result.loc[is_healthy, (modality, side, sublevel)] = False - result.loc[~is_healthy, (modality, side, sublevel)] = None + def enhance( + self, + modalities: dict[str, ModalityConfig] | None = None, + method: Literal["max_llh", "rank"] = "max_llh", + subdivisions: Mapping[str, Sequence[str]] | None = None, + ) -> LyDataFrame: + """Shorthand for first combining ``modalities`` and then augmenting them. - return result + This first runs the :py:meth:`~LyDataAccessor.combine` method and after that + the :py:meth:`~LyDataAccessor.augment` for every modality in ``modalities`` + and the newly combined ``method`` column. + """ + if subdivisions is None: + subdivisions = { + "I": ["a", "b"], + "II": ["a", "b"], + "V": ["a", "b"], + } - def infer_superlevels( - self, - modalities: list[str] | None = None, - sides: list[Literal["ipsi", "contra"]] | None = None, - subdivisions: dict[str, list[str]] | None = None, - ) -> pd.DataFrame: - """Determine involvement status of an LNL's superlevel (e.g., II). + if modalities is None: + modalities = get_default_modalities() + + # Originally, I thought we could just combine and not augment the super- and + # sub-levels, but then we discard the involvement probability information from + # the original modalities. + combined = self.combine( + modalities=modalities, + method=method, + subdivisions=subdivisions, + ) + combined = pd.concat({method: combined}, axis="columns") + combined.index = self._obj.index + enhanced: LyDataFrame = pd.concat([self._obj, combined], axis="columns") - Some LNLs have sublevels, e.g., IIa and IIb. In real data, sometimes the - sublevels are reported, sometimes only the superlevel. This function infers the - status of the superlevel from the sublevels. + for modality in list(modalities.keys()): + if modality not in enhanced.columns: + continue - The superlevel's status is computed for the specified ``modalities``. If and - what sublevels a superlevel has, is specified in ``subdivisions``. + augmented = enhanced.ly.augment( + modality=modality, + subdivisions=subdivisions, + ) + augmented = pd.concat({modality: augmented}, axis="columns") + augmented.index = enhanced.index + enhanced = replace(left=enhanced, right=augmented) - The resulting DataFrame will only contain the newly inferred superlevel columns - and only for those superlevels that were not already present in the DataFrame. - This way, it is straightforward to :py:meth:`~pandas.DataFrame.join` it with the - original DataFrame. + return _sort_all(enhanced) - >>> df = pd.DataFrame({ - ... ('MRI', 'ipsi' , 'Ia' ): [True , False, False, None, None ], - ... ('MRI', 'ipsi' , 'Ib' ): [False, True , False, None, False], - ... ('MRI', 'contra', 'IIa'): [False, False, None , None, None ], - ... ('MRI', 'contra', 'IIb'): [False, True , True , None, False], - ... ('CT' , 'ipsi' , 'I' ): [True , False, False, None, None ], - ... }) - >>> df.ly.infer_superlevels(modalities=["MRI"]) # doctest: +NORMALIZE_WHITESPACE - MRI - ipsi contra - I II - 0 True False - 1 True True - 2 False True - 3 None None - 4 None None + def cast( + self, + casters: Mapping[type, str] | None = None, + ) -> LyDataFrame: + """Cast the dtypes of the DataFrame to the expected types. + + This uses the annotations of the Pydantic schema to cast the individual columns + of the DataFrame to the expected types. It uses the ``casters`` mapping to + determine the type to cast to. By default, it uses the mapping from the + :py:func:`_get_default_casters` function. """ - modalities = modalities or list(get_default_modalities().keys()) - sides = sides or ["ipsi", "contra"] - subdivisions = subdivisions or { - "I": ["a", "b"], - "II": ["a", "b"], - "V": ["a", "b"], - } + from lydata.validator import cast_dtypes - result = self._obj.copy().drop(self._obj.columns, axis=1) + return cast_dtypes(self._obj, casters=casters) - loop_combinations = product(modalities, sides, subdivisions.items()) - for modality, side, (superlevel, subids) in loop_combinations: - sublevels = [superlevel + subid for subid in subids] - sublevel_cols = [(modality, side, sublevel) for sublevel in sublevels] - try: - is_unknown = self._obj[sublevel_cols].isna().any(axis=1) - is_any_involved = self._obj[sublevel_cols].any(axis=1) - are_all_healthy = ~is_unknown & ~is_any_involved - except KeyError: - continue +# Using the class below instead of pd.DataFrame enables IDE type hints. +class LyDataFrame(pd.DataFrame): + """Subclass of a pandas DataFrame with a custom lydata accessor.""" - result.loc[are_all_healthy, (modality, side, superlevel)] = False - result.loc[is_unknown, (modality, side, superlevel)] = None - result.loc[is_any_involved, (modality, side, superlevel)] = True + ly: LyDataAccessor + """The custom lydata accessor for these DataFrame subclass instances.""" - return result + +if __name__ == "__main__": + import doctest + + doctest.testmod() diff --git a/src/lydata/augmentor.py b/src/lydata/augmentor.py new file mode 100644 index 0000000..93391e0 --- /dev/null +++ b/src/lydata/augmentor.py @@ -0,0 +1,300 @@ +"""Provides functions for augmenting and enhancing the lyDATA tables. + +This module does the heavy lifting of inferring the most likely true involvment based +on several - possibly conflicting - diagnoses and their sensitivities and +specificities. It also resolves the sub- and super-level involvement information, +e.g. if a sublevel is involved, the superlevel is also involved, and vice-versa. + +All this is achieved in the :py:func:`combine_and_augment_levels` function, which is +also used by the :py:meth:`~lydata.accessor.LyDataAccessor.combine`, +:py:meth:`~lydata.accessor.LyDataAccessor.augment`, and +:py:meth:`~lydata.accessor.LyDataAccessor.enhance` methods of the +:py:class:`~lydata.accessor.LyDataAccessor` class. +""" + +from collections.abc import Mapping, Sequence +from itertools import product +from typing import Literal + +import numpy as np +import pandas as pd + +from lydata.utils import _sort_by + + +def _keep_only_involvement(table: pd.DataFrame) -> pd.DataFrame: + """Keep only the involvement information under ``"ipsi"`` and ``"contra"``. + + >>> table = pd.DataFrame({ + ... ("ipsi", "I"): [True, False, None], + ... ("contra", "II"): [False, True, None], + ... ("foo", "bar"): [1, 2, 3], + ... }) + >>> _keep_only_involvement(table) + ipsi contra + I II + 0 True False + 1 False True + 2 None None + """ + return table.filter(regex=r"(ipsi|contra)", axis="columns") + + +def _align_tables(tables: Sequence[pd.DataFrame]) -> list[pd.DataFrame]: + """Align all columns in the sequence of ``tables``. + + >>> one = pd.DataFrame({ + ... ("x", "a"): [1, 2], + ... ("x", "b"): [3, 4], + ... ("y", "c"): [5, 6], + ... ("y", "b"): [19, 120], + ... }) + >>> two = pd.DataFrame({ + ... ("y", "c"): [91, 10], + ... ("y", "b"): [9, 10], + ... ("x", "a"): [7, 8], + ... }) + >>> three = pd.DataFrame({ + ... ("x", "c"): [71, 81], + ... ("y", "b"): [5, 6], + ... ("x", "a"): [5, 61], + ... }) + >>> aligned = _align_tables([one, two, three]) + >>> aligned[0] # doctest: +NORMALIZE_WHITESPACE + x y + a b c b c + 0 1 3 NaN 19 5 + 1 2 4 NaN 120 6 + >>> aligned[1] # doctest: +NORMALIZE_WHITESPACE + x y + a b c b c + 0 7 NaN NaN 9 91 + 1 8 NaN NaN 10 10 + >>> aligned[2] # doctest: +NORMALIZE_WHITESPACE + x y + a b c b c + 0 5 NaN 71 5 NaN + 1 61 NaN 81 6 NaN + """ + if len(tables) == 0: + return [] + + all_columns = tables[0].columns + for table in tables[1:]: + all_columns = all_columns.union(table.columns) + + return [table.reindex(columns=all_columns) for table in tables] + + +def _convert_to_float_matrix(diagnoses: Sequence[pd.DataFrame]) -> np.ndarray: + """Convert a sequence of ``diagnoses`` to a 3D float matrix. + + >>> one = pd.DataFrame({"a": [1, None], "b": [3, 4]}) + >>> two = pd.DataFrame({"a": [5, 6], "b": [7, None]}) + >>> _convert_to_float_matrix([one, two]) # doctest: +NORMALIZE_WHITESPACE + array([[[ 1., 3.], + [nan, 4.]], + [[ 5., 7.], + [ 6., nan]]]) + """ + matrix = np.array(diagnoses) + matrix[pd.isna(matrix)] = np.nan + return np.astype(matrix, float) + + +def _compute_likelihoods( + diagnosis_matrix: np.ndarray, + sensitivities: np.ndarray, + specificities: np.ndarray, + method: Literal["max_llh", "rank"], +) -> tuple[np.ndarray, np.ndarray]: + """Compute the likelihoods of true/false diagnoses using the given ``method``. + + The ``diagnosis_matrix`` is a 3D array of shape ``(n_modalities, n_patients, + n_levels)``. It should contain ``1.0`` where the diagnosis was positive and ``0.0`` + where it was negative. It may also contain ``np.nan``. + + The ``sensitivities`` and ``specificities`` are 1D arrays of shape + ``(n_modalities,)``. When choosing the ``method="max_llh"``, the likelihood of each + diagnosis is combined into one likelihood for each patient and level. With + ``method="rank"``, the likelihoods are computed for the most trustworthy diagnosis. + + Returns the likelihoods of true and false diagnoses as two separate arrays. + """ + true_pos = sensitivities[:, None, None] * diagnosis_matrix + false_neg = (1 - sensitivities[:, None, None]) * (1 - diagnosis_matrix) + true_neg = specificities[:, None, None] * (1 - diagnosis_matrix) + false_pos = (1 - specificities[:, None, None]) * diagnosis_matrix + + if method not in {"max_llh", "rank"}: + raise ValueError(f"Unknown method {method}") + + agg_func = np.nanprod if method == "max_llh" else np.nanmax + true_llh = agg_func(true_pos + false_neg, axis=0) + false_llh = agg_func(true_neg + false_pos, axis=0) + return true_llh, false_llh + + +def _compute_involved_probs( + diagnosis_matrix: np.ndarray, + sensitivities: np.ndarray, + specificities: np.ndarray, + method: Literal["max_llh", "rank"], +) -> np.ndarray: + """Compute the probabilities of involvement for each diagnosis.""" + true_llhs, false_llhs = _compute_likelihoods( + diagnosis_matrix=diagnosis_matrix, + sensitivities=sensitivities, + specificities=specificities, + method=method, + ) + return true_llhs / (true_llhs + false_llhs) + + +def combine_and_augment_levels( + diagnoses: Sequence[pd.DataFrame], + specificities: Sequence[float], + sensitivities: Sequence[float], + method: Literal["max_llh", "rank"] = "max_llh", + sides: Sequence[Literal["ipsi", "contra"]] | None = None, + subdivisions: Mapping[str, Sequence[str]] | None = None, +) -> pd.DataFrame: + """Combine ``diagnoses`` and add sub-/superlevel involvement info. + + Different diagnostic modalities may conflict with each other, e.g. on MRI an + LNL may look metastatic, while FNA finds no malignancy. This function combines + available diagnoses based on their ``sensitivities`` and ``specificities`` + into a sort of consensus. When choosing the ``method="max_llh"``, the most likely/ + probable diagnosis is chosen. If ``method="rank"``, the single most trustworthy + diagnosis is kept. + + Additionally, the function may add and resolve sub- and superlevel involvement + information. For example, some datasets report the overall involvement in LNL II, + while others differentiate between sublevels IIa and IIb. Now, if IIa harbors + disease, that means that the overall involvement in II is also true. By specifying + ``subdivisions``, the function consistently updates these super- and sublevel + involvement patterns. + + The returned :py:class:`~pandas.DataFrame` has a two-level multi-index: One level + for each of the ``sides`` and the second level for the involvement levels. This + means it i in the same format as the stack of input ``diagnoses``. + + See the accessor methods ``:py:meth:`~lydata.accessor.LyDataAccessor.augment`` and + ``:py:meth:`~lydata.accessor.LyDataAccessor.combine`` for some examples. + """ + diagnoses = [_keep_only_involvement(table) for table in diagnoses] + diagnoses = _align_tables(diagnoses) + matrix = _convert_to_float_matrix(diagnoses) + all_nan_mask = np.all(np.isnan(matrix), axis=0) + + involved_probs = _compute_involved_probs( + diagnosis_matrix=matrix, + sensitivities=np.array(sensitivities), + specificities=np.array(specificities), + method=method, + ) + + combined = np.astype(involved_probs >= 0.5, object) + combined[all_nan_mask] = None + combined = pd.DataFrame(combined, columns=diagnoses[0].columns) + + healthy_probs = 1.0 - involved_probs + involved_probs[all_nan_mask] = np.nan + involved_probs = pd.DataFrame(involved_probs, columns=diagnoses[0].columns) + healthy_probs[all_nan_mask] = np.nan + healthy_probs = pd.DataFrame(healthy_probs, columns=diagnoses[0].columns) + + if sides is None: + sides = ["ipsi", "contra"] + + if subdivisions is None: + subdivisions = { + "I": ["a", "b"], + "II": ["a", "b"], + "V": ["a", "b"], + } + + for side, (superlvl, subids) in product(sides, subdivisions.items()): + if side not in combined.columns: + continue + + superlvl_col = (side, superlvl) + sublvls = [superlvl + subid for subid in subids] + sublvl_cols = [(side, sublvl) for sublvl in sublvls] + + if set([superlvl] + sublvls).isdisjoint(set(combined[side].columns)): + continue + + for lvl in [superlvl] + sublvls: + combined[(side, lvl)] = combined.get((side, lvl), [None] * len(combined)) + nans = [np.nan] * len(combined) + involved_probs[(side, lvl)] = involved_probs.get((side, lvl), nans) + healthy_probs[(side, lvl)] = healthy_probs.get((side, lvl), nans) + + is_super_unknown = combined[superlvl_col].isna() + is_super_healthy = combined[superlvl_col] == False + is_super_involved = combined[superlvl_col] == True + + is_any_sub_involved = combined[sublvl_cols].any(axis=1) + is_one_sub_unknown = combined[sublvl_cols].isna().sum(axis=1) == 1 + are_all_subs_healthy = (combined[sublvl_cols] == False).all(axis=1) + are_all_subs_unknown = combined[sublvl_cols].isna().all(axis=1) + + # Superlvl unknown => no conflict, use sublvl info + combined.loc[is_super_unknown & is_any_sub_involved, superlvl_col] = True + combined.loc[is_super_unknown & are_all_subs_healthy, superlvl_col] = False + + # No sublvl involved => no conflict, use superlvl info + combined.loc[~is_any_sub_involved & is_super_healthy, sublvl_cols] = False + + # Conflicts + # 1) Subs override superlvl + super_healthy_prob_from_subs = np.nanprod(healthy_probs[sublvl_cols], axis=1) + super_involved_prob_from_subs = 1.0 - super_healthy_prob_from_subs + + do_subs_determine_super_healthy = ( + is_super_involved + & ~are_all_subs_unknown + & (super_healthy_prob_from_subs > involved_probs[superlvl_col]) + ) + combined.loc[do_subs_determine_super_healthy, superlvl_col] = False + + do_subs_determine_super_involved = ( + is_super_healthy + & ~are_all_subs_unknown + & (super_involved_prob_from_subs > healthy_probs[superlvl_col]) + ) + combined.loc[do_subs_determine_super_involved, superlvl_col] = True + + # 2) Superlvl overrides subs + does_super_determine_all_subs_healthy = ( + is_any_sub_involved + & is_super_healthy + & (healthy_probs[superlvl_col] > super_involved_prob_from_subs) + ) + combined.loc[does_super_determine_all_subs_healthy, sublvl_cols] = False + + does_super_determine_subs_unknown = ( + are_all_subs_healthy + & is_super_involved + & (involved_probs[superlvl_col] > super_healthy_prob_from_subs) + ) + combined.loc[does_super_determine_subs_unknown, sublvl_cols] = None + + for sublvl in sublvls: + sublvl_col = (side, sublvl) + is_sub_unknown = combined[sublvl_col].isna() + does_super_determine_unknown_sub_involved = ( + is_super_involved + & is_sub_unknown + & is_one_sub_unknown + & ~is_any_sub_involved + & (involved_probs[superlvl_col] > super_healthy_prob_from_subs) + ) + # The above combination of conditions means that the current `sublvl` is + # unknown, while all others are healthy, while the superlvl is involved. + # Then below, we change the sublvl to involved. + combined.loc[does_super_determine_unknown_sub_involved, sublvl_col] = True + + combined = _sort_by(combined, which="lnl", level=1) + return _sort_by(combined, which="mid", level=0) diff --git a/src/lydata/loader.py b/src/lydata/loader.py index 7b08d2a..5a436bb 100644 --- a/src/lydata/loader.py +++ b/src/lydata/loader.py @@ -3,7 +3,7 @@ The loading itself is implemented in the :py:class:`.LyDataset` class, which is a :py:class:`pydantic.BaseModel` subclass. It validates the unique specification that identifies a dataset and then allows loading it from the disk (if present) or -from GitHub. +from GitHub (default). The :py:func:`available_datasets` function can be used to create a generator of such :py:class:`.LyDataset` instances, corresponding to all available datasets that @@ -14,9 +14,6 @@ :py:func:`available_datasets` but returns a generator of :py:class:`pandas.DataFrame` instead of :py:class:`.LyDataset`. -Lastly, with the :py:func:`join_datasets` function, one can load and concatenate all -datasets matching the given specs/pattern into a single :py:class:`pandas.DataFrame`. - The docstring of all functions contains some basic doctest examples. """ @@ -28,12 +25,22 @@ import numpy as np # noqa: F401 import pandas as pd -from github import Github, Repository +from github import BadCredentialsException, Github, Repository, UnknownObjectException from github.ContentFile import ContentFile +from github.GithubException import GithubException from loguru import logger -from pydantic import BaseModel, Field, PrivateAttr, constr - +from pydantic import ( + BaseModel, + DirectoryPath, + Field, + PrivateAttr, + RootModel, + constr, +) + +from lydata.accessor import LyDataFrame from lydata.utils import get_github_auth +from lydata.validator import cast_dtypes, is_valid _default_repo_name = "lycosystem/lydata" low_min1_str = constr(to_lower=True, min_length=1) @@ -43,6 +50,42 @@ class SkipDiskError(Exception): """Raised when the user wants to skip loading from disk.""" +def _safely_fetch_repo(gh: Github, repo_name: str) -> Repository: + """Fetch a GitHub repository, handling common errors.""" + try: + logger.debug(f"Fetching repository '{repo_name}' from GitHub...") + repo = gh.get_repo(repo_name) + except UnknownObjectException as e: + raise ValueError(f"Could not find repository '{repo_name}' on GitHub.") from e + except BadCredentialsException as e: + raise ValueError("Invalid GitHub credentials.") from e + + logger.debug(f"Fetched repository '{repo.full_name}' from GitHub.") + return repo + + +def _safely_fetch_contents( + repo: Repository, + ref: str, + path: str = ".", +) -> list[ContentFile] | ContentFile: + """Fetch contents of a GitHub ``repo`` at a specific ``ref``, handling errors.""" + try: + logger.debug(f"Fetching contents of repo '{repo.full_name}' at ref '{ref}'...") + contents = repo.get_contents(path=path, ref=ref) + except GithubException as e: + available_branches = [b.name for b in repo.get_branches()] + available_tags = [t.name for t in repo.get_tags()] + raise ValueError( + f"Could not find ref '{ref}' in repository '{repo.full_name}'.\n" + f"Available branches: {available_branches}.\n" + f"Available tags: {available_tags}." + ) from e + + logger.debug(f"Fetched contents of repo '{repo.full_name}' at ref '{ref}'.") + return contents + + class LyDataset(BaseModel): """Specification of a dataset.""" @@ -57,14 +100,23 @@ class LyDataset(BaseModel): subsite: low_min1_str = Field( description="Tumor subsite(s) patients in this dataset were diagnosed with.", ) - repo_name: low_min1_str = Field( + repo_name: low_min1_str | None = Field( default=_default_repo_name, description="GitHub `repository/owner`.", ) - ref: low_min1_str = Field( + ref: low_min1_str | None = Field( default="main", description="Branch/tag/commit of the repo.", ) + local_dataset_dir: DirectoryPath | None = Field( + default=None, + description=( + "Path to directory containing all the dataset subdirectories. So, e.g. if " + "`path_on_disk` is `~/datasets` and the dataset is `2023-clb-multisite`, " + "then the CSV file is expected to be at " + "`~/datasets/2023-clb-multisite/data.csv`." + ), + ) _content_file: ContentFile | None = PrivateAttr(default=None) @property @@ -77,11 +129,17 @@ def name(self) -> str: """ return f"{self.year}-{self.institution}-{self.subsite}" - @property - def path_on_disk(self) -> Path: - """Get the path to the dataset.""" - install_loc = Path(__file__).parent.parent - return install_loc / self.name / "data.csv" + def get_file_path(self) -> Path: + """Get the path to the CSV dataset.""" + if self.local_dataset_dir is None: + self.local_dataset_dir = Path(__file__).parent.parent + + dataset_path = self.local_dataset_dir / self.name / "data.csv" + if not dataset_path.exists(): + raise FileNotFoundError(f"Could not find CSV locally at '{dataset_path}'.") + + logger.info(f"Found dataset {self.name} on disk at '{dataset_path}'.") + return dataset_path def get_repo( self, @@ -108,9 +166,7 @@ def get_repo( """ auth = get_github_auth(token=token, user=user, password=password) gh = Github(auth=auth) - repo = gh.get_repo(self.repo_name) - logger.info(f"Fetched repository {repo.full_name} from GitHub.") - return repo + return _safely_fetch_repo(gh=gh, repo_name=self.repo_name) def get_content_file( self, @@ -138,7 +194,11 @@ def get_content_file( return self._content_file repo = self.get_repo(token=token, user=user, password=password) - self._content_file = repo.get_contents(f"{self.name}/data.csv", ref=self.ref) + self._content_file = _safely_fetch_contents( + repo=repo, + path=f"{self.name}/data.csv", + ref=self.ref, + ) return self._content_file def get_dataframe( @@ -148,7 +208,7 @@ def get_dataframe( user: str | None = None, password: str | None = None, **load_kwargs, - ) -> pd.DataFrame: + ) -> LyDataFrame: """Load the ``data.csv`` file from disk or from GitHub. One can also choose to ``use_github``. Any keyword arguments are passed to @@ -171,7 +231,7 @@ def get_dataframe( token=token, user=user, password=password ).download_url else: - from_location = self.path_on_disk + from_location = self.get_file_path() df = pd.read_csv(from_location, **kwargs) logger.info(f"Loaded dataset {self.name} from {from_location}.") @@ -186,16 +246,24 @@ def _available_datasets_on_disk( search_paths: list[Path] | None = None, ) -> Generator[LyDataset, None, None]: pattern = f"{str(year)}-{institution}-{subsite}" - search_paths = search_paths or [Path(__file__).parent.parent] + + if search_paths is None: + search_paths = [Path(__file__).parent.parent] + + search_paths = RootModel[list[DirectoryPath]].model_validate(search_paths).root for search_path in search_paths: for match in search_path.glob(pattern): if match.is_dir() and (match / "data.csv").exists(): + logger.debug(f"Found dataset directory at '{match}'.") year, institution, subsite = match.name.split("-", maxsplit=2) yield LyDataset( year=year, institution=institution, subsite=subsite, + local_dataset_dir=search_path, + repo_name=None, + ref=None, ) @@ -206,10 +274,10 @@ def _available_datasets_on_github( repo_name: str = _default_repo_name, ref: str = "main", ) -> Generator[LyDataset, None, None]: + """Generate :py:class:`.LyDataset` instances of available datasets on GitHub.""" gh = Github(auth=get_github_auth()) - - repo = gh.get_repo(repo_name) - contents = repo.get_contents(path="", ref=ref) + repo = _safely_fetch_repo(gh=gh, repo_name=repo_name) + contents = _safely_fetch_contents(repo=repo, ref=ref) matches = [] for content in contents: @@ -218,6 +286,12 @@ def _available_datasets_on_github( ): matches.append(content) + if len(matches) == 0: + raise ValueError( + f"No datasets found in repository '{repo_name}' matching " + f"'{year}-{institution}-{subsite}' at ref '{ref}'." + ) + for match in matches: year, institution, subsite = match.name.split("-", maxsplit=2) yield LyDataset( @@ -307,13 +381,25 @@ def load_datasets( use_github: bool = True, repo_name: str = _default_repo_name, ref: str = "main", + cast: bool = False, + validate: bool = False, + enhance: bool = False, **kwargs, -) -> Generator[pd.DataFrame, None, None]: - """Load matching datasets from the disk. +) -> Generator[LyDataFrame, None, None]: + """Load matching datasets from GitHub or from the disk. It loads every dataset from the :py:class:`.LyDataset` instances generated by - the :py:func:`available_datasets` function, which also receives all arguments of + the :py:func:`available_datasets` function, which also receives most arguments of this function. + + The boolean flags ``cast``, ``validate``, and ``enhance`` can be used to + automatically cast the dtypes of the loaded :py:class:`pandas.DataFrame`s, + validate them, and enhance them with additional columns. These operations are + performed using the :py:func:`~lydata.cast_dtypes`, :py:func:`~lydata.is_valid`, + the :py:func:`~lydata.LyDataAccessor.enhance` method, respectively. + + Additional keyword arguments are passed to the :py:meth:`LyDataset.get_dataframe` + method. """ dset_confs = available_datasets( year=year, @@ -325,4 +411,7 @@ def load_datasets( ref=ref, ) for dset_conf in dset_confs: - yield dset_conf.get_dataframe(use_github=use_github, **kwargs) + df: LyDataFrame = dset_conf.get_dataframe(use_github=use_github, **kwargs) + df = cast_dtypes(df) if cast else df + _ = validate and is_valid(df, fail_on_error=True) + yield df.ly.enhance() if enhance else df diff --git a/src/lydata/querier.py b/src/lydata/querier.py new file mode 100644 index 0000000..6ae4128 --- /dev/null +++ b/src/lydata/querier.py @@ -0,0 +1,389 @@ +"""Querier module for lydata package. + +This module provides the :py:class:`Q` and :py:class:`C` classes for creating and +combining reusable queries to filter :py:class:`pandas.DataFrame` objects. These +classes are inspired by Django's ``Q`` objects and allow for a more readable and modular +way to filter and query data. + +For example, we may want to keep only patient with tumors of T-category 3 or higher. +Then, we can write + +.. code-block:: python + + from lydata import C + has_t_stage = C("t_stage") >= 3 + +Now, through the equality comparison of an instance of :py:class:`C`, the +``has_t_stage`` is an instance of :py:class:`Q` that can be combined with other queries +and applied via our custom :py:class:`~lydata.accessor.LyDataAccessor` to a table: + +.. code-block:: python + + is_old = C("age") >= 65 + data.ly.query(has_t_stage & is_old) + +Internally, this works by calling the :py:meth:`Q.execute` method, which returns a +boolean mask to filter the DataFrame. So, the above example is equivalent to + +.. code-block:: python + + (has_t_stage & is_old).execute(data) +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any, Literal + +import pandas as pd + +from lydata import accessor # noqa: F401 +from lydata.types import CanExecute +from lydata.utils import _get_all_true + + +class CombineQMixin: + """Mixin class for combining queries. + + Four operators are defined for combining queries: + + 1. ``&`` for logical AND operations. + The returned object is an :py:class:`AndQ` instance and - when executed - + returns a boolean mask where both queries are satisfied. When the right-hand + side is ``None``, the left-hand side query object is returned unchanged. + 2. ``|`` for logical OR operations. + The returned object is an :py:class:`OrQ` instance and - when executed - + returns a boolean mask where either query is satisfied. When the right-hand + side is ``None``, the left-hand side query object is returned unchanged. + 3. ``~`` for inverting a query. + The returned object is a :py:class:`NotQ` instance and - when executed - + returns a boolean mask where the query is not satisfied. + 4. ``==`` for checking if two queries are equal. + Two queries are equal if their column names, operators, and values are equal. + Note that this does not check if the queries are semantically equal, i.e., if + they would return the same result when executed. + """ + + def __and__(self, other: CanExecute | None) -> AndQ: + """Combine two queries with a logical AND.""" + return self if other is None else AndQ(self, other) + + def __or__(self, other: CanExecute | None) -> OrQ: + """Combine two queries with a logical OR.""" + return self if other is None else OrQ(self, other) + + def __invert__(self) -> NotQ: + """Negate the query.""" + return NotQ(self) + + def __eq__(self, value): + """Check if two queries are equal.""" + return ( + isinstance(value, self.__class__) + and self.colname == value.colname + and self.operator == value.operator + and self.value == value.value + ) + + +class Q(CombineQMixin): + """Combinable query object for filtering a DataFrame. + + The syntax for this object is similar to Django's ``Q`` object. It can be used to + define queries in a more readable and modular way. + + .. caution:: + + The column names are not checked upon instantiation. This is only done when the + query is executed. In fact, the :py:class:`Q` object does not even know about + the :py:class:`~pandas.DataFrame` it will be applied to in the beginning. On the + flip side, this means a query may be reused for different DataFrames. + + The ``operator`` argument may be one of the following: + + - ``'=='``: Checks if ``column`` values are equal to the ``value``. + - ``'<'``: Checks if ``column`` values are less than the ``value``. + - ``'<='``: Checks if ``column`` values are less than or equal to ``value``. + - ``'>'``: Checks if ``column`` values are greater than the ``value``. + - ``'>='``: Checks if ``column`` values are greater than or equal to ``value``. + - ``'!='``: Checks if ``column`` values are not equal to the ``value``. This is + equivalent to ``~Q(column, '==', value)``. + - ``'in'``: Checks if ``column`` values are in the list of ``value``. For this, + pandas' :py:meth:`~pandas.Series.isin` method is used. + - ``'contains'``: Checks if ``column`` values contain the string ``value``. + Here, pandas' :py:meth:`~pandas.Series.str.contains` method is used. + - ``'pass_to'``: Passes the column values to the callable ``value``. This is useful + for custom filtering functions that may not be covered by the other operators. + """ + + _OPERATOR_MAP: dict[str, Callable[[pd.Series, Any], pd.Series]] = { + "==": lambda series, value: series == value, + "<": lambda series, value: series < value, + "<=": lambda series, value: series <= value, + ">": lambda series, value: series > value, + ">=": lambda series, value: series >= value, + "!=": lambda series, value: series != value, # same as ~Q("col", "==", value) + "in": lambda series, value: series.isin(value), # value is a list + "contains": lambda series, value: series.str.contains(value), # value is a str + "pass_to": lambda series, value: value(series), # value is a callable + } + + def __init__( + self, + column: str, + operator: Literal["==", "<", "<=", ">", ">=", "!=", "in", "contains"], + value: Any, + ) -> None: + """Create query object that can compare a ``column`` with a ``value``.""" + self.colname = column + self.operator = operator + self.value = value + + def __repr__(self) -> str: + """Return a string representation of the query.""" + return f"Q({self.colname!r}, {self.operator!r}, {self.value!r})" + + def execute(self, df: pd.DataFrame) -> pd.Series: + """Return a boolean mask where the query is satisfied for ``df``. + + >>> df = pd.DataFrame({'col1': [1, 2, 3], 'col2': ['foo', 'bar', 'baz']}) + >>> Q('col1', '<=', 2).execute(df) + 0 True + 1 True + 2 False + Name: col1, dtype: bool + >>> Q('col2', 'contains', 'ba').execute(df) + 0 False + 1 True + 2 True + Name: col2, dtype: bool + >>> Q('col1', 'pass_to', lambda x: x % 2 == 0).execute(df) + 0 False + 1 True + 2 False + Name: col1, dtype: bool + """ + column = df.ly[self.colname] + return self._OPERATOR_MAP[self.operator](column, self.value) + + +class NoneQ(CombineQMixin): + """Query object that always returns the entire DataFrame. Useful as default.""" + + def __repr__(self) -> str: + """Return a string representation of the query.""" + return "NoneQ()" + + def execute(self, df: pd.DataFrame) -> pd.Series: + """Return a boolean mask with all entries set to ``True``.""" + return _get_all_true(df) + + +class AndQ(CombineQMixin): + """Query object for combining two queries with a logical AND. + + >>> df = pd.DataFrame({'col1': [1, 2, 3], 'col2': ['foo', 'bar', 'baz']}) + >>> q1 = Q('col1', '!=', 3) + >>> q2 = Q('col2', 'contains', 'ba') + >>> and_q = q1 & q2 + >>> print(and_q) + (Q('col1', '!=', 3) & Q('col2', 'contains', 'ba')) + >>> isinstance(and_q, AndQ) + True + >>> and_q.execute(df) + 0 False + 1 True + 2 False + dtype: bool + >>> all((q1 & None).execute(df) == q1.execute(df)) + True + """ + + def __init__(self, q1: CanExecute, q2: CanExecute) -> None: + """Combine two queries with a logical AND.""" + self.q1 = q1 + self.q2 = q2 + + def __repr__(self) -> str: + """Return a string representation of the query.""" + return f"({self.q1!r} & {self.q2!r})" + + def execute(self, df: pd.DataFrame) -> pd.Series: + """Return a boolean mask where both queries are satisfied.""" + return self.q1.execute(df) & self.q2.execute(df) + + +class OrQ(CombineQMixin): + """Query object for combining two queries with a logical OR. + + >>> df = pd.DataFrame({'col1': [1, 2, 3]}) + >>> q1 = Q('col1', '==', 1) + >>> q2 = Q('col1', '==', 3) + >>> or_q = q1 | q2 + >>> print(or_q) + (Q('col1', '==', 1) | Q('col1', '==', 3)) + >>> isinstance(or_q, OrQ) + True + >>> or_q.execute(df) + 0 True + 1 False + 2 True + Name: col1, dtype: bool + >>> all((q1 | None).execute(df) == q1.execute(df)) + True + """ + + def __init__(self, q1: CanExecute, q2: CanExecute) -> None: + """Combine two queries with a logical OR.""" + self.q1 = q1 + self.q2 = q2 + + def __repr__(self) -> str: + """Return a string representation of the query.""" + return f"({self.q1!r} | {self.q2!r})" + + def execute(self, df: pd.DataFrame) -> pd.Series: + """Return a boolean mask where either query is satisfied.""" + return self.q1.execute(df) | self.q2.execute(df) + + +class NotQ(CombineQMixin): + """Query object for negating a query. + + >>> df = pd.DataFrame({'col1': [1, 2, 3]}) + >>> q = Q('col1', '==', 2) + >>> not_q = ~q + >>> print(not_q) + ~Q('col1', '==', 2) + >>> isinstance(not_q, NotQ) + True + >>> not_q.execute(df) + 0 True + 1 False + 2 True + Name: col1, dtype: bool + >>> print(~(Q('col1', '==', 2) & Q('col1', '!=', 3))) + ~(Q('col1', '==', 2) & Q('col1', '!=', 3)) + """ + + def __init__(self, q: CanExecute) -> None: + """Negate the given query ``q``.""" + self.q = q + + def __repr__(self) -> str: + """Return a string representation of the query.""" + return f"~{self.q!r}" + + def execute(self, df: pd.DataFrame) -> pd.Series: + """Return a boolean mask where the query is not satisfied.""" + return ~self.q.execute(df) + + +class C: + """Wraps a column name and produces a :py:class:`Q` object upon comparison. + + This is basically a shorthand for creating a :py:class:`Q` object that avoids + writing the operator and value in quotes. Thus, it may be more readable and allows + IDEs to provide better autocompletion. + + .. caution:: + + Just like for the :py:class:`Q` object, it is not checked upon instantiation + whether the column name is valid. This is only done when the query is executed. + """ + + def __init__(self, *column: str) -> None: + """Create a column object for comparison. + + For querying multi-level columns, both the syntax ``C('col1', 'col2')`` and + ``C(('col1', 'col2'))`` is valid. + + >>> (C('col1', 'col2') == 1) == (C(('col1', 'col2')) == 1) + True + """ + self.column = column[0] if len(column) == 1 else column + + def __repr__(self) -> str: + """Return a string representation of the column object. + + >>> repr(C('foo')) + "C('foo')" + >>> repr(C('foo', 'bar')) + "C(('foo', 'bar'))" + """ + return f"C({self.column!r})" + + def __eq__(self, value: Any) -> Q: + """Create a query object for comparing equality. + + >>> C('foo') == 'bar' + Q('foo', '==', 'bar') + """ + return Q(self.column, "==", value) + + def __lt__(self, value: Any) -> Q: + """Create a query object for comparing less than. + + >>> C('foo') < 42 + Q('foo', '<', 42) + """ + return Q(self.column, "<", value) + + def __le__(self, value: Any) -> Q: + """Create a query object for comparing less than or equal. + + >>> C('foo') <= 42 + Q('foo', '<=', 42) + """ + return Q(self.column, "<=", value) + + def __gt__(self, value: Any) -> Q: + """Create a query object for comparing greater than. + + >>> C('foo') > 42 + Q('foo', '>', 42) + """ + return Q(self.column, ">", value) + + def __ge__(self, value: Any) -> Q: + """Create a query object for comparing greater than or equal. + + >>> C('foo') >= 42 + Q('foo', '>=', 42) + """ + return Q(self.column, ">=", value) + + def __ne__(self, value: Any) -> Q: + """Create a query object for comparing inequality. + + >>> C('foo') != 'bar' + Q('foo', '!=', 'bar') + """ + return Q(self.column, "!=", value) + + def isin(self, value: list[Any]) -> Q: + """Create a query object for checking if the column values are in a list. + + >>> C('foo').isin([1, 2, 3]) + Q('foo', 'in', [1, 2, 3]) + """ + return Q(self.column, "in", value) + + def contains(self, value: str) -> Q: + """Create a query object for checking if the column values contain a string. + + >>> C('foo').contains('bar') + Q('foo', 'contains', 'bar') + """ + return Q(self.column, "contains", value) + + def pass_to(self, value: Callable[[pd.Series], pd.Series]) -> Q: + """Create a query object that passes the column values to a callable. + + This is useful for custom filtering functions that may not be covered by the + other operators. + + >>> C('foo').pass_to(lambda x: x > 42) # doctest: +ELLIPSIS + Q('foo', 'pass_to', at ...>) + """ + return Q(self.column, "pass_to", value) diff --git a/src/lydata/schema.py b/src/lydata/schema.py new file mode 100644 index 0000000..42a0859 --- /dev/null +++ b/src/lydata/schema.py @@ -0,0 +1,459 @@ +"""Pydantic schema to define a single patient record. + +This schema is useful for casting dtypes, as done in :py:func:`validator.cast_dtypes`, +validation via :py:func:`~validator.is_valid`, and for exporting a JSON schema that +may be used for all kinds of purposes, e.g. to automatically generate HTML forms using +a `JSON-Editor`_. + +.. _JSON-Editor: https://json-editor.github.io/json-editor/ +""" + +from __future__ import annotations + +import json +import sys +from pathlib import Path +from typing import Annotated, Any, Literal + +import pandas as pd +from loguru import logger +from pydantic import ( + BaseModel, + BeforeValidator, + Field, + PastDate, + RootModel, + create_model, + field_validator, + model_validator, +) + +from lydata.utils import get_default_modalities + +_LNLS = [ + "I", + "Ia", + "Ib", + "II", + "IIa", + "IIb", + "III", + "IV", + "V", + "Va", + "Vb", + "VI", + "VII", + "VIII", + "IX", + "X", +] + + +def convert_nat(value: Any) -> Any: + """Convert pandas NaT to None. + + pydantic throws an unspecific ``TypeError`` when ``pd.NaT`` is passed to a field. + See [this issue on Github](https://github.com/pydantic/pydantic/issues/8039). + """ + return None if pd.isna(value) else value + + +class PatientCore(BaseModel): + """Basic required patient information. + + This includes demographic information, such as age and sex, as well as some risk + factors for head and neck cancer, including HPV status, alcohol and nicotine abuse, + etc. + """ + + id: str = Field( + description=( + "Unique but anonymized identifier for a patient. We commonly use the " + "format `YYYY--`, where `` is an abbreviation of the " + "institution (hospital) where the patient was treated." + ) + ) + institution: str = Field( + description="Name of the institution/hospital where the patient was treated." + ) + sex: Literal["male", "female"] = Field(description="Biological sex of the patient.") + age: int = Field( + ge=0, + le=120, + description="Age of the patient at the time of diagnosis in years.", + ) + diagnose_date: Annotated[PastDate, BeforeValidator(convert_nat)] = Field( + description="Date of diagnosis of the patient (format YYYY-MM-DD)." + ) + alcohol_abuse: bool | None = Field( + description="Whether the patient currently abuses alcohol." + ) + nicotine_abuse: bool | None = Field( + description="Whether the patient currently abuses nicotine." + ) + pack_years: float | None = Field( + default=None, + ge=0, + description="Number of pack years of nicotine abuse.", + ) + hpv_status: bool | None = Field( + default=None, + description="Whether the patient was infected with HPV.", + ) + neck_dissection: bool | None = Field( + description=( + "Whether the patient underwent neck dissection as part of their treatment." + ), + ) + tnm_edition: int = Field( + ge=6, + le=8, + default=8, + description="Edition of the TNM classification used for staging.", + ) + n_stage_prefix: Literal["c", "p"] | None = Field( + default=None, + description=( + "Prefix for the N stage, 'c' = clinical, 'p' = pathological. " + "This is used to distinguish between clinical and pathological staging." + ), + ) + n_stage: int = Field( + ge=-1, + le=3, + description=( + "N stage of the patient according to the TNM classification. The value -1 " + "is reserved for the NX stage, which means that the lymph nodes could not " + "be assessed for involvement." + ), + ) + n_stage_suffix: Literal["a", "b", "c"] | None = Field( + default=None, + description=( + "Suffix for the N-stage according to the TNM classification. " + "Can be 'a', 'b', or 'c'." + ), + ) + m_stage: int | None = Field( + default=None, + ge=-1, + le=1, + description=( + "M stage of the patient according to the TNM classification. The value -1 " + "is reserved for the MX stage, which technically doesn't exist, but it is " + "commonly used." + ), + ) + weight: float | None = Field( + default=None, + ge=0, + description="Weight of the patient in kg at the time of diagnosis.", + ) + + @field_validator( + "alcohol_abuse", + "nicotine_abuse", + "pack_years", + "hpv_status", + "n_stage_prefix", + "n_stage_suffix", + "neck_dissection", + "m_stage", + "weight", + mode="before", + ) + @classmethod + def nan_to_none(cls, value: Any) -> Any: + """Convert NaN values to None to avoid pydantic errors.""" + return None if pd.isna(value) else value + + @field_validator( + "sex", + "n_stage_prefix", + "n_stage_suffix", + mode="before", + ) + @classmethod + def to_lower(cls, value: Any) -> Any: + """Convert some string fields to lower case before validation.""" + if isinstance(value, str): + return value.lower() + + return value + + +class PatientRecord(BaseModel): + """A patient's record. + + Because the final dataset has a three-level header, this record holds only the + key ``core`` under which we store the actual patient information defined in the + :py:class:`PatientCore` model. + + Alongside ``core``, this may at some point hold additional or optional information + about the patient. + """ + + core: PatientCore = Field( + title="Core", + description="Core information about the patient.", + default_factory=PatientCore, + ) + + +class TumorCore(BaseModel): + """Information about the tumor of a patient. + + This information characterizes the primary tumor via its location, ICD-O-3 subsite, + T-category and so on. + """ + + location: str = Field(description="Primary tumor location.") + subsite: str = Field( + description="ICD-O-3 subsite of the primary tumor.", + pattern=r"C[0-9]{2}(\.[0-9X])?", + ) + central: bool | None = Field( + description="Whether the tumor is located on the mid-sagittal line.", + default=False, + ) + extension: bool | None = Field( + description="Whether the tumor extends over the mid-sagittal line.", + default=False, + ) + dist_to_midline: float | None = Field( + default=None, + ge=0, + description="Distance of the tumor to the mid-sagittal line in mm.", + ) + volume: float | None = Field( + default=None, + ge=0, + description="Estimated volume of the tumor in cmยณ.", + ) + t_stage_prefix: Literal["c", "p"] = Field( + default="c", + description="Prefix for the tumor stage, 'c' = clinical, 'p' = pathological.", + ) + t_stage: int = Field( + ge=-1, + le=4, + description=( + "T stage of the tumor according to the TNM classification. -1 is reserved " + "for the TX stage, meaning the presence of tumor could not be assessed." + ), + ) + t_stage_suffix: Literal["is", "a", "b"] | None = Field( + default=None, + description=( + "Suffix for the T-stage according to the TNM classification. " + "Can be 'a' or 'b'. The value 'is' is reserved for the Tis stage, in which " + "case the `t_stage` should be 0." + ), + ) + side: Literal["left", "right", "central"] | None = Field( + default=None, + description="Side of the neck where the main tumor mass is located.", + ) + + @field_validator( + "central", + "extension", + "dist_to_midline", + "volume", + "t_stage_suffix", + "side", + mode="before", + ) + @classmethod + def nan_to_none(cls, value: Any) -> Any: + """Convert NaN values to None.""" + return None if pd.isna(value) else value + + @field_validator( + "location", + "t_stage_prefix", + "t_stage_suffix", + "side", + mode="before", + ) + @classmethod + def to_lower(cls, value: Any) -> Any: + """Convert string values to lower case.""" + if isinstance(value, str): + return value.lower() + + return value + + @model_validator(mode="after") + def check_tumor_side(self) -> TumorCore: + """Ensure tumor side information is consistent with ``central``.""" + if self.side == "central" and not self.central: + raise ValueError(f"{self.central=}, but {self.side=}.") + + return self + + @model_validator(mode="after") + def check_t_stage(self) -> TumorCore: + """Ensure T-category is valid.""" + if self.t_stage == -1 and self.t_stage_suffix is not None: + raise ValueError( + f"{self.t_stage_suffix=}, but should be `None`, since " + f"{self.t_stage=}, indicating TX stage.", + ) + + if self.t_stage_suffix == "is" and self.t_stage != 0: + raise ValueError( + f"T-stage 'Tis' is indicated by t_stage=0 and t_stage_suffix='is'. " + f"But got {self.t_stage=} and {self.t_stage_suffix=}.", + ) + + if self.t_stage_suffix in ["a", "b"] and self.t_stage not in [1, 2, 3, 4]: + raise ValueError( + f"T-stage suffix {self.t_stage_suffix=} is only valid for T-stages " + f"1, 2, 3, or 4, but got {self.t_stage=}.", + ) + + return self + + +class TumorRecord(BaseModel): + """A tumor record of a patient. + + As with the :py:class:`PatientRecord`, this holds only the key ``core`` under which + we store the actual tumor information defined in the :py:class:`TumorCore` model. + """ + + core: TumorCore = Field( + title="Core", + description="Core information about the tumor.", + default_factory=TumorCore, + ) + + +def create_lnl_field(lnl: str) -> tuple[type, Field]: + """Create a field for a specific lymph node level.""" + return ( + Annotated[bool | None, BeforeValidator(lambda v: None if pd.isna(v) else v)], + Field(default=None, description=f"LN {lnl} involvement"), + ) + + +class ModalityCore(BaseModel): + """Basic info about a diagnostic/pathological modality.""" + + date: Annotated[PastDate | None, BeforeValidator(convert_nat)] = Field( + description="Date of the diagnostic or pathological modality.", + default=None, + ) + + +UnilateralInvolvementInfo = create_model( + "UnilateralInvolvementInfo", + **{lnl: create_lnl_field(lnl) for lnl in _LNLS}, +) + + +class ModalityRecord(BaseModel): + """Involvement patterns of a diagnostic or pathological modality. + + This holds some basic information about the modality, which is currently limited to + the date its information was collected (e.g. the date of the PET/CT scan). + + Most importantly, this holds the ipsi- and contralateral lymph node level + involvement patterns under the respective keys ``ipsi`` and ``contra``. + """ + + core: ModalityCore = Field( + title="Core", + default_factory=ModalityCore, + ) + ipsi: UnilateralInvolvementInfo = Field( + title="Ipsilateral Involvement", + description="Involvement patterns of the ipsilateral side.", + default_factory=UnilateralInvolvementInfo, + ) + contra: UnilateralInvolvementInfo = Field( + title="Contralateral Involvement", + description="Involvement patterns of the contralateral side.", + default_factory=UnilateralInvolvementInfo, + ) + + +def create_modality_field(modality: str) -> tuple[type, Field]: + """Create a field for a specific modality.""" + return ( + ModalityRecord, + Field( + title=modality, + description=f"Involvement patterns as observed using {modality}.", + default_factory=ModalityRecord, + ), + ) + + +class BaseRecord(BaseModel): + """A basic record of a patient. + + Contains at least the patient and tumor information in the same nested form + as the data represents it. + """ + + patient: PatientRecord = Field( + title="Patient", + description=( + "Characterizes the patient via demographic information and risk factors " + "associated with head and neck cancer. In order to achieve the three-level " + "header structure in the final table, there is a subkey `core` under which " + "the actual patient information is stored." + ), + default_factory=PatientRecord, + ) + tumor: TumorRecord = Field( + title="Tumor", + description=( + "Characterizes the primary tumor via its location, ICD-O-3 subsite, " + "T-category and so on. As with the patient record, this has a subkey " + "`core` under which the actual tumor information is stored." + ), + default_factory=TumorRecord, + ) + + +def create_full_record_model( + modalities: list[str], + model_name: str = "FullRecord", + **kwargs: dict[str, Any], +) -> type: + """Create a Pydantic model for a full record with all ``modalities``.""" + return create_model( + model_name, + __base__=BaseRecord, + **{mod: create_modality_field(mod) for mod in modalities}, + **kwargs, + ) + + +def _write_schema_to_file( + schema: type[BaseModel] | None = None, + file_path: Path = Path("schema.json"), +) -> None: + """Write the Pydantic schema to a file.""" + if schema is None: + modalities = get_default_modalities() + schema = create_full_record_model(modalities, model_name="Record") + + root_schema = RootModel[list[schema]] + + with open(file_path, "w") as f: + json_schema = root_schema.model_json_schema() + f.write(json.dumps(json_schema, indent=2)) + + logger.success(f"Schema written to {file_path}") + + +if __name__ == "__main__": + logger.enable("lydata") + logger.remove() + logger.add(sys.stderr, level="DEBUG") + _write_schema_to_file() diff --git a/src/lydata/types.py b/src/lydata/types.py new file mode 100644 index 0000000..15cd931 --- /dev/null +++ b/src/lydata/types.py @@ -0,0 +1,14 @@ +"""Protocol and type definitions for lydata package.""" + +from typing import Protocol, runtime_checkable + +import pandas as pd + + +@runtime_checkable +class CanExecute(Protocol): + """Protocol for objects that can :py:func:`execute` on a DataFrame.""" + + def execute(self, df: pd.DataFrame) -> pd.Series: + """Provide a binary mask for the ``df`` DataFrame.""" + ... diff --git a/src/lydata/utils.py b/src/lydata/utils.py index 487af71..9e8ef95 100644 --- a/src/lydata/utils.py +++ b/src/lydata/utils.py @@ -1,14 +1,17 @@ """Utility functions and classes.""" import os +import re from collections.abc import Callable from dataclasses import dataclass, field +from functools import cmp_to_key from typing import Any, Literal import pandas as pd from github import Auth from loguru import logger from pydantic import BaseModel, Field +from roman import fromRoman as roman_to_int # noqa: N813 def get_github_auth( @@ -16,7 +19,7 @@ def get_github_auth( user: str | None = None, password: str | None = None, ) -> Auth.Auth | None: - """Get the GitHub authentication object.""" + """Get the GitHub authentication object from arguments or environment variables.""" token = token or os.getenv("GITHUB_TOKEN") user = user or os.getenv("GITHUB_USER") password = password or os.getenv("GITHUB_PASSWORD") @@ -64,6 +67,16 @@ def update_and_expand( return result +def replace(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame: + """Replace all columns in ``left`` with those from ``right``.""" + result = left.copy() + + for column in right.columns: + result[column] = right[column] + + return result + + @dataclass class _ColumnSpec: """Class for specifying column names and aggfuncs. @@ -156,6 +169,7 @@ def get_default_column_map_old() -> _ColumnMap: _ColumnSpec("m_stage", ("patient", "#", "m_stage")), _ColumnSpec("midext", ("tumor", "1", "extension")), _ColumnSpec("subsite", ("tumor", "1", "subsite")), + _ColumnSpec("location", ("tumor", "1", "location")), _ColumnSpec("volume", ("tumor", "1", "volume")), _ColumnSpec("central", ("tumor", "1", "central")), _ColumnSpec("side", ("tumor", "1", "side")), @@ -167,15 +181,15 @@ def _new_from_old(long_name: tuple[str, str, str]) -> tuple[str, str, str]: """Convert an old long key name to a new long key name. >>> _new_from_old(("patient", "#", "neck_dissection")) - ('patient', 'info', 'neck_dissection') + ('patient', 'core', 'neck_dissection') >>> _new_from_old(("tumor", "1", "t_stage")) - ('tumor', 'info', 't_stage') + ('tumor', 'core', 't_stage') >>> _new_from_old(("a", "b", "c")) ('a', 'b', 'c') """ start, middle, end = long_name if (start == "patient" and middle == "#") or (start == "tumor" and middle == "1"): - middle = "info" + middle = "core" return (start, middle, end) @@ -195,18 +209,18 @@ def get_default_column_map_new() -> _ColumnMap: >>> df = next(loader.load_datasets( ... institution="usz", ... repo_name="lycosystem/lydata.private", - ... ref="ce2ac255b8aec7443375b610e5254a46bf236a46", + ... ref="fb55afa26ff78afa78274a86b131fb3014d0ceea", ... )) >>> df.ly.surgery # doctest: +ELLIPSIS 0 False ... 286 False - Name: (patient, info, neck_dissection), Length: 287, dtype: bool + Name: (patient, core, neck_dissection), Length: 287, dtype: bool >>> df.ly.smoke # doctest: +ELLIPSIS 0 True ... 286 True - Name: (patient, info, nicotine_abuse), Length: 287, dtype: bool + Name: (patient, core, nicotine_abuse), Length: 287, dtype: bool """ return _ColumnMap.from_list( [ @@ -244,69 +258,118 @@ def get_default_modalities() -> dict[str, ModalityConfig]: } -def infer_all_levels( - dataset: pd.DataFrame, - infer_superlevels_kwargs: dict[str, Any] | None = None, - infer_sublevels_kwargs: dict[str, Any] | None = None, -) -> pd.DataFrame: - """Infer all levels of involvement for each diagnostic modality. +def _get_all_true(df: pd.DataFrame) -> pd.Series: + """Return a mask with all entries set to ``True``.""" + return pd.Series([True] * len(df)) - This function first infers sublevel (e.g. 'IIa', and 'IIb') involvement for each - modality using :py:meth:`~lydata.accessor.LyDataAccessor.infer_sublevels`. Then, - it infers superlevel (e.g. 'II') involvement for each modality using - :py:meth:`~lydata.accessor.LyDataAccessor.infer_superlevels`. + +def _get_numeral_with_sub_value(key: str) -> float: + """Get the value of a Roman numeral with an optional sublevel. + + >>> _get_numeral_with_sub_value("I") + 1.0 + >>> _get_numeral_with_sub_value("IIa") + 2.01 + >>> _get_numeral_with_sub_value("IXb") + 9.02 """ - infer_sublevels_kwargs = infer_sublevels_kwargs or {} - infer_superlevels_kwargs = infer_superlevels_kwargs or {} + match = re.match(r"([IVXLCDM]+)([a-z]?)", key) + if match is None: + raise ValueError(f"Invalid Roman numeral with sublevel: {key}") + numeral, sublvl = match.groups() - result = dataset.copy() + base = roman_to_int(numeral) + addition = 0.0 - result = update_and_expand( - left=result, - right=result.ly.infer_superlevels(**infer_superlevels_kwargs), - ) - return update_and_expand( - left=result, - right=result.ly.infer_sublevels(**infer_sublevels_kwargs), - ) + if len(sublvl) == 1: + addition = "abcdefghijklmnopqrstuvwxyz".index(sublvl) / 100.0 + 0.01 + + return base + addition + + +def _top_lvl_cmp(left: str, right: str) -> int: + """Compare two top-level column names.""" + if left == right: + return 0 + + if left == "patient": + return -1 + + if right == "patient": + return 1 + if left == "tumor": + return -1 -def infer_and_combine_levels( + if right == "tumor": + return 1 + + if left == "max_llh": + return -1 + + if right == "max_llh": + return 1 + + return (left > right) - (left < right) + + +def _mid_lvl_cmp(left: str, right: str) -> int: + """Compare two mid-level column names.""" + if left == right: + return 0 + + if left == "core": + return -1 + + if right == "core": + return 1 + + return (left > right) - (left < right) + + +def _lnl_cmp(left: str, right: str) -> int: + """Compare two roman numeral LNLs.""" + try: + left_value = _get_numeral_with_sub_value(left) + right_value = _get_numeral_with_sub_value(right) + return (left_value > right_value) - (left_value < right_value) + except ValueError: + if "id" in left: + return -1 + if "id" in right: + return 1 + + return (left > right) - (left < right) + + +def _sort_by( dataset: pd.DataFrame, - infer_superlevels_kwargs: dict[str, Any] | None = None, - infer_sublevels_kwargs: dict[str, Any] | None = None, - combine_kwargs: dict[str, Any] | None = None, + which: Literal["top", "mid", "lnl"], + level: int | None = None, ) -> pd.DataFrame: - """Enhance the dataset by inferring additional columns from the data. + """Sort the DataFrame columns by the specified level.""" + if level is None: + level = ["top", "mid", "lnl"].index(which) + + cmps = { + "top": _top_lvl_cmp, + "mid": _mid_lvl_cmp, + "lnl": _lnl_cmp, + } - This performs the following steps in order: + if which not in cmps: + raise ValueError(f"Invalid sorting level: {which} ('top', 'mid', or 'lnl').") - 1. Infer the superlevel involvement for each diagnostic modality using the - :py:meth:`~lydata.accessor.LyDataAccessor.infer_superlevels` method. - 2. Infer the sublevel involvement for each diagnostic modality using the - :py:meth:`~lydata.accessor.LyDataAccessor.infer_sublevels` method. This skips - all LNLs that were computed in the previous step. - 3. Compute the maximum likelihood estimate of the true state of the patient using - the :py:meth:`~lydata.accessor.LyDataAccessor.combine`. + if level < 0 or level > 2: + raise ValueError(f"Invalid level: {level} (must be 0, 1, or 2).") - .. important:: + columns = dataset.columns.get_level_values(level).unique() + sorted_columns = sorted(columns, key=cmp_to_key(cmps[which])) + return dataset.reindex(columns=sorted_columns, level=level) - Performing these operations in any other order may lead to the loss of some - information or even to conflicting LNL involvement information. - The result contains all LNLs of interest in the head and neck region, as well as - the best estimate of the true state of the patient under the top-level key - ``max_llh``. - """ - result = infer_all_levels( - dataset, - infer_superlevels_kwargs=infer_superlevels_kwargs, - infer_sublevels_kwargs=infer_sublevels_kwargs, - ) - combine_kwargs = combine_kwargs or {} - method = combine_kwargs.get("method", "max_llh") - max_llh = pd.concat( - {method: result.ly.combine(**combine_kwargs)}, - axis="columns", - ) - return result.join(max_llh) +def _sort_all(dataset: pd.DataFrame) -> pd.DataFrame: + """Use the custom sorting to sort the DataFrame columns by all levels.""" + dataset = _sort_by(dataset, "lnl", level=2) + dataset = _sort_by(dataset, "mid", level=1) + return _sort_by(dataset, "top", level=0) diff --git a/src/lydata/validator.py b/src/lydata/validator.py index ed46552..ea3ed78 100644 --- a/src/lydata/validator.py +++ b/src/lydata/validator.py @@ -1,177 +1,23 @@ -"""Module to transform to and validate the CSV schema of the lydata datasets. +"""Module to cast dtypes and to and validate the lyDATA datasets. -Here we define the function :py:func:`construct_schema` to dynamically create a -:py:class:`pandera.DataFrameSchema` that we can use to validate that a given -:py:class:`~pandas.DataFrame` conforms to the minimum requirements of the lyDATA -datasets. +The two main functions here are :py:func:`cast_dtypes` and :py:func:`is_valid`. The +first one can be used to cast the dtypes of the columns in a :py:class:`LyDataFrame` +to the expected types according to the schema constructed using +:py:func:`create_full_record_model`. -Currently, we only publish the :py:func:`validate_datasets` function that validates all -datasets that are found by the function :py:func:`~lydata.loader.available_datasets`. -In the future, we may want to make this more flexible. - -In this module, we also provide the :py:func:`transform_to_lyprox` function that can be -used to transform any raw data into the format that can be uploaded to the `LyProX`_ -platform database. - -.. _LyProX: https://lyprox.org +Subsequently, :py:func:`is_valid` can be used to validate every row in the table, again +using the constructed schema. """ -from typing import Any +import sys +from collections.abc import Mapping +from typing import Any, Literal -import pandas as pd from loguru import logger -from pandera import Check, Column, DataFrameSchema -from pandera.errors import SchemaError - -from lydata.loader import available_datasets - -_NULLABLE_OPTIONAL = {"required": False, "nullable": True} -_NULLABLE_OPTIONAL_BOOLEAN_COLUMN = Column( - dtype="boolean", - coerce=True, - **_NULLABLE_OPTIONAL, -) -_DATE_CHECK = Check.str_matches(r"^\d{4}-\d{2}-\d{2}$") -_LNLS = [ - "I", - "Ia", - "Ib", - "II", - "IIa", - "IIb", - "III", - "IV", - "V", - "Va", - "Vb", - "VI", - "VII", - "VIII", - "IX", - "X", -] - - -class ParsingError(Exception): - """Error while parsing the CSV file.""" - - -patient_columns = { - ("patient", "#", "institution"): Column(str), - ("patient", "#", "sex"): Column(str, Check.str_matches(r"^(male|female)$")), - ("patient", "#", "age"): Column(int), - ("patient", "#", "weight"): Column( - float, Check.greater_than(0), **_NULLABLE_OPTIONAL - ), - ("patient", "#", "diagnose_date"): Column(str, _DATE_CHECK), - ("patient", "#", "alcohol_abuse"): _NULLABLE_OPTIONAL_BOOLEAN_COLUMN, - ("patient", "#", "nicotine_abuse"): _NULLABLE_OPTIONAL_BOOLEAN_COLUMN, - ("patient", "#", "hpv_status"): _NULLABLE_OPTIONAL_BOOLEAN_COLUMN, - ("patient", "#", "neck_dissection"): _NULLABLE_OPTIONAL_BOOLEAN_COLUMN, - ("patient", "#", "tnm_edition"): Column(int, Check.in_range(7, 8)), - ("patient", "#", "n_stage"): Column(int, Check.in_range(0, 3)), - ("patient", "#", "m_stage"): Column(int, Check.in_range(-1, 1)), -} - -tumor_columns = { - ("tumor", "1", "subsite"): Column(str, Check.str_matches(r"^C\d{2}(\.\d)?$")), - ("tumor", "1", "t_stage"): Column(int, Check.in_range(0, 4)), - ("tumor", "1", "stage_prefix"): Column(str, Check.str_matches(r"^(p|c)$")), - ("tumor", "1", "volume"): Column( - float, Check.greater_than(0), **_NULLABLE_OPTIONAL - ), - ("tumor", "1", "central"): _NULLABLE_OPTIONAL_BOOLEAN_COLUMN, - ("tumor", "1", "extension"): _NULLABLE_OPTIONAL_BOOLEAN_COLUMN, -} - - -def get_modality_columns( - modality: str, - lnls: list[str] = _LNLS, -) -> dict[tuple[str, str, str], Column]: - """Get the validation columns for a given modality.""" - cols = {(modality, "info", "date"): Column(str, _DATE_CHECK, **_NULLABLE_OPTIONAL)} - - for side in ["ipsi", "contra"]: - for lnl in lnls: - cols[(modality, side, lnl)] = _NULLABLE_OPTIONAL_BOOLEAN_COLUMN - - return cols - - -def construct_schema( - modalities: list[str], - lnls: list[str] = _LNLS, -) -> DataFrameSchema: - """Construct a :py:class:`pandera.DataFrameSchema` for the lydata datasets.""" - schema = DataFrameSchema(patient_columns).add_columns(tumor_columns) - - for modality in modalities: - schema = schema.add_columns(get_modality_columns(modality, lnls)) - - return schema - - -def validate_datasets( - year: int | str = "*", - institution: str = "*", - subsite: str = "*", - use_github: bool = True, - repo: str = "lycosystem/lydata", - ref: str = "main", - **kwargs, -) -> None: - """Validate all lydata datasets. - - The arguments of this function are directly passed to the - :py:func:`available_datasets` function to determine which datasets to validate. - - Keyword arguments beyond the ones that :py:func:`available_datasets` accepts are - passed to the :py:meth:`~lydata.loader.Dataset.load` method of the - :py:class:`~lydata.loader.Dataset` instances. - """ - lydata_schema = construct_schema( - modalities=["pathology", "diagnostic_consensus", "PET", "CT", "FNA", "MRI"], - ) +from pydantic import BaseModel, Field, PastDate, ValidationError # noqa: F401 - for dataset in available_datasets( - year=year, - institution=institution, - subsite=subsite, - use_github=use_github, - repo_name=repo, - ref=ref, - ): - dataframe = dataset.get_dataframe(**kwargs) - try: - lydata_schema.validate(dataframe) - logger.info(f"Schema validation passed for {dataframe!r}.") - except SchemaError as schema_err: - message = f"Schema validation failed for {dataframe!r}." - logger.error(message, exc_info=schema_err) - raise Exception(message) from schema_err - - -def delete_private_keys(nested: dict) -> dict: - """Delete private keys from a nested dictionary. - - A 'private' key is a key whose name starts with an underscore. For example: - - >>> delete_private_keys({"patient": {"__doc__": "some patient info", "age": 61}}) - {'patient': {'age': 61}} - >>> delete_private_keys({"patient": {"age": 61}}) - {'patient': {'age': 61}} - """ - cleaned = {} - - if isinstance(nested, dict): - for key, value in nested.items(): - if not (isinstance(key, str) and key.startswith("_")): - cleaned[key] = delete_private_keys(value) - else: - cleaned = nested - - return cleaned +from lydata.accessor import LyDataAccessor, LyDataFrame # noqa: F401 +from lydata.schema import create_full_record_model def flatten( @@ -226,131 +72,140 @@ def unflatten(flat: dict) -> dict: return result -def get_depth( - nested_map: dict, - leaf_keys: set | None = None, -) -> int: - """Get the depth at which 'leaf' dicts sit in a nested dictionary. - - A leaf is a dictionary that contains any of the ``leaf_keys``. The default is - ``{"func", "default"}``. - - >>> nested_column_map = {"patient": {"age": {"func": int}}} - >>> get_depth(nested_column_map) - 2 - >>> flat_column_map = flatten(nested_column_map, max_depth=2) - >>> get_depth(flat_column_map) - 1 - >>> nested_column_map = {"patient": {"__doc__": "some patient info", "age": 61}} - >>> get_depth(nested_column_map) # doctest: +ELLIPSIS - Traceback (most recent call last): - ... - ValueError: Leaf of nested map must be dict with any of ['default', 'func']. - """ - leaf_keys = leaf_keys or {"func", "default"} - - for _, value in nested_map.items(): - if not isinstance(value, dict): - raise ValueError( - f"Leaf of nested map must be dict with any of {sorted(leaf_keys)}." - ) - - is_leaf = not set(value.keys()).isdisjoint(leaf_keys) - return 1 if is_leaf else 1 + get_depth(value, leaf_keys) - - raise ValueError("Empty `nested_map`.") +def is_valid(dataset: LyDataFrame, fail_on_error: bool = True) -> bool: + """Validate the given dataset against the lyDATA schema. + Returns ``True`` if all records are valid, otherwise it either raises an error + (if ``fail_on_error`` is ``True``) or returns ``False``. + """ + modalities = dataset.ly.get_modalities() + FullRecord = create_full_record_model(modalities) # noqa: N806 + result = True -def transform_to_lyprox( - raw: pd.DataFrame, - column_map: dict[str | tuple, dict | Any], -) -> pd.DataFrame: - """Transform ``raw`` data into table that can be uploaded directly to `LyProX`_. - - To do so, it uses instructions in the ``colum_map`` dictionary, that needs to have - a particular structure: - - For each column in the final 'lyproxified' :py:class:`pd.DataFrame`, one entry must - exist in the ``column_map`` dictionary. E.g., for the column corresponding to a - patient's age, the dictionary should contain a key-value pair of this shape: - - .. code-block:: python - - column_map = { - ("patient", "#", "age"): { - "func": compute_age_from_raw, - "kwargs": {"randomize": False}, - "columns": ["birthday", "date of diagnosis"] - }, - } - - In this example, the function ``compute_age_from_raw`` is called with the - values of the columns ``"birthday"`` and ``"date of diagnosis"`` as positional - arguments, and the keyword argument ``"randomize"`` is set to ``False``. The - function then returns the patient's age, which is subsequently stored in the column - ``("patient", "#", "age")``. + for _i, row in dataset.iterrows(): + patient_id = row.patient.core.id + record = unflatten(row.to_dict()) - Alternatively, this dictionary can also have a nested, tree-like structure, like - this: + try: + _validated_record = FullRecord(**record) + logger.debug(f"Successful validation of {patient_id=}") + except ValidationError as e: + if fail_on_error: + raise ValueError(f"Validation error for {patient_id=}") from e + logger.error(f"{patient_id}: {e}") + result = False - .. code-block:: python + return result - column_map = { - "patient": { - "#": { - "age": { - "func": compute_age_from_raw, - "kwargs": {"randomize": False}, - "columns": ["birthday", "date of diagnosis"] - } - } - } - } - In this case it is imortant that all the leaf nodes, which are defined by having - either a ``"func"`` or a ``"default"`` key, are at the same depth. Because this - nested dictionary is flattened to look like the first example above. +def _get_field_annotations( + model: type[BaseModel], +) -> dict[str, Any]: + """Get the field annotations of a three-level nested Pydantic model. - .. _LyProX: https://lyprox.org + >>> class Foo(BaseModel): + ... bar: int = 3 + >>> class Baz(BaseModel): + ... foo: Foo = Field(default_factory=Foo) + >>> _get_field_annotations(Baz) + {'foo': {'bar': }} """ - column_map = delete_private_keys(column_map) - instruction_depth = get_depth(column_map) - - if instruction_depth > 1: - column_map = flatten(column_map, max_depth=instruction_depth) + annotations = {} + for field_name, field_info in model.model_fields.items(): + if issubclass(field_info.annotation, BaseModel): + annotations[field_name] = _get_field_annotations(field_info.annotation) + else: + annotations[field_name] = field_info.annotation + + return annotations + + +def _get_default_casters() -> Mapping[type, str]: + """Get the default dtype casters for the lyDATA schema.""" + return { + int: "Int64", + int | None: "Int64", + float: "Float64", + float | None: "Float64", + str: "string", + str | None: "string", + bool: "boolean", + bool | None: "boolean", + PastDate: "datetime64[ns]", + PastDate | None: "datetime64[ns]", + Literal["male", "female"]: "string", + Literal["c", "p"]: "string", + Literal["a", "b"] | None: "string", + Literal["a", "b", "c"] | None: "string", + Literal["left", "right"] | None: "string", + } + + +def cast_dtypes( + dataset: LyDataFrame, + casters: Mapping[type, str] | None = None, + fail_on_error: bool = True, +) -> LyDataFrame: + """Cast the dtypes of the ``dataset`` to the expected types. + + This function uses the annotations of the Pydantic schema to cast the individual + columns of the ``dataset`` to the expected types. It uses the ``casters`` mapping + to determine the type to cast to. By default, it uses the mapping from the + :py:func:`_get_default_casters` function. + + That way, pandas uses e.g. the nullable integer type ``Int64`` if we specify in + pydantic that a field can be an integer or None. If you want to use a different + mapping, you can pass it as the ``casters`` argument. + """ + dataset = dataset.convert_dtypes() - multi_idx = pd.MultiIndex.from_tuples(column_map.keys()) - processed = pd.DataFrame(columns=multi_idx) + if casters is None: + casters = _get_default_casters() - for multi_idx_col, instruction in column_map.items(): - if instruction == "": - continue + modalities = dataset.ly.get_modalities() + FullRecord = create_full_record_model(modalities) # noqa: N806 + annotations = _get_field_annotations(FullRecord) + annotations = flatten(annotations, max_depth=3) - if "default" in instruction: - processed[multi_idx_col] = [instruction["default"]] * len(raw) + for col in dataset.columns: + annotation = annotations.get(col, None) + old_type = dataset[col].dtype + new_type = casters.get(annotation, old_type) - elif "func" in instruction: - cols = instruction.get("columns", []) - kwargs = instruction.get("kwargs", {}) - func = instruction["func"] + if annotation is None: + logger.warning(f"No annotation found for {col=}. Using {old_type=}.") + continue - try: - processed[multi_idx_col] = [ - func(*vals, **kwargs) for vals in raw[cols].values - ] - except Exception as exc: - raise ParsingError( - f"Exception encountered while parsing column {multi_idx_col}" - ) from exc + if new_type == old_type: + logger.debug(f"Column {col=} already has expected {old_type=}. Skipping.") + continue - else: - raise ParsingError( - f"Column {multi_idx_col} has neither a `default` value nor `func` " - "describing how to fill this column." + try: + dataset = dataset.astype({col: new_type}) + logger.success(f"Cast {col=} from {old_type=} to {new_type=}.") + except TypeError as e: + msg = ( + f"Failed to cast column {col=} with ({annotation=}) to " + f"caster = `{new_type}." ) + logger.error(msg) + if fail_on_error: + raise TypeError(msg) from e - return processed + return dataset if __name__ == "__main__": - validate_datasets() + from lydata import loader + + logger.enable("lydata") + logger.remove() + logger.add(sys.stderr, level="DEBUG") + dataset = next( + loader.load_datasets( + repo_name="lycosystem/lydata.private", + ref="e68141fd5440d4cfa6491df14ca2203ddb7946b0", + ) + ) + dataset = cast_dtypes(dataset) + print(f"{is_valid(dataset, fail_on_error=False)=}") # noqa: T201 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..a8f7e60 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,50 @@ +"""Fixtures for testing lydata functionality.""" + +import pandas as pd +import pytest + +import lydata + + +@pytest.fixture(scope="session") +def clb_raw() -> pd.DataFrame: + """Load the CLB dataset.""" + return next( + lydata.load_datasets( + year=2021, + institution="clb", + subsite="oropharynx", + use_github=True, + repo_name="lycosystem/lydata.private", + ref="e68141fd5440d4cfa6491df14ca2203ddb7946b0", + cast=True, + ), + ) + + +@pytest.fixture(scope="session") +def usz_2021_df() -> pd.DataFrame: + """Load the CLB dataset.""" + return next( + lydata.load_datasets( + year=2021, + institution="usz", + repo_name="lycosystem/lydata.private", + ref="fb55afa26ff78afa78274a86b131fb3014d0ceea", + cast=True, + ) + ) + + +@pytest.fixture(scope="session") +def usz_2025_df() -> lydata.LyDataFrame: + """Fixture to load a sample DataFrame from the USZ 2025 dataset.""" + return next( + lydata.load_datasets( + year=2025, + institution="usz", + repo_name="lycosystem/lydata.private", + ref="c11011aa928fe43f18e73e42577a0fcee5652d99", + cast=True, + ) + ) diff --git a/tests/test_accessor.py b/tests/test_accessor.py new file mode 100644 index 0000000..c83ad77 --- /dev/null +++ b/tests/test_accessor.py @@ -0,0 +1,13 @@ +"""Test the ``.ly`` accessor for lyDATA DataFrames.""" + +import lydata # noqa: F401 + + +def test_enhance(usz_2021_df: lydata.LyDataFrame) -> None: + """Test the enhance method of the ly accessor.""" + enhanced = usz_2021_df.ly.enhance() + assert enhanced.shape == (287, 250) + assert "max_llh" in enhanced.columns + assert "Ia" in enhanced.max_llh.ipsi + assert "Ib" in enhanced.max_llh.ipsi + assert "I" in enhanced.max_llh.ipsi diff --git a/tests/test_augmentor.py b/tests/test_augmentor.py new file mode 100644 index 0000000..c7e77a3 --- /dev/null +++ b/tests/test_augmentor.py @@ -0,0 +1,116 @@ +"""Check that inferring sub- and super-levels works correctly.""" + +import pandas as pd + +import lydata # noqa: F401 +from lydata.augmentor import combine_and_augment_levels +from lydata.utils import ModalityConfig, get_default_modalities + + +def test_clb_patient_17(clb_raw: pd.DataFrame) -> None: + """Check the advanced combination and augmentation of diagnoses and levels.""" + modalities = get_default_modalities() + modalities = { + name: mod + for name, mod in modalities.items() + if name in clb_raw.columns.get_level_values(0) + } + clb_aug = combine_and_augment_levels( + diagnoses=[clb_raw[mod] for mod in modalities.keys()], + specificities=[mod.spec for mod in modalities.values()], + sensitivities=[mod.sens for mod in modalities.values()], + ) + assert len(clb_aug) == len(clb_raw), "Augmented data length mismatch" + assert clb_aug.iloc[16].ipsi.I == False + assert clb_aug.iloc[16].ipsi.Ia == False + assert clb_aug.iloc[16].ipsi.Ib == False + + +def test_2021_clb_001(clb_raw: pd.DataFrame) -> None: + """Check that this patient's `NaN` values are handled correctly. + + In this patient, the sublvls are missing, therefore the superlvls should not be + overridden by the augmentor. + """ + idx = clb_raw.ly.id == "2021-CLB-001" + patient = clb_raw.loc[idx] + enhanced = patient.ly.enhance() + assert enhanced.iloc[0].pathology.ipsi.II == patient.iloc[0].pathology.ipsi.II + + +def test_2021_clb_017(clb_raw: pd.DataFrame) -> None: + """Check that this patient's `NaN` values are handled correctly. + + In this patient, pathology reports ipsi.Ib as healthy, while diagnostic consensus + reports ipsi.Ib as involved. This should correctly be combined to ipsi.Ib = False + and the superlvl should also be set to False. + """ + idx = clb_raw.ly.id == "2021-CLB-017" + patient = clb_raw.loc[idx] + enhanced = patient.ly.enhance() + assert len(patient) == len(enhanced) == 1, "Patient data length mismatch" + assert enhanced.iloc[0].max_llh.ipsi.I == False + assert enhanced.iloc[0].max_llh.ipsi.Ib == False + + +def test_2021_usz_009(usz_2021_df: pd.DataFrame) -> None: + """Check the advanced combination and augmentation of diagnoses and levels.""" + modalities = get_default_modalities() + modalities = { + name: mod + for name, mod in modalities.items() + if name in usz_2021_df.columns.get_level_values(0) + } + usz_aug = combine_and_augment_levels( + diagnoses=[usz_2021_df[mod] for mod in modalities.keys()], + specificities=[mod.spec for mod in modalities.values()], + sensitivities=[mod.sens for mod in modalities.values()], + ) + assert len(usz_aug) == len(usz_2021_df), "Augmented data length mismatch" + assert usz_aug.iloc[8].ipsi.III == False + + +def test_2025_usz_080(usz_2025_df: lydata.LyDataFrame) -> None: + """Check that this patient...""" + idx = usz_2025_df.ly.id == "2025-USZ-080" + patient = usz_2025_df.loc[idx] + enhanced = patient.ly.enhance() + assert enhanced.iloc[0].max_llh.ipsi.II == True + assert pd.isna(enhanced.iloc[0].max_llh.ipsi.IIa) + assert pd.isna(enhanced.iloc[0].max_llh.ipsi.IIb) + + +def test_2025_usz_312(usz_2025_df: lydata.LyDataFrame) -> None: + """Check that this patient...""" + idx = usz_2025_df.ly.id == "2025-USZ-312" + patient = usz_2025_df.loc[idx] + assert len(patient) == 1 + assert patient.ly.date.iloc[0].strftime("%Y-%m-%d") == "2013-06-03" + + enhanced = patient.ly.enhance() + assert len(enhanced) == 1 + assert enhanced.iloc[0].max_llh.ipsi.II == False + + +def test_2025_usz_075(usz_2025_df: lydata.LyDataFrame) -> None: + """Ensure patient 2025-USZ-075 is correctly enhanced. + + This patient has a pathologically (FNA) confirmed contra II involvement, but PET + and planning CT (pCT) are negative. Depending on the sensitivity and specificity + values, this leads to a max_llh of True or False for the contra II level. + """ + idx = usz_2025_df.ly.id == "2025-USZ-075" + patient = usz_2025_df.loc[idx] + assert len(patient) == 1 + assert patient.ly.date.iloc[0].strftime("%Y-%m-%d") == "2015-11-23" + assert patient.FNA.contra.II.iloc[0] == True + + enhanced = patient.ly.enhance( + modalities={ + "PET": ModalityConfig(spec=0.86, sens=0.79), + "FNA": ModalityConfig(spec=0.98, sens=0.80, kind="pathological"), + # "pCT": ModalityConfig(spec=0.86, sens=0.81), + } + ) + assert len(enhanced) == 1 + assert enhanced.iloc[0].max_llh.contra.II == True diff --git a/tests/test_infer_levels.py b/tests/test_infer_levels.py deleted file mode 100644 index 80e38c5..0000000 --- a/tests/test_infer_levels.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Check that inferring sub- and super-levels works correctly.""" - -import pandas as pd -import pytest - -import lydata # noqa: F401 - - -@pytest.fixture -def mock_data() -> pd.DataFrame: - """Create a mock dataset for testing.""" - return pd.DataFrame({ - ("MRI", "ipsi", "Ia" ): [True , False, False, None, None ], - ("MRI", "ipsi", "Ib" ): [False, True , False, None, False], - ("MRI", "contra", "IIa"): [False, False, None , None, None ], - ("MRI", "contra", "IIb"): [False, True , True , None, False], - ("CT", "ipsi", "I" ): [True , False, False, None, None ], - }) - - -def test_infer_superlevels(mock_data: pd.DataFrame) -> None: - """Check that superlevels are inferred correctly.""" - inferred = mock_data.ly.infer_superlevels(modalities=["MRI"]) - - expected_ipsi_I = [True, True, False, None, None] - expected_contra_II = [False, True, True, None, None] - - for example in range(len(mock_data)): - assert ( - inferred.iloc[example].MRI.ipsi.I - == expected_ipsi_I[example] - ), f"{example = } mismatch for ipsi I" - assert ( - inferred.iloc[example].MRI.contra.II - == expected_contra_II[example] - ), f"{example = } mismatch for contra II" diff --git a/tests/test_installation.py b/tests/test_installation.py index d31192b..c776b73 100644 --- a/tests/test_installation.py +++ b/tests/test_installation.py @@ -1,6 +1,16 @@ """Simply ensure `lydata` is installed and pytest can proceed with doctests.""" +import os + + +def test_env_vars() -> None: + """Check that the .env file is loaded and the Github token is accessible.""" + token_env_var: str = os.environ.get("GITHUB_TOKEN", "nope") + assert "github" in token_env_var, "GITHUB_TOKEN env var not accessible" + + def test_is_installed() -> None: """Check that `lydata` can be imported (and is therefore installed).""" import lydata # noqa: F401 + assert True, "lydata is not installed or cannot be imported." # noqa: S101 diff --git a/tests/test_schema.py b/tests/test_schema.py new file mode 100644 index 0000000..f4db1a8 --- /dev/null +++ b/tests/test_schema.py @@ -0,0 +1,106 @@ +"""Check the pydantic schema for the lyDATA format works.""" + +import datetime +from typing import Any + +import pytest + +from lydata.schema import ( + BaseRecord, + PatientCore, + PatientRecord, + TumorCore, + TumorRecord, +) + + +@pytest.fixture +def patient_core_dict() -> dict[str, Any]: + """Fixture for a sample patient info.""" + return { + "id": "12345", + "institution": "Test Hospital", + "sex": "female", + "age": 42, + "diagnose_date": "2023-01-01", + "alcohol_abuse": False, + "nicotine_abuse": True, + "pack_years": 10.0, + "hpv_status": True, + "neck_dissection": True, + "tnm_edition": 8, + "n_stage": 1, + "m_stage": 0, + } + + +@pytest.fixture +def tumor_core_dict() -> dict[str, Any]: + """Fixture for a sample tumor info.""" + return { + "location": "gums", + "subsite": "C03.9", + "central": False, + "extension": True, + "t_stage_prefix": "c", + "t_stage": 2, + } + + +def test_patient_core(patient_core_dict: dict[str, Any]) -> None: + """Test the PatientInfo schema.""" + patient_info = PatientCore(**patient_core_dict) + + for key, dict_value in patient_core_dict.items(): + model_value = getattr(patient_info, key) + if isinstance(model_value, datetime.date): + model_value = model_value.isoformat() + assert model_value == dict_value, f"Mismatch for {key}" + + +def test_tumor_core(tumor_core_dict: dict[str, Any]) -> None: + """Test the TumorInfo schema.""" + tumor_core = TumorCore(**tumor_core_dict) + + for key, value in tumor_core_dict.items(): + assert getattr(tumor_core, key) == value, f"Mismatch for {key}" + + +@pytest.fixture +def patient_core(patient_core_dict: dict[str, Any]) -> PatientCore: + """Fixture for a sample PatientInfo instance.""" + return PatientCore(**patient_core_dict) + + +@pytest.fixture +def tumor_core(tumor_core_dict: dict[str, Any]) -> TumorCore: + """Fixture for a sample TumorInfo instance.""" + return TumorCore(**tumor_core_dict) + + +def test_patient_record(patient_core: PatientCore) -> None: + """Test the PatientRecord schema.""" + record = PatientRecord(core=patient_core) + + assert record.core == patient_core, "PatientRecord info does not match PatientInfo" + + +def test_tumor_record(tumor_core: TumorCore) -> None: + """Test the TumorRecord schema.""" + record = TumorRecord(core=tumor_core) + + assert record.core == tumor_core, "TumorRecord info does not match TumorInfo" + + +@pytest.fixture +def complete_record(patient_core: PatientCore, tumor_core: TumorCore) -> BaseRecord: + """Fixture for a sample CompleteRecord instance.""" + return BaseRecord( + patient=PatientRecord(core=patient_core), + tumor=TumorRecord(core=tumor_core), + ) + + +def test_complete_record(complete_record: BaseRecord) -> None: + """Test the CompleteRecord schema.""" + assert complete_record.patient.core.id == "12345", "Patient ID does not match" diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..760915e --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,18 @@ +"""Tet some of the utility functions in `lydata.utils`.""" + +import pandas as pd + +from lydata.utils import update_and_expand + + +def test_update_and_expand_using_p035(clb_raw: pd.DataFrame) -> None: + """Check the `update_and_expand` function with a specific patient.""" + idx = clb_raw.ly.id == "2021-CLB-017" + patient = clb_raw.loc[idx] + combined = patient.ly.combine() + combined = pd.concat({"test": combined}, axis="columns") + augmented = combined.ly.augment(modality="test") + augmented = pd.concat({"test": augmented}, axis="columns") + result = update_and_expand(combined, augmented) + assert len(result) == 1 + assert not pd.isna(result.iloc[0].test.ipsi.I) diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 0000000..03fe144 --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,15 @@ +"""Test the casting and validation of lydata datasets.""" + +import pandas as pd + +from lydata.validator import cast_dtypes + + +def test_casting(clb_raw: pd.DataFrame) -> None: + """Test the casting of a dataset.""" + clb_casted = cast_dtypes(clb_raw) + + assert clb_casted.patient.core.id.dtype == "string" + assert clb_casted.patient.core.age.dtype == "Int64" + assert clb_casted.patient.core.diagnose_date.dtype == "datetime64[ns]" + assert clb_casted.tumor.core.t_stage.dtype == "Int64" diff --git a/uv.lock b/uv.lock index 310505d..74ffcc5 100644 --- a/uv.lock +++ b/uv.lock @@ -473,6 +473,7 @@ dependencies = [ { name = "pandera" }, { name = "pydantic" }, { name = "pygithub" }, + { name = "roman" }, ] [package.optional-dependencies] @@ -491,6 +492,7 @@ docs = [ tests = [ { name = "pytest" }, { name = "pytest-cov" }, + { name = "python-dotenv" }, ] [package.metadata] @@ -507,6 +509,8 @@ requires-dist = [ { name = "pygithub" }, { name = "pytest", marker = "extra == 'tests'" }, { name = "pytest-cov", marker = "extra == 'tests'" }, + { name = "python-dotenv", marker = "extra == 'tests'", specifier = ">=1.1.1" }, + { name = "roman" }, { name = "sphinx", marker = "extra == 'docs'" }, { name = "sphinx-autodoc-typehints", marker = "extra == 'docs'" }, { name = "sphinx-book-theme", marker = "extra == 'docs'" }, @@ -1124,6 +1128,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892 }, ] +[[package]] +name = "python-dotenv" +version = "1.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f6/b0/4bc07ccd3572a2f9df7e6782f52b0c6c90dcbb803ac4a167702d7d0dfe1e/python_dotenv-1.1.1.tar.gz", hash = "sha256:a8a6399716257f45be6a007360200409fce5cda2661e3dec71d23dc15f6189ab", size = 41978 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5f/ed/539768cf28c661b5b068d66d96a2f155c4971a5d55684a514c1a0e0dec2f/python_dotenv-1.1.1-py3-none-any.whl", hash = "sha256:31f23644fe2602f88ff55e1f5c79ba497e01224ee7737937930c448e4d0e24dc", size = 20556 }, +] + [[package]] name = "pytz" version = "2025.2" @@ -1192,6 +1205,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl", hash = "sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c", size = 64847 }, ] +[[package]] +name = "roman" +version = "5.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/30/86/8bdb59db4b7ea9a2bd93f8d25298981e09a4c9f4744cf4cbafa7ef6fee7b/roman-5.1.tar.gz", hash = "sha256:3a86572e9bc9183e771769601189e5fa32f1620ffeceebb9eca836affb409986", size = 8066 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f7/d0/27c9840ddaf331ace898c7f4aa1e1304a7acc22b844b5420fabb6d14c3a0/roman-5.1-py3-none-any.whl", hash = "sha256:bf595d8a9bc4a8e8b1dfa23e1d4def0251b03b494786df6b8c3d3f1635ce285a", size = 5825 }, +] + [[package]] name = "roman-numerals-py" version = "3.1.0"