From 6eae543d68360b9ec7b2b5eb117770f6981e47c3 Mon Sep 17 00:00:00 2001 From: paulnoirel <87332996+paulnoirel@users.noreply.github.com> Date: Wed, 31 Dec 2025 14:09:01 +0000 Subject: [PATCH 1/2] Remove input tests Format updated by linter --- .../schema/workflow/workflow_utils.py | 25 -------- .../tests/integration/test_workflow.py | 34 ++++++++++ .../unit/test_workflow_utils_validation.py | 64 +++++++++++++++++++ 3 files changed, 98 insertions(+), 25 deletions(-) create mode 100644 libs/labelbox/tests/unit/test_workflow_utils_validation.py diff --git a/libs/labelbox/src/labelbox/schema/workflow/workflow_utils.py b/libs/labelbox/src/labelbox/schema/workflow/workflow_utils.py index bd2ca0ca0..1e71a576f 100644 --- a/libs/labelbox/src/labelbox/schema/workflow/workflow_utils.py +++ b/libs/labelbox/src/labelbox/schema/workflow/workflow_utils.py @@ -117,31 +117,6 @@ def validate_node_connections( "node_type": node_type, } ) - elif len(predecessors) > 1: - # Check if all predecessors are initial nodes - node_map = {n.id: n for n in nodes} - predecessor_nodes = [ - node_map.get(pred_id) for pred_id in predecessors - ] - all_initial = all( - pred_node - and pred_node.definition_id in initial_node_types - for pred_node in predecessor_nodes - if pred_node is not None - ) - - if not all_initial: - preds_info = ", ".join( - [p[:8] + "..." for p in predecessors] - ) - errors.append( - { - "reason": f"has multiple incoming connections ({len(predecessors)}) but not all are from initial nodes", - "node_id": node.id, - "node_type": node_type, - "details": f"Connected from: {preds_info}", - } - ) # Check outgoing connections (except terminal nodes) if node.definition_id not in terminal_node_types: diff --git a/libs/labelbox/tests/integration/test_workflow.py b/libs/labelbox/tests/integration/test_workflow.py index 96cb53b46..f09cac2a4 100644 --- a/libs/labelbox/tests/integration/test_workflow.py +++ b/libs/labelbox/tests/integration/test_workflow.py @@ -96,6 +96,40 @@ def test_workflow_creation(client, test_projects): assert WorkflowDefinitionId.Done in node_types +def test_workflow_allows_multiple_incoming_from_non_initial_nodes( + client, test_projects +): + """ + Nodes may have multiple incoming connections from any nodes (not only initial nodes). + + This used to fail validation when a node had >1 predecessor and at least one + predecessor was not an initial node. + """ + source_project, _ = test_projects + + workflow = source_project.get_workflow() + initial_nodes = workflow.reset_to_initial_nodes( + labeling_config=LabelingConfig(instructions="Start labeling here") + ) + + logic = workflow.add_node( + type=NodeType.Logic, + name="Gate", + filters=ProjectWorkflowFilter([labeled_by.is_one_of(["test-user"])]), + ) + review = workflow.add_node(type=NodeType.Review, name="Review Task") + done = workflow.add_node(type=NodeType.Done, name="Done") + + # Multiple incoming connections to review, including from a non-initial node (logic) + workflow.add_edge(initial_nodes.labeling, logic) + workflow.add_edge(logic, review, NodeOutput.If) + workflow.add_edge(initial_nodes.rework, review) + workflow.add_edge(review, done, NodeOutput.Approved) + + # Should validate and update successfully + workflow.update_config(reposition=False) + + def test_workflow_creation_simple(client): """Test creating a simple workflow with the working pattern.""" # Create a new project for this test diff --git a/libs/labelbox/tests/unit/test_workflow_utils_validation.py b/libs/labelbox/tests/unit/test_workflow_utils_validation.py new file mode 100644 index 000000000..6e608faa6 --- /dev/null +++ b/libs/labelbox/tests/unit/test_workflow_utils_validation.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass + +from labelbox.schema.workflow.enums import WorkflowDefinitionId +from labelbox.schema.workflow.graph import ProjectWorkflowGraph +from labelbox.schema.workflow.workflow_utils import WorkflowValidator + + +@dataclass(frozen=True) +class _Node: + id: str + definition_id: WorkflowDefinitionId + + +def test_validate_node_connections_allows_multiple_incoming_from_non_initial_nodes(): + """ + Regression test: nodes may have multiple incoming connections from any nodes. + + Historically validation required that if a node had >1 predecessors, they all had + to be initial nodes. Workflow Management now allows multi-input nodes from any + nodes, so this must not error. + """ + initial_labeling = _Node( + id="initial_labeling", + definition_id=WorkflowDefinitionId.InitialLabelingTask, + ) + initial_rework = _Node( + id="initial_rework", + definition_id=WorkflowDefinitionId.InitialReworkTask, + ) + logic = _Node(id="logic", definition_id=WorkflowDefinitionId.Logic) + review = _Node(id="review", definition_id=WorkflowDefinitionId.ReviewTask) + done = _Node(id="done", definition_id=WorkflowDefinitionId.Done) + + nodes = [initial_labeling, initial_rework, logic, review, done] + + graph = ProjectWorkflowGraph() + graph.add_edge(initial_labeling.id, logic.id) + graph.add_edge(logic.id, review.id) + graph.add_edge(initial_rework.id, review.id) + graph.add_edge(review.id, done.id) + + errors = WorkflowValidator.validate_node_connections(nodes, graph) + assert errors == [] + + +def test_validate_node_connections_still_flags_missing_incoming_connections(): + """Non-initial nodes must still have at least one incoming connection.""" + initial_labeling = _Node( + id="initial_labeling", + definition_id=WorkflowDefinitionId.InitialLabelingTask, + ) + review = _Node(id="review", definition_id=WorkflowDefinitionId.ReviewTask) + done = _Node(id="done", definition_id=WorkflowDefinitionId.Done) + + nodes = [initial_labeling, review, done] + graph = ProjectWorkflowGraph() + graph.add_edge(initial_labeling.id, done.id) + + errors = WorkflowValidator.validate_node_connections(nodes, graph) + assert any( + e.get("node_id") == review.id + and e.get("reason") == "has no incoming connections" + for e in errors + ) From 9b588c167ab5a2067e11e2c7b4e3e8b793ee2735 Mon Sep 17 00:00:00 2001 From: paulnoirel <87332996+paulnoirel@users.noreply.github.com> Date: Wed, 31 Dec 2025 14:32:35 +0000 Subject: [PATCH 2/2] Fix lint --- libs/labelbox/src/labelbox/client.py | 10 +++--- libs/labelbox/src/labelbox/schema/project.py | 36 ++++++++++++-------- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/libs/labelbox/src/labelbox/client.py b/libs/labelbox/src/labelbox/client.py index 0d8c113a3..60fb8016d 100644 --- a/libs/labelbox/src/labelbox/client.py +++ b/libs/labelbox/src/labelbox/client.py @@ -507,16 +507,16 @@ def delete_project_memberships( self, project_id: str, user_ids: list[str] ) -> dict: """Deletes project memberships for one or more users. - + Args: project_id (str): ID of the project user_ids (list[str]): List of user IDs to remove from the project - + Returns: dict: Result containing: - success (bool): True if operation succeeded - errorMessage (str or None): Error message if operation failed - + Example: >>> result = client.delete_project_memberships( >>> project_id="project123", @@ -539,12 +539,12 @@ def delete_project_memberships( errorMessage } }""" - + params = { "projectId": project_id, "userIds": user_ids, } - + result = self.execute(mutation, params) return result["deleteProjectMemberships"] diff --git a/libs/labelbox/src/labelbox/schema/project.py b/libs/labelbox/src/labelbox/schema/project.py index f00a75cb2..60d6b6258 100644 --- a/libs/labelbox/src/labelbox/schema/project.py +++ b/libs/labelbox/src/labelbox/schema/project.py @@ -317,7 +317,9 @@ def get_resource_tags(self) -> List[ResourceTag]: return [ResourceTag(self.client, tag) for tag in results] - def labels(self, datasets=None, order_by=None, created_by=None) -> PaginatedCollection: + def labels( + self, datasets=None, order_by=None, created_by=None + ) -> PaginatedCollection: """Custom relationship expansion method to support limited filtering. Args: @@ -334,7 +336,7 @@ def labels(self, datasets=None, order_by=None, created_by=None) -> PaginatedColl Example: >>> # Get all labels >>> all_labels = project.labels() - >>> + >>> >>> # Get labels by specific user >>> user_labels = project.labels(created_by=user_id) >>> # or @@ -351,16 +353,22 @@ def labels(self, datasets=None, order_by=None, created_by=None) -> PaginatedColl # Build where clause where_clauses = [] - + if datasets is not None: - dataset_ids = ", ".join('"%s"' % dataset.uid for dataset in datasets) - where_clauses.append(f"dataRow: {{dataset: {{id_in: [{dataset_ids}]}}}}") - + dataset_ids = ", ".join( + '"%s"' % dataset.uid for dataset in datasets + ) + where_clauses.append( + f"dataRow: {{dataset: {{id_in: [{dataset_ids}]}}}}" + ) + if created_by is not None: # Handle both User object and user_id string - user_id = created_by.uid if hasattr(created_by, 'uid') else created_by + user_id = ( + created_by.uid if hasattr(created_by, "uid") else created_by + ) where_clauses.append(f'createdBy: {{id: "{user_id}"}}') - + if where_clauses: where = " where:{" + ", ".join(where_clauses) + "}" else: @@ -396,7 +404,7 @@ def labels(self, datasets=None, order_by=None, created_by=None) -> PaginatedColl def delete_labels_by_user(self, user_id: str) -> int: """Soft deletes all labels created by a specific user in this project. - + This performs a soft delete (sets deleted=true in the database). The labels will no longer appear in queries but remain in the database. Labels are deleted in chunks of 500 to avoid overwhelming the API. @@ -413,18 +421,18 @@ def delete_labels_by_user(self, user_id: str) -> int: >>> print(f"Deleted {deleted_count} labels") """ labels_to_delete = list(self.labels(created_by=user_id)) - + if not labels_to_delete: return 0 - + chunk_size = 500 total_deleted = 0 - + for i in range(0, len(labels_to_delete), chunk_size): - chunk = labels_to_delete[i:i + chunk_size] + chunk = labels_to_delete[i : i + chunk_size] Entity.Label.bulk_delete(chunk) total_deleted += len(chunk) - + return total_deleted def export(