diff --git a/CHANGELOG.md b/CHANGELOG.md index 19d2fec4..7ef02a62 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,25 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Changed +- **Top-level imports optimized**: Key classes are now importable directly from `graflo`: + - `GraphEngine`, `IngestionParams` promoted to top-level alongside existing `Caster` + - Architecture classes `Resource`, `Vertex`, `VertexConfig`, `Edge`, `EdgeConfig`, `FieldType` now at top-level + - `FilterExpression` promoted to top-level (alongside existing `ComparisonOperator`, `LogicalOperator`) + - `InMemoryDataSource` added to top-level data-source exports + - Import groups reorganized: orchestration, architecture, data sources, database, filters, enums & utilities +- **`graflo.filter` package exports**: `FilterExpression`, `ComparisonOperator`, and `LogicalOperator` are now re-exported from `graflo.filter.__init__` (previously only available via `graflo.filter.onto`) + +### Documentation +- Added **Mermaid class diagrams** to Concepts page showing: + - `GraphEngine` orchestration: how `GraphEngine` delegates to `InferenceManager`, `ResourceMapper`, `Caster`, and `ConnectionManager` + - `Schema` architecture: the full hierarchy from `Schema` through `VertexConfig`/`EdgeConfig`, `Resource`, `Actor` subtypes, `Field`, and `FilterExpression` + - `Caster` ingestion pipeline: how `Caster` coordinates `RegistryBuilder`, `DataSourceRegistry`, `DBWriter`, `GraphContainer`, and `ConnectionManager` +- Enabled Mermaid rendering in mkdocs configuration +- Updated top-level package docstring with modern usage example (`GraphEngine` workflow) + ## [1.5.0] - 2026-02-02 ### Added diff --git a/docs/concepts/index.md b/docs/concepts/index.md index 5b56f501..3f5f7447 100644 --- a/docs/concepts/index.md +++ b/docs/concepts/index.md @@ -10,6 +10,266 @@ graflo transforms data sources into property graphs through a pipeline of compon Each component plays a specific role in this transformation process. +## Class Diagrams + +### GraphEngine orchestration + +`GraphEngine` is the top-level orchestrator that coordinates schema inference, +pattern creation, schema definition, and data ingestion. The diagram below shows +how it delegates to specialised components. + +```mermaid +classDiagram + direction TB + + class GraphEngine { + +target_db_flavor: DBType + +resource_mapper: ResourceMapper + +introspect(postgres_config) SchemaIntrospectionResult + +infer_schema(postgres_config) Schema + +create_patterns(postgres_config) Patterns + +define_schema(schema, target_db_config) + +define_and_ingest(schema, target_db_config, ...) + +ingest(schema, target_db_config, ...) + } + + class InferenceManager { + +conn: PostgresConnection + +target_db_flavor: DBType + +introspect(schema_name) SchemaIntrospectionResult + +infer_complete_schema(schema_name) Schema + } + + class ResourceMapper { + +create_patterns_from_postgres(conn, ...) Patterns + } + + class Caster { + +schema: Schema + +ingestion_params: IngestionParams + +ingest(target_db_config, patterns, ...) + } + + class ConnectionManager { + +connection_config: DBConfig + +init_db(schema, recreate_schema) + +clear_data(schema) + } + + class Schema { + «see Schema diagram» + } + + class Patterns { + +file_patterns: list~FilePattern~ + +table_patterns: list~TablePattern~ + } + + class DBConfig { + <> + +uri: str + +effective_schema: str? + +connection_type: DBType + } + + GraphEngine --> InferenceManager : creates for introspect / infer_schema + GraphEngine --> ResourceMapper : resource_mapper + GraphEngine --> Caster : creates for ingest + GraphEngine --> ConnectionManager : creates for define_schema + GraphEngine ..> Schema : produces / consumes + GraphEngine ..> Patterns : produces / consumes + GraphEngine ..> DBConfig : target_db_config +``` + +### Schema architecture + +`Schema` is the central configuration object that defines how data is +transformed into a property graph. The diagram below shows its constituent +parts and their relationships. + +```mermaid +classDiagram + direction TB + + class Schema { + +general: SchemaMetadata + +vertex_config: VertexConfig + +edge_config: EdgeConfig + +resources: list~Resource~ + +transforms: dict~str,ProtoTransform~ + +finish_init() + +fetch_resource(name) Resource + +remove_disconnected_vertices() + } + + class SchemaMetadata { + +name: str + +version: str? + +description: str? + } + + class VertexConfig { + +vertices: list~Vertex~ + +blank_vertices: list~Vertex~ + +db_flavor: DBType? + } + + class Vertex { + +name: str + +indexes: list~list~str~~ + +fields: list~Field~ + +filters: FilterExpression? + +dbname: str? + } + + class Field { + +name: str + +type: FieldType? + } + + class EdgeConfig { + +edges: list~Edge~ + +extra_edges: list~Edge~ + } + + class Edge { + +source: str + +target: str + +indexes: list~str~ + +weights: WeightConfig? + +relation: str? + +relation_field: str? + +filters: FilterExpression? + } + + class Resource { + +name: str + +root: ActorWrapper + +finish_init(vertex_config, edge_config, transforms) + } + + class ActorWrapper { + +actor: Actor + +children: list~ActorWrapper~ + } + note for ActorWrapper "Recursive tree: each
child is an ActorWrapper" + + class Actor { + <> + } + class VertexActor + class EdgeActor + class TransformActor + class DescendActor + + class ProtoTransform { + +name: str + } + + class FilterExpression { + +kind: leaf | composite + +from_dict(data) FilterExpression + } + + Schema *-- SchemaMetadata : general + Schema *-- VertexConfig : vertex_config + Schema *-- EdgeConfig : edge_config + Schema *-- "0..*" Resource : resources + Schema *-- "0..*" ProtoTransform : transforms + + VertexConfig *-- "0..*" Vertex : vertices + Vertex *-- "0..*" Field : fields + Vertex --> FilterExpression : filters + + EdgeConfig *-- "0..*" Edge : edges + Edge --> FilterExpression : filters + + Resource *-- ActorWrapper : root + ActorWrapper --> Actor : actor + + Actor <|-- VertexActor + Actor <|-- EdgeActor + Actor <|-- TransformActor + Actor <|-- DescendActor +``` + +### Caster ingestion pipeline + +`Caster` is the ingestion workhorse. It builds a `DataSourceRegistry` via +`RegistryBuilder`, casts each batch of source data into a `GraphContainer`, +and hands that container to `DBWriter` which pushes vertices and edges to the +target database through `ConnectionManager`. + +```mermaid +classDiagram + direction LR + + class Caster { + +schema: Schema + +ingestion_params: IngestionParams + +ingest(target_db_config, patterns, ...) + +cast_normal_resource(data, resource_name) GraphContainer + +process_batch(batch, resource_name, conn_conf) + +process_data_source(data_source, ...) + +ingest_data_sources(registry, conn_conf, ...) + } + + class IngestionParams { + +clear_data: bool + +n_cores: int + +batch_size: int + +max_items: int? + +dry: bool + +datetime_after: str? + +datetime_before: str? + +datetime_column: str? + } + + class RegistryBuilder { + +schema: Schema + +build(patterns, ingestion_params) DataSourceRegistry + } + + class DataSourceRegistry { + +register(data_source, resource_name) + +get_data_sources(resource_name) list~AbstractDataSource~ + } + + class DBWriter { + +schema: Schema + +dry: bool + +max_concurrent: int + +write(gc, conn_conf, resource_name) + } + + class GraphContainer { + +vertices: dict + +edges: dict + +from_docs_list(docs) GraphContainer + } + + class ConnectionManager { + +connection_config: DBConfig + +upsert_docs_batch(...) + +insert_edges_batch(...) + } + + class AbstractDataSource { + <> + +resource_name: str? + +iter_batches(batch_size, limit) + } + + Caster --> IngestionParams : ingestion_params + Caster --> RegistryBuilder : creates + RegistryBuilder --> DataSourceRegistry : builds + Caster --> DBWriter : creates per batch + Caster ..> GraphContainer : produces + DBWriter ..> GraphContainer : consumes + DBWriter --> ConnectionManager : opens connections + DataSourceRegistry o-- "0..*" AbstractDataSource : contains +``` + ### Data Sources vs Resources It's important to understand the distinction between **Data Sources** and **Resources**: diff --git a/docs/reference/filter/sql.md b/docs/reference/filter/sql.md new file mode 100644 index 00000000..71760c2a --- /dev/null +++ b/docs/reference/filter/sql.md @@ -0,0 +1,3 @@ +# `graflo.filter.sql` + +::: graflo.filter.sql diff --git a/docs/reference/hq/auto_join.md b/docs/reference/hq/auto_join.md new file mode 100644 index 00000000..be90bedb --- /dev/null +++ b/docs/reference/hq/auto_join.md @@ -0,0 +1,3 @@ +# `graflo.hq.auto_join` + +::: graflo.hq.auto_join diff --git a/docs/reference/hq/db_writer.md b/docs/reference/hq/db_writer.md new file mode 100644 index 00000000..98d815ce --- /dev/null +++ b/docs/reference/hq/db_writer.md @@ -0,0 +1,3 @@ +# `graflo.hq.db_writer` + +::: graflo.hq.db_writer diff --git a/docs/reference/hq/registry_builder.md b/docs/reference/hq/registry_builder.md new file mode 100644 index 00000000..c968e4ff --- /dev/null +++ b/docs/reference/hq/registry_builder.md @@ -0,0 +1,3 @@ +# `graflo.hq.registry_builder` + +::: graflo.hq.registry_builder diff --git a/graflo/__init__.py b/graflo/__init__.py index 7589f2ba..6073562e 100644 --- a/graflo/__init__.py +++ b/graflo/__init__.py @@ -1,26 +1,41 @@ """graflo: A flexible graph database abstraction layer. graflo provides a unified interface for working with different graph databases -(ArangoDB, Neo4j) through a common API. It handles graph operations, data -transformations, and query generation while abstracting away database-specific -details. +(ArangoDB, Neo4j, TigerGraph, FalkorDB, Memgraph) through a common API. +It handles graph operations, data transformations, and query generation while +abstracting away database-specific details. Key Features: - Database-agnostic graph operations - - Flexible schema management + - Flexible schema management with typed fields + - Automatic schema inference from PostgreSQL databases - Query generation and execution - Data transformation utilities - Filter expression system Example: - >>> from graflo.db.manager import ConnectionManager - >>> with ConnectionManager(config) as conn: - ... conn.init_db(schema, recreate_schema=True) - ... conn.upsert_docs_batch(docs, "users") + >>> from graflo import GraphEngine, Schema, IngestionParams + >>> engine = GraphEngine() + >>> schema = engine.infer_schema(postgres_config) + >>> engine.define_and_ingest(schema, target_db_config) """ -from .architecture import Index, Schema -from graflo.hq.caster import Caster +# --- Core orchestration --------------------------------------------------- +from .hq import Caster, GraphEngine, IngestionParams + +# --- Architecture ---------------------------------------------------------- +from .architecture import ( + Edge, + EdgeConfig, + FieldType, + Index, + Resource, + Schema, + Vertex, + VertexConfig, +) + +# --- Data sources ---------------------------------------------------------- from .data_source import ( APIConfig, APIDataSource, @@ -29,6 +44,7 @@ DataSourceRegistry, DataSourceType, FileDataSource, + InMemoryDataSource, JsonFileDataSource, JsonlFileDataSource, PaginationConfig, @@ -36,36 +52,58 @@ SQLDataSource, TableFileDataSource, ) + +# --- Database -------------------------------------------------------------- from .db import ConnectionManager, ConnectionType -from .filter.onto import ComparisonOperator, LogicalOperator + +# --- Filters --------------------------------------------------------------- +from .filter import ComparisonOperator, FilterExpression, LogicalOperator + +# --- Enums & utilities ----------------------------------------------------- from .onto import AggregationType, DBType from .util.onto import FilePattern, Patterns, ResourcePattern, TablePattern __all__ = [ + # Orchestration + "GraphEngine", + "Caster", + "IngestionParams", + # Architecture + "Schema", + "Resource", + "Vertex", + "VertexConfig", + "Edge", + "EdgeConfig", + "FieldType", + "Index", + # Data sources "AbstractDataSource", "APIConfig", "APIDataSource", - "AggregationType", - "ComparisonOperator", - "ConnectionManager", - "ConnectionType", - "Caster", "DataSourceFactory", "DataSourceRegistry", "DataSourceType", - "DBType", "FileDataSource", - "FilePattern", - "Index", + "InMemoryDataSource", "JsonFileDataSource", "JsonlFileDataSource", - "LogicalOperator", "PaginationConfig", - "Patterns", - "ResourcePattern", - "Schema", "SQLConfig", "SQLDataSource", "TableFileDataSource", + # Database + "ConnectionManager", + "ConnectionType", + # Filters + "ComparisonOperator", + "FilterExpression", + "LogicalOperator", + # Enums & utilities + "AggregationType", + "DBType", + "FilePattern", + "Patterns", + "ResourcePattern", "TablePattern", ] diff --git a/graflo/architecture/actor.py b/graflo/architecture/actor.py index 43b33a06..e368b7c6 100644 --- a/graflo/architecture/actor.py +++ b/graflo/architecture/actor.py @@ -35,6 +35,7 @@ EdgeActorConfig, TransformActorConfig, VertexActorConfig, + VertexRouterActorConfig, parse_root_config, normalize_actor_step, validate_actor_step, @@ -931,10 +932,135 @@ def fetch_actors(self, level, edges): return level, type(self), str(self), edges +class VertexRouterActor(Actor): + """Routes documents to the correct VertexActor based on a type field. + + Maintains an internal ``dict[str, ActorWrapper]`` mapping vertex type names + to pre-initialised VertexActor wrappers, giving O(1) dispatch per document + instead of iterating over all known vertex types. + + On ``__call__``: + + 1. Read ``doc[type_field]`` to determine the vertex type name. + 2. Look up ``_vertex_actors[vtype]`` for the matching wrapper. + 3. Strip *prefix* from field keys (or apply *field_map*) to build a sub-doc. + 4. Delegate to the looked-up wrapper. + + Attributes: + type_field: Document field whose value names the target vertex type. + prefix: Optional prefix to strip from field keys. + field_map: Optional explicit rename mapping (original_key -> vertex_key). + """ + + def __init__(self, config: VertexRouterActorConfig): + """Initialise from config.""" + self.type_field = config.type_field + self.prefix = config.prefix + self.field_map = config.field_map + self._vertex_actors: dict[str, ActorWrapper] = {} + self.vertex_config: VertexConfig = VertexConfig(vertices=[]) + + @classmethod + def from_config(cls, config: VertexRouterActorConfig) -> VertexRouterActor: + """Create a VertexRouterActor from a VertexRouterActorConfig.""" + return cls(config) + + def fetch_important_items(self) -> dict[str, Any]: + """Get important items for string representation.""" + items: dict[str, Any] = {"type_field": self.type_field} + if self.prefix: + items["prefix"] = self.prefix + if self.field_map: + items["field_map"] = self.field_map + items["vertex_types"] = sorted(self._vertex_actors.keys()) + return items + + def finish_init(self, **kwargs: Any) -> None: + """Build the internal vertex-type -> ActorWrapper mapping. + + One wrapper is created for every vertex type known to *vertex_config*, + so that any dynamically-typed document can be routed at runtime. + """ + self.vertex_config = kwargs.get("vertex_config", VertexConfig(vertices=[])) + for vertex in self.vertex_config.vertex_list: + wrapper = ActorWrapper.from_config(VertexActorConfig(vertex=vertex.name)) + wrapper.finish_init(**kwargs) + self._vertex_actors[vertex.name] = wrapper + logger.debug( + "VertexRouterActor: registered VertexActor(%s) for type_field=%s", + vertex.name, + self.type_field, + ) + + def count(self) -> int: + """Total actors managed by this router (self + all wrapped vertex actors).""" + return 1 + sum(w.count() for w in self._vertex_actors.values()) + + def _extract_sub_doc(self, doc: dict[str, Any]) -> dict[str, Any]: + """Build the vertex sub-document from *doc*. + + If *prefix* is set, extracts and strips prefixed keys. + If *field_map* is set, renames keys according to the map. + Otherwise returns *doc* unchanged. + """ + if self.prefix: + return { + k[len(self.prefix) :]: v + for k, v in doc.items() + if k.startswith(self.prefix) + } + if self.field_map: + return { + new_key: doc[old_key] + for old_key, new_key in self.field_map.items() + if old_key in doc + } + return doc + + def __call__( + self, ctx: ActionContext, lindex: LocationIndex, *nargs: Any, **kwargs: Any + ) -> ActionContext: + """Route the document to the matching VertexActor. + + Args: + ctx: Action context. + lindex: Current location index. + **kwargs: Must contain ``doc``. + + Returns: + Updated ActionContext. + """ + doc: dict[str, Any] = kwargs.get("doc", {}) + vtype = doc.get(self.type_field) + if vtype is None: + logger.debug( + "VertexRouterActor: type_field '%s' not in doc, skipping", + self.type_field, + ) + return ctx + + wrapper = self._vertex_actors.get(vtype) + if wrapper is None: + logger.debug( + "VertexRouterActor: vertex type '%s' (from field '%s') " + "not in VertexConfig, skipping", + vtype, + self.type_field, + ) + return ctx + + sub_doc = self._extract_sub_doc(doc) + if not sub_doc: + return ctx + + return wrapper(ctx, lindex, doc=sub_doc) + + _NodeTypePriority: MappingProxyType[Type[Actor], int] = MappingProxyType( { DescendActor: 10, TransformActor: 20, + VertexRouterActor: 30, VertexActor: 50, EdgeActor: 90, } @@ -1038,9 +1164,12 @@ def from_config(cls, config: ActorConfig) -> ActorWrapper: actor = EdgeActor.from_config(config) elif isinstance(config, DescendActorConfig): actor = DescendActor.from_config(config) + elif isinstance(config, VertexRouterActorConfig): + actor = VertexRouterActor.from_config(config) else: raise ValueError( - f"Expected VertexActorConfig, TransformActorConfig, EdgeActorConfig, or DescendActorConfig, got {type(config)}" + f"Expected VertexActorConfig, TransformActorConfig, EdgeActorConfig, " + f"DescendActorConfig, or VertexRouterActorConfig, got {type(config)}" ) wrapper = cls.__new__(cls) wrapper.actor = actor diff --git a/graflo/architecture/actor_config.py b/graflo/architecture/actor_config.py index 263ced2b..9c9c286d 100644 --- a/graflo/architecture/actor_config.py +++ b/graflo/architecture/actor_config.py @@ -91,6 +91,13 @@ def normalize_actor_step(data: dict[str, Any]) -> dict[str, Any]: del data["apply"] return data + if "vertex_router" in data: + inner = data.pop("vertex_router") + if isinstance(inner, dict): + data.update(inner) + data["type"] = "vertex_router" + return data + if "apply" in data: data["type"] = "descend" data["pipeline"] = [normalize_actor_step(s) for s in _steps_list(data["apply"])] @@ -258,9 +265,46 @@ def set_type_and_normalize(cls, data: Any) -> Any: return data +class VertexRouterActorConfig(ConfigBaseModel): + """Configuration for a VertexRouterActor. + + Routes documents to the correct VertexActor based on a type field value. + Optionally strips a prefix from field keys or applies an explicit field map. + """ + + type: Literal["vertex_router"] = Field( + default="vertex_router", description="Actor type discriminator" + ) + type_field: str = Field( + ..., + description="Document field whose value determines the target vertex type name.", + ) + prefix: str | None = Field( + default=None, + description="Optional prefix to strip from document field keys when building the vertex sub-doc.", + ) + field_map: dict[str, str] | None = Field( + default=None, + description="Optional explicit rename map (original_key -> vertex_field_key). " + "Mutually exclusive with prefix.", + ) + + @model_validator(mode="before") + @classmethod + def set_type(cls, data: Any) -> Any: + if isinstance(data, dict) and "type_field" in data and "type" not in data: + data = dict(data) + data["type"] = "vertex_router" + return data + + # Discriminated union for parsing a single pipeline step (used in ActorWrapper and Resource) ActorConfig = Annotated[ - VertexActorConfig | TransformActorConfig | EdgeActorConfig | DescendActorConfig, + VertexActorConfig + | TransformActorConfig + | EdgeActorConfig + | DescendActorConfig + | VertexRouterActorConfig, Field(discriminator="type"), ] @@ -268,9 +312,17 @@ def set_type_and_normalize(cls, data: Any) -> Any: # TypeAdapter for validating a single pipeline step (union type has no model_validate) _actor_config_adapter: TypeAdapter[ - VertexActorConfig | TransformActorConfig | EdgeActorConfig | DescendActorConfig + VertexActorConfig + | TransformActorConfig + | EdgeActorConfig + | DescendActorConfig + | VertexRouterActorConfig ] = TypeAdapter( - VertexActorConfig | TransformActorConfig | EdgeActorConfig | DescendActorConfig + VertexActorConfig + | TransformActorConfig + | EdgeActorConfig + | DescendActorConfig + | VertexRouterActorConfig ) @@ -288,7 +340,13 @@ def set_type_and_normalize(cls, data: Any) -> Any: def validate_actor_step( data: dict[str, Any], -) -> VertexActorConfig | TransformActorConfig | EdgeActorConfig | DescendActorConfig: +) -> ( + VertexActorConfig + | TransformActorConfig + | EdgeActorConfig + | DescendActorConfig + | VertexRouterActorConfig +): """Validate a normalized step dict as ActorConfig (discriminated union).""" return _actor_config_adapter.validate_python(data) @@ -296,7 +354,13 @@ def validate_actor_step( def parse_root_config( *args: Any, **kwargs: Any, -) -> VertexActorConfig | TransformActorConfig | EdgeActorConfig | DescendActorConfig: +) -> ( + VertexActorConfig + | TransformActorConfig + | EdgeActorConfig + | DescendActorConfig + | VertexRouterActorConfig +): """Parse root input into a single ActorConfig (single step or descend pipeline). Accepts the same shapes as ActorWrapper: diff --git a/graflo/architecture/resource.py b/graflo/architecture/resource.py index 65481646..ede48e4e 100644 --- a/graflo/architecture/resource.py +++ b/graflo/architecture/resource.py @@ -15,6 +15,7 @@ - Weight management - Collection merging - Type casting and validation + - Dynamic vertex-type routing via VertexRouterActor in the pipeline Example: >>> resource = Resource( @@ -54,6 +55,9 @@ class Resource(ConfigBaseModel): structures. Manages the processing pipeline through actors and handles data encoding, transformation, and mapping. Suitable for LLM-generated schema constituents. + + Dynamic vertex-type routing is handled by ``vertex_router`` steps in the + pipeline (see :class:`~graflo.architecture.actor.VertexRouterActor`). """ model_config = {"extra": "forbid"} @@ -166,9 +170,7 @@ def finish_init( edge_greedy=self.edge_greedy, ) - logger.debug( - "total resource actor count (after 2 finit): %s", self.root.count() - ) + logger.debug("total resource actor count (after finit): %s", self.root.count()) for e in self.extra_weights: e.finish_init(vertex_config) diff --git a/graflo/architecture/schema.py b/graflo/architecture/schema.py index e4e92c6f..8a95c167 100644 --- a/graflo/architecture/schema.py +++ b/graflo/architecture/schema.py @@ -19,7 +19,7 @@ Example: >>> schema = Schema( - ... general=SchemaMetadata(name="social_network", version="1.0"), + ... general=SchemaMetadata(name="social_network", version="1.0.0"), ... vertex_config=VertexConfig(...), ... edge_config=EdgeConfig(...), ... resources=[Resource(...)] @@ -30,6 +30,7 @@ from __future__ import annotations import logging +import re from collections import Counter from typing import Any @@ -50,12 +51,19 @@ logger = logging.getLogger(__name__) +_SEMVER_RE = re.compile( + r"^\d+\.\d+\.\d+" + r"(-[0-9A-Za-z-]+(\.[0-9A-Za-z-]+)*)?" + r"(\+[0-9A-Za-z-]+(\.[0-9A-Za-z-]+)*)?$" +) + + class SchemaMetadata(ConfigBaseModel): """Schema metadata and versioning information. - Holds metadata about the schema, including its name and version. - Used for schema identification and versioning. Suitable for LLM-generated - schema constituents. + Holds metadata about the schema, including its name, version, and + description. Used for schema identification and versioning. + Suitable for LLM-generated schema constituents. """ name: str = PydanticField( @@ -64,9 +72,23 @@ class SchemaMetadata(ConfigBaseModel): ) version: str | None = PydanticField( default=None, - description="Optional version string of the schema (e.g. semantic version).", + description="Semantic version of the schema (e.g. '1.0.0', '2.1.3-beta+build.42').", + ) + description: str | None = PydanticField( + default=None, + description="Optional human-readable description of the schema.", ) + @field_validator("version") + @classmethod + def _validate_semver(cls, v: str | None) -> str | None: + if v is not None and not _SEMVER_RE.match(v): + raise ValueError( + f"version '{v}' is not a valid semantic version " + f"(expected MAJOR.MINOR.PATCH[-prerelease][+build])" + ) + return v + class Schema(ConfigBaseModel): """Graph database schema configuration. diff --git a/graflo/filter/__init__.py b/graflo/filter/__init__.py index 55a22bd4..dc9c5a7f 100644 --- a/graflo/filter/__init__.py +++ b/graflo/filter/__init__.py @@ -9,7 +9,7 @@ - FilterExpression: Filter expression (leaf or composite logical formulae) Example: - >>> from graflo.filter.onto import FilterExpression + >>> from graflo.filter import FilterExpression >>> expr = FilterExpression.from_dict({ ... "AND": [ ... {"field": "age", "cmp_operator": ">=", "value": 18}, @@ -18,3 +18,11 @@ ... }) >>> # Converts to: "age >= 18 AND status == 'active'" """ + +from .onto import ComparisonOperator, FilterExpression, LogicalOperator + +__all__ = [ + "ComparisonOperator", + "FilterExpression", + "LogicalOperator", +] diff --git a/graflo/filter/onto.py b/graflo/filter/onto.py index 942727df..ac88f2de 100644 --- a/graflo/filter/onto.py +++ b/graflo/filter/onto.py @@ -81,6 +81,8 @@ class ComparisonOperator(BaseEnum): GT: Greater than (>) LT: Less than (<) IN: Membership test (IN) + IS_NULL: Null check (IS NULL) + IS_NOT_NULL: Non-null check (IS NOT NULL) """ NEQ = "!=" @@ -90,6 +92,8 @@ class ComparisonOperator(BaseEnum): GT = ">" LT = "<" IN = "IN" + IS_NULL = "IS_NULL" + IS_NOT_NULL = "IS_NOT_NULL" class FilterExpression(ConfigBaseModel): @@ -141,10 +145,16 @@ def leaf_operator_to_unary_op(cls, data: Any) -> Any: @model_validator(mode="after") def check_discriminated_shape(self) -> FilterExpression: - """Enforce exactly one shape per kind.""" + """Enforce exactly one shape per kind and normalise null-check operators.""" if self.kind == "leaf": if self.operator is not None or self.deps: raise ValueError("leaf expression must not have operator or deps") + # IS_NULL / IS_NOT_NULL are unary; clear any spurious value list + if self.cmp_operator in ( + ComparisonOperator.IS_NULL, + ComparisonOperator.IS_NOT_NULL, + ): + object.__setattr__(self, "value", []) else: if self.operator is None: raise ValueError("composite expression must have operator") @@ -218,13 +228,20 @@ def __call__( return self._call_leaf(doc_name=doc_name, kind=kind, **kwargs) return self._call_composite(doc_name=doc_name, kind=kind, **kwargs) + def _is_null_operator(self) -> bool: + """Check if this is a null-checking operator (IS_NULL or IS_NOT_NULL).""" + return self.cmp_operator in ( + ComparisonOperator.IS_NULL, + ComparisonOperator.IS_NOT_NULL, + ) + def _call_leaf( self, doc_name="doc", kind: ExpressionFlavor = ExpressionFlavor.AQL, **kwargs, ) -> str | bool: - if not self.value: + if not self._is_null_operator() and not self.value: logger.warning(f"for {self} value is not set : {self.value}") if kind == ExpressionFlavor.AQL: assert self.cmp_operator is not None @@ -275,6 +292,10 @@ def _cast_value(self) -> str: return value def _cast_arango(self, doc_name: str) -> str: + if self.cmp_operator == ComparisonOperator.IS_NULL: + return f'{doc_name}["{self.field}"] == null' + if self.cmp_operator == ComparisonOperator.IS_NOT_NULL: + return f'{doc_name}["{self.field}"] != null' const = self._cast_value() lemma = f"{self.cmp_operator} {const}" if self.unary_op is not None: @@ -284,6 +305,10 @@ def _cast_arango(self, doc_name: str) -> str: return lemma def _cast_cypher(self, doc_name: str) -> str: + if self.cmp_operator == ComparisonOperator.IS_NULL: + return f"{doc_name}.{self.field} IS NULL" + if self.cmp_operator == ComparisonOperator.IS_NOT_NULL: + return f"{doc_name}.{self.field} IS NOT NULL" const = self._cast_value() cmp_op = ( "=" if self.cmp_operator == ComparisonOperator.EQ else self.cmp_operator @@ -296,6 +321,10 @@ def _cast_cypher(self, doc_name: str) -> str: return lemma def _cast_tigergraph(self, doc_name: str) -> str: + if self.cmp_operator == ComparisonOperator.IS_NULL: + return f"{doc_name}.{self.field} IS NULL" + if self.cmp_operator == ComparisonOperator.IS_NOT_NULL: + return f"{doc_name}.{self.field} IS NOT NULL" const = self._cast_value() cmp_op = ( "==" if self.cmp_operator == ComparisonOperator.EQ else self.cmp_operator @@ -307,10 +336,27 @@ def _cast_tigergraph(self, doc_name: str) -> str: lemma = f"{doc_name}.{self.field} {lemma}" return lemma + @staticmethod + def _quote_sql_field(field: str) -> str: + """Quote a SQL field name, handling dotted alias.column references. + + ``sys_id`` -> ``"sys_id"`` + ``s.sys_id`` -> ``s."sys_id"`` + """ + if "." in field: + alias, col = field.split(".", 1) + return f'{alias}."{col}"' + return f'"{field}"' + def _cast_sql(self) -> str: """Render leaf as SQL WHERE fragment: \"column\" op value (strings/dates single-quoted).""" if not self.field: return "" + quoted = self._quote_sql_field(self.field) + if self.cmp_operator == ComparisonOperator.IS_NULL: + return f"{quoted} IS NULL" + if self.cmp_operator == ComparisonOperator.IS_NOT_NULL: + return f"{quoted} IS NOT NULL" if self.cmp_operator == ComparisonOperator.EQ: op_str = "=" elif self.cmp_operator == ComparisonOperator.NEQ: @@ -333,11 +379,15 @@ def _cast_sql(self) -> str: # Strings and ISO datetimes: single-quoted for SQL value_str = str(value).replace("'", "''") value_str = f"'{value_str}'" - return f'"{self.field}" {op_str} {value_str}' + return f"{quoted} {op_str} {value_str}" def _cast_restpp(self, field_types: dict[str, Any] | None = None) -> str: if not self.field: return "" + if self.cmp_operator == ComparisonOperator.IS_NULL: + return f'{self.field}=""' + if self.cmp_operator == ComparisonOperator.IS_NOT_NULL: + return f'{self.field}!=""' if self.cmp_operator == ComparisonOperator.EQ: op_str = "=" elif self.cmp_operator == ComparisonOperator.NEQ: @@ -376,6 +426,10 @@ def _cast_restpp(self, field_types: dict[str, Any] | None = None) -> str: def _cast_python(self, **kwargs: Any) -> bool: if self.field is not None: field_val = kwargs.pop(self.field, None) + if self.cmp_operator == ComparisonOperator.IS_NULL: + return field_val is None + if self.cmp_operator == ComparisonOperator.IS_NOT_NULL: + return field_val is not None if field_val is not None and self.unary_op is not None: foo = getattr(field_val, self.unary_op) return foo(self.value[0]) diff --git a/graflo/filter/sql.py b/graflo/filter/sql.py new file mode 100644 index 00000000..74a277ce --- /dev/null +++ b/graflo/filter/sql.py @@ -0,0 +1,56 @@ +"""SQL filter helpers built on top of FilterExpression. + +Provides utility functions for generating SQL WHERE fragments from +structured filter parameters. +""" + +from __future__ import annotations + +from typing import cast + +from graflo.filter.onto import ( + ComparisonOperator, + FilterExpression, + LogicalOperator, +) +from graflo.onto import ExpressionFlavor + + +def datetime_range_where_sql( + datetime_after: str | None, + datetime_before: str | None, + date_column: str, +) -> str: + """Build SQL WHERE fragment for [datetime_after, datetime_before) via FilterExpression. + + Returns empty string if both bounds are None; otherwise uses column with >= and <. + """ + if not datetime_after and not datetime_before: + return "" + parts: list[FilterExpression] = [] + if datetime_after is not None: + parts.append( + FilterExpression( + kind="leaf", + field=date_column, + cmp_operator=ComparisonOperator.GE, + value=[datetime_after], + ) + ) + if datetime_before is not None: + parts.append( + FilterExpression( + kind="leaf", + field=date_column, + cmp_operator=ComparisonOperator.LT, + value=[datetime_before], + ) + ) + if len(parts) == 1: + return cast(str, parts[0](kind=ExpressionFlavor.SQL)) + expr = FilterExpression( + kind="composite", + operator=LogicalOperator.AND, + deps=parts, + ) + return cast(str, expr(kind=ExpressionFlavor.SQL)) diff --git a/graflo/hq/__init__.py b/graflo/hq/__init__.py index fc681d4e..58174b48 100644 --- a/graflo/hq/__init__.py +++ b/graflo/hq/__init__.py @@ -5,16 +5,20 @@ """ from graflo.hq.caster import Caster, IngestionParams +from graflo.hq.db_writer import DBWriter from graflo.hq.graph_engine import GraphEngine from graflo.hq.inferencer import InferenceManager +from graflo.hq.registry_builder import RegistryBuilder from graflo.hq.resource_mapper import ResourceMapper from graflo.hq.sanitizer import SchemaSanitizer __all__ = [ "Caster", + "DBWriter", "GraphEngine", "IngestionParams", "InferenceManager", + "RegistryBuilder", "ResourceMapper", "SchemaSanitizer", ] diff --git a/graflo/hq/auto_join.py b/graflo/hq/auto_join.py new file mode 100644 index 00000000..dcc974e9 --- /dev/null +++ b/graflo/hq/auto_join.py @@ -0,0 +1,157 @@ +"""Auto-JOIN generation for edge resources. + +When a Resource's pipeline contains an EdgeActor whose edge has +``match_source`` / ``match_target``, and the source/target vertex types +have known TablePatterns, this module can auto-generate JoinClauses and +IS_NOT_NULL filters on the edge resource's TablePattern so that the +resulting SQL fetches fully resolved rows. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from graflo.architecture.actor import ActorWrapper, EdgeActor +from graflo.architecture.resource import Resource +from graflo.filter.onto import ComparisonOperator, FilterExpression +from graflo.util.onto import JoinClause, TablePattern + +if TYPE_CHECKING: + from graflo.architecture.vertex import VertexConfig + from graflo.util.onto import Patterns + +logger = logging.getLogger(__name__) + +# Alias prefixes assigned to source / target joins. +_SOURCE_ALIAS = "s" +_TARGET_ALIAS = "t" + + +def enrich_edge_pattern_with_joins( + resource: Resource, + pattern: TablePattern, + patterns: Patterns, + vertex_config: VertexConfig, +) -> None: + """Mutate *pattern* in-place, adding JoinClauses + IS_NOT_NULL filters. + + The function inspects the Resource's actor pipeline for EdgeActors and, + for each edge that declares ``match_source`` **and** ``match_target``, + looks up the source / target vertex TablePatterns and primary keys to + construct LEFT JOINs and NOT-NULL guards. + + If the pattern already has joins, this function is a no-op (the user + provided explicit join specs). + + Args: + resource: The Resource whose pipeline is inspected. + pattern: The TablePattern to enrich (mutated in-place). + patterns: The Patterns collection holding all vertex TablePatterns. + vertex_config: VertexConfig for looking up primary keys. + """ + if pattern.joins: + return + + edge_actors = _collect_edge_actors(resource.root) + if not edge_actors: + return + + new_joins: list[JoinClause] = [] + new_filters: list[FilterExpression] = [] + + for ea in edge_actors: + edge = ea.edge + if not edge.match_source or not edge.match_target: + continue + + source_info = _vertex_table_info(edge.source, patterns, vertex_config) + target_info = _vertex_table_info(edge.target, patterns, vertex_config) + if source_info is None or target_info is None: + logger.debug( + "Skipping auto-join for edge %s->%s: missing vertex pattern", + edge.source, + edge.target, + ) + continue + + src_table, src_schema, src_pk = source_info + tgt_table, tgt_schema, tgt_pk = target_info + + src_alias = _SOURCE_ALIAS + tgt_alias = _TARGET_ALIAS + + new_joins.append( + JoinClause( + table=src_table, + schema_name=src_schema, + alias=src_alias, + on_self=edge.match_source, + on_other=src_pk, + join_type="LEFT", + ) + ) + new_joins.append( + JoinClause( + table=tgt_table, + schema_name=tgt_schema, + alias=tgt_alias, + on_self=edge.match_target, + on_other=tgt_pk, + join_type="LEFT", + ) + ) + + new_filters.append( + FilterExpression( + kind="leaf", + field=f"{src_alias}.{src_pk}", + cmp_operator=ComparisonOperator.IS_NOT_NULL, + ) + ) + new_filters.append( + FilterExpression( + kind="leaf", + field=f"{tgt_alias}.{tgt_pk}", + cmp_operator=ComparisonOperator.IS_NOT_NULL, + ) + ) + + if new_joins: + pattern.joins = new_joins + pattern.filters = list(pattern.filters) + new_filters + + +# ------------------------------------------------------------------ +# Internal helpers +# ------------------------------------------------------------------ + + +def _collect_edge_actors(wrapper: ActorWrapper) -> list[EdgeActor]: + """Recursively collect all EdgeActors from an ActorWrapper tree.""" + result: list[EdgeActor] = [] + for actor in wrapper.collect_actors(): + if isinstance(actor, EdgeActor): + result.append(actor) + return result + + +def _vertex_table_info( + vertex_name: str, + patterns: Patterns, + vertex_config: VertexConfig, +) -> tuple[str, str | None, str] | None: + """Return (table_name, schema_name, primary_key_field) for a vertex. + + Returns None if the vertex has no TablePattern in *patterns*. + """ + tp = patterns.table_patterns.get(vertex_name) + if tp is None: + return None + try: + pk_fields = vertex_config.index(vertex_name).fields + except (KeyError, IndexError): + return None + if not pk_fields: + return None + return tp.table_name, tp.schema_name, pk_fields[0] diff --git a/graflo/hq/caster.py b/graflo/hq/caster.py index 93c5c458..04f7b951 100644 --- a/graflo/hq/caster.py +++ b/graflo/hq/caster.py @@ -16,34 +16,27 @@ import asyncio import logging -import re import sys from pathlib import Path from typing import Any, cast import pandas as pd from pydantic import BaseModel + from suthing import Timer -from graflo.architecture.edge import Edge from graflo.architecture.onto import EncodingType, GraphContainer from graflo.architecture.schema import Schema -from graflo.filter.onto import ( - ComparisonOperator, - FilterExpression, - LogicalOperator, -) -from graflo.onto import ExpressionFlavor from graflo.data_source import ( AbstractDataSource, DataSourceFactory, DataSourceRegistry, ) -from graflo.data_source.sql import SQLConfig, SQLDataSource -from graflo.db import ConnectionManager from graflo.db.connection.onto import DBConfig +from graflo.hq.db_writer import DBWriter +from graflo.hq.registry_builder import RegistryBuilder from graflo.util.chunker import ChunkerType -from graflo.util.onto import FilePattern, Patterns, ResourceType, TablePattern +from graflo.util.onto import Patterns logger = logging.getLogger(__name__) @@ -116,91 +109,13 @@ def __init__( - dry: Whether to perform a dry run """ if ingestion_params is None: - # Create IngestionParams from kwargs or use defaults ingestion_params = IngestionParams(**kwargs) self.ingestion_params = ingestion_params self.schema = schema - @staticmethod - def _datetime_range_where_sql( - datetime_after: str | None, - datetime_before: str | None, - date_column: str, - ) -> str: - """Build SQL WHERE fragment for [datetime_after, datetime_before) via FilterExpression. - - Returns empty string if both bounds are None; otherwise uses column with >= and <. - """ - if not datetime_after and not datetime_before: - return "" - parts: list[FilterExpression] = [] - if datetime_after is not None: - parts.append( - FilterExpression( - kind="leaf", - field=date_column, - cmp_operator=ComparisonOperator.GE, - value=[datetime_after], - ) - ) - if datetime_before is not None: - parts.append( - FilterExpression( - kind="leaf", - field=date_column, - cmp_operator=ComparisonOperator.LT, - value=[datetime_before], - ) - ) - if len(parts) == 1: - return cast(str, parts[0](kind=ExpressionFlavor.SQL)) - expr = FilterExpression( - kind="composite", - operator=LogicalOperator.AND, - deps=parts, - ) - return cast(str, expr(kind=ExpressionFlavor.SQL)) - - @staticmethod - def discover_files( - fpath: Path | str, pattern: FilePattern, limit_files=None - ) -> list[Path]: - """Discover files matching a pattern in a directory. - - Args: - fpath: Path to search in (should be the directory containing files) - pattern: Pattern to match files against - limit_files: Optional limit on number of files to return - - Returns: - list[Path]: List of matching file paths - - Raises: - AssertionError: If pattern.sub_path is None - """ - assert pattern.sub_path is not None - if isinstance(fpath, str): - fpath_pathlib = Path(fpath) - else: - fpath_pathlib = fpath - - # fpath is already the directory to search (pattern.sub_path from caller) - # so we use it directly, not combined with pattern.sub_path again - files = [ - f - for f in fpath_pathlib.iterdir() - if f.is_file() - and ( - True - if pattern.regex is None - else re.search(pattern.regex, f.name) is not None - ) - ] - - if limit_files is not None: - files = files[:limit_files] - - return files + # ------------------------------------------------------------------ + # Casting + # ------------------------------------------------------------------ async def cast_normal_resource( self, data, resource_name: str | None = None @@ -216,7 +131,6 @@ async def cast_normal_resource( """ rr = self.schema.fetch_resource(resource_name) - # Process documents in parallel using asyncio semaphore = asyncio.Semaphore(self.ingestion_params.n_cores) async def process_doc(doc): @@ -228,6 +142,10 @@ async def process_doc(doc): graph = GraphContainer.from_docs_list(docs) return graph + # ------------------------------------------------------------------ + # Processing pipeline + # ------------------------------------------------------------------ + async def process_batch( self, batch, @@ -244,7 +162,8 @@ async def process_batch( gc = await self.cast_normal_resource(batch, resource_name=resource_name) if conn_conf is not None: - await self.push_db(gc=gc, conn_conf=conn_conf, resource_name=resource_name) + writer = self._make_db_writer() + await writer.write(gc=gc, conn_conf=conn_conf, resource_name=resource_name) async def process_data_source( self, @@ -259,10 +178,8 @@ async def process_data_source( resource_name: Optional name of the resource (overrides data_source.resource_name) conn_conf: Optional database connection configuration """ - # Use provided resource_name or fall back to data_source's resource_name actual_resource_name = resource_name or data_source.resource_name - # Use pattern-specific limit if available, otherwise use global max_items limit = getattr(data_source, "_pattern_limit", None) if limit is None: limit = self.ingestion_params.max_items @@ -302,16 +219,11 @@ async def process_resource( **kwargs: Additional arguments passed to data source creation (e.g., columns for list[list], encoding for files) """ - # Handle configuration dictionary if isinstance(resource_instance, dict): config = resource_instance.copy() - # Merge with kwargs (kwargs take precedence) config.update(kwargs) data_source = DataSourceFactory.create_data_source_from_config(config) - # Handle file paths elif isinstance(resource_instance, (Path, str)): - # File path - create FileDataSource - # Extract only valid file data source parameters with proper typing file_type: str | ChunkerType | None = cast( str | ChunkerType | None, kwargs.get("file_type", None) ) @@ -325,10 +237,7 @@ async def process_resource( encoding=encoding, sep=sep, ) - # Handle in-memory data else: - # In-memory data - create InMemoryDataSource - # Extract only valid in-memory data source parameters with proper typing columns: list[str] | None = cast( list[str] | None, kwargs.get("columns", None) ) @@ -339,164 +248,15 @@ async def process_resource( data_source.resource_name = resource_name - # Process using the data source await self.process_data_source( data_source=data_source, resource_name=resource_name, conn_conf=conn_conf, ) - async def push_db( - self, - gc: GraphContainer, - conn_conf: DBConfig, - resource_name: str | None, - ): - """Push graph container data to the database. - - Args: - gc: Graph container with data to push - conn_conf: Database connection configuration - resource_name: Optional name of the resource - """ - vc = self.schema.vertex_config - resource = self.schema.fetch_resource(resource_name) - - # Push vertices in parallel (with configurable concurrency control to prevent deadlocks) - # Some databases can deadlock when multiple transactions modify the same nodes - # Use a semaphore to limit concurrent operations based on max_concurrent_db_ops - max_concurrent = ( - self.ingestion_params.max_concurrent_db_ops - if self.ingestion_params.max_concurrent_db_ops is not None - else self.ingestion_params.n_cores - ) - vertex_semaphore = asyncio.Semaphore(max_concurrent) - - async def push_vertex(vcol: str, data: list[dict]): - async with vertex_semaphore: - - def _push_vertex_sync(): - with ConnectionManager(connection_config=conn_conf) as db_client: - # blank nodes: push and get back their keys {"_key": ...} - if vcol in vc.blank_vertices: - query0 = db_client.insert_return_batch( - data, vc.vertex_dbname(vcol) - ) - cursor = db_client.execute(query0) - return vcol, [item for item in cursor] - else: - db_client.upsert_docs_batch( - data, - vc.vertex_dbname(vcol), - vc.index(vcol), - update_keys="doc", - filter_uniques=True, - dry=self.ingestion_params.dry, - ) - return vcol, None - - return await asyncio.to_thread(_push_vertex_sync) - - # Process all vertices in parallel (with semaphore limiting concurrency for Neo4j) - vertex_results = await asyncio.gather( - *[push_vertex(vcol, data) for vcol, data in gc.vertices.items()] - ) - - # Update blank vertices with returned keys - for vcol, result in vertex_results: - if result is not None: - gc.vertices[vcol] = result - - # update edge misc with blank node edges - for vcol in vc.blank_vertices: - for edge_id, edge in self.schema.edge_config.edges_items(): - vfrom, vto, relation = edge_id - if vcol == vfrom or vcol == vto: - if edge_id not in gc.edges: - gc.edges[edge_id] = [] - gc.edges[edge_id].extend( - [ - (x, y, {}) - for x, y in zip(gc.vertices[vfrom], gc.vertices[vto]) - ] - ) - - # Process extra weights - async def process_extra_weights(): - def _process_extra_weights_sync(): - with ConnectionManager(connection_config=conn_conf) as db_client: - # currently works only on item level - for edge in resource.extra_weights: - if edge.weights is None: - continue - for weight in edge.weights.vertices: - if weight.name in vc.vertex_set: - index_fields = vc.index(weight.name) - - if ( - not self.ingestion_params.dry - and weight.name in gc.vertices - ): - weights_per_item = ( - db_client.fetch_present_documents( - class_name=vc.vertex_dbname(weight.name), - batch=gc.vertices[weight.name], - match_keys=index_fields.fields, - keep_keys=weight.fields, - ) - ) - - for j, item in enumerate(gc.linear): - weights = weights_per_item[j] - - for ee in item[edge.edge_id]: - weight_collection_attached = { - weight.cfield(k): v - for k, v in weights[0].items() - } - ee.update(weight_collection_attached) - else: - logger.error(f"{weight.name} not a valid vertex") - - await asyncio.to_thread(_process_extra_weights_sync) - - await process_extra_weights() - - # Push edges in parallel (with configurable concurrency control to prevent deadlocks) - # Some databases can deadlock when multiple transactions modify the same nodes/relationships - # Use a semaphore to limit concurrent operations based on max_concurrent_db_ops - edge_semaphore = asyncio.Semaphore(max_concurrent) - - async def push_edge(edge_id: tuple, edge: Edge): - async with edge_semaphore: - - def _push_edge_sync(): - with ConnectionManager(connection_config=conn_conf) as db_client: - for ee in gc.loop_over_relations(edge_id): - _, _, relation = ee - if not self.ingestion_params.dry: - data = gc.edges[ee] - db_client.insert_edges_batch( - docs_edges=data, - source_class=vc.vertex_dbname(edge.source), - target_class=vc.vertex_dbname(edge.target), - relation_name=relation, - match_keys_source=vc.index(edge.source).fields, - match_keys_target=vc.index(edge.target).fields, - filter_uniques=False, - dry=self.ingestion_params.dry, - collection_name=edge.database_name, - ) - - await asyncio.to_thread(_push_edge_sync) - - # Process all edges in parallel (with semaphore limiting concurrency for Neo4j) - await asyncio.gather( - *[ - push_edge(edge_id, edge) - for edge_id, edge in self.schema.edge_config.edges_items() - ] - ) + # ------------------------------------------------------------------ + # Queue-based processing + # ------------------------------------------------------------------ async def process_with_queue( self, tasks: asyncio.Queue, conn_conf: DBConfig | None = None @@ -507,20 +267,16 @@ async def process_with_queue( tasks: Async queue of tasks to process conn_conf: Optional database connection configuration """ - # Sentinel value to signal completion SENTINEL = None while True: try: - # Get task from queue (will wait if queue is empty) task = await tasks.get() - # Check for sentinel value if task is SENTINEL: tasks.task_done() break - # Support both (Path, str) tuples and DataSource instances if isinstance(task, tuple) and len(task) == 2: filepath, resource_name = task await self.process_resource( @@ -538,6 +294,10 @@ async def process_with_queue( tasks.task_done() break + # ------------------------------------------------------------------ + # Normalization utility + # ------------------------------------------------------------------ + @staticmethod def normalize_resource( data: pd.DataFrame | list[list] | list[dict], columns: list[str] | None = None @@ -558,14 +318,18 @@ def normalize_resource( columns = data.columns.tolist() _data = data.values.tolist() elif data and isinstance(data[0], list): - _data = cast(list[list], data) # Tell mypy this is list[list] + _data = cast(list[list], data) if columns is None: raise ValueError("columns should be set") else: - return cast(list[dict], data) # Tell mypy this is list[dict] + return cast(list[dict], data) rows_dressed = [{k: v for k, v in zip(columns, item)} for item in _data] return rows_dressed + # ------------------------------------------------------------------ + # Ingestion orchestration + # ------------------------------------------------------------------ + async def ingest_data_sources( self, data_source_registry: DataSourceRegistry, @@ -586,7 +350,6 @@ async def ingest_data_sources( if ingestion_params is None: ingestion_params = IngestionParams() - # Update ingestion params (may override defaults set in __init__) self.ingestion_params = ingestion_params init_only = ingestion_params.init_only @@ -594,7 +357,6 @@ async def ingest_data_sources( logger.info("ingest execution bound to init") sys.exit(0) - # Collect all data sources tasks: list[AbstractDataSource] = [] for resource_name in self.schema._resources.keys(): data_sources = data_source_registry.get_data_sources(resource_name) @@ -606,22 +368,18 @@ async def ingest_data_sources( with Timer() as klepsidra: if self.ingestion_params.n_cores > 1: - # Use asyncio for parallel processing queue_tasks: asyncio.Queue = asyncio.Queue() for item in tasks: await queue_tasks.put(item) - # Add sentinel values to signal workers to stop for _ in range(self.ingestion_params.n_cores): await queue_tasks.put(None) - # Create worker tasks worker_tasks = [ self.process_with_queue(queue_tasks, conn_conf=conn_conf) for _ in range(self.ingestion_params.n_cores) ] - # Run all workers in parallel await asyncio.gather(*worker_tasks) else: for data_source in tasks: @@ -630,192 +388,10 @@ async def ingest_data_sources( ) logger.info(f"Processing took {klepsidra.elapsed:.1f} sec") - def _register_file_sources( - self, - registry: DataSourceRegistry, - resource_name: str, - pattern: FilePattern, - ingestion_params: IngestionParams, - ) -> None: - """Register file data sources for a resource. - - Args: - registry: Data source registry to add sources to - resource_name: Name of the resource - pattern: File pattern configuration - ingestion_params: Ingestion parameters - """ - if pattern.sub_path is None: - logger.warning( - f"FilePattern for resource '{resource_name}' has no sub_path, skipping" - ) - return - - path_obj = pattern.sub_path.expanduser() - files = Caster.discover_files( - path_obj, limit_files=ingestion_params.limit_files, pattern=pattern - ) - logger.info(f"For resource name {resource_name} {len(files)} files were found") - - for file_path in files: - file_source = DataSourceFactory.create_file_data_source(path=file_path) - registry.register(file_source, resource_name=resource_name) - - def _register_sql_table_sources( - self, - registry: DataSourceRegistry, - resource_name: str, - pattern: TablePattern, - patterns: "Patterns", - ingestion_params: IngestionParams, - ) -> None: - """Register SQL table data sources for a resource. - - Uses SQLDataSource with batch processing (cursors) instead of loading - all data into memory. This is efficient for large tables. - - Args: - registry: Data source registry to add sources to - resource_name: Name of the resource - pattern: Table pattern configuration - patterns: Patterns instance for accessing configs - ingestion_params: Ingestion parameters - """ - postgres_config = patterns.get_postgres_config(resource_name) - if postgres_config is None: - logger.warning( - f"PostgreSQL table '{resource_name}' has no connection config, skipping" - ) - return - - table_info = patterns.get_table_info(resource_name) - if table_info is None: - logger.warning( - f"Could not get table info for resource '{resource_name}', skipping" - ) - return - - table_name, schema_name = table_info - effective_schema = schema_name or postgres_config.schema_name or "public" - - try: - # Build base query - query = f'SELECT * FROM "{effective_schema}"."{table_name}"' - where_parts: list[str] = [] - pattern_where = pattern.build_where_clause() - if pattern_where: - where_parts.append(pattern_where) - # Ingestion datetime range [datetime_after, datetime_before) - date_column = pattern.date_field or ingestion_params.datetime_column - if ( - ingestion_params.datetime_after or ingestion_params.datetime_before - ) and date_column: - datetime_where = Caster._datetime_range_where_sql( - ingestion_params.datetime_after, - ingestion_params.datetime_before, - date_column, - ) - if datetime_where: - where_parts.append(datetime_where) - elif ingestion_params.datetime_after or ingestion_params.datetime_before: - logger.warning( - "datetime_after/datetime_before set but no date column: " - "set TablePattern.date_field or IngestionParams.datetime_column for resource %s", - resource_name, - ) - if where_parts: - query += " WHERE " + " AND ".join(where_parts) - - # Get SQLAlchemy connection string from PostgresConfig - connection_string = postgres_config.to_sqlalchemy_connection_string() - - # Create SQLDataSource with pagination for efficient batch processing - # Note: max_items limit is handled by SQLDataSource.iter_batches() limit parameter - sql_config = SQLConfig( - connection_string=connection_string, - query=query, - pagination=True, - page_size=ingestion_params.batch_size, # Use batch_size for page size - ) - sql_source = SQLDataSource(config=sql_config) - - # Register the SQL data source (it will be processed in batches) - registry.register(sql_source, resource_name=resource_name) - - logger.info( - f"Created SQL data source for table '{effective_schema}.{table_name}' " - f"mapped to resource '{resource_name}' (will process in batches of {ingestion_params.batch_size})" - ) - except Exception as e: - logger.error( - f"Failed to create data source for PostgreSQL table '{resource_name}': {e}", - exc_info=True, - ) - - def _build_registry_from_patterns( - self, - patterns: "Patterns", - ingestion_params: IngestionParams, - ) -> DataSourceRegistry: - """Build data source registry from patterns. - - Args: - patterns: Patterns instance mapping resources to data sources - ingestion_params: Ingestion parameters - - Returns: - DataSourceRegistry with registered data sources - """ - registry = DataSourceRegistry() - - for resource in self.schema.resources: - resource_name = resource.name - resource_type = patterns.get_resource_type(resource_name) - - if resource_type is None: - logger.warning( - f"No resource type found for resource '{resource_name}', skipping" - ) - continue - - pattern = patterns.patterns.get(resource_name) - if pattern is None: - logger.warning( - f"No pattern found for resource '{resource_name}', skipping" - ) - continue - - if resource_type == ResourceType.FILE: - if not isinstance(pattern, FilePattern): - logger.warning( - f"Pattern for resource '{resource_name}' is not a FilePattern, skipping" - ) - continue - self._register_file_sources( - registry, resource_name, pattern, ingestion_params - ) - - elif resource_type == ResourceType.SQL_TABLE: - if not isinstance(pattern, TablePattern): - logger.warning( - f"Pattern for resource '{resource_name}' is not a TablePattern, skipping" - ) - continue - self._register_sql_table_sources( - registry, resource_name, pattern, patterns, ingestion_params - ) - - else: - logger.warning( - f"Unsupported resource type '{resource_type}' for resource '{resource_name}', skipping" - ) - - return registry - def ingest( self, target_db_config: DBConfig, - patterns: "Patterns | None" = None, + patterns: Patterns | None = None, ingestion_params: IngestionParams | None = None, ): """Ingest data into the graph database. @@ -833,21 +409,16 @@ def ingest( ingestion_params: IngestionParams instance with ingestion configuration. If None, uses default IngestionParams() """ - # Normalize parameters patterns = patterns or Patterns() ingestion_params = ingestion_params or IngestionParams() - # Initialize vertex config with correct field types based on database type db_flavor = target_db_config.connection_type self.schema.vertex_config.db_flavor = db_flavor self.schema.vertex_config.finish_init() - # Initialize edge config after vertex config is fully initialized self.schema.edge_config.finish_init(self.schema.vertex_config) - # Build registry from patterns - registry = self._build_registry_from_patterns(patterns, ingestion_params) + registry = RegistryBuilder(self.schema).build(patterns, ingestion_params) - # Ingest data sources asyncio.run( self.ingest_data_sources( data_source_registry=registry, @@ -855,3 +426,20 @@ def ingest( ingestion_params=ingestion_params, ) ) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _make_db_writer(self) -> DBWriter: + """Create a :class:`DBWriter` from the current ingestion params.""" + max_concurrent = ( + self.ingestion_params.max_concurrent_db_ops + if self.ingestion_params.max_concurrent_db_ops is not None + else self.ingestion_params.n_cores + ) + return DBWriter( + schema=self.schema, + dry=self.ingestion_params.dry, + max_concurrent=max_concurrent, + ) diff --git a/graflo/hq/db_writer.py b/graflo/hq/db_writer.py new file mode 100644 index 00000000..97d99805 --- /dev/null +++ b/graflo/hq/db_writer.py @@ -0,0 +1,189 @@ +"""Database writer for pushing graph data to the target database. + +Handles vertex upserts (including blank-node resolution), extra-weight +enrichment, and edge insertion. All heavy DB I/O lives here so that +:class:`Caster` stays a lightweight orchestrator. +""" + +from __future__ import annotations + +import asyncio +import logging + +from graflo.architecture.edge import Edge +from graflo.architecture.onto import GraphContainer +from graflo.architecture.schema import Schema +from graflo.db import ConnectionManager +from graflo.db.connection.onto import DBConfig + +logger = logging.getLogger(__name__) + + +class DBWriter: + """Push :class:`GraphContainer` data to the target graph database. + + Attributes: + schema: Schema configuration providing vertex/edge metadata. + dry: When ``True`` no database mutations are performed. + max_concurrent: Upper bound on concurrent DB operations (semaphore size). + """ + + def __init__(self, schema: Schema, *, dry: bool = False, max_concurrent: int = 1): + self.schema = schema + self.dry = dry + self.max_concurrent = max_concurrent + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + async def write( + self, + gc: GraphContainer, + conn_conf: DBConfig, + resource_name: str | None, + ) -> None: + """Push *gc* to the database (vertices, extra weights, then edges). + + .. note:: + *gc* is mutated in-place: blank-vertex keys are updated and blank + edges are extended after the vertex round-trip. + """ + _ = self.schema.vertex_config + resource = self.schema.fetch_resource(resource_name) + + await self._push_vertices(gc, conn_conf) + self._resolve_blank_edges(gc) + await self._enrich_extra_weights(gc, conn_conf, resource) + await self._push_edges(gc, conn_conf) + + # ------------------------------------------------------------------ + # Vertices + # ------------------------------------------------------------------ + + async def _push_vertices(self, gc: GraphContainer, conn_conf: DBConfig) -> None: + """Upsert all vertex collections in *gc*, resolving blank nodes.""" + vc = self.schema.vertex_config + semaphore = asyncio.Semaphore(self.max_concurrent) + + async def _push_one(vcol: str, data: list[dict]): + async with semaphore: + + def _sync(): + with ConnectionManager(connection_config=conn_conf) as db: + if vcol in vc.blank_vertices: + query = db.insert_return_batch(data, vc.vertex_dbname(vcol)) + cursor = db.execute(query) + return vcol, list(cursor) + db.upsert_docs_batch( + data, + vc.vertex_dbname(vcol), + vc.index(vcol), + update_keys="doc", + filter_uniques=True, + dry=self.dry, + ) + return vcol, None + + return await asyncio.to_thread(_sync) + + results = await asyncio.gather( + *[_push_one(vcol, data) for vcol, data in gc.vertices.items()] + ) + + for vcol, result in results: + if result is not None: + gc.vertices[vcol] = result + + # ------------------------------------------------------------------ + # Blank-edge resolution + # ------------------------------------------------------------------ + + def _resolve_blank_edges(self, gc: GraphContainer) -> None: + """Extend edge lists for blank vertices after their keys are resolved.""" + vc = self.schema.vertex_config + for vcol in vc.blank_vertices: + for edge_id, _edge in self.schema.edge_config.edges_items(): + vfrom, vto, _relation = edge_id + if vcol == vfrom or vcol == vto: + if edge_id not in gc.edges: + gc.edges[edge_id] = [] + gc.edges[edge_id].extend( + (x, y, {}) for x, y in zip(gc.vertices[vfrom], gc.vertices[vto]) + ) + + # ------------------------------------------------------------------ + # Extra weights + # ------------------------------------------------------------------ + + async def _enrich_extra_weights( + self, gc: GraphContainer, conn_conf: DBConfig, resource + ) -> None: + """Fetch extra-weight vertex data from the DB and attach to edges.""" + vc = self.schema.vertex_config + + def _sync(): + with ConnectionManager(connection_config=conn_conf) as db: + for edge in resource.extra_weights: + if edge.weights is None: + continue + for weight in edge.weights.vertices: + if weight.name not in vc.vertex_set: + logger.error(f"{weight.name} not a valid vertex") + continue + index_fields = vc.index(weight.name) + if self.dry or weight.name not in gc.vertices: + continue + weights_per_item = db.fetch_present_documents( + class_name=vc.vertex_dbname(weight.name), + batch=gc.vertices[weight.name], + match_keys=index_fields.fields, + keep_keys=weight.fields, + ) + for j, item in enumerate(gc.linear): + weights = weights_per_item[j] + for ee in item[edge.edge_id]: + ee.update( + {weight.cfield(k): v for k, v in weights[0].items()} + ) + + await asyncio.to_thread(_sync) + + # ------------------------------------------------------------------ + # Edges + # ------------------------------------------------------------------ + + async def _push_edges(self, gc: GraphContainer, conn_conf: DBConfig) -> None: + """Insert all edges in *gc*.""" + vc = self.schema.vertex_config + semaphore = asyncio.Semaphore(self.max_concurrent) + + async def _push_one(edge_id: tuple, edge: Edge): + async with semaphore: + + def _sync(): + with ConnectionManager(connection_config=conn_conf) as db: + for ee in gc.loop_over_relations(edge_id): + _, _, relation = ee + if not self.dry: + data = gc.edges[ee] + db.insert_edges_batch( + docs_edges=data, + source_class=vc.vertex_dbname(edge.source), + target_class=vc.vertex_dbname(edge.target), + relation_name=relation, + match_keys_source=vc.index(edge.source).fields, + match_keys_target=vc.index(edge.target).fields, + filter_uniques=False, + dry=self.dry, + collection_name=edge.database_name, + ) + + await asyncio.to_thread(_sync) + + await asyncio.gather( + *[ + _push_one(edge_id, edge) + for edge_id, edge in self.schema.edge_config.edges_items() + ] + ) diff --git a/graflo/hq/graph_engine.py b/graflo/hq/graph_engine.py index 7d758c85..88ac1565 100644 --- a/graflo/hq/graph_engine.py +++ b/graflo/hq/graph_engine.py @@ -7,7 +7,7 @@ import logging -from graflo import Schema +from graflo.architecture.schema import Schema from graflo.onto import DBType from graflo.architecture.onto_sql import SchemaIntrospectionResult from graflo.db import ConnectionManager, PostgresConnection diff --git a/graflo/hq/inferencer.py b/graflo/hq/inferencer.py index 5073c254..29cc53ce 100644 --- a/graflo/hq/inferencer.py +++ b/graflo/hq/inferencer.py @@ -1,4 +1,4 @@ -from graflo import Schema +from graflo.architecture.schema import Schema from graflo.onto import DBType from graflo.architecture import Resource from graflo.db import PostgresConnection diff --git a/graflo/hq/registry_builder.py b/graflo/hq/registry_builder.py new file mode 100644 index 00000000..484e502a --- /dev/null +++ b/graflo/hq/registry_builder.py @@ -0,0 +1,256 @@ +"""Build a :class:`DataSourceRegistry` from :class:`Patterns` and :class:`Schema`. + +Handles file discovery, SQL table source creation (with auto-JOIN +enrichment and datetime filtering), and pattern dispatch by resource type. +""" + +from __future__ import annotations + +import logging +import re +from pathlib import Path +from typing import TYPE_CHECKING + +from graflo.architecture.schema import Schema +from graflo.data_source import DataSourceFactory, DataSourceRegistry +from graflo.data_source.sql import SQLConfig, SQLDataSource +from graflo.filter.sql import datetime_range_where_sql +from graflo.util.onto import FilePattern, ResourceType, TablePattern + +if TYPE_CHECKING: + from graflo.hq.caster import IngestionParams + from graflo.util.onto import Patterns + +logger = logging.getLogger(__name__) + + +class RegistryBuilder: + """Create a :class:`DataSourceRegistry` from :class:`Patterns`. + + Attributes: + schema: Schema providing the resource definitions and vertex/edge config. + """ + + def __init__(self, schema: Schema): + self.schema = schema + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def build( + self, + patterns: Patterns, + ingestion_params: IngestionParams, + ) -> DataSourceRegistry: + """Return a populated :class:`DataSourceRegistry`. + + Iterates over every resource in the schema, looks up its pattern and + resource type, then delegates to the appropriate registration helper. + """ + registry = DataSourceRegistry() + + for resource in self.schema.resources: + resource_name = resource.name + resource_type = patterns.get_resource_type(resource_name) + + if resource_type is None: + logger.warning( + f"No resource type found for resource '{resource_name}', skipping" + ) + continue + + pattern = patterns.patterns.get(resource_name) + if pattern is None: + logger.warning( + f"No pattern found for resource '{resource_name}', skipping" + ) + continue + + if resource_type == ResourceType.FILE: + if not isinstance(pattern, FilePattern): + logger.warning( + f"Pattern for resource '{resource_name}' is not a FilePattern, skipping" + ) + continue + self._register_file_sources( + registry, resource_name, pattern, ingestion_params + ) + + elif resource_type == ResourceType.SQL_TABLE: + if not isinstance(pattern, TablePattern): + logger.warning( + f"Pattern for resource '{resource_name}' is not a TablePattern, skipping" + ) + continue + self._register_sql_table_sources( + registry, resource_name, pattern, patterns, ingestion_params + ) + + else: + logger.warning( + f"Unsupported resource type '{resource_type}' for resource '{resource_name}', skipping" + ) + + return registry + + # ------------------------------------------------------------------ + # File sources + # ------------------------------------------------------------------ + + @staticmethod + def discover_files( + fpath: Path | str, pattern: FilePattern, limit_files: int | None = None + ) -> list[Path]: + """Discover files matching *pattern* in a directory. + + Args: + fpath: Directory to search in. + pattern: Pattern to match files against. + limit_files: Optional cap on the number of files returned. + + Returns: + Matching file paths. + """ + assert pattern.sub_path is not None + path = Path(fpath) if isinstance(fpath, str) else fpath + + files = [ + f + for f in path.iterdir() + if f.is_file() + and ( + True + if pattern.regex is None + else re.search(pattern.regex, f.name) is not None + ) + ] + + if limit_files is not None: + files = files[:limit_files] + + return files + + def _register_file_sources( + self, + registry: DataSourceRegistry, + resource_name: str, + pattern: FilePattern, + ingestion_params: IngestionParams, + ) -> None: + if pattern.sub_path is None: + logger.warning( + f"FilePattern for resource '{resource_name}' has no sub_path, skipping" + ) + return + + path_obj = pattern.sub_path.expanduser() + files = self.discover_files( + path_obj, limit_files=ingestion_params.limit_files, pattern=pattern + ) + logger.info(f"For resource name {resource_name} {len(files)} files were found") + + for file_path in files: + file_source = DataSourceFactory.create_file_data_source(path=file_path) + registry.register(file_source, resource_name=resource_name) + + # ------------------------------------------------------------------ + # SQL / table sources + # ------------------------------------------------------------------ + + def _register_sql_table_sources( + self, + registry: DataSourceRegistry, + resource_name: str, + pattern: TablePattern, + patterns: Patterns, + ingestion_params: IngestionParams, + ) -> None: + """Register SQL table data sources for a resource. + + Uses SQLDataSource with batch processing (cursors) instead of loading + all data into memory. + + When the matching Resource has edge actors with ``match_source`` / + ``match_target`` and the source/target vertex types have known + TablePatterns, JoinClauses and IS_NOT_NULL filters are auto-generated + on the pattern before building the SQL query. + """ + from graflo.hq.auto_join import enrich_edge_pattern_with_joins + + postgres_config = patterns.get_postgres_config(resource_name) + if postgres_config is None: + logger.warning( + f"PostgreSQL table '{resource_name}' has no connection config, skipping" + ) + return + + table_info = patterns.get_table_info(resource_name) + if table_info is None: + logger.warning( + f"Could not get table info for resource '{resource_name}', skipping" + ) + return + + table_name, schema_name = table_info + effective_schema = schema_name or postgres_config.schema_name or "public" + + try: + resource = self.schema.fetch_resource(resource_name) + if not pattern.joins: + enrich_edge_pattern_with_joins( + resource=resource, + pattern=pattern, + patterns=patterns, + vertex_config=self.schema.vertex_config, + ) + + date_column = pattern.date_field or ingestion_params.datetime_column + if ( + ingestion_params.datetime_after or ingestion_params.datetime_before + ) and date_column: + # Handled below via build_query + appended WHERE. + pass + elif ingestion_params.datetime_after or ingestion_params.datetime_before: + logger.warning( + "datetime_after/datetime_before set but no date column: " + "set TablePattern.date_field or IngestionParams.datetime_column for resource %s", + resource_name, + ) + + query = pattern.build_query(effective_schema) + + if date_column and date_column != pattern.date_field: + dt_where = datetime_range_where_sql( + ingestion_params.datetime_after, + ingestion_params.datetime_before, + date_column, + ) + if dt_where: + if " WHERE " in query: + query += f" AND {dt_where}" + else: + query += f" WHERE {dt_where}" + + connection_string = postgres_config.to_sqlalchemy_connection_string() + + sql_config = SQLConfig( + connection_string=connection_string, + query=query, + pagination=True, + page_size=ingestion_params.batch_size, + ) + sql_source = SQLDataSource(config=sql_config) + + registry.register(sql_source, resource_name=resource_name) + + logger.info( + f"Created SQL data source for table '{effective_schema}.{table_name}' " + f"mapped to resource '{resource_name}' " + f"(will process in batches of {ingestion_params.batch_size})" + ) + except Exception as e: + logger.error( + f"Failed to create data source for PostgreSQL table '{resource_name}': {e}", + exc_info=True, + ) diff --git a/graflo/util/onto.py b/graflo/util/onto.py index c68e2706..e7556403 100644 --- a/graflo/util/onto.py +++ b/graflo/util/onto.py @@ -137,9 +137,51 @@ def get_resource_type(self) -> ResourceType: return ResourceType.FILE +class JoinClause(ConfigBaseModel): + """Specification for a SQL JOIN operation. + + Used by TablePattern to describe multi-table queries. Each JoinClause + adds one JOIN to the generated SQL. + + Attributes: + table: Table name to join (e.g. "all_classes"). + schema_name: Optional schema override for the joined table. + alias: SQL alias for the joined table (e.g. "s", "t"). Required when + the same table is joined more than once. + on_self: Column on the base (left) table used in the ON condition. + on_other: Column on the joined (right) table used in the ON condition. + join_type: Type of join -- LEFT, INNER, etc. Defaults to LEFT. + select_fields: Explicit list of columns to SELECT from this join. + When None every column of the joined table is included (aliased + with the join alias prefix). + """ + + table: str = Field(..., description="Table name to join.") + schema_name: str | None = Field( + default=None, description="Schema override for the joined table." + ) + alias: str | None = Field( + default=None, description="SQL alias for the joined table." + ) + on_self: str = Field( + ..., description="Column on the base table for the ON condition." + ) + on_other: str = Field( + ..., description="Column on the joined table for the ON condition." + ) + join_type: str = Field(default="LEFT", description="JOIN type (LEFT, INNER, etc.).") + select_fields: list[str] | None = Field( + default=None, + description="Columns to SELECT from this join (None = all columns).", + ) + + class TablePattern(ResourcePattern): """Pattern for matching database tables. + Supports simple single-table queries as well as multi-table JOINs and + pushdown filters via ``FilterExpression``. + Attributes: table_name: Exact table name or regex pattern schema_name: Schema name (optional, defaults to public) @@ -148,6 +190,10 @@ class TablePattern(ResourcePattern): date_filter: SQL-style date filter condition (e.g., "> '2020-10-10'") date_range_start: Start date for range filtering (e.g., "2015-11-11") date_range_days: Number of days after start date (used with date_range_start) + filters: General-purpose pushdown filters rendered as SQL WHERE fragments. + joins: Multi-table JOIN specifications (auto-generated or explicit). + select_columns: Explicit SELECT column list. None means ``*`` for the + base table (plus aliased columns from joins). """ table_name: str = "" @@ -157,6 +203,18 @@ class TablePattern(ResourcePattern): date_filter: str | None = None date_range_start: str | None = None date_range_days: int | None = None + filters: list[Any] = Field( + default_factory=list, + description="Pushdown FilterExpression filters (rendered to SQL WHERE).", + ) + joins: list[JoinClause] = Field( + default_factory=list, + description="JOIN clauses for multi-table queries.", + ) + select_columns: list[str] | None = Field( + default=None, + description="Explicit SELECT columns. None = SELECT * (plus join aliases).", + ) @model_validator(mode="after") def _validate_table_pattern(self) -> Self: @@ -208,17 +266,19 @@ def get_resource_type(self) -> ResourceType: return ResourceType.SQL_TABLE def build_where_clause(self) -> str: - """Build SQL WHERE clause from date filtering parameters. + """Build SQL WHERE clause from date filtering parameters **and** general filters. Returns: WHERE clause string (without the WHERE keyword) or empty string if no filters """ - conditions = [] + from graflo.filter.onto import FilterExpression + from graflo.onto import ExpressionFlavor + conditions: list[str] = [] + + # Date-specific conditions (legacy fields) if self.date_field: if self.date_range_start and self.date_range_days is not None: - # Range filtering: dt >= start_date AND dt < start_date + interval - # Example: Ingest for k days after 2015-11-11 conditions.append( f"\"{self.date_field}\" >= '{self.date_range_start}'::date" ) @@ -226,26 +286,93 @@ def build_where_clause(self) -> str: f"\"{self.date_field}\" < '{self.date_range_start}'::date + INTERVAL '{self.date_range_days} days'" ) elif self.date_filter: - # Direct filter: dt > 2020-10-10 or dt > '2020-10-10' - # The date_filter should include the operator and value - # If value doesn't have quotes, add them filter_parts = self.date_filter.strip().split(None, 1) if len(filter_parts) == 2: operator, value = filter_parts - # Add quotes if not already present and value looks like a date if not (value.startswith("'") and value.endswith("'")): - # Check if it's a date-like string (YYYY-MM-DD format) if len(value) == 10 and value.count("-") == 2: value = f"'{value}'" conditions.append(f'"{self.date_field}" {operator} {value}') else: - # If format is unexpected, use as-is conditions.append(f'"{self.date_field}" {self.date_filter}') + # General-purpose FilterExpression filters + for filt in self.filters: + if isinstance(filt, FilterExpression): + rendered = filt(kind=ExpressionFlavor.SQL) + if rendered: + conditions.append(str(rendered)) + if conditions: return " AND ".join(conditions) return "" + # ------------------------------------------------------------------ + # Full SQL query builder (handles SELECT, FROM, JOINs, WHERE) + # ------------------------------------------------------------------ + + def build_query(self, effective_schema: str | None = None) -> str: + """Build a complete SQL SELECT query. + + Incorporates the base table, any JoinClauses, explicit select_columns, + date filters, and FilterExpression filters. + + Args: + effective_schema: Schema to use if ``self.schema_name`` is None. + + Returns: + Complete SQL query string. + """ + schema = self.schema_name or effective_schema or "public" + base_alias = "r" if self.joins else None + base_ref = f'"{schema}"."{self.table_name}"' + if base_alias: + base_ref_aliased = f"{base_ref} {base_alias}" + else: + base_ref_aliased = base_ref + + # --- SELECT --- + select_parts: list[str] = [] + if self.select_columns is not None: + select_parts = list(self.select_columns) + elif self.joins: + select_parts.append(f"{base_alias}.*") + for jc in self.joins: + alias = jc.alias or jc.table + jc_schema = jc.schema_name or schema + if jc.select_fields is not None: + for col in jc.select_fields: + select_parts.append(f'{alias}."{col}" AS "{alias}__{col}"') + else: + select_parts.append(f"{alias}.*") + else: + select_parts.append("*") + + select_clause = ", ".join(select_parts) + + # --- FROM + JOINs --- + from_clause = base_ref_aliased + for jc in self.joins: + jc_schema = jc.schema_name or schema + alias = jc.alias or jc.table + join_ref = f'"{jc_schema}"."{jc.table}"' + left_col = ( + f'{base_alias}."{jc.on_self}"' if base_alias else f'"{jc.on_self}"' + ) + right_col = f'{alias}."{jc.on_other}"' + from_clause += ( + f" {jc.join_type} JOIN {join_ref} {alias} ON {left_col} = {right_col}" + ) + + query = f"SELECT {select_clause} FROM {from_clause}" + + # --- WHERE --- + where = self.build_where_clause() + if where: + query += f" WHERE {where}" + + return query + class Patterns(ConfigBaseModel): """Collection of named resource patterns with connection management. diff --git a/mkdocs.yml b/mkdocs.yml index b7da2314..6b3409b0 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -72,7 +72,11 @@ markdown_extensions: pygments_lang_class: true - pymdownx.inlinehilite - pymdownx.snippets -- pymdownx.superfences +- pymdownx.superfences: + custom_fences: + - name: mermaid + class: mermaid + format: !!python/name:pymdownx.superfences.fence_code_format - attr_list - md_in_html - admonition @@ -80,8 +84,6 @@ markdown_extensions: - footnotes - meta - pymdownx.blocks.caption -- admonition -- pymdownx.superfences - pymdownx.details - pymdownx.tabbed - pymdownx.highlight diff --git a/pyproject.toml b/pyproject.toml index bfb00f57..83443eab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,7 @@ description = "A framework for transforming tabular (CSV, SQL) and hierarchical name = "graflo" readme = "README.md" requires-python = ">=3.11" -version = "1.5.0" +version = "1.6.0" [project.optional-dependencies] plot = [ diff --git a/test/architecture/test_resource_dynamic.py b/test/architecture/test_resource_dynamic.py new file mode 100644 index 00000000..9701e646 --- /dev/null +++ b/test/architecture/test_resource_dynamic.py @@ -0,0 +1,436 @@ +"""Integration tests for resource-level query enhancements with mock SQL. + +Tests use SQLite file-based (via SQLAlchemy) to exercise: +- Case 1: Filtered vertex resources (same table, different WHERE) +- Case 2: Edge resource with auto-JOIN and dynamic vertex types +""" + +from __future__ import annotations + +import tempfile +import os + +from sqlalchemy import create_engine, text + +from graflo.architecture.schema import Schema +from graflo.data_source.sql import SQLConfig, SQLDataSource +from graflo.filter.onto import ComparisonOperator, FilterExpression +from graflo.util.onto import TablePattern +from graflo.hq.auto_join import enrich_edge_pattern_with_joins +from graflo.util.onto import Patterns + + +# --------------------------------------------------------------- +# Helper: create a fresh SQLite file-based DB with test data +# --------------------------------------------------------------- + + +def _setup_db() -> str: + """Create a SQLite file DB with CMDB-like tables and return the connection string. + + Uses a temporary file so that multiple SQLAlchemy engines can share it + (in-memory SQLite is connection-private). + """ + fd, path = tempfile.mkstemp(suffix=".db") + os.close(fd) + conn_str = f"sqlite:///{path}" + engine = create_engine(conn_str) + with engine.connect() as conn: + conn.execute( + text( + """ + CREATE TABLE classes ( + id TEXT PRIMARY KEY, + class_name TEXT NOT NULL, + description TEXT + ) + """ + ) + ) + conn.execute( + text( + """ + INSERT INTO classes (id, class_name, description) VALUES + ('1', 'server', 'Web Server'), + ('2', 'database', 'PostgreSQL'), + ('3', 'server', 'App Server'), + ('4', 'database', 'MySQL'), + ('5', 'network', 'Router') + """ + ) + ) + conn.execute( + text( + """ + CREATE TABLE relations ( + id INTEGER PRIMARY KEY, + parent TEXT NOT NULL, + child TEXT NOT NULL, + type_display TEXT NOT NULL + ) + """ + ) + ) + conn.execute( + text( + """ + INSERT INTO relations (id, parent, child, type_display) VALUES + (1, '1', '2', 'runs_on'), + (2, '3', '4', 'runs_on'), + (3, '1', '5', 'connects_to') + """ + ) + ) + conn.commit() + return conn_str + + +# --------------------------------------------------------------- +# Case 1: Filtered vertex resources +# --------------------------------------------------------------- + + +class TestFilteredVertexResources: + """Two Resources read the same table with different filter predicates.""" + + def test_filtered_queries_produce_correct_subsets(self): + """Each Resource's generated query returns only its filtered rows.""" + conn_str = _setup_db() + + # Resource "server" -> classes WHERE class_name = 'server' + f_server = FilterExpression( + kind="leaf", + field="class_name", + cmp_operator=ComparisonOperator.EQ, + value=["server"], + ) + tp_server = TablePattern( + table_name="classes", + filters=[f_server], + ) + + # Resource "database" -> classes WHERE class_name = 'database' + f_db = FilterExpression( + kind="leaf", + field="class_name", + cmp_operator=ComparisonOperator.EQ, + value=["database"], + ) + tp_db = TablePattern( + table_name="classes", + filters=[f_db], + ) + + # SQLite doesn't use schemas, so we use "main" as effective_schema + # but build_query quotes it -- SQLite ignores schema prefix on tables + # so we pass None and rely on the default "public" which SQLite also ignores. + # We'll just use the raw query from build_where_clause instead. + query_server = "SELECT * FROM classes" + where_server = tp_server.build_where_clause() + if where_server: + query_server += f" WHERE {where_server}" + + query_db = "SELECT * FROM classes" + where_db = tp_db.build_where_clause() + if where_db: + query_db += f" WHERE {where_db}" + + # Execute the queries + ds_server = SQLDataSource( + config=SQLConfig( + connection_string=conn_str, + query=query_server, + pagination=False, + ) + ) + ds_db = SQLDataSource( + config=SQLConfig( + connection_string=conn_str, + query=query_db, + pagination=False, + ) + ) + + server_rows = list(ds_server) + db_rows = list(ds_db) + + # server has 2 rows (id 1, 3) + assert len(server_rows) == 2 + assert all(r["class_name"] == "server" for r in server_rows) + + # database has 2 rows (id 2, 4) + assert len(db_rows) == 2 + assert all(r["class_name"] == "database" for r in db_rows) + + def test_build_query_filter_sql_renders_correctly(self): + """Verify build_where_clause() renders FilterExpression filters.""" + f = FilterExpression( + kind="leaf", + field="class_name", + cmp_operator=ComparisonOperator.EQ, + value=["server"], + ) + tp = TablePattern(table_name="classes", filters=[f]) + where = tp.build_where_clause() + assert "\"class_name\" = 'server'" in where + + +# --------------------------------------------------------------- +# Case 2: Edge resource with auto-JOIN and dynamic vertex types +# --------------------------------------------------------------- + + +class TestEdgeResourceAutoJoin: + """Edge resource with JOINs and dynamic vertex types through full pipeline.""" + + def _build_schema(self) -> Schema: + return Schema.model_validate( + { + "general": {"name": "test", "version": "0.0.1"}, + "vertex_config": { + "vertices": [ + { + "name": "server", + "fields": ["id", "class_name", "description"], + }, + { + "name": "database", + "fields": ["id", "class_name", "description"], + }, + { + "name": "network", + "fields": ["id", "class_name", "description"], + }, + ], + }, + "edge_config": { + "edges": [ + {"source": "server", "target": "database"}, + {"source": "server", "target": "network"}, + ], + }, + "resources": [ + { + "resource_name": "relations", + "pipeline": [ + { + "vertex_router": { + "type_field": "s__class_name", + "prefix": "s__", + } + }, + { + "vertex_router": { + "type_field": "t__class_name", + "prefix": "t__", + } + }, + { + "edge": { + "from": "server", + "to": "database", + "match_source": "parent", + "match_target": "child", + "relation_field": "type_display", + } + }, + ], + }, + ], + } + ) + + def test_auto_join_query_generates_correct_sql(self): + """Verify the auto-JOIN + build_query produces valid SQL structure.""" + + schema = self._build_schema() + resource = schema.fetch_resource("relations") + + tp_edge = TablePattern(table_name="relations", schema_name="main") + patterns_table = { + "server": TablePattern(table_name="classes", schema_name="main"), + "database": TablePattern(table_name="classes", schema_name="main"), + "network": TablePattern(table_name="classes", schema_name="main"), + "relations": tp_edge, + } + + patterns = Patterns(table_patterns=patterns_table) + + enrich_edge_pattern_with_joins( + resource=resource, + pattern=tp_edge, + patterns=patterns, + vertex_config=schema.vertex_config, + ) + + q = tp_edge.build_query("main") + # Structure checks + assert "LEFT JOIN" in q + assert "IS NOT NULL" in q + assert '"main"."relations"' in q + assert '"main"."classes"' in q + + def test_auto_join_query_executes_on_sqlite(self): + """Run the generated JOIN query against a real SQLite DB.""" + from graflo.hq.auto_join import enrich_edge_pattern_with_joins + from graflo.util.onto import Patterns + + conn_str = _setup_db() + schema = self._build_schema() + resource = schema.fetch_resource("relations") + + tp_edge = TablePattern(table_name="relations") + patterns_table = { + "server": TablePattern(table_name="classes"), + "database": TablePattern(table_name="classes"), + "network": TablePattern(table_name="classes"), + "relations": tp_edge, + } + patterns = Patterns(table_patterns=patterns_table) + + enrich_edge_pattern_with_joins( + resource=resource, + pattern=tp_edge, + patterns=patterns, + vertex_config=schema.vertex_config, + ) + + # SQLite doesn't use schema prefixes, so build query manually + # mimicking what build_query does but without schema quoting + base = "relations r" + join_parts = [] + for jc in tp_edge.joins: + alias = jc.alias + join_parts.append( + f"{jc.join_type} JOIN classes {alias} ON r.{jc.on_self} = {alias}.{jc.on_other}" + ) + + # SELECT with aliased columns to simulate the prefix convention + select_cols = [ + "r.*", + 's.id AS "s__id"', + 's.class_name AS "s__class_name"', + 's.description AS "s__description"', + 't.id AS "t__id"', + 't.class_name AS "t__class_name"', + 't.description AS "t__description"', + ] + query = f"SELECT {', '.join(select_cols)} FROM {base} {' '.join(join_parts)}" + query += " WHERE s.id IS NOT NULL AND t.id IS NOT NULL" + + ds = SQLDataSource( + config=SQLConfig( + connection_string=conn_str, + query=query, + pagination=False, + ) + ) + rows = list(ds) + + # We inserted 3 relations; all should have valid source/target + assert len(rows) == 3 + # Each row should have the aliased columns + assert "s__class_name" in rows[0] + assert "t__class_name" in rows[0] + + def test_pipeline_contains_vertex_router_actors(self): + """Pipeline with vertex_router steps produces VertexRouterActor instances.""" + from graflo.architecture.actor import VertexRouterActor + + schema = self._build_schema() + resource = schema.fetch_resource("relations") + + all_actors = resource.root.collect_actors() + router_actors = [a for a in all_actors if isinstance(a, VertexRouterActor)] + + # Should have 2 routers (source and target) + assert len(router_actors) == 2 + type_fields = {a.type_field for a in router_actors} + assert type_fields == {"s__class_name", "t__class_name"} + + # Each router should have all 3 vertex types registered + for ra in router_actors: + assert set(ra._vertex_actors.keys()) == {"server", "database", "network"} + + def test_vertex_router_extract_sub_doc_strips_prefix(self): + """VertexRouterActor._extract_sub_doc strips prefix from field keys.""" + from graflo.architecture.actor import VertexRouterActor + from graflo.architecture.actor_config import VertexRouterActorConfig + + config = VertexRouterActorConfig(type_field="s__class_name", prefix="s__") + router = VertexRouterActor(config) + + doc = { + "parent": "1", + "child": "2", + "type_display": "runs_on", + "s__id": "1", + "s__class_name": "server", + "s__description": "Web Server", + "t__id": "2", + "t__class_name": "database", + "t__description": "PostgreSQL", + } + + sub_doc = router._extract_sub_doc(doc) + + # Only s__-prefixed keys extracted, with prefix stripped + assert sub_doc == { + "id": "1", + "class_name": "server", + "description": "Web Server", + } + + def test_vertex_router_extract_sub_doc_with_field_map(self): + """VertexRouterActor._extract_sub_doc applies field_map when set.""" + from graflo.architecture.actor import VertexRouterActor + from graflo.architecture.actor_config import VertexRouterActorConfig + + config = VertexRouterActorConfig( + type_field="src_type", + field_map={"src_id": "id", "src_name": "class_name"}, + ) + router = VertexRouterActor(config) + + doc = { + "src_type": "server", + "src_id": "1", + "src_name": "server", + "extra": "ignored", + } + sub_doc = router._extract_sub_doc(doc) + + assert sub_doc == {"id": "1", "class_name": "server"} + + def test_full_resource_call_produces_vertices_and_edges(self): + """Resource.__call__ with dynamic types creates vertices and edges.""" + schema = self._build_schema() + resource = schema.fetch_resource("relations") + + doc = { + "parent": "1", + "child": "2", + "type_display": "runs_on", + "s__id": "1", + "s__class_name": "server", + "s__description": "Web Server", + "t__id": "2", + "t__class_name": "database", + "t__description": "PostgreSQL", + } + + result = resource(doc) + + # Should have vertices + vertex_keys = [k for k in result if isinstance(k, str)] + assert "server" in vertex_keys + assert "database" in vertex_keys + + # server should have the routed vertex doc + server_docs = result["server"] + assert len(server_docs) >= 1 + assert any(d.get("id") == "1" for d in server_docs) + + db_docs = result["database"] + assert len(db_docs) >= 1 + assert any(d.get("id") == "2" for d in db_docs) diff --git a/test/architecture/test_resource_filters.py b/test/architecture/test_resource_filters.py new file mode 100644 index 00000000..8839768a --- /dev/null +++ b/test/architecture/test_resource_filters.py @@ -0,0 +1,398 @@ +"""Unit tests for resource-level query enhancements. + +Tests cover: +- FilterExpression IS_NULL / IS_NOT_NULL across all flavours +- TablePattern.build_query() with joins and filters +- Auto-JOIN generation helper (enrich_edge_pattern_with_joins) +""" + +from __future__ import annotations + + +from graflo.filter.onto import ComparisonOperator, FilterExpression, LogicalOperator +from graflo.onto import ExpressionFlavor +from graflo.util.onto import JoinClause, TablePattern + + +# --------------------------------------------------------------- +# Phase 1: IS_NULL / IS_NOT_NULL operators +# --------------------------------------------------------------- + + +class TestIsNullIsNotNull: + """FilterExpression rendering of IS_NULL / IS_NOT_NULL across flavours.""" + + def _leaf(self, field: str, op: ComparisonOperator) -> FilterExpression: + return FilterExpression(kind="leaf", field=field, cmp_operator=op) + + # --- SQL --- + + def test_is_null_sql(self): + expr = self._leaf("class_name", ComparisonOperator.IS_NULL) + assert expr(kind=ExpressionFlavor.SQL) == '"class_name" IS NULL' + + def test_is_not_null_sql(self): + expr = self._leaf("class_name", ComparisonOperator.IS_NOT_NULL) + assert expr(kind=ExpressionFlavor.SQL) == '"class_name" IS NOT NULL' + + def test_is_not_null_sql_aliased_field(self): + """Dotted field ``s.id`` should render as ``s."id" IS NOT NULL``.""" + expr = self._leaf("s.id", ComparisonOperator.IS_NOT_NULL) + assert expr(kind=ExpressionFlavor.SQL) == 's."id" IS NOT NULL' + + # --- AQL --- + + def test_is_null_aql(self): + expr = self._leaf("name", ComparisonOperator.IS_NULL) + assert expr(doc_name="d", kind=ExpressionFlavor.AQL) == 'd["name"] == null' + + def test_is_not_null_aql(self): + expr = self._leaf("name", ComparisonOperator.IS_NOT_NULL) + assert expr(doc_name="d", kind=ExpressionFlavor.AQL) == 'd["name"] != null' + + # --- Cypher --- + + def test_is_null_cypher(self): + expr = self._leaf("age", ComparisonOperator.IS_NULL) + assert expr(doc_name="n", kind=ExpressionFlavor.CYPHER) == "n.age IS NULL" + + def test_is_not_null_cypher(self): + expr = self._leaf("age", ComparisonOperator.IS_NOT_NULL) + assert expr(doc_name="n", kind=ExpressionFlavor.CYPHER) == "n.age IS NOT NULL" + + # --- GSQL (TigerGraph) --- + + def test_is_null_gsql(self): + expr = self._leaf("status", ComparisonOperator.IS_NULL) + assert expr(doc_name="v", kind=ExpressionFlavor.GSQL) == "v.status IS NULL" + + def test_is_not_null_gsql(self): + expr = self._leaf("status", ComparisonOperator.IS_NOT_NULL) + assert expr(doc_name="v", kind=ExpressionFlavor.GSQL) == "v.status IS NOT NULL" + + # --- REST++ (TigerGraph with empty doc_name) --- + + def test_is_null_restpp(self): + expr = self._leaf("x", ComparisonOperator.IS_NULL) + result = expr(doc_name="", kind=ExpressionFlavor.GSQL) + assert result == 'x=""' + + def test_is_not_null_restpp(self): + expr = self._leaf("x", ComparisonOperator.IS_NOT_NULL) + result = expr(doc_name="", kind=ExpressionFlavor.GSQL) + assert result == 'x!=""' + + # --- Python --- + + def test_is_null_python_true(self): + expr = self._leaf("col", ComparisonOperator.IS_NULL) + assert expr(kind=ExpressionFlavor.PYTHON, col=None) is True + + def test_is_null_python_false(self): + expr = self._leaf("col", ComparisonOperator.IS_NULL) + assert expr(kind=ExpressionFlavor.PYTHON, col="val") is False + + def test_is_not_null_python_true(self): + expr = self._leaf("col", ComparisonOperator.IS_NOT_NULL) + assert expr(kind=ExpressionFlavor.PYTHON, col="val") is True + + def test_is_not_null_python_false(self): + expr = self._leaf("col", ComparisonOperator.IS_NOT_NULL) + assert expr(kind=ExpressionFlavor.PYTHON, col=None) is False + + # --- Composite with IS_NOT_NULL --- + + def test_and_with_is_not_null_sql(self): + expr = FilterExpression( + kind="composite", + operator=LogicalOperator.AND, + deps=[ + self._leaf("s.id", ComparisonOperator.IS_NOT_NULL), + self._leaf("t.id", ComparisonOperator.IS_NOT_NULL), + ], + ) + result = expr(kind=ExpressionFlavor.SQL) + assert isinstance(result, str) + assert 's."id" IS NOT NULL' in result + assert 't."id" IS NOT NULL' in result + assert " AND " in result + + # --- value list is cleared for null ops --- + + def test_value_cleared(self): + expr = FilterExpression( + kind="leaf", + field="f", + cmp_operator=ComparisonOperator.IS_NULL, + value=["should_be_cleared"], + ) + assert expr.value == [] + + +# --------------------------------------------------------------- +# Phase 2: TablePattern.build_query +# --------------------------------------------------------------- + + +class TestTablePatternBuildQuery: + """TablePattern.build_query() generates correct SQL.""" + + def test_simple_select_star(self): + tp = TablePattern(table_name="users") + q = tp.build_query("public") + assert q == 'SELECT * FROM "public"."users"' + + def test_with_date_filter(self): + tp = TablePattern( + table_name="events", + date_field="created_at", + date_filter="> '2020-01-01'", + ) + q = tp.build_query("public") + assert 'WHERE "created_at" >' in q + + def test_with_filter_expression(self): + f = FilterExpression( + kind="leaf", + field="class_name", + cmp_operator=ComparisonOperator.EQ, + value=["server"], + ) + tp = TablePattern(table_name="classes", filters=[f]) + q = tp.build_query("myschema") + assert "WHERE" in q + assert "\"class_name\" = 'server'" in q + + def test_with_multiple_filters(self): + f1 = FilterExpression( + kind="leaf", + field="status", + cmp_operator=ComparisonOperator.EQ, + value=["active"], + ) + f2 = FilterExpression( + kind="leaf", + field="age", + cmp_operator=ComparisonOperator.GE, + value=[18], + ) + tp = TablePattern(table_name="people", filters=[f1, f2]) + q = tp.build_query("public") + assert "AND" in q + assert '"status"' in q + assert '"age"' in q + + def test_with_single_join(self): + jc = JoinClause( + table="addresses", + alias="a", + on_self="address_id", + on_other="id", + ) + tp = TablePattern(table_name="users", joins=[jc]) + q = tp.build_query("public") + assert "LEFT JOIN" in q + assert '"public"."addresses" a' in q + assert 'r."address_id" = a."id"' in q + # base table aliased as 'r' + assert "r.*" in q + + def test_with_two_joins_same_table(self): + """CMDB-style: two joins to same table with different aliases.""" + jc_s = JoinClause( + table="classes", + alias="s", + on_self="parent", + on_other="id", + ) + jc_t = JoinClause( + table="classes", + alias="t", + on_self="child", + on_other="id", + ) + f1 = FilterExpression( + kind="leaf", + field="s.id", + cmp_operator=ComparisonOperator.IS_NOT_NULL, + ) + f2 = FilterExpression( + kind="leaf", + field="t.id", + cmp_operator=ComparisonOperator.IS_NOT_NULL, + ) + tp = TablePattern( + table_name="cmdb_rel_ci", + joins=[jc_s, jc_t], + filters=[f1, f2], + ) + q = tp.build_query("sn") + # Both JOINs present + assert '"sn"."classes" s' in q + assert '"sn"."classes" t' in q + assert 'r."parent" = s."id"' in q + assert 'r."child" = t."id"' in q + # IS NOT NULL filters + assert 's."id" IS NOT NULL' in q + assert 't."id" IS NOT NULL' in q + + def test_join_with_select_fields(self): + jc = JoinClause( + table="classes", + alias="s", + on_self="parent", + on_other="id", + select_fields=["id", "class_name"], + ) + tp = TablePattern(table_name="rel", joins=[jc]) + q = tp.build_query("public") + assert 's."id" AS "s__id"' in q + assert 's."class_name" AS "s__class_name"' in q + + def test_explicit_select_columns(self): + tp = TablePattern( + table_name="t", + select_columns=["a", "b"], + ) + q = tp.build_query("public") + assert q.startswith("SELECT a, b FROM") + + def test_schema_defaults_to_public(self): + tp = TablePattern(table_name="t") + q = tp.build_query() + assert '"public"."t"' in q + + +# --------------------------------------------------------------- +# Phase 3: Auto-JOIN generation +# --------------------------------------------------------------- + + +class TestAutoJoin: + """enrich_edge_pattern_with_joins adds JoinClauses from edge defs.""" + + def _make_schema_and_patterns(self): + """Build a minimal Schema + Patterns for the CMDB-like scenario.""" + from graflo.architecture.schema import Schema + from graflo.util.onto import Patterns + + schema = Schema.model_validate( + { + "general": {"name": "test", "version": "0.0.1"}, + "vertex_config": { + "vertices": [ + {"name": "server", "fields": ["id", "class_name"]}, + {"name": "database", "fields": ["id", "class_name"]}, + ], + }, + "edge_config": { + "edges": [ + {"source": "server", "target": "database"}, + ], + }, + "resources": [ + { + "resource_name": "cmdb_relations", + "pipeline": [ + { + "edge": { + "from": "server", + "to": "database", + "match_source": "parent", + "match_target": "child", + } + } + ], + } + ], + } + ) + + patterns = Patterns( + table_patterns={ + "server": TablePattern(table_name="classes", schema_name="sn"), + "database": TablePattern(table_name="classes", schema_name="sn"), + "cmdb_relations": TablePattern( + table_name="cmdb_rel_ci", schema_name="sn" + ), + }, + ) + return schema, patterns + + def test_enrichment_adds_joins(self): + from graflo.hq.auto_join import enrich_edge_pattern_with_joins + + schema, patterns = self._make_schema_and_patterns() + resource = schema.fetch_resource("cmdb_relations") + pattern = patterns.table_patterns["cmdb_relations"] + + enrich_edge_pattern_with_joins( + resource=resource, + pattern=pattern, + patterns=patterns, + vertex_config=schema.vertex_config, + ) + + assert len(pattern.joins) == 2 + aliases = {j.alias for j in pattern.joins} + assert aliases == {"s", "t"} + # The on_self fields come from edge match_source / match_target + on_self_cols = {j.on_self for j in pattern.joins} + assert on_self_cols == {"parent", "child"} + + def test_enrichment_adds_is_not_null_filters(self): + from graflo.hq.auto_join import enrich_edge_pattern_with_joins + + schema, patterns = self._make_schema_and_patterns() + resource = schema.fetch_resource("cmdb_relations") + pattern = patterns.table_patterns["cmdb_relations"] + + enrich_edge_pattern_with_joins( + resource=resource, + pattern=pattern, + patterns=patterns, + vertex_config=schema.vertex_config, + ) + + assert len(pattern.filters) == 2 + rendered = [f(kind=ExpressionFlavor.SQL) for f in pattern.filters] + assert 's."id" IS NOT NULL' in rendered + assert 't."id" IS NOT NULL' in rendered + + def test_enrichment_noop_when_joins_already_set(self): + from graflo.hq.auto_join import enrich_edge_pattern_with_joins + + schema, patterns = self._make_schema_and_patterns() + resource = schema.fetch_resource("cmdb_relations") + pattern = patterns.table_patterns["cmdb_relations"] + pattern.joins = [JoinClause(table="x", alias="x", on_self="a", on_other="b")] + + enrich_edge_pattern_with_joins( + resource=resource, + pattern=pattern, + patterns=patterns, + vertex_config=schema.vertex_config, + ) + + # Should not have modified the existing join + assert len(pattern.joins) == 1 + assert pattern.joins[0].table == "x" + + def test_full_query_after_enrichment(self): + from graflo.hq.auto_join import enrich_edge_pattern_with_joins + + schema, patterns = self._make_schema_and_patterns() + resource = schema.fetch_resource("cmdb_relations") + pattern = patterns.table_patterns["cmdb_relations"] + + enrich_edge_pattern_with_joins( + resource=resource, + pattern=pattern, + patterns=patterns, + vertex_config=schema.vertex_config, + ) + + q = pattern.build_query("sn") + assert "LEFT JOIN" in q + assert "IS NOT NULL" in q + assert '"sn"."cmdb_rel_ci"' in q diff --git a/test/db/postgres/test_ingest_datetime_range.py b/test/db/postgres/test_ingest_datetime_range.py index 0dfa8df1..239fa9cf 100644 --- a/test/db/postgres/test_ingest_datetime_range.py +++ b/test/db/postgres/test_ingest_datetime_range.py @@ -5,7 +5,8 @@ filter rows correctly. """ -from graflo.hq.caster import Caster, IngestionParams +from graflo.filter.sql import datetime_range_where_sql +from graflo.hq.caster import IngestionParams from graflo.hq.graph_engine import GraphEngine from graflo.onto import DBType from graflo.util.onto import TablePattern @@ -60,7 +61,7 @@ def test_ingest_datetime_range_postgres(postgres_conn, load_mock_schema): resource_name="purchases", date_field="purchase_date", ) - datetime_where = Caster._datetime_range_where_sql( + datetime_where = datetime_range_where_sql( "2020-02-01", "2020-06-01", pattern.date_field or "purchase_date", @@ -93,7 +94,7 @@ def test_ingest_datetime_range_with_global_column(postgres_conn, load_mock_schem ) date_column = pattern.date_field or ingestion_params.datetime_column assert date_column == "purchase_date" - datetime_where = Caster._datetime_range_where_sql( + datetime_where = datetime_range_where_sql( ingestion_params.datetime_after, ingestion_params.datetime_before, date_column, diff --git a/test/test_ingestion_datetime.py b/test/test_ingestion_datetime.py index f3225681..011afa27 100644 --- a/test/test_ingestion_datetime.py +++ b/test/test_ingestion_datetime.py @@ -1,6 +1,7 @@ """Tests for ingestion datetime range params and SQL WHERE building.""" -from graflo.hq.caster import Caster, IngestionParams +from graflo.filter.sql import datetime_range_where_sql +from graflo.hq.caster import IngestionParams from graflo.util.onto import TablePattern @@ -25,14 +26,14 @@ def test_ingestion_params_datetime_set(): def test_datetime_range_where_sql_empty(): - """_datetime_range_where_sql returns empty when both bounds None.""" - out = Caster._datetime_range_where_sql(None, None, "dt") + """datetime_range_where_sql returns empty when both bounds None.""" + out = datetime_range_where_sql(None, None, "dt") assert out == "" def test_datetime_range_where_sql_both_bounds(): - """_datetime_range_where_sql produces [after, before) with AND.""" - out = Caster._datetime_range_where_sql( + """datetime_range_where_sql produces [after, before) with AND.""" + out = datetime_range_where_sql( "2020-01-01", "2020-12-31", "created_at", @@ -43,20 +44,20 @@ def test_datetime_range_where_sql_both_bounds(): def test_datetime_range_where_sql_only_after(): - """_datetime_range_where_sql with only datetime_after.""" - out = Caster._datetime_range_where_sql("2020-06-01", None, "dt") + """datetime_range_where_sql with only datetime_after.""" + out = datetime_range_where_sql("2020-06-01", None, "dt") assert out == "\"dt\" >= '2020-06-01'" def test_datetime_range_where_sql_only_before(): - """_datetime_range_where_sql with only datetime_before.""" - out = Caster._datetime_range_where_sql(None, "2021-01-01", "ts") + """datetime_range_where_sql with only datetime_before.""" + out = datetime_range_where_sql(None, "2021-01-01", "ts") assert out == "\"ts\" < '2021-01-01'" def test_datetime_range_where_sql_iso_format(): - """_datetime_range_where_sql accepts ISO datetime strings.""" - out = Caster._datetime_range_where_sql( + """datetime_range_where_sql accepts ISO datetime strings.""" + out = datetime_range_where_sql( "2020-01-15T10:30:00", "2020-01-15T18:00:00", "updated_at", @@ -68,19 +69,18 @@ def test_datetime_range_where_sql_iso_format(): def test_sql_query_where_combines_pattern_and_ingestion_datetime(): """Query WHERE combines TablePattern date_filter and ingestion datetime range.""" - # Simulate the logic in _register_sql_table_sources: pattern WHERE + datetime WHERE pattern = TablePattern( table_name="events", date_field="dt", date_filter="!= '2020-01-01'", ) pattern_where = pattern.build_where_clause() - datetime_where = Caster._datetime_range_where_sql( + dt_where = datetime_range_where_sql( "2020-06-01", "2020-07-01", pattern.date_field or "dt", ) - where_parts = [p for p in [pattern_where, datetime_where] if p] + where_parts = [p for p in [pattern_where, dt_where] if p] combined = " AND ".join(where_parts) assert "\"dt\" != '2020-01-01'" in combined assert "\"dt\" >= '2020-06-01'" in combined diff --git a/uv.lock b/uv.lock index cac297b7..de4ee54e 100644 --- a/uv.lock +++ b/uv.lock @@ -348,7 +348,7 @@ wheels = [ [[package]] name = "graflo" -version = "1.5.0" +version = "1.6.0" source = { editable = "." } dependencies = [ { name = "click" },