From 87419865bd2de1e8357a181aecbc342b699a4f4a Mon Sep 17 00:00:00 2001 From: YuanmingLeee Date: Mon, 20 Nov 2023 18:32:13 +0800 Subject: [PATCH 01/20] :sparkles: [metadata] create a lineage server --- metadata/__init__.py | 7 ++ metadata/models.py | 164 +++++++++++++++++++++++++++++++++++++++++++ metadata/server.py | 24 +++++++ requirements.txt | 2 + 4 files changed, 197 insertions(+) create mode 100644 metadata/__init__.py create mode 100644 metadata/models.py create mode 100644 metadata/server.py diff --git a/metadata/__init__.py b/metadata/__init__.py new file mode 100644 index 0000000..ca67a66 --- /dev/null +++ b/metadata/__init__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Author: Li Yuanming +Email: yuanmingleee@gmail.com +Date: Nov 20, 2023 +""" diff --git a/metadata/models.py b/metadata/models.py new file mode 100644 index 0000000..f75b76d --- /dev/null +++ b/metadata/models.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Author: Li Yuanming +Email: yuanmingleee@gmail.com +Date: Nov 20, 2023 +""" +from datetime import datetime +from enum import Enum +from typing import List, Optional, Dict, Any +from uuid import UUID + +from pydantic import BaseModel, Field, AnyUrl, Extra, root_validator + + +class BaseEvent(BaseModel): + eventTime: datetime = Field(default_factory=datetime.utcnow, description="the time the event occurred at") + producer: AnyUrl = Field( + description="URI identifying the producer of this metadata. For example this could be a git url with a given tag or sha", + example="https://github.com/OpenLineage/OpenLineage/blob/v1-0-0/client", + ) + schemaURL: AnyUrl = Field( + description="The JSON Pointer (https://tools.ietf.org/html/rfc6901) URL to the corresponding version of the schema definition for this RunEvent", + example="https://openlineage.io/spec/0-0-1/OpenLineage.json", + ) + +class BaseFacet(BaseModel, extra=Extra.allow): + """all fields of the base facet are prefixed with _ to avoid name conflicts in facets""" + + _producer: AnyUrl = Field( + description="URI identifying the producer of this metadata. For example this could be a git url with a given tag or sha", + example="https://github.com/OpenLineage/OpenLineage/blob/v1-0-0/client" + ) + _schemaURL: AnyUrl = Field( + description="The JSON Pointer (https://tools.ietf.org/html/rfc6901) URL to the corresponding version of the schema definition for this facet", + example="https://openlineage.io/spec/1-0-2/OpenLineage.json#/$defs/BaseFacet" + ) + +class RunFacet(BaseFacet): + """A Run Facet""" + +class Run(BaseModel): + runId: UUID = Field(description="The globally unique ID of the run associated with the job.") + facets: Optional[Dict[Any, RunFacet]] = Field( + default_factory=dict, + description="The run facets.", + ) + + +class JobFacet(BaseFacet): + """A Job Facet""" + _deleted: bool = Field( + description="set to true to delete a facet", + ) + + +class DatasetFacet(BaseFacet): + """A Dataset Facet""" + _deleted: bool = Field( + description="set to true to delete a facet", + ) + + +class InputDatasetFacet(DatasetFacet): + """An Input Dataset Facet""" + + +class OutputDatasetFacet(DatasetFacet): + """An Output Dataset Facet""" + + +class Job(BaseModel): + namespace: str = Field(description="The namespace containing that job", example="my-scheduler-namespace") + name: str = Field(description="The unique name for that job within that namespace", example="myjob.mytask") + facets: Optional[Dict[Any, JobFacet]] = Field( + default_factory=dict, + description="The job facets.", + ) + +class Dataset(BaseModel): + namespace: str = Field(description="The namespace containing that dataset", example="my-datasource-namespace") + name: str = Field(description="The unique name for that dataset within that namespace", example="instance.schema.table") + facets: Optional[Dict[Any, Any]] = Field( + default_factory=dict, + description="The facets for this dataset", + ) + + +class StaticDataset(Dataset): + """A Dataset sent within static metadata events""" + + +class InputDataset(Dataset): + """An input dataset""" + inputFacets: Optional[Dict[Any, InputDatasetFacet]] = Field( + default_factory=dict, + description="The input facets for this dataset.", + ) + + +class OutputDataset(Dataset): + """An output dataset""" + outputFacets: Optional[Dict[Any, OutputDatasetFacet]] = Field( + default_factory=dict, + description="The output facets for this dataset", + ) + + +class RunState(Enum): + START = "START" + RUNNING = "RUNNING" + COMPLETE = "COMPLETE" + ABORT = "ABORT" + FAIL = "FAIL" + OTHER = "OTHER" + + +class RunEvent(BaseEvent): + eventType: RunState = Field( + description="the current transition of the run state. It is required to issue 1 START event and 1 of [ COMPLETE, ABORT, FAIL ] event per run. Additional events with OTHER eventType can be added to the same run. For example to send additional metadata after the run is complete", + example="START|RUNNING|COMPLETE|ABORT|FAIL|OTHER", + # enum=["START", "RUNNING", "COMPLETE", "ABORT", "FAIL", "OTHER"], + ) + run: Run + job: Job + inputs: Optional[List[InputDataset]] = Field(default_factory=list, description="The set of **input** datasets.") + outputs: Optional[List[OutputDataset]] = Field(default_factory=list, description="The set of **output** datasets.") + + +class DatasetEvent(BaseEvent): + dataset: StaticDataset + + @root_validator + def check_not_required(cls, values): + if "job" in values or "run" in values: + raise ValueError("DatasetEvent should not contain `job` or `run`") + return values + + class Config: + schema_extra = {"not": {"required": ["job", "run"]}} + + +class JobEvent(BaseEvent): + job: Job + inputs: Optional[List[InputDataset]] = Field(default_factory=list, description="The set of **input** datasets.") + outputs: Optional[List[OutputDataset]] = Field(default_factory=list, description="The set of **output** datasets.") + + @root_validator + def check_not_required(cls, values): + if "run" in values: + raise ValueError("JobEvent should not contain `run`") + return values + + class Config: + schema_extra = {"not": {"required": ["run"]}} + + +if __name__ == '__main__': + import rich + import builtins + + builtins.print = rich.print + + print(RunEvent.schema_json(indent=2)) diff --git a/metadata/server.py b/metadata/server.py new file mode 100644 index 0000000..304f830 --- /dev/null +++ b/metadata/server.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Author: Li Yuanming +Email: yuanmingleee@gmail.com +Date: Nov 20, 2023 +""" +from typing import List, Optional, Union + +from openlineage.client.run import DatasetEvent, JobEvent +from fastapi import FastAPI + +from metadata.models import RunEvent, DatasetEvent, JobEvent + +app = FastAPI() + +# Record lineage information as schema defined as OpenLineage (2-0-2) +# https://openlineage.io/apidocs/openapi/ +@app.post('/lineage', summary='Send an event related to the state of a run') +def post_lineage(event: Union[RunEvent, DatasetEvent, JobEvent]): + """Updates a run state for a job. + """ + print(event) + return {'status': 'success'} diff --git a/requirements.txt b/requirements.txt index 88da1f9..c8eb696 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,8 @@ networkx>=3.1 uvicorn>=0.15.0 click==8.1.3 tqdm>=4.62.3 +pydantic>=1.8.2 pygraphviz>=1.11 +pyan3>=1.2.0 rich>=13.3.5 GitPython>=3.1.24 \ No newline at end of file From ba784cb64b1efff22a6b73186c08b638ce0a49b0 Mon Sep 17 00:00:00 2001 From: YuanmingLeee Date: Thu, 23 Nov 2023 18:25:06 +0800 Subject: [PATCH 02/20] :sparkles: [lineage] add lineage and run data model --- dataci/models/__init__.py | 4 +- dataci/models/base.py | 2 +- dataci/models/lineage.py | 91 +++++++++++++++++++++++++++++++++++ dataci/models/run.py | 59 +++++++++++++++++++++++ dataci/models/run/__init__.py | 10 ---- dataci/models/run/list.py | 42 ---------------- dataci/models/run/run.py | 90 ---------------------------------- dataci/models/run/save.py | 61 ----------------------- dataci/models/stage.py | 1 + dataci/models/workflow.py | 7 ++- metadata/models.py | 23 ++++++++- metadata/server.py | 21 ++++++-- 12 files changed, 199 insertions(+), 212 deletions(-) create mode 100644 dataci/models/lineage.py create mode 100644 dataci/models/run.py delete mode 100644 dataci/models/run/__init__.py delete mode 100644 dataci/models/run/list.py delete mode 100644 dataci/models/run/run.py delete mode 100644 dataci/models/run/save.py diff --git a/dataci/models/__init__.py b/dataci/models/__init__.py index 43089d8..3aa6ff4 100644 --- a/dataci/models/__init__.py +++ b/dataci/models/__init__.py @@ -8,10 +8,12 @@ from .base import BaseModel from .dataset import Dataset from .event import Event +from .lineage import Lineage +from .run import Run from .stage import Stage from .workflow import Workflow from .workspace import Workspace __all__ = [ - 'BaseModel', 'Workspace', 'Dataset', 'Event', 'Workflow', 'Stage', + 'BaseModel', 'Workspace', 'Dataset', 'Event', 'Workflow', 'Stage', 'Run', 'Lineage', ] diff --git a/dataci/models/base.py b/dataci/models/base.py index 5f94d74..294a912 100644 --- a/dataci/models/base.py +++ b/dataci/models/base.py @@ -55,7 +55,7 @@ def uri(self): return f'dataci://{self.workspace.name}/{self.type_name}/{self.name}/{self.version}' @abc.abstractmethod - def dict(self): + def dict(self, id_only=False): pass @classmethod diff --git a/dataci/models/lineage.py b/dataci/models/lineage.py new file mode 100644 index 0000000..41f9edb --- /dev/null +++ b/dataci/models/lineage.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Author: Li Yuanming +Email: yuanmingleee@gmail.com +Date: Nov 22, 2023 +""" +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import List, Optional, Union + + from dataci.models import Dataset, Workflow, Stage, Run + +class Lineage(object): + + def __init__( + self, + run: 'Union[Run, dict]', + parent_run: 'Optional[Union[Run, dict]]' = None, + inputs: 'List[Union[Dataset, dict]]' = None, + outputs: 'List[Union[Dataset, dict]]' = None, + ): + self._run = run + self._parent_run = parent_run + self._inputs: 'List[Dataset]' = inputs or list() + self._outputs: 'List[Dataset]' = outputs or list() + + def dict(self): + return { + 'parent_run': self.parent_run.dict(id_only=True) if self.parent_run else None, + 'run': self.run.dict() if self.run else None, + 'inputs': [input_.dict(id_only=True) for input_ in self.inputs], + 'outputs': [output.dict(id_only=True) for output in self.outputs], + } + + @classmethod + def from_dict(cls, config): + pass + + @property + def job(self) -> 'Union[Workflow, Stage]': + return self.run.job + + @property + def run(self) -> 'Run': + """Lazy load run from database.""" + from dataci.models import Run + + if not isinstance(self._run, Run): + self._run = Run.get(self._run['run_id']) + return self._run + + @property + def parent_run(self) -> 'Optional[Run]': + """Lazy load parent run from database.""" + from dataci.models import Run + + if self._parent_run is None: + return None + + if not isinstance(self._parent_run, Run): + self._parent_run = Run.get(self._parent_run['run_id']) + return self._parent_run + + @property + def inputs(self) -> 'List[Dataset]': + """Lazy load inputs from database.""" + inputs = list() + for input_ in self._inputs: + if not isinstance(input_, Dataset): + dataset_id = input_['workspace'] + '.' + input_['name'] + '@' + input_['version'] + inputs.append(Dataset.get(dataset_id)) + else: + inputs.append(input_) + self._inputs = inputs + return self._inputs + + @property + def outputs(self) -> 'List[Dataset]': + """Lazy load outputs from database.""" + outputs = list() + for output in self._outputs: + if not isinstance(output, Dataset): + dataset_id = output['workspace'] + '.' + output['name'] + '@' + output['version'] + outputs.append(Dataset.get(dataset_id)) + else: + outputs.append(output) + self._outputs = outputs + return self._outputs + diff --git a/dataci/models/run.py b/dataci/models/run.py new file mode 100644 index 0000000..bd5a29c --- /dev/null +++ b/dataci/models/run.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Author: Li Yuanming +Email: yuanmingleee@gmail.com +Date: Mar 09, 2023 + +Run for pipeline. +""" +import os +from copy import deepcopy +from shutil import rmtree, copy2 +from typing import TYPE_CHECKING + +from dataci.utils import symlink_force + +if TYPE_CHECKING: + from typing import Union + + from dataci.models import Workflow, Stage + + +class Run(object): + def __init__(self, run_id: str, status: str, job: 'Union[Workflow, Stage, dict]', try_num: int, **kwargs): + self.run_id = run_id + self.status: str = status + self._job = job + self.try_num = try_num + + @property + def job(self) -> 'Union[Workflow, Stage]': + """Lazy load job (workflow or stage) from database.""" + from dataci.models import Workflow, Stage + + if not isinstance(self._job, (Workflow, Stage)): + job_id = self._job['workspace'] + '.' + self._job['name'] + '@' + self._job['version'] + if self._job['type'] == 'workflow': + self._job = Workflow.get(job_id) + elif self._job['type'] == 'stage': + self._job = Stage(self._job) + else: + raise ValueError(f'Invalid job type: {self._job}') + return self._job + + def dict(self, id_only=False): + if id_only: + return { + 'run_id': self.run_id, + } + return { + 'run_id': self.run_id, + 'status': self.status, + 'job': self.job.dict(id_only=True), + 'try_num': self.try_num, + } + + @classmethod + def from_dict(cls, config): + return cls(**config) diff --git a/dataci/models/run/__init__.py b/dataci/models/run/__init__.py deleted file mode 100644 index cd61e35..0000000 --- a/dataci/models/run/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -Author: Li Yuanming -Email: yuanmingleee@gmail.com -Date: Mar 15, 2023 -""" -from .run import Run - -__all__ = ['Run'] diff --git a/dataci/models/run/list.py b/dataci/models/run/list.py deleted file mode 100644 index f1d733e..0000000 --- a/dataci/models/run/list.py +++ /dev/null @@ -1,42 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -Author: Li Yuanming -Email: yuanmingleee@gmail.com -Date: Mar 16, 2023 -""" -from collections import defaultdict -from typing import TYPE_CHECKING - -from dataci.pipeline.list import LIST_PIPELINE_IDENTIFIER_PATTERN -from dataci.run import Run - -from dataci.db.run import get_many_runs - -if TYPE_CHECKING: - from typing import Optional - from dataci.repo import Repo - - -def list_run(pipeline_identifier=None, tree_view=True, repo: 'Optional[Repo]' = None): - pipeline_identifier = pipeline_identifier or '*' - matched = LIST_PIPELINE_IDENTIFIER_PATTERN.match(pipeline_identifier) - if not matched: - raise ValueError(f'Invalid pipeline identifier {pipeline_identifier}') - pipeline_name, pipeline_version = matched.groups() - pipeline_version = (pipeline_version or '').lower() + '*' - - # Check matched runs - run_dict_list = get_many_runs(pipeline_name, pipeline_version) - run_list = list() - for run_dict in run_dict_list: - run_dict['repo'] = repo - run_list.append(Run.from_dict(run_dict)) - - if tree_view: - run_dict = defaultdict(lambda: defaultdict(list)) - for run in run_list: - run_dict[run.pipeline.name][run.pipeline.version].append(run) - return run_dict - - return run_list diff --git a/dataci/models/run/run.py b/dataci/models/run/run.py deleted file mode 100644 index 0fea86f..0000000 --- a/dataci/models/run/run.py +++ /dev/null @@ -1,90 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -Author: Li Yuanming -Email: yuanmingleee@gmail.com -Date: Mar 09, 2023 - -Run for pipeline. -""" -import os -from copy import deepcopy -from shutil import rmtree, copy2 -from typing import TYPE_CHECKING - -from dataci.utils import symlink_force - -if TYPE_CHECKING: - from dataci.pipeline.pipeline import Pipeline - - -class Run(object): - from .save import save # type: ignore[misc] - - def __init__(self, pipeline: 'Pipeline', run_num: int, **kwargs): - self.pipeline = pipeline - self.run_num = run_num - - @property - def workdir(self): - return self.pipeline.workdir / 'runs' / str(self.run_num) - - def prepare(self): - from dataci.dataset import Dataset - - # Clean all for the run workdir - if self.workdir.exists(): - rmtree(self.workdir) - # Create workdir folder - self.workdir.mkdir(parents=True) - # Link code to work directory - (self.workdir / self.pipeline.CODE_DIR).symlink_to( - self.pipeline.workdir / self.pipeline.CODE_DIR, target_is_directory=True - ) - # TODO: better way to prepare input feat - # Create feat dir and link feat into the feat dir - (self.workdir / self.pipeline.FEAT_DIR).mkdir(parents=True) - for stage in self.pipeline.stages: - for dependency in stage.dependency: - if isinstance(dependency, Dataset): - # Link global dataset files path to local - local_file_path = os.path.join( - self.workdir / self.pipeline.FEAT_DIR, dependency.name + dependency.dataset_files.suffix) - symlink_force(dependency.dataset_files, local_file_path) - dependency = local_file_path - - # Copy pipeline definition file to work directory - copy2(self.pipeline.workdir / 'dvc.yaml', self.workdir / 'dvc.yaml', ) - - @property - def feat(self): - outputs = deepcopy(self.pipeline.outputs) - outputs_dict = dict() - for output in outputs: - output.rebase(self.workdir) - outputs_dict[output.name] = output - return outputs_dict - - def __cmp__(self, other): - if not isinstance(other, Run): - raise ValueError(f'Compare between type {type(Run)} and {type(other)} is invalid.') - if self.pipeline != other.pipeline: - raise ValueError( - f'Compare between two different pipeline {self.pipeline} and {other.pipeline} is invalid.' - ) - return self.run_num.__cmp__(other.run_num) - - def __str__(self): - return str(self.pipeline) + f'.run{self.run_num}' - - def dict(self): - return {'run_num': self.run_num, 'pipeline': self.pipeline.dict()} - - @classmethod - def from_dict(cls, config): - from dataci.pipeline.pipeline import Pipeline - - config['pipeline']['repo'] = config.get('repo', None) - config['pipeline'] = Pipeline.from_dict(config['pipeline']) - config['pipeline'].restore() - return cls(**config) diff --git a/dataci/models/run/save.py b/dataci/models/run/save.py deleted file mode 100644 index ca06b07..0000000 --- a/dataci/models/run/save.py +++ /dev/null @@ -1,61 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -Author: Li Yuanming -Email: yuanmingleee@gmail.com -Date: Mar 15, 2023 -""" -import logging -import os -from pathlib import Path -from typing import TYPE_CHECKING - -import yaml - -from dataci.db.run import create_one_run -from dataci.utils import cwd - -if TYPE_CHECKING: - from .run import Run - -logger = logging.getLogger(__name__) - - -def save(run: 'Run' = ...): - with cwd(run.workdir): - ##################################################################### - # Step 1: Recover pipeline feat cached file (.dvc) from .lock - # TODO: The reason due to https://github.com/iterative/dvc/issues/4428 - ##################################################################### - if os.path.exists('dvc.lock'): - with open('dvc.lock', 'r') as f: - run_cache_lock = yaml.safe_load(f) - for k, v in run_cache_lock['stages'].items(): - for out in v['outs']: - logger.info(f'Recover dvc file {out["path"]}.dvc') - with open(out['path'] + '.dvc', 'w') as f: - yaml.safe_dump({ - 'outs': [ - { - 'md5': out['md5'], - 'path': os.path.basename(out['path']), - 'size': out['size'] - } - ] - }, f) - ##################################################################### - # Step 2: Publish pipeline output feat - ##################################################################### - for output in run.pipeline.outputs: - output.publish() - - ##################################################################### - # Step 3: Publish run object to DB - ##################################################################### - create_one_run(run.dict()) - - ##################################################################### - # Step 4: Remove feat cached file (.dvc) - ##################################################################### - for dvc_file in Path(run.pipeline.FEAT_DIR).glob('**/*.dvc'): - dvc_file.unlink() diff --git a/dataci/models/stage.py b/dataci/models/stage.py index bd9bd1e..d8ed2b8 100644 --- a/dataci/models/stage.py +++ b/dataci/models/stage.py @@ -74,6 +74,7 @@ def dict(self, id_only=False): if id_only: return { 'workspace': self.workspace.name, + 'type': self.type_name, 'name': self.name, 'version': self.version, } diff --git a/dataci/models/workflow.py b/dataci/models/workflow.py index 395dcba..a983ad5 100644 --- a/dataci/models/workflow.py +++ b/dataci/models/workflow.py @@ -91,7 +91,12 @@ def stage_script_paths(self): def dict(self, id_only=False): if id_only: - return {'workspace': self.workspace.name, 'name': self.name, 'version': self.version} + return { + 'workspace': self.workspace.name, + 'type': self.type_name, + 'name': self.name, + 'version': self.version + } # export the dag as a dict # 1. convert the dag to a list of edges # 2. convert each node from Stage to an id diff --git a/metadata/models.py b/metadata/models.py index f75b76d..68b14d0 100644 --- a/metadata/models.py +++ b/metadata/models.py @@ -5,12 +5,25 @@ Email: yuanmingleee@gmail.com Date: Nov 20, 2023 """ +import re from datetime import datetime from enum import Enum from typing import List, Optional, Dict, Any from uuid import UUID -from pydantic import BaseModel, Field, AnyUrl, Extra, root_validator +from packaging import version +from pydantic import BaseModel, Field, AnyUrl, Extra, root_validator, validator + +SCHEMA_VERSION = "2-0-2" +SCHEMA_PATH_PATTERN = re.compile(r'^/spec/(\d+)-(\d+)-(\d+)/OpenLineage.json$') + + +def parse_schema_version(schema_url: AnyUrl) -> str: + match = SCHEMA_PATH_PATTERN.match(schema_url.path) + if match: + major, minor, micro = match.groups() + return f"{major}-{minor}-{micro}" + raise ValueError(f"Invalid schema url: {schema_url}") class BaseEvent(BaseModel): @@ -24,6 +37,14 @@ class BaseEvent(BaseModel): example="https://openlineage.io/spec/0-0-1/OpenLineage.json", ) + @validator("schemaURL") + def check_schema_version(cls, v): + schema_version = parse_schema_version(v) + if version.parse(schema_version.replace('-', '.')) <= \ + version.parse(SCHEMA_VERSION.replace('-', '.')): + return v + raise ValueError(f"Invalid schema version: {v}") + class BaseFacet(BaseModel, extra=Extra.allow): """all fields of the base facet are prefixed with _ to avoid name conflicts in facets""" diff --git a/metadata/server.py b/metadata/server.py index 304f830..a612aed 100644 --- a/metadata/server.py +++ b/metadata/server.py @@ -5,20 +5,31 @@ Email: yuanmingleee@gmail.com Date: Nov 20, 2023 """ -from typing import List, Optional, Union +from typing import Union -from openlineage.client.run import DatasetEvent, JobEvent -from fastapi import FastAPI +from fastapi import APIRouter, FastAPI from metadata.models import RunEvent, DatasetEvent, JobEvent app = FastAPI() +api_router = APIRouter(prefix='/api/v1') + + # Record lineage information as schema defined as OpenLineage (2-0-2) # https://openlineage.io/apidocs/openapi/ -@app.post('/lineage', summary='Send an event related to the state of a run') +@api_router.post('/lineage', summary='Send an event related to the state of a run') def post_lineage(event: Union[RunEvent, DatasetEvent, JobEvent]): """Updates a run state for a job. """ - print(event) + print(event.json()) return {'status': 'success'} + + +app.include_router(api_router) + + +if __name__ == '__main__': + import uvicorn + + uvicorn.run(app, host='localhost', port=8000) From f8f3b3bbca4ddf6c1c37f6ef1c538caeeb1c1064 Mon Sep 17 00:00:00 2001 From: yuanmingleee Date: Fri, 24 Nov 2023 02:02:01 +0800 Subject: [PATCH 03/20] :wrench: [model] Add methods for run data model --- dataci/models/run.py | 67 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 56 insertions(+), 11 deletions(-) diff --git a/dataci/models/run.py b/dataci/models/run.py index bd5a29c..900be50 100644 --- a/dataci/models/run.py +++ b/dataci/models/run.py @@ -7,25 +7,44 @@ Run for pipeline. """ -import os -from copy import deepcopy -from shutil import rmtree, copy2 +import re from typing import TYPE_CHECKING -from dataci.utils import symlink_force +from dataci.models import BaseModel if TYPE_CHECKING: + from datetime import datetime from typing import Union from dataci.models import Workflow, Stage -class Run(object): - def __init__(self, run_id: str, status: str, job: 'Union[Workflow, Stage, dict]', try_num: int, **kwargs): - self.run_id = run_id +class Run(BaseModel): + # run id (uuid) + NAME_PATTERN = re.compile(r'^[a-f0-9]{8}-?[a-f0-9]{4}-?[a-f0-9]{4}-?[a-f0-9]{4}-?[a-f0-9]{12}$', flags=re.IGNORECASE) + VERSION_PATTERN = re.compile(r'^\d+$', flags=re.IGNORECASE) + type_name = 'run' + + def __init__( + self, + name: str, + status: str, + job: 'Union[Workflow, Stage, dict]', + try_num: int, + create_time: 'datetime', + update_time: 'datetime', + **kwargs + ): + super().__init__(name, **kwargs) self.status: str = status self._job = job - self.try_num = try_num + self.version = try_num + self.create_time = create_time + self.update_time = update_time + + @property + def try_num(self): + return self.version @property def job(self) -> 'Union[Workflow, Stage]': @@ -45,15 +64,41 @@ def job(self) -> 'Union[Workflow, Stage]': def dict(self, id_only=False): if id_only: return { - 'run_id': self.run_id, + 'workspace': self.workspace.name, + 'name': self.name, + 'version': self.version, + 'type': self.type_name, } return { - 'run_id': self.run_id, + 'workspace': self.workspace.name, + 'name': self.name, + 'version': self.version, 'status': self.status, 'job': self.job.dict(id_only=True), - 'try_num': self.try_num, + 'create_time': int(self.create_time.timestamp()), + 'update_time': int(self.update_time.timestamp()), } @classmethod def from_dict(cls, config): return cls(**config) + + def save(self, version=None): + # Get next run try number + version = self.version or get_next_run_version(name) + # Check if run exists + if exist_run_by_version(self.name, version): + # update run + return self.update() + create_one_run(self.dict()) + + def update(self): + pass + + @classmethod + def get(cls, name, version=None, not_found_ok=False): + """Get run by name and version.""" + workspace, name, version = cls.parse_data_model_get_identifier(name, version) + config = get_run_by_uuid(name) + + return cls.from_dict(config) From ef0301da700dc2947ae6b53b017e7aaf6ce8da61 Mon Sep 17 00:00:00 2001 From: YuanmingLeee Date: Fri, 24 Nov 2023 18:41:10 +0800 Subject: [PATCH 04/20] :wrench: [lineage] Add dao for run model --- dataci/db/init.py | 38 +++++++- dataci/db/run.py | 162 ++++++++++++++++++++++++----------- dataci/db/stage.py | 8 ++ dataci/db/workflow.py | 13 ++- dataci/models/run.py | 21 +++-- dataci/models/stage.py | 1 + dataci/models/workflow.py | 2 + metadata/lineage_analysis.py | 60 +++++++++++++ metadata/server.py | 24 +++++- 9 files changed, 267 insertions(+), 62 deletions(-) create mode 100644 metadata/lineage_analysis.py diff --git a/dataci/db/init.py b/dataci/db/init.py index dee5627..67575a0 100644 --- a/dataci/db/init.py +++ b/dataci/db/init.py @@ -17,6 +17,7 @@ # Drop all tables with db_connection: db_connection.executescript(""" + DROP TABLE IF EXISTS run; DROP TABLE IF EXISTS dataset_tag; DROP TABLE IF EXISTS dataset; DROP TABLE IF EXISTS workflow_dag_node; @@ -24,12 +25,23 @@ DROP TABLE IF EXISTS stage; DROP TABLE IF EXISTS workflow_tag; DROP TABLE IF EXISTS workflow; + DROP TABLE IF EXISTS job; """) logger.info('Drop all tables.') # Create dataset table with db_connection: db_connection.executescript(""" + CREATE TABLE job + ( + workspace TEXT, + name TEXT, + version TEXT, + type TEXT, + PRIMARY KEY (workspace, name, version, type), + UNIQUE (workspace, name, version, type) + ); + CREATE TABLE workflow ( workspace TEXT, @@ -45,7 +57,9 @@ script_filelist TEXT, script_hash TEXT, PRIMARY KEY (workspace, name, version), - UNIQUE (workspace, name, version) + UNIQUE (workspace, name, version), + FOREIGN KEY (workspace, name, version) + REFERENCES job (workspace, name, version) ); CREATE TABLE workflow_tag @@ -73,7 +87,9 @@ script_filelist TEXT, script_hash TEXT, PRIMARY KEY (workspace, name, version), - UNIQUE (workspace, name, version) + UNIQUE (workspace, name, version), + FOREIGN KEY (workspace, name, version) + REFERENCES job (workspace, name, version) ); CREATE TABLE stage_tag @@ -133,5 +149,23 @@ FOREIGN KEY (workspace, name, version) REFERENCES dataset (workspace, name, version) ); + + CREATE TABLE run + ( + workspace TEXT, + name TEXT, + version INTEGER, + status TEXT, + job_workspace TEXT, + job_name TEXT, + job_version TEXT, + job_type TEXT, + create_time INTEGER, + update_time INTEGER, + PRIMARY KEY (workspace, name, version), + UNIQUE (name, version), + FOREIGN KEY (workspace, name, version, job_type) + REFERENCES job (workspace, name, version, type) + ); """) logger.info('Create all tables.') diff --git a/dataci/db/run.py b/dataci/db/run.py index 302b3a5..aba1454 100644 --- a/dataci/db/run.py +++ b/dataci/db/run.py @@ -5,71 +5,133 @@ Email: yuanmingleee@gmail.com Date: Mar 14, 2023 """ -from . import db_connection +import sqlite3 +from copy import deepcopy +from dataci.config import DB_FILE -def get_next_run_num(pipeline_name, pipeline_version): - with db_connection: - (next_run_id,), = db_connection.execute( + +def create_one_run(config: dict): + config = deepcopy(config) + job_config = config.pop('job') + config['job_workspace'] = job_config['workspace'] + config['job_name'] = job_config['name'] + config['job_version'] = job_config['version'] + config['job_type'] = job_config['type'] + with sqlite3.connect(DB_FILE) as conn: + cur = conn.cursor() + cur.execute( """ - SELECT COALESCE(MAX(run_num), 0) + 1 AS next_run_id - FROM run - WHERE pipeline_name = ? - AND pipeline_version = ? + INSERT INTO run ( + workspace, + name, + version, + status, + job_workspace, + job_name, + job_version, + job_type, + create_time, + update_time + ) + VALUES (:workspace, :name, :version, :status, :job_workspace, :job_name, :job_version, :job_type, :create_time, :update_time) ; """, - (pipeline_name, pipeline_version) + config ) - return next_run_id + return cur.lastrowid -def create_one_run(run_dict): - pipeline_dict = run_dict['pipeline'] - with db_connection: - db_connection.execute( +def exist_run(name, version): + with sqlite3.connect(DB_FILE) as conn: + cur = conn.cursor() + (exists,), = cur.execute( """ - INSERT INTO run(run_num, pipeline_name, pipeline_version) VALUES - (?,?,?) + SELECT EXISTS( + SELECT 1 + FROM run + WHERE name = ? + AND version = ? + ) ; """, - (run_dict['run_num'], pipeline_dict['name'], - pipeline_dict['version']) + (name, version) ) + return exists -def get_many_runs(pipeline_name, pipeline_version): - with db_connection: - run_dict_iter = db_connection.execute( +def get_next_run_version(name): + with sqlite3.connect(DB_FILE) as conn: + cur = conn.cursor() + (version,), = cur.execute( """ - SELECT run.*, - timestamp - FROM ( - SELECT run_num, - pipeline_name, - pipeline_version - FROM run - WHERE pipeline_name GLOB ? - AND pipeline_version GLOB ? - ) run - JOIN ( - SELECT name, - version, - timestamp - FROM pipeline - WHERE name GLOB ? - AND version GLOB ? - ) pipeline - ON pipeline_name = name - AND pipeline_version = version + SELECT MAX(version) + 1 + FROM run + WHERE name = ? ; """, - (pipeline_name, pipeline_version, pipeline_name, pipeline_version), + (name,) ) - run_dict_list = list() - for run_po in run_dict_iter: - run_num, pipeline_name, pipeline_version, timestamp = run_po - run_dict_list.append({ - 'run_num': run_num, - 'pipeline': {'name': pipeline_name, 'version': pipeline_version, 'timestamp': timestamp}, - }) - return run_dict_list + return version or 1 + + +def get_one_run(name, version='latest'): + with sqlite3.connect(DB_FILE) as conn: + cur = conn.cursor() + if version == 'latest': + cur.execute( + """ + SELECT workspace + , name + , version + , status + , job_workspace + , job_name + , job_version + , job_type + , create_time + , update_time + FROM run + WHERE name = ? + ORDER BY version DESC + LIMIT 1 + ; + """, + (name,) + ) + else: + cur.execute( + """ + SELECT workspace + , name + , version + , status + , job_workspace + , job_name + , job_version + , job_type + , create_time + , update_time + FROM run + WHERE name = ? + AND version = ? + ; + """, + (name, version) + ) + + config = cur.fetchone() + return { + 'workspace': config[0], + 'name': config[1], + 'version': config[2], + 'status': config[3], + 'job': { + 'workspace': config[4], + 'name': config[5], + 'version': config[6], + 'type': config[7], + }, + 'create_time': config[8], + 'update_time': config[9], + } if config else None diff --git a/dataci/db/stage.py b/dataci/db/stage.py index 05834ab..c4780d1 100644 --- a/dataci/db/stage.py +++ b/dataci/db/stage.py @@ -24,6 +24,14 @@ def create_one_stage(stage_dict): cur = conn.cursor() cur.execute( """ + -- Insert into job first + INSERT INTO job (workspace, name, version, type) + VALUES (:workspace, :name, :version, :type) + ; + """, + stage_dict, + ) + cur.execute(""" INSERT INTO stage ( workspace, name, version, params, timestamp, script_dir, script_entry, script_filelist, script_hash ) diff --git a/dataci/db/workflow.py b/dataci/db/workflow.py index 5de40c8..aee8bd8 100644 --- a/dataci/db/workflow.py +++ b/dataci/db/workflow.py @@ -29,6 +29,14 @@ def create_one_workflow(config): with sqlite3.connect(DB_FILE) as conn: cur = conn.cursor() + cur.execute( + """ + INSERT INTO job (workspace, name, version, type) + VALUES (:workspace, :name, :version, :type) + ; + """, + workflow_dict, + ) cur.execute( """ INSERT INTO workflow ( @@ -381,7 +389,7 @@ def get_one_workflow_by_tag(workspace, name, tag): 'timestamp': config[4], 'params': '', 'flag': '', - 'trigger': json.loads(config[6]), + 'trigger': json.loads(config[6]) if config[6] is not None else list(), 'dag': { 'edge': json.loads(config[8]), }, @@ -396,7 +404,7 @@ def get_one_workflow_by_tag(workspace, name, tag): version = workflow_dict['version'] cur.execute( dedent(""" - SELECT stage_workspace, stage_name, stage_version, dag_node_id + SELECT stage_workspace, stage_name, stage_version, dag_node_id, dag_node_path FROM workflow_dag_node WHERE workflow_workspace=:workspace AND workflow_name=:name @@ -414,6 +422,7 @@ def get_one_workflow_by_tag(workspace, name, tag): 'workspace': node[0], 'name': node[1], 'version': node[2] if node[2] != '' else None, + 'path': node[4], } for node in cur.fetchall() } return workflow_dict diff --git a/dataci/models/run.py b/dataci/models/run.py index 900be50..0417782 100644 --- a/dataci/models/run.py +++ b/dataci/models/run.py @@ -10,11 +10,12 @@ import re from typing import TYPE_CHECKING +from dataci.db.run import exist_run, create_one_run, get_next_run_version, get_one_run from dataci.models import BaseModel if TYPE_CHECKING: from datetime import datetime - from typing import Union + from typing import Optional, Union from dataci.models import Workflow, Stage @@ -22,7 +23,7 @@ class Run(BaseModel): # run id (uuid) NAME_PATTERN = re.compile(r'^[a-f0-9]{8}-?[a-f0-9]{4}-?[a-f0-9]{4}-?[a-f0-9]{4}-?[a-f0-9]{12}$', flags=re.IGNORECASE) - VERSION_PATTERN = re.compile(r'^\d+$', flags=re.IGNORECASE) + VERSION_PATTERN = re.compile(r'^\d+|latest$', flags=re.IGNORECASE) type_name = 'run' def __init__( @@ -31,8 +32,8 @@ def __init__( status: str, job: 'Union[Workflow, Stage, dict]', try_num: int, - create_time: 'datetime', - update_time: 'datetime', + create_time: 'Optional[datetime]' = None, + update_time: 'Optional[datetime]' = None, **kwargs ): super().__init__(name, **kwargs) @@ -85,9 +86,9 @@ def from_dict(cls, config): def save(self, version=None): # Get next run try number - version = self.version or get_next_run_version(name) + version = self.version or get_next_run_version(self.name) # Check if run exists - if exist_run_by_version(self.name, version): + if exist_run(self.name, version): # update run return self.update() create_one_run(self.dict()) @@ -99,6 +100,12 @@ def update(self): def get(cls, name, version=None, not_found_ok=False): """Get run by name and version.""" workspace, name, version = cls.parse_data_model_get_identifier(name, version) - config = get_run_by_uuid(name) + # If version not set, get the latest version + version = version or 'latest' + config = get_one_run(name, version) + if config is None: + if not_found_ok: + return + raise ValueError(f'Run {name}@{version} not found.') return cls.from_dict(config) diff --git a/dataci/models/stage.py b/dataci/models/stage.py index d8ed2b8..88ae541 100644 --- a/dataci/models/stage.py +++ b/dataci/models/stage.py @@ -80,6 +80,7 @@ def dict(self, id_only=False): } return { 'name': self.name, + 'type': self.type_name, 'workspace': self.workspace.name, 'version': self.version, 'version_tag': self.version_tag, diff --git a/dataci/models/workflow.py b/dataci/models/workflow.py index a983ad5..9e9cde8 100644 --- a/dataci/models/workflow.py +++ b/dataci/models/workflow.py @@ -114,6 +114,7 @@ def dict(self, id_only=False): return { 'workspace': self.workspace.name, + 'type': self.type_name, 'name': self.name, 'version': self.version, 'dag': { @@ -226,6 +227,7 @@ def reload(self, config=None): if config is not None: stage_mapping[stage_full_name] = config # Update the stage script base path + print(config) self._stage_script_paths[stage_full_name] = stage_config['path'] for stage in self.stages.values(): if stage.full_name in stage_mapping: diff --git a/metadata/lineage_analysis.py b/metadata/lineage_analysis.py new file mode 100644 index 0000000..ba04601 --- /dev/null +++ b/metadata/lineage_analysis.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Author: Li Yuanming +Email: yuanmingleee@gmail.com +Date: Nov 21, 2023 +""" +import json +from copy import deepcopy + +from metadata.models import RunEvent + +JSON_PATH = 'bash_example_metadata.json' +# JSON_PATH = 'lineagetest_postgres.json' + +if __name__ == '__main__': + import builtins + import rich + + builtins.print = rich.print + +with open(JSON_PATH) as f: + lines = f.readlines() + + + +def merge_dicts(a: dict, b: dict): + """Merge dictionaries b into a.""" + result = deepcopy(a) + for k, v in b.items(): + if isinstance(v, dict): + result[k] = merge_dicts(result.get(k, dict()), v) + else: + result[k] = v + return result + + +events = [] +for line in lines: + events.append(json.loads(line)) + +runs = dict() +for event in events: + run_id = event['run']['runId'] + if run_id not in runs: + runs[run_id] = event + else: + # Merge the new event into the existing one + existing_event = runs[run_id] + if event['eventTime'] >= existing_event['eventTime']: + runs[run_id] = merge_dicts(existing_event, event) + else: + runs[run_id] = merge_dicts(event, existing_event) + + +# r = runs['db33cca4-8d48-3ade-9111-384c86657c25'] +r = runs['1ff96850-b38d-3247-b2cc-7ebe66e7d0d7'] +# r2 = runs['3b80ac13-0b1b-38b5-a70d-97e15e33189f'] + +print(json.dumps(r)) diff --git a/metadata/server.py b/metadata/server.py index a612aed..82c1650 100644 --- a/metadata/server.py +++ b/metadata/server.py @@ -9,6 +9,7 @@ from fastapi import APIRouter, FastAPI +from dataci.models import Run as RunModel from metadata.models import RunEvent, DatasetEvent, JobEvent app = FastAPI() @@ -22,7 +23,28 @@ def post_lineage(event: Union[RunEvent, DatasetEvent, JobEvent]): """Updates a run state for a job. """ - print(event.json()) + # Get job + job_workspace, job_name, job_version = event.job.name.split('--') + # Parse job type + if '.' in event.job.name: + job_type = 'workflow' + else: + job_type = 'stage' + + # Create run object and save + run = RunModel( + name=str(event.run.runId), + status=event.eventType.value, + job={ + 'job_workspace': job_workspace, + 'job_type': job_type, + 'job_name': job_name, + 'job_version': job_version, + }, + try_num=event.run.facets.get('airflow', {}).get('taskInstance', {}).get('try_number', None), + create_time=event.eventTime, + ) + print(run.dict()) return {'status': 'success'} From 2b08ab298d16a9a04946c93678b5d4ee0b72abdc Mon Sep 17 00:00:00 2001 From: yuanmingleee Date: Mon, 27 Nov 2023 01:58:56 +0800 Subject: [PATCH 05/20] :wrench: [model] Add run model update method --- dataci/config.py | 2 ++ dataci/db/run.py | 37 ++++++++++++++++++++++++++--- dataci/models/run.py | 56 ++++++++++++++++++++++++++++++++++---------- metadata/server.py | 39 ++++++++++++++++++------------ 4 files changed, 104 insertions(+), 30 deletions(-) diff --git a/dataci/config.py b/dataci/config.py index 2eb95e7..181a77f 100644 --- a/dataci/config.py +++ b/dataci/config.py @@ -8,6 +8,7 @@ import configparser import logging import os +from datetime import datetime, timezone from pathlib import Path from textwrap import dedent from threading import Event as ThreadEvent @@ -97,6 +98,7 @@ def load_config(): LOG_DIR = None LOG_LEVEL = None STORAGE_BACKEND = None +TIMEZONE = datetime.now(timezone.utc).astimezone().tzinfo # DataCI Trigger and Scheduler server SERVER_ADDRESS = '0.0.0.0' diff --git a/dataci/db/run.py b/dataci/db/run.py index aba1454..1b92b8a 100644 --- a/dataci/db/run.py +++ b/dataci/db/run.py @@ -42,6 +42,33 @@ def create_one_run(config: dict): return cur.lastrowid +def update_one_run(config): + config = deepcopy(config) + job_config = config.pop('job') + config['job_workspace'] = job_config['workspace'] + config['job_name'] = job_config['name'] + config['job_version'] = job_config['version'] + config['job_type'] = job_config['type'] + with sqlite3.connect(DB_FILE) as conn: + cur = conn.cursor() + cur.execute( + """ + UPDATE run + SET status = :status + , job_workspace = :job_workspace + , job_name = :job_name + , job_version = :job_version + , job_type = :job_type + , update_time = :update_time + WHERE name = :name + AND version = :version + ; + """, + config + ) + return cur.lastrowid + + def exist_run(name, version): with sqlite3.connect(DB_FILE) as conn: cur = conn.cursor() @@ -60,19 +87,23 @@ def exist_run(name, version): return exists -def get_next_run_version(name): +def get_latest_run_version(name): with sqlite3.connect(DB_FILE) as conn: cur = conn.cursor() (version,), = cur.execute( """ - SELECT MAX(version) + 1 + SELECT MAX(version) FROM run WHERE name = ? ; """, (name,) ) - return version or 1 + return version or 0 + + +def get_next_run_version(name): + return get_latest_run_version(name) + 1 def get_one_run(name, version='latest'): diff --git a/dataci/models/run.py b/dataci/models/run.py index 0417782..f65f2aa 100644 --- a/dataci/models/run.py +++ b/dataci/models/run.py @@ -10,11 +10,13 @@ import re from typing import TYPE_CHECKING -from dataci.db.run import exist_run, create_one_run, get_next_run_version, get_one_run +from dataci.config import TIMEZONE +from dataci.db.run import exist_run, create_one_run, get_next_run_version, get_latest_run_version, get_one_run, \ + patch_one_run from dataci.models import BaseModel if TYPE_CHECKING: - from datetime import datetime + from datetime import datetime, timezone from typing import Optional, Union from dataci.models import Workflow, Stage @@ -31,7 +33,7 @@ def __init__( name: str, status: str, job: 'Union[Workflow, Stage, dict]', - try_num: int, + version: int, create_time: 'Optional[datetime]' = None, update_time: 'Optional[datetime]' = None, **kwargs @@ -39,7 +41,7 @@ def __init__( super().__init__(name, **kwargs) self.status: str = status self._job = job - self.version = try_num + self.version = version self.create_time = create_time self.update_time = update_time @@ -75,26 +77,56 @@ def dict(self, id_only=False): 'name': self.name, 'version': self.version, 'status': self.status, - 'job': self.job.dict(id_only=True), - 'create_time': int(self.create_time.timestamp()), - 'update_time': int(self.update_time.timestamp()), + 'job': self.job.dict(id_only=True) if self.job else None, + 'create_time': int(self.create_time.replace(tzinfo=timezone.utc).timestamp()) if self.create_time else None, + 'update_time': int(self.update_time.replace(tzinfo=timezone.utc).timestamp()) if self.update_time else None, } @classmethod def from_dict(cls, config): - return cls(**config) + self = cls(**config) + return self.reload(config) + + def reload(self, config=None): + if config is None: + config = get_one_run(self.name, self.version) + self.version = config['version'] + self.create_time = datetime.fromtimestamp(config['create_time'], tz=TIMEZONE) + self.update_time = datetime.fromtimestamp(config['update_time'], tz=TIMEZONE) def save(self, version=None): # Get next run try number version = self.version or get_next_run_version(self.name) # Check if run exists if exist_run(self.name, version): - # update run - return self.update() - create_one_run(self.dict()) + # reload + config = get_one_run(self.name, version) + return self.reload(config) + + config = self.dict() + config['version'] = version + config['update_time'] = config['create_time'] + create_one_run(config) + return self.reload(config) def update(self): - pass + # Get latest run try number + version = self.version or get_latest_run_version(self.name) + # Check if run exists + run_prev = get_one_run(self.name, version) + if run_prev is None: + raise ValueError(f'Run {self.name}@{version} not found.') + # Update run by merging with previous run + config = self.dict() + config['version'] = version + config['create_time'] = run_prev['create_time'] + # Overwrite with previous field values if not set + for k, v in run_prev.items(): + if k not in config: + config[k] = v + + patch_one_run(config) + return self.reload(config) @classmethod def get(cls, name, version=None, not_found_ok=False): diff --git a/metadata/server.py b/metadata/server.py index 82c1650..9478175 100644 --- a/metadata/server.py +++ b/metadata/server.py @@ -10,7 +10,7 @@ from fastapi import APIRouter, FastAPI from dataci.models import Run as RunModel -from metadata.models import RunEvent, DatasetEvent, JobEvent +from metadata.models import RunEvent, DatasetEvent, JobEvent, RunState app = FastAPI() @@ -23,27 +23,36 @@ def post_lineage(event: Union[RunEvent, DatasetEvent, JobEvent]): """Updates a run state for a job. """ + # Skip if event is a test event (event job name cannot parse to workspace, job name and version) + name_parts = event.job.name.split('--') + if len(name_parts) != 3: + return {'status': 'skip'} + # Get job - job_workspace, job_name, job_version = event.job.name.split('--') + job_workspace, job_name, job_version = name_parts # Parse job type if '.' in event.job.name: job_type = 'workflow' else: job_type = 'stage' - # Create run object and save - run = RunModel( - name=str(event.run.runId), - status=event.eventType.value, - job={ - 'job_workspace': job_workspace, - 'job_type': job_type, - 'job_name': job_name, - 'job_version': job_version, - }, - try_num=event.run.facets.get('airflow', {}).get('taskInstance', {}).get('try_number', None), - create_time=event.eventTime, - ) + # If event type is START, create a new run + if event.eventType == RunState.START: + run = RunModel( + name=str(event.run.runId), + status=event.eventType.value, + job={ + 'job_workspace': job_workspace, + 'job_type': job_type, + 'job_name': job_name, + 'job_version': job_version, + }, + try_num=event.run.facets.get('airflow', {}).get('taskInstance', {}).get('try_number', None), + create_time=event.eventTime, + ) + run.save() + else: + run = RunModel.get(str(event.run.runId)) print(run.dict()) return {'status': 'success'} From fc2c63b463393169a75881ac1e7522483e121dad Mon Sep 17 00:00:00 2001 From: YuanmingLeee Date: Mon, 27 Nov 2023 17:16:25 +0800 Subject: [PATCH 06/20] :bug: [run] Fix bug in run model create and update APIs --- dataci/decorators/base.py | 4 ++-- dataci/models/run.py | 25 ++++++++++++++++--------- dataci/models/stage.py | 19 ++++++++++++++++++- dataci/models/workflow.py | 11 ++++++++--- metadata/server.py | 36 +++++++++++++++++++++++++----------- 5 files changed, 69 insertions(+), 26 deletions(-) diff --git a/dataci/decorators/base.py b/dataci/decorators/base.py index f8f93c9..ef611e8 100644 --- a/dataci/decorators/base.py +++ b/dataci/decorators/base.py @@ -48,8 +48,8 @@ def script(self): def test(self, *args, **kwargs): return self._stage.test(*args, **kwargs) - def dict(self): - return self._stage.dict() + def dict(self, id_only=False): + return self._stage.dict(id_only=id_only) def from_dict(self, config): self._stage.from_dict(config) diff --git a/dataci/models/run.py b/dataci/models/run.py index f65f2aa..20364f7 100644 --- a/dataci/models/run.py +++ b/dataci/models/run.py @@ -8,15 +8,16 @@ Run for pipeline. """ import re +import warnings +from datetime import datetime, timezone from typing import TYPE_CHECKING from dataci.config import TIMEZONE from dataci.db.run import exist_run, create_one_run, get_next_run_version, get_latest_run_version, get_one_run, \ - patch_one_run + update_one_run from dataci.models import BaseModel if TYPE_CHECKING: - from datetime import datetime, timezone from typing import Optional, Union from dataci.models import Workflow, Stage @@ -33,7 +34,6 @@ def __init__( name: str, status: str, job: 'Union[Workflow, Stage, dict]', - version: int, create_time: 'Optional[datetime]' = None, update_time: 'Optional[datetime]' = None, **kwargs @@ -41,7 +41,7 @@ def __init__( super().__init__(name, **kwargs) self.status: str = status self._job = job - self.version = version + self.version = None self.create_time = create_time self.update_time = update_time @@ -53,13 +53,14 @@ def try_num(self): def job(self) -> 'Union[Workflow, Stage]': """Lazy load job (workflow or stage) from database.""" from dataci.models import Workflow, Stage + from dataci.decorators.base import DecoratedOperatorStageMixin - if not isinstance(self._job, (Workflow, Stage)): - job_id = self._job['workspace'] + '.' + self._job['name'] + '@' + self._job['version'] + if not isinstance(self._job, (Workflow, Stage, DecoratedOperatorStageMixin)): + workflow_id = self._job['workspace'] + '.' + self._job['name'] + '@' + self._job['version'] if self._job['type'] == 'workflow': - self._job = Workflow.get(job_id) + self._job = Workflow.get(workflow_id) elif self._job['type'] == 'stage': - self._job = Stage(self._job) + self._job = Stage.get_by_workflow(self._job['stage_name'], workflow_id) else: raise ValueError(f'Invalid job type: {self._job}') return self._job @@ -91,8 +92,10 @@ def reload(self, config=None): if config is None: config = get_one_run(self.name, self.version) self.version = config['version'] + self.status = config['status'] self.create_time = datetime.fromtimestamp(config['create_time'], tz=TIMEZONE) self.update_time = datetime.fromtimestamp(config['update_time'], tz=TIMEZONE) + return self def save(self, version=None): # Get next run try number @@ -109,6 +112,10 @@ def save(self, version=None): create_one_run(config) return self.reload(config) + def publish(self): + warnings.warn('Run.publish(...) is not implemented. Use Run.save() instead.', DeprecationWarning) + pass + def update(self): # Get latest run try number version = self.version or get_latest_run_version(self.name) @@ -125,7 +132,7 @@ def update(self): if k not in config: config[k] = v - patch_one_run(config) + update_one_run(config) return self.reload(config) @classmethod diff --git a/dataci/models/stage.py b/dataci/models/stage.py index 88ae541..70c12c9 100644 --- a/dataci/models/stage.py +++ b/dataci/models/stage.py @@ -11,7 +11,6 @@ import shutil from collections import defaultdict from datetime import datetime -from pathlib import Path from typing import TYPE_CHECKING from dataci.db.stage import ( @@ -204,6 +203,24 @@ def get(cls, name, version=None): return cls.from_dict(config) + @classmethod + def get_by_workflow(cls, stage_name, workflow_name, workflow_version=None): + """Get the stage from the workspace.""" + from dataci.models.workflow import Workflow + + workflow_config = Workflow.get_config(workflow_name, workflow_version) + if workflow_config is None: + raise ValueError(f'Workflow {workflow_name}@{workflow_version} not found') + # Find stage version + for _, v in workflow_config['dag']['node'].items(): + if v['name'] == stage_name: + stage_version = v['version'] + break + else: + raise ValueError(f'Stage {stage_name} not found in workflow {workflow_name}@{workflow_version}') + + return cls.get(stage_name, stage_version) + @classmethod def find(cls, stage_identifier, tree_view=False, all=False): """Find the stage from the workspace.""" diff --git a/dataci/models/workflow.py b/dataci/models/workflow.py index 9e9cde8..a7a027d 100644 --- a/dataci/models/workflow.py +++ b/dataci/models/workflow.py @@ -227,7 +227,6 @@ def reload(self, config=None): if config is not None: stage_mapping[stage_full_name] = config # Update the stage script base path - print(config) self._stage_script_paths[stage_full_name] = stage_config['path'] for stage in self.stages.values(): if stage.full_name in stage_mapping: @@ -346,8 +345,8 @@ def publish(self): return self.reload(config) @classmethod - def get(cls, name: str, version: str = None): - """Get a models from the workspace.""" + def get_config(cls, name: str, version: str = None): + """Get workflow config only""" workspace, name, version = cls.parse_data_model_get_identifier(name, version) if version is None or version == 'latest' or version.startswith('v'): @@ -358,6 +357,12 @@ def get(cls, name: str, version: str = None): if version.lower() == 'none': version = None config = get_one_workflow_by_version(workspace, name, version) + return config + + @classmethod + def get(cls, name: str, version: str = None): + """Get a models from the workspace.""" + config = cls.get_config(name, version) if config is None: return diff --git a/metadata/server.py b/metadata/server.py index 9478175..8a37753 100644 --- a/metadata/server.py +++ b/metadata/server.py @@ -28,13 +28,16 @@ def post_lineage(event: Union[RunEvent, DatasetEvent, JobEvent]): if len(name_parts) != 3: return {'status': 'skip'} - # Get job - job_workspace, job_name, job_version = name_parts # Parse job type if '.' in event.job.name: - job_type = 'workflow' - else: job_type = 'stage' + # Get job + job_workspace, job_name, job_version = name_parts + job_version, stage_name = job_version.split('.') + else: + job_type = 'workflow' + job_workspace, job_name, job_version = name_parts + stage_name = None # If event type is START, create a new run if event.eventType == RunState.START: @@ -42,18 +45,29 @@ def post_lineage(event: Union[RunEvent, DatasetEvent, JobEvent]): name=str(event.run.runId), status=event.eventType.value, job={ - 'job_workspace': job_workspace, - 'job_type': job_type, - 'job_name': job_name, - 'job_version': job_version, + 'workspace': job_workspace, + 'type': job_type, + 'name': job_name, + 'version': job_version, + 'stage_name': stage_name, }, - try_num=event.run.facets.get('airflow', {}).get('taskInstance', {}).get('try_number', None), create_time=event.eventTime, ) run.save() else: - run = RunModel.get(str(event.run.runId)) - print(run.dict()) + run = RunModel( + name=str(event.run.runId), + status=event.eventType.value, + job={ + 'workspace': job_workspace, + 'type': job_type, + 'name': job_name, + 'version': job_version, + 'stage_name': stage_name, + }, + update_time=event.eventTime, + ) + run.update() return {'status': 'success'} From ef794ba02a14c901b98972fb0bc41c5561e7011a Mon Sep 17 00:00:00 2001 From: yuanmingleee Date: Tue, 28 Nov 2023 01:54:01 +0800 Subject: [PATCH 07/20] :sparkles: [model] Add lineage data model --- dataci/db/init.py | 30 +++++++++ dataci/db/lineage.py | 135 +++++++++++++++++++++++++++++++++++++++ dataci/models/lineage.py | 40 +++++++++++- 3 files changed, 203 insertions(+), 2 deletions(-) create mode 100644 dataci/db/lineage.py diff --git a/dataci/db/init.py b/dataci/db/init.py index 67575a0..623cd91 100644 --- a/dataci/db/init.py +++ b/dataci/db/init.py @@ -17,6 +17,8 @@ # Drop all tables with db_connection: db_connection.executescript(""" + DROP TABLE IF EXISTS run_dataset_lineage; + DROP TABLE IF EXISTS run_lineage; DROP TABLE IF EXISTS run; DROP TABLE IF EXISTS dataset_tag; DROP TABLE IF EXISTS dataset; @@ -167,5 +169,33 @@ FOREIGN KEY (workspace, name, version, job_type) REFERENCES job (workspace, name, version, type) ); + + CREATE TABLE run_lineage + ( + run_name TEXT, + run_version INTEGER, + parent_run_name TEXT, + parent_run_version INTEGER, + PRIMARY KEY (run_name, run_version, parent_run_name, parent_run_version), + FOREIGN KEY (run_name, run_version) + REFERENCES run (name, version), + FOREIGN KEY (parent_run_name, parent_run_version) + REFERENCES run (name, version) + ); + + CREATE TABLE run_dataset_lineage + ( + run_name TEXT, + run_version INTEGER, + dataset_workspace TEXT, + dataset_name TEXT, + dataset_version INTEGER, + direction TEXT, + PRIMARY KEY (run_name, run_version, dataset_workspace, dataset_name, dataset_version, direction), + FOREIGN KEY (run_name, run_version) + REFERENCES run (name, version), + FOREIGN KEY (dataset_workspace, dataset_name, dataset_version) + REFERENCES dataset (workspace, name, version) + ); """) logger.info('Create all tables.') diff --git a/dataci/db/lineage.py b/dataci/db/lineage.py new file mode 100644 index 0000000..cad9d0a --- /dev/null +++ b/dataci/db/lineage.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Author: Li Yuanming +Email: yuanmingleee@gmail.com +Date: Nov 27, 2023 +""" +import sqlite3 + +from dataci.config import DB_FILE + + +def get_lineage(): + pass + + +def exist_run_lineage(run_name, run_version, parent_run_name, parent_run_version): + with sqlite3.connect(DB_FILE) as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT EXISTS( + SELECT 1 + FROM run_lineage + WHERE run_name = :run_name + AND run_version = :run_version + AND parent_run_name = :parent_run_name + AND parent_run_version = :parent_run_version + ) + ; + """, + { + 'run_name': run_name, + 'run_version': run_version, + 'parent_run_name': parent_run_name, + 'parent_run_version': parent_run_version, + } + ) + return cur.fetchone()[0] + + +def exist_many_run_dataset_lineage(run_name, run_version, dataset_configs): + + run_dataset_params = [ + { + 'run_name': run_name, + 'run_version': run_version, + 'dataset_workspace': dataset_configs['workspace'], + 'dataset_name': dataset_configs['name'], + 'dataset_version': dataset_configs['version'], + 'direction': dataset_configs['direction'], + } + for dataset_configs in dataset_configs + ] + + with sqlite3.connect(DB_FILE) as conn: + cur = conn.cursor() + cur.executemany( + """ + SELECT EXISTS( + SELECT 1 + FROM run_dataset_lineage + WHERE run_name = :run_name + AND run_version = :run_version + AND dataset_workspace = :dataset_workspace + AND dataset_name = :dataset_name + AND dataset_version = :dataset_version + AND direction = :direction + ) + ; + """, + run_dataset_params, + ) + return [row[0] for row in cur.fetchall()] + + +def create_one_lineage(config): + run_lineage_config = { + 'run_name': config['run']['name'], + 'run_version': config['run']['version'], + 'parent_run_name': config['parent_run']['name'], + 'parent_run_version': config['parent_run']['version'], + } + + run_dataset_lineage_configs = list() + for dataset in config['inputs']: + run_dataset_lineage_configs.append({ + 'run_name': config['run']['name'], + 'run_version': config['run']['version'], + 'dataset_name': dataset['name'], + 'dataset_version': dataset['version'], + 'direction': 'input', + }) + for dataset in config['outputs']: + run_dataset_lineage_configs.append({ + 'run_name': config['run']['name'], + 'run_version': config['run']['version'], + 'dataset_name': dataset['name'], + 'dataset_version': dataset['version'], + 'direction': 'output', + }) + + with sqlite3.connect(DB_FILE) as conn: + cur = conn.cursor() + # add parent_run -> run + cur.execute( + """ + INSERT INTO run_lineage ( + run_name + ,run_version + ,parent_run_name + ,parent_run_version + ) + VALUES (:run_name, :run_version, :parent_run_name, :parent_run_version) + ; + """, + run_lineage_config + ) + + # add run -> dataset + cur.executemany( + """ + INSERT INTO run_dataset_lineage ( + run_name + ,run_version + ,dataset_workspace + ,dataset_name + ,dataset_version + ,direction + ) + VALUES (:run_name, :run_version, :dataset_workspace, :dataset_name, :dataset_version, :direction) + ; + """, + run_dataset_lineage_configs + ) diff --git a/dataci/models/lineage.py b/dataci/models/lineage.py index 41f9edb..66075d1 100644 --- a/dataci/models/lineage.py +++ b/dataci/models/lineage.py @@ -7,6 +7,8 @@ """ from typing import TYPE_CHECKING +from dataci.db.lineage import create_one_lineage, exist_run_lineage, exist_many_run_dataset_lineage + if TYPE_CHECKING: from typing import List, Optional, Union @@ -48,7 +50,7 @@ def run(self) -> 'Run': from dataci.models import Run if not isinstance(self._run, Run): - self._run = Run.get(self._run['run_id']) + self._run = Run.get(self._run['name']) return self._run @property @@ -60,7 +62,7 @@ def parent_run(self) -> 'Optional[Run]': return None if not isinstance(self._parent_run, Run): - self._parent_run = Run.get(self._parent_run['run_id']) + self._parent_run = Run.get(self._parent_run['name']) return self._parent_run @property @@ -89,3 +91,37 @@ def outputs(self) -> 'List[Dataset]': self._outputs = outputs return self._outputs + def save(self, exist_ok=False): + config = self.dict() + run_lineage_exist = exist_run_lineage( + config['run']['name'], + config['run']['version'], + config['parent_run'].get('name', None), + config['parent_run'].get('version', None) + ) + dataset_config_list = config['inputs'] + config['outputs'] + + dataset_lineage_exist_list = exist_many_run_dataset_lineage( + config['run']['name'], + config['run']['version'], + dataset_config_list + ) + + # Check if run lineage exists + if run_lineage_exist: + if not exist_ok: + raise ValueError(f'Run lineage {self.parent_run} -> {self.run} exists.') + else: + # Create run lineage + ... + + # Check if dataset lineage exists + if any(dataset_lineage_exist_list): + if not exist_ok: + raise ValueError(f'Dataset lineage exists.') + else: + # Create dataset lineage + ... + + create_one_lineage(config) + From 2ea0a4557f5f71694b5b2b073a477547bb99d116 Mon Sep 17 00:00:00 2001 From: YuanmingLeee Date: Wed, 29 Nov 2023 15:48:34 +0800 Subject: [PATCH 08/20] :bug: [lineage] Fix lineage save API --- dataci/db/lineage.py | 61 +++++++++++++++++++++------------------- dataci/models/lineage.py | 49 +++++++++++++++++++++++--------- dataci/models/run.py | 5 +++- metadata/server.py | 19 ++++++++++++- 4 files changed, 90 insertions(+), 44 deletions(-) diff --git a/dataci/db/lineage.py b/dataci/db/lineage.py index cad9d0a..ce0bd53 100644 --- a/dataci/db/lineage.py +++ b/dataci/db/lineage.py @@ -6,6 +6,7 @@ Date: Nov 27, 2023 """ import sqlite3 +from contextlib import nullcontext from dataci.config import DB_FILE @@ -74,7 +75,7 @@ def exist_many_run_dataset_lineage(run_name, run_version, dataset_configs): return [row[0] for row in cur.fetchall()] -def create_one_lineage(config): +def create_one_run_lineage(config, cursor=None): run_lineage_config = { 'run_name': config['run']['name'], 'run_version': config['run']['version'], @@ -82,26 +83,8 @@ def create_one_lineage(config): 'parent_run_version': config['parent_run']['version'], } - run_dataset_lineage_configs = list() - for dataset in config['inputs']: - run_dataset_lineage_configs.append({ - 'run_name': config['run']['name'], - 'run_version': config['run']['version'], - 'dataset_name': dataset['name'], - 'dataset_version': dataset['version'], - 'direction': 'input', - }) - for dataset in config['outputs']: - run_dataset_lineage_configs.append({ - 'run_name': config['run']['name'], - 'run_version': config['run']['version'], - 'dataset_name': dataset['name'], - 'dataset_version': dataset['version'], - 'direction': 'output', - }) - - with sqlite3.connect(DB_FILE) as conn: - cur = conn.cursor() + with sqlite3.connect(DB_FILE) if cursor is None else nullcontext() as conn: + cur = cursor or conn.cursor() # add parent_run -> run cur.execute( """ @@ -117,19 +100,39 @@ def create_one_lineage(config): run_lineage_config ) - # add run -> dataset + +def create_many_dataset_lineage(config, cursor=None): + dataset_configs = list() + for dataset in config['inputs']: + dataset_configs.append({ + 'run_name': config['run']['name'], + 'run_version': config['run']['version'], + 'dataset_name': dataset['name'], + 'dataset_version': dataset['version'], + 'direction': 'input', + }) + for dataset in config['outputs']: + dataset_configs.append({ + 'run_name': config['run']['name'], + 'run_version': config['run']['version'], + 'dataset_name': dataset['name'], + 'dataset_version': dataset['version'], + 'direction': 'output', + }) + with sqlite3.connect(DB_FILE) if cursor is None else nullcontext() as conn: + cur = cursor or conn.cursor() cur.executemany( """ INSERT INTO run_dataset_lineage ( - run_name - ,run_version - ,dataset_workspace - ,dataset_name - ,dataset_version - ,direction + run_name + ,run_version + ,dataset_workspace + ,dataset_name + ,dataset_version + ,direction ) VALUES (:run_name, :run_version, :dataset_workspace, :dataset_name, :dataset_version, :direction) ; """, - run_dataset_lineage_configs + dataset_configs, ) diff --git a/dataci/models/lineage.py b/dataci/models/lineage.py index 66075d1..0cb4b69 100644 --- a/dataci/models/lineage.py +++ b/dataci/models/lineage.py @@ -5,9 +5,13 @@ Email: yuanmingleee@gmail.com Date: Nov 22, 2023 """ +import sqlite3 from typing import TYPE_CHECKING -from dataci.db.lineage import create_one_lineage, exist_run_lineage, exist_many_run_dataset_lineage +from dataci.config import DB_FILE +from dataci.db.lineage import ( + exist_run_lineage, exist_many_run_dataset_lineage, create_one_run_lineage, create_many_dataset_lineage +) if TYPE_CHECKING: from typing import List, Optional, Union @@ -91,37 +95,56 @@ def outputs(self) -> 'List[Dataset]': self._outputs = outputs return self._outputs - def save(self, exist_ok=False): + def save(self, exist_ok=True): config = self.dict() - run_lineage_exist = exist_run_lineage( + run_lineage_exist = (config['parent_run'] is None) or exist_run_lineage( config['run']['name'], config['run']['version'], config['parent_run'].get('name', None), config['parent_run'].get('version', None) ) - dataset_config_list = config['inputs'] + config['outputs'] dataset_lineage_exist_list = exist_many_run_dataset_lineage( config['run']['name'], config['run']['version'], - dataset_config_list + config['inputs'] + config['outputs'] ) # Check if run lineage exists + is_create_run_lineage = False if run_lineage_exist: if not exist_ok: raise ValueError(f'Run lineage {self.parent_run} -> {self.run} exists.') - else: - # Create run lineage - ... + else: + # Set create run lineage to True + is_create_run_lineage = True # Check if dataset lineage exists if any(dataset_lineage_exist_list): if not exist_ok: raise ValueError(f'Dataset lineage exists.') else: - # Create dataset lineage - ... - - create_one_lineage(config) - + # Remove the existed dataset lineage + config['inputs'] = [ + dataset_config for dataset_config, exist in zip( + config['inputs'], dataset_lineage_exist_list[:len(config['inputs'])] + ) if exist + ] + config['outputs'] = [ + dataset_config for dataset_config, exist in zip( + config['outputs'], dataset_lineage_exist_list[len(config['inputs']):] + ) if exist + ] + + with sqlite3.connect(DB_FILE) as conn: + cur = conn.cursor() + # Create run lineage + if is_create_run_lineage: + create_one_run_lineage(config, cur) + # Create dataset lineage + create_many_dataset_lineage(config, cur) + + return self + + def get(self, run_name, run_version): + pass diff --git a/dataci/models/run.py b/dataci/models/run.py index 20364f7..fd901fb 100644 --- a/dataci/models/run.py +++ b/dataci/models/run.py @@ -27,6 +27,9 @@ class Run(BaseModel): # run id (uuid) NAME_PATTERN = re.compile(r'^[a-f0-9]{8}-?[a-f0-9]{4}-?[a-f0-9]{4}-?[a-f0-9]{4}-?[a-f0-9]{12}$', flags=re.IGNORECASE) VERSION_PATTERN = re.compile(r'^\d+|latest$', flags=re.IGNORECASE) + GET_DATA_MODEL_IDENTIFIER_PATTERN = re.compile( + r'^(?:([a-z]\w*)\.)?([a-f0-9]{8}-?[a-f0-9]{4}-?[a-f0-9]{4}-?[a-f0-9]{4}-?[a-f0-9]{12})(\d+|latest$)?$', flags=re.IGNORECASE + ) type_name = 'run' def __init__( @@ -114,7 +117,7 @@ def save(self, version=None): def publish(self): warnings.warn('Run.publish(...) is not implemented. Use Run.save() instead.', DeprecationWarning) - pass + return self def update(self): # Get latest run try number diff --git a/metadata/server.py b/metadata/server.py index 8a37753..78e79c2 100644 --- a/metadata/server.py +++ b/metadata/server.py @@ -5,11 +5,12 @@ Email: yuanmingleee@gmail.com Date: Nov 20, 2023 """ +import json from typing import Union from fastapi import APIRouter, FastAPI -from dataci.models import Run as RunModel +from dataci.models import Run as RunModel, Lineage from metadata.models import RunEvent, DatasetEvent, JobEvent, RunState app = FastAPI() @@ -68,6 +69,22 @@ def post_lineage(event: Union[RunEvent, DatasetEvent, JobEvent]): update_time=event.eventTime, ) run.update() + + # get parent run if exists + if 'parent' in event.run.facets: + parent_run_config = { + 'name': str(event.run.facets['parent'].run['runId']), + } + else: + parent_run_config = None + + Lineage( + run=run, + parent_run=parent_run_config, + inputs=[], + outputs=[], + ).save() + return {'status': 'success'} From bb9e032660de1b1ec050547293c3f1a7d0b8103a Mon Sep 17 00:00:00 2001 From: YuanmingLeee Date: Wed, 29 Nov 2023 18:38:50 +0800 Subject: [PATCH 09/20] :wrench: [lineage] Add lineage track for DataCI dataset --- dataci/db/lineage.py | 64 +++++++++++++------------- dataci/models/dataset.py | 10 +++- dataci/models/lineage.py | 35 ++++++++------ dataci/plugins/orchestrator/airflow.py | 10 +++- metadata/server.py | 13 +++++- 5 files changed, 82 insertions(+), 50 deletions(-) diff --git a/dataci/db/lineage.py b/dataci/db/lineage.py index ce0bd53..02bf45d 100644 --- a/dataci/db/lineage.py +++ b/dataci/db/lineage.py @@ -41,36 +41,42 @@ def exist_run_lineage(run_name, run_version, parent_run_name, parent_run_version def exist_many_run_dataset_lineage(run_name, run_version, dataset_configs): - - run_dataset_params = [ - { - 'run_name': run_name, - 'run_version': run_version, - 'dataset_workspace': dataset_configs['workspace'], - 'dataset_name': dataset_configs['name'], - 'dataset_version': dataset_configs['version'], - 'direction': dataset_configs['direction'], - } - for dataset_configs in dataset_configs - ] + # Return empty list if no dataset_configs, + # this prevents SQL syntax error when generating SQL statement + if len(dataset_configs) == 0: + return list() with sqlite3.connect(DB_FILE) as conn: cur = conn.cursor() - cur.executemany( - """ - SELECT EXISTS( - SELECT 1 + sql_dataset_values = ',\n'.join([ + repr((dataset_config['workspace'], dataset_config['name'], dataset_config['version'], + dataset_config['direction'])) + for dataset_config in dataset_configs + ]) + cur.execute( + f""" + WITH datasets (dataset_workspace, dataset_name, dataset_version, direction) AS ( + VALUES {sql_dataset_values} + ) + ,lineage AS ( + SELECT TRUE AS flg + ,dataset_workspace + ,dataset_name + ,dataset_version + ,direction FROM run_dataset_lineage WHERE run_name = :run_name AND run_version = :run_version - AND dataset_workspace = :dataset_workspace - AND dataset_name = :dataset_name - AND dataset_version = :dataset_version - AND direction = :direction - ) + ) + SELECT COALESCE(flg, FALSE) AS flg + FROM datasets + LEFT JOIN lineage USING (dataset_workspace, dataset_name, dataset_version, direction) ; """, - run_dataset_params, + { + 'run_name': run_name, + 'run_version': run_version, + } ) return [row[0] for row in cur.fetchall()] @@ -103,22 +109,16 @@ def create_one_run_lineage(config, cursor=None): def create_many_dataset_lineage(config, cursor=None): dataset_configs = list() - for dataset in config['inputs']: - dataset_configs.append({ - 'run_name': config['run']['name'], - 'run_version': config['run']['version'], - 'dataset_name': dataset['name'], - 'dataset_version': dataset['version'], - 'direction': 'input', - }) - for dataset in config['outputs']: + for dataset in config['inputs'] + config['outputs']: dataset_configs.append({ 'run_name': config['run']['name'], 'run_version': config['run']['version'], + 'dataset_workspace': dataset['workspace'], 'dataset_name': dataset['name'], 'dataset_version': dataset['version'], - 'direction': 'output', + 'direction': dataset['direction'], }) + with sqlite3.connect(DB_FILE) if cursor is None else nullcontext() as conn: cur = cursor or conn.cursor() cur.executemany( diff --git a/dataci/models/dataset.py b/dataci/models/dataset.py index 4a12ccf..97cd28a 100644 --- a/dataci/models/dataset.py +++ b/dataci/models/dataset.py @@ -341,6 +341,13 @@ def from_dict(cls, config): return dataset_obj def dict(self, id_only=False): + if id_only: + return { + 'workspace': self.workspace.name, + 'type': self.type_name, + 'name': self.name, + 'version': self.version, + } config = { 'workspace': self.workspace.name, 'name': self.name, @@ -447,7 +454,8 @@ def publish(self, version_tag=None): return self.reload(config) @classmethod - def get(cls, name: str, version=None, not_found_ok=False, file_reader='auto', file_writer='csv'): + def get(cls, name: str, workspace=None, version=None, not_found_ok=False, file_reader='auto', file_writer='csv'): + name = workspace + '.' + name if workspace else name workspace, name, version_or_tag = cls.parse_data_model_get_identifier(name, version) if version_or_tag is None or cls.VERSION_TAG_PATTERN.match(version_or_tag) is not None: diff --git a/dataci/models/lineage.py b/dataci/models/lineage.py index 0cb4b69..397eab6 100644 --- a/dataci/models/lineage.py +++ b/dataci/models/lineage.py @@ -6,17 +6,19 @@ Date: Nov 22, 2023 """ import sqlite3 +import warnings from typing import TYPE_CHECKING from dataci.config import DB_FILE from dataci.db.lineage import ( exist_run_lineage, exist_many_run_dataset_lineage, create_one_run_lineage, create_many_dataset_lineage ) +from dataci.models.dataset import Dataset if TYPE_CHECKING: from typing import List, Optional, Union - from dataci.models import Dataset, Workflow, Stage, Run + from dataci.models import Workflow, Stage, Run class Lineage(object): @@ -24,8 +26,8 @@ def __init__( self, run: 'Union[Run, dict]', parent_run: 'Optional[Union[Run, dict]]' = None, - inputs: 'List[Union[Dataset, dict]]' = None, - outputs: 'List[Union[Dataset, dict]]' = None, + inputs: 'List[Union[Dataset, dict, str]]' = None, + outputs: 'List[Union[Dataset, dict, str]]' = None, ): self._run = run self._parent_run = parent_run @@ -74,11 +76,12 @@ def inputs(self) -> 'List[Dataset]': """Lazy load inputs from database.""" inputs = list() for input_ in self._inputs: - if not isinstance(input_, Dataset): - dataset_id = input_['workspace'] + '.' + input_['name'] + '@' + input_['version'] - inputs.append(Dataset.get(dataset_id)) - else: + if isinstance(input_, Dataset): inputs.append(input_) + elif isinstance(input_, dict): + inputs.append(Dataset.get(**input_)) + else: + warnings.warn(f'Unable to parse input {input_}') self._inputs = inputs return self._inputs @@ -87,16 +90,22 @@ def outputs(self) -> 'List[Dataset]': """Lazy load outputs from database.""" outputs = list() for output in self._outputs: - if not isinstance(output, Dataset): - dataset_id = output['workspace'] + '.' + output['name'] + '@' + output['version'] - outputs.append(Dataset.get(dataset_id)) - else: + if isinstance(output, Dataset): outputs.append(output) + elif isinstance(output, dict): + outputs.append(Dataset.get(**output)) + else: + warnings.warn(f'Unable to parse output {output}') self._outputs = outputs return self._outputs def save(self, exist_ok=True): config = self.dict() + for input_ in config['inputs']: + input_['direction'] = 'input' + for output in config['outputs']: + output['direction'] = 'output' + run_lineage_exist = (config['parent_run'] is None) or exist_run_lineage( config['run']['name'], config['run']['version'], @@ -128,12 +137,12 @@ def save(self, exist_ok=True): config['inputs'] = [ dataset_config for dataset_config, exist in zip( config['inputs'], dataset_lineage_exist_list[:len(config['inputs'])] - ) if exist + ) if not exist ] config['outputs'] = [ dataset_config for dataset_config, exist in zip( config['outputs'], dataset_lineage_exist_list[len(config['inputs']):] - ) if exist + ) if not exist ] with sqlite3.connect(DB_FILE) as conn: diff --git a/dataci/plugins/orchestrator/airflow.py b/dataci/plugins/orchestrator/airflow.py index d3331e4..4cd6f6c 100644 --- a/dataci/plugins/orchestrator/airflow.py +++ b/dataci/plugins/orchestrator/airflow.py @@ -227,6 +227,12 @@ def execute_callable(self) -> 'Any': bound.arguments[arg_name] = dataset.read() # For logging self.log.info(f'Input table {arg_name}: {dataset.identifier}') + # Load back to the input table + self.input_table[arg_name] = { + 'name': dataset.identifier, + 'file_reader': dataset.file_reader.NAME, + 'file_writer': dataset.file_writer.NAME + } self.op_args, self.op_kwargs = bound.args, bound.kwargs # Run the stage by backend @@ -237,7 +243,7 @@ def execute_callable(self) -> 'Any': if self.multiple_outputs: dataset.dataset_files = ret[key] dataset.save() - ret[key] = { + self.output_table[key] = ret[key] = { 'name': dataset.identifier, 'file_reader': dataset.file_reader.NAME, 'file_writer': dataset.file_writer.NAME @@ -245,7 +251,7 @@ def execute_callable(self) -> 'Any': else: dataset.dataset_files = ret dataset.save() - ret = { + self.output_table[key] = ret = { 'name': dataset.identifier, 'file_reader': dataset.file_reader.NAME, 'file_writer': dataset.file_writer.NAME diff --git a/metadata/server.py b/metadata/server.py index 78e79c2..c1162a2 100644 --- a/metadata/server.py +++ b/metadata/server.py @@ -6,6 +6,7 @@ Date: Nov 20, 2023 """ import json +import traceback from typing import Union from fastapi import APIRouter, FastAPI @@ -78,11 +79,19 @@ def post_lineage(event: Union[RunEvent, DatasetEvent, JobEvent]): else: parent_run_config = None + # Get input and output dataset + # Inputs: event.run.facets['unknownSourceAttribute'].unknownItems[0]['properties']['input_table'] + # Outputs: event.run.facets['unknownSourceAttribute'].unknownItems[0]['properties']['output_table'] + unknown_src_attr = event.run.facets.get('unknownSourceAttribute', object()) + ops_props = (getattr(unknown_src_attr, 'unknownItems', None) or [dict()])[0].get('properties', dict()) + inputs = list(ops_props.get('input_table', dict()).values()) + outputs = list(ops_props.get('output_table', dict()).values()) + Lineage( run=run, parent_run=parent_run_config, - inputs=[], - outputs=[], + inputs=inputs, + outputs=outputs, ).save() return {'status': 'success'} From 0ea3dfd131f6f7060296033677489a320b471b7c Mon Sep 17 00:00:00 2001 From: YuanmingLeee Date: Thu, 30 Nov 2023 18:38:00 +0800 Subject: [PATCH 10/20] :bug: [lineage] Fix bug in output data not tracked by lineage --- dataci/db/lineage.py | 2 +- dataci/models/lineage.py | 6 ++++-- metadata/server.py | 7 +++---- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/dataci/db/lineage.py b/dataci/db/lineage.py index 02bf45d..d12d6c8 100644 --- a/dataci/db/lineage.py +++ b/dataci/db/lineage.py @@ -78,7 +78,7 @@ def exist_many_run_dataset_lineage(run_name, run_version, dataset_configs): 'run_version': run_version, } ) - return [row[0] for row in cur.fetchall()] + return [bool(row[0]) for row in cur.fetchall()] def create_one_run_lineage(config, cursor=None): diff --git a/dataci/models/lineage.py b/dataci/models/lineage.py index 397eab6..402c8c1 100644 --- a/dataci/models/lineage.py +++ b/dataci/models/lineage.py @@ -130,18 +130,20 @@ def save(self, exist_ok=True): # Check if dataset lineage exists if any(dataset_lineage_exist_list): + inputs_lineage_exists = dataset_lineage_exist_list[:len(config['inputs'])] + outputs_lineage_exists = dataset_lineage_exist_list[len(config['inputs']):] if not exist_ok: raise ValueError(f'Dataset lineage exists.') else: # Remove the existed dataset lineage config['inputs'] = [ dataset_config for dataset_config, exist in zip( - config['inputs'], dataset_lineage_exist_list[:len(config['inputs'])] + config['inputs'], inputs_lineage_exists ) if not exist ] config['outputs'] = [ dataset_config for dataset_config, exist in zip( - config['outputs'], dataset_lineage_exist_list[len(config['inputs']):] + config['outputs'], outputs_lineage_exists ) if not exist ] diff --git a/metadata/server.py b/metadata/server.py index c1162a2..5953949 100644 --- a/metadata/server.py +++ b/metadata/server.py @@ -5,8 +5,6 @@ Email: yuanmingleee@gmail.com Date: Nov 20, 2023 """ -import json -import traceback from typing import Union from fastapi import APIRouter, FastAPI @@ -87,12 +85,13 @@ def post_lineage(event: Union[RunEvent, DatasetEvent, JobEvent]): inputs = list(ops_props.get('input_table', dict()).values()) outputs = list(ops_props.get('output_table', dict()).values()) - Lineage( + lineage = Lineage( run=run, parent_run=parent_run_config, inputs=inputs, outputs=outputs, - ).save() + ) + lineage.save() return {'status': 'success'} From 19f84e78a848613164da8769fa63e6eb753756a3 Mon Sep 17 00:00:00 2001 From: YuanmingLeee Date: Fri, 1 Dec 2023 18:43:14 +0800 Subject: [PATCH 11/20] :wrench: [lineage] Modify lineage data object structure --- dataci/db/init.py | 46 ++--- dataci/db/lineage.py | 224 +++++++++++++++++-------- dataci/db/run.py | 15 ++ dataci/models/lineage.py | 213 +++++++++++------------ dataci/models/run.py | 1 + dataci/plugins/orchestrator/airflow.py | 3 + metadata/server.py | 19 ++- 7 files changed, 298 insertions(+), 223 deletions(-) diff --git a/dataci/db/init.py b/dataci/db/init.py index 623cd91..23bcd51 100644 --- a/dataci/db/init.py +++ b/dataci/db/init.py @@ -17,8 +17,7 @@ # Drop all tables with db_connection: db_connection.executescript(""" - DROP TABLE IF EXISTS run_dataset_lineage; - DROP TABLE IF EXISTS run_lineage; + DROP TABLE IF EXISTS lineage; DROP TABLE IF EXISTS run; DROP TABLE IF EXISTS dataset_tag; DROP TABLE IF EXISTS dataset; @@ -166,36 +165,27 @@ update_time INTEGER, PRIMARY KEY (workspace, name, version), UNIQUE (name, version), - FOREIGN KEY (workspace, name, version, job_type) + FOREIGN KEY (workspace, name, version) + REFERENCES job (workspace, name, version), + FOREIGN KEY (job_workspace, job_name, job_version, job_type) REFERENCES job (workspace, name, version, type) ); - CREATE TABLE run_lineage - ( - run_name TEXT, - run_version INTEGER, - parent_run_name TEXT, - parent_run_version INTEGER, - PRIMARY KEY (run_name, run_version, parent_run_name, parent_run_version), - FOREIGN KEY (run_name, run_version) - REFERENCES run (name, version), - FOREIGN KEY (parent_run_name, parent_run_version) - REFERENCES run (name, version) - ); - - CREATE TABLE run_dataset_lineage + CREATE TABLE lineage ( - run_name TEXT, - run_version INTEGER, - dataset_workspace TEXT, - dataset_name TEXT, - dataset_version INTEGER, - direction TEXT, - PRIMARY KEY (run_name, run_version, dataset_workspace, dataset_name, dataset_version, direction), - FOREIGN KEY (run_name, run_version) - REFERENCES run (name, version), - FOREIGN KEY (dataset_workspace, dataset_name, dataset_version) - REFERENCES dataset (workspace, name, version) + upstream_workspace TEXT, + upstream_name TEXT, + upstream_version TEXT, + upstream_type TEXT, + downstream_workspace TEXT, + downstream_name TEXT, + downstream_version TEXT, + downstream_type TEXT, + PRIMARY KEY (upstream_workspace, upstream_name, upstream_version, upstream_type, downstream_workspace, downstream_name, downstream_version, downstream_type), + FOREIGN KEY (upstream_workspace, upstream_name, upstream_version, upstream_type) + REFERENCES job (workspace, name, version, type), + FOREIGN KEY (downstream_workspace, downstream_name, downstream_version, downstream_type) + REFERENCES job (workspace, name, version, type) ); """) logger.info('Create all tables.') diff --git a/dataci/db/lineage.py b/dataci/db/lineage.py index d12d6c8..c47f745 100644 --- a/dataci/db/lineage.py +++ b/dataci/db/lineage.py @@ -15,124 +15,206 @@ def get_lineage(): pass -def exist_run_lineage(run_name, run_version, parent_run_name, parent_run_version): +def exist_one_lineage(upstream_config, downstream_config): with sqlite3.connect(DB_FILE) as conn: cur = conn.cursor() cur.execute( """ SELECT EXISTS( SELECT 1 - FROM run_lineage - WHERE run_name = :run_name - AND run_version = :run_version - AND parent_run_name = :parent_run_name - AND parent_run_version = :parent_run_version + FROM lineage + WHERE upstream_workspace = :upstream_workspace + AND upstream_name = :upstream_name + AND upstream_version = :upstream_version + AND upstream_type = :upstream_type + AND downstream_workspace = :downstream_workspace + AND downstream_name = :downstream_name + AND downstream_version = :downstream_version + AND downstream_type = :downstream_type ) ; """, { - 'run_name': run_name, - 'run_version': run_version, - 'parent_run_name': parent_run_name, - 'parent_run_version': parent_run_version, + 'upstream_workspace': upstream_config['workspace'], + 'upstream_name': upstream_config['name'], + 'upstream_version': upstream_config['version'], + 'upstream_type': upstream_config['type'], + 'downstream_workspace': downstream_config['workspace'], + 'downstream_name': downstream_config['name'], + 'downstream_version': downstream_config['version'], + 'downstream_type': downstream_config['type'], } ) return cur.fetchone()[0] -def exist_many_run_dataset_lineage(run_name, run_version, dataset_configs): - # Return empty list if no dataset_configs, +def exist_many_downstream_lineage(upstream_config, downstream_configs): + # Return empty list if no upstream_configs or downstream_configs, # this prevents SQL syntax error when generating SQL statement - if len(dataset_configs) == 0: + if len(downstream_configs) == 0: return list() with sqlite3.connect(DB_FILE) as conn: cur = conn.cursor() - sql_dataset_values = ',\n'.join([ - repr((dataset_config['workspace'], dataset_config['name'], dataset_config['version'], - dataset_config['direction'])) - for dataset_config in dataset_configs + sql_lineage_values = ',\n'.join([ + repr(( + downstream_config['workspace'], + downstream_config['name'], + downstream_config['version'], + downstream_config['type'], + )) + for downstream_config in downstream_configs ]) cur.execute( f""" - WITH datasets (dataset_workspace, dataset_name, dataset_version, direction) AS ( - VALUES {sql_dataset_values} + WITH downstreams ( + downstream_workspace + ,downstream_name + ,downstream_version + ,downstream_type + ) AS ( + VALUES {sql_lineage_values} ) - ,lineage AS ( + ,lineages AS ( SELECT TRUE AS flg - ,dataset_workspace - ,dataset_name - ,dataset_version - ,direction - FROM run_dataset_lineage - WHERE run_name = :run_name - AND run_version = :run_version + ,upstream_workspace + ,upstream_name + ,upstream_version + ,upstream_type + ,downstream_workspace + ,downstream_name + ,downstream_version + ,downstream_type + FROM lineage + WHERE upstream_workspace = :upstream_workspace + AND upstream_name = :upstream_name + AND upstream_version = :upstream_version + AND upstream_type = :upstream_type ) SELECT COALESCE(flg, FALSE) AS flg - FROM datasets - LEFT JOIN lineage USING (dataset_workspace, dataset_name, dataset_version, direction) + FROM downstreams + LEFT JOIN lineages USING ( + downstream_workspace + ,downstream_name + ,downstream_version + ,downstream_type + ) ; """, { - 'run_name': run_name, - 'run_version': run_version, + 'upstream_workspace': upstream_config['workspace'], + 'upstream_name': upstream_config['name'], + 'upstream_version': upstream_config['version'], + 'upstream_type': upstream_config['type'], } ) return [bool(row[0]) for row in cur.fetchall()] -def create_one_run_lineage(config, cursor=None): - run_lineage_config = { - 'run_name': config['run']['name'], - 'run_version': config['run']['version'], - 'parent_run_name': config['parent_run']['name'], - 'parent_run_version': config['parent_run']['version'], - } +def exist_many_upstream_lineage(upstream_configs, downstream_config): + # Return empty list if no upstream_configs or downstream_configs, + # this prevents SQL syntax error when generating SQL statement + if len(upstream_configs) == 0: + return list() - with sqlite3.connect(DB_FILE) if cursor is None else nullcontext() as conn: - cur = cursor or conn.cursor() - # add parent_run -> run + with sqlite3.connect(DB_FILE) as conn: + cur = conn.cursor() + sql_lineage_values = ',\n'.join([ + repr(( + upstream_config['workspace'], + upstream_config['name'], + upstream_config['version'], + upstream_config['type'], + )) + for upstream_config in upstream_configs + ]) cur.execute( - """ - INSERT INTO run_lineage ( - run_name - ,run_version - ,parent_run_name - ,parent_run_version + f""" + WITH upstreams ( + upstream_workspace + ,upstream_name + ,upstream_version + ,upstream_type + ) AS ( + VALUES {sql_lineage_values} + ) + ,lineages AS ( + SELECT TRUE AS flg + ,upstream_workspace + ,upstream_name + ,upstream_version + ,upstream_type + ,downstream_workspace + ,downstream_name + ,downstream_version + ,downstream_type + FROM lineage + WHERE downstream_workspace = :downstream_workspace + AND downstream_name = :downstream_name + AND downstream_version = :downstream_version + AND downstream_type = :downstream_type + ) + SELECT COALESCE(flg, FALSE) AS flg + FROM upstreams + LEFT JOIN lineages USING ( + upstream_workspace + ,upstream_name + ,upstream_version + ,upstream_type ) - VALUES (:run_name, :run_version, :parent_run_name, :parent_run_version) ; """, - run_lineage_config + { + 'downstream_workspace': downstream_config['workspace'], + 'downstream_name': downstream_config['name'], + 'downstream_version': downstream_config['version'], + 'downstream_type': downstream_config['type'], + } ) + return [bool(row[0]) for row in cur.fetchall()] -def create_many_dataset_lineage(config, cursor=None): - dataset_configs = list() - for dataset in config['inputs'] + config['outputs']: - dataset_configs.append({ - 'run_name': config['run']['name'], - 'run_version': config['run']['version'], - 'dataset_workspace': dataset['workspace'], - 'dataset_name': dataset['name'], - 'dataset_version': dataset['version'], - 'direction': dataset['direction'], - }) +def create_many_lineage(config): + # Permute all upstream and downstream lineage + lineage_configs = list() + for upstream_config in config['upstream']: + for downstream_config in config['downstream']: + lineage_configs.append({ + 'upstream_workspace': upstream_config['workspace'], + 'upstream_name': upstream_config['name'], + 'upstream_version': upstream_config['version'], + 'upstream_type': upstream_config['type'], + 'downstream_workspace': downstream_config['workspace'], + 'downstream_name': downstream_config['name'], + 'downstream_version': downstream_config['version'], + 'downstream_type': downstream_config['type'], + }) - with sqlite3.connect(DB_FILE) if cursor is None else nullcontext() as conn: - cur = cursor or conn.cursor() + with sqlite3.connect(DB_FILE) as conn: + cur = conn.cursor() cur.executemany( """ - INSERT INTO run_dataset_lineage ( - run_name - ,run_version - ,dataset_workspace - ,dataset_name - ,dataset_version - ,direction + INSERT INTO lineage ( + upstream_workspace + ,upstream_name + ,upstream_version + ,upstream_type + ,downstream_workspace + ,downstream_name + ,downstream_version + ,downstream_type + ) + VALUES ( + :upstream_workspace + ,:upstream_name + ,:upstream_version + ,:upstream_type + ,:downstream_workspace + ,:downstream_name + ,:downstream_version + ,:downstream_type ) - VALUES (:run_name, :run_version, :dataset_workspace, :dataset_name, :dataset_version, :direction) ; """, - dataset_configs, + lineage_configs, ) diff --git a/dataci/db/run.py b/dataci/db/run.py index 1b92b8a..c807c53 100644 --- a/dataci/db/run.py +++ b/dataci/db/run.py @@ -20,6 +20,21 @@ def create_one_run(config: dict): config['job_type'] = job_config['type'] with sqlite3.connect(DB_FILE) as conn: cur = conn.cursor() + # Create job + cur.execute( + """ + INSERT INTO job ( + workspace, + name, + version, + type + ) + VALUES (:workspace, :name, :version, :type) + ; + """, + config + ) + # Create run cur.execute( """ INSERT INTO run ( diff --git a/dataci/models/lineage.py b/dataci/models/lineage.py index 402c8c1..92ba8fa 100644 --- a/dataci/models/lineage.py +++ b/dataci/models/lineage.py @@ -5,41 +5,40 @@ Email: yuanmingleee@gmail.com Date: Nov 22, 2023 """ -import sqlite3 import warnings -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeVar -from dataci.config import DB_FILE from dataci.db.lineage import ( - exist_run_lineage, exist_many_run_dataset_lineage, create_one_run_lineage, create_many_dataset_lineage + exist_many_downstream_lineage, + exist_many_upstream_lineage, + create_many_lineage, ) from dataci.models.dataset import Dataset +from dataci.models.run import Run if TYPE_CHECKING: - from typing import List, Optional, Union + from typing import List, Union + +LineageAllowedType = TypeVar('LineageAllowedType', Dataset, Run) - from dataci.models import Workflow, Stage, Run class Lineage(object): def __init__( self, - run: 'Union[Run, dict]', - parent_run: 'Optional[Union[Run, dict]]' = None, - inputs: 'List[Union[Dataset, dict, str]]' = None, - outputs: 'List[Union[Dataset, dict, str]]' = None, + upstream: 'Union[List[LineageAllowedType], LineageAllowedType, dict]', + downstream: 'Union[List[LineageAllowedType], LineageAllowedType, dict]', ): - self._run = run - self._parent_run = parent_run - self._inputs: 'List[Dataset]' = inputs or list() - self._outputs: 'List[Dataset]' = outputs or list() + # only one of upstream and downstream can be list + if isinstance(upstream, list) and isinstance(downstream, list): + raise ValueError('Only one of upstream and downstream can be list.') + self._upstream = upstream if isinstance(upstream, list) else [upstream] + self._downstream = downstream if isinstance(downstream, list) else [downstream] def dict(self): return { - 'parent_run': self.parent_run.dict(id_only=True) if self.parent_run else None, - 'run': self.run.dict() if self.run else None, - 'inputs': [input_.dict(id_only=True) for input_ in self.inputs], - 'outputs': [output.dict(id_only=True) for output in self.outputs], + 'upstream': [node.dict(id_only=True) for node in self.upstream], + 'downstream': [node.dict(id_only=True) for node in self.downstream], } @classmethod @@ -47,113 +46,95 @@ def from_dict(cls, config): pass @property - def job(self) -> 'Union[Workflow, Stage]': - return self.run.job - - @property - def run(self) -> 'Run': - """Lazy load run from database.""" - from dataci.models import Run - - if not isinstance(self._run, Run): - self._run = Run.get(self._run['name']) - return self._run - - @property - def parent_run(self) -> 'Optional[Run]': - """Lazy load parent run from database.""" - from dataci.models import Run - - if self._parent_run is None: - return None - - if not isinstance(self._parent_run, Run): - self._parent_run = Run.get(self._parent_run['name']) - return self._parent_run - - @property - def inputs(self) -> 'List[Dataset]': - """Lazy load inputs from database.""" - inputs = list() - for input_ in self._inputs: - if isinstance(input_, Dataset): - inputs.append(input_) - elif isinstance(input_, dict): - inputs.append(Dataset.get(**input_)) + def upstream(self) -> 'List[LineageAllowedType]': + """Lazy load upstream from database.""" + nodes = list() + for node in self._upstream: + if isinstance(node, (Dataset, Run)): + nodes.append(node) + elif isinstance(node, dict): + node_type = node.pop('type', None) + if node_type == 'run': + node_cls = Run + nodes.append(node_cls.get(**node)) + elif node_type == 'dataset': + node_cls = Dataset + nodes.append(node_cls.get(**node)) + else: + warnings.warn(f'Unknown node type {node_type}') else: - warnings.warn(f'Unable to parse input {input_}') - self._inputs = inputs - return self._inputs + warnings.warn(f'Unable to parse upstream {node}') + self._upstream = nodes + return self._upstream @property - def outputs(self) -> 'List[Dataset]': - """Lazy load outputs from database.""" - outputs = list() - for output in self._outputs: - if isinstance(output, Dataset): - outputs.append(output) - elif isinstance(output, dict): - outputs.append(Dataset.get(**output)) + def downstream(self) -> 'List[LineageAllowedType]': + """Lazy load downstream from database.""" + downstream = list() + for node in self._downstream: + if isinstance(node, (Dataset, Run)): + downstream.append(node) + elif isinstance(node, dict): + node_type = node.pop('type', None) + if node_type == 'run': + node_cls = Run + downstream.append(node_cls.get(**node)) + elif node_type == 'dataset': + node_cls = Dataset + downstream.append(node_cls.get(**node)) + else: + warnings.warn(f'Unknown node type {node_type}') else: - warnings.warn(f'Unable to parse output {output}') - self._outputs = outputs - return self._outputs + warnings.warn(f'Unable to parse downstream {node}') + self._downstream = downstream + return self._downstream def save(self, exist_ok=True): config = self.dict() - for input_ in config['inputs']: - input_['direction'] = 'input' - for output in config['outputs']: - output['direction'] = 'output' - - run_lineage_exist = (config['parent_run'] is None) or exist_run_lineage( - config['run']['name'], - config['run']['version'], - config['parent_run'].get('name', None), - config['parent_run'].get('version', None) - ) - - dataset_lineage_exist_list = exist_many_run_dataset_lineage( - config['run']['name'], - config['run']['version'], - config['inputs'] + config['outputs'] - ) - - # Check if run lineage exists - is_create_run_lineage = False - if run_lineage_exist: - if not exist_ok: - raise ValueError(f'Run lineage {self.parent_run} -> {self.run} exists.') + + if len(config['upstream']) == 1: + # Check if downstream lineage exists + upstream_config = config['upstream'][0] + lineage_exist_status_list = exist_many_downstream_lineage( + upstream_config, config['downstream'], + ) + + if any(lineage_exist_status_list): + if not exist_ok: + exist_downstreams = [ + downstream_config for downstream_config, exist in zip( + config['downstream'], lineage_exist_status_list + ) if exist + ] + raise ValueError(f"Lineage exists: {upstream_config} -> {exist_downstreams}") + else: + # Remove the existed lineage + config['downstream'] = [ + node for node, exist in zip(config['downstream'], lineage_exist_status_list) if not exist + ] else: - # Set create run lineage to True - is_create_run_lineage = True - - # Check if dataset lineage exists - if any(dataset_lineage_exist_list): - inputs_lineage_exists = dataset_lineage_exist_list[:len(config['inputs'])] - outputs_lineage_exists = dataset_lineage_exist_list[len(config['inputs']):] - if not exist_ok: - raise ValueError(f'Dataset lineage exists.') - else: - # Remove the existed dataset lineage - config['inputs'] = [ - dataset_config for dataset_config, exist in zip( - config['inputs'], inputs_lineage_exists - ) if not exist - ] - config['outputs'] = [ - dataset_config for dataset_config, exist in zip( - config['outputs'], outputs_lineage_exists - ) if not exist - ] - - with sqlite3.connect(DB_FILE) as conn: - cur = conn.cursor() - # Create run lineage - if is_create_run_lineage: - create_one_run_lineage(config, cur) + # Check if upstream lineage exists + downstream_config = config['downstream'][0] + lineage_exist_status_list = exist_many_upstream_lineage( + config['upstream'], downstream_config, + ) + + if any(lineage_exist_status_list): + if not exist_ok: + exist_upstreams = [ + upstream_config for upstream_config, exist in zip( + config['upstream'], lineage_exist_status_list + ) if exist + ] + raise ValueError(f"Lineage exists: {exist_upstreams} -> {downstream_config}") + else: + # Remove the existed lineage + config['upstream'] = [ + node for node, exist in zip(config['upstream'], lineage_exist_status_list) if not exist + ] + # Create dataset lineage - create_many_dataset_lineage(config, cur) + create_many_lineage(config) return self diff --git a/dataci/models/run.py b/dataci/models/run.py index fd901fb..be19037 100644 --- a/dataci/models/run.py +++ b/dataci/models/run.py @@ -78,6 +78,7 @@ def dict(self, id_only=False): } return { 'workspace': self.workspace.name, + 'type': self.type_name, 'name': self.name, 'version': self.version, 'status': self.status, diff --git a/dataci/plugins/orchestrator/airflow.py b/dataci/plugins/orchestrator/airflow.py index 4cd6f6c..04df74d 100644 --- a/dataci/plugins/orchestrator/airflow.py +++ b/dataci/plugins/orchestrator/airflow.py @@ -230,6 +230,7 @@ def execute_callable(self) -> 'Any': # Load back to the input table self.input_table[arg_name] = { 'name': dataset.identifier, + 'type': dataset.type_name, 'file_reader': dataset.file_reader.NAME, 'file_writer': dataset.file_writer.NAME } @@ -245,6 +246,7 @@ def execute_callable(self) -> 'Any': dataset.save() self.output_table[key] = ret[key] = { 'name': dataset.identifier, + 'type': dataset.type_name, 'file_reader': dataset.file_reader.NAME, 'file_writer': dataset.file_writer.NAME } @@ -253,6 +255,7 @@ def execute_callable(self) -> 'Any': dataset.save() self.output_table[key] = ret = { 'name': dataset.identifier, + 'type': dataset.type_name, 'file_reader': dataset.file_reader.NAME, 'file_writer': dataset.file_writer.NAME } diff --git a/metadata/server.py b/metadata/server.py index 5953949..1456d49 100644 --- a/metadata/server.py +++ b/metadata/server.py @@ -73,6 +73,7 @@ def post_lineage(event: Union[RunEvent, DatasetEvent, JobEvent]): if 'parent' in event.run.facets: parent_run_config = { 'name': str(event.run.facets['parent'].run['runId']), + 'type': 'run', } else: parent_run_config = None @@ -82,23 +83,25 @@ def post_lineage(event: Union[RunEvent, DatasetEvent, JobEvent]): # Outputs: event.run.facets['unknownSourceAttribute'].unknownItems[0]['properties']['output_table'] unknown_src_attr = event.run.facets.get('unknownSourceAttribute', object()) ops_props = (getattr(unknown_src_attr, 'unknownItems', None) or [dict()])[0].get('properties', dict()) + # Input tables and parent run are upstream inputs = list(ops_props.get('input_table', dict()).values()) + if parent_run_config is not None: + inputs.append(parent_run_config) + # Output tables are downstream outputs = list(ops_props.get('output_table', dict()).values()) - lineage = Lineage( - run=run, - parent_run=parent_run_config, - inputs=inputs, - outputs=outputs, - ) - lineage.save() + if len(inputs) > 0: + upstream_lineage = Lineage(upstream=inputs, downstream=run) + upstream_lineage.save() + if len(outputs) > 0: + downstream_lineage = Lineage(upstream=run, downstream=outputs) + downstream_lineage.save() return {'status': 'success'} app.include_router(api_router) - if __name__ == '__main__': import uvicorn From 324a2e136df61f509d35949d9c3069c76f36c51c Mon Sep 17 00:00:00 2001 From: YuanmingLeee Date: Mon, 4 Dec 2023 18:29:48 +0800 Subject: [PATCH 12/20] :bug: [lineage] Fix bug in lineage save --- dataci/models/lineage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dataci/models/lineage.py b/dataci/models/lineage.py index 92ba8fa..547310b 100644 --- a/dataci/models/lineage.py +++ b/dataci/models/lineage.py @@ -133,8 +133,8 @@ def save(self, exist_ok=True): node for node, exist in zip(config['upstream'], lineage_exist_status_list) if not exist ] - # Create dataset lineage - create_many_lineage(config) + # Create dataset lineage + create_many_lineage(config) return self From 3519be507deb3016fe9cb8ae4c453c1750f98fdc Mon Sep 17 00:00:00 2001 From: YuanmingLeee Date: Tue, 5 Dec 2023 18:36:11 +0800 Subject: [PATCH 13/20] :wrench: [lineage] Add lineage get API design --- dataci/db/lineage.py | 245 ++++++++++++++++++++++++++++++++++++--- dataci/models/base.py | 6 + dataci/models/lineage.py | 49 +++++++- dataci/models/run.py | 9 ++ 4 files changed, 292 insertions(+), 17 deletions(-) diff --git a/dataci/db/lineage.py b/dataci/db/lineage.py index c47f745..f4576e1 100644 --- a/dataci/db/lineage.py +++ b/dataci/db/lineage.py @@ -6,13 +6,226 @@ Date: Nov 27, 2023 """ import sqlite3 +from collections import OrderedDict from contextlib import nullcontext from dataci.config import DB_FILE -def get_lineage(): - pass +def get_many_upstream_lineage(downstream_config): + """Config in downstream.""" + with sqlite3.connect(DB_FILE) as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT upstream_workspace + ,upstream_name + ,upstream_version + ,upstream_type + FROM lineage + WHERE ( + downstream_workspace = :workspace + AND downstream_name = :name + AND downstream_version = :version + AND downstream_type = :type + ) + """, + downstream_config, + ) + return [ + { + 'workspace': row[0], + 'name': row[1], + 'version': row[2], + 'type': row[3], + } for row in cur.fetchall() + ] + + +def get_many_downstream_lineage(upstream_config): + """Config in upstream.""" + with sqlite3.connect(DB_FILE) as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT downstream_workspace + ,downstream_name + ,downstream_version + ,downstream_type + FROM lineage + WHERE ( + upstream_workspace = :workspace + AND upstream_name = :name + AND upstream_version = :version + AND upstream_type = :type + ) + """, + upstream_config, + ) + return [ + { + 'workspace': row[0], + 'name': row[1], + 'version': row[2], + 'type': row[3], + } for row in cur.fetchall() + ] + + +def list_many_upstream_lineage(downstream_configs): + """List all upstream lineage of downstream_configs.""" + # Return empty list if no downstream_configs, + # this prevents SQL syntax error when generating SQL statement + if len(downstream_configs) == 0: + return list() + + # Create a ordered dict to preserve the order of downstream_configs + od = OrderedDict() + for downstream_config in downstream_configs: + od[( + downstream_config['workspace'], + downstream_config['name'], + downstream_config['version'], + downstream_config['type'] + )] = list() + + with sqlite3.connect(DB_FILE) as conn: + cur = conn.cursor() + sql_lineage_values = ',\n'.join([ + repr(( + downstream_config['workspace'], + downstream_config['name'], + downstream_config['version'], + downstream_config['type'], + )) + for downstream_config in downstream_configs + ]) + cur.execute( + f""" + WITH downstreams ( + downstream_workspace + ,downstream_name + ,downstream_version + ,downstream_type + ) AS ( + VALUES {sql_lineage_values} + ) + ,lineages AS ( + SELECT upstream_workspace + ,upstream_name + ,upstream_version + ,upstream_type + ,downstream_workspace + ,downstream_name + ,downstream_version + ,downstream_type + FROM lineage + ) + SELECT upstream_workspace + ,upstream_name + ,upstream_version + ,upstream_type + ,downstream_workspace + ,downstream_name + ,downstream_version + ,downstream_type + FROM lineages + JOIN downstreams USING ( + downstream_workspace + ,downstream_name + ,downstream_version + ,downstream_type + ) + ; + """ + ) + + for row in cur.fetchall(): + od[(row[4], row[5], row[6], row[7],)].append({ + 'workspace': row[0], + 'name': row[1], + 'version': row[2], + 'type': row[3], + }) + return list(od.values()) + + +def list_many_downstream_lineage(upstream_configs): + """List all downstream lineage of upstream_configs.""" + # Return empty list if no upstream_configs, + # this prevents SQL syntax error when generating SQL statement + if len(upstream_configs) == 0: + return list() + + # Create a ordered dict to preserve the order of upstream_configs + od = OrderedDict() + for upstream_config in upstream_configs: + od[( + upstream_config['workspace'], + upstream_config['name'], + upstream_config['version'], + upstream_config['type'] + )] = list() + + with sqlite3.connect(DB_FILE) as conn: + cur = conn.cursor() + sql_lineage_values = ',\n'.join([ + repr(( + upstream_config['workspace'], + upstream_config['name'], + upstream_config['version'], + upstream_config['type'], + )) + for upstream_config in upstream_configs + ]) + cur.execute( + f""" + WITH upstreams ( + upstream_workspace + ,upstream_name + ,upstream_version + ,upstream_type + ) AS ( + VALUES {sql_lineage_values} + ) + ,lineages AS ( + SELECT upstream_workspace + ,upstream_name + ,upstream_version + ,upstream_type + ,downstream_workspace + ,downstream_name + ,downstream_version + ,downstream_type + FROM lineage + ) + SELECT upstream_workspace + ,upstream_name + ,upstream_version + ,upstream_type + ,downstream_workspace + ,downstream_name + ,downstream_version + ,downstream_type + FROM lineages + JOIN upstreams USING ( + upstream_workspace + ,upstream_name + ,upstream_version + ,upstream_type + ) + ; + """ + ) + + for row in cur.fetchall(): + od[(row[0], row[1], row[2], row[3],)].append({ + 'workspace': row[4], + 'name': row[5], + 'version': row[6], + 'type': row[7], + }) + return list(od.values()) def exist_one_lineage(upstream_config, downstream_config): @@ -57,13 +270,13 @@ def exist_many_downstream_lineage(upstream_config, downstream_configs): with sqlite3.connect(DB_FILE) as conn: cur = conn.cursor() sql_lineage_values = ',\n'.join([ - repr(( - downstream_config['workspace'], - downstream_config['name'], - downstream_config['version'], - downstream_config['type'], - )) - for downstream_config in downstream_configs + repr(( + downstream_config['workspace'], + downstream_config['name'], + downstream_config['version'], + downstream_config['type'], + )) + for downstream_config in downstream_configs ]) cur.execute( f""" @@ -120,13 +333,13 @@ def exist_many_upstream_lineage(upstream_configs, downstream_config): with sqlite3.connect(DB_FILE) as conn: cur = conn.cursor() sql_lineage_values = ',\n'.join([ - repr(( - upstream_config['workspace'], - upstream_config['name'], - upstream_config['version'], - upstream_config['type'], - )) - for upstream_config in upstream_configs + repr(( + upstream_config['workspace'], + upstream_config['name'], + upstream_config['version'], + upstream_config['type'], + )) + for upstream_config in upstream_configs ]) cur.execute( f""" diff --git a/dataci/models/base.py b/dataci/models/base.py index 294a912..4b83cf7 100644 --- a/dataci/models/base.py +++ b/dataci/models/base.py @@ -76,6 +76,12 @@ def publish(self): def get(cls, name, version=None, not_found_ok=False): pass + def upstream(self, n=1, type=None): + pass + + def downstream(self, n=1, type=None): + pass + # @classmethod # @abc.abstractmethod # def find(cls, identifier=None, tree_view=False): diff --git a/dataci/models/lineage.py b/dataci/models/lineage.py index 547310b..7832a1e 100644 --- a/dataci/models/lineage.py +++ b/dataci/models/lineage.py @@ -8,10 +8,14 @@ import warnings from typing import TYPE_CHECKING, TypeVar +import networkx as nx + from dataci.db.lineage import ( exist_many_downstream_lineage, exist_many_upstream_lineage, create_many_lineage, + get_many_upstream_lineage, + get_many_downstream_lineage, ) from dataci.models.dataset import Dataset from dataci.models.run import Run @@ -138,5 +142,48 @@ def save(self, exist_ok=True): return self - def get(self, run_name, run_version): +class LineageGraph: + + def get_vertices(self): + # Retrieves all vertices from the vertices table V. + pass + + def get_edges(self): + # Retrieves all edges from the edge table E. + pass + + def get_vertex(self, vertex_id): + # Retrieves a vertex from the vertices table V by its ID. + pass + + def get_edge(self, edge_id): + # Retrieves an edge from the edge table E by its ID. pass + + @classmethod + def upstream(cls, job: 'Union[LineageAllowedType, dict]', n: 'int' = 1, type: 'str' = None) -> 'LineageGraph': + # Retrieves incoming edges that are connected to a vertex. + if isinstance(job, dict): + job_config = job + else: + job_config = job.dict(id_only=True) + + lineage_configs = get_many_upstream_lineage(job_config) + g = nx.DiGraph() + for lineage_config in lineage_configs: + g.add_edge(lineage_config['upstream'], lineage_config['downstream']) + return g + + @classmethod + def downstream(cls, job: 'Union[LineageAllowedType, dict]', n: 'int' = 1, type: 'str' = None) -> 'LineageGraph': + # Retrieves all outgoing edges that are connected to a vertex. + if isinstance(job, dict): + job_config = job + else: + job_config = job.dict(id_only=True) + + lineage_configs = get_many_downstream_lineage(job_config) + g = nx.DiGraph() + for lineage_config in lineage_configs: + g.add_edge(lineage_config['upstream'], lineage_config['downstream']) + return g diff --git a/dataci/models/run.py b/dataci/models/run.py index be19037..bd7d4fb 100644 --- a/dataci/models/run.py +++ b/dataci/models/run.py @@ -16,6 +16,7 @@ from dataci.db.run import exist_run, create_one_run, get_next_run_version, get_latest_run_version, get_one_run, \ update_one_run from dataci.models import BaseModel +from dataci.models.lineage import LineageGraph if TYPE_CHECKING: from typing import Optional, Union @@ -152,3 +153,11 @@ def get(cls, name, version=None, not_found_ok=False): raise ValueError(f'Run {name}@{version} not found.') return cls.from_dict(config) + + def upstream(self, n=1, type=None): + """Get upstream lineage.""" + return LineageGraph.upstream(self, n, type) + + def downstream(self, n=1, type=None): + """Get downstream lineage.""" + return LineageGraph.downstream(self, n, type) From 96269410f2a9128b2c5151487b76b3f3fd2cc01f Mon Sep 17 00:00:00 2001 From: YuanmingLeee Date: Thu, 7 Dec 2023 18:31:44 +0800 Subject: [PATCH 14/20] :beer: Add lineage downstream/upstream query for stage --- dataci/db/run.py | 44 ++++++++++++++++++++++ dataci/decorators/base.py | 6 +++ dataci/models/lineage.py | 78 +++++++++++++++++++++++++++++++-------- dataci/models/run.py | 36 ++++++++++++++++-- dataci/models/stage.py | 65 ++++++++++++++++++++++++++++++-- 5 files changed, 206 insertions(+), 23 deletions(-) diff --git a/dataci/db/run.py b/dataci/db/run.py index c807c53..94bacdd 100644 --- a/dataci/db/run.py +++ b/dataci/db/run.py @@ -181,3 +181,47 @@ def get_one_run(name, version='latest'): 'create_time': config[8], 'update_time': config[9], } if config else None + + +def list_run_by_job(workspace, name, version, type): + with sqlite3.connect(DB_FILE) as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT workspace + , name + , version + , status + , job_workspace + , job_name + , job_version + , job_type + , create_time + , update_time + FROM run + WHERE job_workspace = ? + AND job_name = ? + AND job_version = ? + AND job_type = ? + ; + """, + (workspace, name, version, type) + ) + configs = cur.fetchall() + return [ + { + 'workspace': config[0], + 'name': config[1], + 'version': config[2], + 'status': config[3], + 'job': { + 'workspace': config[4], + 'name': config[5], + 'version': config[6], + 'type': config[7], + }, + 'create_time': config[8], + 'update_time': config[9], + } + for config in configs + ] diff --git a/dataci/decorators/base.py b/dataci/decorators/base.py index ef611e8..33e3fff 100644 --- a/dataci/decorators/base.py +++ b/dataci/decorators/base.py @@ -66,3 +66,9 @@ def save(self): def publish(self): self._stage.publish() return self + + def upstream(self, n=1, type=None): + return self._stage.upstream(n, type) + + def downstream(self, n=1, type=None): + return self._stage.downstream(n, type) diff --git a/dataci/models/lineage.py b/dataci/models/lineage.py index 7832a1e..193c949 100644 --- a/dataci/models/lineage.py +++ b/dataci/models/lineage.py @@ -14,16 +14,16 @@ exist_many_downstream_lineage, exist_many_upstream_lineage, create_many_lineage, - get_many_upstream_lineage, - get_many_downstream_lineage, + list_many_upstream_lineage, + list_many_downstream_lineage, ) -from dataci.models.dataset import Dataset -from dataci.models.run import Run if TYPE_CHECKING: from typing import List, Union + from dataci.models.dataset import Dataset + from dataci.models.run import Run -LineageAllowedType = TypeVar('LineageAllowedType', Dataset, Run) + LineageAllowedType = TypeVar('LineageAllowedType', Dataset, Run) class Lineage(object): @@ -161,29 +161,77 @@ def get_edge(self, edge_id): pass @classmethod - def upstream(cls, job: 'Union[LineageAllowedType, dict]', n: 'int' = 1, type: 'str' = None) -> 'LineageGraph': - # Retrieves incoming edges that are connected to a vertex. + def upstream(cls, job: 'Union[LineageAllowedType, dict]', n: 'int' = 1, type: 'str' = None) -> 'nx.DiGraph': + """Retrieves incoming edges that are connected to a vertex. + """ if isinstance(job, dict): job_config = job else: job_config = job.dict(id_only=True) - lineage_configs = get_many_upstream_lineage(job_config) g = nx.DiGraph() - for lineage_config in lineage_configs: - g.add_edge(lineage_config['upstream'], lineage_config['downstream']) + g.add_node(job_config) + job_configs = [job_config] + # Retrieve upstream lineage up to n levels + for _ in range(n): + lineage_configs = list_many_upstream_lineage(job_configs) + g.add_edges_from( + (lineage_config['upstream'], lineage_config['downstream']) for lineage_config in lineage_configs + ) + # With type filter, we need to retrieve extra upstream jobs if the current job does not match the type + job_configs = [ + lineage_config for lineage_config in lineage_configs + if type is not None and lineage_config['upstream']['type'] != type + ] + + add_lineage_configs = list_many_upstream_lineage(job_configs) + g.add_edges_from( + (lineage_config['upstream'], lineage_config['downstream']) for lineage_config in add_lineage_configs + ) + # Add upstream jobs to job_configs for next iteration of lineage retrieval + job_configs = [ + lineage_config['upstream'] for lineage_config in lineage_configs + if type is None or lineage_config['upstream']['type'] == type + ] + [ + lineage_config['upstream'] for lineage_config in add_lineage_configs + ] + return g @classmethod - def downstream(cls, job: 'Union[LineageAllowedType, dict]', n: 'int' = 1, type: 'str' = None) -> 'LineageGraph': - # Retrieves all outgoing edges that are connected to a vertex. + def downstream(cls, job: 'Union[LineageAllowedType, dict]', n: 'int' = 1, type: 'str' = None) -> 'nx.DiGraph': + """Retrieves outgoing edges that are connected to a vertex. + """ if isinstance(job, dict): job_config = job else: job_config = job.dict(id_only=True) - lineage_configs = get_many_downstream_lineage(job_config) g = nx.DiGraph() - for lineage_config in lineage_configs: - g.add_edge(lineage_config['upstream'], lineage_config['downstream']) + g.add_node(job_config) + job_configs = [job_config] + # Retrieve downstream lineage up to n levels + for _ in range(n): + lineage_configs = list_many_downstream_lineage(job_configs) + g.add_edges_from( + (lineage_config['upstream'], lineage_config['downstream']) for lineage_config in lineage_configs + ) + # With type filter, we need to retrieve extra downstream jobs if the current job does not match the type + job_configs = [ + lineage_config for lineage_config in lineage_configs + if type is not None and lineage_config['downstream']['type'] != type + ] + + add_lineage_configs = list_many_downstream_lineage(job_configs) + g.add_edges_from( + (lineage_config['upstream'], lineage_config['downstream']) for lineage_config in add_lineage_configs + ) + # Add downstream jobs to job_configs for next iteration of lineage retrieval + job_configs = [ + lineage_config['downstream'] for lineage_config in lineage_configs + if type is None or lineage_config['downstream']['type'] == type + ] + [ + lineage_config['downstream'] for lineage_config in add_lineage_configs + ] + return g diff --git a/dataci/models/run.py b/dataci/models/run.py index bd7d4fb..efa3880 100644 --- a/dataci/models/run.py +++ b/dataci/models/run.py @@ -12,9 +12,18 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING +import networkx as nx + from dataci.config import TIMEZONE -from dataci.db.run import exist_run, create_one_run, get_next_run_version, get_latest_run_version, get_one_run, \ +from dataci.db.run import ( + exist_run, + create_one_run, + get_next_run_version, + get_latest_run_version, + get_one_run, + list_run_by_job, update_one_run +) from dataci.models import BaseModel from dataci.models.lineage import LineageGraph @@ -141,7 +150,7 @@ def update(self): return self.reload(config) @classmethod - def get(cls, name, version=None, not_found_ok=False): + def get(cls, name, version=None, workspace=None, not_found_ok=False): """Get run by name and version.""" workspace, name, version = cls.parse_data_model_get_identifier(name, version) # If version not set, get the latest version @@ -154,10 +163,29 @@ def get(cls, name, version=None, not_found_ok=False): return cls.from_dict(config) + @classmethod + def find_by_job(cls, workspace, name, version, type): + """Find run by job id.""" + configs = list_run_by_job(workspace=workspace, name=name, version=version, type=type) + + return [cls.from_dict(config) for config in configs] + def upstream(self, n=1, type=None): """Get upstream lineage.""" - return LineageGraph.upstream(self, n, type) + g = LineageGraph.upstream(self, n, type) + node_mapping = {node.id: node for node in g.nodes()} + for node in g.nodes(): + if node['type'] == 'run': + node_mapping[node] = Run.get(name=node['name'], version=node['version']) + nx.relabel_nodes(g, node_mapping, copy=False) + return g def downstream(self, n=1, type=None): """Get downstream lineage.""" - return LineageGraph.downstream(self, n, type) + g = LineageGraph.downstream(self, n, type) + node_mapping = {node.id: node for node in g.nodes()} + for node in g.nodes(): + if node['type'] == 'run': + node_mapping[node] = Run.get(name=node['name'], version=node['version']) + nx.relabel_nodes(g, node_mapping, copy=False) + return g diff --git a/dataci/models/stage.py b/dataci/models/stage.py index 70c12c9..3be55ee 100644 --- a/dataci/models/stage.py +++ b/dataci/models/stage.py @@ -13,6 +13,8 @@ from datetime import datetime from typing import TYPE_CHECKING +import networkx as nx + from dataci.db.stage import ( create_one_stage, exist_stage, @@ -184,9 +186,11 @@ def publish(self): return self.reload(config) @classmethod - def get_config(cls, name, version=None): + def get_config(cls, name, version=None, workspace=None): """Get the stage config from the workspace.""" - workspace, name, version_or_tag = cls.parse_data_model_get_identifier(name, version) + workspace_, name, version_or_tag = cls.parse_data_model_get_identifier(name, version) + # Override workspace if provided + workspace = workspace or workspace_ if version_or_tag == 'latest' or version_or_tag.startswith('v'): config = get_one_stage_by_tag(workspace, name, version_or_tag) else: @@ -195,10 +199,12 @@ def get_config(cls, name, version=None): return config @classmethod - def get(cls, name, version=None): + def get(cls, name, version=None, workspace=None, not_found_ok=False): """Get the stage from the workspace.""" - config = cls.get_config(name, version) + config = cls.get_config(name, version, workspace) if config is None: + if not not_found_ok: + raise ValueError(f'Stage {name}@{version} not found') return return cls.from_dict(config) @@ -234,3 +240,54 @@ def find(cls, stage_identifier, tree_view=False, all=False): return stage_dict return stages + + def upstream(self, n=1, type=None): + """Get the downstream stages of the stage. + TODO: type is miss-aligned with stage::upstream, only 'run' and 'dataset' are supported. + """ + from dataci.models.run import Run + + runs = Run.find(self.full_name, job_type=self.type_name) + graphs = list() + node_mapping = dict() + for run in runs: + g = run.upstream(n=n, type=type) + for node in g.nodes: + if isinstance(node, Run): + node_mapping[node] = node._job + nx.relabel_nodes(g, node_mapping, copy=False) + graphs.append(g) + # Merge all graphs + upstream_graph = nx.compose_all(graphs) + # Replace the node with the stage object + node_mapping = { + node: Stage.get(name=node['name'], version=node['version'], workspace=node['workspace']) + for node in upstream_graph.nodes() if node['type'] == 'stage' + } + nx.relabel_nodes(upstream_graph, node_mapping, copy=False) + return upstream_graph + + def downstream(self, n=1, type=None): + """Get the downstream stages of the stage. + """ + from dataci.models.run import Run + + runs = Run.find_by_job(workspace=self.workspace.name, name=self.name, version=self.version, type=self.type_name) + graphs = list() + node_mapping = dict() + for run in runs: + g = run.downstream(n=n, type=type) + for node in g.nodes: + if isinstance(node, Run): + node_mapping[node] = node._job + nx.relabel_nodes(g, node_mapping, copy=False) + graphs.append(g) + # Merge all graphs + downstream_graph = nx.compose_all(graphs) + # Replace the node with the stage object + node_mapping = { + node: Stage.get(name=node['name'], version=node['version'], workspace=node['workspace']) + for node in downstream_graph.nodes() if node['type'] == 'stage' + } + nx.relabel_nodes(downstream_graph, node_mapping, copy=False) + return downstream_graph From b1d9130c2f3ee468f8fc98a8c9c6907d266ddc31 Mon Sep 17 00:00:00 2001 From: yuanmingleee Date: Thu, 7 Dec 2023 23:04:18 +0800 Subject: [PATCH 15/20] :recycle: [rename] Rename class BaseModel -> Job to be consistent with db table --- dataci/db/run.py | 44 ++++++++++++++++++++++++++++++++++++++ dataci/decorators/event.py | 4 ++-- dataci/models/__init__.py | 4 ++-- dataci/models/base.py | 4 ++-- dataci/models/dataset.py | 4 ++-- dataci/models/run.py | 4 ++-- dataci/models/stage.py | 4 ++-- dataci/models/workflow.py | 4 ++-- 8 files changed, 58 insertions(+), 14 deletions(-) diff --git a/dataci/db/run.py b/dataci/db/run.py index 94bacdd..da7459b 100644 --- a/dataci/db/run.py +++ b/dataci/db/run.py @@ -225,3 +225,47 @@ def list_run_by_job(workspace, name, version, type): } for config in configs ] + + +def list_run_by_job(workspace, name, version, type): + with sqlite3.connect(DB_FILE) as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT workspace + , name + , version + , status + , job_workspace + , job_name + , job_version + , job_type + , create_time + , update_time + FROM run + WHERE job_workspace = ? + AND job_name = ? + AND job_version = ? + AND job_type = ? + ; + """, + (workspace, name, version, type) + ) + configs = cur.fetchall() + return [ + { + 'workspace': config[0], + 'name': config[1], + 'version': config[2], + 'status': config[3], + 'job': { + 'workspace': config[4], + 'name': config[5], + 'version': config[6], + 'type': config[7], + }, + 'create_time': config[8], + 'update_time': config[9], + } + for config in configs + ] diff --git a/dataci/decorators/event.py b/dataci/decorators/event.py index 3ad13e1..2765ef2 100644 --- a/dataci/decorators/event.py +++ b/dataci/decorators/event.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from typing import Type, Union, TypeVar, Callable - from dataci.models.base import BaseModel + from dataci.models.base import Job T = TypeVar('T', bound=Callable) @@ -18,7 +18,7 @@ def event(name: str = None, producer: str = None): def wrapper(func: 'T') -> 'T': @wraps(func) - def inner_wrapper(self: 'Union[BaseModel, Type[BaseModel]]', *args, **kwargs): + def inner_wrapper(self: 'Union[Job, Type[Job]]', *args, **kwargs): # Prevent circular import from dataci.models import Event diff --git a/dataci/models/__init__.py b/dataci/models/__init__.py index 3aa6ff4..fc05c0a 100644 --- a/dataci/models/__init__.py +++ b/dataci/models/__init__.py @@ -5,7 +5,7 @@ Email: yuanmingleee@gmail.com Date: Feb 20, 2023 """ -from .base import BaseModel +from .base import Job from .dataset import Dataset from .event import Event from .lineage import Lineage @@ -15,5 +15,5 @@ from .workspace import Workspace __all__ = [ - 'BaseModel', 'Workspace', 'Dataset', 'Event', 'Workflow', 'Stage', 'Run', 'Lineage', + 'Job', 'Workspace', 'Dataset', 'Event', 'Workflow', 'Stage', 'Run', 'Lineage', ] diff --git a/dataci/models/base.py b/dataci/models/base.py index 4b83cf7..822fd5a 100644 --- a/dataci/models/base.py +++ b/dataci/models/base.py @@ -12,7 +12,7 @@ from dataci.models.workspace import Workspace -class BaseModel(abc.ABC): +class Job(abc.ABC): NAME_PATTERN = re.compile(r'^(?:[a-z]\w*\.)?[a-z]\w*$', flags=re.IGNORECASE) VERSION_PATTERN = re.compile(r'latest|v\d+|none|[\da-f]+', flags=re.IGNORECASE) GET_DATA_MODEL_IDENTIFIER_PATTERN = re.compile( @@ -26,7 +26,7 @@ class BaseModel(abc.ABC): def __init__(self, name, *args, **kwargs): # Prevent to pass invalid arguments to object.__init__ mro = type(self).mro() - for next_cls in mro[mro.index(BaseModel) + 1:]: + for next_cls in mro[mro.index(Job) + 1:]: if '__init__' in next_cls.__dict__: break else: diff --git a/dataci/models/dataset.py b/dataci/models/dataset.py index 97cd28a..824d07e 100644 --- a/dataci/models/dataset.py +++ b/dataci/models/dataset.py @@ -32,7 +32,7 @@ ) from dataci.decorators.event import event from dataci.utils import hash_binary -from .base import BaseModel +from .base import Job if TYPE_CHECKING: from typing import Optional, Union, Type @@ -220,7 +220,7 @@ def __len__(self): } -class Dataset(BaseModel): +class Dataset(Job): type_name = 'dataset' VERSION_PATTERN = re.compile(r'latest|none|\w+', flags=re.IGNORECASE) diff --git a/dataci/models/run.py b/dataci/models/run.py index efa3880..855543b 100644 --- a/dataci/models/run.py +++ b/dataci/models/run.py @@ -24,7 +24,7 @@ list_run_by_job, update_one_run ) -from dataci.models import BaseModel +from dataci.models import Job from dataci.models.lineage import LineageGraph if TYPE_CHECKING: @@ -33,7 +33,7 @@ from dataci.models import Workflow, Stage -class Run(BaseModel): +class Run(Job): # run id (uuid) NAME_PATTERN = re.compile(r'^[a-f0-9]{8}-?[a-f0-9]{4}-?[a-f0-9]{4}-?[a-f0-9]{4}-?[a-f0-9]{12}$', flags=re.IGNORECASE) VERSION_PATTERN = re.compile(r'^\d+|latest$', flags=re.IGNORECASE) diff --git a/dataci/models/stage.py b/dataci/models/stage.py index 3be55ee..3b1bf1f 100644 --- a/dataci/models/stage.py +++ b/dataci/models/stage.py @@ -23,7 +23,7 @@ get_next_stage_version_tag, get_many_stages, create_one_stage_tag ) -from .base import BaseModel +from .base import Job from .script import Script from ..utils import hash_binary, cwd @@ -31,7 +31,7 @@ from typing import Optional -class Stage(BaseModel): +class Stage(Job): """Stage mixin class. Attributes: diff --git a/dataci/models/workflow.py b/dataci/models/workflow.py index a7a027d..79b0de0 100644 --- a/dataci/models/workflow.py +++ b/dataci/models/workflow.py @@ -28,7 +28,7 @@ get_next_workflow_version_id, create_one_workflow_tag, get_one_workflow_by_tag, get_one_workflow_by_version, ) -from .base import BaseModel +from .base import Job from .event import Event from .script import Script from .stage import Stage @@ -44,7 +44,7 @@ logger = logging.getLogger(__name__) -class Workflow(BaseModel, ABC): +class Workflow(Job, ABC): name_arg = 'name' type_name = 'workflow' From 64f5da019e4aeb099f586b3bbe8c16c5fe308e8f Mon Sep 17 00:00:00 2001 From: yuanmingleee Date: Fri, 8 Dec 2023 01:51:01 +0800 Subject: [PATCH 16/20] :wrench: [refactor] Add a sub-class register for Job for better method overriding --- dataci/models/__init__.py | 4 ++++ dataci/models/base.py | 19 ++++++++++++++++--- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/dataci/models/__init__.py b/dataci/models/__init__.py index fc05c0a..0f0a7e2 100644 --- a/dataci/models/__init__.py +++ b/dataci/models/__init__.py @@ -17,3 +17,7 @@ __all__ = [ 'Job', 'Workspace', 'Dataset', 'Event', 'Workflow', 'Stage', 'Run', 'Lineage', ] + + +# Register subclasses of Job +getattr(Job, '_Job__register_job_type')() diff --git a/dataci/models/base.py b/dataci/models/base.py index 822fd5a..1ddade1 100644 --- a/dataci/models/base.py +++ b/dataci/models/base.py @@ -7,10 +7,14 @@ """ import abc import re +from typing import TYPE_CHECKING from dataci.config import DEFAULT_WORKSPACE from dataci.models.workspace import Workspace +if TYPE_CHECKING: + from typing import Dict, Type + class Job(abc.ABC): NAME_PATTERN = re.compile(r'^(?:[a-z]\w*\.)?[a-z]\w*$', flags=re.IGNORECASE) @@ -22,6 +26,7 @@ class Job(abc.ABC): r'^(?:([a-z]\w*)\.)?([\w:.*[\]]+?)(?:@(\d+|latest|none|\*))?$', re.IGNORECASE ) type_name: str + __type_name_mapper__: 'Dict[str, Type[Job]]' = dict() def __init__(self, name, *args, **kwargs): # Prevent to pass invalid arguments to object.__init__ @@ -72,9 +77,11 @@ def publish(self): pass @classmethod - @abc.abstractmethod - def get(cls, name, version=None, not_found_ok=False): - pass + def get(cls, name, version=None, workspace=None, type=..., not_found_ok=False) -> 'Job': + subcls = cls.__type_name_mapper__.get(type, None) + if not subcls: + raise ValueError(f'Invalid type {type}') + return subcls.get(name, version, workspace, not_found_ok) def upstream(self, n=1, type=None): pass @@ -125,3 +132,9 @@ def parse_data_model_list_identifier(cls, identifier): version = str(version or '*').lower() return workspace, name, version + + @classmethod + def __register_job_type(cls): + """Register data class to job, this is essential load data class from job.""" + for sub_cls in cls.__subclasses__(): + cls.__type_name_mapper__[sub_cls.type_name] = sub_cls From 6c04604f2dfb56d2eb63a82256d6d3bb5b420c53 Mon Sep 17 00:00:00 2001 From: yuanmingleee Date: Sun, 10 Dec 2023 22:21:30 +0800 Subject: [PATCH 17/20] :bug: [lineage] Fix bug in get upstream/downstream lineage --- dataci/db/lineage.py | 9 ++-- dataci/models/base.py | 12 +++++ dataci/models/dataset.py | 11 ++++ dataci/models/lineage.py | 111 +++++++++++++++++++++------------------ dataci/models/run.py | 27 ++-------- dataci/models/stage.py | 24 +++------ dataci/utils.py | 4 ++ 7 files changed, 100 insertions(+), 98 deletions(-) diff --git a/dataci/db/lineage.py b/dataci/db/lineage.py index f4576e1..e088013 100644 --- a/dataci/db/lineage.py +++ b/dataci/db/lineage.py @@ -7,7 +7,6 @@ """ import sqlite3 from collections import OrderedDict -from contextlib import nullcontext from dataci.config import DB_FILE @@ -85,7 +84,7 @@ def list_many_upstream_lineage(downstream_configs): od[( downstream_config['workspace'], downstream_config['name'], - downstream_config['version'], + str(downstream_config['version']), downstream_config['type'] )] = list() @@ -95,7 +94,7 @@ def list_many_upstream_lineage(downstream_configs): repr(( downstream_config['workspace'], downstream_config['name'], - downstream_config['version'], + str(downstream_config['version']), downstream_config['type'], )) for downstream_config in downstream_configs @@ -163,7 +162,7 @@ def list_many_downstream_lineage(upstream_configs): od[( upstream_config['workspace'], upstream_config['name'], - upstream_config['version'], + str(upstream_config['version']), upstream_config['type'] )] = list() @@ -173,7 +172,7 @@ def list_many_downstream_lineage(upstream_configs): repr(( upstream_config['workspace'], upstream_config['name'], - upstream_config['version'], + str(upstream_config['version']), upstream_config['type'], )) for upstream_config in upstream_configs diff --git a/dataci/models/base.py b/dataci/models/base.py index 1ddade1..d54934e 100644 --- a/dataci/models/base.py +++ b/dataci/models/base.py @@ -7,6 +7,7 @@ """ import abc import re +from dataclasses import dataclass from typing import TYPE_CHECKING from dataci.config import DEFAULT_WORKSPACE @@ -138,3 +139,14 @@ def __register_job_type(cls): """Register data class to job, this is essential load data class from job.""" for sub_cls in cls.__subclasses__(): cls.__type_name_mapper__[sub_cls.type_name] = sub_cls + + +@dataclass(frozen=True) +class JobView: + type: str + workspace: str + name: str + version: str = None + + def get(self) -> 'Job': + return Job.get(workspace=self.workspace, name=self.name, version=self.version, type=self.type) diff --git a/dataci/models/dataset.py b/dataci/models/dataset.py index 824d07e..cbd6a11 100644 --- a/dataci/models/dataset.py +++ b/dataci/models/dataset.py @@ -33,6 +33,7 @@ from dataci.decorators.event import event from dataci.utils import hash_binary from .base import Job +from .lineage import LineageGraph if TYPE_CHECKING: from typing import Optional, Union, Type @@ -523,3 +524,13 @@ def find(cls, dataset_identifier=None, tree_view=False, all=False): return dict(dataset_dict) return dataset_list + + def upstream(self, n=1, type=None): + """Get upstream""" + """Get upstream lineage.""" + g = LineageGraph.upstream(self, n, type) + return g + + def downstream(self, n=1, type=None): + """Get downstream lineage.""" + return LineageGraph.downstream(self, n, type) diff --git a/dataci/models/lineage.py b/dataci/models/lineage.py index 193c949..0795909 100644 --- a/dataci/models/lineage.py +++ b/dataci/models/lineage.py @@ -6,6 +6,7 @@ Date: Nov 22, 2023 """ import warnings +from itertools import chain from typing import TYPE_CHECKING, TypeVar import networkx as nx @@ -17,6 +18,8 @@ list_many_upstream_lineage, list_many_downstream_lineage, ) +from dataci.models.base import JobView +from dataci.utils import dict_to_frozenset if TYPE_CHECKING: from typing import List, Union @@ -164,74 +167,78 @@ def get_edge(self, edge_id): def upstream(cls, job: 'Union[LineageAllowedType, dict]', n: 'int' = 1, type: 'str' = None) -> 'nx.DiGraph': """Retrieves incoming edges that are connected to a vertex. """ - if isinstance(job, dict): - job_config = job - else: + from dataci.models import Job + + if isinstance(job, Job): job_config = job.dict(id_only=True) + else: + job_config = job g = nx.DiGraph() - g.add_node(job_config) + g.add_node(JobView(**job_config)) job_configs = [job_config] # Retrieve upstream lineage up to n levels for _ in range(n): - lineage_configs = list_many_upstream_lineage(job_configs) - g.add_edges_from( - (lineage_config['upstream'], lineage_config['downstream']) for lineage_config in lineage_configs - ) - # With type filter, we need to retrieve extra upstream jobs if the current job does not match the type - job_configs = [ - lineage_config for lineage_config in lineage_configs - if type is not None and lineage_config['upstream']['type'] != type - ] - - add_lineage_configs = list_many_upstream_lineage(job_configs) - g.add_edges_from( - (lineage_config['upstream'], lineage_config['downstream']) for lineage_config in add_lineage_configs - ) - # Add upstream jobs to job_configs for next iteration of lineage retrieval - job_configs = [ - lineage_config['upstream'] for lineage_config in lineage_configs - if type is None or lineage_config['upstream']['type'] == type - ] + [ - lineage_config['upstream'] for lineage_config in add_lineage_configs - ] - + # (level n) -> . . + # job configs to query for next iteration, job configs to query and add to graph + job_configs, job_configs_add = list(), job_configs + # Recursively query for upstream lineage until all lineage_configs are the same `type` as the argument + while len(job_configs_add) > 0: + lineage_configs = list_many_upstream_lineage(job_configs_add) + for upstreams, downstream in zip(lineage_configs, job_configs_add): + downstream_job_view = JobView(**downstream) + g.add_edges_from((JobView(**upstream), downstream_job_view) for upstream in upstreams) + job_configs_add.clear() + # (level n+1) -> . . . x + # \/ \/ + # (level n) . . + # upstreams that are the same `type` as the argument (represented as dot ".") + # will be queried for next level of lineage + # the others (represented as cross "x") will query for next iteration in the loop, because they + # are not considered as a valid node for the current level + for upstreams in chain.from_iterable(lineage_configs): + if type is None or upstreams['type'] == type: + job_configs.append(upstreams) + else: + job_configs_add.append(upstreams) return g @classmethod def downstream(cls, job: 'Union[LineageAllowedType, dict]', n: 'int' = 1, type: 'str' = None) -> 'nx.DiGraph': """Retrieves outgoing edges that are connected to a vertex. """ - if isinstance(job, dict): - job_config = job - else: + from dataci.models import Job + + if isinstance(job, Job): job_config = job.dict(id_only=True) + else: + job_config = job g = nx.DiGraph() - g.add_node(job_config) + g.add_node(JobView(**job_config)) job_configs = [job_config] # Retrieve downstream lineage up to n levels for _ in range(n): - lineage_configs = list_many_downstream_lineage(job_configs) - g.add_edges_from( - (lineage_config['upstream'], lineage_config['downstream']) for lineage_config in lineage_configs - ) - # With type filter, we need to retrieve extra downstream jobs if the current job does not match the type - job_configs = [ - lineage_config for lineage_config in lineage_configs - if type is not None and lineage_config['downstream']['type'] != type - ] - - add_lineage_configs = list_many_downstream_lineage(job_configs) - g.add_edges_from( - (lineage_config['upstream'], lineage_config['downstream']) for lineage_config in add_lineage_configs - ) - # Add downstream jobs to job_configs for next iteration of lineage retrieval - job_configs = [ - lineage_config['downstream'] for lineage_config in lineage_configs - if type is None or lineage_config['downstream']['type'] == type - ] + [ - lineage_config['downstream'] for lineage_config in add_lineage_configs - ] - + # (level n) -> . . + # job configs to query for next iteration, job configs to query and add to graph + job_configs, job_configs_add = list(), job_configs + # Recursively query for downstream lineage until all lineage_configs are the same `type` as the argument + while len(job_configs_add) > 0: + lineage_configs = list_many_downstream_lineage(job_configs_add) + for upstream, downstreams in zip(job_configs_add, lineage_configs): + upstream_job_view = JobView(**upstream) + g.add_edges_from((upstream_job_view, JobView(**downstream)) for downstream in downstreams) + job_configs_add.clear() + # (level n) . . + # /\ /\ + # (level n+1) -> . . . x + # downstreams that are the same `type` as the argument (represented as dot ".") + # will be queried for next level of lineage + # the others (represented as cross "x") will query for next iteration in the loop, because they + # are not considered as a valid node for the current level + for downstreams in chain.from_iterable(lineage_configs): + if type is None or downstreams['type'] == type: + job_configs.append(downstreams) + else: + job_configs_add.append(downstreams) return g diff --git a/dataci/models/run.py b/dataci/models/run.py index 855543b..7042d6b 100644 --- a/dataci/models/run.py +++ b/dataci/models/run.py @@ -63,19 +63,11 @@ def try_num(self): return self.version @property - def job(self) -> 'Union[Workflow, Stage]': + def job(self) -> 'Job': """Lazy load job (workflow or stage) from database.""" - from dataci.models import Workflow, Stage from dataci.decorators.base import DecoratedOperatorStageMixin - - if not isinstance(self._job, (Workflow, Stage, DecoratedOperatorStageMixin)): - workflow_id = self._job['workspace'] + '.' + self._job['name'] + '@' + self._job['version'] - if self._job['type'] == 'workflow': - self._job = Workflow.get(workflow_id) - elif self._job['type'] == 'stage': - self._job = Stage.get_by_workflow(self._job['stage_name'], workflow_id) - else: - raise ValueError(f'Invalid job type: {self._job}') + if not isinstance(self._job, (Job, DecoratedOperatorStageMixin)): + self._job = Job.get(**self._job) return self._job def dict(self, id_only=False): @@ -173,19 +165,8 @@ def find_by_job(cls, workspace, name, version, type): def upstream(self, n=1, type=None): """Get upstream lineage.""" g = LineageGraph.upstream(self, n, type) - node_mapping = {node.id: node for node in g.nodes()} - for node in g.nodes(): - if node['type'] == 'run': - node_mapping[node] = Run.get(name=node['name'], version=node['version']) - nx.relabel_nodes(g, node_mapping, copy=False) return g def downstream(self, n=1, type=None): """Get downstream lineage.""" - g = LineageGraph.downstream(self, n, type) - node_mapping = {node.id: node for node in g.nodes()} - for node in g.nodes(): - if node['type'] == 'run': - node_mapping[node] = Run.get(name=node['name'], version=node['version']) - nx.relabel_nodes(g, node_mapping, copy=False) - return g + return LineageGraph.downstream(self, n, type) diff --git a/dataci/models/stage.py b/dataci/models/stage.py index 3b1bf1f..97271d8 100644 --- a/dataci/models/stage.py +++ b/dataci/models/stage.py @@ -23,7 +23,7 @@ get_next_stage_version_tag, get_many_stages, create_one_stage_tag ) -from .base import Job +from .base import Job, JobView from .script import Script from ..utils import hash_binary, cwd @@ -247,24 +247,18 @@ def upstream(self, n=1, type=None): """ from dataci.models.run import Run - runs = Run.find(self.full_name, job_type=self.type_name) + runs = Run.find_by_job(workspace=self.workspace.name, name=self.name, version=self.version, type=self.type_name) graphs = list() node_mapping = dict() for run in runs: g = run.upstream(n=n, type=type) for node in g.nodes: - if isinstance(node, Run): - node_mapping[node] = node._job + if node.type == Run.type_name: + node_mapping[node] = JobView(**node.get()._job) nx.relabel_nodes(g, node_mapping, copy=False) graphs.append(g) # Merge all graphs upstream_graph = nx.compose_all(graphs) - # Replace the node with the stage object - node_mapping = { - node: Stage.get(name=node['name'], version=node['version'], workspace=node['workspace']) - for node in upstream_graph.nodes() if node['type'] == 'stage' - } - nx.relabel_nodes(upstream_graph, node_mapping, copy=False) return upstream_graph def downstream(self, n=1, type=None): @@ -278,16 +272,10 @@ def downstream(self, n=1, type=None): for run in runs: g = run.downstream(n=n, type=type) for node in g.nodes: - if isinstance(node, Run): - node_mapping[node] = node._job + if node.type == Run.type_name: + node_mapping[node] = JobView(**node.get()._job) nx.relabel_nodes(g, node_mapping, copy=False) graphs.append(g) # Merge all graphs downstream_graph = nx.compose_all(graphs) - # Replace the node with the stage object - node_mapping = { - node: Stage.get(name=node['name'], version=node['version'], workspace=node['workspace']) - for node in downstream_graph.nodes() if node['type'] == 'stage' - } - nx.relabel_nodes(downstream_graph, node_mapping, copy=False) return downstream_graph diff --git a/dataci/utils.py b/dataci/utils.py index 9dd3d8c..3ef5640 100644 --- a/dataci/utils.py +++ b/dataci/utils.py @@ -78,6 +78,10 @@ def hash_file(filepaths: 'Union[str, os.PathLike, List[Union[os.PathLike, str]]] return sha_hash.hexdigest() +def dict_to_frozenset(d): + return frozenset((k, d[k]) for k in sorted(d.keys())) + + def hash_binary(b: bytes): """ Compute the hash of a binary. From add5ea5524eafdb5065514d28c094161f3b3c354 Mon Sep 17 00:00:00 2001 From: yuanmingleee Date: Mon, 11 Dec 2023 00:18:02 +0800 Subject: [PATCH 18/20] :white_check_mark: [test] Add test for lineage (WIP) --- tests/lineage/python_ops_pipeline.py | 46 +++++++++++++++++++++++++++ tests/test_lineage.py | 47 ++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+) create mode 100644 tests/lineage/python_ops_pipeline.py create mode 100644 tests/test_lineage.py diff --git a/tests/lineage/python_ops_pipeline.py b/tests/lineage/python_ops_pipeline.py new file mode 100644 index 0000000..367f804 --- /dev/null +++ b/tests/lineage/python_ops_pipeline.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Author: Li Yuanming +Email: yuanmingleee@gmail.com +Date: Dec 10, 2023 +""" +from datetime import datetime +from dataci.plugins.decorators import dag, Dataset, stage + + +@stage +def task1(df): + return df + + +@stage +def task2_0(df): + return df + + +@stage +def task2_1(df): + return df + + +@stage +def task3(df1, df2): + import pandas as pd + + return pd.concat([df1, df2]) + + +@dag( + start_date=datetime(2020, 7, 30), schedule=None, +) +def python_ops_pipeline(): + raw_dataset_train = Dataset.get('test.yelp_review_test@latest') + dataset1 = Dataset(name='test.task1_out', dataset_files=task1(raw_dataset_train)) + dataset2_0 = Dataset(name='test.task2_0_out', dataset_files=task2_0(dataset1)) + dataset2_1 = Dataset(name='test.task2_1_out', dataset_files=task2_1(dataset1)) + dataset3 = Dataset(name='test.task3_out', dataset_files=task3(dataset2_0, dataset2_1)) + + +# Build the pipeline +python_ops_dag = python_ops_pipeline() diff --git a/tests/test_lineage.py b/tests/test_lineage.py new file mode 100644 index 0000000..e097628 --- /dev/null +++ b/tests/test_lineage.py @@ -0,0 +1,47 @@ +import unittest +from pathlib import Path + +TEST_DIR = Path(__file__).parent + + +class TestLineage(unittest.TestCase): + def setUp(self): + """Set up test fixtures. + 1. Create a test workspace and set it as the current default workspace + 2. Save and publish a test dataset + 3. Save the test pipeline + """ + from dataci.models import workspace + from dataci.models import Dataset + + workspace.DEFAULT_WORKSPACE = 'test' + self.test_dataset = Dataset('yelp_review_test', dataset_files=[ + {'date': '2020-10-05 00:44:08', 'review_id': 'HWRpzNHPqjA4pxN5863QUA', 'stars': 5.0, + 'text': "I called Anytime on Friday afternoon about the number pad lock on my front door. After several questions, the gentleman asked me if I had changed the battery.", }, + {'date': '2020-10-15 04:34:49', 'review_id': '01plHaNGM92IT0LLcHjovQ', 'stars': 5.0, + 'text': "Friend took me for lunch. Ordered the Chicken Pecan Tart although it was like a piece quiche, was absolutely delicious!", }, + {'date': '2020-10-17 06:58:09', 'review_id': '7CDDSuzoxTr4H5N4lOi9zw', 'stars': 4.0, + 'text': "I love coming here for my fruit and vegetables. It is always fresh and a great variety. The bags of already diced veggies are a huge time saver.", }, + ]) + self.assertEqual( + self.test_dataset.workspace.name, + 'test', + "Failed to set the default workspace to `test`." + ) + self.test_dataset.publish(version_tag='2020-10') + + from dataci.models import Workflow + + self.workflow = Workflow.from_path( + TEST_DIR / 'lineage', + entry_path='python_ops_pipeline.py' + ) + self.workflow.publish() + self.workflow.run() + + def test_lineage(self): + self.assertEqual(True, True) # add assertion here + + +if __name__ == '__main__': + unittest.main() From 8785cdd25eba2e4a0ba075798be97a5db41aba9a Mon Sep 17 00:00:00 2001 From: yuanmingleee Date: Mon, 11 Dec 2023 00:19:04 +0800 Subject: [PATCH 19/20] :bug: [lineage] Fix bug for lineage post API Bug introduced due to change from data model get method to universal Job.get --- dataci/models/base.py | 4 +-- dataci/models/dataset.py | 11 ++++++- dataci/models/lineage.py | 47 ++++++++-------------------- dataci/models/run.py | 3 +- dataci/models/stage.py | 6 +++- dataci/models/workflow.py | 10 +++--- dataci/plugins/decorators/airflow.py | 3 +- metadata/server.py | 28 ++++++----------- 8 files changed, 50 insertions(+), 62 deletions(-) diff --git a/dataci/models/base.py b/dataci/models/base.py index d54934e..073272c 100644 --- a/dataci/models/base.py +++ b/dataci/models/base.py @@ -78,11 +78,11 @@ def publish(self): pass @classmethod - def get(cls, name, version=None, workspace=None, type=..., not_found_ok=False) -> 'Job': + def get(cls, name, version=None, workspace=None, type=..., **kwargs) -> 'Job': subcls = cls.__type_name_mapper__.get(type, None) if not subcls: raise ValueError(f'Invalid type {type}') - return subcls.get(name, version, workspace, not_found_ok) + return subcls.get(name=name, version=version, workspace=workspace, **kwargs) def upstream(self, n=1, type=None): pass diff --git a/dataci/models/dataset.py b/dataci/models/dataset.py index cbd6a11..d119723 100644 --- a/dataci/models/dataset.py +++ b/dataci/models/dataset.py @@ -455,7 +455,16 @@ def publish(self, version_tag=None): return self.reload(config) @classmethod - def get(cls, name: str, workspace=None, version=None, not_found_ok=False, file_reader='auto', file_writer='csv'): + def get( + cls, + name: str, + workspace=None, + version=None, + not_found_ok=False, + file_reader='auto', + file_writer='csv', + **kwargs, + ): name = workspace + '.' + name if workspace else name workspace, name, version_or_tag = cls.parse_data_model_get_identifier(name, version) diff --git a/dataci/models/lineage.py b/dataci/models/lineage.py index 0795909..73b6f00 100644 --- a/dataci/models/lineage.py +++ b/dataci/models/lineage.py @@ -7,7 +7,7 @@ """ import warnings from itertools import chain -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING import networkx as nx @@ -18,23 +18,18 @@ list_many_upstream_lineage, list_many_downstream_lineage, ) -from dataci.models.base import JobView -from dataci.utils import dict_to_frozenset +from dataci.models.base import Job, JobView if TYPE_CHECKING: from typing import List, Union - from dataci.models.dataset import Dataset - from dataci.models.run import Run - - LineageAllowedType = TypeVar('LineageAllowedType', Dataset, Run) class Lineage(object): def __init__( self, - upstream: 'Union[List[LineageAllowedType], LineageAllowedType, dict]', - downstream: 'Union[List[LineageAllowedType], LineageAllowedType, dict]', + upstream: 'Union[List[Job], Job, dict]', + downstream: 'Union[List[Job], Job, dict]', ): # only one of upstream and downstream can be list if isinstance(upstream, list) and isinstance(downstream, list): @@ -53,47 +48,31 @@ def from_dict(cls, config): pass @property - def upstream(self) -> 'List[LineageAllowedType]': + def upstream(self) -> 'List[Job]': """Lazy load upstream from database.""" nodes = list() for node in self._upstream: - if isinstance(node, (Dataset, Run)): + if isinstance(node, Job): nodes.append(node) elif isinstance(node, dict): - node_type = node.pop('type', None) - if node_type == 'run': - node_cls = Run - nodes.append(node_cls.get(**node)) - elif node_type == 'dataset': - node_cls = Dataset - nodes.append(node_cls.get(**node)) - else: - warnings.warn(f'Unknown node type {node_type}') + nodes.append(Job.get(**node)) else: warnings.warn(f'Unable to parse upstream {node}') self._upstream = nodes return self._upstream @property - def downstream(self) -> 'List[LineageAllowedType]': + def downstream(self) -> 'List[Job]': """Lazy load downstream from database.""" - downstream = list() + nodes = list() for node in self._downstream: - if isinstance(node, (Dataset, Run)): - downstream.append(node) + if isinstance(node, Job): + nodes.append(node) elif isinstance(node, dict): - node_type = node.pop('type', None) - if node_type == 'run': - node_cls = Run - downstream.append(node_cls.get(**node)) - elif node_type == 'dataset': - node_cls = Dataset - downstream.append(node_cls.get(**node)) - else: - warnings.warn(f'Unknown node type {node_type}') + nodes.append(Job.get(**node)) else: warnings.warn(f'Unable to parse downstream {node}') - self._downstream = downstream + self._downstream = nodes return self._downstream def save(self, exist_ok=True): diff --git a/dataci/models/run.py b/dataci/models/run.py index 7042d6b..5460740 100644 --- a/dataci/models/run.py +++ b/dataci/models/run.py @@ -51,6 +51,7 @@ def __init__( update_time: 'Optional[datetime]' = None, **kwargs ): + # TODO: get workspace from job super().__init__(name, **kwargs) self.status: str = status self._job = job @@ -142,7 +143,7 @@ def update(self): return self.reload(config) @classmethod - def get(cls, name, version=None, workspace=None, not_found_ok=False): + def get(cls, name, version=None, workspace=None, not_found_ok=False, **kwargs): """Get run by name and version.""" workspace, name, version = cls.parse_data_model_get_identifier(name, version) # If version not set, get the latest version diff --git a/dataci/models/stage.py b/dataci/models/stage.py index 97271d8..21289d1 100644 --- a/dataci/models/stage.py +++ b/dataci/models/stage.py @@ -199,7 +199,7 @@ def get_config(cls, name, version=None, workspace=None): return config @classmethod - def get(cls, name, version=None, workspace=None, not_found_ok=False): + def get(cls, name, version=None, workspace=None, not_found_ok=False, **kwargs): """Get the stage from the workspace.""" config = cls.get_config(name, version, workspace) if config is None: @@ -225,6 +225,10 @@ def get_by_workflow(cls, stage_name, workflow_name, workflow_version=None): else: raise ValueError(f'Stage {stage_name} not found in workflow {workflow_name}@{workflow_version}') + if '.' not in stage_name: + stage_workspace = workflow_config['workspace'] + stage_name = stage_workspace + '.' + stage_name + return cls.get(stage_name, stage_version) @classmethod diff --git a/dataci/models/workflow.py b/dataci/models/workflow.py index 79b0de0..f0562a6 100644 --- a/dataci/models/workflow.py +++ b/dataci/models/workflow.py @@ -345,9 +345,11 @@ def publish(self): return self.reload(config) @classmethod - def get_config(cls, name: str, version: str = None): + def get_config(cls, name: str, version: str = None, workspace: str = None): """Get workflow config only""" - workspace, name, version = cls.parse_data_model_get_identifier(name, version) + workspace_, name, version = cls.parse_data_model_get_identifier(name, version) + # Override the workspace if specified + workspace = workspace or workspace_ if version is None or version == 'latest' or version.startswith('v'): # Get by tag @@ -360,9 +362,9 @@ def get_config(cls, name: str, version: str = None): return config @classmethod - def get(cls, name: str, version: str = None): + def get(cls, name: str, version: str = None, workspace: str = None, not_found_ok=False, **kwargs): """Get a models from the workspace.""" - config = cls.get_config(name, version) + config = cls.get_config(name=name, version=version, workspace=workspace) if config is None: return diff --git a/dataci/plugins/decorators/airflow.py b/dataci/plugins/decorators/airflow.py index 9cd320a..8a2b6b6 100644 --- a/dataci/plugins/decorators/airflow.py +++ b/dataci/plugins/decorators/airflow.py @@ -60,7 +60,8 @@ def __call__(self, *args, **kwargs): self._stage.input_table[key] = ... elif isinstance(arg, _Dataset): # arg is a DataCI dataset self._stage.input_table[key] = { - 'name': arg.identifier, 'file_reader': arg.file_reader.NAME, 'file_writer': arg.file_writer.NAME + 'name': arg.identifier, 'file_reader': arg.file_reader.NAME, 'file_writer': arg.file_writer.NAME, + 'type': arg.type_name, } # Rewrite the argument with the dataset identifier bound.arguments[key] = arg.identifier diff --git a/metadata/server.py b/metadata/server.py index 1456d49..3fa9b92 100644 --- a/metadata/server.py +++ b/metadata/server.py @@ -9,7 +9,7 @@ from fastapi import APIRouter, FastAPI -from dataci.models import Run as RunModel, Lineage +from dataci.models import Run as RunModel, Lineage, Stage from metadata.models import RunEvent, DatasetEvent, JobEvent, RunState app = FastAPI() @@ -30,27 +30,25 @@ def post_lineage(event: Union[RunEvent, DatasetEvent, JobEvent]): # Parse job type if '.' in event.job.name: - job_type = 'stage' # Get job job_workspace, job_name, job_version = name_parts job_version, stage_name = job_version.split('.') + job = Stage.get_by_workflow(stage_name, f'{job_workspace}.{job_name}@{job_version}') else: - job_type = 'workflow' job_workspace, job_name, job_version = name_parts - stage_name = None + job = { + 'workspace': job_workspace, + 'type': 'workflow', + 'name': job_name, + 'version': job_version, + } # If event type is START, create a new run if event.eventType == RunState.START: run = RunModel( name=str(event.run.runId), status=event.eventType.value, - job={ - 'workspace': job_workspace, - 'type': job_type, - 'name': job_name, - 'version': job_version, - 'stage_name': stage_name, - }, + job=job, create_time=event.eventTime, ) run.save() @@ -58,13 +56,7 @@ def post_lineage(event: Union[RunEvent, DatasetEvent, JobEvent]): run = RunModel( name=str(event.run.runId), status=event.eventType.value, - job={ - 'workspace': job_workspace, - 'type': job_type, - 'name': job_name, - 'version': job_version, - 'stage_name': stage_name, - }, + job=job, update_time=event.eventTime, ) run.update() From 47f5631e271375ad603107eff3f5dadebf1fcba6 Mon Sep 17 00:00:00 2001 From: yuanmingleee Date: Mon, 11 Dec 2023 00:46:21 +0800 Subject: [PATCH 20/20] :bug: [lineage] Fix bug for lineage upstream/downstream due to common query for multiple times --- dataci/models/base.py | 5 +++- dataci/models/lineage.py | 58 +++++++++++++++++++++------------------- 2 files changed, 35 insertions(+), 28 deletions(-) diff --git a/dataci/models/base.py b/dataci/models/base.py index 073272c..c2faea8 100644 --- a/dataci/models/base.py +++ b/dataci/models/base.py @@ -7,7 +7,7 @@ """ import abc import re -from dataclasses import dataclass +from dataclasses import dataclass, asdict from typing import TYPE_CHECKING from dataci.config import DEFAULT_WORKSPACE @@ -150,3 +150,6 @@ class JobView: def get(self) -> 'Job': return Job.get(workspace=self.workspace, name=self.name, version=self.version, type=self.type) + + def dict(self): + return asdict(self) diff --git a/dataci/models/lineage.py b/dataci/models/lineage.py index 73b6f00..6292033 100644 --- a/dataci/models/lineage.py +++ b/dataci/models/lineage.py @@ -5,6 +5,7 @@ Email: yuanmingleee@gmail.com Date: Nov 22, 2023 """ +import dataclasses import warnings from itertools import chain from typing import TYPE_CHECKING @@ -152,38 +153,40 @@ def upstream(cls, job: 'Union[LineageAllowedType, dict]', n: 'int' = 1, type: 's job_config = job.dict(id_only=True) else: job_config = job + job_view = JobView(**job_config) g = nx.DiGraph() - g.add_node(JobView(**job_config)) - job_configs = [job_config] + g.add_node(job_view) + job_views = {job_view} # Retrieve upstream lineage up to n levels for _ in range(n): # (level n) -> . . # job configs to query for next iteration, job configs to query and add to graph - job_configs, job_configs_add = list(), job_configs + job_views, job_views_add = set(), job_views # Recursively query for upstream lineage until all lineage_configs are the same `type` as the argument - while len(job_configs_add) > 0: - lineage_configs = list_many_upstream_lineage(job_configs_add) - for upstreams, downstream in zip(lineage_configs, job_configs_add): - downstream_job_view = JobView(**downstream) - g.add_edges_from((JobView(**upstream), downstream_job_view) for upstream in upstreams) - job_configs_add.clear() - # (level n+1) -> . . . x + while len(job_views_add) > 0: + lineage_configs = list_many_upstream_lineage([job_view.dict() for job_view in job_views_add]) + for upstream_job_view, upstreams in zip(job_views_add, lineage_configs): + g.add_edges_from((JobView(**upstream), upstream_job_view) for upstream in upstreams) + job_views_add.clear() + # (level n-1) -> . . . x # \/ \/ # (level n) . . # upstreams that are the same `type` as the argument (represented as dot ".") # will be queried for next level of lineage # the others (represented as cross "x") will query for next iteration in the loop, because they # are not considered as a valid node for the current level - for upstreams in chain.from_iterable(lineage_configs): - if type is None or upstreams['type'] == type: - job_configs.append(upstreams) + for upstream in chain.from_iterable(lineage_configs): + upstream_job_view = JobView(**upstream) + if type is None or upstream_job_view.type == type: + job_views.add(upstream_job_view) else: - job_configs_add.append(upstreams) + job_views_add.add(upstream_job_view) return g + @classmethod - def downstream(cls, job: 'Union[LineageAllowedType, dict]', n: 'int' = 1, type: 'str' = None) -> 'nx.DiGraph': + def downstream(cls, job: 'Union[Job, dict]', n: 'int' = 1, type: 'str' = None) -> 'nx.DiGraph': """Retrieves outgoing edges that are connected to a vertex. """ from dataci.models import Job @@ -192,22 +195,22 @@ def downstream(cls, job: 'Union[LineageAllowedType, dict]', n: 'int' = 1, type: job_config = job.dict(id_only=True) else: job_config = job + job_view = JobView(**job_config) g = nx.DiGraph() - g.add_node(JobView(**job_config)) - job_configs = [job_config] + g.add_node(job_view) + job_views = {job_view} # Retrieve downstream lineage up to n levels for _ in range(n): # (level n) -> . . # job configs to query for next iteration, job configs to query and add to graph - job_configs, job_configs_add = list(), job_configs + job_views, job_views_add = set(), job_views # Recursively query for downstream lineage until all lineage_configs are the same `type` as the argument - while len(job_configs_add) > 0: - lineage_configs = list_many_downstream_lineage(job_configs_add) - for upstream, downstreams in zip(job_configs_add, lineage_configs): - upstream_job_view = JobView(**upstream) + while len(job_views_add) > 0: + lineage_configs = list_many_downstream_lineage([job_view.dict() for job_view in job_views_add]) + for upstream_job_view, downstreams in zip(job_views_add, lineage_configs): g.add_edges_from((upstream_job_view, JobView(**downstream)) for downstream in downstreams) - job_configs_add.clear() + job_views_add.clear() # (level n) . . # /\ /\ # (level n+1) -> . . . x @@ -215,9 +218,10 @@ def downstream(cls, job: 'Union[LineageAllowedType, dict]', n: 'int' = 1, type: # will be queried for next level of lineage # the others (represented as cross "x") will query for next iteration in the loop, because they # are not considered as a valid node for the current level - for downstreams in chain.from_iterable(lineage_configs): - if type is None or downstreams['type'] == type: - job_configs.append(downstreams) + for downstream in chain.from_iterable(lineage_configs): + downstream_job_view = JobView(**downstream) + if type is None or downstream_job_view.type == type: + job_views.add(downstream_job_view) else: - job_configs_add.append(downstreams) + job_views_add.add(downstream_job_view) return g