Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions libs/labelbox/src/labelbox/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,16 +507,16 @@ def delete_project_memberships(
self, project_id: str, user_ids: list[str]
) -> dict:
"""Deletes project memberships for one or more users.

Args:
project_id (str): ID of the project
user_ids (list[str]): List of user IDs to remove from the project

Returns:
dict: Result containing:
- success (bool): True if operation succeeded
- errorMessage (str or None): Error message if operation failed

Example:
>>> result = client.delete_project_memberships(
>>> project_id="project123",
Expand All @@ -539,12 +539,12 @@ def delete_project_memberships(
errorMessage
}
}"""

params = {
"projectId": project_id,
"userIds": user_ids,
}

result = self.execute(mutation, params)
return result["deleteProjectMemberships"]

Expand Down
36 changes: 22 additions & 14 deletions libs/labelbox/src/labelbox/schema/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,9 @@ def get_resource_tags(self) -> List[ResourceTag]:

return [ResourceTag(self.client, tag) for tag in results]

def labels(self, datasets=None, order_by=None, created_by=None) -> PaginatedCollection:
def labels(
self, datasets=None, order_by=None, created_by=None
) -> PaginatedCollection:
"""Custom relationship expansion method to support limited filtering.

Args:
Expand All @@ -334,7 +336,7 @@ def labels(self, datasets=None, order_by=None, created_by=None) -> PaginatedColl
Example:
>>> # Get all labels
>>> all_labels = project.labels()
>>>
>>>
>>> # Get labels by specific user
>>> user_labels = project.labels(created_by=user_id)
>>> # or
Expand All @@ -351,16 +353,22 @@ def labels(self, datasets=None, order_by=None, created_by=None) -> PaginatedColl

# Build where clause
where_clauses = []

if datasets is not None:
dataset_ids = ", ".join('"%s"' % dataset.uid for dataset in datasets)
where_clauses.append(f"dataRow: {{dataset: {{id_in: [{dataset_ids}]}}}}")

dataset_ids = ", ".join(
'"%s"' % dataset.uid for dataset in datasets
)
where_clauses.append(
f"dataRow: {{dataset: {{id_in: [{dataset_ids}]}}}}"
)

if created_by is not None:
# Handle both User object and user_id string
user_id = created_by.uid if hasattr(created_by, 'uid') else created_by
user_id = (
created_by.uid if hasattr(created_by, "uid") else created_by
)
where_clauses.append(f'createdBy: {{id: "{user_id}"}}')

if where_clauses:
where = " where:{" + ", ".join(where_clauses) + "}"
else:
Expand Down Expand Up @@ -396,7 +404,7 @@ def labels(self, datasets=None, order_by=None, created_by=None) -> PaginatedColl

def delete_labels_by_user(self, user_id: str) -> int:
"""Soft deletes all labels created by a specific user in this project.

This performs a soft delete (sets deleted=true in the database).
The labels will no longer appear in queries but remain in the database.
Labels are deleted in chunks of 500 to avoid overwhelming the API.
Expand All @@ -413,18 +421,18 @@ def delete_labels_by_user(self, user_id: str) -> int:
>>> print(f"Deleted {deleted_count} labels")
"""
labels_to_delete = list(self.labels(created_by=user_id))

if not labels_to_delete:
return 0

chunk_size = 500
total_deleted = 0

for i in range(0, len(labels_to_delete), chunk_size):
chunk = labels_to_delete[i:i + chunk_size]
chunk = labels_to_delete[i : i + chunk_size]
Entity.Label.bulk_delete(chunk)
total_deleted += len(chunk)

return total_deleted

def export(
Expand Down
25 changes: 0 additions & 25 deletions libs/labelbox/src/labelbox/schema/workflow/workflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,31 +117,6 @@ def validate_node_connections(
"node_type": node_type,
}
)
elif len(predecessors) > 1:
# Check if all predecessors are initial nodes
node_map = {n.id: n for n in nodes}
predecessor_nodes = [
node_map.get(pred_id) for pred_id in predecessors
]
all_initial = all(
pred_node
and pred_node.definition_id in initial_node_types
for pred_node in predecessor_nodes
if pred_node is not None
)

if not all_initial:
preds_info = ", ".join(
[p[:8] + "..." for p in predecessors]
)
errors.append(
{
"reason": f"has multiple incoming connections ({len(predecessors)}) but not all are from initial nodes",
"node_id": node.id,
"node_type": node_type,
"details": f"Connected from: {preds_info}",
}
)

# Check outgoing connections (except terminal nodes)
if node.definition_id not in terminal_node_types:
Expand Down
34 changes: 34 additions & 0 deletions libs/labelbox/tests/integration/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,40 @@ def test_workflow_creation(client, test_projects):
assert WorkflowDefinitionId.Done in node_types


def test_workflow_allows_multiple_incoming_from_non_initial_nodes(
client, test_projects
):
"""
Nodes may have multiple incoming connections from any nodes (not only initial nodes).

This used to fail validation when a node had >1 predecessor and at least one
predecessor was not an initial node.
"""
source_project, _ = test_projects

workflow = source_project.get_workflow()
initial_nodes = workflow.reset_to_initial_nodes(
labeling_config=LabelingConfig(instructions="Start labeling here")
)

logic = workflow.add_node(
type=NodeType.Logic,
name="Gate",
filters=ProjectWorkflowFilter([labeled_by.is_one_of(["test-user"])]),
)
review = workflow.add_node(type=NodeType.Review, name="Review Task")
done = workflow.add_node(type=NodeType.Done, name="Done")

# Multiple incoming connections to review, including from a non-initial node (logic)
workflow.add_edge(initial_nodes.labeling, logic)
workflow.add_edge(logic, review, NodeOutput.If)
workflow.add_edge(initial_nodes.rework, review)
workflow.add_edge(review, done, NodeOutput.Approved)

# Should validate and update successfully
workflow.update_config(reposition=False)


def test_workflow_creation_simple(client):
"""Test creating a simple workflow with the working pattern."""
# Create a new project for this test
Expand Down
64 changes: 64 additions & 0 deletions libs/labelbox/tests/unit/test_workflow_utils_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from dataclasses import dataclass

from labelbox.schema.workflow.enums import WorkflowDefinitionId
from labelbox.schema.workflow.graph import ProjectWorkflowGraph
from labelbox.schema.workflow.workflow_utils import WorkflowValidator


@dataclass(frozen=True)
class _Node:
id: str
definition_id: WorkflowDefinitionId


def test_validate_node_connections_allows_multiple_incoming_from_non_initial_nodes():
"""
Regression test: nodes may have multiple incoming connections from any nodes.

Historically validation required that if a node had >1 predecessors, they all had
to be initial nodes. Workflow Management now allows multi-input nodes from any
nodes, so this must not error.
"""
initial_labeling = _Node(
id="initial_labeling",
definition_id=WorkflowDefinitionId.InitialLabelingTask,
)
initial_rework = _Node(
id="initial_rework",
definition_id=WorkflowDefinitionId.InitialReworkTask,
)
logic = _Node(id="logic", definition_id=WorkflowDefinitionId.Logic)
review = _Node(id="review", definition_id=WorkflowDefinitionId.ReviewTask)
done = _Node(id="done", definition_id=WorkflowDefinitionId.Done)

nodes = [initial_labeling, initial_rework, logic, review, done]

graph = ProjectWorkflowGraph()
graph.add_edge(initial_labeling.id, logic.id)
graph.add_edge(logic.id, review.id)
graph.add_edge(initial_rework.id, review.id)
graph.add_edge(review.id, done.id)

errors = WorkflowValidator.validate_node_connections(nodes, graph)
assert errors == []


def test_validate_node_connections_still_flags_missing_incoming_connections():
"""Non-initial nodes must still have at least one incoming connection."""
initial_labeling = _Node(
id="initial_labeling",
definition_id=WorkflowDefinitionId.InitialLabelingTask,
)
review = _Node(id="review", definition_id=WorkflowDefinitionId.ReviewTask)
done = _Node(id="done", definition_id=WorkflowDefinitionId.Done)

nodes = [initial_labeling, review, done]
graph = ProjectWorkflowGraph()
graph.add_edge(initial_labeling.id, done.id)

errors = WorkflowValidator.validate_node_connections(nodes, graph)
assert any(
e.get("node_id") == review.id
and e.get("reason") == "has no incoming connections"
for e in errors
)
Loading