diff --git a/libs/labelbox/src/labelbox/schema/workflow/workflow_utils.py b/libs/labelbox/src/labelbox/schema/workflow/workflow_utils.py index bd2ca0ca0..1e71a576f 100644 --- a/libs/labelbox/src/labelbox/schema/workflow/workflow_utils.py +++ b/libs/labelbox/src/labelbox/schema/workflow/workflow_utils.py @@ -117,31 +117,6 @@ def validate_node_connections( "node_type": node_type, } ) - elif len(predecessors) > 1: - # Check if all predecessors are initial nodes - node_map = {n.id: n for n in nodes} - predecessor_nodes = [ - node_map.get(pred_id) for pred_id in predecessors - ] - all_initial = all( - pred_node - and pred_node.definition_id in initial_node_types - for pred_node in predecessor_nodes - if pred_node is not None - ) - - if not all_initial: - preds_info = ", ".join( - [p[:8] + "..." for p in predecessors] - ) - errors.append( - { - "reason": f"has multiple incoming connections ({len(predecessors)}) but not all are from initial nodes", - "node_id": node.id, - "node_type": node_type, - "details": f"Connected from: {preds_info}", - } - ) # Check outgoing connections (except terminal nodes) if node.definition_id not in terminal_node_types: diff --git a/libs/labelbox/tests/integration/test_workflow.py b/libs/labelbox/tests/integration/test_workflow.py index 96cb53b46..f09cac2a4 100644 --- a/libs/labelbox/tests/integration/test_workflow.py +++ b/libs/labelbox/tests/integration/test_workflow.py @@ -96,6 +96,40 @@ def test_workflow_creation(client, test_projects): assert WorkflowDefinitionId.Done in node_types +def test_workflow_allows_multiple_incoming_from_non_initial_nodes( + client, test_projects +): + """ + Nodes may have multiple incoming connections from any nodes (not only initial nodes). + + This used to fail validation when a node had >1 predecessor and at least one + predecessor was not an initial node. + """ + source_project, _ = test_projects + + workflow = source_project.get_workflow() + initial_nodes = workflow.reset_to_initial_nodes( + labeling_config=LabelingConfig(instructions="Start labeling here") + ) + + logic = workflow.add_node( + type=NodeType.Logic, + name="Gate", + filters=ProjectWorkflowFilter([labeled_by.is_one_of(["test-user"])]), + ) + review = workflow.add_node(type=NodeType.Review, name="Review Task") + done = workflow.add_node(type=NodeType.Done, name="Done") + + # Multiple incoming connections to review, including from a non-initial node (logic) + workflow.add_edge(initial_nodes.labeling, logic) + workflow.add_edge(logic, review, NodeOutput.If) + workflow.add_edge(initial_nodes.rework, review) + workflow.add_edge(review, done, NodeOutput.Approved) + + # Should validate and update successfully + workflow.update_config(reposition=False) + + def test_workflow_creation_simple(client): """Test creating a simple workflow with the working pattern.""" # Create a new project for this test diff --git a/libs/labelbox/tests/unit/test_workflow_utils_validation.py b/libs/labelbox/tests/unit/test_workflow_utils_validation.py new file mode 100644 index 000000000..6e608faa6 --- /dev/null +++ b/libs/labelbox/tests/unit/test_workflow_utils_validation.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass + +from labelbox.schema.workflow.enums import WorkflowDefinitionId +from labelbox.schema.workflow.graph import ProjectWorkflowGraph +from labelbox.schema.workflow.workflow_utils import WorkflowValidator + + +@dataclass(frozen=True) +class _Node: + id: str + definition_id: WorkflowDefinitionId + + +def test_validate_node_connections_allows_multiple_incoming_from_non_initial_nodes(): + """ + Regression test: nodes may have multiple incoming connections from any nodes. + + Historically validation required that if a node had >1 predecessors, they all had + to be initial nodes. Workflow Management now allows multi-input nodes from any + nodes, so this must not error. + """ + initial_labeling = _Node( + id="initial_labeling", + definition_id=WorkflowDefinitionId.InitialLabelingTask, + ) + initial_rework = _Node( + id="initial_rework", + definition_id=WorkflowDefinitionId.InitialReworkTask, + ) + logic = _Node(id="logic", definition_id=WorkflowDefinitionId.Logic) + review = _Node(id="review", definition_id=WorkflowDefinitionId.ReviewTask) + done = _Node(id="done", definition_id=WorkflowDefinitionId.Done) + + nodes = [initial_labeling, initial_rework, logic, review, done] + + graph = ProjectWorkflowGraph() + graph.add_edge(initial_labeling.id, logic.id) + graph.add_edge(logic.id, review.id) + graph.add_edge(initial_rework.id, review.id) + graph.add_edge(review.id, done.id) + + errors = WorkflowValidator.validate_node_connections(nodes, graph) + assert errors == [] + + +def test_validate_node_connections_still_flags_missing_incoming_connections(): + """Non-initial nodes must still have at least one incoming connection.""" + initial_labeling = _Node( + id="initial_labeling", + definition_id=WorkflowDefinitionId.InitialLabelingTask, + ) + review = _Node(id="review", definition_id=WorkflowDefinitionId.ReviewTask) + done = _Node(id="done", definition_id=WorkflowDefinitionId.Done) + + nodes = [initial_labeling, review, done] + graph = ProjectWorkflowGraph() + graph.add_edge(initial_labeling.id, done.id) + + errors = WorkflowValidator.validate_node_connections(nodes, graph) + assert any( + e.get("node_id") == review.id + and e.get("reason") == "has no incoming connections" + for e in errors + )