diff --git a/.github/workflows/ci_linting.yml b/.github/workflows/ci_linting.yml index c7e57fa..f1b8982 100644 --- a/.github/workflows/ci_linting.yml +++ b/.github/workflows/ci_linting.yml @@ -13,7 +13,7 @@ jobs: steps: - uses: actions/checkout@v5 - - name: Install extra dependencies for a python 3.7.17 install + - name: Install extra dependencies for a python install run: | sudo apt-get update sudo apt -y install --no-install-recommends liblzma-dev libbz2-dev libreadline-dev @@ -26,6 +26,9 @@ jobs: - name: reshim asdf run: asdf reshim + + - name: ensure poetry using desired python version + run: poetry env use $(asdf which python) - name: Cache Poetry virtualenv uses: actions/cache@v4 diff --git a/.github/workflows/ci_testing.yml b/.github/workflows/ci_testing.yml index 24fa5fc..f54eb67 100644 --- a/.github/workflows/ci_testing.yml +++ b/.github/workflows/ci_testing.yml @@ -13,11 +13,12 @@ jobs: steps: - name: Checkout code uses: actions/checkout@v5 - - - name: Install extra dependencies for a python 3.7.17 install + + - name: Install extra dependencies for a python install run: | sudo apt-get update sudo apt -y install --no-install-recommends liblzma-dev libbz2-dev libreadline-dev + - name: Install asdf cli uses: asdf-vm/actions/setup@v4 @@ -26,6 +27,9 @@ jobs: - name: reshim asdf run: asdf reshim + + - name: ensure poetry using desired python version + run: poetry env use $(asdf which python) - name: Cache Poetry virtualenv uses: actions/cache@v4 @@ -42,7 +46,6 @@ jobs: - name: Run pytest and coverage run: | export JAVA_HOME=$(asdf where java) - echo "JAVA_HOME - $JAVA_HOME" make coverage - name: Upload Coverage Report @@ -54,5 +57,4 @@ jobs: - name: Run behave tests run: | export JAVA_HOME=$(asdf where java) - echo "JAVA_HOME - $JAVA_HOME" make behave \ No newline at end of file diff --git a/.mise.toml b/.mise.toml index f75aa25..b31659e 100644 --- a/.mise.toml +++ b/.mise.toml @@ -1,4 +1,4 @@ [tools] -python="3.7.17" -poetry="1.4.2" +python="3.11" +poetry="2.2" java="liberica-1.8.0" diff --git a/.tool-versions b/.tool-versions index 385d4e1..b23db8d 100644 --- a/.tool-versions +++ b/.tool-versions @@ -1,3 +1,3 @@ -python 3.7.17 -poetry 1.4.2 +python 3.11.14 +poetry 2.2.0 java liberica-1.8.0 diff --git a/CHANGELOG.md b/CHANGELOG.md index 19f70c1..12c69f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,28 @@ +## v0.2.0 (2025-11-12) + +### Refactor + +- ensure dve working on python 3.10 +- ensure dve working on python 3.11 + +### BREAKING CHANGE + +- Numerous typing updates that will make this codebase unusable below python 3.9 + +note - this does not mean the package will work on python 3.9. Minimum working version is 3.10. + +### Feat + +- added functionality to allow error messages in business rules t… (#8) + +### Refactor + +- bump pylint to work correctly with py3.11 and fix numerous linting issues + ## 0.1.0 (2025-11-10) +*NB - This was previously v1.0.0 and v1.1.0 but has been rolled back into a 0.1.0 release to reflect lack of package stability.* + ### Feat - Added ability to define custom error codes and templated messages for data contract feedback messages diff --git a/Makefile b/Makefile index 7feca08..cfad520 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ activate = poetry run # dev install: poetry lock - poetry install --with dev,test + poetry install --with dev # dist wheel: @@ -27,6 +27,15 @@ coverage: $(activate) coverage report $(activate) coverage xml +# lint +pylint: + ${activate} pylint src/ + +mypy: + ${activate} mypy src/ + +lint: mypy pylint + # pre-commit pre-commit-all: ${activate} pre-commit run --all-files diff --git a/README.md b/README.md index 03b81ba..635df6a 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Data Validation Engine -The Data Validation Engine (DVE) is a configuration driven data validation library built and utilised by NHS England. +The Data Validation Engine (DVE) is a configuration driven data validation library built and utilised by NHS England. Currently the package has been reverted from v1.0.0 release to a 0.x as we feel the package is not yet mature enough to be considered a 1.0.0 release. So please bare this in mind if reading through the commits and references to a v1+ release when on v0.x. As mentioned above, the DVE is "configuration driven" which means the majority of development for you as a user will be building a JSON document to describe how the data will be validated. The JSON document is known as a `dischema` file and example files can be accessed [here](./tests/testdata/). If you'd like to learn more about JSON document and how to build one from scratch, then please read the documentation [here](./docs/). @@ -21,7 +21,7 @@ Additionally, if you'd like to contribute a new backend implementation into the ## Installation and usage -The DVE is a Python package and can be installed using `pip`. As of release v0.1.0 we currently only supports Python 3.7, with Spark version 3.2.1 and DuckDB version of 1.1.0. We are currently working on upgrading the DVE to work on Python 3.11+ and this will be made available asap with version 1.0.0 release. +The DVE is a Python package and can be installed using `pip`. As of release v0.1.x we currently only supports Python 3.7, with Spark version 3.2.1 and DuckDB version of 1.1.0. We are currently working on upgrading the DVE to work on Python 3.10-3.11 and this will be made available with version v0.2.x release. In addition to a working Python 3.7+ installation you will need OpenJDK 11 installed if you're planning to use the Spark backend implementation. @@ -49,7 +49,7 @@ Below is a list of features that we would like to implement or have been request | Feature | Release Version | Released? | | ------- | --------------- | --------- | | Open source release | 0.1.0 | Yes | -| Uplift to Python 3.11 | 1.0.0 | No | +| Uplift to Python 3.11 | 0.2.0 | Yes | | Upgrade to Pydantic 2.0 | Not yet confirmed | No | | Create a more user friendly interface for building and modifying dischema files | Not yet confirmed | No | diff --git a/pyproject.toml b/pyproject.toml index eaeb708..31ce5dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "nhs_dve" -version = "0.1.0" +version = "0.2.0" description = "`nhs data validation engine` is a framework used to validate data" authors = ["NHS England "] readme = "README.md" @@ -9,58 +9,73 @@ packages = [ ] classifiers = [ "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Operating System :: OS Independent", "Topic :: Software Development :: Libraries", "Typing :: Typed", ] [tool.poetry.dependencies] -python = ">=3.7.2,<3.8" -boto3 = "1.28.47" # Boto3 will no longer support Python 3.7 starting December 13, 2023 -botocore = "1.31.47" -delta-spark = "1.1.0" +python = ">=3.10,<3.12" +boto3 = "1.34.162" +botocore = "1.34.162" +delta-spark = "2.4.0" duckdb = "1.1.0" # mitigates security vuln in < 1.1.0 formulas = "1.2.4" idna = "3.7" # Downstream dep of requests but has security vuln < 3.7 Jinja2 = "3.1.6" # mitigates security vuln in < 3.1.6 lxml = "4.9.1" openpyxl = "3.1.0" -pandas = "1.3.5" -polars = "0.17.14" -pyarrow = "7.0.0" +pandas = "2.2.2" +polars = "0.20.14" +pyarrow = "17.0.0" pydantic = "1.10.15" # Mitigates security vuln in < 1.10.13 pymongo = "4.6.3" -pyspark = "3.2.1" +pyspark = "3.4.4" pytz = "2022.1" -PyYAML = "5.4" -requests = "2.31.0" +PyYAML = "6.0.3" +requests = "2.32.4" # Mitigates security vuln in < 2.31.0 schedula = "1.2.19" sqlalchemy = "2.0.19" typing_extensions = "4.6.2" -urllib3 = "1.26.19" # Used transiently, but has security vuln < 1.26.19 +urllib3 = "2.5.0" # Mitigates security vuln in < 1.26.19 xmltodict = "0.13.0" +[tool.poetry.group.dev] +optional = true +include-groups = [ + "test", + "lint" +] + [tool.poetry.group.dev.dependencies] -commitizen = "3.9.1" # latest version to support Python 3.7.17 -pre-commit = "2.21.0" # latest version to support Python 3.7.17 +commitizen = "4.9.1" +pre-commit = "4.3.0" + +[tool.poetry.group.test] +optional = true [tool.poetry.group.test.dependencies] faker = "18.11.1" -behave = "1.2.6" -coverage = "6.4.3" -moto = {extras = ["s3"], version = "3.1.18"} +behave = "1.3.3" +coverage = "7.11.0" +moto = {extras = ["s3"], version = "4.0.13"} +Werkzeug = "3.0.6" # Dependency of moto which needs 3.0.6 for security vuln mitigation mongomock = "4.1.2" -pytest = "7.4.4" -pytest-lazy-fixture = "0.6.3" +pytest = "8.4.2" +pytest-lazy-fixtures = "1.4.0" # switched from https://github.com/TvoroG/pytest-lazy-fixture as it's no longer supported xlsx2csv = "0.8.2" +[tool.poetry.group.lint] +optional = true + [tool.poetry.group.lint.dependencies] -black = "22.6.0" -astroid = "2.11.7" +black = "24.3.0" +astroid = "2.14.2" isort = "5.11.5" -pylint = "2.14.5" -mypy = "0.982" +pylint = "2.16.4" +mypy = "0.991" boto3-stubs = {extras = ["essential"], version = "1.26.72"} botocore-stubs = "1.29.72" pandas-stubs = "1.2.0.62" @@ -112,9 +127,8 @@ source_pkgs = [ show_missing = true [tool.pylint] -# Can't add support for custom checker until running on Python 3.9+ again. -# init-hook = "import sys; sys.path.append('./pylint_checkers')" -# load-plugins = "check_typing_imports" +init-hook = "import sys; sys.path.append('./pylint_checkers')" +load-plugins = "check_typing_imports" [tool.pylint.main] # Analyse import fallback blocks. This can be used to support both Python 2 and 3 @@ -189,7 +203,7 @@ persistent = true # Minimum Python version to use for version dependent checks. Will default to the # version used to run pylint. -py-version = "3.7" +py-version = "3.11" # Discover python modules and packages in the file system subtree. # recursive = diff --git a/src/dve/core_engine/backends/base/auditing.py b/src/dve/core_engine/backends/base/auditing.py index 9ec6f97..37b1774 100644 --- a/src/dve/core_engine/backends/base/auditing.py +++ b/src/dve/core_engine/backends/base/auditing.py @@ -5,27 +5,14 @@ import threading import time from abc import abstractmethod +from collections.abc import Callable, Iterable from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from datetime import date, datetime, timedelta from multiprocessing import Queue as ProcessQueue from queue import Queue as ThreadQueue from types import TracebackType -from typing import ( - Any, - Callable, - ClassVar, - Dict, - Generic, - Iterable, - List, - Optional, - Set, - Tuple, - Type, - TypeVar, - Union, -) +from typing import Any, ClassVar, Generic, Optional, TypeVar, Union from pydantic import ValidationError, validate_arguments from typing_extensions import Literal, get_origin @@ -55,7 +42,7 @@ class FilterCriteria: field: str comparison_value: Any operator_: Callable = operator.eq - operator_mapping: ClassVar[Dict[BinaryComparator, str]] = { + operator_mapping: ClassVar[dict[BinaryComparator, str]] = { operator.eq: "=", operator.ne: "!=", operator.lt: "<", @@ -102,12 +89,12 @@ class BaseAuditor(Generic[AuditReturnType]): """Base auditor object - defines structure for implementations to use in conjunction with AuditingManager""" - def __init__(self, name: str, record_type: Type[AuditRecord]): + def __init__(self, name: str, record_type: type[AuditRecord]): self._name = name self._record_type = record_type @property - def schema(self) -> Dict[str, type]: + def schema(self) -> dict[str, type]: """Determine python schema of auditor""" return { fld: str if get_origin(mdl.type_) == Literal else mdl.type_ @@ -135,27 +122,27 @@ def conv_to_records(self, recs: AuditReturnType) -> Iterable[AuditRecord]: raise NotImplementedError() @abstractmethod - def conv_to_entity(self, recs: List[AuditRecord]) -> AuditReturnType: + def conv_to_entity(self, recs: list[AuditRecord]) -> AuditReturnType: """Convert the list of pydantic models to an entity for use in pipelines""" raise NotImplementedError() @abstractmethod - def add_records(self, records: Iterable[Dict[str, Any]]): + def add_records(self, records: Iterable[dict[str, Any]]): """Add audit records to the Auditor""" raise NotImplementedError() @abstractmethod def retrieve_records( - self, filter_criteria: List[FilterCriteria], data: Optional[AuditReturnType] = None + self, filter_criteria: list[FilterCriteria], data: Optional[AuditReturnType] = None ) -> AuditReturnType: """Retrieve audit records from the Auditor""" raise NotImplementedError() def get_most_recent_records( self, - order_criteria: List[OrderCriteria], - partition_fields: Optional[List[str]] = None, - pre_filter_criteria: Optional[List[FilterCriteria]] = None, + order_criteria: list[OrderCriteria], + partition_fields: Optional[list[str]] = None, + pre_filter_criteria: Optional[list[FilterCriteria]] = None, ) -> AuditReturnType: """Retrieve the most recent records, defined by the ordering criteria for each partition combination""" @@ -203,12 +190,12 @@ def combine_auditor_information( raise NotImplementedError() @staticmethod - def conv_to_iterable(recs: Union[AuditorType, AuditReturnType]) -> Iterable[Dict[str, Any]]: + def conv_to_iterable(recs: Union[AuditorType, AuditReturnType]) -> Iterable[dict[str, Any]]: """Convert AuditReturnType to iterable of dictionaries""" raise NotImplementedError() @validate_arguments - def add_processing_records(self, processing_records: List[ProcessingStatusRecord]): + def add_processing_records(self, processing_records: list[ProcessingStatusRecord]): """Add an entry to the processing_status auditor.""" if self.pool: return self._submit( @@ -220,7 +207,7 @@ def add_processing_records(self, processing_records: List[ProcessingStatusRecord ) @validate_arguments - def add_submission_statistics_records(self, sub_stats: List[SubmissionStatisticsRecord]): + def add_submission_statistics_records(self, sub_stats: list[SubmissionStatisticsRecord]): """Add an entry to the submission statistics auditor.""" if self.pool: return self._submit( @@ -230,7 +217,7 @@ def add_submission_statistics_records(self, sub_stats: List[SubmissionStatistics return self._submission_statistics.add_records(records=[dict(rec) for rec in sub_stats]) @validate_arguments - def add_transfer_records(self, transfer_records: List[TransferRecord]): + def add_transfer_records(self, transfer_records: list[TransferRecord]): """Add an entry to the transfers auditor""" if self.pool: return self._submit( @@ -241,7 +228,7 @@ def add_transfer_records(self, transfer_records: List[TransferRecord]): @validate_arguments def add_new_submissions( self, - submissions: List[SubmissionMetadata], + submissions: list[SubmissionMetadata], job_run_id: Optional[int] = None, ): """Add an entry to the submission_info auditor.""" @@ -250,8 +237,8 @@ def add_new_submissions( time_now: datetime = datetime.now() ts_info = {"time_updated": time_now, "date_updated": time_now.date()} - processing_status_recs: List[Dict[str, Any]] = [] - submission_info_recs: List[Dict[str, Any]] = [] + processing_status_recs: list[dict[str, Any]] = [] + submission_info_recs: list[dict[str, Any]] = [] for sub_info in submissions: # add processing_record - add time info @@ -311,7 +298,7 @@ def is_writing(self) -> bool: return not self.queue.empty() or locked - def mark_transform(self, submission_ids: List[str], **kwargs): + def mark_transform(self, submission_ids: list[str], **kwargs): """Update submission processing_status to file_transformation.""" recs = [ @@ -323,7 +310,7 @@ def mark_transform(self, submission_ids: List[str], **kwargs): return self.add_processing_records(recs) - def mark_data_contract(self, submission_ids: List[str], **kwargs): + def mark_data_contract(self, submission_ids: list[str], **kwargs): """Update submission processing_status to data_contract.""" recs = [ @@ -335,7 +322,7 @@ def mark_data_contract(self, submission_ids: List[str], **kwargs): return self.add_processing_records(recs) - def mark_business_rules(self, submissions: List[Tuple[str, bool]], **kwargs): + def mark_business_rules(self, submissions: list[tuple[str, bool]], **kwargs): """Update submission processing_status to business_rules.""" recs = [ @@ -352,11 +339,11 @@ def mark_business_rules(self, submissions: List[Tuple[str, bool]], **kwargs): def mark_error_report( self, - submissions: List[Tuple[str, SubmissionResult]], + submissions: list[tuple[str, SubmissionResult]], job_run_id: Optional[int] = None, ): """Mark the given submission as being ready for error report""" - processing_recs: List[ProcessingStatusRecord] = [] + processing_recs: list[ProcessingStatusRecord] = [] sub_id: str sub_result: str @@ -373,7 +360,7 @@ def mark_error_report( return self.add_processing_records(processing_recs) - def mark_finished(self, submissions: List[Tuple[str, SubmissionResult]], **kwargs): + def mark_finished(self, submissions: list[tuple[str, SubmissionResult]], **kwargs): """Update submission processing_status to finished.""" recs = [ @@ -388,7 +375,7 @@ def mark_finished(self, submissions: List[Tuple[str, SubmissionResult]], **kwarg return self.add_processing_records(recs) - def mark_failed(self, submissions: List[str], **kwargs): + def mark_failed(self, submissions: list[str], **kwargs): """Update submission processing_status to failed.""" recs = [ ProcessingStatusRecord( @@ -399,7 +386,7 @@ def mark_failed(self, submissions: List[str], **kwargs): return self.add_processing_records(recs) - def mark_archived(self, submissions: List[str], **kwargs): + def mark_archived(self, submissions: list[str], **kwargs): """Update submission processing_status to archived.""" recs = [ ProcessingStatusRecord( @@ -410,7 +397,7 @@ def mark_archived(self, submissions: List[str], **kwargs): return self.add_processing_records(recs) - def add_feedback_transfer_ids(self, submissions: List[Tuple[str, str]], **kwargs): + def add_feedback_transfer_ids(self, submissions: list[tuple[str, str]], **kwargs): """Adds transfer_id for error report to submission""" recs = [ TransferRecord( @@ -425,7 +412,7 @@ def add_feedback_transfer_ids(self, submissions: List[Tuple[str, str]], **kwargs return self.add_transfer_records(recs) def get_latest_processing_records( - self, filter_criteria: Optional[List[FilterCriteria]] = None + self, filter_criteria: Optional[list[FilterCriteria]] = None ) -> AuditReturnType: """Get the most recent processing record for each submission_id stored in the processing_status auditor""" @@ -441,10 +428,10 @@ def downstream_pending( max_concurrency: int = 1, run_number: int = 0, max_days_old: int = 3, - statuses_to_include: Optional[List[ProcessingStatus]] = None, + statuses_to_include: Optional[list[ProcessingStatus]] = None, ) -> bool: """Checks if there are any downstream submissions currently pending""" - steps: List[ProcessingStatus] = [ + steps: list[ProcessingStatus] = [ "received", "file_transformation", "data_contract", @@ -452,7 +439,7 @@ def downstream_pending( "error_report", ] - downstream: Set[ProcessingStatus] + downstream: set[ProcessingStatus] if statuses_to_include: downstream = {status, *statuses_to_include} else: @@ -519,7 +506,7 @@ def __enter__(self): def __exit__( self, - exc_type: Optional[Type[Exception]], + exc_type: Optional[type[Exception]], exc_value: Optional[Exception], traceback: Optional[TracebackType], ) -> None: @@ -532,7 +519,7 @@ def __exit__( def _get_status( self, - status: Union[ProcessingStatus, Set[ProcessingStatus], List[ProcessingStatus]], + status: Union[ProcessingStatus, set[ProcessingStatus], list[ProcessingStatus]], max_days_old: int, ) -> AuditReturnType: _filter = [ @@ -572,8 +559,8 @@ def get_all_error_report_submissions(self, max_days_old: int = 3): self.combine_auditor_information(subs, self._submission_info) ) - processed: List[SubmissionInfo] = [] - dodgy_info: List[Tuple[Dict, str]] = [] + processed: list[SubmissionInfo] = [] + dodgy_info: list[tuple[dict, str]] = [] for sub_info in sub_infos: try: diff --git a/src/dve/core_engine/backends/base/backend.py b/src/dve/core_engine/backends/base/backend.py index 883bc8c..bed2e17 100644 --- a/src/dve/core_engine/backends/base/backend.py +++ b/src/dve/core_engine/backends/base/backend.py @@ -3,7 +3,8 @@ import logging import warnings from abc import ABC, abstractmethod -from typing import Any, ClassVar, Dict, Generic, Mapping, MutableMapping, Optional, Tuple, Type +from collections.abc import Mapping, MutableMapping +from typing import Any, ClassVar, Generic, Optional from pyspark.sql import DataFrame, SparkSession @@ -28,7 +29,7 @@ class BaseBackend(Generic[EntityType], ABC): """A complete implementation of a backend.""" - __entity_type__: ClassVar[Type[EntityType]] # type: ignore + __entity_type__: ClassVar[type[EntityType]] # type: ignore """ The entity type used within the backend. @@ -45,7 +46,7 @@ def __init__( # pylint: disable=unused-argument self, contract: BaseDataContract[EntityType], steps: BaseStepImplementations[EntityType], - reference_data_loader_type: Optional[Type[BaseRefDataLoader[EntityType]]], + reference_data_loader_type: Optional[type[BaseRefDataLoader[EntityType]]], logger: Optional[logging.Logger] = None, **kwargs: Any, ) -> None: @@ -76,7 +77,7 @@ def __init__( # pylint: disable=unused-argument def load_reference_data( self, - reference_entity_config: Dict[EntityName, ReferenceConfigUnion], + reference_entity_config: dict[EntityName, ReferenceConfigUnion], submission_info: Optional[SubmissionInfo], ) -> Mapping[EntityName, EntityType]: """Load the reference data as specified in the reference entity config.""" @@ -115,7 +116,7 @@ def write_entities_to_parquet( def convert_entities_to_spark( self, entities: Entities, cache_prefix: URI, _emit_deprecation_warning: bool = True - ) -> Dict[EntityName, DataFrame]: + ) -> dict[EntityName, DataFrame]: """Convert entities to Spark DataFrames. Entities may be omitted if they are blank, because Spark cannot create an @@ -151,7 +152,7 @@ def apply( contract_metadata: DataContractMetadata, rule_metadata: RuleMetadata, submission_info: Optional[SubmissionInfo] = None, - ) -> Tuple[Entities, Messages, StageSuccessful]: + ) -> tuple[Entities, Messages, StageSuccessful]: """Apply the data contract and the rules, returning the entities and all generated messages. @@ -184,7 +185,7 @@ def process( rule_metadata: RuleMetadata, cache_prefix: URI, submission_info: Optional[SubmissionInfo] = None, - ) -> Tuple[MutableMapping[EntityName, URI], Messages]: + ) -> tuple[MutableMapping[EntityName, URI], Messages]: """Apply the data contract and the rules, write the entities out to parquet and returning the entity locations and all generated messages. @@ -205,7 +206,7 @@ def process_legacy( rule_metadata: RuleMetadata, cache_prefix: URI, submission_info: Optional[SubmissionInfo] = None, - ) -> Tuple[MutableMapping[EntityName, DataFrame], Messages]: + ) -> tuple[MutableMapping[EntityName, DataFrame], Messages]: """Apply the data contract and the rules, create Spark `DataFrame`s from the entities and return the Spark entities and all generated messages. diff --git a/src/dve/core_engine/backends/base/contract.py b/src/dve/core_engine/backends/base/contract.py index 65aa776..338bd9f 100644 --- a/src/dve/core_engine/backends/base/contract.py +++ b/src/dve/core_engine/backends/base/contract.py @@ -2,8 +2,9 @@ import logging from abc import ABC, abstractmethod +from collections.abc import Iterable, Iterator from inspect import ismethod -from typing import Any, ClassVar, Dict, Generic, Iterable, Iterator, Optional, Tuple, Type, TypeVar +from typing import Any, ClassVar, Generic, Optional, TypeVar from pydantic import BaseModel from typing_extensions import Protocol @@ -30,7 +31,7 @@ from dve.parser.type_hints import Extension T = TypeVar("T") -ExtensionConfig = Dict[Extension, "ReaderConfig"] +ExtensionConfig = dict[Extension, "ReaderConfig"] """Configuration options for file extensions.""" _READER_OVERRIDE_ATTR_NAME = "_implements_reader_for" """The name of the reader override function's reader override attribute.""" @@ -54,12 +55,11 @@ def __call__( # pylint: disable=bad-staticmethod-argument reader: BaseFileReader, resource: URI, entity_name: EntityName, - schema: Type[BaseModel], - ) -> T: - ... + schema: type[BaseModel], + ) -> T: ... -def reader_override(reader_type: Type[BaseFileReader]) -> WrapDecorator: +def reader_override(reader_type: type[BaseFileReader]) -> WrapDecorator: """A decorator function which wraps a `ReaderProtocol` method to add support for custom reader overrides. @@ -79,7 +79,7 @@ def reader_impl_decorator(func: ArbitraryFunction) -> ArbitraryFunction: class BaseDataContract(Generic[EntityType], ABC): """The base implementation of a data contract.""" - __entity_type__: ClassVar[Type[EntityType]] # type: ignore + __entity_type__: ClassVar[type[EntityType]] # type: ignore """ The entity type that should be requested from a reader without a specific implementation. @@ -87,7 +87,7 @@ class BaseDataContract(Generic[EntityType], ABC): This will be populated from the generic annotation at class creation time. """ - __reader_overrides__: ClassVar[Dict[Type[BaseFileReader], _UnboundReaderOverride[EntityType]]] = {} # type: ignore # pylint: disable=line-too-long + __reader_overrides__: ClassVar[dict[type[BaseFileReader], _UnboundReaderOverride[EntityType]]] = {} # type: ignore # pylint: disable=line-too-long """ A dictionary mapping implemented reader types to override functions which provide a 'local' implementation of the reader. These can provide a more optimised version @@ -134,7 +134,7 @@ def __init__( # pylint: disable=unused-argument @abstractmethod def create_entity_from_py_iterator( - self, entity_name: EntityName, records: Iterator[Dict[str, Any]], schema: Type[BaseModel] + self, entity_name: EntityName, records: Iterator[dict[str, Any]], schema: type[BaseModel] ) -> EntityType: """A fallback function to be used where no entity type specific reader implemenattions are available. @@ -146,7 +146,7 @@ def read_entity_from_py_iterator( reader: BaseFileReader, resource: URI, entity_name: EntityName, - schema: Type[BaseModel], + schema: type[BaseModel], ) -> EntityType: """A fallback function for readers that should read records with the 'read_to_py_iterator' implementation and create an entity of the correct @@ -165,7 +165,7 @@ def read_entity( reader: BaseFileReader, resource: URI, entity_name: EntityName, - schema: Type[BaseModel], + schema: type[BaseModel], ) -> EntityType: """Read an entity using the provided reader class. @@ -305,7 +305,7 @@ def _ensure_entity_locations_have_read_support( def read_raw_entities( self, entity_locations: EntityLocations, contract_metadata: DataContractMetadata - ) -> Tuple[Entities, Messages, StageSuccessful]: + ) -> tuple[Entities, Messages, StageSuccessful]: """Read the raw entities from the entity locations using the configured readers. These will not yet have had the data contracts applied. @@ -361,7 +361,7 @@ def read_raw_entities( @abstractmethod def apply_data_contract( self, entities: Entities, contract_metadata: DataContractMetadata - ) -> Tuple[Entities, Messages, StageSuccessful]: + ) -> tuple[Entities, Messages, StageSuccessful]: """Apply the data contract to the raw entities, returning the validated entities and any messages. @@ -372,7 +372,7 @@ def apply_data_contract( def apply( self, entity_locations: EntityLocations, contract_metadata: DataContractMetadata - ) -> Tuple[Entities, Messages, StageSuccessful]: + ) -> tuple[Entities, Messages, StageSuccessful]: """Read the entities from the provided locations according to the data contract, and return the validated entities and any messages. diff --git a/src/dve/core_engine/backends/base/core.py b/src/dve/core_engine/backends/base/core.py index 286252d..e04ecde 100644 --- a/src/dve/core_engine/backends/base/core.py +++ b/src/dve/core_engine/backends/base/core.py @@ -1,6 +1,7 @@ """Core functionality for the backend bases.""" -from typing import Any, Callable, Generic, Iterator, Mapping, MutableMapping, Optional, Tuple, Type +from collections.abc import Callable, Iterator, Mapping, MutableMapping +from typing import Any, Generic, Optional from typing_extensions import get_args, get_origin @@ -8,13 +9,13 @@ from dve.core_engine.backends.types import EntityType from dve.core_engine.type_hints import EntityName -get_original_bases: Callable[[type], Tuple[Any, ...]] +get_original_bases: Callable[[type], tuple[Any, ...]] try: # pylint: disable=ungrouped-imports from typing_extensions import get_original_bases # type: ignore except ImportError: - def get_original_bases(__cls: type) -> Tuple[Any, ...]: + def get_original_bases(__cls: type) -> tuple[Any, ...]: """A basic version of 'get_original_bases' in case it's not in typing extensions.""" try: return __cls.__orig_bases__ # type: ignore @@ -28,7 +29,7 @@ def get_original_bases(__cls: type) -> Tuple[Any, ...]: return __cls.__mro__ -def get_entity_type(child: Type, annotated_type_name: str) -> Type[EntityType]: +def get_entity_type(child: type, annotated_type_name: str) -> type[EntityType]: """Get the annotated entity type from a subclass, given the name of the parent class which must be annotated. @@ -80,7 +81,7 @@ def __init__( """The reference data mapping.""" @staticmethod - def _get_key_and_whether_refdata(key: str) -> Tuple[EntityName, IsRefdata]: + def _get_key_and_whether_refdata(key: str) -> tuple[EntityName, IsRefdata]: """Get the key and whether the entity is a reference data entry.""" if key.startswith("refdata_"): return key[8:], True diff --git a/src/dve/core_engine/backends/base/reader.py b/src/dve/core_engine/backends/base/reader.py index 3aafdb6..9862e7e 100644 --- a/src/dve/core_engine/backends/base/reader.py +++ b/src/dve/core_engine/backends/base/reader.py @@ -1,8 +1,9 @@ """Abstract implementation of the file parser.""" from abc import ABC, abstractmethod +from collections.abc import Iterator from inspect import ismethod -from typing import Any, ClassVar, Dict, Iterator, Optional, Type, TypeVar +from typing import Any, ClassVar, Optional, TypeVar from pydantic import BaseModel from typing_extensions import Protocol @@ -15,7 +16,7 @@ ET_co = TypeVar("ET_co", covariant=True) # This needs to be defined outside the class since otherwise mypy expects # BaseFileReader to be generic: -_ReadFunctions = Dict[Type[T], "_UnboundReadFunction[T]"] +_ReadFunctions = dict[type[T], "_UnboundReadFunction[T]"] """A convenience type indicating a mapping from type to reader function.""" _ENTITY_TYPE_ATTR_NAME = "_read_func_entity_type" """The name of the read function's entity type annotation attribute.""" @@ -29,9 +30,8 @@ def __call__( # pylint: disable=bad-staticmethod-argument self: "BaseFileReader", # This is the protocol for an _unbound_ method. resource: URI, entity_name: EntityName, - schema: Type[BaseModel], - ) -> ET_co: - ... + schema: type[BaseModel], + ) -> ET_co: ... def read_function(entity_type: T) -> WrapDecorator: @@ -87,8 +87,8 @@ def read_to_py_iterator( self, resource: URI, entity_name: EntityName, - schema: Type[BaseModel], - ) -> Iterator[Dict[str, Any]]: + schema: type[BaseModel], + ) -> Iterator[dict[str, Any]]: """Iterate through the contents of the resource, yielding dicts representing each record. @@ -101,10 +101,10 @@ def read_to_py_iterator( def read_to_entity_type( self, - entity_type: Type[EntityType], + entity_type: type[EntityType], resource: URI, entity_name: EntityName, - schema: Type[BaseModel], + schema: type[BaseModel], ) -> EntityType: """Read to the specified entity type, if supported. @@ -113,7 +113,7 @@ def read_to_entity_type( data contract. """ - if entity_name == Iterator[Dict[str, Any]]: + if entity_name == Iterator[dict[str, Any]]: return self.read_to_py_iterator(resource, entity_name, schema) # type: ignore try: @@ -127,7 +127,7 @@ def write_parquet( self, entity: EntityType, target_location: URI, - schema: Optional[Type[BaseModel]] = None, + schema: Optional[type[BaseModel]] = None, **kwargs, ) -> URI: """Write entity to parquet. diff --git a/src/dve/core_engine/backends/base/reference_data.py b/src/dve/core_engine/backends/base/reference_data.py index 88006ab..a9a68fa 100644 --- a/src/dve/core_engine/backends/base/reference_data.py +++ b/src/dve/core_engine/backends/base/reference_data.py @@ -1,18 +1,8 @@ """The base implementation of the reference data loader..""" from abc import ABC, abstractmethod -from typing import ( - Callable, - ClassVar, - Dict, - Generic, - Iterator, - Mapping, - Optional, - Type, - Union, - get_type_hints, -) +from collections.abc import Callable, Iterator, Mapping +from typing import ClassVar, Generic, Optional, Union, get_type_hints from pydantic import BaseModel, Field from typing_extensions import Annotated, Literal @@ -69,14 +59,14 @@ class ReferenceURI(BaseModel, frozen=True): class BaseRefDataLoader(Generic[EntityType], Mapping[EntityName, EntityType], ABC): """A reference data mapper which lazy-loads requested entities.""" - __entity_type__: ClassVar[Type[EntityType]] # type: ignore + __entity_type__: ClassVar[type[EntityType]] # type: ignore """ The entity type used for the reference data. This will be populated from the generic annotation at class creation time. """ - __step_functions__: ClassVar[Dict[Type[ReferenceConfig], Callable]] = {} + __step_functions__: ClassVar[dict[type[ReferenceConfig], Callable]] = {} """ A mapping between refdata config types and functions to call to load these configs into reference data entities @@ -112,7 +102,7 @@ class variable for the subclass. # pylint: disable=unused-argument def __init__( - self, reference_entity_config: Dict[EntityName, ReferenceConfig], **kwargs + self, reference_entity_config: dict[EntityName, ReferenceConfig], **kwargs ) -> None: self.reference_entity_config = reference_entity_config """ @@ -121,7 +111,7 @@ def __init__( some backends, and table names for others). """ - self.entity_cache: Dict[EntityName, EntityType] = {} + self.entity_cache: dict[EntityName, EntityType] = {} """A cache for already-loaded entities.""" @abstractmethod diff --git a/src/dve/core_engine/backends/base/rules.py b/src/dve/core_engine/backends/base/rules.py index 5c859cd..043f826 100644 --- a/src/dve/core_engine/backends/base/rules.py +++ b/src/dve/core_engine/backends/base/rules.py @@ -3,19 +3,8 @@ import logging from abc import ABC, abstractmethod from collections import defaultdict -from typing import ( - Any, - ClassVar, - Dict, - Generic, - Iterable, - List, - NoReturn, - Optional, - Tuple, - Type, - TypeVar, -) +from collections.abc import Iterable +from typing import Any, ClassVar, Generic, NoReturn, Optional, TypeVar from uuid import uuid4 from typing_extensions import Literal, Protocol, get_type_hints @@ -55,7 +44,7 @@ T = TypeVar("T", bound=AbstractStep) # This needs to be defined outside the class since otherwise mypy expects # BaseFileReader to be generic: -_StepFunctions = Dict[Type[T], "_UnboundStepFunction[T]"] +_StepFunctions = dict[type[T], "_UnboundStepFunction[T]"] """A convenience type indicating a mapping from config type to step method.""" Stage = Literal["Pre-filter", "Filter", "Post-filter"] """The name of a stage within a rule.""" @@ -70,8 +59,7 @@ def __call__( # pylint: disable=bad-staticmethod-argument entities: Entities, *, config: T_contra, - ) -> Messages: - ... + ) -> Messages: ... class BaseStepImplementations(Generic[EntityType], ABC): # pylint: disable=too-many-public-methods @@ -87,7 +75,7 @@ class BaseStepImplementations(Generic[EntityType], ABC): # pylint: disable=too- """ - __entity_type__: ClassVar[Type[EntityType]] # type: ignore + __entity_type__: ClassVar[type[EntityType]] # type: ignore """ The entity type that the steps are implemented for. @@ -148,7 +136,7 @@ def drop_row_id(entity: EntityType) -> EntityType: @classmethod def _raise_notimplemented_error( - cls, config_type: Type[AbstractStep], source: Exception + cls, config_type: type[AbstractStep], source: Exception ) -> NoReturn: """Raise a `NotImplementedError` from a provided error.""" raise NotImplementedError( @@ -175,7 +163,7 @@ def _handle_rule_error(self, error: Exception, config: AbstractStep) -> Messages """Log an error and create appropriate error messages.""" return render_error(error, self._step_metadata_to_location(config)) - def evaluate(self, entities, *, config: AbstractStep) -> Tuple[Messages, StageSuccessful]: + def evaluate(self, entities, *, config: AbstractStep) -> tuple[Messages, StageSuccessful]: """Evaluate a step definition, applying it to the entities.""" config_type = type(config) success = True @@ -357,7 +345,7 @@ def notify(self, entities: Entities, *, config: Notification) -> Messages: def apply_sync_filters( self, entities: Entities, *filters: DeferredFilter - ) -> Tuple[Messages, StageSuccessful]: + ) -> tuple[Messages, StageSuccessful]: """Apply the synchronised filters, emitting appropriate error messages for any records which do not meet the conditions. @@ -366,7 +354,7 @@ def apply_sync_filters( for that entity. """ - filters_by_entity: Dict[EntityName, List[DeferredFilter]] = defaultdict(list) + filters_by_entity: dict[EntityName, list[DeferredFilter]] = defaultdict(list) for rule in filters: filters_by_entity[rule.entity_name].append(rule) @@ -374,7 +362,7 @@ def apply_sync_filters( for entity_name, filter_rules in filters_by_entity.items(): entity = entities[entity_name] - filter_column_names: List[str] = [] + filter_column_names: list[str] = [] unmodified_entities = {entity_name: entity} modified_entities = {entity_name: entity} @@ -468,7 +456,7 @@ def apply_rules(self, entities: Entities, rule_metadata: RuleMetadata) -> Messag altering the entities in-place. """ - rules_and_locals: Iterable[Tuple[Rule, TemplateVariables]] + rules_and_locals: Iterable[tuple[Rule, TemplateVariables]] if rule_metadata.templating_strategy == "upfront": rules_and_locals = [] for rule, local_variables in rule_metadata: diff --git a/src/dve/core_engine/backends/base/utilities.py b/src/dve/core_engine/backends/base/utilities.py index 3dc52f0..30efc74 100644 --- a/src/dve/core_engine/backends/base/utilities.py +++ b/src/dve/core_engine/backends/base/utilities.py @@ -2,7 +2,8 @@ import warnings from collections import deque -from typing import Deque, Optional, Sequence +from collections.abc import Sequence +from typing import Optional from dve.core_engine.message import FeedbackMessage from dve.core_engine.type_hints import ExpressionArray, MultiExpression @@ -32,7 +33,7 @@ def _split_multiexpr_string(expressions: MultiExpression) -> ExpressionArray: string_opened_with: Optional[str] = None # ...and need to ensure we don't split on commas inside brackets, so have to # keep track of what brackets are open and closed (except those in strings). - bracket_stack: Deque[str] = deque() + bracket_stack: deque[str] = deque() expression_list, slice_start = [], 0 for slice_end, char in enumerate(expressions): diff --git a/src/dve/core_engine/backends/exceptions.py b/src/dve/core_engine/backends/exceptions.py index e56a8b3..279f4ce 100644 --- a/src/dve/core_engine/backends/exceptions.py +++ b/src/dve/core_engine/backends/exceptions.py @@ -2,7 +2,8 @@ import logging from abc import ABC, abstractmethod -from typing import Any, Mapping, Optional, Set, Tuple +from collections.abc import Mapping +from typing import Any, Optional from dve.core_engine.message import FeedbackMessage from dve.core_engine.type_hints import EntityName, ErrorCategory, ErrorLocation, Messages @@ -208,9 +209,9 @@ class SchemaMismatch(ReaderErrorMixin, ValueError): def __init__( self, *args: object, - missing_fields: Optional[Set[FieldName]] = None, - extra_fields: Optional[Set[FieldName]] = None, - wrong_types: Optional[Mapping[FieldName, Tuple[ActualFieldType, ExpectedFieldType]]] = None, + missing_fields: Optional[set[FieldName]] = None, + extra_fields: Optional[set[FieldName]] = None, + wrong_types: Optional[Mapping[FieldName, tuple[ActualFieldType, ExpectedFieldType]]] = None, ): self.missing_fields = missing_fields or set() """Fields that are missing from the expected schema.""" diff --git a/src/dve/core_engine/backends/implementations/duckdb/auditing.py b/src/dve/core_engine/backends/implementations/duckdb/auditing.py index 7119548..3124c6d 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/auditing.py +++ b/src/dve/core_engine/backends/implementations/duckdb/auditing.py @@ -1,5 +1,7 @@ """Auditing definitions for duckdb backend""" -from typing import Any, Dict, Iterable, List, Optional, Type, Union + +from collections.abc import Iterable +from typing import Any, Optional, Union import polars as pl from duckdb import ColumnExpression, DuckDBPyConnection, DuckDBPyRelation, StarExpression, connect @@ -31,7 +33,7 @@ class DDBAuditor(BaseAuditor[DuckDBPyRelation]): def __init__( self, - record_type: Type[AuditRecord], + record_type: type[AuditRecord], database_uri: URI, name: str, connection: Optional[DuckDBPyConnection] = None, @@ -66,7 +68,7 @@ def ddb_create_table_sql(self) -> str: return _sql_expression @property - def polars_schema(self) -> Dict[str, PolarsType]: + def polars_schema(self) -> dict[str, PolarsType]: """Get polars dataframe schema for auditor""" return { fld: PYTHON_TYPE_TO_POLARS_TYPE.get(dtype, pl.Utf8) # type: ignore @@ -77,7 +79,7 @@ def get_relation(self) -> DuckDBPyRelation: """Get a relation to interact with the auditor duckdb table""" return self._connection.table(self._name) - def combine_filters(self, filter_criteria: List[FilterCriteria]) -> str: + def combine_filters(self, filter_criteria: list[FilterCriteria]) -> str: """Combine multiple filters to apply""" return " AND ".join([self.normalise_filter(filt) for filt in filter_criteria]) @@ -100,7 +102,7 @@ def conv_to_records(self, recs: DuckDBPyRelation) -> Iterable[AuditRecord]: """Convert the relation to an iterable of the related audit record""" return (self._record_type(**rec) for rec in recs.pl().iter_rows(named=True)) - def conv_to_entity(self, recs: List[AuditRecord]) -> DuckDBPyRelation: + def conv_to_entity(self, recs: list[AuditRecord]) -> DuckDBPyRelation: """Convert a list of audit records to a relation""" # pylint: disable=W0612 rec_df = pl.DataFrame( # type: ignore @@ -109,7 +111,7 @@ def conv_to_entity(self, recs: List[AuditRecord]) -> DuckDBPyRelation: ) return self._connection.sql("select * from rec_df") - def add_records(self, records: Iterable[Dict[str, Any]]) -> None: + def add_records(self, records: Iterable[dict[str, Any]]) -> None: """Add records to the underlying duckdb table""" # pylint: disable=W0612 data_pl_df = pl.DataFrame( # type: ignore @@ -124,7 +126,7 @@ def add_records(self, records: Iterable[Dict[str, Any]]) -> None: def retrieve_records( self, - filter_criteria: Optional[List[FilterCriteria]] = None, + filter_criteria: Optional[list[FilterCriteria]] = None, data: Optional[DuckDBPyRelation] = None, ) -> DuckDBPyRelation: """Get records from the underlying duckdb table""" @@ -135,9 +137,9 @@ def retrieve_records( def get_most_recent_records( self, - order_criteria: List[OrderCriteria], - partition_fields: Optional[List[str]] = None, - pre_filter_criteria: Optional[List[FilterCriteria]] = None, + order_criteria: list[OrderCriteria], + partition_fields: Optional[list[str]] = None, + pre_filter_criteria: Optional[list[FilterCriteria]] = None, ) -> DuckDBPyRelation: """Get most recent records, based on the order and partitioning, from the underlying duckdb table""" @@ -227,6 +229,6 @@ def combine_auditor_information( ) @staticmethod - def conv_to_iterable(recs: Union[DDBAuditor, DuckDBPyRelation]) -> Iterable[Dict[str, Any]]: + def conv_to_iterable(recs: Union[DDBAuditor, DuckDBPyRelation]) -> Iterable[dict[str, Any]]: recs_rel: DuckDBPyRelation = recs.get_relation() if isinstance(recs, DDBAuditor) else recs return recs_rel.pl().iter_rows(named=True) diff --git a/src/dve/core_engine/backends/implementations/duckdb/contract.py b/src/dve/core_engine/backends/implementations/duckdb/contract.py index 97ae258..5113da5 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/contract.py +++ b/src/dve/core_engine/backends/implementations/duckdb/contract.py @@ -2,7 +2,8 @@ # pylint: disable=R0903 import logging -from typing import Any, Dict, Iterator, List, Optional, Tuple, Type +from collections.abc import Iterator +from typing import Any, Optional from uuid import uuid4 import pandas as pd @@ -36,7 +37,7 @@ class PandasApplyHelper: def __init__(self, row_validator: RowValidator): self.row_validator = row_validator - self.errors: List[FeedbackMessage] = [] + self.errors: list[FeedbackMessage] = [] def __call__(self, row: pd.Series): self.errors.extend(self.row_validator(row.to_dict())[1]) # type: ignore @@ -76,10 +77,10 @@ def _cache_records(self, relation: DuckDBPyRelation, cache_prefix: URI) -> URI: return chunk_uri def create_entity_from_py_iterator( # pylint: disable=unused-argument - self, entity_name: URI, records: Iterator[Dict[URI, Any]], schema: Type[BaseModel] + self, entity_name: URI, records: Iterator[dict[URI, Any]], schema: type[BaseModel] ) -> DuckDBPyRelation: """Create DuckDB Relation from iterator of records""" - polars_schema: Dict[str, PolarsType] = { + polars_schema: dict[str, PolarsType] = { fld.name: get_polars_type_from_annotation(fld.type_) for fld in stringify_model(schema).__fields__.values() } @@ -99,7 +100,7 @@ def generate_ddb_cast_statement( def apply_data_contract( self, entities: DuckDBEntities, contract_metadata: DataContractMetadata - ) -> Tuple[DuckDBEntities, Messages, StageSuccessful]: + ) -> tuple[DuckDBEntities, Messages, StageSuccessful]: """Apply the data contract to the duckdb relations""" self.logger.info("Applying data contracts") all_messages: Messages = [] @@ -107,12 +108,12 @@ def apply_data_contract( successful = True for entity_name, relation in entities.items(): # get dtypes for all fields -> python data types or use with relation - entity_fields: Dict[str, ModelField] = contract_metadata.schemas[entity_name].__fields__ - ddb_schema: Dict[str, DuckDBPyType] = { + entity_fields: dict[str, ModelField] = contract_metadata.schemas[entity_name].__fields__ + ddb_schema: dict[str, DuckDBPyType] = { fld.name: get_duckdb_type_from_annotation(fld.annotation) for fld in entity_fields.values() } - polars_schema: Dict[str, PolarsType] = { + polars_schema: dict[str, PolarsType] = { fld.name: get_polars_type_from_annotation(fld.annotation) for fld in entity_fields.values() } @@ -133,9 +134,11 @@ def apply_data_contract( all_messages.extend(application_helper.errors) casting_statements = [ - self.generate_ddb_cast_statement(column, dtype) - if column in relation.columns - else self.generate_ddb_cast_statement(column, dtype, null_flag=True) + ( + self.generate_ddb_cast_statement(column, dtype) + if column in relation.columns + else self.generate_ddb_cast_statement(column, dtype, null_flag=True) + ) for column, dtype in ddb_schema.items() ] try: diff --git a/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py b/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py index a8656cc..ea1901e 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py +++ b/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py @@ -2,11 +2,12 @@ # ignore: type[attr-defined] """Helper objects for duckdb data contract implementation""" +from collections.abc import Generator, Iterator from dataclasses import is_dataclass from datetime import date, datetime from decimal import Decimal from pathlib import Path -from typing import Any, ClassVar, Dict, Generator, Iterator, Set, Union +from typing import Any, ClassVar, Union from urllib.parse import urlparse import duckdb.typing as ddbtyp @@ -58,7 +59,7 @@ class DDBStruct: TYPE_TEXT = "STRUCT" - def __init__(self, sub_elements: Dict[str, DuckDBPyType]): + def __init__(self, sub_elements: dict[str, DuckDBPyType]): self._sub_elements = {**sub_elements} def add_element(self, field_name: str, data_type: DuckDBPyType): @@ -77,7 +78,7 @@ def __call__(self): return self.__str__() -PYTHON_TYPE_TO_DUCKDB_TYPE: Dict[type, DuckDBPyType] = { +PYTHON_TYPE_TO_DUCKDB_TYPE: dict[type, DuckDBPyType] = { str: ddbtyp.VARCHAR, int: ddbtyp.BIGINT, bool: ddbtyp.BOOLEAN, @@ -157,7 +158,7 @@ def get_duckdb_type_from_annotation(type_annotation: Any) -> DuckDBPyType: # Type hint is a `pydantic` model. or (type_origin is None and issubclass(type_annotation, BaseModel)) ): - fields: Dict[str, DuckDBPyType] = {} + fields: dict[str, DuckDBPyType] = {} for field_name, field_annotation in get_type_hints(type_annotation).items(): # Technically non-string keys are disallowed, but people are bad. if not isinstance(field_name, str): @@ -181,7 +182,7 @@ def get_duckdb_type_from_annotation(type_annotation: Any) -> DuckDBPyType: f"List must have type annotation (e.g. `List[str]`), got {type_annotation!r}" ) if type_annotation is dict or type_origin is dict: - raise ValueError(f"Dict must be `typing.TypedDict` subclass, got {type_annotation!r}") + raise ValueError(f"dict must be `typing.TypedDict` subclass, got {type_annotation!r}") for type_ in type_annotation.mro(): duck_type = PYTHON_TYPE_TO_DUCKDB_TYPE.get(type_) @@ -224,7 +225,7 @@ def _ddb_read_parquet( def _ddb_write_parquet( # pylint: disable=unused-argument - self, entity: Union[Iterator[Dict[str, Any]], DuckDBPyRelation], target_location: URI, **kwargs + self, entity: Union[Iterator[dict[str, Any]], DuckDBPyRelation], target_location: URI, **kwargs ) -> URI: """Method to write parquet files from type cast entities following data contract application @@ -265,7 +266,7 @@ def duckdb_get_entity_count(cls): return cls -def get_all_registered_udfs(connection: DuckDBPyConnection) -> Set[str]: +def get_all_registered_udfs(connection: DuckDBPyConnection) -> set[str]: """Function to supply the names of a registered functions stored in the supplied duckdb connection. Creates the temp table used to store registered functions (if not exists). """ diff --git a/src/dve/core_engine/backends/implementations/duckdb/readers/csv.py b/src/dve/core_engine/backends/implementations/duckdb/readers/csv.py index 3a7cb83..df43348 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/readers/csv.py +++ b/src/dve/core_engine/backends/implementations/duckdb/readers/csv.py @@ -1,7 +1,8 @@ """A csv reader to create duckdb relations""" # pylint: disable=arguments-differ -from typing import Any, Dict, Iterator, Optional, Type +from collections.abc import Iterator +from typing import Any, Optional import duckdb as ddb import polars as pl @@ -42,26 +43,26 @@ def __init__( super().__init__() def read_to_py_iterator( - self, resource: URI, entity_name: EntityName, schema: Type[BaseModel] - ) -> Iterator[Dict[str, Any]]: + self, resource: URI, entity_name: EntityName, schema: type[BaseModel] + ) -> Iterator[dict[str, Any]]: """Creates an iterable object of rows as dictionaries""" yield from self.read_to_relation(resource, entity_name, schema).pl().iter_rows(named=True) @read_function(DuckDBPyRelation) def read_to_relation( # pylint: disable=unused-argument - self, resource: URI, entity_name: EntityName, schema: Type[BaseModel] + self, resource: URI, entity_name: EntityName, schema: type[BaseModel] ) -> DuckDBPyRelation: """Returns a relation object from the source csv""" if get_content_length(resource) == 0: raise EmptyFileError(f"File at {resource} is empty.") - reader_options: Dict[str, Any] = { + reader_options: dict[str, Any] = { "header": self.header, "delimiter": self.delim, "quotechar": self.quotechar, } - ddb_schema: Dict[str, SQLType] = { + ddb_schema: dict[str, SQLType] = { fld.name: str(get_duckdb_type_from_annotation(fld.annotation)) # type: ignore for fld in schema.__fields__.values() } @@ -80,13 +81,13 @@ class PolarsToDuckDBCSVReader(DuckDBCSVReader): @read_function(DuckDBPyRelation) def read_to_relation( # pylint: disable=unused-argument - self, resource: URI, entity_name: EntityName, schema: Type[BaseModel] + self, resource: URI, entity_name: EntityName, schema: type[BaseModel] ) -> DuckDBPyRelation: """Returns a relation object from the source csv""" if get_content_length(resource) == 0: raise EmptyFileError(f"File at {resource} is empty.") - reader_options: Dict[str, Any] = { + reader_options: dict[str, Any] = { "has_header": self.header, "separator": self.delim, "quote_char": self.quotechar, @@ -131,7 +132,7 @@ class DuckDBCSVRepeatingHeaderReader(PolarsToDuckDBCSVReader): @read_function(DuckDBPyRelation) def read_to_relation( # pylint: disable=unused-argument - self, resource: URI, entity_name: EntityName, schema: Type[BaseModel] + self, resource: URI, entity_name: EntityName, schema: type[BaseModel] ) -> DuckDBPyRelation: entity = super().read_to_relation(resource=resource, entity_name=entity_name, schema=schema) entity = entity.distinct() diff --git a/src/dve/core_engine/backends/implementations/duckdb/readers/json.py b/src/dve/core_engine/backends/implementations/duckdb/readers/json.py index f8f3a77..a706f11 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/readers/json.py +++ b/src/dve/core_engine/backends/implementations/duckdb/readers/json.py @@ -1,7 +1,8 @@ """A csv reader to create duckdb relations""" # pylint: disable=arguments-differ -from typing import Any, Dict, Iterator, Optional, Type +from collections.abc import Iterator +from typing import Any, Optional from duckdb import DuckDBPyRelation, read_json from pydantic import BaseModel @@ -25,18 +26,18 @@ def __init__(self, json_format: Optional[str] = "array"): super().__init__() def read_to_py_iterator( - self, resource: URI, entity_name: EntityName, schema: Type[BaseModel] - ) -> Iterator[Dict[str, Any]]: + self, resource: URI, entity_name: EntityName, schema: type[BaseModel] + ) -> Iterator[dict[str, Any]]: """Creates an iterable object of rows as dictionaries""" return self.read_to_relation(resource, entity_name, schema).pl().iter_rows(named=True) @read_function(DuckDBPyRelation) def read_to_relation( # pylint: disable=unused-argument - self, resource: URI, entity_name: EntityName, schema: Type[BaseModel] + self, resource: URI, entity_name: EntityName, schema: type[BaseModel] ) -> DuckDBPyRelation: """Returns a relation object from the source json""" - ddb_schema: Dict[str, SQLType] = { + ddb_schema: dict[str, SQLType] = { fld.name: str(get_duckdb_type_from_annotation(fld.annotation)) # type: ignore for fld in schema.__fields__.values() } diff --git a/src/dve/core_engine/backends/implementations/duckdb/readers/xml.py b/src/dve/core_engine/backends/implementations/duckdb/readers/xml.py index af1147f..32b5d1d 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/readers/xml.py +++ b/src/dve/core_engine/backends/implementations/duckdb/readers/xml.py @@ -1,7 +1,7 @@ # mypy: disable-error-code="attr-defined" """An xml reader to create duckdb relations""" -from typing import Dict, Optional, Type +from typing import Optional import polars as pl from duckdb import DuckDBPyConnection, DuckDBPyRelation, default_connection @@ -23,9 +23,9 @@ def __init__(self, ddb_connection: Optional[DuckDBPyConnection] = None, **kwargs super().__init__(**kwargs) @read_function(DuckDBPyRelation) - def read_to_relation(self, resource: URI, entity_name: str, schema: Type[BaseModel]): + def read_to_relation(self, resource: URI, entity_name: str, schema: type[BaseModel]): """Returns a relation object from the source xml""" - polars_schema: Dict[str, pl.DataType] = { # type: ignore + polars_schema: dict[str, pl.DataType] = { # type: ignore fld.name: get_polars_type_from_annotation(fld.annotation) for fld in stringify_model(schema).__fields__.values() } diff --git a/src/dve/core_engine/backends/implementations/duckdb/reference_data.py b/src/dve/core_engine/backends/implementations/duckdb/reference_data.py index 98e80bb..5df3f6a 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/reference_data.py +++ b/src/dve/core_engine/backends/implementations/duckdb/reference_data.py @@ -1,6 +1,6 @@ """A reference data loader for duckdb.""" -from typing import Dict, Optional +from typing import Optional from duckdb import DuckDBPyConnection, DuckDBPyRelation @@ -32,7 +32,7 @@ class DuckDBRefDataLoader(BaseRefDataLoader[DuckDBPyRelation]): def __init__( self, - reference_entity_config: Dict[EntityName, ReferenceConfigUnion], + reference_entity_config: dict[EntityName, ReferenceConfigUnion], **kwargs, ) -> None: super().__init__(reference_entity_config, **kwargs) diff --git a/src/dve/core_engine/backends/implementations/duckdb/rules.py b/src/dve/core_engine/backends/implementations/duckdb/rules.py index f03b65c..dbc308e 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/rules.py +++ b/src/dve/core_engine/backends/implementations/duckdb/rules.py @@ -1,6 +1,7 @@ """Business rule definitions for duckdb backend""" -from typing import Callable, Dict, Set, Tuple, get_type_hints +from collections.abc import Callable +from typing import get_type_hints from uuid import uuid4 from duckdb import ( @@ -76,14 +77,14 @@ def register_udfs( # type: ignore cls, connection: DuckDBPyConnection, **kwargs ): # pylint: disable=arguments-differ """Method to register all custom dve functions for use during business rules application""" - _registered_functions: Set[str] = get_all_registered_udfs(connection) - _available_functions: Dict[str, Callable] = { + _registered_functions: set[str] = get_all_registered_udfs(connection) + _available_functions: dict[str, Callable] = { func_name: func for func_name, func in vars(functions).items() if callable(func) and func.__module__ == "dve.core_engine.functions.implementations" } - _unregistered_functions: Set[str] = set(_available_functions).difference( + _unregistered_functions: set[str] = set(_available_functions).difference( _registered_functions ) @@ -145,7 +146,7 @@ def select(self, entities: DuckDBEntities, *, config: SelectColumns) -> Messages def group_by(self, entities: DuckDBEntities, *, config: Aggregation) -> Messages: """A transformation step which performs an aggregation on an entity.""" - def _add_cnst_field(rel: DuckDBPyRelation) -> Tuple[str, DuckDBPyRelation]: + def _add_cnst_field(rel: DuckDBPyRelation) -> tuple[str, DuckDBPyRelation]: """Add a constant field for use as an index to allow for pivoting with no group""" fld_name = f"fld_{uuid4().hex[0:8]}" return fld_name, rel.select( @@ -234,7 +235,7 @@ def _resolve_join_name_conflicts( def _perform_join( self, entities: DuckDBEntities, config: AbstractConditionalJoin - ) -> Tuple[Source, Target, Joined]: + ) -> tuple[Source, Target, Joined]: """Perform a conditional join between source and target, returning the source, target and joined DataFrames. diff --git a/src/dve/core_engine/backends/implementations/duckdb/types.py b/src/dve/core_engine/backends/implementations/duckdb/types.py index 21e2615..a6820d5 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/types.py +++ b/src/dve/core_engine/backends/implementations/duckdb/types.py @@ -1,7 +1,7 @@ """Types used in Spark implementations.""" # pylint: disable=C0103 -from typing import MutableMapping +from collections.abc import MutableMapping from duckdb import DuckDBPyRelation from typing_extensions import Literal diff --git a/src/dve/core_engine/backends/implementations/duckdb/utilities.py b/src/dve/core_engine/backends/implementations/duckdb/utilities.py index 280a218..39e4929 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/utilities.py +++ b/src/dve/core_engine/backends/implementations/duckdb/utilities.py @@ -1,11 +1,9 @@ """Utility objects for use with duckdb backend""" -from typing import List - from dve.core_engine.backends.base.utilities import _split_multiexpr_string -def parse_multiple_expressions(expressions) -> List[str]: +def parse_multiple_expressions(expressions) -> list[str]: """Break multiple expressions into a list of expressions""" if isinstance(expressions, dict): return expr_mapping_to_columns(expressions) @@ -16,7 +14,7 @@ def parse_multiple_expressions(expressions) -> List[str]: return [] -def expr_mapping_to_columns(expressions: dict) -> List[str]: +def expr_mapping_to_columns(expressions: dict) -> list[str]: """Map duckdb expressions to column names""" columns = [] for expression, alias in expressions.items(): @@ -24,12 +22,12 @@ def expr_mapping_to_columns(expressions: dict) -> List[str]: return columns -def expr_array_to_columns(expressions: List[str]) -> List[str]: +def expr_array_to_columns(expressions: list[str]) -> list[str]: """Create list of duckdb expressions from list of expressions""" return [f"{expression}" for expression in expressions] -def multiexpr_string_to_columns(expressions: str) -> List[str]: +def multiexpr_string_to_columns(expressions: str) -> list[str]: """Split string containing multiple expressions to list of duck db column expressions """ diff --git a/src/dve/core_engine/backends/implementations/spark/auditing.py b/src/dve/core_engine/backends/implementations/spark/auditing.py index f0eaf94..c050a17 100644 --- a/src/dve/core_engine/backends/implementations/spark/auditing.py +++ b/src/dve/core_engine/backends/implementations/spark/auditing.py @@ -1,7 +1,9 @@ """Auditing definitions for spark backend""" + import operator +from collections.abc import Iterable from functools import reduce -from typing import Any, Dict, Iterable, List, Optional, Type, Union +from typing import Any, Optional, Union from pyspark.sql import Column, DataFrame, DataFrameWriter, SparkSession from pyspark.sql.functions import col, lit, row_number @@ -28,7 +30,7 @@ SparkTableFormat = Literal["delta", "parquet"] -AUDIT_PARTITION_COLS: Dict[str, List[str]] = { +AUDIT_PARTITION_COLS: dict[str, list[str]] = { "submission_info": ["date_updated"], "transfers": ["date_updated"], "processing_status": ["date_updated"], @@ -41,7 +43,7 @@ class SparkAuditor(BaseAuditor[DataFrame]): def __init__( self, - record_type: Type[AuditRecord], + record_type: type[AuditRecord], database: str, name: str, table_format: Optional[SparkTableFormat] = "delta", @@ -49,7 +51,7 @@ def __init__( ): self._db = database self._table_format = table_format - self._partition_cols: List[str] = AUDIT_PARTITION_COLS.get(name, []) + self._partition_cols: list[str] = AUDIT_PARTITION_COLS.get(name, []) self._spark = spark if spark else SparkSession.builder.getOrCreate() super().__init__(name=name, record_type=record_type) if not table_exists(self._spark, f"{self._db}.{self._name}"): @@ -80,7 +82,7 @@ def get_df(self) -> DataFrame: self._spark.catalog.refreshTable(self.fq_name) return self._spark.table(self.fq_name) - def combine_filters(self, filter_criteria: List[FilterCriteria]) -> Column: + def combine_filters(self, filter_criteria: list[FilterCriteria]) -> Column: """Combine multiple filters to apply""" return reduce(lambda x, y: x & y, [self.normalise_filter(filt) for filt in filter_criteria]) @@ -111,13 +113,13 @@ def conv_to_records(self, recs: DataFrame) -> Iterable[AuditRecord]: """Convert the dataframe to an iterable of the related audit record""" return (self._record_type(**rec.asDict()) for rec in recs.toLocalIterator()) - def conv_to_entity(self, recs: List[AuditRecord]) -> DataFrame: + def conv_to_entity(self, recs: list[AuditRecord]) -> DataFrame: """Convert the dataframe to an iterable of the related audit record""" return self._spark.createDataFrame( # type: ignore [rec.dict() for rec in recs], schema=self.spark_schema ) - def add_records(self, records: Iterable[Dict[str, Any]]): + def add_records(self, records: Iterable[dict[str, Any]]): _df_writer: DataFrameWriter = ( self._spark.createDataFrame(records, schema=self.spark_schema) # type: ignore .coalesce(1) @@ -130,7 +132,7 @@ def add_records(self, records: Iterable[Dict[str, Any]]): def retrieve_records( self, - filter_criteria: Optional[List[FilterCriteria]] = None, + filter_criteria: Optional[list[FilterCriteria]] = None, data: Optional[DataFrame] = None, ) -> DataFrame: df = self.get_df() if not data else data @@ -140,9 +142,9 @@ def retrieve_records( def get_most_recent_records( self, - order_criteria: List[OrderCriteria], - partition_fields: Optional[List[str]] = None, - pre_filter_criteria: Optional[List[FilterCriteria]] = None, + order_criteria: list[OrderCriteria], + partition_fields: Optional[list[str]] = None, + pre_filter_criteria: Optional[list[FilterCriteria]] = None, ) -> DataFrame: ordering = [self.normalise_order(fld) for fld in order_criteria] df = self.get_df() @@ -223,6 +225,6 @@ def combine_auditor_information( ) @staticmethod - def conv_to_iterable(recs: Union[SparkAuditor, DataFrame]) -> Iterable[Dict[str, Any]]: + def conv_to_iterable(recs: Union[SparkAuditor, DataFrame]) -> Iterable[dict[str, Any]]: recs_df: DataFrame = recs.get_df() if isinstance(recs, SparkAuditor) else recs return iter([rw.asDict() for rw in recs_df.toLocalIterator()]) diff --git a/src/dve/core_engine/backends/implementations/spark/backend.py b/src/dve/core_engine/backends/implementations/spark/backend.py index b19757d..742e9e3 100644 --- a/src/dve/core_engine/backends/implementations/spark/backend.py +++ b/src/dve/core_engine/backends/implementations/spark/backend.py @@ -1,7 +1,7 @@ """A 'complete' implementation of the generic backend in Spark.""" import logging -from typing import Any, Optional, Type +from typing import Any, Optional from pyspark.sql import DataFrame, SparkSession @@ -26,7 +26,7 @@ def __init__( dataset_config_uri: Optional[URI] = None, contract: Optional[SparkDataContract] = None, steps: Optional[SparkStepImplementations] = None, - reference_data_loader: Optional[Type[SparkRefDataLoader]] = None, + reference_data_loader: Optional[type[SparkRefDataLoader]] = None, logger: Optional[logging.Logger] = None, spark_session: Optional[SparkSession] = None, **kwargs: Any, diff --git a/src/dve/core_engine/backends/implementations/spark/contract.py b/src/dve/core_engine/backends/implementations/spark/contract.py index 6b6307f..bbd2d5a 100644 --- a/src/dve/core_engine/backends/implementations/spark/contract.py +++ b/src/dve/core_engine/backends/implementations/spark/contract.py @@ -1,7 +1,8 @@ """An implementation of the data contract in Apache spark.""" import logging -from typing import Any, Dict, Iterator, Optional, Set, Tuple, Type +from collections.abc import Iterator +from typing import Any, Optional from uuid import uuid4 from pydantic import BaseModel @@ -30,7 +31,7 @@ from dve.core_engine.constants import ROWID_COLUMN_NAME from dve.core_engine.type_hints import URI, EntityName, Messages -COMPLEX_TYPES: Set[Type[DataType]] = {StructType, ArrayType, MapType} +COMPLEX_TYPES: set[type[DataType]] = {StructType, ArrayType, MapType} """Spark types indicating complex types.""" @@ -70,7 +71,7 @@ def _cache_records(self, dataframe: DataFrame, cache_prefix: URI) -> URI: return chunk_uri def create_entity_from_py_iterator( - self, entity_name: EntityName, records: Iterator[Dict[str, Any]], schema: Type[BaseModel] + self, entity_name: EntityName, records: Iterator[dict[str, Any]], schema: type[BaseModel] ) -> DataFrame: return self.spark_session.createDataFrame( # type: ignore records, @@ -79,7 +80,7 @@ def create_entity_from_py_iterator( def apply_data_contract( self, entities: SparkEntities, contract_metadata: DataContractMetadata - ) -> Tuple[SparkEntities, Messages, StageSuccessful]: + ) -> tuple[SparkEntities, Messages, StageSuccessful]: self.logger.info("Applying data contracts") all_messages: Messages = [] @@ -159,10 +160,10 @@ def read_csv_file( reader: CSVFileReader, resource: URI, entity_name: EntityName, # pylint: disable=unused-argument - schema: Type[BaseModel], + schema: type[BaseModel], ) -> DataFrame: """Read a CSV file using Apache Spark.""" - reader_args: Dict[str, Any] = { + reader_args: dict[str, Any] = { "inferSchema": False, "header": reader.header, "multiLine": True, diff --git a/src/dve/core_engine/backends/implementations/spark/readers/csv.py b/src/dve/core_engine/backends/implementations/spark/readers/csv.py index 5e21714..a95cad2 100644 --- a/src/dve/core_engine/backends/implementations/spark/readers/csv.py +++ b/src/dve/core_engine/backends/implementations/spark/readers/csv.py @@ -1,7 +1,7 @@ """A reader implementation using the Databricks Spark CSV reader.""" - -from typing import Any, Dict, Iterator, Type +from collections.abc import Iterator +from typing import Any, Optional from pydantic import BaseModel from pyspark.sql import DataFrame, SparkSession @@ -30,7 +30,7 @@ def __init__( header: bool = True, multi_line: bool = False, encoding: str = "utf-8-sig", - spark_session: SparkSession = None, + spark_session: Optional[SparkSession] = None, ) -> None: self.delimiter = delimiter @@ -39,13 +39,13 @@ def __init__( self.quote_char = quote_char self.header = header self.multi_line = multi_line - self.spark_session = spark_session if spark_session else SparkSession.builder.getOrCreate() + self.spark_session = spark_session if spark_session else SparkSession.builder.getOrCreate() # type: ignore # pylint: disable=C0301 super().__init__() def read_to_py_iterator( - self, resource: URI, entity_name: EntityName, schema: Type[BaseModel] - ) -> Iterator[Dict[URI, Any]]: + self, resource: URI, entity_name: EntityName, schema: type[BaseModel] + ) -> Iterator[dict[URI, Any]]: df = self.read_to_dataframe(resource, entity_name, schema) yield from (record.asDict(True) for record in df.toLocalIterator()) @@ -54,7 +54,7 @@ def read_to_dataframe( self, resource: URI, entity_name: EntityName, # pylint: disable=unused-argument - schema: Type[BaseModel], + schema: type[BaseModel], ) -> DataFrame: """Read a CSV file directly to a Spark DataFrame.""" if get_content_length(resource) == 0: diff --git a/src/dve/core_engine/backends/implementations/spark/readers/json.py b/src/dve/core_engine/backends/implementations/spark/readers/json.py index 6c1902b..56f394e 100644 --- a/src/dve/core_engine/backends/implementations/spark/readers/json.py +++ b/src/dve/core_engine/backends/implementations/spark/readers/json.py @@ -1,7 +1,7 @@ """A reader implementation using the Databricks Spark JSON reader.""" - -from typing import Any, Dict, Iterator, Optional, Type +from collections.abc import Iterator +from typing import Any, Optional from pydantic import BaseModel from pyspark.sql import DataFrame, SparkSession @@ -31,13 +31,13 @@ def __init__( self.encoding = encoding self.multi_line = multi_line - self.spark_session = spark_session if spark_session else SparkSession.builder.getOrCreate() + self.spark_session = spark_session if spark_session else SparkSession.builder.getOrCreate() # type: ignore # pylint: disable=C0301 super().__init__() def read_to_py_iterator( - self, resource: URI, entity_name: EntityName, schema: Type[BaseModel] - ) -> Iterator[Dict[URI, Any]]: + self, resource: URI, entity_name: EntityName, schema: type[BaseModel] + ) -> Iterator[dict[URI, Any]]: df = self.read_to_dataframe(resource, entity_name, schema) yield from (record.asDict(True) for record in df.toLocalIterator()) @@ -46,7 +46,7 @@ def read_to_dataframe( self, resource: URI, entity_name: EntityName, # pylint: disable=unused-argument - schema: Type[BaseModel], + schema: type[BaseModel], ) -> DataFrame: """Read a JSON file directly to a Spark DataFrame.""" if get_content_length(resource) == 0: diff --git a/src/dve/core_engine/backends/implementations/spark/readers/xml.py b/src/dve/core_engine/backends/implementations/spark/readers/xml.py index a2ae2c5..8d29f24 100644 --- a/src/dve/core_engine/backends/implementations/spark/readers/xml.py +++ b/src/dve/core_engine/backends/implementations/spark/readers/xml.py @@ -1,7 +1,8 @@ """A reader implementation using the Databricks Spark XML reader.""" import re -from typing import Any, Collection, Dict, Iterable, Iterator, Optional, Type +from collections.abc import Collection, Iterable, Iterator +from typing import Any, Optional from pydantic import BaseModel from pyspark.sql import DataFrame, SparkSession @@ -38,7 +39,7 @@ def read_to_dataframe( self, resource: URI, entity_name: EntityName, - schema: Type[BaseModel], + schema: type[BaseModel], ) -> DataFrame: """Stream an XML file into a Spark data frame""" if not self.spark: @@ -86,8 +87,8 @@ def __init__( super().__init__() def read_to_py_iterator( - self, resource: URI, entity_name: EntityName, schema: Type[BaseModel] - ) -> Iterator[Dict[URI, Any]]: + self, resource: URI, entity_name: EntityName, schema: type[BaseModel] + ) -> Iterator[dict[URI, Any]]: df = self.read_to_dataframe(resource, entity_name, schema) yield from (record.asDict(True) for record in df.toLocalIterator()) @@ -96,7 +97,7 @@ def read_to_dataframe( self, resource: URI, entity_name: EntityName, # pylint: disable=unused-argument - schema: Type[BaseModel], + schema: type[BaseModel], ) -> DataFrame: """Read an XML file directly to a Spark DataFrame using the Databricks XML reader package. diff --git a/src/dve/core_engine/backends/implementations/spark/reference_data.py b/src/dve/core_engine/backends/implementations/spark/reference_data.py index a5b949d..de323d7 100644 --- a/src/dve/core_engine/backends/implementations/spark/reference_data.py +++ b/src/dve/core_engine/backends/implementations/spark/reference_data.py @@ -1,7 +1,7 @@ # pylint: disable=no-member """A reference data loader for Spark.""" -from typing import Dict, Optional +from typing import Optional from pyspark.sql import DataFrame, SparkSession @@ -28,7 +28,7 @@ class SparkRefDataLoader(BaseRefDataLoader[DataFrame]): def __init__( self, - reference_entity_config: Dict[EntityName, ReferenceConfig], + reference_entity_config: dict[EntityName, ReferenceConfig], **kwargs, ) -> None: super().__init__(reference_entity_config, **kwargs) diff --git a/src/dve/core_engine/backends/implementations/spark/rules.py b/src/dve/core_engine/backends/implementations/spark/rules.py index 9075db2..93baae9 100644 --- a/src/dve/core_engine/backends/implementations/spark/rules.py +++ b/src/dve/core_engine/backends/implementations/spark/rules.py @@ -1,6 +1,7 @@ """Step implementations in Spark.""" -from typing import Callable, Dict, List, Optional, Set, Tuple +from collections.abc import Callable +from typing import Optional from uuid import uuid4 from pyspark.sql import DataFrame, SparkSession @@ -54,13 +55,13 @@ class SparkStepImplementations(BaseStepImplementations[DataFrame]): """An implementation of transformation steps in Apache Spark.""" - def __init__(self, spark_session: SparkSession = None, **kwargs): + def __init__(self, spark_session: Optional[SparkSession] = None, **kwargs): self._spark_session = spark_session - self._registered_functions: List[str] = [] + self._registered_functions: list[str] = [] super().__init__(**kwargs) @property - def spark_session(self): + def spark_session(self) -> SparkSession: """The current spark session""" if not self._spark_session: self._spark_session = SparkSession.builder.getOrCreate() @@ -80,14 +81,14 @@ def register_udfs( """Register all function implementations as Spark UDFs.""" spark_session = spark_session or SparkSession.builder.getOrCreate() - _registered_functions: Set[str] = get_all_registered_udfs(spark_session) - _available_functions: Dict[str, Callable] = { + _registered_functions: set[str] = get_all_registered_udfs(spark_session) + _available_functions: dict[str, Callable] = { func_name: func for func_name, func in vars(functions).items() if callable(func) and func.__module__ == "dve.core_engine.functions.implementations" } - _unregistered_functions: Set[str] = set(_available_functions).difference( + _unregistered_functions: set[str] = set(_available_functions).difference( _registered_functions ) @@ -151,7 +152,7 @@ def group_by(self, entities: SparkEntities, *, config: Aggregation) -> Messages: def _perform_join( self, entities: SparkEntities, config: AbstractConditionalJoin - ) -> Tuple[Source, Target, Joined]: + ) -> tuple[Source, Target, Joined]: """Perform a conditional join between source and target, returning the source, target and joined DataFrames. diff --git a/src/dve/core_engine/backends/implementations/spark/spark_helpers.py b/src/dve/core_engine/backends/implementations/spark/spark_helpers.py index 12f32d3..921b04e 100644 --- a/src/dve/core_engine/backends/implementations/spark/spark_helpers.py +++ b/src/dve/core_engine/backends/implementations/spark/spark_helpers.py @@ -8,20 +8,14 @@ import datetime as dt import logging import time +from collections.abc import Callable, Generator, Iterator from dataclasses import dataclass, is_dataclass from decimal import Decimal from functools import wraps from typing import ( Any, - Callable, ClassVar, - Dict, - Generator, - Iterator, - List, Optional, - Set, - Type, TypeVar, Union, overload, @@ -96,7 +90,7 @@ def __post_init__(self): """The defualt decimal precision/scale config.""" -PYTHON_TYPE_TO_SPARK_TYPE: Dict[type, st.DataType] = { +PYTHON_TYPE_TO_SPARK_TYPE: dict[type, st.DataType] = { str: st.StringType(), int: st.LongType(), bool: st.BooleanType(), @@ -111,20 +105,20 @@ def __post_init__(self): PydanticModel = TypeVar("PydanticModel", bound=BaseModel) """An Pydantic model.""" -TypedDictSubclass = TypeVar("TypedDictSubclass", bound=Type[TypedDict]) # type: ignore +TypedDictSubclass = TypeVar("TypedDictSubclass", bound=type[TypedDict]) # type: ignore """A TypedDict subclass.""" class GenericDataclass(Protocol): # pylint: disable=too-few-public-methods """A dataclass-like class.""" - __dataclass_fields__: Dict[str, Any] # `is_dataclass` checks for this field. + __dataclass_fields__: dict[str, Any] # `is_dataclass` checks for this field. @overload def get_type_from_annotation( type_annotation: Union[ - Type[PydanticModel], Type[GenericDataclass], GenericDataclass, TypedDictSubclass + type[PydanticModel], type[GenericDataclass], GenericDataclass, TypedDictSubclass ], ) -> st.StructType: pass # pragma: no cover @@ -150,10 +144,10 @@ def get_type_from_annotation(type_annotation: Any) -> st.DataType: * `datetime.datetime`: a Spark `TimestampType` * `decimal.Decimal`: a Spark `DecimalType` with precision of 38 and scale of 18 - - A list of supported types (e.g. `List[str]` or `typing.List[str]`). + - A list of supported types (e.g. `list[str]` or `typing.list[str]`). This will return a Spark `ArrayType` with the specified element type. - A `typing.Optional` type or a `typing.Union` of the type and `None` (e.g. - `typing.Optional[str]`, `typing.Union[List[str], None]`). This will remove the + `typing.Optional[str]`, `typing.Union[list[str], None]`). This will remove the 'optional' wrapper and return the inner type (Spark types are all nullable). - A subclass of `typing.TypedDict` with values typed using supported types. This will parse the value types as Spark types and return a Spark `StructType`. @@ -176,7 +170,7 @@ def get_type_from_annotation(type_annotation: Any) -> st.DataType: python_type = _get_non_heterogenous_type(get_args(type_annotation)) return get_type_from_annotation(python_type) - # Type hint is e.g. `List[str]`, check to ensure non-heterogenity. + # type hint is e.g. `list[str]`, check to ensure non-heterogenity. if type_origin is list or (isinstance(type_origin, type) and issubclass(type_origin, list)): element_type = _get_non_heterogenous_type(get_args(type_annotation)) return st.ArrayType(get_type_from_annotation(element_type)) @@ -200,11 +194,11 @@ def get_type_from_annotation(type_annotation: Any) -> st.DataType: raise ValueError(f"Unsupported type annotation {type_annotation!r}") if ( - # Type hint is a dict subclass, but not dict. Possibly a `TypedDict`. + # type hint is a dict subclass, but not dict. Possibly a `TypedDict`. (issubclass(type_annotation, dict) and type_annotation is not dict) - # Type hint is a dataclass. + # type hint is a dataclass. or is_dataclass(type_annotation) - # Type hint is a `pydantic` model. + # type hint is a `pydantic` model. or (type_origin is None and issubclass(type_annotation, BaseModel)) ): fields = [] @@ -234,10 +228,10 @@ def get_type_from_annotation(type_annotation: Any) -> st.DataType: if type_annotation is list: raise ValueError( - f"List must have type annotation (e.g. `List[str]`), got {type_annotation!r}" + f"list must have type annotation (e.g. `list[str]`), got {type_annotation!r}" ) if type_annotation is dict or type_origin is dict: - raise ValueError(f"Dict must be `typing.TypedDict` subclass, got {type_annotation!r}") + raise ValueError(f"dict must be `typing.TypedDict` subclass, got {type_annotation!r}") for type_ in type_annotation.mro(): spark_type = PYTHON_TYPE_TO_SPARK_TYPE.get(type_) @@ -284,7 +278,7 @@ def create_udf(function: Callable) -> Callable: SupportedBaseType = Union[str, int, bool, float, Decimal, dt.date, dt.datetime] """Supported base types for Spark literals.""" SparkLiteralType = Union[ # type: ignore - SupportedBaseType, Dict[str, "SparkLiteralType"], List["SparkLiteralType"] # type: ignore + SupportedBaseType, dict[str, "SparkLiteralType"], list["SparkLiteralType"] # type: ignore ] """Recursive definition of supported literal types.""" @@ -313,7 +307,7 @@ def object_to_spark_literal(obj: SparkLiteralType) -> Column: array_elements = [] arr_element_type: Optional[type] = None - arr_element_keys: Optional[Set[str]] = None + arr_element_keys: Optional[set[str]] = None for element in obj: element_type = type(element) @@ -347,12 +341,12 @@ def _spark_read_parquet(self, path: URI, **kwargs) -> DataFrame: def _spark_write_parquet( # pylint: disable=unused-argument - self, entity: Union[Iterator[Dict[str, Any]], DataFrame], target_location: URI, **kwargs + self, entity: Union[Iterator[dict[str, Any]], DataFrame], target_location: URI, **kwargs ) -> URI: """Method to write parquet files from type cast entities following data contract application """ - _options: Dict[str, Any] = {**kwargs} + _options: dict[str, Any] = {**kwargs} if isinstance(entity, Generator): _writer = self.spark_session.createDataFrame(entity).write else: @@ -388,7 +382,7 @@ def spark_get_entity_count(cls): return cls -def get_all_registered_udfs(spark: SparkSession) -> Set[str]: +def get_all_registered_udfs(spark: SparkSession) -> set[str]: """Function to supply the names of a registered functions stored in the supplied spark session. """ diff --git a/src/dve/core_engine/backends/implementations/spark/types.py b/src/dve/core_engine/backends/implementations/spark/types.py index 9994a6c..96398da 100644 --- a/src/dve/core_engine/backends/implementations/spark/types.py +++ b/src/dve/core_engine/backends/implementations/spark/types.py @@ -1,6 +1,6 @@ """Types used in Spark implementations.""" -from typing import MutableMapping +from collections.abc import MutableMapping from pyspark.sql import DataFrame diff --git a/src/dve/core_engine/backends/implementations/spark/utilities.py b/src/dve/core_engine/backends/implementations/spark/utilities.py index e876524..bd5b02a 100644 --- a/src/dve/core_engine/backends/implementations/spark/utilities.py +++ b/src/dve/core_engine/backends/implementations/spark/utilities.py @@ -1,9 +1,10 @@ """Some utilities which are useful for implementing Spark transformations.""" import datetime as dt +from collections.abc import Callable from json import JSONEncoder from operator import and_, or_ -from typing import Any, Callable, List +from typing import Any from pydantic import BaseModel from pyspark.sql import SparkSession @@ -50,7 +51,7 @@ def any_columns(*columns: Column) -> Column: return _apply_operation_to_column_sequence(*columns, operation=or_) -def expr_mapping_to_columns(expressions: ExpressionMapping) -> List[Column]: +def expr_mapping_to_columns(expressions: ExpressionMapping) -> list[Column]: """Convert a mapping of expression to alias to a list of columns. Where the expression requires a tuple of column names, the alias should be a list of column names. @@ -67,12 +68,12 @@ def expr_mapping_to_columns(expressions: ExpressionMapping) -> List[Column]: return columns -def expr_array_to_columns(expressions: ExpressionArray) -> List[Column]: +def expr_array_to_columns(expressions: ExpressionArray) -> list[Column]: """Convert an array of expressions to a list of columns.""" return list(map(sf.expr, expressions)) -def multiexpr_string_to_columns(expressions: MultiExpression) -> List[Column]: +def multiexpr_string_to_columns(expressions: MultiExpression) -> list[Column]: """Convert multiple SQL expressions in a comma-delimited string to a list of columns. @@ -81,7 +82,7 @@ def multiexpr_string_to_columns(expressions: MultiExpression) -> List[Column]: return expr_array_to_columns(expression_list) -def parse_multiple_expressions(expressions: MultipleExpressions) -> List[Column]: +def parse_multiple_expressions(expressions: MultipleExpressions) -> list[Column]: """Parse multiple expressions provided as a mapping or alias to expression, an array of expressions, or a string containing multiple comma-delimited SQL expressions. diff --git a/src/dve/core_engine/backends/metadata/contract.py b/src/dve/core_engine/backends/metadata/contract.py index 3eada84..12beb41 100644 --- a/src/dve/core_engine/backends/metadata/contract.py +++ b/src/dve/core_engine/backends/metadata/contract.py @@ -1,6 +1,6 @@ """Metadata classes for the data contract.""" -from typing import Any, Dict, Type +from typing import Any from pydantic import BaseModel, PrivateAttr, root_validator @@ -14,14 +14,14 @@ class ReaderConfig(BaseModel): reader: str """The name of the reader to be used.""" - parameters: Dict[str, Any] + parameters: dict[str, Any] """The parameters the reader should use.""" class DataContractMetadata(BaseModel, frozen=True, arbitrary_types_allowed=True): """Metadata for the data contract.""" - reader_metadata: Dict[EntityName, Dict[Extension, ReaderConfig]] + reader_metadata: dict[EntityName, dict[Extension, ReaderConfig]] """ The per-entity reader metadata. @@ -30,17 +30,17 @@ class DataContractMetadata(BaseModel, frozen=True, arbitrary_types_allowed=True) the requested reader. """ - validators: Dict[EntityName, RowValidator] + validators: dict[EntityName, RowValidator] """The per-entity record validators.""" - reporting_fields: Dict[EntityName, ReportingFields] + reporting_fields: dict[EntityName, ReportingFields] """The per-entity reporting fields.""" cache_originals: bool = False """Whether to cache the original entities after loading.""" - _schemas: Dict[EntityName, Type[BaseModel]] = PrivateAttr(default_factory=dict) + _schemas: dict[EntityName, type[BaseModel]] = PrivateAttr(default_factory=dict) """The pydantic models of the schmas.""" @property - def schemas(self) -> Dict[EntityName, Type[BaseModel]]: + def schemas(self) -> dict[EntityName, type[BaseModel]]: """The per-entity schemas, as pydantic models.""" if not self._schemas: for entity_name, validator in self.validators.items(): @@ -49,7 +49,7 @@ def schemas(self) -> Dict[EntityName, Type[BaseModel]]: @root_validator(allow_reuse=True) @classmethod - def _ensure_entities_complete(cls, values: Dict[str, Dict[EntityName, Any]]): + def _ensure_entities_complete(cls, values: dict[str, dict[EntityName, Any]]): """Ensure the entities in 'readers' and 'validators' are the same.""" try: reader_entities = set(values["reader_metadata"].keys()) diff --git a/src/dve/core_engine/backends/metadata/reporting.py b/src/dve/core_engine/backends/metadata/reporting.py index 83555ac..0f2079a 100644 --- a/src/dve/core_engine/backends/metadata/reporting.py +++ b/src/dve/core_engine/backends/metadata/reporting.py @@ -2,7 +2,8 @@ import json import warnings -from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union +from collections.abc import Callable +from typing import Any, ClassVar, Optional, Union from pydantic import BaseModel, root_validator, validate_arguments from typing_extensions import Literal @@ -27,7 +28,7 @@ class BaseReportingConfig(BaseModel): """ - UNTEMPLATED_FIELDS: ClassVar[Set[str]] = {"message"} + UNTEMPLATED_FIELDS: ClassVar[set[str]] = {"message"} """Fields that should not be templated.""" emit: Optional[str] = None @@ -107,7 +108,7 @@ def template( self, local_variables: TemplateVariables, *, - global_variables: TemplateVariables = None, + global_variables: Optional[TemplateVariables] = None, ) -> "BaseReportingConfig": """Template a reporting config.""" type_ = type(self) @@ -127,7 +128,7 @@ class ReportingConfig(BaseReportingConfig): emit: ErrorEmitValue = "record_failure" category: ErrorCategory = "Bad value" - def _get_root_and_fields(self) -> Tuple[Optional[str], Union[Literal["*"], List[str]]]: + def _get_root_and_fields(self) -> tuple[Optional[str], Union[Literal["*"], list[str]]]: """Get the source field (or None, if the source is the root of the record) and a list of fields (or `'*'`) if all fields are to be selected from the location. @@ -140,7 +141,7 @@ def _get_root_and_fields(self) -> Tuple[Optional[str], Union[Literal["*"], List[ if len(nesting_splits) > 2: raise ValueError("Nesting must be a maximum of one level") - fields: Union[Literal["*"], List[str]] + fields: Union[Literal["*"], list[str]] fields = [field.strip() for field in nesting_splits[-1].strip("{}").split(",")] if fields and fields[0] == "*": fields = "*" @@ -165,7 +166,7 @@ def legacy_location(self) -> Optional[str]: @property # pylint: disable=too-many-return-statements - def legacy_reporting_field(self) -> Union[str, List[str], None]: + def legacy_reporting_field(self) -> Union[str, list[str], None]: """DEPRECATED: The legacy reporting field, extracted from `location`.""" warnings.warn("Use new combined `location` field", DeprecationWarning) if self.location is None: @@ -207,7 +208,7 @@ def legacy_error_type(self) -> Literal["record", "submission", "integrity"]: def get_location_selector( self, - ) -> Callable[[Dict[str, Any]], Union[List[Dict[str, Any]], Dict[str, Any], None]]: + ) -> Callable[[dict[str, Any]], Union[list[dict[str, Any]], dict[str, Any], None]]: """Get a function which extracts the location from a provided record.""" # TODO: Check this against the schema to eliminate type checks at runtime. # This should enable us to use some really efficient 'getter' functions. @@ -223,7 +224,7 @@ def get_location_selector( if fields == "*": return lambda record: record[root_field] # type: ignore - def _selector(record: Dict[str, Any]) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: + def _selector(record: dict[str, Any]) -> Union[list[dict[str, Any]], dict[str, Any], None]: map_or_list = record[root_field] # type: ignore if not isinstance(map_or_list, (dict, list)): return None @@ -236,8 +237,8 @@ def _selector(record: Dict[str, Any]) -> Union[List[Dict[str, Any]], Dict[str, A return _selector def get_location_value( - self, record: Optional[Dict[str, Any]] - ) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: + self, record: Optional[dict[str, Any]] + ) -> Union[list[dict[str, Any]], dict[str, Any], None]: """Get the value of the location field from a record.""" if record is None: return None @@ -254,7 +255,7 @@ class LegacyReportingConfig(BaseReportingConfig): legacy_location: Optional[str] = None """DEPRECATED: The legacy error location, now a component of `location`.""" - legacy_reporting_field: Optional[Union[str, List[str]]] = None + legacy_reporting_field: Optional[Union[str, list[str]]] = None """DEPRECATED: The legacy reporting field, now a component of `location`.""" legacy_error_type: Optional[str] = None """DEPRECATED: The legacy error type.""" @@ -263,7 +264,7 @@ class LegacyReportingConfig(BaseReportingConfig): @root_validator(allow_reuse=True, skip_on_failure=True) @classmethod - def _ensure_only_one_reporting_config(cls, values: Dict[str, Any]) -> Dict[str, Any]: + def _ensure_only_one_reporting_config(cls, values: dict[str, Any]) -> dict[str, Any]: """Ensure only the modern or legacy location is populated.""" has_modern = bool(values.get("location")) has_legacy = bool(values.get("legacy_location") or values.get("legacy_reporting_field")) @@ -277,7 +278,7 @@ def _ensure_only_one_reporting_config(cls, values: Dict[str, Any]) -> Dict[str, @root_validator(allow_reuse=True, skip_on_failure=True) @classmethod - def _ensure_only_one_error_type_config(cls, values: Dict[str, Any]) -> Dict[str, Any]: + def _ensure_only_one_error_type_config(cls, values: dict[str, Any]) -> dict[str, Any]: """Ensure only the modern or legacy error type is populated.""" has_modern = bool(values.get("emit")) has_legacy = bool( @@ -313,7 +314,7 @@ def _convert_legacy_emit_value( @staticmethod @validate_arguments def _convert_legacy_reporting_fields( - error_location: Optional[str] = None, reporting_field: Union[str, List[str], None] = None + error_location: Optional[str] = None, reporting_field: Union[str, list[str], None] = None ) -> Optional[str]: """Convert legacy reporting field specification to a new location string.""" if error_location is None and reporting_field is None: @@ -345,7 +346,7 @@ def template( self, local_variables: TemplateVariables, *, - global_variables: TemplateVariables = None, + global_variables: Optional[TemplateVariables] = None, ) -> "ReportingConfig": """Template the untemplated reporting config.""" if global_variables: diff --git a/src/dve/core_engine/backends/metadata/rules.py b/src/dve/core_engine/backends/metadata/rules.py index 30afcee..7bc0353 100644 --- a/src/dve/core_engine/backends/metadata/rules.py +++ b/src/dve/core_engine/backends/metadata/rules.py @@ -2,19 +2,8 @@ import warnings from abc import ABCMeta, abstractmethod -from typing import ( - Any, - ClassVar, - Dict, - Iterator, - List, - Optional, - Sequence, - Set, - Tuple, - TypeVar, - Union, -) +from collections.abc import Iterator, Sequence +from typing import Any, ClassVar, Optional, TypeVar, Union from pydantic import BaseModel, Extra, Field, root_validator, validate_arguments, validator from typing_extensions import Literal @@ -77,7 +66,7 @@ class ParentMetadata(BaseModel): """The name of the stage in the rule the step belongs to.""" def __repr__(self) -> str: - components: List[str] = [] + components: list[str] = [] if isinstance(self.rule, Rule): components.append(f"rule=Rule(name={self.rule.name!r}, ...)") else: @@ -106,7 +95,7 @@ class AbstractStep(BaseModel, metaclass=ABCMeta): parent: Optional[ParentMetadata] = None """Data about the parent rule and the step's place within it.""" - UNTEMPLATED_KEYS: ClassVar[Set[str]] = {"id", "description", "parent"} + UNTEMPLATED_KEYS: ClassVar[set[str]] = {"id", "description", "parent"} """A set of aliases which are exempted from templating.""" class Config: # pylint: disable=too-few-public-methods @@ -115,21 +104,21 @@ class Config: # pylint: disable=too-few-public-methods frozen = True extra = Extra.forbid - def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: + def __repr_args__(self) -> Sequence[tuple[Optional[str], Any]]: # Exclude nulls from 'repr' for conciseness. return [(key, value) for key, value in super().__repr_args__() if value is not None] @abstractmethod - def get_required_entities(self) -> Set[EntityName]: + def get_required_entities(self) -> set[EntityName]: """Get a set of the required entity names for the transformation.""" raise NotImplementedError() # pragma: no cover @abstractmethod - def get_created_entities(self) -> Set[EntityName]: + def get_created_entities(self) -> set[EntityName]: """Get a set of the entity names created by the transformation.""" raise NotImplementedError() # pragma: no cover - def get_removed_entities(self) -> Set[EntityName]: + def get_removed_entities(self) -> set[EntityName]: """Get a set of the entity names removed by the transformation.""" return set() @@ -141,7 +130,7 @@ def template( self: ASSelf, local_variables: TemplateVariables, *, - global_variables: TemplateVariables = None, + global_variables: Optional[TemplateVariables] = None, ) -> ASSelf: """Template the rule, given the global and local variables.""" type_ = type(self) @@ -162,7 +151,7 @@ def __str__(self): # pydantic's default __str__ strips the model name. @root_validator(pre=True) @classmethod - def _warn_for_deprecated_aliases(cls, values: Dict[str, JSONable]) -> Dict[str, JSONable]: + def _warn_for_deprecated_aliases(cls, values: dict[str, JSONable]) -> dict[str, JSONable]: for deprecated_name, replacement in ( ("entity", "entity_name"), ("target", "target_entity_name"), @@ -191,11 +180,11 @@ class BaseStep(AbstractStep, metaclass=ABCMeta): new_entity_name: Optional[EntityName] = None """Optionally, a new entity to create after the operation.""" - def get_required_entities(self) -> Set[EntityName]: + def get_required_entities(self) -> set[EntityName]: """Get a set of the required entity names for the transformation.""" return {self.entity_name} - def get_created_entities(self) -> Set[EntityName]: + def get_created_entities(self) -> set[EntityName]: """Get a set of the entity names created by the transformation.""" return {self.new_entity_name or self.entity_name} @@ -239,9 +228,9 @@ class DeferredFilter(AbstractStep): def template( self: "DeferredFilter", - local_variables: Dict[Alias, Any], + local_variables: dict[Alias, Any], *, - global_variables: Dict[Alias, Any] = None, + global_variables: Optional[dict[Alias, Any]] = None, ) -> "DeferredFilter": """Template the rule, given the global and local variables.""" type_ = type(self) @@ -260,11 +249,11 @@ def template( return type_(**templated_data, **untemplated_data) - def get_required_entities(self) -> Set[EntityName]: + def get_required_entities(self) -> set[EntityName]: """Get a set of the required entity names for the transformation.""" return {self.entity_name} - def get_created_entities(self) -> Set[EntityName]: + def get_created_entities(self) -> set[EntityName]: """Get a set of the required entity names for the transformation.""" return {self.entity_name} @@ -287,15 +276,15 @@ class Notification(AbstractStep): emit notifications according to the reporting config. """ - excluded_columns: List[Alias] = Field(default_factory=list) + excluded_columns: list[Alias] = Field(default_factory=list) """Columns to be excluded from the record in the report.""" reporting: ReportingConfig """The reporting information for the filter.""" - def get_required_entities(self) -> Set[EntityName]: + def get_required_entities(self) -> set[EntityName]: return {self.entity_name} - def get_created_entities(self) -> Set[EntityName]: + def get_created_entities(self) -> set[EntityName]: return set() @@ -309,10 +298,10 @@ class AbstractJoin(AbstractStep, metaclass=ABCMeta): new_entity_name: Optional[EntityName] = None """Optionally, a new entity to create after the operation.""" - def get_required_entities(self) -> Set[EntityName]: + def get_required_entities(self) -> set[EntityName]: return {self.entity_name, self.target_name} - def get_created_entities(self) -> Set[EntityName]: + def get_created_entities(self) -> set[EntityName]: return {self.new_entity_name or self.entity_name} @@ -379,7 +368,7 @@ class Aggregation(BaseStep): """Multiple expressions to group by.""" pivot_column: Optional[Alias] = None """An optional pivot column for the table.""" - pivot_values: Optional[List[Any]] = None + pivot_values: Optional[list[Any]] = None """A list of values to translate to columns when pivoting.""" agg_columns: Optional[MultipleExpressions] = None """Multiple aggregate expressions to take from the group by (for spark backend)""" @@ -391,7 +380,7 @@ class Aggregation(BaseStep): def _ensure_column_if_values( cls, value: Optional[Any], - values: Dict[str, Any], + values: dict[str, Any], ): """Ensure that `pivot_column` is not null if pivot values are provided.""" if value and not values["pivot_column"]: @@ -403,7 +392,7 @@ def _ensure_column_if_values( def _ensure_column_if_function( cls, agg_function: Optional[Any], - values: Dict[str, Any], + values: dict[str, Any], ): """Ensure that `pivot_column` is not null if pivot values are provided.""" if agg_function and not values["agg_columns"]: @@ -414,18 +403,18 @@ def _ensure_column_if_function( class EntityRemoval(AbstractStep): """A transformation which drops entities.""" - entity_name: Union[EntityName, List[EntityName]] + entity_name: Union[EntityName, list[EntityName]] """The entity to drop.""" - def get_required_entities(self) -> Set[EntityName]: + def get_required_entities(self) -> set[EntityName]: """Get a set of the required entity names for the transformation.""" return set() - def get_created_entities(self) -> Set[EntityName]: + def get_created_entities(self) -> set[EntityName]: """Get a set of the entity names created by the transformation.""" return set() - def get_removed_entities(self) -> Set[EntityName]: + def get_removed_entities(self) -> set[EntityName]: """Get a set of the entity names created by the transformation.""" if isinstance(self.entity_name, list): return set(self.entity_name) @@ -440,15 +429,15 @@ class CopyEntity(AbstractStep): new_entity_name: EntityName """The new name for the copied entity.""" - def get_required_entities(self) -> Set[EntityName]: + def get_required_entities(self) -> set[EntityName]: """Get a set of the required entity names for the transformation.""" return {self.entity_name} - def get_created_entities(self) -> Set[EntityName]: + def get_created_entities(self) -> set[EntityName]: """Get a set of the entity names created by the transformation.""" return {self.new_entity_name} - def get_removed_entities(self) -> Set[EntityName]: + def get_removed_entities(self) -> set[EntityName]: """Gets the entity which has been removed""" return set() @@ -456,7 +445,7 @@ def get_removed_entities(self) -> Set[EntityName]: class RenameEntity(CopyEntity): """A transformation which renames an entity.""" - def get_removed_entities(self) -> Set[EntityName]: + def get_removed_entities(self) -> set[EntityName]: """Get a set of the entity names removed by the transformation.""" return {self.entity_name} @@ -582,11 +571,11 @@ class Rule(BaseModel): name: str """The name of the rule.""" - pre_sync_steps: List[AbstractStep] + pre_sync_steps: list[AbstractStep] """The pre-sync steps in the rule.""" - sync_filter_steps: List[DeferredFilter] + sync_filter_steps: list[DeferredFilter] """The sync filter steps in the rule.""" - post_sync_steps: List[AbstractStep] + post_sync_steps: list[AbstractStep] """The post-sync steps in the rule.""" def __str__(self): # pydantic's default __str__ strips the model name. @@ -594,11 +583,11 @@ def __str__(self): # pydantic's default __str__ strips the model name. @classmethod @validate_arguments - def from_step_list(cls, name: str, steps: List[Step]): + def from_step_list(cls, name: str, steps: list[Step]): """Load the rule from a single step list.""" - pre_sync_steps: List[AbstractStep] = [] - sync_filter_steps: List[DeferredFilter] = [] - post_sync_steps: List[AbstractStep] = [] + pre_sync_steps: list[AbstractStep] = [] + sync_filter_steps: list[DeferredFilter] = [] + post_sync_steps: list[AbstractStep] = [] self = cls( name=name, pre_sync_steps=pre_sync_steps, @@ -662,13 +651,13 @@ def template( self: RSelf, local_variables: TemplateVariables, *, - global_variables: TemplateVariables = None, + global_variables: Optional[TemplateVariables] = None, ) -> RSelf: """Template the rule, returning the new templated rule. This is only really useful for 'upfront' templating, as all stages of the rule will be templated at once. """ - rule_lists: Dict[str, Union[List[AbstractStep], List[DeferredFilter]]] = { + rule_lists: dict[str, Union[list[AbstractStep], list[DeferredFilter]]] = { "pre_sync_steps": self.pre_sync_steps, "sync_filter_steps": self.sync_filter_steps, "post_sync_steps": self.post_sync_steps, @@ -687,9 +676,9 @@ def template( class RuleMetadata(BaseModel): """Metadata about the rules.""" - rules: List[Rule] + rules: list[Rule] """A list of rules to be applied to to the entities.""" - local_variables: Optional[List[TemplateVariables]] = None + local_variables: Optional[list[TemplateVariables]] = None """ An optional list of local, rule-level template variables. @@ -713,7 +702,7 @@ class RuleMetadata(BaseModel): performing work. """ - reference_data_config: Dict[EntityName, ReferenceConfigUnion] + reference_data_config: dict[EntityName, ReferenceConfigUnion] """ Per-entity configuration options for the reference data. @@ -724,7 +713,7 @@ class RuleMetadata(BaseModel): @root_validator() @classmethod - def _ensure_locals_same_length_as_rules(cls, values: Dict[str, List[Any]]): + def _ensure_locals_same_length_as_rules(cls, values: dict[str, list[Any]]): """Ensure that if 'local_variables' is provided, it's the same length as 'rules'.""" local_vars = values["local_variables"] if local_vars is not None: @@ -737,7 +726,7 @@ def _ensure_locals_same_length_as_rules(cls, values: Dict[str, List[Any]]): ) return values - def __iter__(self) -> Iterator[Tuple[Rule, TemplateVariables]]: # type: ignore + def __iter__(self) -> Iterator[tuple[Rule, TemplateVariables]]: # type: ignore """Iterate over the rules and local variables.""" if self.local_variables is None: yield from ((rule, {}) for rule in self.rules) # type: ignore diff --git a/src/dve/core_engine/backends/readers/__init__.py b/src/dve/core_engine/backends/readers/__init__.py index ade6be4..296264c 100644 --- a/src/dve/core_engine/backends/readers/__init__.py +++ b/src/dve/core_engine/backends/readers/__init__.py @@ -5,7 +5,6 @@ """ import warnings -from typing import Dict, List, Type from dve.core_engine.backends.base.reader import BaseFileReader from dve.core_engine.backends.readers.csv import CSVFileReader @@ -14,14 +13,14 @@ ReaderName = str """The name of a reader type.""" -CORE_READERS: List[Type[BaseFileReader]] = [CSVFileReader, BasicXMLFileReader, XMLStreamReader] +CORE_READERS: list[type[BaseFileReader]] = [CSVFileReader, BasicXMLFileReader, XMLStreamReader] """A list of core reader types which should be registered.""" -_READER_REGISTRY: Dict[ReaderName, Type[BaseFileReader]] = {} +_READER_REGISTRY: dict[ReaderName, type[BaseFileReader]] = {} """A global registry of supported reader types.""" -def register_reader(reader_class: Type[BaseFileReader]): +def register_reader(reader_class: type[BaseFileReader]): """Register a reader type, making it accessible to the engine.""" if not issubclass(reader_class, BaseFileReader): raise TypeError(f"Reader type {reader_class} is not 'BaseFileReader' subclass") @@ -33,7 +32,7 @@ def register_reader(reader_class: Type[BaseFileReader]): _READER_REGISTRY[reader_name] = reader_class -def get_reader(reader_name: ReaderName) -> Type[BaseFileReader]: +def get_reader(reader_name: ReaderName) -> type[BaseFileReader]: """Get the reader type from the registry by name.""" return _READER_REGISTRY[reader_name] diff --git a/src/dve/core_engine/backends/readers/csv.py b/src/dve/core_engine/backends/readers/csv.py index 6969a32..c0b6479 100644 --- a/src/dve/core_engine/backends/readers/csv.py +++ b/src/dve/core_engine/backends/readers/csv.py @@ -2,8 +2,9 @@ """Core Python-based CSV reader.""" import csv +from collections.abc import Collection, Iterator from functools import partial -from typing import IO, Any, Collection, Dict, Iterator, List, Optional, Type +from typing import IO, Any, Optional import polars as pl from pydantic.main import BaseModel @@ -78,8 +79,8 @@ def __init__( self.encoding = encoding """Encoding of the CSV file.""" - def _get_reader_args(self) -> Dict[str, Any]: - reader_args: Dict[str, Any] = { + def _get_reader_args(self) -> dict[str, Any]: + reader_args: dict[str, Any] = { "delimiter": self.delimiter, "escapechar": self.escape_char, "quotechar": self.quote_char, @@ -103,7 +104,7 @@ def _parse_n_fields(self, stream: IO[str]) -> int: stream.seek(cursor_position) return len(row) - def _parse_field_names(self, stream: IO[str]) -> List[str]: + def _parse_field_names(self, stream: IO[str]) -> list[str]: """Peek the provided field names from the CSV, returning a list of field names as strings. @@ -129,8 +130,8 @@ def _parse_field_names(self, stream: IO[str]) -> List[str]: def _get_field_names( self, stream: IO[str], - field_names: List[str], - ) -> List[str]: + field_names: list[str], + ) -> list[str]: """Get field names to be used by the reader.""" # CSV already expected to have named fields. if self.header: @@ -164,8 +165,8 @@ def _get_field_names( return field_names def _coerce( - self, row: Dict[str, Optional[str]], field_names: List[str] - ) -> Dict[str, Optional[str]]: + self, row: dict[str, Optional[str]], field_names: list[str] + ) -> dict[str, Optional[str]]: """Coerce a parsed row into the indended shape, nulling values which are expected to be parsed as nulls. @@ -187,8 +188,8 @@ def read_to_py_iterator( self, resource: URI, entity_name: EntityName, - schema: Type[BaseModel], - ) -> Iterator[Dict[str, Any]]: + schema: type[BaseModel], + ) -> Iterator[dict[str, Any]]: """Reads the data to an iterator of dictionaries""" if get_content_length(resource) == 0: raise EmptyFileError(f"File at {resource!r} is empty") @@ -206,9 +207,9 @@ def read_to_py_iterator( def write_parquet( # type: ignore self, - entity: Iterator[Dict[str, Any]], + entity: Iterator[dict[str, Any]], target_location: URI, - schema: Optional[Type[BaseModel]] = None, + schema: Optional[type[BaseModel]] = None, **kwargs, ) -> EntityName: """Writes the data of the given entity to a parquet file""" @@ -217,7 +218,7 @@ def write_parquet( # type: ignore if isinstance(_get_implementation(target_location), LocalFilesystemImplementation): target_location = file_uri_to_local_path(target_location).as_posix() if schema: - polars_schema: Dict[str, pl.DataType] = { # type: ignore + polars_schema: dict[str, pl.DataType] = { # type: ignore fld.name: get_polars_type_from_annotation(fld.annotation) for fld in stringify_model(schema).__fields__.values() } diff --git a/src/dve/core_engine/backends/readers/xml.py b/src/dve/core_engine/backends/readers/xml.py index ca55eb9..bd7b8e4 100644 --- a/src/dve/core_engine/backends/readers/xml.py +++ b/src/dve/core_engine/backends/readers/xml.py @@ -2,7 +2,15 @@ """XML parsers for the Data Validation Engine.""" import re -from typing import IO, Any, Collection, Dict, Iterator, List, Optional, Type, Union, overload +from collections.abc import Collection, Iterator +from typing import ( + IO, + Any, + GenericAlias, # type: ignore + Optional, + Union, + overload +) import polars as pl from lxml import etree # type: ignore @@ -21,13 +29,13 @@ ) from dve.parser.file_handling.service import _get_implementation -XMLType = Union[Optional[str], List["XMLType"], Dict[str, "XMLType"]] # type: ignore +XMLType = Union[Optional[str], list["XMLType"], dict[str, "XMLType"]] # type: ignore """The definition of a type within XML.""" -XMLRecord = Dict[str, XMLType] # type: ignore +XMLRecord = dict[str, XMLType] # type: ignore """A record within XML.""" -TemplateElement = Union[None, List["TemplateElement"], Dict[str, "TemplateElement"]] # type: ignore +TemplateElement = Union[None, list["TemplateElement"], dict[str, "TemplateElement"]] # type: ignore """The base types used in the template row.""" -TemplateRow = Dict[str, "TemplateElement"] # type: ignore +TemplateRow = dict[str, "TemplateElement"] # type: ignore """The type of a template row.""" @@ -39,14 +47,14 @@ def _strip_annotated(annotation: Any) -> Any: return get_args(annotation)[0] -def create_template_row(schema: Type[BaseModel]) -> Dict[str, Any]: +def create_template_row(schema: type[BaseModel]) -> dict[str, Any]: """Create a template row from a schema. A template row is essentially the shape of the record that would be populated by the reader (i.e. contains default values), except lists are pre-populated with a single 'empty' record as a hint to the reader about the data structure. """ - template_row: Dict[str, Any] = {} + template_row: dict[str, Any] = {} for field_name, model_field_def in schema.__fields__.items(): field_type = _strip_annotated(model_field_def.annotation) @@ -54,7 +62,7 @@ def create_template_row(schema: Type[BaseModel]) -> Dict[str, Any]: template_row[field_name] = None continue - if isinstance(field_type, type): + if isinstance(field_type, type) and not isinstance(field_type, GenericAlias): if issubclass(field_type, BaseModel): template_row[field_name] = create_template_row(field_type) continue @@ -97,8 +105,7 @@ class XMLElement(Protocol): def clear(self) -> None: """Clear the element, removing children/attrs/etc.""" - def __iter__(self) -> Iterator["XMLElement"]: - ... + def __iter__(self) -> Iterator["XMLElement"]: ... class BasicXMLFileReader(BaseFileReader): @@ -177,12 +184,10 @@ def _sanitise_field(self, value: Optional[str]) -> Optional[str]: return value @overload - def _parse_element(self, element: XMLElement, template: TemplateRow) -> XMLRecord: - ... + def _parse_element(self, element: XMLElement, template: TemplateRow) -> XMLRecord: ... @overload - def _parse_element(self, element: XMLElement, template: TemplateElement) -> XMLType: - ... + def _parse_element(self, element: XMLElement, template: TemplateElement) -> XMLType: ... def _parse_element(self, element: XMLElement, template: Union[TemplateElement, TemplateRow]): """Parse an XML element according to a template.""" @@ -232,7 +237,7 @@ def _get_elements_from_stream(self, stream: IO[bytes]) -> Iterator[XMLElement]: tree: etree._ElementTree = etree.parse(stream, parser) root: etree._Element = tree.getroot() - elements: List[XMLElement] + elements: list[XMLElement] if self.root_tag: elements = root.xpath( f"//*[local-name()='{self.root_tag}']/*[local-name()='{self.record_tag}']" @@ -249,8 +254,8 @@ def _get_elements_from_stream(self, stream: IO[bytes]) -> Iterator[XMLElement]: break def _parse_xml( - self, stream: IO[bytes], schema: Type[BaseModel] - ) -> Iterator[Dict[str, XMLType]]: + self, stream: IO[bytes], schema: type[BaseModel] + ) -> Iterator[dict[str, XMLType]]: """Coerce a parsed record into the intended shape, nulling values which are expected to be parsed as nulls. @@ -265,8 +270,8 @@ def read_to_py_iterator( self, resource: URI, entity_name: EntityName, - schema: Type[BaseModel], - ) -> Iterator[Dict[str, Any]]: + schema: type[BaseModel], + ) -> Iterator[dict[str, Any]]: """Iterate through the contents of the file at URI, yielding rows containing the data. @@ -285,9 +290,9 @@ def read_to_py_iterator( def write_parquet( # type: ignore self, - entity: Iterator[Dict[str, Any]], + entity: Iterator[dict[str, Any]], target_location: URI, - schema: Optional[Type[BaseModel]] = None, + schema: Optional[type[BaseModel]] = None, **kwargs, ) -> URI: """Writes the data of the given entity out to a parquet file""" @@ -296,7 +301,7 @@ def write_parquet( # type: ignore if isinstance(_get_implementation(target_location), LocalFilesystemImplementation): target_location = file_uri_to_local_path(target_location).as_posix() if schema: - polars_schema: Dict[str, pl.DataType] = { # type: ignore + polars_schema: dict[str, pl.DataType] = { # type: ignore fld.name: get_polars_type_from_annotation(fld.type_) for fld in stringify_model(schema).__fields__.values() } @@ -392,9 +397,9 @@ def _get_elements_from_stream(self, stream: IO[bytes]) -> Iterator[XMLElement]: def write_parquet( # type: ignore self, - entity: Iterator[Dict[str, Any]], + entity: Iterator[dict[str, Any]], target_location: URI, - schema: Optional[Type[BaseModel]] = None, + schema: Optional[type[BaseModel]] = None, **kwargs, ) -> URI: """Writes the given entity data out to a parquet file""" diff --git a/src/dve/core_engine/backends/types.py b/src/dve/core_engine/backends/types.py index 61b3c14..2445ea8 100644 --- a/src/dve/core_engine/backends/types.py +++ b/src/dve/core_engine/backends/types.py @@ -5,7 +5,8 @@ """ -from typing import MutableMapping, TypeVar +from collections.abc import MutableMapping +from typing import TypeVar from dve.core_engine.type_hints import EntityName diff --git a/src/dve/core_engine/backends/utilities.py b/src/dve/core_engine/backends/utilities.py index 3086c81..9319780 100644 --- a/src/dve/core_engine/backends/utilities.py +++ b/src/dve/core_engine/backends/utilities.py @@ -4,7 +4,8 @@ from dataclasses import is_dataclass from datetime import date, datetime from decimal import Decimal -from typing import Any, ClassVar, Dict, Type, Union +from typing import Any, ClassVar, Union +from typing import GenericAlias # type: ignore import polars as pl # type: ignore from polars.datatypes.classes import DataTypeClass as PolarsType @@ -22,7 +23,7 @@ else: from typing import Annotated, get_args, get_origin, get_type_hints -PYTHON_TYPE_TO_POLARS_TYPE: Dict[type, PolarsType] = { +PYTHON_TYPE_TO_POLARS_TYPE: dict[type, PolarsType] = { # issue with decimal conversion at the moment... str: pl.Utf8, # type: ignore int: pl.Int64, # type: ignore @@ -36,9 +37,9 @@ """A mapping of Python types to the equivalent Polars types.""" -def stringify_type(type_: type) -> type: +def stringify_type(type_: Union[type, GenericAlias]) -> type: """Stringify an individual type.""" - if isinstance(type_, type): # A model, return the contents. + if isinstance(type_, type) and not isinstance(type_, GenericAlias): # A model, return the contents. # pylint: disable=C0301 if issubclass(type_, BaseModel): return stringify_model(type_) @@ -61,7 +62,7 @@ def stringify_type(type_: type) -> type: return origin[string_type_args] -def stringify_model(model: Type[BaseModel]) -> Type[BaseModel]: +def stringify_model(model: type[BaseModel]) -> type[BaseModel]: """Stringify a `pydantic` model.""" fields = {} for field_name, field in model.__fields__.items(): @@ -141,7 +142,7 @@ def get_polars_type_from_annotation(type_annotation: Any) -> PolarsType: # Type hint is a `pydantic` model. or (type_origin is None and issubclass(type_annotation, BaseModel)) ): - fields: Dict[str, PolarsType] = {} + fields: dict[str, PolarsType] = {} for field_name, field_annotation in get_type_hints(type_annotation).items(): # Technically non-string keys are disallowed, but people are bad. if not isinstance(field_name, str): diff --git a/src/dve/core_engine/configuration/base.py b/src/dve/core_engine/configuration/base.py index f01e258..2ec7d18 100644 --- a/src/dve/core_engine/configuration/base.py +++ b/src/dve/core_engine/configuration/base.py @@ -2,7 +2,7 @@ import json from abc import ABC, abstractmethod -from typing import Any, Dict, Type, TypeVar +from typing import Any, TypeVar from pydantic import BaseModel @@ -35,7 +35,7 @@ def get_contract_metadata(self) -> DataContractMetadata: """Build the contract metadata from the configuration.""" @abstractmethod - def get_reference_data_config(self) -> Dict[EntityName, Dict[str, Any]]: + def get_reference_data_config(self) -> dict[EntityName, dict[str, Any]]: """Get the configuration info for the reference data. This will likely be backend dependent, and should be a dict mapping reference @@ -45,7 +45,7 @@ def get_reference_data_config(self) -> Dict[EntityName, Dict[str, Any]]: """ @classmethod - def load(cls: Type[CSelf], location: URI) -> CSelf: + def load(cls: type[CSelf], location: URI) -> CSelf: """Load an instance of the config from the URI.""" with open_stream(location) as config_stream: json_config = json.load(config_stream) diff --git a/src/dve/core_engine/configuration/v1/__init__.py b/src/dve/core_engine/configuration/v1/__init__.py index ba1b9b6..057abe3 100644 --- a/src/dve/core_engine/configuration/v1/__init__.py +++ b/src/dve/core_engine/configuration/v1/__init__.py @@ -1,7 +1,7 @@ """The loader for the first JSON-based dataset configuration.""" import json -from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Optional, Union from pydantic import BaseModel, Field, PrivateAttr, validate_arguments from typing_extensions import Annotated, Literal @@ -33,7 +33,7 @@ """The name of a nested schema.""" RuleName = str """The name of a business rule.""" -RuleDependencies = Set[RuleName] +RuleDependencies = set[RuleName] """A list of dependencies required by a rule.""" FieldName = str @@ -45,7 +45,7 @@ Operation = str """The operation """ -RuleType = Type[AbstractStep] +RuleType = type[AbstractStep] """The metadata step type implemented by the rule.""" @@ -63,7 +63,7 @@ class _CallableTypeDefinition(_BaseTypeDefintion): callable: str """The callable to be called to create the type.""" - constraints: Dict[str, Any] = Field(default_factory=dict) + constraints: dict[str, Any] = Field(default_factory=dict) """The keyword arguments passed to the callable as kwargs.""" @@ -84,9 +84,9 @@ class _TypeAliasDefinition(_BaseTypeDefintion): class _SchemaConfig(BaseModel): """Configuration for a component schema within a dataset.""" - fields: Dict[FieldName, TypeOrDef] + fields: dict[FieldName, TypeOrDef] """Field definitions within the schema.""" - mandatory_fields: List[FieldName] = Field(default_factory=list) + mandatory_fields: list[FieldName] = Field(default_factory=list) """A list of the field names within the schema which _must_ be provided.""" @@ -95,22 +95,22 @@ class _ReaderConfig(BaseModel): # type: ignore reader: str """The name of the reader to use.""" - kwargs_: Dict[str, Any] = Field(alias="kwargs", default_factory=dict) + kwargs_: dict[str, Any] = Field(alias="kwargs", default_factory=dict) """Keyword arguments for the reader.""" - field_names: Optional[List[str]] = None + field_names: Optional[list[str]] = None """The field names to request from the reader. These are deprecated and will not be used.""" class _ModelConfig(_SchemaConfig): """A concrete model within the dataset.""" - reporting_fields: List[FieldName] = Field(default_factory=list) + reporting_fields: list[FieldName] = Field(default_factory=list) """A list of the reporting fields within the model.""" key_field: Optional[str] = None """A single key field to be used by the model.""" - reader_config: Dict[Extension, _ReaderConfig] + reader_config: dict[Extension, _ReaderConfig] """Reader configuration options for the model.""" - aliases: Dict[FieldName, FieldName] = Field(default_factory=dict) + aliases: dict[FieldName, FieldName] = Field(default_factory=dict) """An alias field name mapping.""" @@ -128,7 +128,7 @@ class _ComplexRuleConfig(BaseModel): rule_name: str """The name of the complex rule.""" - parameters: Dict[str, Any] = Field(default_factory=dict) + parameters: dict[str, Any] = Field(default_factory=dict) """The parameters for the rule.""" @@ -139,30 +139,30 @@ class V1DataContractConfig(BaseModel): """Whether to cache the original entities after loading.""" error_details: Optional[URI] = None """Optional URI containing custom data contract error codes and messages""" - types: Dict[TypeName, TypeOrDef] = Field(default_factory=dict) + types: dict[TypeName, TypeOrDef] = Field(default_factory=dict) """Dataset specific types defined within the config.""" - schemas: Dict[SchemaName, _SchemaConfig] = Field(default_factory=dict) + schemas: dict[SchemaName, _SchemaConfig] = Field(default_factory=dict) """Component schemas within the config.""" - datasets: Dict[SchemaName, _ModelConfig] + datasets: dict[SchemaName, _ModelConfig] """Concrete entity definitions which will be loaded by the config.""" class V1TransformationConfig(BaseModel): """Configuration for the transformation component of the dataset.""" - rule_stores: List[_RuleStoreConfig] = Field(default_factory=list) + rule_stores: list[_RuleStoreConfig] = Field(default_factory=list) """The external rule stores that rules can be referenced from.""" - reference_data: Dict[EntityName, ReferenceConfigUnion] = Field(default_factory=dict) + reference_data: dict[EntityName, ReferenceConfigUnion] = Field(default_factory=dict) """Configuration options for reference data.""" - parameters: Dict[str, Any] = Field(default_factory=dict) + parameters: dict[str, Any] = Field(default_factory=dict) """Global parameters to be passed to rules for templating.""" - rules: List[StepConfigUnion] = Field(default_factory=list) + rules: list[StepConfigUnion] = Field(default_factory=list) """Pre-filter stage rules.""" - filters: List[FilterConfigUnion] = Field(default_factory=list) + filters: list[FilterConfigUnion] = Field(default_factory=list) """Filter stage rules.""" - post_filter_rules: List[StepConfigUnion] = Field(default_factory=list) + post_filter_rules: list[StepConfigUnion] = Field(default_factory=list) """Post-filter stage rules/""" - complex_rules: List[_ComplexRuleConfig] = Field(default_factory=list) + complex_rules: list[_ComplexRuleConfig] = Field(default_factory=list) """Complex rules.""" @@ -173,13 +173,13 @@ class V1EngineConfig(BaseEngineConfig): """The data contract configuration.""" transformations: V1TransformationConfig = Field(default_factory=V1TransformationConfig) """The transformation/rules configuration.""" - _rule_store_rules: Dict[RuleName, BusinessComponentSpecConfigUnion] = PrivateAttr( + _rule_store_rules: dict[RuleName, BusinessComponentSpecConfigUnion] = PrivateAttr( default_factory=dict ) """Rule store rules from the loaded rule stores.""" @validate_arguments - def _update_rule_store(self, rule_store: Dict[RuleName, BusinessComponentSpecConfigUnion]): + def _update_rule_store(self, rule_store: dict[RuleName, BusinessComponentSpecConfigUnion]): """Update the rule store rules to add/override the rules from the new store.""" self._rule_store_rules.update(rule_store) @@ -204,7 +204,7 @@ def __init__(self, *args, **kwargs): def _resolve_business_filter( self, config: BusinessFilterConfig - ) -> Tuple[ConcreteFilterConfig, TemplateVariables]: + ) -> tuple[ConcreteFilterConfig, TemplateVariables]: """Resolve a business filter and create a concrete filter.""" local_params: TemplateVariables = config.parameters.copy() @@ -222,10 +222,10 @@ def _resolve_business_filter( def _create_rule( self, name: str, - rules: List[StepConfigUnion], - filters: List[FilterConfigUnion], - post_filter_rules: List[StepConfigUnion], - ) -> Tuple[Rule, TemplateVariables]: + rules: list[StepConfigUnion], + filters: list[FilterConfigUnion], + post_filter_rules: list[StepConfigUnion], + ) -> tuple[Rule, TemplateVariables]: """Create a rule from the config types, returning the rule and any template vars from the filters. @@ -252,7 +252,7 @@ def _create_rule( def _resolve_business_rule( self, config: _ComplexRuleConfig - ) -> Tuple[Rule, TemplateVariables, RuleDependencies]: + ) -> tuple[Rule, TemplateVariables, RuleDependencies]: """Load a complex business rule spec to a rule.""" rule_spec = self._rule_store_rules[config.rule_name] if not isinstance(rule_spec, BusinessRuleSpecConfig): @@ -269,10 +269,10 @@ def _resolve_business_rule( local_params.update(new_local_params) return rule, local_params, set(rule_spec.dependencies) - def _load_rules_and_vars(self) -> Tuple[List[Rule], List[TemplateVariables]]: + def _load_rules_and_vars(self) -> tuple[list[Rule], list[TemplateVariables]]: """Load the rules and local variables for the transformations.""" rules, local_variable_list = [], [] - added_rules: Set[RuleName] = set() + added_rules: set[RuleName] = set() for index, complex_rule_config in enumerate(self.transformations.complex_rules): rule, local_params, deps = self._resolve_business_rule(complex_rule_config) @@ -329,7 +329,7 @@ def load_error_message_info(self, uri): with open_stream(joinuri(uri_prefix, uri)) as stream: return json.load(stream) - def get_reference_data_config(self) -> Dict[EntityName, ReferenceConfig]: # type: ignore + def get_reference_data_config(self) -> dict[EntityName, ReferenceConfig]: # type: ignore """Gets the reference data configuration from the transformations""" return self.transformations.reference_data diff --git a/src/dve/core_engine/configuration/v1/filters.py b/src/dve/core_engine/configuration/v1/filters.py index 64a815c..c25bca4 100644 --- a/src/dve/core_engine/configuration/v1/filters.py +++ b/src/dve/core_engine/configuration/v1/filters.py @@ -1,6 +1,6 @@ """The loader for the first JSON-based dataset configuration.""" -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from pydantic import BaseModel, Field @@ -21,7 +21,7 @@ class ConcreteFilterConfig(BaseModel): error_code: Optional[str] = None error_location: Optional[str] = None reporting_entity: Optional[str] = None - reporting_field: Optional[Union[str, List[str]]] = None + reporting_field: Optional[Union[str, list[str]]] = None reporting_field_name: Optional[str] = None category: ErrorCategory = "Bad value" @@ -52,7 +52,7 @@ class BusinessFilterConfig(BaseModel): rule_name: str """The name of the business rule.""" - parameters: Dict[str, Any] = Field(default_factory=dict) + parameters: dict[str, Any] = Field(default_factory=dict) """Parameters for the business rule.""" diff --git a/src/dve/core_engine/configuration/v1/rule_stores/models.py b/src/dve/core_engine/configuration/v1/rule_stores/models.py index f42d71c..67d9be6 100644 --- a/src/dve/core_engine/configuration/v1/rule_stores/models.py +++ b/src/dve/core_engine/configuration/v1/rule_stores/models.py @@ -1,6 +1,6 @@ """Models for components in the rule stores.""" -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from pydantic import BaseModel, Field from typing_extensions import Annotated, Literal @@ -17,9 +17,9 @@ class BusinessSpecConfig(BaseModel): description: Optional[str] = None """A description of what the rule/filter should do.""" - parameter_descriptions: Dict[str, str] = Field(default_factory=dict) + parameter_descriptions: dict[str, str] = Field(default_factory=dict) """Descriptions of parameters used by the rule.""" - parameter_defaults: Dict[str, Any] = Field(default_factory=dict) + parameter_defaults: dict[str, Any] = Field(default_factory=dict) """Default parameters to be used by the rule if no param is passed.""" @@ -35,9 +35,9 @@ class BusinessFilterSpecConfig(BusinessSpecConfig): class ComplexRuleConfig(BaseModel): """The rule config for a business rule.""" - rules: List[StepConfigUnion] = Field(default_factory=list) - filters: List[FilterConfigUnion] = Field(default_factory=list) - post_filter_rules: List[StepConfigUnion] = Field(default_factory=list) + rules: list[StepConfigUnion] = Field(default_factory=list) + filters: list[FilterConfigUnion] = Field(default_factory=list) + post_filter_rules: list[StepConfigUnion] = Field(default_factory=list) class BusinessRuleSpecConfig(BusinessSpecConfig): @@ -47,7 +47,7 @@ class BusinessRuleSpecConfig(BusinessSpecConfig): rule_config: ComplexRuleConfig """The configuration for the rule.""" - dependencies: List[str] = Field(default_factory=list) + dependencies: list[str] = Field(default_factory=list) """The dependencies for the business rule.""" diff --git a/src/dve/core_engine/configuration/v1/steps.py b/src/dve/core_engine/configuration/v1/steps.py index 5a14394..f795c8c 100644 --- a/src/dve/core_engine/configuration/v1/steps.py +++ b/src/dve/core_engine/configuration/v1/steps.py @@ -10,7 +10,7 @@ # pylint: disable=missing-class-docstring from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from pydantic import BaseModel, Extra, Field, validator from typing_extensions import Annotated, Literal @@ -104,12 +104,12 @@ class GroupByConfig(ConfigStep): new_entity_name: Optional[str] = None group_by: MultipleExpressions pivot_column: Optional[str] = None - pivot_values: Optional[List[str]] = None + pivot_values: Optional[list[str]] = None agg_columns: MultipleExpressions @validator("pivot_values") @classmethod - def _ensure_no_values_if_not_column(cls, value: Optional[str], values: Dict[str, Any]): + def _ensure_no_values_if_not_column(cls, value: Optional[str], values: dict[str, Any]): if value and not values["pivot_column"]: raise ValueError("Cannot provide 'pivot_values' if no 'pivot_column'") return value @@ -336,7 +336,7 @@ class RemoveEntityConfig(ConfigStep): operation: Literal["remove_entity", "remove_entities"] - entity: Union[str, List[str]] + entity: Union[str, list[str]] def to_step(self) -> AbstractStep: """Takes a config object and returns a step object""" diff --git a/src/dve/core_engine/engine.py b/src/dve/core_engine/engine.py index 06d5e18..87ab0b6 100644 --- a/src/dve/core_engine/engine.py +++ b/src/dve/core_engine/engine.py @@ -8,7 +8,7 @@ from pathlib import Path from tempfile import NamedTemporaryFile from types import TracebackType -from typing import Any, Dict, Optional, Tuple, Type, Union +from typing import Any, Optional, Union from pydantic import BaseModel, Field, PrivateAttr, validate_arguments, validator from pydantic.types import FilePath @@ -84,7 +84,7 @@ def __init__(self, *args, **kwargs): @validator("backend", always=True) @classmethod - def _ensure_backend(cls, value: Optional[BaseBackend], values: Dict[str, Any]) -> BaseBackend: + def _ensure_backend(cls, value: Optional[BaseBackend], values: dict[str, Any]) -> BaseBackend: """Ensure a default backend is created if a backend is not specified.""" if value is not None: return value @@ -100,7 +100,7 @@ def _ensure_backend(cls, value: Optional[BaseBackend], values: Dict[str, Any]) - ) @classmethod - @validate_arguments(config=dict(arbitrary_types_allowed=True)) + @validate_arguments(config={"arbitrary_types_allowed": True}) def build( cls, dataset_config_path: Union[FilePath, URI], @@ -172,7 +172,7 @@ def __enter__(self) -> "CoreEngine": def __exit__( self, - exc_type: Optional[Type[Exception]], + exc_type: Optional[type[Exception]], exc_value: Optional[Exception], traceback: Optional[TracebackType], ) -> None: @@ -268,7 +268,7 @@ def _write_exception_report(self, messages: Messages) -> None: def _write_outputs( self, entities: SparkEntities, messages: Messages, verbose: bool = False - ) -> Tuple[SparkEntities, Messages]: + ) -> tuple[SparkEntities, Messages]: """Write the outputs from the pipeline, returning the written entities and messages. @@ -300,12 +300,12 @@ def _show_available_entities(self, entities: SparkEntities, *, verbose: bool = F def run_pipeline( self, - entity_locations: Dict[EntityName, URI], + entity_locations: dict[EntityName, URI], *, verbose: bool = False, # pylint: disable=unused-argument submission_info: Optional[SubmissionInfo] = None, - ) -> Tuple[SparkEntities, Messages]: + ) -> tuple[SparkEntities, Messages]: """Run the pipeline, reading in the entities and applying validation and transformation rules, and then write the outputs. diff --git a/src/dve/core_engine/exceptions.py b/src/dve/core_engine/exceptions.py index 6fdb801..dae8647 100644 --- a/src/dve/core_engine/exceptions.py +++ b/src/dve/core_engine/exceptions.py @@ -1,6 +1,6 @@ """Exceptions emitted by the pipeline.""" -from typing import Iterator +from collections.abc import Iterator from dve.core_engine.backends.implementations.spark.types import SparkEntities from dve.core_engine.message import FeedbackMessage diff --git a/src/dve/core_engine/message.py b/src/dve/core_engine/message.py index daebe41..d81acde 100644 --- a/src/dve/core_engine/message.py +++ b/src/dve/core_engine/message.py @@ -2,11 +2,12 @@ import copy import datetime as dt -import json import operator +import json +from collections.abc import Callable from decimal import Decimal from functools import reduce -from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, ClassVar, Optional, Union from pydantic import BaseModel, ValidationError, validator from pydantic.dataclasses import dataclass @@ -32,8 +33,8 @@ class DataContractErrorDetail(BaseModel): def template_message( self, - variables: Dict[str, Any], - error_location: Optional[Tuple[Union[str, int], ...]] = None, + variables: dict[str, Any], + error_location: Optional[tuple[Union[str, int], ...]] = None, ) -> Optional[str]: """Template error messages with values from the record""" if error_location: @@ -53,7 +54,7 @@ def extract_error_value(records, error_location): return _records -DEFAULT_ERROR_DETAIL: Dict[ErrorCategory, DataContractErrorDetail] = { +DEFAULT_ERROR_DETAIL: dict[ErrorCategory, DataContractErrorDetail] = { "Blank": DataContractErrorDetail(error_code="FieldBlank", error_message="cannot be blank"), "Bad value": DataContractErrorDetail(error_code="BadValue", error_message="is invalid"), "Wrong format": DataContractErrorDetail( @@ -62,12 +63,12 @@ def extract_error_value(records, error_location): } -INTEGRITY_ERROR_CODES: Set[str] = {"blockingsubmission"} +INTEGRITY_ERROR_CODES: set[str] = {"blockingsubmission"} """ Error types which should raise integrity errors if encountered. """ -SUBMISSION_ERROR_CODES: Set[str] = {"submission"} +SUBMISSION_ERROR_CODES: set[str] = {"submission"} """ Error types which should raise submission errors if encountered. @@ -114,7 +115,7 @@ class FeedbackMessage: # pylint: disable=too-many-instance-attributes """The error message.""" error_code: Optional[str] = None """ETOS Error code for the error.""" - reporting_field: Union[str, List[str], None] = None + reporting_field: Union[str, list[str], None] = None """The field that the error pertains to.""" reporting_field_name: Optional[str] = None """ @@ -127,7 +128,7 @@ class FeedbackMessage: # pylint: disable=too-many-instance-attributes category: Optional[ErrorCategory] = None """The category of the error.""" - HEADER: ClassVar[List[str]] = [ + HEADER: ClassVar[list[str]] = [ "Entity", "Key", "FailureType", @@ -144,7 +145,7 @@ class FeedbackMessage: # pylint: disable=too-many-instance-attributes @validator("reporting_field") # pylint: disable=no-self-argument - def _split_reporting_field(cls, value) -> Union[List[str], str, None]: + def _split_reporting_field(cls, value) -> Union[list[str], str, None]: if isinstance(value, list): return value if isinstance(value, str): @@ -176,8 +177,8 @@ def _validate_error_location(cls, value: Any) -> Optional[str]: @validator("record") def _strip_rowid( # pylint: disable=no-self-argument - cls, value: Optional[Dict[str, Any]] - ) -> Optional[Dict[str, Any]]: + cls, value: Optional[dict[str, Any]] + ) -> Optional[dict[str, Any]]: """Strip the row ID column from the record, if present.""" if isinstance(value, dict): value.pop(ROWID_COLUMN_NAME, None) @@ -190,12 +191,12 @@ def is_critical(self) -> bool: @classmethod def from_pydantic_error( - cls: Type["FeedbackMessage"], + cls: type["FeedbackMessage"], entity: str, record: Record, error: ValidationError, error_details: Optional[ - Dict[FieldName, Dict[ErrorCategory, DataContractErrorDetail]] + dict[FieldName, dict[ErrorCategory, DataContractErrorDetail]] ] = None, ) -> Messages: """Create messages from a `pydantic` validation error.""" @@ -301,9 +302,9 @@ def _extract_value( loc = self.error_location if loc: # this is because, for some reason, even if error_location is set to be - # a List[str] or Tuple[str] and set smart_unions to be True, it still + # a list[str] or tuple[str] and set smart_unions to be True, it still # always comes in as a string - loc_items: List[Union[str, int]] = [ + loc_items: list[Union[str, int]] = [ field if not field.isnumeric() else int(field.strip()) for field in loc.strip("()").replace("'", "").replace(" ", "").split(",") ] @@ -346,11 +347,11 @@ def _extract_value( def _multi_reporting_fields( self, - reporting_field: List[str], + reporting_field: list[str], max_number_of_values: Optional[int], value_separator: str, loc: Optional[str], - loc_items: List[Union[str, int]], + loc_items: list[Union[str, int]], ) -> Any: value: Any @@ -409,7 +410,7 @@ def _cond_str(value: Any, str_none: bool = False) -> Optional[str]: return str(value) def _string_values( - self, reporting_field: Union[List[str], str], values: List[Any], value_separator: str + self, reporting_field: Union[list[str], str], values: list[Any], value_separator: str ) -> str: if all(isinstance(item, dict) for item in values) and isinstance(reporting_field, str): values = [ @@ -436,7 +437,7 @@ def to_dict( max_number_of_values: Optional[int] = None, value_separator: str = ", ", record_converter: Optional[Callable] = repr, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Create a reporting dict from the message.""" return dict( zip( diff --git a/src/dve/core_engine/models.py b/src/dve/core_engine/models.py index 17ac3be..2e6578f 100644 --- a/src/dve/core_engine/models.py +++ b/src/dve/core_engine/models.py @@ -7,8 +7,9 @@ import json import os import uuid +from collections.abc import MutableMapping from pathlib import Path, PurePath -from typing import Any, Dict, List, MutableMapping, Optional +from typing import Any, Optional from pydantic import UUID4, BaseModel, Field, FilePath, root_validator, validator @@ -83,7 +84,7 @@ def from_metadata_file(cls, submission_id: str, metadata_uri: Location): metadata_uri = resolve_location(metadata_uri) with open_stream(metadata_uri, "r", "utf-8") as stream: try: - metadata_dict: Dict[str, Any] = json.load(stream) + metadata_dict: dict[str, Any] = json.load(stream) except json.JSONDecodeError as exc: raise ValueError(f"File found at {metadata_uri!r} is not valid JSON") from exc @@ -179,14 +180,14 @@ class EngineRunValidation(EngineRun): class ConcreteEntity(EntitySpecification, arbitrary_types_allowed=True): """An entity which has a configured reader and (possibly) a key field.""" - reader_config: Dict[Extension, ReaderConfig] + reader_config: dict[Extension, ReaderConfig] """A reader configuration for the entity.""" key_field: Optional[str] = None """An optional key field to use for the entity.""" - reporting_fields: Optional[List[str]] = None + reporting_fields: Optional[list[str]] = None @validator("reporting_fields", pre=True) - def _ensure_list(cls, value: Optional[str]) -> Optional[List[str]]: # pylint: disable=E0213 + def _ensure_list(cls, value: Optional[str]) -> Optional[list[str]]: # pylint: disable=E0213 """Ensure the reporting fields are a list.""" if value is None: return None diff --git a/src/dve/core_engine/type_hints.py b/src/dve/core_engine/type_hints.py index 8320b29..a6c0c44 100644 --- a/src/dve/core_engine/type_hints.py +++ b/src/dve/core_engine/type_hints.py @@ -1,26 +1,18 @@ """Type aliases for the core engine.""" +from collections.abc import Callable, MutableMapping from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from multiprocessing import Queue as ProcessQueue from pathlib import Path from queue import Queue as ThreadQueue -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - List, - MutableMapping, - Optional, - Tuple, - TypeVar, - Union, -) +from typing import TYPE_CHECKING, Any, List, Optional, TypeVar, Union # pylint: disable=W1901 +# TODO - cannot remove List from Typing. See L60 for details. from pyspark.sql import DataFrame from pyspark.sql.types import StructType from typing_extensions import Literal, ParamSpec, get_args + if TYPE_CHECKING: # pragma: no cover from dve.core_engine.message import FeedbackMessage @@ -28,7 +20,7 @@ """The name of a field within a record""" ErrorValue = Any """The value contained in a specific field.""" -Record = Dict[Field, ErrorValue] +Record = dict[Field, ErrorValue] """A record within an entity.""" PathStr = str @@ -59,18 +51,21 @@ """The locations of entities as Parquet.""" Messages = List["FeedbackMessage"] """A queue of messages returned by a process.""" +# todo - issue ^^ where converting to list["FeedbackMessage"] breaks get_type_hints in +# todo - base/rules.py:113. Not sure entirely why this is the case atm. Will raise a ticket +# todo - to resolve in the future. Alias = str """A column alias.""" Expression = str """An SQL expression.""" -ExpressionMapping = Dict[Expression, Union[Alias, List[Alias]]] +ExpressionMapping = dict[Expression, Union[Alias, list[Alias]]] """ A mapping of expression to alias. Some expressions (e.g. `posexplode`) require multiple aliases, which can be passed as a list. """ -ExpressionArray = List[Expression] +ExpressionArray = list[Expression] """An array of expressions with aliases in SQL (e.g. using `expression AS alias`).""" MultiExpression = str """ @@ -103,16 +98,16 @@ DeprecationMessage = str """A message indicating the reason/context for a rule's deprecation.""" -ContractContents = Dict[str, Any] +ContractContents = dict[str, Any] """A JSON mapping containing the data contract for a dataset.""" SparkSchema = StructType """The Spark schema for a given dataset.""" KeyField = Optional[str] """The name of the field containing the record field.""" -ReportingFields = List[Optional[str]] +ReportingFields = list[Optional[str]] """Field(s) used to identify records without a single identifying field.""" -Key = Union[str, Dict[str, Any], None] +Key = Union[str, dict[str, Any], None] """ If no record is attached to the message, then `None`. Otherwise, the value of the record in `KeyField`, or the whole row as a dict or repr for the consumer to determine how to handle @@ -141,7 +136,7 @@ ErrorCategory = Literal["Blank", "Wrong format", "Bad value", "Bad file"] """A string indicating the category of the error.""" -MessageTuple = Tuple[ +MessageTuple = tuple[ Optional[EntityName], Key, FailureType, @@ -172,7 +167,7 @@ ] """A union of the types of values that can be contained in a message.""" "" -MessageDict = Dict[MessageKeys, MessageValues] +MessageDict = dict[MessageKeys, MessageValues] """A dictionary representing the information from a message.""" JSONstring = str @@ -180,9 +175,9 @@ JSONBaseType = Union[str, int, float, bool, None] """The fundamental allowed types in JSON.""" # mypy doesn't support recursive type definitions. -JSONable = Union[Dict[str, "JSONable"], List["JSONable"], JSONBaseType] # type: ignore +JSONable = Union[dict[str, "JSONable"], list["JSONable"], JSONBaseType] # type: ignore """A recursive description of the types that come from parsing JSON.""" -JSONDict = Dict[str, JSONable] # type: ignore +JSONDict = dict[str, JSONable] # type: ignore """A JSON dictionary.""" Source = DataFrame @@ -196,7 +191,7 @@ """The name of a template variable.""" TemplateVariableValue = Any """The value of a template variable.""" -TemplateVariables = Dict[TemplateVariableName, TemplateVariableValue] +TemplateVariables = dict[TemplateVariableName, TemplateVariableValue] """Variables for templating.""" FP = ParamSpec("FP") @@ -238,13 +233,13 @@ ] """Allowed statuses for DVE submission""" -PROCESSING_STATUSES: Tuple[ProcessingStatus, ...] = tuple(list(get_args(ProcessingStatus))) +PROCESSING_STATUSES: tuple[ProcessingStatus, ...] = tuple(list(get_args(ProcessingStatus))) """List of all possible DVE submission statuses""" SubmissionResult = Literal["success", "failed", "failed_xml_generation", "archived"] """Allowed DVE submission results""" -SUBMISSION_RESULTS: Tuple[SubmissionResult, ...] = tuple(list(get_args(SubmissionResult))) +SUBMISSION_RESULTS: tuple[SubmissionResult, ...] = tuple(list(get_args(SubmissionResult))) """List of possible DVE submission results""" BinaryComparator = Callable[[Any, Any], bool] diff --git a/src/dve/core_engine/validation.py b/src/dve/core_engine/validation.py index d21e7e2..2be101e 100644 --- a/src/dve/core_engine/validation.py +++ b/src/dve/core_engine/validation.py @@ -1,7 +1,7 @@ """XML schema/contract configuration.""" import warnings -from typing import Dict, List, Optional, Tuple +from typing import Optional from pydantic import ValidationError from pydantic.main import ModelMetaclass @@ -36,7 +36,7 @@ def __init__( self._model: Optional[ModelMetaclass] = None self._error_info = error_info or {} self._error_details: Optional[ - Dict[FieldName, Dict[ErrorCategory, DataContractErrorDetail]] + dict[FieldName, dict[ErrorCategory, DataContractErrorDetail]] ] = None def __reduce__(self): # Don't attempt to pickle Pydantic models. @@ -61,7 +61,7 @@ def model(self) -> ModelMetaclass: return self._model @property - def error_details(self) -> Dict[FieldName, Dict[ErrorCategory, DataContractErrorDetail]]: + def error_details(self) -> dict[FieldName, dict[ErrorCategory, DataContractErrorDetail]]: """Custom error code and message mapping for contract phase""" if not self._error_details: _error_details = { @@ -74,7 +74,7 @@ def error_details(self) -> Dict[FieldName, Dict[ErrorCategory, DataContractError self._error_details = _error_details return self._error_details - def __call__(self, record: Record) -> Tuple[Optional[Record], Messages]: + def __call__(self, record: Record) -> tuple[Optional[Record], Messages]: """Take a record, returning a validated record (is successful) and a list of messages.""" with warnings.catch_warnings(record=True) as caught_warnings: messages: Messages = [] @@ -100,9 +100,9 @@ def __call__(self, record: Record) -> Tuple[Optional[Record], Messages]: return validated, messages - def handle_warnings(self, record, caught_warnings) -> List[FeedbackMessage]: + def handle_warnings(self, record, caught_warnings) -> list[FeedbackMessage]: """Handle warnings from the pydantic validation.""" - messages: List[FeedbackMessage] = [] + messages: list[FeedbackMessage] = [] for warning_message in caught_warnings: warning = warning_message.message diff --git a/src/dve/metadata_parser/domain_types.py b/src/dve/metadata_parser/domain_types.py index b0cd730..c944278 100644 --- a/src/dve/metadata_parser/domain_types.py +++ b/src/dve/metadata_parser/domain_types.py @@ -5,8 +5,9 @@ import itertools import re import warnings +from collections.abc import Iterator, Sequence from functools import lru_cache -from typing import ClassVar, Dict, Iterator, Optional, Sequence, Tuple, Type, TypeVar, Union +from typing import ClassVar, Optional, TypeVar, Union from pydantic import fields, types, validate_arguments from typing_extensions import Literal @@ -79,7 +80,7 @@ class NHSNumber(types.ConstrainedStr): """ - SENTINEL_VALUES: ClassVar[Dict[str, str]] = { + SENTINEL_VALUES: ClassVar[dict[str, str]] = { "0000000000": "returned by MPS to indicate no match", "1111111111": "common example value given for patient-facing forms", "9999999999": "returned by MPS to indicate multiple matches", @@ -91,7 +92,7 @@ class NHSNumber(types.ConstrainedStr): """ - _FACTORS: ClassVar[Tuple[int, ...]] = (10, 9, 8, 7, 6, 5, 4, 3, 2) + _FACTORS: ClassVar[tuple[int, ...]] = (10, 9, 8, 7, 6, 5, 4, 3, 2) """Weights for the NHS number digits in the checksum.""" warn_on_test_numbers = True @@ -286,7 +287,7 @@ def conformatteddate( le: Optional[dt.date] = None, # pylint: disable=invalid-name gt: Optional[dt.date] = None, # pylint: disable=invalid-name lt: Optional[dt.date] = None, # pylint: disable=invalid-name -) -> Type[ConFormattedDate]: +) -> type[ConFormattedDate]: """Return a formatted date class with a set date format and timezone treatment. @@ -396,7 +397,7 @@ def __get_validators__(cls) -> Iterator[classmethod]: def formatteddatetime( date_format: Optional[str] = None, timezone_treatment: Literal["forbid", "permit", "require"] = "permit", -) -> Type[FormattedDatetime]: +) -> type[FormattedDatetime]: """Return a formatted datetime class with a set date format and timezone treatment. @@ -466,7 +467,7 @@ def __get_validators__(cls) -> Iterator[classmethod]: @validate_arguments def reportingperiod( reporting_period_type: Literal["start", "end"], date_format: Optional[str] = "%Y-%m-%d" -) -> Type[ReportingPeriod]: +) -> type[ReportingPeriod]: """Return a check on whether a reporting period date is a valid date, and is the start/ end of the month supplied depending on reporting period type """ @@ -482,7 +483,7 @@ def reportingperiod( def alphanumeric( min_digits: types.NonNegativeInt = 1, max_digits: types.PositiveInt = 1, -) -> Type[_SimpleRegexValidator]: +) -> type[_SimpleRegexValidator]: """Return a regex-validated class which will ensure that passed numbers are alphanumeric. @@ -510,7 +511,7 @@ def alphanumeric( def identifier( min_digits: types.NonNegativeInt = 1, max_digits: types.PositiveInt = 1, -) -> Type[_SimpleRegexValidator]: +) -> type[_SimpleRegexValidator]: """ Return a regex-validated class which will ensure that passed strings are alphanumeric or in a fixed set of diff --git a/src/dve/metadata_parser/function_wrapper.py b/src/dve/metadata_parser/function_wrapper.py index 5d3d37b..eb69546 100644 --- a/src/dve/metadata_parser/function_wrapper.py +++ b/src/dve/metadata_parser/function_wrapper.py @@ -1,19 +1,20 @@ """Wrapping functions for wrapping generic functions""" import warnings -from typing import Any, Callable, Dict, Iterable, Optional, Type, Union +from collections.abc import Callable, Iterable +from typing import Any, Optional, Union import pydantic from dve.metadata_parser import exc PydanticCompatible = Callable[ - [Any, Dict[str, Any], pydantic.fields.ModelField, pydantic.BaseConfig], Any + [Any, dict[str, Any], pydantic.fields.ModelField, pydantic.BaseConfig], Any ] """Function Compatable with pydantic Args: value (Any): Value to be validated - values (Dict[str, Any]): dict of previously validated fields + values (dict[str, Any]): dict of previously validated fields field (pydantic.fields.ModelField): field object containing field name and type config (pydantic.BaseConfig): the config that determines things like aliases @@ -21,14 +22,14 @@ def error_handler( - error_type: Union[Type[Exception], Type[Warning]], + error_type: Union[type[Exception], type[Warning]], error_message: str, field: pydantic.fields.ModelField, ): """Determines whether to raise an error or warning based on error_type Args: - error_type (Union[Type[Exception], Type[Warning]]): type of error to raise + error_type (Union[type[Exception], type[Warning]]): type of error to raise error_message (str): message to apply field (pydantic.fields.ModelField): field that caused the error to be raised @@ -45,7 +46,7 @@ def error_handler( def pydantic_wrapper( - error_type: Union[Type[Exception], Type[Warning]], + error_type: Union[type[Exception], type[Warning]], error_message: str, *field_names: str, failure_function: Callable = lambda x: x is False, @@ -61,7 +62,7 @@ def pydantic_wrapper( takes a function that will result in the passed exception being raised Args: - error_type (Type[Exception]): The exception type to be raised if the failure_function + error_type (type[Exception]): The exception type to be raised if the failure_function evaluates to True error_message (str): Message to be passed to the above exception failure_function (Optional[Callable]): A callable that when it evaluates to True @@ -95,7 +96,7 @@ def wrapper( def inner( value: Any, - values: Dict[str, Any], + values: dict[str, Any], field: pydantic.fields.ModelField, # pylint: disable=unused-argument config: pydantic.BaseConfig, # pylint: disable=unused-argument ) -> Any: @@ -125,7 +126,7 @@ def inner( def create_validator( function: Callable, field: str, - error_type: Type[Exception], + error_type: type[Exception], error_message: str, fields: Optional[Iterable[str]] = None, return_result=True, @@ -137,7 +138,7 @@ def create_validator( function (Callable): function to wrap field (str): field validator is applier to fields (Iterable[str]): other fields to be included in validation (in order of arguments) - error_type (Type[Exception]): Error to be raised on failure + error_type (type[Exception]): Error to be raised on failure error_message (str): Message to be raised on failure kwargs: pydantic_wrapper_kwargs: diff --git a/src/dve/metadata_parser/model_generator.py b/src/dve/metadata_parser/model_generator.py index 34f88db..53a82d8 100644 --- a/src/dve/metadata_parser/model_generator.py +++ b/src/dve/metadata_parser/model_generator.py @@ -3,10 +3,11 @@ # pylint: disable=super-init-not-called import warnings from abc import ABCMeta, abstractmethod +from collections.abc import Mapping from copy import deepcopy # This _needs_ to be `typing.Mapping`, or pydantic complains. -from typing import Any, Dict, Mapping, Optional, Union +from typing import Any, Optional, Union import pydantic as pyd from typing_extensions import Literal @@ -81,13 +82,13 @@ class ModelLoader(metaclass=ABCMeta): # pylint: disable=too-few-public-methods """An abstract model loader.""" @abstractmethod - def __init__(self, contract_contents: Dict[str, Any], type_map: Optional[dict] = None): + def __init__(self, contract_contents: dict[str, Any], type_map: Optional[dict] = None): raise NotImplementedError() @abstractmethod def generate_models( self, additional_validators: Optional[dict] = None - ) -> Dict[str, pyd.main.ModelMetaclass]: + ) -> dict[str, pyd.main.ModelMetaclass]: """Generates models from the instance schema. Args: @@ -95,7 +96,7 @@ def generate_models( those described in the schema. Defaults to None [DEPRECATED] Returns: - Dict[str, model]: dict of table names to pydantic models + dict[str, model]: dict of table names to pydantic models """ raise NotImplementedError() @@ -104,13 +105,13 @@ def generate_models( class JSONtoPyd(ModelLoader): # pylint: disable=too-few-public-methods """Generate pydantic model from a JSON schema.""" - def __init__(self, contract_contents: Dict[str, Any], type_map: Optional[dict] = None): + def __init__(self, contract_contents: dict[str, Any], type_map: Optional[dict] = None): self.contract_contents = contract_contents self.type_map = deepcopy(type_map or STR_TO_PY_MAPPING) def generate_models( self, additional_validators: Optional[dict] = None - ) -> Dict[str, pyd.main.ModelMetaclass]: + ) -> dict[str, pyd.main.ModelMetaclass]: """Generates pydantic models from a loaded json file""" if additional_validators: warnings.warn("Ignoring additional validator functions") diff --git a/src/dve/metadata_parser/models.py b/src/dve/metadata_parser/models.py index a861e55..18cdc68 100644 --- a/src/dve/metadata_parser/models.py +++ b/src/dve/metadata_parser/models.py @@ -4,7 +4,8 @@ import datetime as dt import warnings from collections import Counter -from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Set, Tuple, Union +from collections.abc import Mapping, MutableMapping +from typing import Any, Optional, Union import pydantic as pyd from pydantic import BaseModel, Field, root_validator, validator @@ -30,7 +31,7 @@ """A pydantic-appropriate type.""" ValidatorName = str """The name of a validator.""" -Validators = Dict[ValidatorName, classmethod] +Validators = dict[ValidatorName, classmethod] """The validators for a class.""" @@ -57,9 +58,9 @@ class ValidationFunctionSpecification(BaseModel): # type: ignore """The type of error/warning to emit if the function fails.""" error_message: str = None # type: ignore """The message to associate with the error.""" - fields: List[str] = Field(default_factory=list) + fields: list[str] = Field(default_factory=list) """Fields to include in the validator.""" - kwargs_: Dict[str, Any] = Field(default_factory=dict, alias="kwargs") + kwargs_: dict[str, Any] = Field(default_factory=dict, alias="kwargs") """Keyword arguments for the validation function.""" @validator("name", allow_reuse=True) @@ -70,8 +71,8 @@ def validate_name(cls, value: str) -> str: return value @validator("error_message", allow_reuse=True) - def validate_error_message(cls, value: str, values: Dict[str, Any]) -> str: - """Set a default error message if one is not available.""" + def validate_error_message(cls, value: str, values: dict[str, Any]) -> str: + """set a default error message if one is not available.""" if value: return value name: str = values["name"] @@ -115,7 +116,7 @@ class FieldSpecification(BaseModel): a Python type. This is mututally exclusive with 'type' and 'model'. """ - constraints: Dict[str, Any] = Field(default_factory=dict) + constraints: dict[str, Any] = Field(default_factory=dict) """Keyword arguments to be used with 'callable'.""" is_array: bool = False """ @@ -124,11 +125,11 @@ class FieldSpecification(BaseModel): """ default: Any = None """A default value for the field, to be used if it is not provided.""" - functions: List[ValidationFunctionSpecification] = Field(default_factory=list) + functions: list[ValidationFunctionSpecification] = Field(default_factory=list) """Validation functions to be applied to the type.""" @root_validator(allow_reuse=True) - def ensure_one_type_spec_method(cls, values: Dict[str, Any]) -> Dict[str, Any]: + def ensure_one_type_spec_method(cls, values: dict[str, Any]) -> dict[str, Any]: """Ensure that exactly one of 'type', 'model' and 'callable' was specified.""" has_type = bool(values.get("type_")) has_model = bool(values.get("model")) @@ -164,7 +165,7 @@ def ensure_one_type_spec_method(cls, values: Dict[str, Any]) -> Dict[str, Any]: return values @validator("default", allow_reuse=True) - def validate_default(cls, value: Any, values: Dict[str, Any]) -> Any: + def validate_default(cls, value: Any, values: dict[str, Any]) -> Any: """Validate that 'default' is aligned with 'is_array'.""" if value is None: return value @@ -211,9 +212,9 @@ def get_type_and_validators( self, field_name: str, *type_mappings: Mapping[TypeName, FieldTypeOption], - schemas: Optional[Dict[EntityName, pyd.main.ModelMetaclass]] = None, + schemas: Optional[dict[EntityName, pyd.main.ModelMetaclass]] = None, is_mandatory: bool = False, - ) -> Tuple[PydanticType, Default, Validators]: + ) -> tuple[PydanticType, Default, Validators]: """Get the type, default value, and validators for the specification.""" default: Optional[Default] = self.default validators = self._get_validators(field_name) @@ -231,7 +232,7 @@ def get_type_and_validators( if nested_validators and self.is_array: # Need to work out how to hook into the validators and update - # them to take List[T] instead of T. Probably create validators + # them to take list[T] instead of T. Probably create validators # and wrap them later in `EntitySpecification` raise ValueError( f"{field_name!r}: Unable to create array of standard type with validators" @@ -262,29 +263,29 @@ def get_type_and_validators( default = default or (... if is_mandatory else None) if self.is_array: - python_type = List[python_type] # type: ignore + python_type = list[python_type] # type: ignore return python_type, default, validators class EntitySpecification(BaseModel): """Configuration options for an entity.""" - fields: Dict[FieldName, FieldSpecification] + fields: dict[FieldName, FieldSpecification] """ A mapping of field names to their Python types. These will either be strings representing Python types (if there are no argumements to the type), and field specification objects otherwise. """ - aliases: Dict[FieldName, FieldAlias] = Field(default_factory=dict) + aliases: dict[FieldName, FieldAlias] = Field(default_factory=dict) """A mapping of field name to allowed field alias.""" - mandatory_fields: List[FieldName] = Field(default_factory=list) + mandatory_fields: list[FieldName] = Field(default_factory=list) """An array of field names which should be considered mandatory.""" @validator("fields", pre=True, allow_reuse=True) def validate_fields( - cls, value: Dict[FieldName, Union[TypeName, FieldSpecification]] - ) -> Dict[FieldName, FieldSpecification]: + cls, value: dict[FieldName, Union[TypeName, FieldSpecification]] + ) -> dict[FieldName, FieldSpecification]: """Convert bare string fields to field specifications.""" for key in value: type_spec = value[key] @@ -295,8 +296,8 @@ def validate_fields( @validator("aliases", allow_reuse=True) def validate_aliases( - cls, value: Dict[FieldName, FieldAlias], values: Dict[str, Any] - ) -> Dict[FieldName, FieldAlias]: + cls, value: dict[FieldName, FieldAlias], values: dict[str, Any] + ) -> dict[FieldName, FieldAlias]: """Ensure that 'aliases' is aligned with 'fields'.""" # Check that aliases are not given more than once if not value: @@ -315,7 +316,7 @@ def validate_aliases( + f"more than once: {multiple_occurrences}" ) # And warn when unnecessary aliases were given. - field_names: Set[FieldName] = set(values["fields"].keys()) + field_names: set[FieldName] = set(values["fields"].keys()) missing_fields = set(value.keys()) - field_names if missing_fields: warnings.warn( @@ -328,13 +329,13 @@ def validate_aliases( @validator("mandatory_fields", allow_reuse=True) def validate_mandatory_fields( - cls, value: List[FieldName], values: Dict[str, Any] - ) -> List[FieldName]: + cls, value: list[FieldName], values: dict[str, Any] + ) -> list[FieldName]: """Ensure that 'mandatory_fields' is aligned with 'fields'.""" if not value: return value - field_names: Set[FieldName] = set(values["fields"].keys()) + field_names: set[FieldName] = set(values["fields"].keys()) missing_fields = set(value) - field_names if missing_fields: raise ValueError( @@ -348,7 +349,7 @@ def as_model( self, model_name: str, *type_mappings: Mapping[TypeName, FieldTypeOption], - schemas: Optional[Dict[EntityName, pyd.main.ModelMetaclass]] = None, + schemas: Optional[dict[EntityName, pyd.main.ModelMetaclass]] = None, ) -> pyd.main.ModelMetaclass: """Get the pydantic model from an entity definition.""" validators = {} @@ -383,9 +384,9 @@ class DatasetSpecification(BaseModel): """Configuration options for a dataset.""" cache_originals: bool = False - types: Dict[TypeName, FieldSpecification] = Field(default_factory=dict) + types: dict[TypeName, FieldSpecification] = Field(default_factory=dict) """Predefined types to be used within schema/dataset definitions.""" - schemas: Dict[EntityName, EntitySpecification] = Field(default_factory=dict) + schemas: dict[EntityName, EntitySpecification] = Field(default_factory=dict) """Predefined models to be used within dataset definitions.""" datasets: MutableMapping[EntityName, EntitySpecification] """Models which represent entities within the data.""" @@ -393,9 +394,9 @@ class DatasetSpecification(BaseModel): def load_models( self, *type_mappings: Mapping[TypeName, FieldTypeOption], - ) -> Dict[EntityName, pyd.main.ModelMetaclass]: + ) -> dict[EntityName, pyd.main.ModelMetaclass]: """Load the models from the dataset definition.""" - loaded_schemas: Dict[EntityName, pyd.main.ModelMetaclass] = {} + loaded_schemas: dict[EntityName, pyd.main.ModelMetaclass] = {} for model_name, specification in self.schemas.items(): loaded_schemas[model_name] = specification.as_model( model_name, self.types, *type_mappings, schemas=loaded_schemas diff --git a/src/dve/metadata_parser/utilities.py b/src/dve/metadata_parser/utilities.py index 756374e..0efa078 100644 --- a/src/dve/metadata_parser/utilities.py +++ b/src/dve/metadata_parser/utilities.py @@ -1,7 +1,8 @@ """Utility functions for the metadata parser.""" +from collections.abc import Mapping from types import ModuleType -from typing import TYPE_CHECKING, Any, Mapping, Union +from typing import TYPE_CHECKING, Any, Union from typing_extensions import Protocol @@ -14,8 +15,7 @@ class TypeCallable(Protocol): # pylint: disable=too-few-public-methods """A callable which returns a type.""" - def __call__(self, *args, **kwds: Any) -> type: - ... + def __call__(self, *args, **kwds: Any) -> type: ... FieldTypeOption = Union[type, TypeCallable, "FieldSpecification"] diff --git a/src/dve/parser/exceptions.py b/src/dve/parser/exceptions.py index 0ef345d..3e436c1 100644 --- a/src/dve/parser/exceptions.py +++ b/src/dve/parser/exceptions.py @@ -1,7 +1,6 @@ """Exceptions used in the file parser.""" from functools import partial -from typing import Dict class UnsupportedSchemeError(ValueError): @@ -30,7 +29,7 @@ class FieldCountMismatch(ValueError): class SDCSTemplateValidationFailure(ValueError): """An error to indicate a template validation failure.""" - def __init__(self, *args: object, errors: Dict[str, str]) -> None: + def __init__(self, *args: object, errors: dict[str, str]) -> None: super().__init__(*args) self.errors = errors """ diff --git a/src/dve/parser/file_handling/helpers.py b/src/dve/parser/file_handling/helpers.py index d797e93..9079e53 100644 --- a/src/dve/parser/file_handling/helpers.py +++ b/src/dve/parser/file_handling/helpers.py @@ -1,7 +1,6 @@ """Helpers for file handling.""" from io import TextIOWrapper -from typing import Tuple from urllib.parse import urlparse from dve.parser.type_hints import URI, Hostname, Scheme, URIPath @@ -22,7 +21,7 @@ def __exit__(self, *_): # pragma: no cover return -def parse_uri(uri: URI) -> Tuple[Scheme, Hostname, URIPath]: +def parse_uri(uri: URI) -> tuple[Scheme, Hostname, URIPath]: """Parse a URI, yielding the scheme, hostname and URI path.""" parse_result = urlparse(uri) scheme = parse_result.scheme.lower() or "file" # Assume missing scheme is file URI. diff --git a/src/dve/parser/file_handling/implementations/base.py b/src/dve/parser/file_handling/implementations/base.py index 0a2fd4c..5774d20 100644 --- a/src/dve/parser/file_handling/implementations/base.py +++ b/src/dve/parser/file_handling/implementations/base.py @@ -1,7 +1,8 @@ """An abstract implementation of the filesystem layer.""" from abc import ABCMeta, abstractmethod -from typing import IO, Iterable, Iterator, Set, Tuple +from collections.abc import Iterable, Iterator +from typing import IO from dve.parser.exceptions import FileAccessError from dve.parser.type_hints import URI, NodeType, Scheme @@ -12,7 +13,7 @@ class BaseFilesystemImplementation(metaclass=ABCMeta): @property @abstractmethod - def SUPPORTED_SCHEMES(self) -> Set[Scheme]: # pylint: disable=invalid-name + def SUPPORTED_SCHEMES(self) -> set[Scheme]: # pylint: disable=invalid-name """Schemes supported by the filesystem implementation.""" @abstractmethod @@ -35,7 +36,7 @@ def get_resource_exists(self, resource: URI) -> bool: """ @abstractmethod - def iter_prefix(self, prefix: URI, recursive: bool = False) -> Iterator[Tuple[URI, NodeType]]: + def iter_prefix(self, prefix: URI, recursive: bool = False) -> Iterator[tuple[URI, NodeType]]: """List the contents of a given prefix. Directory URIs should be returned with a trailing /. diff --git a/src/dve/parser/file_handling/implementations/file.py b/src/dve/parser/file_handling/implementations/file.py index 8aeba94..eeed3de 100644 --- a/src/dve/parser/file_handling/implementations/file.py +++ b/src/dve/parser/file_handling/implementations/file.py @@ -2,8 +2,9 @@ import platform import shutil +from collections.abc import Callable, Iterator from pathlib import Path -from typing import IO, Any, Callable, Dict, Iterator, NoReturn, Optional, Set, Tuple +from typing import IO, Any, NoReturn, Optional from urllib.parse import unquote from typing_extensions import Literal @@ -13,7 +14,7 @@ from dve.parser.file_handling.implementations.base import BaseFilesystemImplementation from dve.parser.type_hints import URI, NodeType, PathStr, Scheme -FILE_URI_SCHEMES: Set[Scheme] = {"file"} +FILE_URI_SCHEMES: set[Scheme] = {"file"} """A set of all allowed file URI schemes.""" @@ -57,7 +58,7 @@ def _path_to_uri(self, path: Path) -> URI: @staticmethod def _handle_error( - err: Exception, resource: URI, mode: str, extra_args: Optional[Dict[str, Any]] = None + err: Exception, resource: URI, mode: str, extra_args: Optional[dict[str, Any]] = None ) -> NoReturn: """Handle a local file opening error.""" message = f"Unable to access file at {resource!r} ({mode!r} mode, got {err!r})" @@ -113,7 +114,7 @@ def _iter_prefix_path(self, prefix_path: Path, recursive: bool = False) -> Itera if child.is_dir() and recursive: yield from self._iter_prefix_path(child, recursive) - def iter_prefix(self, prefix: URI, recursive: bool = False) -> Iterator[Tuple[URI, NodeType]]: + def iter_prefix(self, prefix: URI, recursive: bool = False) -> Iterator[tuple[URI, NodeType]]: """Iterates over the given prefix yielding any resources or directories""" try: for child in self._iter_prefix_path(self._uri_to_path(prefix), recursive): diff --git a/src/dve/parser/file_handling/implementations/s3.py b/src/dve/parser/file_handling/implementations/s3.py index 44fd70e..f565527 100644 --- a/src/dve/parser/file_handling/implementations/s3.py +++ b/src/dve/parser/file_handling/implementations/s3.py @@ -3,10 +3,11 @@ # pylint: disable=broad-except import os from collections import deque +from collections.abc import Iterator from contextlib import contextmanager from math import ceil from threading import Lock -from typing import IO, TYPE_CHECKING, Any, Deque, Dict, Iterator, List, NoReturn, Optional, Tuple +from typing import IO, TYPE_CHECKING, Any, NoReturn, Optional from urllib.parse import quote, unquote import boto3 @@ -46,7 +47,7 @@ class SessionPool: def __init__(self): self._lock: Lock = Lock() - self._pool: List[boto3.Session] = [] + self._pool: list[boto3.Session] = [] def pop(self) -> boto3.Session: """Take a session from the pool. If no sessions exist, @@ -86,7 +87,7 @@ def _handle_boto_error( err: Exception, resource: URI, access_type: str, - extra_args: Optional[Dict[str, Any]] = None, + extra_args: Optional[dict[str, Any]] = None, ) -> NoReturn: """Handle an error from boto3.""" if not isinstance(err, ClientError): @@ -106,7 +107,7 @@ def _handle_boto_error( raise FileAccessError(message) from err - def _parse_s3_uri(self, uri: URI) -> Tuple[Scheme, Bucket, Key]: + def _parse_s3_uri(self, uri: URI) -> tuple[Scheme, Bucket, Key]: """Parse an S3 URI to a bucket and key""" scheme, bucket, key = parse_uri(uri) if scheme not in self.SUPPORTED_SCHEMES: # pragma: no cover @@ -184,21 +185,21 @@ def get_resource_exists(self, resource: URI) -> bool: return True - def iter_prefix(self, prefix: URI, recursive: bool = False) -> Iterator[Tuple[URI, NodeType]]: + def iter_prefix(self, prefix: URI, recursive: bool = False) -> Iterator[tuple[URI, NodeType]]: """Iterates over the given prefix""" with get_session(_session_pool) as session: try: scheme, bucket, key = self._parse_s3_uri(prefix) client = session.client("s3", endpoint_url=ENDPOINT_URL) - prefix_keys: Deque[Key] = deque([key]) + prefix_keys: deque[Key] = deque([key]) while prefix_keys: next_key = prefix_keys.popleft().rstrip("/") + "/" - paginate_args = dict( - Bucket=bucket, - Delimiter="/", - ) + paginate_args = { + "Bucket": bucket, + "Delimiter": "/", + } if next_key != "/": paginate_args["Prefix"] = next_key @@ -243,7 +244,7 @@ def remove_resource(self, resource: URI): @staticmethod def _calculate_file_chunks( file_size_bytes: int, chunk_size: int = MULTIPART_CHUNK_SIZE - ) -> Iterator[Tuple[PartNumber, ByteRange]]: + ) -> Iterator[tuple[PartNumber, ByteRange]]: """Calculate the part numbers and byte ranges for a multipart upload's chunks.""" if chunk_size < FIVE_MEBIBYTES: raise ValueError("Chunk size must be at least five mebibytes") @@ -304,7 +305,7 @@ def copy_resource( ) upload_id = cmu_response["UploadId"] try: - parts: List["CompletedPartTypeDef"] = [] + parts: list["CompletedPartTypeDef"] = [] if source_size <= FIVE_MEBIBYTES: # Can't use UploadRange on <= 5MiB file upc_response = s3_client.upload_part_copy( @@ -341,13 +342,13 @@ def copy_resource( Bucket=target_bucket, Key=target_key, UploadId=upload_id ) raise - else: - s3_client.complete_multipart_upload( - Bucket=target_bucket, - Key=target_key, - MultipartUpload={"Parts": parts}, - UploadId=upload_id, - ) + + s3_client.complete_multipart_upload( + Bucket=target_bucket, + Key=target_key, + MultipartUpload={"Parts": parts}, + UploadId=upload_id, + ) return except Exception as err: # pragma: no cover diff --git a/src/dve/parser/file_handling/log_handler.py b/src/dve/parser/file_handling/log_handler.py index cc0e9bc..1b40605 100644 --- a/src/dve/parser/file_handling/log_handler.py +++ b/src/dve/parser/file_handling/log_handler.py @@ -6,9 +6,10 @@ import tempfile import warnings import weakref +from collections.abc import Iterator from contextlib import contextmanager from threading import Lock, RLock -from typing import IO, ClassVar, Dict, Iterator, Optional, Type, Union +from typing import IO, ClassVar, Optional, Union from dve.parser.exceptions import LogDataLossWarning from dve.parser.file_handling.service import open_stream @@ -28,7 +29,7 @@ class _ResourceHandlerManager: """ def __init__(self): - self._extant_handlers: Dict[URI, "ResourceHandler"] = weakref.WeakValueDictionary() + self._extant_handlers: dict[URI, "ResourceHandler"] = weakref.WeakValueDictionary() self._lock = RLock() def get_handler( @@ -36,7 +37,7 @@ def get_handler( level: Union[str, int] = logging.NOTSET, *, resource: URI, - type_: Type["ResourceHandler"], + type_: type["ResourceHandler"], ) -> "ResourceHandler": """Get a handler for a given URI.""" with self._lock: @@ -168,7 +169,7 @@ def close(self): @classmethod def _cleanup( - cls: Type["ResourceHandler"], *, stream: IO[str], resource: URI, lock: Optional[Lock] = None + cls: type["ResourceHandler"], *, stream: IO[str], resource: URI, lock: Optional[Lock] = None ): # pragma: no cover """Write the logs to the remote location. This needs to be done in a separate classmethod so that it can run as a 'real' finalizer. diff --git a/src/dve/parser/file_handling/service.py b/src/dve/parser/file_handling/service.py index 0b659a5..b638eb6 100644 --- a/src/dve/parser/file_handling/service.py +++ b/src/dve/parser/file_handling/service.py @@ -10,10 +10,11 @@ import shutil import uuid import warnings -from contextlib import contextmanager +from collections.abc import Iterator +from contextlib import AbstractContextManager, contextmanager from pathlib import Path from tempfile import TemporaryDirectory -from typing import IO, ContextManager, Iterator, List, Optional, Set, Tuple, overload +from typing import IO, Optional, overload from urllib.parse import unquote, urlparse from typing_extensions import Literal @@ -40,7 +41,7 @@ URIPath, ) -_IMPLEMENTATIONS: List[BaseFilesystemImplementation] = [ +_IMPLEMENTATIONS: list[BaseFilesystemImplementation] = [ S3FilesystemImplementation(), LocalFilesystemImplementation(), ] @@ -51,18 +52,18 @@ except ValueError: pass -_SUPPORTED_SCHEMES: Set[Scheme] = set().union( +_SUPPORTED_SCHEMES: set[Scheme] = set().union( *[impl.SUPPORTED_SCHEMES for impl in _IMPLEMENTATIONS] ) """Supported URI schemes.""" -ALL_FILE_MODES: Set[FileOpenMode] = {"r", "a", "w", "ab", "rb", "wb", "ba", "br", "bw"} +ALL_FILE_MODES: set[FileOpenMode] = {"r", "a", "w", "ab", "rb", "wb", "ba", "br", "bw"} """All supported file modes.""" -TEXT_MODES: Set[TextFileOpenMode] = {"r", "a", "w"} +TEXT_MODES: set[TextFileOpenMode] = {"r", "a", "w"} """Text file modes.""" -APPEND_MODES: Set[FileOpenMode] = {"a", "ab", "ba"} +APPEND_MODES: set[FileOpenMode] = {"a", "ab", "ba"} """Modes that append to the resource.""" -READ_ONLY_MODES: Set[FileOpenMode] = {"r", "br", "rb"} +READ_ONLY_MODES: set[FileOpenMode] = {"r", "br", "rb"} """Modes that only read the file and do not write to it.""" ONE_MEBIBYTE = 1024**3 @@ -109,7 +110,7 @@ def open_stream( mode: TextFileOpenMode = "r", encoding: Optional[str] = None, ensure_seekable: bool = False, -) -> ContextManager[IO[str]]: +) -> AbstractContextManager[IO[str]]: pass # pragma: no cover @@ -119,7 +120,7 @@ def open_stream( mode: BinaryFileOpenMode, encoding: None = None, ensure_seekable: bool = False, -) -> ContextManager[IO[bytes]]: +) -> AbstractContextManager[IO[bytes]]: pass # pragma: no cover @@ -219,7 +220,7 @@ def remove_resource(resource: URI): return _get_implementation(resource).remove_resource(resource) -def iter_prefix(prefix: URI, recursive: bool = False) -> Iterator[Tuple[URI, NodeType]]: +def iter_prefix(prefix: URI, recursive: bool = False) -> Iterator[tuple[URI, NodeType]]: """List the contents of a given prefix.""" return _get_implementation(prefix).iter_prefix(prefix, recursive) @@ -300,8 +301,8 @@ def _transfer_prefix( if not target_prefix.endswith("/"): target_prefix += "/" - source_uris: List[URI] = [] - target_uris: List[URI] = [] + source_uris: list[URI] = [] + target_uris: list[URI] = [] source_impl = _get_implementation(source_prefix) target_impl = _get_implementation(target_prefix) diff --git a/src/dve/parser/file_handling/utilities.py b/src/dve/parser/file_handling/utilities.py index 7b96715..11408d3 100644 --- a/src/dve/parser/file_handling/utilities.py +++ b/src/dve/parser/file_handling/utilities.py @@ -3,7 +3,7 @@ import tempfile from pathlib import Path from types import TracebackType -from typing import Optional, Type +from typing import Optional from dve.parser.exceptions import UnsupportedSchemeError from dve.parser.file_handling.service import is_supported, remove_prefix @@ -44,7 +44,7 @@ def __enter__(self) -> URI: def __exit__( self, - exc_type: Optional[Type[Exception]], + exc_type: Optional[type[Exception]], exc_value: Optional[Exception], traceback: Optional[TracebackType], ): diff --git a/src/dve/parser/type_hints.py b/src/dve/parser/type_hints.py index 3caa862..c8b2fb7 100644 --- a/src/dve/parser/type_hints.py +++ b/src/dve/parser/type_hints.py @@ -1,7 +1,7 @@ """Type hints for the parser.""" from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from typing_extensions import Literal @@ -37,7 +37,7 @@ ReaderName = str """A parser name. This must be importable from `parser.readers`""" -ReaderArgs = Optional[Dict[str, Any]] +ReaderArgs = Optional[dict[str, Any]] """Keyword arguments to be passed to the parser's constructor.""" FieldName = str """The name of a field within the dataset.""" diff --git a/src/dve/parser/utilities.py b/src/dve/parser/utilities.py index e605e5c..95115bb 100644 --- a/src/dve/parser/utilities.py +++ b/src/dve/parser/utilities.py @@ -6,19 +6,20 @@ """ from collections import defaultdict +from collections.abc import Iterable, Iterator from itertools import tee -from typing import Dict, Iterable, Iterator, List, Tuple, TypeVar, Union, overload +from typing import TypeVar, Union, overload from pyspark.sql.types import ArrayType, StringType, StructField, StructType T = TypeVar("T") -TemplateElement = Union[None, List["TemplateElement"], Dict[str, "TemplateElement"]] # type: ignore +TemplateElement = Union[None, list["TemplateElement"], dict[str, "TemplateElement"]] # type: ignore """The base types used in the template row.""" -TemplateRow = Dict[str, "TemplateElement"] # type: ignore +TemplateRow = dict[str, "TemplateElement"] # type: ignore """The type of a template row.""" -def peek(iterable: Iterable[T]) -> Tuple[T, Iterator[T]]: +def peek(iterable: Iterable[T]) -> tuple[T, Iterator[T]]: """Peek the first item from an iterable, returning the first item and an iterator representing the state of the iterable _before_ the first item was taken. @@ -29,15 +30,13 @@ def peek(iterable: Iterable[T]) -> Tuple[T, Iterator[T]]: @overload -def template_row_to_spark_schema(template_element: TemplateRow) -> StructType: - ... +def template_row_to_spark_schema(template_element: TemplateRow) -> StructType: ... @overload def template_row_to_spark_schema( template_element: TemplateElement, -) -> Union[ArrayType, StringType, StructType]: - ... +) -> Union[ArrayType, StringType, StructType]: ... def template_row_to_spark_schema(template_element): @@ -84,7 +83,7 @@ def parse_template_row(field_names: Iterable[str]) -> TemplateRow: """ array_levels = set() - sub_levels_by_level: Dict[str, List[str]] = defaultdict(list) + sub_levels_by_level: dict[str, list[str]] = defaultdict(list) for name in field_names: is_array = name.startswith("[") diff --git a/src/dve/pipeline/duckdb_pipeline.py b/src/dve/pipeline/duckdb_pipeline.py index 1287d6b..4e7707b 100644 --- a/src/dve/pipeline/duckdb_pipeline.py +++ b/src/dve/pipeline/duckdb_pipeline.py @@ -1,6 +1,6 @@ """DuckDB implementation for `Pipeline` object.""" -from typing import Optional, Type +from typing import Optional from duckdb import DuckDBPyConnection, DuckDBPyRelation @@ -29,7 +29,7 @@ def __init__( rules_path: Optional[URI], processed_files_path: Optional[URI], submitted_files_path: Optional[URI], - reference_data_loader: Optional[Type[BaseRefDataLoader]] = None, + reference_data_loader: Optional[type[BaseRefDataLoader]] = None, ): self._connection = connection super().__init__( diff --git a/src/dve/pipeline/pipeline.py b/src/dve/pipeline/pipeline.py index 6f3b1bb..69cb141 100644 --- a/src/dve/pipeline/pipeline.py +++ b/src/dve/pipeline/pipeline.py @@ -3,10 +3,11 @@ import json import re from collections import defaultdict +from collections.abc import Generator, Iterable, Iterator from concurrent.futures import Executor, Future, ThreadPoolExecutor from functools import lru_cache from threading import Lock -from typing import Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Type, Union +from typing import Optional, Union from uuid import uuid4 import polars as pl @@ -29,7 +30,7 @@ from dve.pipeline.utils import SubmissionStatus, deadletter_file, load_config, load_reader from dve.reporting.error_report import ERROR_SCHEMA, calculate_aggregates, conditional_cast -PERMISSIBLE_EXCEPTIONS: Tuple[Type[Exception]] = ( +PERMISSIBLE_EXCEPTIONS: tuple[type[Exception]] = ( FileNotFoundError, # type: ignore FileNotFoundError, ) @@ -49,7 +50,7 @@ def __init__( rules_path: Optional[URI], processed_files_path: Optional[URI], submitted_files_path: Optional[URI], - reference_data_loader: Optional[Type[BaseRefDataLoader]] = None, + reference_data_loader: Optional[type[BaseRefDataLoader]] = None, ): self._submitted_files_path = submitted_files_path self._processed_files_path = processed_files_path @@ -99,7 +100,7 @@ def _dump_errors( submission_id: str, step_name: str, messages: Messages, - key_fields: Optional[Dict[str, List[str]]] = None, + key_fields: Optional[dict[str, list[str]]] = None, ): if not self.processed_files_path: raise AttributeError("processed files path not passed") @@ -113,7 +114,7 @@ def _dump_errors( processed = [] for message in messages: - primary_keys: List[str] = key_fields.get(message.entity if message.entity else "", []) + primary_keys: list[str] = key_fields.get(message.entity if message.entity else "", []) error = message.to_dict( key_field=primary_keys, value_separator=" -- ", @@ -136,11 +137,11 @@ def _move_submission_to_working_location( submission_id: str, submitted_file_uri: URI, submission_info_uri: URI, - ) -> Tuple[URI, URI]: + ) -> tuple[URI, URI]: if not self.processed_files_path: raise AttributeError("Path for processed files not supplied.") - paths: List[URI] = [] + paths: list[URI] = [] for path in (submitted_file_uri, submission_info_uri): source = fh.resolve_location(path) dest = fh.joinuri(self.processed_files_path, submission_id, fh.get_file_name(path)) @@ -149,7 +150,7 @@ def _move_submission_to_working_location( return tuple(paths) # type: ignore - def _get_submission_files_for_run(self) -> Generator[Tuple[FileURI, InfoURI], None, None]: + def _get_submission_files_for_run(self) -> Generator[tuple[FileURI, InfoURI], None, None]: """Yields submission files from the submitted_files path""" # TODO - I think the metadata generation needs to be redesigned or at least generated # TODO - if we continue with this approach. This comments is based on the fact that @@ -239,18 +240,18 @@ def audit_received_file( return sub_info def audit_received_file_step( - self, pool: ThreadPoolExecutor, submitted_files: Iterable[Tuple[FileURI, InfoURI]] - ) -> Tuple[List[SubmissionInfo], List[SubmissionInfo]]: + self, pool: ThreadPoolExecutor, submitted_files: Iterable[tuple[FileURI, InfoURI]] + ) -> tuple[list[SubmissionInfo], list[SubmissionInfo]]: """Set files as being received and mark them for file transformation""" - audit_received_futures: List[Tuple[str, FileURI, Future]] = [] + audit_received_futures: list[tuple[str, FileURI, Future]] = [] for submission_file in submitted_files: data_uri, metadata_uri = submission_file submission_id = uuid4().hex future = pool.submit(self.audit_received_file, submission_id, data_uri, metadata_uri) audit_received_futures.append((submission_id, data_uri, future)) - success: List[SubmissionInfo] = [] - failed: List[SubmissionInfo] = [] + success: list[SubmissionInfo] = [] + failed: list[SubmissionInfo] = [] for submission_id, submission_file_uri, future in audit_received_futures: try: submission_info = future.result() @@ -295,7 +296,7 @@ def audit_received_file_step( def file_transformation( self, submission_info: SubmissionInfo - ) -> Union[SubmissionInfo, Dict[str, str]]: + ) -> Union[SubmissionInfo, dict[str, str]]: """Transform a file from its original format into a 'stringified' parquet file""" if not self.processed_files_path: raise AttributeError("processed files path not provided") @@ -322,23 +323,23 @@ def file_transformation( return submission_info.dict() def file_transformation_step( - self, pool: Executor, submissions_to_process: List[SubmissionInfo] - ) -> Tuple[List[SubmissionInfo], List[SubmissionInfo]]: + self, pool: Executor, submissions_to_process: list[SubmissionInfo] + ) -> tuple[list[SubmissionInfo], list[SubmissionInfo]]: """Step to transform files from their original format into parquet files""" - file_transform_futures: List[Tuple[SubmissionInfo, Future]] = [] + file_transform_futures: list[tuple[SubmissionInfo, Future]] = [] for submission_info in submissions_to_process: # add audit entry future = pool.submit(self.file_transformation, submission_info) file_transform_futures.append((submission_info, future)) - success: List[SubmissionInfo] = [] - failed: List[SubmissionInfo] = [] - failed_processing: List[SubmissionInfo] = [] + success: list[SubmissionInfo] = [] + failed: list[SubmissionInfo] = [] + failed_processing: list[SubmissionInfo] = [] for sub_info, future in file_transform_futures: try: - # sub_info passed here either return SubInfo or Dict. If SubInfo, not actually + # sub_info passed here either return SubInfo or dict. If SubInfo, not actually # modified in anyway during this step. result = future.result() except AttributeError as exc: @@ -386,7 +387,7 @@ def file_transformation_step( return success, failed - def apply_data_contract(self, submission_info: SubmissionInfo) -> Tuple[SubmissionInfo, Failed]: + def apply_data_contract(self, submission_info: SubmissionInfo) -> tuple[SubmissionInfo, Failed]: """Method for applying the data contract given a submission_info""" if not self.processed_files_path: @@ -425,12 +426,12 @@ def apply_data_contract(self, submission_info: SubmissionInfo) -> Tuple[Submissi return submission_info, failed def data_contract_step( - self, pool: Executor, file_transform_results: List[SubmissionInfo] - ) -> Tuple[List[Tuple[SubmissionInfo, Failed]], List[SubmissionInfo]]: + self, pool: Executor, file_transform_results: list[SubmissionInfo] + ) -> tuple[list[tuple[SubmissionInfo, Failed]], list[SubmissionInfo]]: """Step to validate the types of an untyped (stringly typed) parquet file""" - processed_files: List[Tuple[SubmissionInfo, Failed]] = [] - failed_processing: List[SubmissionInfo] = [] - dc_futures: List[Tuple[SubmissionInfo, Future]] = [] + processed_files: list[tuple[SubmissionInfo, Failed]] = [] + failed_processing: list[SubmissionInfo] = [] + dc_futures: list[tuple[SubmissionInfo, Future]] = [] for info in file_transform_results: dc_futures.append((info, pool.submit(self.apply_data_contract, info))) @@ -531,9 +532,7 @@ def apply_business_rules(self, submission_info: SubmissionInfo, failed: bool): entity_name, ), ) - entity_manager.entities[ - entity_name - ] = self.step_implementations.read_parquet( # type: ignore + entity_manager.entities[entity_name] = self.step_implementations.read_parquet( # type: ignore projected ) @@ -553,14 +552,14 @@ def apply_business_rules(self, submission_info: SubmissionInfo, failed: bool): def business_rule_step( self, pool: Executor, - files: List[Tuple[SubmissionInfo, Failed]], - ) -> Tuple[ - List[Tuple[SubmissionInfo, SubmissionStatus]], - List[Tuple[SubmissionInfo, SubmissionStatus]], - List[SubmissionInfo], + files: list[tuple[SubmissionInfo, Failed]], + ) -> tuple[ + list[tuple[SubmissionInfo, SubmissionStatus]], + list[tuple[SubmissionInfo, SubmissionStatus]], + list[SubmissionInfo], ]: """Step to apply business rules (Step impl) to a typed parquet file""" - future_files: List[Tuple[SubmissionInfo, Future]] = [] + future_files: list[tuple[SubmissionInfo, Future]] = [] for submission_info, submission_failed in files: future_files.append( @@ -570,9 +569,9 @@ def business_rule_step( ) ) - failed_processing: List[SubmissionInfo] = [] - unsucessful_files: List[Tuple[SubmissionInfo, SubmissionStatus]] = [] - successful_files: List[Tuple[SubmissionInfo, SubmissionStatus]] = [] + failed_processing: list[SubmissionInfo] = [] + unsucessful_files: list[tuple[SubmissionInfo, SubmissionStatus]] = [] + successful_files: list[tuple[SubmissionInfo, SubmissionStatus]] = [] for sub_info, future in future_files: status: SubmissionStatus @@ -637,9 +636,10 @@ def _get_error_dataframes(self, submission_id: str): df = pl.DataFrame(errors, schema={key: pl.Utf8() for key in errors[0]}) # type: ignore df = df.with_columns( - error_type=pl.when(pl.col("Status") == "error") # type: ignore - .then("Submission Failure") - .otherwise("Warning") + pl.when(pl.col("Status") == pl.lit("error")) # type: ignore + .then(pl.lit("Submission Failure")) # type: ignore + .otherwise(pl.lit("Warning")) # type: ignore + .alias("error_type") ) df = df.select( pl.col("Entity").alias("Table"), # type: ignore @@ -675,7 +675,7 @@ def error_report(self, submission_info: SubmissionInfo, status: SubmissionStatus else: err_types = { rw.get("Type"): rw.get("Count") - for rw in aggregates.groupby(pl.col("Type")) # type: ignore + for rw in aggregates.group_by(pl.col("Type")) # type: ignore .agg(pl.col("Count").sum()) # type: ignore .iter_rows(named=True) } @@ -714,19 +714,19 @@ def error_report(self, submission_info: SubmissionInfo, status: SubmissionStatus def error_report_step( self, pool: Executor, - processed: Iterable[Tuple[SubmissionInfo, SubmissionStatus]] = tuple(), + processed: Iterable[tuple[SubmissionInfo, SubmissionStatus]] = tuple(), failed_file_transformation: Iterable[SubmissionInfo] = tuple(), - ) -> List[ - Tuple[SubmissionInfo, SubmissionStatus, Union[None, SubmissionStatisticsRecord], URI] + ) -> list[ + tuple[SubmissionInfo, SubmissionStatus, Union[None, SubmissionStatisticsRecord], URI] ]: """Step to produce error reports takes processed files and files that failed file transformation """ - futures: List[Tuple[SubmissionInfo, Future]] = [] - reports: List[ - Tuple[SubmissionInfo, SubmissionStatus, Union[None, SubmissionStatisticsRecord], URI] + futures: list[tuple[SubmissionInfo, Future]] = [] + reports: list[ + tuple[SubmissionInfo, SubmissionStatus, Union[None, SubmissionStatisticsRecord], URI] ] = [] - failed_processing: List[SubmissionInfo] = [] + failed_processing: list[SubmissionInfo] = [] for info, status in processed: futures.append((info, pool.submit(self.error_report, info, status))) @@ -775,7 +775,7 @@ def error_report_step( def cluster_pipeline_run( self, max_workers: int = 7 - ) -> Iterator[List[Tuple[SubmissionInfo, SubmissionStatus, URI]]]: + ) -> Iterator[list[tuple[SubmissionInfo, SubmissionStatus, URI]]]: """Method for running the full DVE pipeline from start to finish.""" submission_files = self._get_submission_files_for_run() diff --git a/src/dve/pipeline/spark_pipeline.py b/src/dve/pipeline/spark_pipeline.py index 60b7b18..d31a2ee 100644 --- a/src/dve/pipeline/spark_pipeline.py +++ b/src/dve/pipeline/spark_pipeline.py @@ -1,7 +1,7 @@ """Spark implementation for `Pipeline` object.""" from concurrent.futures import Executor -from typing import List, Optional, Tuple, Type +from typing import Optional from pyspark.sql import DataFrame, SparkSession @@ -30,7 +30,7 @@ def __init__( rules_path: Optional[URI], processed_files_path: Optional[URI], submitted_files_path: Optional[URI], - reference_data_loader: Optional[Type[BaseRefDataLoader]] = None, + reference_data_loader: Optional[type[BaseRefDataLoader]] = None, spark: Optional[SparkSession] = None, ): self._spark = spark if spark else SparkSession.builder.getOrCreate() @@ -56,7 +56,7 @@ def write_file_to_parquet( # type: ignore def business_rule_step( self, pool: Executor, - files: List[Tuple[SubmissionInfo, Failed]], + files: list[tuple[SubmissionInfo, Failed]], ): successful_files, unsucessful_files, failed_processing = super().business_rule_step( pool, files diff --git a/src/dve/pipeline/utils.py b/src/dve/pipeline/utils.py index f4b6620..4fa9a02 100644 --- a/src/dve/pipeline/utils.py +++ b/src/dve/pipeline/utils.py @@ -1,7 +1,8 @@ """Utilities to be used with services to abstract away some of the config loading and threading""" + import json from threading import Lock -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Union from pydantic.main import ModelMetaclass from pyspark.sql import SparkSession @@ -14,15 +15,15 @@ from dve.core_engine.type_hints import URI, SubmissionResult from dve.metadata_parser.model_generator import JSONtoPyd -Dataset = Dict[SchemaName, _ModelConfig] -_configs: Dict[str, Tuple[Dict[str, ModelMetaclass], V1EngineConfig, Dataset]] = {} +Dataset = dict[SchemaName, _ModelConfig] +_configs: dict[str, tuple[dict[str, ModelMetaclass], V1EngineConfig, Dataset]] = {} locks = Lock() def load_config( dataset_id: str, file_uri: URI, -) -> Tuple[Dict[SchemaName, ModelMetaclass], V1EngineConfig, Dict[SchemaName, _ModelConfig]]: +) -> tuple[dict[SchemaName, ModelMetaclass], V1EngineConfig, dict[SchemaName, _ModelConfig]]: """Loads the configuration for a given dataset""" if dataset_id in _configs: return _configs[dataset_id] diff --git a/src/dve/pipeline/xml_linting.py b/src/dve/pipeline/xml_linting.py index 097d8da..fbf53a6 100644 --- a/src/dve/pipeline/xml_linting.py +++ b/src/dve/pipeline/xml_linting.py @@ -5,10 +5,11 @@ import shutil import sys import tempfile +from collections.abc import Iterable, Sequence from contextlib import ExitStack from pathlib import Path from subprocess import PIPE, STDOUT, Popen -from typing import Dict, Iterable, List, Optional, Sequence, Tuple +from typing import Optional from uuid import uuid4 from typing_extensions import Literal @@ -32,9 +33,9 @@ """The size of 5 binary megabytes, in bytes.""" # Patterns/strings for xmllint message sanitisation. -IGNORED_PATTERNS: List[re.Pattern] = [re.compile(r"^Unimplemented block at")] +IGNORED_PATTERNS: list[re.Pattern] = [re.compile(r"^Unimplemented block at")] """Regex patterns for messages that should result in their omission.""" -ERRONEOUS_PATTERNS: List[Tuple[re.Pattern, Replacement]] = [ +ERRONEOUS_PATTERNS: list[tuple[re.Pattern, Replacement]] = [ ( re.compile(r"^XSD schema .+/(?P.+) failed to compile$"), r"Missing required component of XSD schema '\g'", @@ -86,16 +87,16 @@ Pattern to match unexpected fields rather than out of order fields. Captures incorrect field """ -REMOVED_PATTERNS: List[re.Pattern] = [ +REMOVED_PATTERNS: list[re.Pattern] = [ re.compile(r"\{.+?\}"), re.compile(r"[\. ]+$"), ] """Regex patterns to remove from the xmllint output.""" -REPLACED_PATTERNS: List[Tuple[re.Pattern, Replacement]] = [ +REPLACED_PATTERNS: list[tuple[re.Pattern, Replacement]] = [ (re.compile(r":(?P\d+):"), r" on line \g:"), ] """Regex patterns to replace in the xmllint output.""" -REPLACED_STRINGS: List[Tuple[str, Replacement]] = [ +REPLACED_STRINGS: list[tuple[str, Replacement]] = [ ( "No matching global declaration available for the validation root", "Incorrect namespace version, please ensure you have the most recent namespace", @@ -139,7 +140,7 @@ def _sanitise_lint_issue(issue: str, file_name: str) -> Optional[str]: def _parse_lint_messages( lint_messages: Iterable[str], - error_mapping: Dict[re.Pattern, Tuple[ErrorMessage, ErrorCode]], + error_mapping: dict[re.Pattern, tuple[ErrorMessage, ErrorCode]], stage: Stage = "Pre-validation", file_name: Optional[str] = None, ) -> Messages: @@ -216,7 +217,7 @@ def run_xmllint( file_uri: URI, schema_uri: URI, *schema_resources: URI, - error_mapping: Dict[re.Pattern, Tuple[ErrorMessage, ErrorCode]], + error_mapping: dict[re.Pattern, tuple[ErrorMessage, ErrorCode]], stage: Stage = "Pre-validation", ) -> Messages: """Run `xmllint`, given a file and information about the schemas to apply. @@ -328,7 +329,7 @@ def run_xmllint( return messages -def _main(cli_args: List[str]): +def _main(cli_args: list[str]): """Command line interface for XML linting. Useful for testing.""" parser = argparse.ArgumentParser() parser.add_argument("xml_file_path", help="The path to the XML file to be validated") diff --git a/src/dve/reporting/error_report.py b/src/dve/reporting/error_report.py index 5cc26d9..ba4a4ac 100644 --- a/src/dve/reporting/error_report.py +++ b/src/dve/reporting/error_report.py @@ -2,12 +2,13 @@ import datetime as dt import json +from collections import deque from functools import partial from multiprocessing import Pool, cpu_count -from typing import Deque, Dict, List, Tuple, Union +from typing import Union import polars as pl -from polars import DataFrame, LazyFrame, Utf8, col, count # type: ignore +from polars import DataFrame, LazyFrame, Utf8, col # type: ignore from dve.core_engine.message import FeedbackMessage from dve.parser.file_handling.service import open_stream @@ -38,7 +39,7 @@ def get_error_codes(error_code_path: str) -> LazyFrame: """Returns an error code dataframe from a json file on any supported filesystem""" with open_stream(error_code_path) as stream: error_codes = json.load(stream) - df_lists: Dict[str, List[str]] = {"Category": [], "Data_Item": [], "Error_Code": []} + df_lists: dict[str, list[str]] = {"Category": [], "Data_Item": [], "Error_Code": []} for field, code in error_codes.items(): for category in ("Blank", "Wrong format", "Bad value"): df_lists["Category"].append(category) @@ -48,7 +49,7 @@ def get_error_codes(error_code_path: str) -> LazyFrame: return pl.DataFrame(df_lists).lazy() # type: ignore -def conditional_cast(value, primary_keys: List[str], value_separator: str) -> Union[List[str], str]: +def conditional_cast(value, primary_keys: list[str], value_separator: str) -> Union[list[str], str]: """Determines what to do with a value coming back from the error list""" if isinstance(value, list): casts = [ @@ -66,9 +67,11 @@ def conditional_cast(value, primary_keys: List[str], value_separator: str) -> Un def _convert_inner_dict(error: FeedbackMessage, key_fields): return { - key: str(conditional_cast(value, key_fields.get(error.entity, ""), " -- ")) - if value is not None - else None + key: ( + str(conditional_cast(value, key_fields.get(error.entity, ""), " -- ")) + if value is not None + else None + ) for key, value in error.to_dict( key_fields.get(error.entity), max_number_of_values=10, @@ -78,7 +81,7 @@ def _convert_inner_dict(error: FeedbackMessage, key_fields): } -def create_error_dataframe(errors: Deque[FeedbackMessage], key_fields): +def create_error_dataframe(errors: deque[FeedbackMessage], key_fields): """Creates a Lazyframe from a Deque of feedback messages and their key fields""" if not errors: return DataFrame({}, schema=ERROR_SCHEMA) @@ -96,20 +99,21 @@ def create_error_dataframe(errors: Deque[FeedbackMessage], key_fields): schema=schema, ) - df = df.with_columns( - error_type=pl.when(col("Status") == "error") # type: ignore - .then("Submission Failure") - .otherwise("Warning") + df = df.with_columns( # type: ignore + pl.when(pl.col("Status") == pl.lit("error")) # type: ignore + .then(pl.lit("Submission Failure")) # type: ignore + .otherwise(pl.lit("Warning")) # type: ignore + .alias("error_type") ) - df = df.select( - col("Entity").alias("Table"), - col("error_type").alias("Type"), - col("ErrorCode").alias("Error_Code"), - col("ReportingField").alias("Data_Item"), - col("ErrorMessage").alias("Error"), - col("Value"), - col("Key").alias("ID"), - col("Category"), + df = df.select( # type: ignore + col("Entity").alias("Table"), # type: ignore + col("error_type").alias("Type"), # type: ignore + col("ErrorCode").alias("Error_Code"), # type: ignore + col("ReportingField").alias("Data_Item"), # type: ignore + col("ErrorMessage").alias("Error"), # type: ignore + col("Value"), # type: ignore + col("Key").alias("ID"), # type: ignore + col("Category"), # type: ignore ) return df.sort("Type", descending=False).collect() # type: ignore @@ -128,20 +132,27 @@ def calculate_aggregates(error_frame: DataFrame) -> DataFrame: if error_frame.is_empty(): return DataFrame({}, schema=AGGREGATE_SCHEMA) aggregates = ( - error_frame.lazy() # type: ignore - .groupby(["Table", "Type", "Data_Item", "Error_Code", "Category"]) - .agg(count("*")) + error_frame.group_by( + [ + pl.col("Table"), # type: ignore + pl.col("Type"), # type: ignore + pl.col("Data_Item"), # type: ignore + pl.col("Error_Code"), # type: ignore + pl.col("Category"), # type: ignore + ] + ) + .agg(pl.len()) # type: ignore .select( # type: ignore - "Type", - "Table", - "Data_Item", - "Category", - "Error_Code", - col("Value").alias("Count"), + pl.col("Type"), # type: ignore + pl.col("Table"), # type: ignore + pl.col("Data_Item"), # type: ignore + pl.col("Category"), # type: ignore + pl.col("Error_Code"), # type: ignore + pl.col("len").alias("Count"), # type: ignore ) - .sort("Type", "Count", descending=[False, True]) + .sort(pl.col("Type"), pl.col("Count"), descending=[False, True]) # type: ignore ) - return aggregates.collect() # type: ignore + return aggregates def generate_report_dataframes( @@ -149,7 +160,7 @@ def generate_report_dataframes( contract_error_codes, key_fields, populate_codes: bool = True, -) -> Tuple[pl.DataFrame, pl.DataFrame]: # type: ignore +) -> tuple[pl.DataFrame, pl.DataFrame]: # type: ignore """Generates the error detail and aggregates dataframes""" error_df = create_error_dataframe(errors, key_fields) diff --git a/src/dve/reporting/excel_report.py b/src/dve/reporting/excel_report.py index 717a1cc..727b244 100644 --- a/src/dve/reporting/excel_report.py +++ b/src/dve/reporting/excel_report.py @@ -1,10 +1,11 @@ # mypy: disable-error-code="attr-defined" """Creates an excel report from error data""" +from collections.abc import Iterable from dataclasses import dataclass, field from io import BytesIO from itertools import chain -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Optional, Union import polars as pl from openpyxl import Workbook, utils @@ -20,16 +21,16 @@ class SummaryItems: """Items to go into the Summary sheet""" - summary_dict: Dict[str, Any] = field(default_factory=dict) + summary_dict: dict[str, Any] = field(default_factory=dict) """Dictionary of items to show in the front sheet key is put into Column B and value in column C""" - row_headings: List[str] = field(default_factory=list) + row_headings: list[str] = field(default_factory=list) """Which errors are expected to show in the summary table""" - table_columns: List[str] = field(default_factory=list) + table_columns: list[str] = field(default_factory=list) """Names of the tables to show in summary table""" partion_key: Optional[str] = None """key to split summary items into multiple tables""" - aggregations: List[pl.Expr] = field(default_factory=lambda: [pl.sum("Count")]) # type: ignore + aggregations: list[pl.Expr] = field(default_factory=lambda: [pl.sum("Count")]) # type: ignore """List of aggregations to apply to the grouped up dataframe""" additional_columns: Optional[list] = None """any additional columns to add to the summary table""" @@ -66,7 +67,7 @@ def create_summary_sheet( error_summary = ( # chaining methods on dataframes seems to confuse mypy - aggregates.groupby(groups).agg(*self.aggregations) # type: ignore + aggregates.group_by(groups).agg(*self.aggregations) # type: ignore ) try: @@ -98,11 +99,11 @@ def get_submission_status(aggregates: DataFrame) -> str: return status def _write_table( - self, summary: Worksheet, columns: List[str], error_summary: DataFrame, row_headings + self, summary: Worksheet, columns: list[str], error_summary: DataFrame, row_headings ): summary.append(["", "", *columns]) for error_type in row_headings: - row: List[Any] = ["", error_type] + row: list[Any] = ["", error_type] for column in columns: if error_summary.is_empty(): counts = error_summary @@ -207,7 +208,7 @@ def create_summary_sheet( error_summary = ( # chaining methods on dataframes seems to confuse mypy - aggregates.groupby(groups).agg(*self.aggregations) # type: ignore + aggregates.group_by(groups).agg(*self.aggregations) # type: ignore ) tables = [table for table in tables if table is not None] column = self.partition_key @@ -239,7 +240,7 @@ def get_submission_status(aggregates: DataFrame) -> str: def _write_combined_table( self, summary: Worksheet, - tables: List[str], + tables: list[str], error_summary: DataFrame, ): try: @@ -256,7 +257,7 @@ def _write_combined_table( summary.append(["", self.row_field.capitalize(), *map(str.capitalize, tables)]) for row_type in sorted(row_headings): - row: List[Any] = ["", row_type] + row: list[Any] = ["", row_type] for table in tables: count_field = self.table_mapping.get(table, "Count") if table in self.table_columns: @@ -281,7 +282,7 @@ class ExcelFormat: def __init__( self, - error_details: Union[DataFrame, Dict[str, DataFrame]], + error_details: Union[DataFrame, dict[str, DataFrame]], error_aggregates: DataFrame, summary_aggregates: Optional[DataFrame] = None, overflow=1_000_000, @@ -353,7 +354,7 @@ def create_error_data_sheets( self, workbook: Workbook, invalid_data: Iterable[str], - headings: List[str], + headings: list[str], title: str = "Error Data", suffix: int = 0, additional_id: Optional[str] = None, @@ -408,7 +409,7 @@ def _format_error_sheet(self, error_report): self._expand_columns(error_report) def create_error_aggregate_sheet( - self, workbook: Workbook, aggregate: List[Dict[str, Any]], headings: List[str] + self, workbook: Workbook, aggregate: list[dict[str, Any]], headings: list[str] ): """Creates a sheet aggregating errors together to give a more granular overview""" # Create sheet for error summary info @@ -441,7 +442,7 @@ def _text_length(value): return 0 if value is None else len(str(value)) @staticmethod - def _format_headings(headings: List[str]) -> List[str]: + def _format_headings(headings: list[str]) -> list[str]: headings = [heading.title() if heading[0].islower() else heading for heading in headings] headings = [heading.replace("_", " ") for heading in headings] return headings diff --git a/tests/features/patches.py b/tests/features/patches.py index 3fe9a7e..6db6a8c 100644 --- a/tests/features/patches.py +++ b/tests/features/patches.py @@ -89,7 +89,7 @@ def get_spark_session() -> SparkSession: os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join( [ "--packages", - "com.databricks:spark-xml_2.12:0.16.0,io.delta:delta-core_2.12:1.1.0", + "com.databricks:spark-xml_2.12:0.16.0,io.delta:delta-core_2.12:2.4.0", "pyspark-shell", ] ) diff --git a/tests/test_parser/test_file_handling.py b/tests/test_parser/test_file_handling.py index 42d6011..cfa90be 100644 --- a/tests/test_parser/test_file_handling.py +++ b/tests/test_parser/test_file_handling.py @@ -10,6 +10,7 @@ import boto3 import pytest +from pytest_lazy_fixtures import lf as lazy_fixture from typing_extensions import Literal from dve.parser.exceptions import FileAccessError, LogDataLossWarning @@ -68,10 +69,10 @@ def test_s3_uri_raises_missing_bucket(): @pytest.mark.parametrize( "prefix", [ - pytest.lazy_fixture("temp_prefix"), - pytest.lazy_fixture("temp_s3_prefix"), - pytest.lazy_fixture("temp_dbfs_prefix"), - ], # type: ignore + lazy_fixture("temp_prefix"), + lazy_fixture("temp_s3_prefix"), + lazy_fixture("temp_dbfs_prefix"), + ], # type: ignore # pylint: disable=E1102 ) class TestParametrizedFileInteractions: """Tests which involve S3 and local filesystem.""" @@ -436,10 +437,10 @@ def test_filename_resolver_linux(uri, expected): @pytest.mark.parametrize( ["source_prefix", "target_prefix"], [ - (pytest.lazy_fixture("temp_prefix"), pytest.lazy_fixture("temp_prefix")), # type: ignore - (pytest.lazy_fixture("temp_s3_prefix"), pytest.lazy_fixture("temp_s3_prefix")), # type: ignore - (pytest.lazy_fixture("temp_prefix"), pytest.lazy_fixture("temp_s3_prefix")), # type: ignore - (pytest.lazy_fixture("temp_s3_prefix"), pytest.lazy_fixture("temp_prefix")), # type: ignore + (lazy_fixture("temp_prefix"), lazy_fixture("temp_prefix")), # type: ignore + (lazy_fixture("temp_s3_prefix"), lazy_fixture("temp_s3_prefix")), # type: ignore + (lazy_fixture("temp_prefix"), lazy_fixture("temp_s3_prefix")), # type: ignore + (lazy_fixture("temp_s3_prefix"), lazy_fixture("temp_prefix")), # type: ignore ], ) def test_copy_move_resource( @@ -476,11 +477,11 @@ def test_copy_move_resource( @pytest.mark.parametrize( ["source_prefix", "target_prefix"], [ - (pytest.lazy_fixture("temp_prefix"), pytest.lazy_fixture("temp_prefix")), # type: ignore - (pytest.lazy_fixture("temp_s3_prefix"), pytest.lazy_fixture("temp_s3_prefix")), # type: ignore - (pytest.lazy_fixture("temp_prefix"), pytest.lazy_fixture("temp_s3_prefix")), # type: ignore - (pytest.lazy_fixture("temp_s3_prefix"), pytest.lazy_fixture("temp_prefix")), # type: ignore - ], + (lazy_fixture("temp_prefix"), lazy_fixture("temp_prefix")), # type: ignore + (lazy_fixture("temp_s3_prefix"), lazy_fixture("temp_s3_prefix")), # type: ignore + (lazy_fixture("temp_prefix"), lazy_fixture("temp_s3_prefix")), # type: ignore + (lazy_fixture("temp_s3_prefix"), lazy_fixture("temp_prefix")), # type: ignore + ], # pylint: disable=E1102 ) def test_copy_move_prefix(source_prefix: str, target_prefix: str, action: Literal["copy", "move"]): """Test that resources can be copied and moved.""" diff --git a/tests/test_pipeline/pipeline_helpers.py b/tests/test_pipeline/pipeline_helpers.py index 3b3c706..de2ee10 100644 --- a/tests/test_pipeline/pipeline_helpers.py +++ b/tests/test_pipeline/pipeline_helpers.py @@ -399,4 +399,4 @@ def error_data_after_business_rules() -> Iterator[Tuple[SubmissionInfo, str]]: def pl_row_count(df: pl.DataFrame) -> int: - return df.select(pl.count()).to_dicts()[0]["count"] + return df.select(pl.len()).to_dicts()[0]["len"] diff --git a/tests/test_pipeline/test_spark_pipeline.py b/tests/test_pipeline/test_spark_pipeline.py index 126f6de..c3a7fb2 100644 --- a/tests/test_pipeline/test_spark_pipeline.py +++ b/tests/test_pipeline/test_spark_pipeline.py @@ -435,7 +435,7 @@ def test_error_report_where_report_is_expected( # pylint: disable=redefined-out report_records = ( pl.read_excel(report_uri) - .filter(pl.col("Data Summary") != pl.lit(None)) + .filter(pl.col("Data Summary").is_not_null()) .select(pl.col("Data Summary"), pl.col("_duplicated_0")) .rows() )