diff --git a/agentlightning/adapter/annotation.py b/agentlightning/adapter/annotation.py new file mode 100644 index 000000000..b65e9cdbc --- /dev/null +++ b/agentlightning/adapter/annotation.py @@ -0,0 +1,551 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Find and repair the annotations from spans.""" + +from __future__ import annotations + +import logging +from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, TypeVar, cast + +from opentelemetry.semconv.attributes import exception_attributes + +from agentlightning.adapter.preprocess import default_span_order +from agentlightning.emitter.message import get_message_value +from agentlightning.emitter.object import get_object_value +from agentlightning.emitter.reward import get_rewards_from_span +from agentlightning.semconv import ( + AGL_ANNOTATION, + AGL_EXCEPTION, + AGL_MESSAGE, + AGL_OBJECT, + AGL_OPERATION, + LightningSpanAttributes, + LinkPydanticModel, +) +from agentlightning.types.adapter import ( + AdaptingSpan, + AgentAnnotation, + Annotation, + BaseAdaptingSequence, + ExceptionAnnotation, + GeneralAnnotation, + MessageAnnotation, + ObjectAnnotation, + OperationAnnotation, + Tree, +) +from agentlightning.types.tracer import Span +from agentlightning.utils.otel import ( + check_linked_span, + extract_links_from_attributes, + extract_tags_from_attributes, + filter_and_unflatten_attributes, +) + +from .base import SequenceAdapter + +T_SpanSequence = TypeVar("T_SpanSequence", bound=Sequence[Span]) + +logger = logging.getLogger(__name__) + + +class IdentifyAnnotations(SequenceAdapter[AdaptingSpan, AdaptingSpan]): + """Identify and parse annotation data from spans based on span name conventions. + + This adapter inspects each span's name to determine if it represents a known + annotation type (general, message, object, exception, operation) or an agent span. + When identified, the span's data field is populated with the corresponding annotation + model containing extracted attributes. + + Supported annotation types: + + - `AGL_ANNOTATION`: General annotations with rewards, tags, and custom fields. + - `AGL_MESSAGE`: Message annotations containing a message body. + - `AGL_OBJECT`: Object annotations containing serialized JSON or literal values. + - `AGL_EXCEPTION`: Exception annotations with type, message, and stacktrace. + - `AGL_OPERATION`: Operation annotations with name, input, and output. + - Agent spans: Detected via heuristics for various agent frameworks. + """ + + def _filter_custom_attributes(self, attributes: Dict[str, Any]) -> Dict[str, Any]: + reserved_fields = [attr.value for attr in LightningSpanAttributes if attr.value in attributes] + return { + key: value + for key, value in attributes.items() + if not any( + # Filter out those that are reserved fields or start with reserved fields (plus ".") + key == reserved_field or key.startswith(reserved_field + ".") + for reserved_field in reserved_fields + ) + } + + def extract_links(self, span: Span) -> Sequence[LinkPydanticModel]: + """Extract link specifications from span attributes. + + Args: + span: The span to extract links from. + + Returns: + A sequence of link models. Returns an empty list if no links are found + or if the link attributes are malformed. + """ + try: + return extract_links_from_attributes(span.attributes) + except Exception as exc: + logger.error(f"Link is malformed for span {span.span_id}: {exc}") + return [] + + def identify_general(self, span: Span) -> Optional[GeneralAnnotation]: + """Parse a general annotation span into a `GeneralAnnotation` model. + + Extracts rewards, tags, links, and custom fields from the span attributes. + + Args: + span: A span with name `AGL_ANNOTATION`. + + Returns: + A `GeneralAnnotation` with extracted data, or None if parsing fails. + """ + rewards = get_rewards_from_span(span) + primary_reward = rewards[0].value if rewards else None + return GeneralAnnotation( + annotation_type="general", + links=self.extract_links(span), + rewards=rewards, + primary_reward=primary_reward, + tags=extract_tags_from_attributes(span.attributes), + custom_fields=self._filter_custom_attributes(span.attributes), + ) + + def identify_message(self, span: Span) -> Optional[MessageAnnotation]: + """Parse a message span into a `MessageAnnotation` model. + + Args: + span: A span with name `AGL_MESSAGE`. + + Returns: + A `MessageAnnotation` containing the message body, or None if the + message body attribute is missing. + """ + msg_body = get_message_value(span) + if msg_body is None: + logger.warning(f"Message body is missing for message span {span.span_id}") + return None + + return MessageAnnotation( + annotation_type="message", + links=self.extract_links(span), + message=msg_body, + ) + + def identify_object(self, span: Span) -> Optional[ObjectAnnotation]: + """Parse an object span into an `ObjectAnnotation` model. + + Supports both JSON-serialized objects and literal values. + + Args: + span: A span with name `AGL_OBJECT`. + + Returns: + An `ObjectAnnotation` containing the deserialized object, or None if + deserialization fails. + """ + try: + obj_value = get_object_value(span) + except Exception as exc: + logger.error(f"Fail to deserialize object for object span {span.span_id}: {exc}") + return None + + return ObjectAnnotation( + annotation_type="object", + links=self.extract_links(span), + object=obj_value, + ) + + def identify_exception(self, span: Span) -> Optional[ExceptionAnnotation]: + """Parse an exception span into an `ExceptionAnnotation` model. + + Uses OpenTelemetry semantic conventions for exception attributes. + + Args: + span: A span with name `AGL_EXCEPTION`. + + Returns: + An `ExceptionAnnotation` containing exception type, message, and stacktrace. + Missing fields default to "UnknownException" for type and empty string for others. + """ + exception_type = span.attributes.get(exception_attributes.EXCEPTION_TYPE, "UnknownException") + exception_message = span.attributes.get(exception_attributes.EXCEPTION_MESSAGE, "") + exception_stacktrace = span.attributes.get(exception_attributes.EXCEPTION_STACKTRACE, "") + + return ExceptionAnnotation( + annotation_type="exception", + links=self.extract_links(span), + type=str(exception_type), + message=str(exception_message), + stacktrace=str(exception_stacktrace), + ) + + def identify_operation(self, span: Span) -> Optional[OperationAnnotation]: + """Parse an operation span into an `OperationAnnotation` model. + + Extracts operation name, input, and output. Input/output can be either + direct values or nested structures reconstructed from flattened attributes. + + Args: + span: A span with name `AGL_OPERATION`. + + Returns: + An `OperationAnnotation` containing operation details, or None if + attribute unpacking fails. + """ + try: + operation_name = span.attributes.get(LightningSpanAttributes.OPERATION_NAME.value, "UnknownOperation") + if LightningSpanAttributes.OPERATION_INPUT.value in span.attributes: + operation_input = span.attributes[LightningSpanAttributes.OPERATION_INPUT.value] + else: + operation_input = filter_and_unflatten_attributes( + span.attributes, LightningSpanAttributes.OPERATION_INPUT.value + ) + if LightningSpanAttributes.OPERATION_OUTPUT.value in span.attributes: + operation_output = span.attributes[LightningSpanAttributes.OPERATION_OUTPUT.value] + else: + operation_output = filter_and_unflatten_attributes( + span.attributes, LightningSpanAttributes.OPERATION_OUTPUT.value + ) + except Exception as exc: + logger.error(f"Fail to unpack operation context for operation span {span.span_id}: {exc}") + return None + + return OperationAnnotation( + annotation_type="operation", + links=self.extract_links(span), + name=str(operation_name), + input=operation_input, + output=operation_output, + ) + + def extract_agent_id(self, span: Span) -> Optional[str]: + """Extract agent ID from span attributes. + + Args: + span: The span to extract the agent ID from. + + Returns: + The agent ID if found, None otherwise. + """ + # TODO: Support agent id in other formats + return cast(Optional[str], span.attributes.get("agent.id")) + + def extract_agent_description(self, span: Span) -> Optional[str]: + """Extract agent description from span attributes. + + Args: + span: The span to extract the agent description from. + + Returns: + The agent description if found, None otherwise. + """ + # TODO: Support agent description in other formats + return cast(Optional[str], span.attributes.get("agent.description")) + + def extract_agent_name(self, span: Span) -> Optional[str]: + """Extract agent name from span attributes using framework-specific heuristics. + + Supports multiple agent frameworks by checking various attribute patterns: + + 1. OpenTelemetry agent spans (`agent.name`) + 2. AgentOps decorated agents (`agentops.span.kind` + `operation.name`) + 3. Autogen teams (`recipient_agent_type`) + 4. LangGraph (`langchain.chain.type`) + 5. agent-framework (`executor.id`) + 6. Weave (`type` == "agent" + `agentlightning.operation.input.name`) + 7. Weave + LangChain (`langchain.Chain.*` span names + `lc_name`) + + Args: + span: The span to extract the agent name from. + + Returns: + The agent name if detected via any supported pattern, None otherwise. + """ + # Case 1: OpenTelemetry Agent Spans + agent_name = cast(Optional[str], span.attributes.get("agent.name")) + if agent_name is not None: + return agent_name + + # Case 2: Agentops decorator @agent + is_agent = span.attributes.get("agentops.span.kind") == "agent" + if is_agent: + agent_name = cast(Optional[str], span.attributes.get("operation.name")) + if agent_name is not None: + return agent_name + + # Case 3: Autogen team + agent_name = cast(Optional[str], span.attributes.get("recipient_agent_type")) + if agent_name is not None: + return agent_name + + # Case 4: LangGraph + agent_name = cast(Optional[str], span.attributes.get("langchain.chain.type")) + if agent_name is not None: + return agent_name + + # Case 5: agent-framework + agent_name = cast(Optional[str], span.attributes.get("executor.id")) + if agent_name is not None: + return agent_name + + # Case 6: Weave + is_agent_type = span.attributes.get("type") == "agent" + if is_agent_type: + agent_name = cast(Optional[str], span.attributes.get("agentlightning.operation.input.name")) + if agent_name is not None: + return agent_name + + # Case 7: Weave + LangChain + if span.name.startswith("langchain.Chain."): + attributes_lc_name = cast(Optional[str], span.attributes.get("lc_name")) + if attributes_lc_name is not None: + return attributes_lc_name + + return None + + def detect_agent_annotation(self, span: Span) -> Optional[AgentAnnotation]: + """Detect and create an agent annotation from span attributes. + + Uses heuristics to identify spans representing agent executions from + various frameworks (OpenTelemetry, AgentOps, Autogen, LangGraph, etc.). + + Args: + span: The span to check for agent indicators. + + Returns: + An `AgentAnnotation` if an agent is detected, None otherwise. + """ + agent_id = self.extract_agent_id(span) + agent_name = self.extract_agent_name(span) + agent_description = self.extract_agent_description(span) + + if agent_name is not None: + return AgentAnnotation( + annotation_type="agent", + links=self.extract_links(span), + id=agent_id, + name=agent_name, + description=agent_description, + ) + return None + + def adapt_one(self, source: AdaptingSpan) -> AdaptingSpan: + """Process a single span to identify and attach annotation data. + + Checks the span name against known annotation types and parses the + corresponding annotation model. Falls back to agent detection for + unrecognized span names. + + Args: + source: The span to process. + + Returns: + The span with annotation data attached if identified, otherwise + the original span unchanged. + """ + annotation: Optional[Annotation] = None + if source.name == AGL_ANNOTATION: + annotation = self.identify_general(source) + elif source.name == AGL_MESSAGE: + annotation = self.identify_message(source) + elif source.name == AGL_OBJECT: + annotation = self.identify_object(source) + elif source.name == AGL_EXCEPTION: + annotation = self.identify_exception(source) + elif source.name == AGL_OPERATION: + annotation = self.identify_operation(source) + else: + # Fallback to agent annotation detection + annotation = self.detect_agent_annotation(source) + if annotation is not None: + return source.with_data(annotation) + else: + return source + + +class SelectByAnnotation(SequenceAdapter[AdaptingSpan, AdaptingSpan]): + """Select the corresponding spans within the annotation sequence, as well as their linked spans + (and subtree spans if applicable). + + The effective radius of an annotation is as follows: + + - If the annotation has links, it applies to the linked spans only. + - If the annotation is on a tree node, it applies to all spans in its subtree. + - If the annotation has neither links nor tree nodes, it applies to only itself. + + The adapter either selects the union of the effective radius of all annotations, + or excludes the union of effective radius. + + When the source is a tree, to avoid the tree nodes from becoming fragmented, + the adapter will also include the ancestors of the tree nodes in "include" mode. + + Args: + mode: "include" to select spans within the annotations; "exclude" to exclude them. + """ + + def __init__(self, mode: Literal["include", "exclude"]) -> None: + self.mode = mode + + def _filter_linked_spans(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Iterable[AdaptingSpan]: + annotation_spans = [span for span in source if isinstance(span.data, Annotation)] + annotation_span_ids = set(annotation_span.span_id for annotation_span in annotation_spans) + annotation_links = [cast(Annotation, span.data).links for span in annotation_spans] + for span in source: + if span.span_id in annotation_span_ids: + yield span + # Only check non-empty link lists; empty links means the annotation applies only to itself + # (check_linked_span returns True for empty links, which would incorrectly match all spans) + elif any(links and check_linked_span(span, links) for links in annotation_links): + yield span + # ignore the current span for now + + def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> BaseAdaptingSequence[AdaptingSpan]: + """Filter spans based on annotation membership and links. + + Args: + source: The span sequence to filter. + + Returns: + A filtered sequence containing only annotated spans and their linked + spans (include mode), or all spans except those (exclude mode). + """ + linked_spans = list(self._filter_linked_spans(source)) + if self.mode == "include": + return source.retain(lambda span: span in linked_spans) + else: + # prune removes items where predicate is True, so we remove linked spans + return source.prune(lambda span: span in linked_spans) + + +class RepairMissingLinks(SequenceAdapter[AdaptingSpan, AdaptingSpan]): + """Populate missing annotation links by searching nearby spans. + + This adapter scans annotations and, for any annotation that has no linked spans, attempts + to infer and attach link targets using a configurable search strategy. + + Typical use case: upstream extraction produced annotations (e.g., entities, citations) + but failed to attach their target spans; this adapter backfills those links based on + proximity and eligibility rules. + + Args: + candidate_predicate: + A predicate to filter the candidate spans. If None, all spans within the candidate scope are considered. + + candidate_scope: + Controls which spans are eligible as link targets: + + - "siblings": search only among sibling spans of the annotation span. Only applicable when input span sequence is a tree. + - "all": search among all spans provided to the adapter. + + The intersection of the candidate scope and predicate forms the candidate span set. + + scan_direction: + Determines both (a) which direction the adapter searches for candidate targets + relative to an annotation and (b) the order in which annotations are processed: + + - "backward": search earlier spans; process annotations from latest to earliest. + - "forward": search later spans; process annotations from earliest to latest. + + allow_reuse_linked_spans: + If False, spans already linked by *any* annotation are not eligible targets for + additional links (i.e., enforce a one-to-one-ish linking constraint). + If True, a span may be linked multiple times by different annotations. + """ + + def __init__( + self, + candidate_predicate: Optional[Callable[[AdaptingSpan], bool]] = None, + candidate_scope: Literal["siblings", "all"] = "all", + scan_direction: Literal["backward", "forward"] = "backward", + allow_reuse_linked_spans: bool = False, + ) -> None: + if candidate_predicate is not None: + self.candidate_predicate = candidate_predicate + else: + self.candidate_predicate: Callable[[AdaptingSpan], bool] = lambda _: True + self.candidate_scope = candidate_scope + self.scan_direction = scan_direction + self.allow_reuse_linked_spans = allow_reuse_linked_spans + + def _search_groups(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Iterable[Sequence[AdaptingSpan]]: + if self.candidate_scope == "siblings": + if not isinstance(source, Tree): + raise ValueError("Candidate scope 'siblings' is only applicable to tree sequences") + + def visit(node: Tree[AdaptingSpan]) -> Iterable[Sequence[AdaptingSpan]]: + # Each group must be siblings + yield [child.item for child in node.children] # yield siblings first + for child in node.children: # then yield children recursively + yield from visit(child) + + yield [source.item] # yield root first + yield from visit(source) + + elif self.candidate_scope == "all": + # Return as a single group containing all spans sorted by default order + yield sorted(list(source), key=lambda span: default_span_order(span)) + + else: + raise ValueError(f"Invalid candidate scope: {self.candidate_scope}") + + def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> BaseAdaptingSequence[AdaptingSpan]: + """Repair annotations that have no links by inferring targets from nearby spans. + + Scans the span sequence according to the configured direction and scope, + linking annotations without targets to the nearest eligible candidate spans. + + Args: + source: The span sequence containing annotations to repair. + + Returns: + A new sequence with repaired annotations containing inferred links. + + Raises: + ValueError: If `candidate_scope` is "siblings" but source is not a tree. + """ + groups = list(self._search_groups(source)) + span_id_to_link: Dict[str, LinkPydanticModel] = {} + for group in groups: + if self.scan_direction == "backward": + group_to_scan = reversed(group) + else: + group_to_scan = group + + annotations_to_fill: List[AdaptingSpan] = [] + for span in group_to_scan: + if isinstance(span.data, Annotation): + if not span.data.links: + annotations_to_fill.append(span) + # The span is an annotation, skip it from being a candidate + else: + # The span is a candidate + if self.candidate_predicate(span): + while len(annotations_to_fill) > 0: + # Fill the link with the earliest-encountered annotation first (FIFO order) + # This ensures each annotation links to its nearest candidate in the scan direction + annotation_span = annotations_to_fill.pop(0) + span_id_to_link[annotation_span.span_id] = LinkPydanticModel( + key_match="span_id", value_match=span.span_id + ) + + if not self.allow_reuse_linked_spans: + # Once used, the candidate span cannot be reused + break + # If no annotations to fill, the candidate is wasted + # Otherwise, the span is not a candidate, skip it + + def _update_links(span: AdaptingSpan) -> AdaptingSpan: + if span.span_id in span_id_to_link and isinstance(span.data, Annotation): + new_annotation = span.data.model_copy(update={"links": [span_id_to_link[span.span_id]]}) + return span.model_copy(update={"data": new_annotation}) + else: + return span + + return source.map(_update_links) diff --git a/agentlightning/adapter/base.py b/agentlightning/adapter/base.py index 2a148f78d..0f2e78bfe 100644 --- a/agentlightning/adapter/base.py +++ b/agentlightning/adapter/base.py @@ -1,13 +1,17 @@ # Copyright (c) Microsoft. All rights reserved. -from typing import Generic, Sequence, TypeVar +from typing import Any, Callable, Generic, Sequence, TypeVar, overload from opentelemetry.sdk.trace import ReadableSpan from agentlightning.types import Span +from agentlightning.types.adapter import BaseAdaptingSequence, BaseAdaptingSequenceItem -T_from = TypeVar("T_from") -T_to = TypeVar("T_to") +T_inv = TypeVar("T_inv") +T_from = TypeVar("T_from", contravariant=True) +T_to = TypeVar("T_to", covariant=True) +T_seq_from = TypeVar("T_seq_from", contravariant=True, bound=BaseAdaptingSequenceItem) +T_seq_to = TypeVar("T_seq_to", covariant=True, bound=BaseAdaptingSequenceItem) class Adapter(Generic[T_from, T_to]): @@ -66,6 +70,184 @@ def adapt(self, source: T_from, /) -> T_to: raise NotImplementedError("Adapter.adapt() is not implemented") +class SequenceAdapter( + Adapter[BaseAdaptingSequence[T_seq_from], BaseAdaptingSequence[T_seq_to]], + Generic[T_seq_from, T_seq_to], +): + """Base class for adapters that convert adapting sequences of data from one format to another. + + This class specializes [`Adapter`][agentlightning.Adapter] for working with + [`AdaptingSequence`][agentlightning.AdaptingSequence] instances. + """ + + def adapt(self, source: BaseAdaptingSequence[T_seq_from]) -> BaseAdaptingSequence[T_seq_to]: + return source.map(self.adapt_one) + + def adapt_one(self, source: T_seq_from) -> T_seq_to: + raise NotImplementedError(f"{self.__class__.__name__}.adapt_one() is not implemented") + + +class Filter(Adapter[Sequence[T_inv], Sequence[T_inv]], Generic[T_inv]): + """Filter items of type T to items of type T based on a predicate.""" + + def __init__(self, predicate: Callable[[T_inv], bool]) -> None: + self.predicate = predicate + + def adapt(self, source: Sequence[T_inv]) -> Sequence[T_inv]: + return [item for item in source if self.predicate(item)] + + +class Sort(Adapter[Sequence[T_inv], Sequence[T_inv]], Generic[T_inv]): + """Sort items of type T based on a key function.""" + + def __init__(self, key: Callable[[T_inv], Any]) -> None: + self.key = key + + def adapt(self, source: Sequence[T_inv]) -> Sequence[T_inv]: + return sorted(source, key=self.key) + + +T_chain1 = TypeVar("T_chain1") +T_chain2 = TypeVar("T_chain2") +T_chain3 = TypeVar("T_chain3") +T_chain4 = TypeVar("T_chain4") +T_chain5 = TypeVar("T_chain5") +T_chain6 = TypeVar("T_chain6") +T_chain7 = TypeVar("T_chain7") +T_chain8 = TypeVar("T_chain8") +T_chain9 = TypeVar("T_chain9") + + +class Chain(Adapter[T_from, T_to]): + """Chain multiple adapters together to form a single adapter. + + The output of each adapter is passed as input to the next adapter in the chain. + """ + + @overload + def __init__( + self, + adapter1: Adapter[T_from, T_chain1], + adapter2: Adapter[T_chain1, T_to], + /, + ) -> None: ... + + @overload + def __init__( + self, + adapter1: Adapter[T_from, T_chain1], + adapter2: Adapter[T_chain1, T_chain2], + adapter3: Adapter[T_chain2, T_to], + /, + ) -> None: ... + + @overload + def __init__( + self, + adapter1: Adapter[T_from, T_chain1], + adapter2: Adapter[T_chain1, T_chain2], + adapter3: Adapter[T_chain2, T_chain3], + adapter4: Adapter[T_chain3, T_to], + /, + ) -> None: ... + + @overload + def __init__( + self, + adapter1: Adapter[T_from, T_chain1], + adapter2: Adapter[T_chain1, T_chain2], + adapter3: Adapter[T_chain2, T_chain3], + adapter4: Adapter[T_chain3, T_chain4], + adapter5: Adapter[T_chain4, T_to], + /, + ) -> None: ... + + @overload + def __init__( + self, + adapter1: Adapter[T_from, T_chain1], + adapter2: Adapter[T_chain1, T_chain2], + adapter3: Adapter[T_chain2, T_chain3], + adapter4: Adapter[T_chain3, T_chain4], + adapter5: Adapter[T_chain4, T_chain5], + adapter6: Adapter[T_chain5, T_to], + /, + ) -> None: ... + + @overload + def __init__( + self, + adapter1: Adapter[T_from, T_chain1], + adapter2: Adapter[T_chain1, T_chain2], + adapter3: Adapter[T_chain2, T_chain3], + adapter4: Adapter[T_chain3, T_chain4], + adapter5: Adapter[T_chain4, T_chain5], + adapter6: Adapter[T_chain5, T_chain6], + adapter7: Adapter[T_chain6, T_to], + /, + ) -> None: ... + + @overload + def __init__( + self, + adapter1: Adapter[T_from, T_chain1], + adapter2: Adapter[T_chain1, T_chain2], + adapter3: Adapter[T_chain2, T_chain3], + adapter4: Adapter[T_chain3, T_chain4], + adapter5: Adapter[T_chain4, T_chain5], + adapter6: Adapter[T_chain5, T_chain6], + adapter7: Adapter[T_chain6, T_chain7], + adapter8: Adapter[T_chain7, T_to], + /, + ) -> None: ... + + @overload + def __init__( + self, + adapter1: Adapter[T_from, T_chain1], + adapter2: Adapter[T_chain1, T_chain2], + adapter3: Adapter[T_chain2, T_chain3], + adapter4: Adapter[T_chain3, T_chain4], + adapter5: Adapter[T_chain4, T_chain5], + adapter6: Adapter[T_chain5, T_chain6], + adapter7: Adapter[T_chain6, T_chain7], + adapter8: Adapter[T_chain7, T_chain8], + adapter9: Adapter[T_chain8, T_to], + /, + ) -> None: ... + + @overload + def __init__( + self, + adapter1: Adapter[T_from, T_chain1], + adapter2: Adapter[T_chain1, T_chain2], + adapter3: Adapter[T_chain2, T_chain3], + adapter4: Adapter[T_chain3, T_chain4], + adapter5: Adapter[T_chain4, T_chain5], + adapter6: Adapter[T_chain5, T_chain6], + adapter7: Adapter[T_chain6, T_chain7], + adapter8: Adapter[T_chain7, T_chain8], + adapter9: Adapter[T_chain8, T_chain9], + adapter10: Adapter[T_chain9, T_to], + /, + ) -> None: ... + + def __init__(self, adapter1: Adapter[Any, Any], *adapters: Adapter[Any, Any]) -> None: + # Enforce that a Chain always has at least one adapter. + self.adapters: tuple[Adapter[Any, Any], ...] = (adapter1, *adapters) + + def adapt(self, source: T_from) -> T_to: + result: Any = source + for idx, adapter in enumerate(self.adapters): + try: + result = adapter.adapt(result) + except Exception as exc: + raise RuntimeError( + f"Adapter chain failed at adapter index {idx} ({adapter.__class__.__name__}). See inner exception for details." + ) from exc + return result + + class OtelTraceAdapter(Adapter[Sequence[ReadableSpan], T_to], Generic[T_to]): """Base class for adapters that convert OpenTelemetry trace spans into other formats. diff --git a/agentlightning/adapter/call.py b/agentlightning/adapter/call.py new file mode 100644 index 000000000..e5ead5333 --- /dev/null +++ b/agentlightning/adapter/call.py @@ -0,0 +1,280 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Handles chat completion / response calls. Extracts them from spans and annotates them with annotations.""" + +from __future__ import annotations + +import ast +import json +from typing import Any, Dict, List, Optional, cast + +from openai.types.chat import ( + ChatCompletion, + CompletionCreateParams, +) +from pydantic import TypeAdapter + +from agentlightning.types.adapter import ( + AdaptingSpan, + AnnotatedChatCompletionCall, + Annotation, + BaseAdaptingSequence, + ChatCompletionCall, + Tree, +) +from agentlightning.types.tracer import Span +from agentlightning.utils.otel import filter_and_unflatten_attributes, query_linked_spans +from agentlightning.utils.pydantic import to_plain_object + +from .base import SequenceAdapter + +CompletionCreateParamsType: TypeAdapter[CompletionCreateParams] = TypeAdapter(CompletionCreateParams) + + +class IdentifyChatCompletionCalls(SequenceAdapter[AdaptingSpan, AdaptingSpan]): + """Curate the chat completion calls from the spans.""" + + def _parse_request(self, span: AdaptingSpan) -> Dict[str, Any]: + request = filter_and_unflatten_attributes(span.attributes, "gen_ai.request") + if not isinstance(request, dict): + raise ValueError(f"Invalid request format in span attributes: {request}") + if "functions" in request: + if not isinstance(request["functions"], list): + raise ValueError(f"Invalid functions format in request. Must be a list: {request['functions']}") + for function in cast(List[Any], request["functions"]): + if not isinstance(function, dict): + raise ValueError(f"Invalid function format in request. Must be a dict: {function}") + if "parameters" in function and isinstance(function["parameters"], str): + function["parameters"] = json.loads(function["parameters"]) + return request + + def _parse_response(self, span: AdaptingSpan) -> Dict[str, Any]: + response = filter_and_unflatten_attributes(span.attributes, "gen_ai.response") + if not isinstance(response, dict): + raise ValueError(f"Invalid response format in span attributes: {response}") + if "created" not in response: + response["created"] = int(span.ensure_end_time()) + if "object" not in response: + response["object"] = "chat.completion" + return response + + def _parse_completion_choices(self, span: AdaptingSpan) -> List[Dict[str, Any]]: + completion_choices = filter_and_unflatten_attributes(span.attributes, "gen_ai.completion") + if not isinstance(completion_choices, list): + raise ValueError(f"Invalid completion choices format in span attributes: {completion_choices}") + for index, choice in enumerate(completion_choices): + if not isinstance(choice, dict): + raise ValueError( + f"Invalid completion choice format in span attributes. Choice must be a dict: {choice}" + ) + + choice["index"] = index + + # Uncover the message from the choice + message: Dict[str, Any] = { + "role": "assistant", + } + if "content" in choice: + message["content"] = cast(Dict[str, Any], choice).pop("content") + if isinstance(span.container, Tree): + # Get additional fields from child spans if any + for child in span.children(): + tool_call = self._parse_agentops_tool_calls(child) + if tool_call is not None: + message.setdefault("tool_calls", []).append(tool_call) + choice["message"] = message + + return completion_choices + + def _parse_prompt_messages(self, span: AdaptingSpan) -> List[Dict[str, Any]]: + prompt_messages = filter_and_unflatten_attributes(span.attributes, "gen_ai.prompt") + if not isinstance(prompt_messages, list) or not all(isinstance(msg, dict) for msg in prompt_messages): + raise ValueError(f"Invalid prompt messages format in span attributes: {prompt_messages}") + + # Fix discrepency between OpenAI API and AGL span attributes. + for message in prompt_messages: + if not isinstance(message, dict): + raise ValueError(f"Invalid message format in prompt messages. Must be a dict: {message}") + if "tool_calls" in message: + if not isinstance(message["tool_calls"], list): + raise ValueError(f"Invalid tool calls format in message. Must be a list: {message['tool_calls']}") + for tool_call in cast(List[Any], message["tool_calls"]): + if not isinstance(tool_call, dict): + raise ValueError(f"Invalid tool call format in message. Must be a dict: {tool_call}") + if "type" not in tool_call: + tool_call["type"] = "function" + if "function" not in tool_call and "name" in tool_call and "arguments" in tool_call: + tool_call["function"] = { + "name": cast(Dict[str, Any], tool_call).pop("name"), + "arguments": cast(Dict[str, Any], tool_call).pop("arguments"), + } + return prompt_messages + + def _parse_usages(self, span: AdaptingSpan) -> Dict[str, Any]: + usages = filter_and_unflatten_attributes(span.attributes, "gen_ai.usage") + if not isinstance(usages, dict): + raise ValueError(f"Invalid usages format in span attributes: {usages}") + if "prompt_tokens" not in usages: + usages["prompt_tokens"] = 0 + if "completion_tokens" not in usages: + usages["completion_tokens"] = 0 + if "total_tokens" not in usages: + usages["total_tokens"] = usages["prompt_tokens"] + usages["completion_tokens"] + return usages + + def _parse_agentops_tool_calls(self, span: Span) -> Optional[Dict[str, Any]]: + if span.name.startswith("tool_call."): + tool_call_data = filter_and_unflatten_attributes(span.attributes, "tool") + if isinstance(tool_call_data, dict): + if "call" in tool_call_data: + # Example tool_call_data: + # {'tool.name': 'get_rooms', 'tool.parameters': '{"date": ...}', + # 'tool.call.id': 'call_owd6', 'tool.call.type': 'function'} + tool_call_data = { + **tool_call_data["call"], + "function": {k: v for k, v in tool_call_data.items() if k != "call"}, + } + if ( + "function" in tool_call_data + and isinstance(tool_call_data["function"], dict) + and "parameters" in tool_call_data["function"] + ): + tool_call_data["function"]["arguments"] = tool_call_data["function"].pop("parameters") # type: ignore + + return cast(Dict[str, Any], tool_call_data) + return None + + def _normalize_request(self, request_body: Dict[str, Any]) -> CompletionCreateParams: + validated_request = CompletionCreateParamsType.validate_python(request_body) + return to_plain_object(validated_request, []) + + def _parse_openai_chat_completion_create(self, span: AdaptingSpan) -> ChatCompletionCall: + prompt_messages = self._parse_prompt_messages(span) + request_metadata = self._parse_request(span) + completion_choices = self._parse_completion_choices(span) + usages = self._parse_usages(span) + response_metadata = self._parse_response(span) + + return self._construct_chat_completion_call( + {"messages": prompt_messages, **request_metadata}, + { + **response_metadata, + "choices": completion_choices, + "usage": usages, + }, + ) + + def _augment_litellm_raw_gen_ai_request( + self, span: AdaptingSpan, request: Dict[str, Any], response: Dict[str, Any] + ) -> None: + """Augment the request/response with more rich info from the sibling raw_gen_ai_request span. + + The request and response are modified in place. + """ + hosted_vllm = filter_and_unflatten_attributes(span.attributes, "llm.hosted_vllm") + if not hosted_vllm: + return + + if not isinstance(hosted_vllm, dict): + raise ValueError(f"Invalid hosted_vllm format in span attributes: {hosted_vllm}") + + if "choices" in hosted_vllm: + choices = ast.literal_eval(hosted_vllm["choices"]) + if not isinstance(choices, list) or not choices or not isinstance(choices[0], dict): + raise ValueError(f"Invalid choices format in hosted_vllm: {choices}") + if "token_ids" in choices[0]: + response["choices"][0]["token_ids"] = choices[0]["token_ids"] + + if "prompt_token_ids" in hosted_vllm: + request["prompt_token_ids"] = ast.literal_eval(hosted_vllm["prompt_token_ids"]) + + def _parse_litellm_request(self, span: AdaptingSpan) -> ChatCompletionCall: + prompt_messages = self._parse_prompt_messages(span) + completion_choices = self._parse_completion_choices(span) + usages = self._parse_usages(span) + request_metadata = self._parse_request(span) + response_metadata = self._parse_response(span) + + request_body = {"messages": prompt_messages, **request_metadata} + response_body = { + **response_metadata, + "choices": completion_choices, + "usage": usages, + } + + # If the underlying backend is vllm, we have more rich info in sibling span. + for sibling in span.siblings(): + if sibling.name == "raw_gen_ai_request": + self._augment_litellm_raw_gen_ai_request(sibling, request_body, response_body) + + return self._construct_chat_completion_call(request_body, response_body) + + def _construct_chat_completion_call( + self, request_body: Dict[str, Any], response_body: Dict[str, Any] + ) -> ChatCompletionCall: + request = self._normalize_request(request_body) + response = ChatCompletion.model_validate(response_body, extra="allow") + return ChatCompletionCall.model_construct( + request=request, + response=response, + malformed_fields={}, # TODO: malformed fields + ) + + def adapt_one(self, source: AdaptingSpan) -> AdaptingSpan: + if source.name == "openai.chat.completion.create" or source.name == "openai.chat.completion": + chat_completion_call = self._parse_openai_chat_completion_create(source) + return source.with_data(chat_completion_call) + elif source.name == "litellm_request": + # Litellm request span + chat_completion_call = self._parse_litellm_request(source) + return source.with_data(chat_completion_call) + else: + # Not a chat completion call span. Do nothing + return source + + +class AnnotateChatCompletionCalls(SequenceAdapter[AdaptingSpan, AdaptingSpan]): + """Annotate chat completion calls with the given annotations. + + The intersection of "effective radius" of annotations and chat completion calls is used to determine + which annotations apply to which chat completion calls. + + If an annotation is not linked to any span, try to use `RepairMissingLinks` first to link it to spans. + """ + + def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> BaseAdaptingSequence[AdaptingSpan]: + annotation_spans = [span for span in source if isinstance(span.data, Annotation)] + span_id_to_updated_annotation: Dict[str, AnnotatedChatCompletionCall] = {} + for annotation_span in annotation_spans: + annotation = cast(Annotation, annotation_span.data) + for linked_span in query_linked_spans(source, annotation.links): + if isinstance(linked_span.container, Tree): + linked_spans = list(linked_span.container.traverse()) + else: + linked_spans = [linked_span] + + for linked_span in linked_spans: + if isinstance(linked_span.data, ChatCompletionCall): + existing_annotations: List[Annotation] = ( + list(linked_span.data.annotations) + if isinstance(linked_span.data, AnnotatedChatCompletionCall) + else [] + ) + # Annotate the chat completion call + annotated_call = AnnotatedChatCompletionCall( + request=linked_span.data.request, + response=linked_span.data.response, + malformed_fields=linked_span.data.malformed_fields, + annotations=existing_annotations + [annotation], + ) + # Update the linked span with the annotated call + span_id_to_updated_annotation[linked_span.span_id] = annotated_call + + def _update_span(span: AdaptingSpan) -> AdaptingSpan: + if span.span_id in span_id_to_updated_annotation: + annotated_call = span_id_to_updated_annotation[span.span_id] + return span.with_data(annotated_call, override="silent") # override is expected here + else: + return span + + return source.map(_update_span) diff --git a/agentlightning/adapter/postprocess.py b/agentlightning/adapter/postprocess.py new file mode 100644 index 000000000..c9b416c96 --- /dev/null +++ b/agentlightning/adapter/postprocess.py @@ -0,0 +1,422 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Post-process the data to make it more suitable for training.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Generic, Iterable, List, Literal, Optional, Sequence, TypeVar, Union, cast + +from openai.types.chat import ChatCompletion, ChatCompletionAssistantMessageParam +from pydantic import TypeAdapter + +from agentlightning.types.adapter import ( + AdaptingSpan, + AnnotatedChatCompletionCall, + BaseAdaptingSequence, + ChatCompletionCall, + GeneralAnnotation, + PromptCompletionAccumulation, + PromptCompletionTriplet, + TokenInput, + TokenOutput, + TokensAccumulation, + TokensAccumulationDiagnosis, + TokensTriplet, +) +from agentlightning.utils.pydantic import to_plain_object + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + +from .base import Adapter + +TripletOrAccumulation = Union[TokensTriplet, TokensAccumulation, PromptCompletionTriplet, PromptCompletionAccumulation] +T_triplet_or_accumulation = TypeVar( + "T_triplet_or_accumulation", + bound=TripletOrAccumulation, +) +T_triplet = TypeVar("T_triplet", bound=Union[TokensTriplet, PromptCompletionTriplet]) + + +def is_prefix(shorter: Iterable[Any], longer: Iterable[Any]) -> bool: + """Check if the shorter sequence is a prefix of the longer sequence.""" + longer_iter = iter(longer) + for item in shorter: + try: + expected_item = next(longer_iter) + if item != expected_item: + return False + except StopIteration: + return False + return True + + +class ToTripletMixin(Generic[T_triplet]): + """Mixin for adapters that convert chat completion calls to triplets.""" + + def __init__(self, strict: bool = False): + self.strict = strict + + def get_reward(self, call: Union[AnnotatedChatCompletionCall, ChatCompletionCall]) -> Optional[float]: + if isinstance(call, AnnotatedChatCompletionCall): + for annotation in call.annotations: + # The first general annotation + if isinstance(annotation, GeneralAnnotation): + # The primary reward + return annotation.primary_reward + return None + + def to_triplet( + self, call: Union[AnnotatedChatCompletionCall, ChatCompletionCall] + ) -> Union[T_triplet, BaseException]: + raise NotImplementedError() + + def to_triplets(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[T_triplet]: + exceptions: List[BaseException] = [] + triplets: List[T_triplet] = [] + for span in source: + if isinstance(span.data, (AnnotatedChatCompletionCall, ChatCompletionCall)): + triplet = self.to_triplet(span.data) + if isinstance(triplet, BaseException): + exceptions.append(triplet) + else: + triplets.append(triplet) + if len(triplets) == 0: + error_msg = ( + f"{self.__class__.__name__} failed to create any triplets. " + f"The adapter has raised {len(exceptions)} exceptions when processing the spans:\n" + + "\n".join([f" - {exc}" for exc in exceptions]) + ) + raise RuntimeError(error_msg) + triplets[-1] = triplets[-1].model_copy(update={"done": True}) + return triplets + + +class ToTokensTriplets( + ToTripletMixin[TokensTriplet], Adapter[BaseAdaptingSequence[AdaptingSpan], Sequence[TokensTriplet]] +): + """Convert adapting spans to token input-output triplets. + + Args: + strict: Whether to raise an exception if the triplet cannot be created. + If False, the exception will be added to the list of exceptions and the triplet will be skipped. + The exceptions will also be raised when the resulting sequence is empty. + If True, the exception will be raised. + Default is False. + """ + + def to_triplet( + self, call: Union[AnnotatedChatCompletionCall, ChatCompletionCall] + ) -> Union[TokensTriplet, BaseException]: + try: + return TokensTriplet.model_construct( + observation=TokenInput( + token_ids=self._get_prompt_token_ids(call), image_urls=self._get_image_urls(call) + ), + action=TokenOutput(token_ids=self._get_completion_token_ids(call), logprobs=self._get_logprobs(call)), + reward=self.get_reward(call), + done=False, # False by now + raw_call=call, + ) + except Exception as exc: + if self.strict: + raise exc + return exc + + def _get_prompt_token_ids(self, call: Union[AnnotatedChatCompletionCall, ChatCompletionCall]) -> Sequence[int]: + prompt = call.request + if "prompt_token_ids" in prompt: + prompt_token_ids = prompt["prompt_token_ids"] + else: + raise ValueError(f"Prompt token ids not found in call: {call}") + # Validate the prompt token ids + if not isinstance(prompt_token_ids, list) or not all(isinstance(x, int) for x in prompt_token_ids): # type: ignore + raise ValueError(f"Invalid prompt token ids. Must be a list of ints. Got: {prompt_token_ids}") + if len(prompt_token_ids) == 0: + raise ValueError("Prompt token ids is empty.") + return prompt_token_ids + + def _get_image_urls(self, call: Union[AnnotatedChatCompletionCall, ChatCompletionCall]) -> Sequence[str]: + image_urls: List[str] = [] + for message in call.request["messages"]: + if "content" not in message: + continue + content = message["content"] + if content is None: + continue + if isinstance(content, list): + for part in content: + if part["type"] == "image_url": + image_urls.append(part["image_url"]["url"]) + return image_urls + + def _get_completion_token_ids(self, call: Union[AnnotatedChatCompletionCall, ChatCompletionCall]) -> Sequence[int]: + completion_choice = call.response.choices[0] + if hasattr(completion_choice, "token_ids") and completion_choice.token_ids is not None: # type: ignore + response_token_ids = cast(Any, completion_choice.token_ids) # type: ignore + elif hasattr(completion_choice, "provider_specific_fields") and "token_ids" in completion_choice.provider_specific_fields: # type: ignore + response_token_ids = cast(Any, completion_choice.provider_specific_fields["token_ids"]) # type: ignore + else: + raise ValueError(f"Completion token ids not found in call: {call}") + if not isinstance(response_token_ids, list) or not all(isinstance(x, int) for x in response_token_ids): # type: ignore + raise ValueError(f"Invalid completion token ids. Must be a list of ints. Got: {response_token_ids}") + response_token_ids = cast(Sequence[int], response_token_ids) + if len(response_token_ids) == 0: + raise ValueError("Completion token ids is empty.") + return response_token_ids + + def _get_logprobs(self, call: Union[AnnotatedChatCompletionCall, ChatCompletionCall]) -> Optional[Sequence[float]]: + logprobs = call.response.choices[0].logprobs + if logprobs is not None: + content_logprobs = logprobs.content + if content_logprobs is not None: + return [logprob.logprob for logprob in content_logprobs] + return None + + def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[TokensTriplet]: + return self.to_triplets(source) + + +class ToTokensAccumulations(Adapter[Sequence[TokensTriplet], Sequence[TokensAccumulation]]): + """Assemble multiple token input-output triplets into accumulated token sequences. + + Args: + diagnosis: Whether to include diagnosis information in the resulting TokensAccumulation. + tokenizer: An optional tokenizer to decode token IDs to text for diagnosis. + """ + + def __init__(self, diagnosis: bool = False, tokenizer: Optional[PreTrainedTokenizer] = None): + self.diagnosis = diagnosis + self.tokenizer = tokenizer + + def _triplet_to_accumulation( + self, triplet: TokensTriplet, diagnosis_info: Optional[TokensAccumulationDiagnosis] + ) -> TokensAccumulation: + if triplet.action.logprobs is not None: + logprobs = [0.0] * len(triplet.observation.token_ids) + list(triplet.action.logprobs) + else: + logprobs = None + + return TokensAccumulation( + token_ids=[*triplet.observation.token_ids, *triplet.action.token_ids], + image_urls=triplet.observation.image_urls, + logprobs=logprobs, + response_mask=[0] * len(triplet.observation.token_ids) + [1] * len(triplet.action.token_ids), + final_reward=triplet.reward, + raw_calls=[triplet.raw_call], + diagnosis_info=diagnosis_info, + ) + + def _special_token_sequence(self, ids: Sequence[int]) -> List[int]: + assert self.tokenizer is not None, "Tokenizer must be provided for special token sequence extraction." + return [id for id in ids if id in self.tokenizer.all_special_ids] + + def _non_special_token_sequence(self, ids: Sequence[int]) -> List[int]: + assert self.tokenizer is not None, "Tokenizer must be provided for non-special token sequence extraction." + return [id for id in ids if id not in self.tokenizer.all_special_ids] + + def _diagnose_mismatch( + self, prev: TokensAccumulation, next: TokensTriplet + ) -> Optional[TokensAccumulationDiagnosis]: + if not self.diagnosis: + return None + + if self.tokenizer is None: + raise ValueError("Tokenizer must be provided for diagnosis.") + + image_urls_match = is_prefix(prev.image_urls, next.observation.image_urls) + + # Check whether the special tokens match + next_special_ids = self._special_token_sequence(next.observation.token_ids) + prev_special_ids = self._special_token_sequence(prev.token_ids) + special_tokens_match = is_prefix(prev_special_ids, next_special_ids) + + # Check whether the non-special tokens match + next_non_special_ids = self._non_special_token_sequence(next.observation.token_ids) + prev_non_special_ids = self._non_special_token_sequence(prev.token_ids) + non_special_tokens_match = is_prefix(prev_non_special_ids, next_non_special_ids) + + # Check whether the detokenized text matches + next_string = self.tokenizer.decode(next.observation.token_ids, skip_special_tokens=True) # type: ignore + prev_string = self.tokenizer.decode(prev.token_ids, skip_special_tokens=True) # type: ignore + detokenized_text_match = next_string.startswith(prev_string) + + return TokensAccumulationDiagnosis( + special_tokens_mismatch=not special_tokens_match, + non_special_tokens_mismatch=not non_special_tokens_match, + detokenized_text_mismatch=not detokenized_text_match, + image_urls_mismatch=not image_urls_match, + accumulation_prev=prev, + special_tokens_prev=prev_special_ids, + special_tokens_next=next_special_ids, + detokenized_text_prev=prev_string, + detokenized_text_next=next_string, + ) + + def _attempt_to_merge(self, prev: TokensAccumulation, next: TokensTriplet) -> List[TokensAccumulation]: + # Check if we can merge the next triplet into the previous accumulation + if not is_prefix(prev.image_urls, next.observation.image_urls): + return [prev, self._triplet_to_accumulation(next, self._diagnose_mismatch(prev, next))] + # Merge token ids + if not is_prefix(prev.token_ids, next.observation.token_ids): + return [prev, self._triplet_to_accumulation(next, self._diagnose_mismatch(prev, next))] + tokens_to_add = [*next.observation.token_ids[len(prev.token_ids) :], *next.action.token_ids] + if prev.logprobs is not None and next.action.logprobs is not None: + # Add zeros only for observation extension tokens, not for action tokens + observation_extension_len = len(next.observation.token_ids) - len(prev.token_ids) + new_logprobs = list(prev.logprobs) + [0.0] * observation_extension_len + list(next.action.logprobs) + else: + new_logprobs = None + response_mask_to_add = [0] * (len(next.observation.token_ids) - len(prev.token_ids)) + [1] * len( + next.action.token_ids + ) + new_reward = ( + (prev.final_reward or 0.0) + (next.reward or 0.0) + if next.reward is not None or prev.final_reward is not None + else None + ) + + return [ + TokensAccumulation( + token_ids=[*prev.token_ids, *tokens_to_add], + image_urls=next.observation.image_urls, + logprobs=new_logprobs, + response_mask=[*prev.response_mask, *response_mask_to_add], + final_reward=new_reward, + raw_calls=[*prev.raw_calls, next.raw_call], + diagnosis_info=None, + ) + ] + + def adapt(self, source: Sequence[TokensTriplet]) -> Sequence[TokensAccumulation]: + accumulations: List[TokensAccumulation] = [] + for triplet in source: + if not accumulations: + accumulations.append(self._triplet_to_accumulation(triplet, None)) + else: + last_accumulation = accumulations[-1] + accumulations = accumulations[:-1] + self._attempt_to_merge(last_accumulation, triplet) + return accumulations + + +class ToPromptCompletionTriplets( + ToTripletMixin[PromptCompletionTriplet], + Adapter[BaseAdaptingSequence[AdaptingSpan], Sequence[PromptCompletionTriplet]], +): + """Convert annotated chat completion calls to prompt-completion triplets. + + Args: + strict: Whether to raise an exception if the triplet cannot be created. + If False, the exception will be added to the list of exceptions and the triplet will be skipped. + The exceptions will also be raised when the resulting sequence is empty. + If True, the exception will be raised. + Default is False. + """ + + def to_triplet( + self, call: Union[AnnotatedChatCompletionCall, ChatCompletionCall] + ) -> Union[PromptCompletionTriplet, BaseException]: + try: + return PromptCompletionTriplet.model_construct( + observation=call.request, + action=call.response, + reward=self.get_reward(call), + done=False, # False by now + raw_call=call, + ) + except Exception as exc: + if self.strict: + raise exc + return exc + + def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[PromptCompletionTriplet]: + return self.to_triplets(source) + + +class ToPromptCompletionAccumulations( + Adapter[Sequence[PromptCompletionTriplet], Sequence[PromptCompletionAccumulation]] +): + """Assemble multiple prompt-completion triplets into accumulated prompt-completion pairs.""" + + def _to_messages(self, completion: ChatCompletion) -> List[ChatCompletionAssistantMessageParam]: + ChatCompletionAssistantMessageParamType = TypeAdapter(ChatCompletionAssistantMessageParam) + # Convert message to dict first since TypeAdapter expects dict for TypedDict validation + # Exclude none to avoid issues like tools=None + message_dict = completion.choices[0].message.model_dump(exclude_none=True) + validated_message = ChatCompletionAssistantMessageParamType.validate_python(message_dict) + return [to_plain_object(validated_message, [])] + + def _to_accumulation(self, triplet: PromptCompletionTriplet) -> PromptCompletionAccumulation: + tools = [*triplet.observation["tools"]] if "tools" in triplet.observation else None + return PromptCompletionAccumulation.model_construct( + messages=[*triplet.observation["messages"], *self._to_messages(triplet.action)], + tools=tools, + final_reward=triplet.reward, + raw_calls=[triplet.raw_call], + ) + + def _attempt_to_merge( + self, prev: PromptCompletionAccumulation, next: PromptCompletionTriplet + ) -> List[PromptCompletionAccumulation]: + next_acc = self._to_accumulation(next) + # Check if we can merge the next triplet into the previous accumulation + if prev.tools != next_acc.tools: + # Cannot merge because tools are different + return [prev, next_acc] + if not is_prefix(prev.messages, next_acc.messages): + # Cannot merge because messages do not match + return [prev, next_acc] + # Merge messages + merged_messages = [*prev.messages, *next_acc.messages[len(prev.messages) :]] + new_reward = ( + (prev.final_reward or 0.0) + (next.reward or 0.0) + if next.reward is not None or prev.final_reward is not None + else None + ) + return [ + PromptCompletionAccumulation( + messages=merged_messages, + tools=prev.tools, + final_reward=new_reward, + raw_calls=[*prev.raw_calls, next.raw_call], + ) + ] + + def adapt(self, source: Sequence[PromptCompletionTriplet]) -> Sequence[PromptCompletionAccumulation]: + accumulations: List[PromptCompletionAccumulation] = [] + for triplet in source: + if not accumulations: + accumulations.append(self._to_accumulation(triplet)) + else: + last_accumulation = accumulations[-1] + accumulations = accumulations[:-1] + self._attempt_to_merge(last_accumulation, triplet) + return accumulations + + +class PropagateRewards(Adapter[Sequence[TripletOrAccumulation], Sequence[TripletOrAccumulation]]): + """Propagate rewards forward or backward from one triplet or accumulation to the next.""" + + def __init__(self, direction: Literal["forward", "backward"]) -> None: + self.direction = direction + + def adapt(self, source: Sequence[T_triplet_or_accumulation]) -> Sequence[T_triplet_or_accumulation]: + prev_reward: Optional[float] = None + if self.direction == "forward": + scan_order = list(range(len(source))) + else: + scan_order = list(reversed(range(len(source)))) + transformed_source = [*source] + for idx in scan_order: + item = source[idx] + if isinstance(item, PromptCompletionTriplet) or isinstance(item, TokensTriplet): + if item.reward is not None: + prev_reward = item.reward + else: + transformed_source[idx] = item.model_copy(update={"reward": prev_reward}) + else: + if item.final_reward is not None: + prev_reward = item.final_reward + else: + transformed_source[idx] = item.model_copy(update={"final_reward": prev_reward}) + + return transformed_source diff --git a/agentlightning/adapter/preprocess.py b/agentlightning/adapter/preprocess.py new file mode 100644 index 000000000..01da8c796 --- /dev/null +++ b/agentlightning/adapter/preprocess.py @@ -0,0 +1,529 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Span re-organization adapters.""" + +from __future__ import annotations + +import logging +import time +from collections import defaultdict +from typing import Dict, List, Literal, Sequence, Set, Tuple, TypeVar + +from agentlightning.semconv import AGL_VIRTUAL +from agentlightning.types.adapter import AdaptingSequence, AdaptingSpan, Tree +from agentlightning.types.tracer import Span, SpanLike +from agentlightning.utils.id import generate_id + +from .base import Adapter + +T_from = TypeVar("T_from") +T_to = TypeVar("T_to") +T_span = TypeVar("T_span", bound=Span, covariant=True) + +logger = logging.getLogger(__name__) + + +def default_span_order(span: Span) -> tuple[int, float, float]: + """Return a tuple for sorting spans by sequence ID, then start time, then end time. + + Args: + span: The span to extract ordering keys from. + + Returns: + A tuple of (sequence_id, start_time, end_time) for use as a sort key. + + Raises: + ValueError: If the span has no start or end time set. + """ + return (span.sequence_id, span.ensure_start_time(), span.ensure_end_time()) + + +class _TreeLikeGraph: + """A simple directed graph implementation for span hierarchy. + + In preparation before forming a tree. + """ + + def __init__(self) -> None: + self.forward_graph: Dict[str, List[str]] = defaultdict(list) + self.parent_map: Dict[str, str] = {} + self.root_ids: Set[str] = set() + self.spans_dict: Dict[str, Span] = {} + + def add_edge(self, from_node: str, to_node: str) -> None: + self.forward_graph[from_node].append(to_node) + + def move_subtree(self, node_id: str, new_parent_id: str) -> None: + old_parent_id = self.parent_map.get(node_id, None) + if old_parent_id is not None: + self.forward_graph[old_parent_id].remove(node_id) + self.add_edge(new_parent_id, node_id) + self.parent_map[node_id] = new_parent_id + if node_id in self.root_ids: + self.root_ids.remove(node_id) + + def validate_no_cycles(self) -> None: + visited = set[str]() + + def visit(node_id: str) -> None: + if node_id in visited: + raise ValueError(f"Cycle detected in the graph at node {node_id}") + visited.add(node_id) + for child_id in self.forward_graph[node_id]: + visit(child_id) + + for root_id in self.root_ids: + visit(root_id) + + if len(visited) != len(self.forward_graph): + raise ValueError("Some nodes are not reachable from the roots") + + def compute_depths(self) -> Dict[str, int]: + depths = {root: 0 for root in self.root_ids} + + def visit(node_id: str) -> None: + for child_id in self.forward_graph[node_id]: + depths[child_id] = depths[node_id] + 1 + visit(child_id) + + for root_id in self.root_ids: + visit(root_id) + + return depths + + def compute_ancestors(self) -> Dict[str, Set[str]]: + ancestors = {root: set[str]() for root in self.root_ids} + + def visit(node_id: str) -> None: + for child_id in self.forward_graph[node_id]: + ancestors[child_id] = ancestors[node_id] | {node_id} + visit(child_id) + + for root_id in self.root_ids: + visit(root_id) + + return ancestors + + def to_tree(self) -> Tree[AdaptingSpan]: + def build_subtree(node_id: str) -> Tree[AdaptingSpan]: + children = [build_subtree(child_id) for child_id in self.forward_graph.get(node_id, [])] + item = AdaptingSpan.from_span(self.spans_dict[node_id], None) + return Tree(item, sorted(children, key=lambda child: default_span_order(child.item))) + + if len(self.root_ids) != 1: + raise ValueError( + "Cannot convert to tree: multiple or no roots found; enable repair options of ToTree to fix this." + ) + root_id = next(iter(self.root_ids)) + return build_subtree(root_id) + + @staticmethod + def from_spans(spans: Sequence[Span], logs_invalid_parent: bool = True) -> _TreeLikeGraph: + graph = _TreeLikeGraph() + + valid_span_ids = set(span.span_id for span in spans) + graph.root_ids = set(span.span_id for span in spans if span.parent_id is None) + graph.spans_dict = {span.span_id: span for span in spans} + for span in spans: + if span.parent_id is not None: + if span.parent_id in valid_span_ids: + graph.forward_graph[span.parent_id].append(span.span_id) + graph.root_ids.discard(span.span_id) + graph.parent_map[span.span_id] = span.parent_id + # We don't care about the wrong parent id in the spans dict + # This fix here is only for graph construction + else: + # Span has invalid parent, treat as root + graph.root_ids.add(span.span_id) + if logs_invalid_parent: + logger.debug( + f'Span {span.span_id} has an invalid parent ID "{span.parent_id}". ' + "The parent will be ignored and the span will be treated as a root." + ) + + graph.validate_no_cycles() + + return graph + + +class ToSpans(Adapter[Sequence[SpanLike], Sequence[Span]]): + """Normalize span-like objects (e.g., OpenTelemetry `ReadableSpan`) to [Span][agentlightning.Span]. + + This adapter handles conversion from various span formats to the internal Span type. + Native Span objects pass through unchanged, while OpenTelemetry spans are converted + using the provided default values for rollout, attempt, and sequence identifiers. + + Args: + default_rollout_id: Default rollout ID for converted OpenTelemetry spans. + default_attempt_id: Default attempt ID for converted OpenTelemetry spans. + default_sequence_id: Default sequence ID for converted OpenTelemetry spans. + """ + + def __init__( + self, + default_rollout_id: str = "rollout-dummy", + default_attempt_id: str = "attempt-dummy", + default_sequence_id: int = 0, + ): + self.default_rollout_id = default_rollout_id + self.default_attempt_id = default_attempt_id + self.default_sequence_id = default_sequence_id + + def adapt_one(self, source: SpanLike) -> Span: + """Convert a single span-like object to a Span. + + Args: + source: A Span or OpenTelemetry ReadableSpan to convert. + + Returns: + The converted Span object. Native Spans pass through unchanged. + """ + if isinstance(source, Span): + return source + return Span.from_opentelemetry( + source, + rollout_id=self.default_rollout_id, + attempt_id=self.default_attempt_id, + sequence_id=self.default_sequence_id, + ) + + def adapt(self, source: Sequence[SpanLike]) -> Sequence[Span]: + """Convert a sequence of span-like objects to Spans. + + Args: + source: A sequence of Span or OpenTelemetry ReadableSpan objects. + """ + return [self.adapt_one(span) for span in source] + + +class ToTree(Adapter[Sequence[Span], Tree[AdaptingSpan]]): + """Convert a sequence of spans into a tree structure. + + This adapter organizes flat span sequences into a hierarchical tree based on parent-child + relationships. It can repair various structural issues in the span data: + + - **Bad hierarchy**: Spans that are incorrectly positioned (e.g., dangling spans without + proper parents despite being contained within other spans' time ranges). + - **Multiple roots**: Cases where more than one span has no parent. + + Note: + Spans with invalid parent IDs (referencing non-existent spans) will cause a ValueError. + Use [RepairMalformedSpans][agentlightning.adapter.preprocess.RepairMalformedSpans] with + `ensure_valid_parent_ids=True` before calling this adapter if you need to handle + invalid parent references. + + Args: + repair_bad_hierarchy: Controls hierarchy repair. `"dangling"` repairs only orphaned spans, + `"all"` re-evaluates all span placements, `"none"` skips hierarchy repair. + repair_multiple_roots: If True, creates a virtual root when multiple root spans exist. + + Raises: + TypeError: If the input is not a sequence. + ValueError: If no spans are provided, if any span has an invalid parent ID, + or if the tree cannot be constructed. + """ + + def __init__( + self, + repair_bad_hierarchy: Literal["dangling", "all", "none"] = "dangling", + repair_multiple_roots: bool = True, + ): + self.repair_bad_hierarchy = repair_bad_hierarchy + self.repair_multiple_roots = repair_multiple_roots + + def _find_eligible_parents( + self, + all_spans: Sequence[Span], + current: Span, + graph: _TreeLikeGraph, + cache_depths: Dict[str, int], + ) -> List[Span]: + """We wish to find a good place to insert the span, which is ideally it's sibling or sibling's child. + + Filter the candidates by: (1) must not in current's ancestors; (2) must not be in current's subtree; + (3) must have an ancestor that is the parent of current span; (4) have start and end time covering the current span. + The third condition can be optional if current span has no parent. + + Then sort the candidates by: (1) shortest to longest, (2) deep to shallow. + """ + spans_to_consider: List[Span] = [] + # This needs to be re-computed every time. + # It will be too troublesome to maintain the dynamic cache of ancestors. + ancestors = graph.compute_ancestors() + + for candidate_parent in all_spans: + if candidate_parent.span_id == current.span_id: + continue + if ( + candidate_parent.ensure_start_time() > current.ensure_start_time() + or candidate_parent.ensure_end_time() < current.ensure_end_time() + ): + # If the span is not covering the current span, it cannot be a parent. + continue + if candidate_parent.span_id in ancestors[current.span_id]: + # If the span is in the current's ancestors, it cannot be a parent. + continue + if current.span_id in ancestors[candidate_parent.span_id]: + # If the span is in the current's subtree, it cannot be a parent. + continue + if current.span_id in graph.parent_map: + # If the current span has a parent, the eligible parent must live in the parent's subtree. + if graph.parent_map[current.span_id] not in ancestors[candidate_parent.span_id]: + continue + spans_to_consider.append(candidate_parent) + + # Sort the spans: (1) shortest to longest duration, (2) deeper to shallower (prefer more specific ancestors) + return sorted( + spans_to_consider, + key=lambda span: (span.ensure_end_time() - span.ensure_start_time(), -cache_depths[span.span_id]), + ) + + def _repair_bad_hierarchy(self, source: Sequence[Span]) -> Sequence[Span]: + """Repair bad hierarchy by re-attaching dangling spans or all spans. + + This is based on the chronological relationships between start time and end time of spans. + """ + if self.repair_bad_hierarchy == "none": + return source + + graph = _TreeLikeGraph.from_spans(source) + depths = graph.compute_depths() + + # Scan all the spans by: (1) shallow to deep, (2) longest to shortest, (3) earliest to latest. + scan_order = sorted( + source, + key=lambda span: ( + depths[span.span_id], + -(span.ensure_end_time() - span.ensure_start_time()), + span.ensure_start_time(), + ), + ) + for i, span in enumerate(scan_order): + # Check whether we should repair this span. + # It must be a dangling span, or the user wants to repair all the spans. + if ( + self.repair_bad_hierarchy == "dangling" and span.span_id not in graph.parent_map + ) or self.repair_bad_hierarchy == "all": + # We wish to find a good place to insert the span. + eligible_parents = self._find_eligible_parents(source, span, graph, depths) + if eligible_parents: + new_parent_id = eligible_parents[0].span_id + scan_order[i] = span.model_copy(update={"parent_id": new_parent_id}) + + # Maintain/update the cache + graph.move_subtree(span.span_id, new_parent_id) + + return scan_order + + def _validate_parent_ids(self, source: Sequence[Span]) -> None: + """Validate that all parent IDs reference existing spans. + + Raises: + ValueError: If any span references a non-existent parent. + """ + valid_span_ids: Set[str] = set(span.span_id for span in source) + invalid_refs: List[str] = [] + + for span in source: + if span.parent_id is not None and span.parent_id not in valid_span_ids: + invalid_refs.append(f"{span.span_id} -> {span.parent_id}") + + if invalid_refs: + raise ValueError( + f"Spans reference non-existent parent IDs: {', '.join(invalid_refs)}. " + "Use RepairMalformedSpans with ensure_valid_parent_ids=True to fix this." + ) + + def _repair_multiple_roots(self, source: Sequence[Span]) -> Sequence[Span]: + root_spans = [span for span in source if span.parent_id is None] + + if len(root_spans) <= 1: + return source + + # Create a new root span + new_root_span = Span.from_attributes( + rollout_id=root_spans[0].rollout_id, + attempt_id=root_spans[0].attempt_id, + sequence_id=root_spans[0].sequence_id, + trace_id=root_spans[0].trace_id, + span_id="span-" + generate_id(12), + parent_id=None, + name=AGL_VIRTUAL, + attributes={}, + start_time=min(span.ensure_start_time() for span in root_spans), + end_time=max(span.ensure_end_time() for span in root_spans), + ) + + updated_spans = [ + span.model_copy(update={"parent_id": new_root_span.span_id}) if span in root_spans else span + for span in source + ] + return [new_root_span, *updated_spans] + + def adapt(self, source: Sequence[Span]) -> Tree[AdaptingSpan]: + """Convert a sequence of spans into a tree of AdaptingSpan objects. + + Args: + source: A sequence of Span objects to organize into a tree. + + Returns: + A Tree with AdaptingSpan items representing the hierarchical structure. + + Raises: + TypeError: If source is not a sequence. + ValueError: If source is empty, has invalid parent IDs, or cannot form a valid tree. + """ + if not isinstance(source, Sequence): # pyright: ignore[reportUnnecessaryIsInstance] + raise TypeError(f"Expected a sequence of spans, but got {type(source)}") + if not source: + raise ValueError("No spans provided to create Tree") + + # Validate parent IDs before any processing + self._validate_parent_ids(source) + + source = self._repair_bad_hierarchy(source) + + if self.repair_multiple_roots: + source = self._repair_multiple_roots(source) + + graph = _TreeLikeGraph.from_spans(source) + return graph.to_tree() + + +class ToAdaptingSpans(Adapter[Sequence[Span], AdaptingSequence[AdaptingSpan]]): + """Convert spans to a sorted AdaptingSequence. + + Sorts spans by sequence ID (primary), start time (secondary), and end time (tertiary), + then wraps each in an AdaptingSpan for use in adaptation pipelines. + """ + + def adapt(self, source: Sequence[Span]) -> AdaptingSequence[AdaptingSpan]: + """Sort spans and convert to AdaptingSequence. + + Args: + source: A sequence of Span objects to sort and convert. + + Returns: + An AdaptingSequence containing sorted AdaptingSpan objects. + """ + sorted_spans = sorted(source, key=default_span_order) + return AdaptingSequence([AdaptingSpan.from_span(span, None) for span in sorted_spans]) + + +class RepairMalformedSpans(Adapter[Sequence[Span], Sequence[Span]]): + """Repair common structural issues in span data. + + This adapter fixes several types of malformed span data: + + - **Missing times**: Fills in missing start/end times using the maximum known time. + - **Negative duration**: Adjusts end times that are earlier than start times. + - **Improper nesting**: Expands parent time ranges to contain all children. + - **Invalid parent IDs**: Removes references to non-existent parent spans. + + Spans that don't require repair pass through unchanged (same object reference). + + Args: + ensure_positive_duration: If True, sets end_time = start_time when end < start. + ensure_proper_nesting: If True, expands parent spans to contain children's time ranges. + ensure_valid_parent_ids: If True, sets parent_id to None for orphaned spans. + """ + + def __init__( + self, + ensure_positive_duration: bool = True, + ensure_proper_nesting: bool = True, + ensure_valid_parent_ids: bool = True, + ) -> None: + self.ensure_positive_duration = ensure_positive_duration + self.ensure_proper_nesting = ensure_proper_nesting + self.ensure_valid_parent_ids = ensure_valid_parent_ids + + def _repair_start_end_time(self, source: Sequence[Span]) -> List[Span]: + times_set = set[float]() + for span in source: + if span.start_time is not None: + times_set.add(span.start_time) + if span.end_time is not None: + times_set.add(span.end_time) + + if not times_set: + logger.debug("No times set in the spans. Setting all the time to current time.") + current_time = time.time() + else: + current_time = max(times_set) + + new_spans: List[Span] = [] + + for span in source: + update_fields: Dict[str, float] = {} + if span.start_time is None: + update_fields["start_time"] = current_time + if span.end_time is None: + update_fields["end_time"] = current_time + if ( + self.ensure_positive_duration + and span.start_time is not None + and span.end_time is not None + and span.end_time < span.start_time + ): + update_fields["end_time"] = span.start_time + if update_fields: + new_spans.append(span.model_copy(update=update_fields)) + else: + new_spans.append(span) + + return new_spans + + def _repair_nesting(self, source: Sequence[Span]) -> List[Span]: + graph = _TreeLikeGraph.from_spans(source, logs_invalid_parent=False) + spans = {span.span_id: span for span in source} + + def visit(node_id: str) -> Tuple[float, float]: + child_start_end_times: List[Tuple[float, float]] = [] + cur_start_time = spans[node_id].ensure_start_time() + cur_end_time = spans[node_id].ensure_end_time() + if graph.forward_graph.get(node_id): + for child_id in graph.forward_graph[node_id]: + child_start_end_times.append(visit(child_id)) + start_times, end_times = zip(*child_start_end_times) + start_time = min(cur_start_time, *start_times) + end_time = max(cur_end_time, *end_times) + if start_time != cur_start_time or end_time != cur_end_time: + spans[node_id] = spans[node_id].model_copy(update={"start_time": start_time, "end_time": end_time}) + + return spans[node_id].ensure_start_time(), spans[node_id].ensure_end_time() + + for root_id in graph.root_ids: + visit(root_id) + + return [spans[span.span_id] for span in source] + + def _repair_invalid_parent_ids(self, source: Sequence[Span]) -> List[Span]: + valid_span_ids: Set[str] = set(span.span_id for span in source) + new_spans: List[Span] = [] + + for span in source: + if span.parent_id is not None and span.parent_id not in valid_span_ids: + new_spans.append(span.model_copy(update={"parent_id": None, "parent": None})) + else: + new_spans.append(span) + + return new_spans + + def adapt(self, source: Sequence[Span]) -> Sequence[Span]: + """Repair malformed spans according to the configured repair options. + + Args: + source: A sequence of Span objects to repair. + + Returns: + A sequence of repaired Span objects. Unmodified spans retain their original + object reference. + """ + # This step is always performed first no matter whether the flags are set. + new_spans = self._repair_start_end_time(source) + if self.ensure_proper_nesting: + new_spans = self._repair_nesting(new_spans) + if self.ensure_valid_parent_ids: + new_spans = self._repair_invalid_parent_ids(new_spans) + return new_spans diff --git a/agentlightning/types/adapter.py b/agentlightning/types/adapter.py new file mode 100644 index 000000000..71f1d6664 --- /dev/null +++ b/agentlightning/types/adapter.py @@ -0,0 +1,587 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Data formats used by adapters, usually the target format converted from trace spans.""" + +from __future__ import annotations + +import logging +import weakref +from typing import ( + Any, + Callable, + Dict, + Generic, + Iterable, + Iterator, + Literal, + Optional, + Protocol, + Sequence, + TypeVar, + Union, + overload, +) + +from openai.types.chat import ( + ChatCompletion, + ChatCompletionMessageParam, + ChatCompletionToolUnionParam, + CompletionCreateParams, +) +from pydantic import BaseModel, ConfigDict, Field, model_validator +from typing_extensions import Self + +from agentlightning.semconv import LinkPydanticModel, RewardPydanticModel + +from .tracer import Attributes, Span + +T = TypeVar("T") + +logger = logging.getLogger(__name__) + + +class BaseAdaptingSequenceItem(Protocol): + + def with_container(self, container: BaseAdaptingSequence[Self]) -> Self: + """Create a new item with the same properties but different container.""" + raise NotImplementedError() + + +T_co = TypeVar("T_co", covariant=True, bound=BaseAdaptingSequenceItem) +V_co = TypeVar("V_co", covariant=True, bound=BaseAdaptingSequenceItem) + + +class BaseAdaptingSequence(Sequence[T_co], Generic[T_co]): + """Interface that makes adapter easier to work with sequences.""" + + @overload + def __getitem__(self, index: int) -> T_co: ... + + @overload + def __getitem__(self, index: slice) -> Sequence[T_co]: ... + + def __getitem__(self, index: Union[int, slice]) -> Union[T_co, Sequence[T_co]]: + return self.get(index) + + def __iter__(self) -> Iterator[T_co]: + return iter(self.traverse()) + + def __len__(self) -> int: + return self.size() + + def get(self, index: Union[int, slice]) -> Union[T_co, Sequence[T_co]]: + """Get the index-th item in the sequence.""" + raise NotImplementedError() + + def map(self, func: Callable[[T_co], V_co]) -> BaseAdaptingSequence[V_co]: + """Map a function over all items in the sequence.""" + raise NotImplementedError() + + def retain(self, predicate: Callable[[T_co], bool]) -> BaseAdaptingSequence[T_co]: + """Filter items in the sequence by a predicate (true for items to be kept). + + Depending on the implementation, the returned sequence may contain more or less items than a standard filter. + """ + raise NotImplementedError() + + def prune(self, predicate: Callable[[T_co], bool]) -> BaseAdaptingSequence[T_co]: + """Prune items in the sequence by a predicate (true for items to be pruned). + + Depending on the implementation, the returned sequence may contain more or less items than a standard prune. + """ + raise NotImplementedError() + + def size(self) -> int: + """Get the size of the sequence.""" + raise NotImplementedError() + + def traverse(self) -> Iterable[T_co]: + """Traverse all items in the sequence.""" + raise NotImplementedError() + + +# General containers + + +class Tree(BaseAdaptingSequence[T_co], Generic[T_co]): + """This is a generic tree data structure that can be used to represent the structure of a tree. + + This data structure is immutable. + """ + + def __init__(self, item: T_co, children: Sequence[Tree[T_co]]) -> None: + self._children = children + self._parent: Optional[weakref.ReferenceType[Tree[T_co]]] = None + for child in self._children: + child._parent = weakref.ref(self) # type: ignore + self._item = item.with_container(self) + + @property + def item(self) -> T_co: + return self._item + + @property + def children(self) -> Sequence[Tree[T_co]]: + return self._children + + @property + def parent(self) -> Optional[Tree[T_co]]: + return self._parent() if self._parent is not None else None + + def traverse(self) -> Iterable[T_co]: + yield self._item + for child in self._children: + yield from child.traverse() + + def size(self) -> int: + return 1 + sum(child.size() for child in self._children) + + def get(self, index: Union[int, slice]) -> Union[T_co, Sequence[T_co]]: + """Get the index-th item in the tree (O(n) time complexity). + + I think this is not efficient, but it's seldomly used. + """ + return list(self.traverse())[index] + + def map(self, func: Callable[[T_co], V_co]) -> Tree[V_co]: + """Map a function over all items in the tree.""" + return Tree(func(self._item), [child.map(func) for child in self._children]) + + def _retain_subtree(self, predicate: Callable[[T_co], bool]) -> Optional[Tree[T_co]]: + if predicate(self._item): + # If the current node satisfies the predicate, retain the subtree + return self + + subtrees = [child._retain_subtree(predicate) for child in self._children] + if all(subtree is None for subtree in subtrees): + # no subtrees satisfy the predicate, remove the current node + return None + + return Tree(self._item, [subtree for subtree in subtrees if subtree is not None]) + + def retain(self, predicate: Callable[[T_co], bool]) -> Tree[T_co]: + """Prune the tree by retaining subtrees with root nodes that satisfy the predicate. + + The root node is always retained. + """ + return self._retain_subtree(predicate) or Tree(self._item, []) + + def prune(self, predicate: Callable[[T_co], bool]) -> Tree[T_co]: + """Prune the tree by removing nodes that satisfy the predicate. + + The root node is always retained. + """ + return Tree(self._item, [child.prune(predicate) for child in self._children if not predicate(child._item)]) + + def visualize(self, filename: str, item_to_str: Callable[[T_co], str]) -> None: + """Render the tree with Graphviz for debugging purposes. + + Args: + filename: Base filename for the generated `.png` diagram (without extension). + + !!! note + + The method requires the optional `graphviz` dependency to be available in the runtime + environment. + """ + import graphviz + + dot = graphviz.Digraph(comment="Tree") + + def visit(node: Tree[T_co]): + dot.node(str(id(node)), item_to_str(node.item)) # type: ignore + for child in node._children: + visit(child) + dot.edge(str(id(node)), str(id(child))) # type: ignore + + visit(self) + dot.render(filename, format="png", cleanup=True) # type: ignore + + +class AdaptingSequence(BaseAdaptingSequence[T_co], Generic[T_co]): + """A simple list implementation of AdaptingSequence.""" + + def __init__(self, items: Sequence[T_co]) -> None: + # Set container on items if they are AdaptingSpan instances + self._items = [item.with_container(self) for item in items] + + def get(self, index: Union[int, slice]) -> Union[T_co, Sequence[T_co]]: + return self._items[index] + + def traverse(self) -> Iterable[T_co]: + return iter(self._items) + + def size(self) -> int: + return len(self._items) + + def map(self, func: Callable[[T_co], V_co]) -> AdaptingSequence[V_co]: + return AdaptingSequence([func(item) for item in self._items]) + + def retain(self, predicate: Callable[[T_co], bool]) -> AdaptingSequence[T_co]: + return AdaptingSequence([item for item in self._items if predicate(item)]) + + def prune(self, predicate: Callable[[T_co], bool]) -> AdaptingSequence[T_co]: + return AdaptingSequence([item for item in self._items if not predicate(item)]) + + +class AdaptingSpan(Span): + """A span that has been adapted to a different format. + + This class extends the base [`Span`][agentlightning.Span] class to represent spans that have + been converted to a different format by an adapter. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + data: Any + """The data in the adapted format. Could be annotations, calls, or other structured data.""" + + container: Optional[BaseAdaptingSequence[AdaptingSpan]] = None + """An optional container that holds multiple adapted data items.""" + + def with_data(self, data: Any, override: Literal["silent", "warning", "forbidden"] = "warning") -> AdaptingSpan: + """Create a new [`AdaptingSpan`][agentlightning.AdaptingSpan] with the same properties but different adapted data. + + Args: + data: The new adapted data. + + Returns: + An instance of [`AdaptingSpan`][agentlightning.AdaptingSpan] with the same properties as + the current span but with the provided adapted data. + """ + if self.data is not None: + if override == "forbidden": + raise ValueError( + "Overwriting existing data on AdaptingSpan is forbidden. " + f"Current data: {self.data}, New data: {data}" + ) + elif override == "warning": + logger.warning( + "Found annotation on an adapting span with existing data; overwriting the data. " + f"Current data: {self.data}, New data: {data}" + ) + return self.model_copy(update={"data": data}) + + def with_container(self, container: BaseAdaptingSequence[AdaptingSpan]) -> AdaptingSpan: + """Create a new [`AdaptingSpan`][agentlightning.AdaptingSpan] with the same properties but different container.""" + return self.model_copy(update={"container": container}) + + @classmethod + def from_span(cls, span: Span, data: Any) -> AdaptingSpan: + """Create an [`AdaptingSpan`][agentlightning.AdaptingSpan] from a base [`Span`][agentlightning.Span]. + + Args: + span: The base span to convert. + data: The data in the adapted format. + + Returns: + An instance of [`AdaptingSpan`][agentlightning.AdaptingSpan] with the same properties as + the input span and the provided adapted data. + """ + if isinstance(span, AdaptingSpan): + return span.model_copy(update={"data": data}) + else: + return AdaptingSpan.model_validate({**span.model_dump(), "data": data}) + + def children(self) -> Sequence[AdaptingSpan]: + """Get the child spans as [`AdaptingSpan`][agentlightning.AdaptingSpan] instances. + + Only applicable when the container is a [`Tree`][agentlightning.Tree]. + """ + if self.container is None or not isinstance(self.container, Tree): + raise ValueError("AdaptingSpan.children() is only applicable when container is non-empty and a Tree.") + return [child.item for child in self.container.children] + + def siblings(self) -> Sequence[AdaptingSpan]: + """Get the sibling spans as [`AdaptingSpan`][agentlightning.AdaptingSpan] instances. + + Only applicable when the container is a [`Tree`][agentlightning.Tree]. + """ + if self.container is None or not isinstance(self.container, Tree): + raise ValueError("AdaptingSpan.siblings() is only applicable when container is non-empty and a Tree.") + parent_tree = self.container.parent + if parent_tree is None: + return [] + return [child.item for child in parent_tree.children if child is not self.container] + + def parent_span(self) -> Optional[AdaptingSpan]: + """Get the parent span if available. + + Only applicable when the container is a [`Tree`][agentlightning.Tree]. + """ + if self.container is None or not isinstance(self.container, Tree): + raise ValueError("AdaptingSpan.parent() is only applicable when container is non-empty and a Tree.") + parent_tree = self.container.parent + return parent_tree.item if parent_tree is not None else None + + +# Annotation-related types + +AnnotationType = Literal["agent", "general", "message", "object", "exception", "operation"] + + +class Annotation(BaseModel): + """An annotation is an approach to parse a span into some kind of structured attachments to another object. + + Not necessarily an [AGL_ANNOTATION][agentlightning.semconv.AGL_ANNOTATION] span. + + Note that a span can be parsed in multiple ways, and annotation is just one of them. + """ + + annotation_type: AnnotationType + """Type of the annotation.""" + + links: Sequence[LinkPydanticModel] = Field(default_factory=list) # type: ignore + """Links to other spans or objects.""" + + +class AgentAnnotation(Annotation): + """Parsed from [OTel Agent Spans](https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-agent-spans/).""" + + annotation_type: AnnotationType = "agent" + """Type of the annotation.""" + + id: Optional[str] = None + """The unique identifier of the GenAI agent.""" + + name: Optional[str] = None + """Human-readable name of the GenAI agent provided by the application.""" + + description: Optional[str] = None + """Free-form description of the GenAI agent provided by the application.""" + + +class GeneralAnnotation(Annotation): + """An annotation payload that is parsed from an [annotation][agentlightning.semconv.AGL_ANNOTATION] span.""" + + annotation_type: AnnotationType = "general" + """Type of the annotation.""" + + rewards: Sequence[RewardPydanticModel] = Field(default_factory=list) # type: ignore + """Reward dimensions and values.""" + + primary_reward: Optional[float] = None + """Primary reward value.""" + + tags: Sequence[str] = Field(default_factory=list) + """Tags for the annotation.""" + + custom_fields: Dict[str, Any] = Field(default_factory=dict) + """Raw payload from the annotation.""" + + +class MessageAnnotation(Annotation): + """A log message that is parsed from a [message][agentlightning.semconv.AGL_MESSAGE] span.""" + + annotation_type: AnnotationType = "message" + """Type of the annotation.""" + + message: str + """Message text.""" + + +class ObjectAnnotation(Annotation): + """An artifact that is parsed from a [object][agentlightning.semconv.AGL_OBJECT] span.""" + + annotation_type: AnnotationType = "object" + """Type of the annotation.""" + + object: Any + """The object payload.""" + + +class ExceptionAnnotation(Annotation): + """An exception that is parsed from an [exception][agentlightning.semconv.AGL_EXCEPTION] span.""" + + annotation_type: AnnotationType = "exception" + """Type of the annotation.""" + + type: str + """Type of the exception.""" + + message: str + """Message of the exception.""" + + stacktrace: Optional[str] = None + """Stacktrace of the exception.""" + + +class OperationAnnotation(Annotation): + """An operation that is parsed from an [operation][agentlightning.semconv.AGL_OPERATION] span.""" + + annotation_type: AnnotationType = "operation" + """Type of the annotation.""" + + name: str + """Name of the operation.""" + + input: Optional[Any] = None + """Input of the operation.""" + + output: Optional[Any] = None + """Output of the operation.""" + + +class ChatCompletionCall(BaseModel): + """Corresponding to exactly one chat completion call. + + OpenAI chat completion request and response are used as standards here. + Convert to other chat completion formats if needed. + """ + + request: CompletionCreateParams + """OpenAI chat completion request parameters.""" + + response: ChatCompletion + """OpenAI chat completion response payload.""" + + malformed_fields: Dict[str, Attributes] + """Fields that are not supported by the adapter. + + Mapping from span names to a dict of malformed fields. + """ + + +class AnnotatedChatCompletionCall(ChatCompletionCall): + """A chat completion call with annotations.""" + + annotations: Sequence[Annotation] + """Annotations for the chat completion call.""" + + +# Algorithm-specific requirements + +T_observation = TypeVar("T_observation") +T_action = TypeVar("T_action") + + +class Triplet(BaseModel, Generic[T_observation, T_action]): + """A triplet of observation, action and reward.""" + + observation: T_observation + """Observation for the model input.""" + + action: T_action + """Action from the model output.""" + + reward: Optional[float] + """Reward of the model input.""" + + done: bool + """Whether it's the end of the trajectory.""" + + raw_call: Union[AnnotatedChatCompletionCall, ChatCompletionCall] + """Raw chat completion call.""" + + +class Accumulation(BaseModel): + """Accumulation from multiple triplets.""" + + final_reward: Optional[float] + """Single reward value for the entire sequence. + + An accumulation can only have one single final reward value. + """ + + raw_calls: Sequence[Union[AnnotatedChatCompletionCall, ChatCompletionCall]] + """Raw chat completion calls. The order of the calls must be the same as the order of the token IDs.""" + + +class TokenInput(BaseModel): + """Token-based model input.""" + + token_ids: Sequence[int] + """Token IDs of the model input.""" + + image_urls: Sequence[str] + """A list of image URLs. Could be pointers to local files or base64-encoded images.""" + + +class TokenOutput(BaseModel): + """Token-based model output.""" + + token_ids: Sequence[int] + """Token IDs of the model output.""" + + logprobs: Optional[Sequence[float]] + """Log probabilities of the model output.""" + + @model_validator(mode="after") + def validate_logprobs(self) -> Self: + if self.logprobs is not None: + if len(self.logprobs) != len(self.token_ids): + raise ValueError("Log probabilities must be the same length as the token IDs.") + return self + + +class TokensTriplet(Triplet[TokenInput, TokenOutput]): + """A triplet of token IDs for the input and output, useful for reinforcement learning. + + This is not a stable interface and the fields here highly depend on RL implementations. + """ + + +class TokensAccumulationDiagnosis(BaseModel): + + special_tokens_mismatch: bool + """Whether there is a mismatch in special tokens.""" + + non_special_tokens_mismatch: bool + """Whether there is a mismatch in non-special tokens.""" + + detokenized_text_mismatch: bool + """Whether there is a mismatch in detokenized text.""" + + image_urls_mismatch: bool + """Whether there is a mismatch in image URLs.""" + + accumulation_prev: TokensAccumulation + """The previous accumulation which this triplet fails to match.""" + + special_tokens_prev: Sequence[int] + """Special tokens in the previous accumulation.""" + + special_tokens_next: Sequence[int] + """Special tokens in the next sequence.""" + + detokenized_text_prev: str + """Detokenized text in the previous accumulation.""" + + detokenized_text_next: str + """Detokenized text in the next sequence.""" + + +class TokensAccumulation(Accumulation): + """A sequence of token IDs that are accumulated from multiple model calls. + + Output is implied in the token IDs. + """ + + token_ids: Sequence[int] + """Token IDs of the model input and output.""" + + image_urls: Sequence[str] + """Image URLs of the model input and output.""" + + logprobs: Optional[Sequence[float]] + """Log probabilities of the model output.""" + + response_mask: Sequence[int] + """Mask for the response tokens. Must a sequence of 0s and 1s, with 1s for the completion tokens and 0s for the prompt tokens.""" + + diagnosis_info: Optional[TokensAccumulationDiagnosis] + """Diagnosis information for token accumulation mismatches.""" + + +class PromptCompletionTriplet(Triplet[CompletionCreateParams, ChatCompletion]): + """A triplet of prompt and completion.""" + + +class PromptCompletionAccumulation(Accumulation): + """A conversation that is accumulated from multiple model calls.""" + + messages: Sequence[ChatCompletionMessageParam] + """Messages of the conversation.""" + + tools: Optional[Sequence[ChatCompletionToolUnionParam]] + """Tools provided for the conversation.""" diff --git a/agentlightning/types/tracer.py b/agentlightning/types/tracer.py index 2db9a2f48..e7825276e 100644 --- a/agentlightning/types/tracer.py +++ b/agentlightning/types/tracer.py @@ -477,6 +477,18 @@ def from_core_fields( status=core.status, ) + def ensure_start_time(self) -> float: + """Get the start time or raise an error if it's not set.""" + if self.start_time is None: + raise ValueError("Start time is not set. Try to use the `RepairTime` adapter to repair the span.") + return self.start_time + + def ensure_end_time(self) -> float: + """Get the end time or raise an error if it's not set.""" + if self.end_time is None: + raise ValueError("End time is not set. Try to use the `RepairTime` adapter to repair the span.") + return self.end_time + class SpanNames(str, Enum): """Enumerated span names recognised by Agent-lightning. Deprecated in favor of [semconv][agentlightning.semconv].""" diff --git a/agentlightning/utils/otel.py b/agentlightning/utils/otel.py index 59909d2ff..2bf8c17eb 100644 --- a/agentlightning/utils/otel.py +++ b/agentlightning/utils/otel.py @@ -32,6 +32,7 @@ "make_tag_attributes", "extract_tags_from_attributes", "make_link_attributes", + "check_linked_span", "query_linked_spans", "extract_links_from_attributes", "filter_attributes", @@ -229,7 +230,42 @@ def make_link_attributes(links: Dict[str, str]) -> Dict[str, Any]: return flatten_attributes({LightningSpanAttributes.LINK.value: link_list}, expand_leaf_lists=True) -def query_linked_spans(spans: Sequence[T_SpanLike], links: List[LinkPydanticModel]) -> List[T_SpanLike]: +def check_linked_span(span: SpanLike, links: Sequence[LinkPydanticModel]) -> bool: + """Check if a span matches a link attribute. + + Args: + span: A span to check. + links: A list of link attributes to match. + """ + span_attributes = span.attributes or {} + for link in links: + # trace_id and span_id must be full match. + if link.key_match == "trace_id": + if isinstance(span, ReadableSpan): + trace_id = trace_api.format_trace_id(span.context.trace_id) if span.context else None + else: + trace_id = span.trace_id + if trace_id != link.value_match: + return False + + elif link.key_match == "span_id": + if isinstance(span, ReadableSpan): + span_id = trace_api.format_span_id(span.context.span_id) if span.context else None + else: + span_id = span.span_id + if span_id != link.value_match: + return False + + else: + attribute = span_attributes.get(link.key_match) + # attributes must also be a full match currently. + if attribute != link.value_match: + return False + + return True + + +def query_linked_spans(spans: Sequence[T_SpanLike], links: Sequence[LinkPydanticModel]) -> List[T_SpanLike]: """Query spans that are linked by the given link attributes. Args: @@ -239,42 +275,7 @@ def query_linked_spans(spans: Sequence[T_SpanLike], links: List[LinkPydanticMode Returns: A list of spans that match the given link attributes. """ - matched_spans: List[T_SpanLike] = [] - - for span in spans: - span_attributes = span.attributes or {} - is_match = True - for link in links: - # trace_id and span_id must be full match. - if link.key_match == "trace_id": - if isinstance(span, ReadableSpan): - trace_id = trace_api.format_trace_id(span.context.trace_id) if span.context else None - else: - trace_id = span.trace_id - if trace_id != link.value_match: - is_match = False - break - - elif link.key_match == "span_id": - if isinstance(span, ReadableSpan): - span_id = trace_api.format_span_id(span.context.span_id) if span.context else None - else: - span_id = span.span_id - if span_id != link.value_match: - is_match = False - break - - else: - attribute = span_attributes.get(link.key_match) - # attributes must also be a full match currently. - if attribute != link.value_match: - is_match = False - break - - if is_match: - matched_spans.append(span) - - return matched_spans + return [span for span in spans if check_linked_span(span, links)] def extract_links_from_attributes(attributes: Dict[str, Any]) -> List[LinkPydanticModel]: diff --git a/agentlightning/utils/pydantic.py b/agentlightning/utils/pydantic.py new file mode 100644 index 000000000..b230d50af --- /dev/null +++ b/agentlightning/utils/pydantic.py @@ -0,0 +1,26 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Helper functions to make pydantic more useful.""" + +from typing import Any, List + +from pydantic import ValidationError + + +def to_plain_object(object: Any, path: List[str]) -> Any: + if type(object).__name__ == "ValidatorIterator": + try: + return [to_plain_object(item, path + [str(i)]) for i, item in enumerate(object)] + except ValidationError as exc: + raise ValueError( + "Failed to convert ValidatorIterator to list.\n" + "ValidatorIterator path (see below for subpath): " + ".".join(path) + "\nError: " + str(exc) + ) from exc + elif isinstance(object, dict): + return {k: to_plain_object(v, path + [k]) for k, v in object.items()} # type: ignore + elif isinstance(object, list): + return [to_plain_object(item, path + [str(i)]) for i, item in enumerate(object)] # type: ignore + elif isinstance(object, tuple): + return tuple(to_plain_object(item, path + [str(i)]) for i, item in enumerate(object)) # type: ignore + else: + return object diff --git a/examples/apo/room_selector.py b/examples/apo/room_selector.py index 289351619..b566f734d 100644 --- a/examples/apo/room_selector.py +++ b/examples/apo/room_selector.py @@ -349,6 +349,11 @@ async def debug_room_selector(limit: int = 1): # Get the spans and convert them to messages # Useful for debugging and analysis spans = await store.query_spans(rollout.rollout_id) + for span in spans: + console.print( + f"[bold blue]=== Span {span.span_id} ({span.name}, parent {span.parent_id}) ===[/bold blue]" + ) + console.print(span.attributes) adapter = TraceToMessages() messages = adapter.adapt(spans) for message_idx, message in enumerate(messages): diff --git a/tests/adapter/test_annotation.py b/tests/adapter/test_annotation.py new file mode 100644 index 000000000..b3b347f0b --- /dev/null +++ b/tests/adapter/test_annotation.py @@ -0,0 +1,1747 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for the annotation module adapters.""" + +import itertools +from typing import Any, Dict, Optional + +import pytest + +from agentlightning.adapter.annotation import ( + IdentifyAnnotations, + RepairMissingLinks, + SelectByAnnotation, +) +from agentlightning.semconv import ( + AGL_ANNOTATION, + AGL_EXCEPTION, + AGL_MESSAGE, + AGL_OBJECT, + AGL_OPERATION, + LightningSpanAttributes, + LinkPydanticModel, +) +from agentlightning.types import Span +from agentlightning.types.adapter import ( + AdaptingSequence, + AdaptingSpan, + AgentAnnotation, + ExceptionAnnotation, + GeneralAnnotation, + MessageAnnotation, + ObjectAnnotation, + OperationAnnotation, + Tree, +) + +_SEQ = itertools.count() + + +def make_span( + span_id: str, + name: str, + *, + parent_id: Optional[str], + start_time: float, + end_time: float, + attributes: Optional[Dict[str, Any]] = None, + rollout_id: str = "rollout-1", + attempt_id: str = "attempt-1", + sequence_id: Optional[int] = None, +) -> Span: + """Create a test span with sensible defaults.""" + return Span.from_attributes( + rollout_id=rollout_id, + attempt_id=attempt_id, + sequence_id=sequence_id if sequence_id is not None else next(_SEQ), + trace_id="trace-1", + span_id=span_id, + parent_id=parent_id, + name=name, + attributes=attributes or {}, + start_time=start_time, + end_time=end_time, + ) + + +def make_adapting_span( + span_id: str, + name: str, + *, + parent_id: Optional[str] = None, + start_time: float = 0.0, + end_time: float = 1.0, + attributes: Optional[Dict[str, Any]] = None, + data: Any = None, +) -> AdaptingSpan: + """Create a test AdaptingSpan with sensible defaults.""" + span = make_span( + span_id=span_id, + name=name, + parent_id=parent_id, + start_time=start_time, + end_time=end_time, + attributes=attributes, + ) + return AdaptingSpan.from_span(span, data) + + +# Tests for IdentifyAnnotations._filter_custom_attributes + + +def test_filter_custom_attributes_filters_reserved_fields() -> None: + """Reserved fields should be filtered out.""" + adapter = IdentifyAnnotations() + attributes = { + LightningSpanAttributes.REWARD.value: 1.0, + LightningSpanAttributes.TAG.value: ["tag1"], + "custom.field": "value", + } + + result = adapter._filter_custom_attributes(attributes) # pyright: ignore[reportPrivateUsage] + + assert "custom.field" in result + assert LightningSpanAttributes.REWARD.value not in result + assert LightningSpanAttributes.TAG.value not in result + + +def test_filter_custom_attributes_filters_nested_reserved_fields() -> None: + """Nested reserved fields are only filtered when base field is present.""" + adapter = IdentifyAnnotations() + # When the base field IS present, nested fields are also filtered + attributes = { + LightningSpanAttributes.REWARD.value: "base", + f"{LightningSpanAttributes.REWARD.value}.0.value": 1.0, + f"{LightningSpanAttributes.REWARD.value}.0.name": "quality", + "custom.field": "value", + } + + result = adapter._filter_custom_attributes(attributes) # pyright: ignore[reportPrivateUsage] + + assert "custom.field" in result + assert len(result) == 1 + + +def test_filter_custom_attributes_nested_fields_not_filtered_without_base() -> None: + """Nested reserved fields pass through if base field is not present. + + Note: This is the current behavior - nested fields like 'agentlightning.reward.0.value' + are NOT filtered unless the base field 'agentlightning.reward' is also present. + """ + adapter = IdentifyAnnotations() + attributes = { + f"{LightningSpanAttributes.REWARD.value}.0.value": 1.0, + "custom.field": "value", + } + + result = adapter._filter_custom_attributes(attributes) # pyright: ignore[reportPrivateUsage] + + # Nested fields pass through since base field is not present + assert len(result) == 2 + + +def test_filter_custom_attributes_preserves_custom_fields() -> None: + """Custom fields not matching reserved prefixes should be preserved.""" + adapter = IdentifyAnnotations() + attributes = { + "my.custom.attribute": "value1", + "another_attribute": "value2", + } + + result = adapter._filter_custom_attributes(attributes) # pyright: ignore[reportPrivateUsage] + + assert result == attributes + + +# Tests for IdentifyAnnotations.extract_links + + +def test_extract_links_extracts_valid_links() -> None: + """Valid links should be extracted from span attributes.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + "span", + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + f"{LightningSpanAttributes.LINK.value}.0.key_match": "span_id", + f"{LightningSpanAttributes.LINK.value}.0.value_match": "target-span", + }, + ) + + links = adapter.extract_links(span) + + assert len(links) == 1 + assert links[0].key_match == "span_id" + assert links[0].value_match == "target-span" + + +def test_extract_links_returns_empty_for_no_links() -> None: + """Should return empty list when no links are present.""" + adapter = IdentifyAnnotations() + span = make_span("s1", "span", parent_id=None, start_time=0.0, end_time=1.0) + + links = adapter.extract_links(span) + + assert links == [] + + +def test_extract_links_handles_malformed_links() -> None: + """Malformed links should return empty list without raising.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + "span", + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + f"{LightningSpanAttributes.LINK.value}.0.invalid_field": "value", + }, + ) + + # Should not raise, returns empty list + links = adapter.extract_links(span) + assert links == [] + + +# Tests for IdentifyAnnotations.identify_general + + +def test_identify_general_creates_general_annotation() -> None: + """Should create GeneralAnnotation from annotation span.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + AGL_ANNOTATION, + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + "custom.field": "value", + # Need to provide empty list for tags since extract_tags_from_attributes + # doesn't handle missing tags gracefully + f"{LightningSpanAttributes.TAG.value}.0": "test-tag", + }, + ) + + result = adapter.identify_general(span) + + assert result is not None + assert result.annotation_type == "general" + # custom.field is preserved, tag is filtered out + assert "custom.field" in result.custom_fields + + +def test_identify_general_extracts_rewards() -> None: + """Should extract rewards from annotation span.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + AGL_ANNOTATION, + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + f"{LightningSpanAttributes.REWARD.value}.0.name": "quality", + f"{LightningSpanAttributes.REWARD.value}.0.value": 0.9, + f"{LightningSpanAttributes.TAG.value}.0": "test-tag", + }, + ) + + result = adapter.identify_general(span) + + assert result is not None + assert len(result.rewards) == 1 + assert result.rewards[0].name == "quality" + assert result.rewards[0].value == 0.9 + assert result.primary_reward == 0.9 + + +def test_identify_general_extracts_tags() -> None: + """Should extract tags from annotation span.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + AGL_ANNOTATION, + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + f"{LightningSpanAttributes.TAG.value}.0": "important", + f"{LightningSpanAttributes.TAG.value}.1": "reviewed", + }, + ) + + result = adapter.identify_general(span) + + assert result is not None + assert "important" in result.tags + assert "reviewed" in result.tags + + +# Tests for IdentifyAnnotations.identify_message + + +def test_identify_message_creates_message_annotation() -> None: + """Should create MessageAnnotation from message span.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + AGL_MESSAGE, + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + LightningSpanAttributes.MESSAGE_BODY.value: "Hello, world!", + }, + ) + + result = adapter.identify_message(span) + + assert result is not None + assert result.annotation_type == "message" + assert result.message == "Hello, world!" + + +def test_identify_message_returns_none_for_missing_body() -> None: + """Should return None when message body is missing.""" + adapter = IdentifyAnnotations() + span = make_span("s1", AGL_MESSAGE, parent_id=None, start_time=0.0, end_time=1.0) + + result = adapter.identify_message(span) + + assert result is None + + +# Tests for IdentifyAnnotations.identify_object + + +def test_identify_object_creates_object_annotation_from_json() -> None: + """Should create ObjectAnnotation from JSON object.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + AGL_OBJECT, + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + LightningSpanAttributes.OBJECT_JSON.value: '{"key": "value"}', + }, + ) + + result = adapter.identify_object(span) + + assert result is not None + assert result.annotation_type == "object" + assert result.object == {"key": "value"} + + +def test_identify_object_creates_object_annotation_from_literal() -> None: + """Should create ObjectAnnotation from literal value.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + AGL_OBJECT, + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + LightningSpanAttributes.OBJECT_LITERAL.value: "simple string", + }, + ) + + result = adapter.identify_object(span) + + assert result is not None + assert result.annotation_type == "object" + + +def test_identify_object_handles_invalid_json() -> None: + """Should handle invalid JSON gracefully.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + AGL_OBJECT, + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + LightningSpanAttributes.OBJECT_JSON.value: "invalid json {", + }, + ) + + # Should not raise, returns None due to error handling + result = adapter.identify_object(span) + assert result is None + + +# Tests for IdentifyAnnotations.identify_exception + + +def test_identify_exception_creates_exception_annotation() -> None: + """Should create ExceptionAnnotation from exception span.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + AGL_EXCEPTION, + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + "exception.type": "ValueError", + "exception.message": "Invalid input", + "exception.stacktrace": "Traceback...", + }, + ) + + result = adapter.identify_exception(span) + + assert result is not None + assert result.annotation_type == "exception" + assert result.type == "ValueError" + assert result.message == "Invalid input" + assert result.stacktrace == "Traceback..." + + +def test_identify_exception_uses_defaults_for_missing_fields() -> None: + """Should use default values when exception fields are missing.""" + adapter = IdentifyAnnotations() + span = make_span("s1", AGL_EXCEPTION, parent_id=None, start_time=0.0, end_time=1.0) + + result = adapter.identify_exception(span) + + assert result is not None + assert result.type == "UnknownException" + assert result.message == "" + + +# Tests for IdentifyAnnotations.identify_operation + + +def test_identify_operation_creates_operation_annotation() -> None: + """Should create OperationAnnotation from operation span.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + AGL_OPERATION, + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + LightningSpanAttributes.OPERATION_NAME.value: "process_data", + LightningSpanAttributes.OPERATION_INPUT.value: "input_value", + LightningSpanAttributes.OPERATION_OUTPUT.value: "output_value", + }, + ) + + result = adapter.identify_operation(span) + + assert result is not None + assert result.annotation_type == "operation" + assert result.name == "process_data" + assert result.input == "input_value" + assert result.output == "output_value" + + +def test_identify_operation_extracts_nested_input_output() -> None: + """Should extract nested input/output from flattened attributes.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + AGL_OPERATION, + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + LightningSpanAttributes.OPERATION_NAME.value: "process_data", + f"{LightningSpanAttributes.OPERATION_INPUT.value}.arg1": "value1", + f"{LightningSpanAttributes.OPERATION_INPUT.value}.arg2": "value2", + }, + ) + + result = adapter.identify_operation(span) + + assert result is not None + assert result.input == {"arg1": "value1", "arg2": "value2"} + + +# Tests for IdentifyAnnotations.detect_agent_annotation + + +def test_detect_agent_annotation_detects_otel_agent_span() -> None: + """Should detect agent from OpenTelemetry agent spans.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + "agent.run", + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + "agent.name": "MyAgent", + "agent.id": "agent-123", + "agent.description": "A helpful agent", + }, + ) + + result = adapter.detect_agent_annotation(span) + + assert result is not None + assert result.annotation_type == "agent" + assert result.name == "MyAgent" + assert result.id == "agent-123" + assert result.description == "A helpful agent" + + +def test_detect_agent_annotation_detects_agentops_agent() -> None: + """Should detect agent from AgentOps spans.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + "some.operation", + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + "agentops.span.kind": "agent", + "operation.name": "AgentOpsAgent", + }, + ) + + result = adapter.detect_agent_annotation(span) + + assert result is not None + assert result.name == "AgentOpsAgent" + + +def test_detect_agent_annotation_detects_autogen_agent() -> None: + """Should detect agent from Autogen spans.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + "autogen.task", + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + "recipient_agent_type": "AssistantAgent", + }, + ) + + result = adapter.detect_agent_annotation(span) + + assert result is not None + assert result.name == "AssistantAgent" + + +def test_detect_agent_annotation_detects_langgraph_agent() -> None: + """Should detect agent from LangGraph spans.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + "langgraph.node", + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + "langchain.chain.type": "ReActAgent", + }, + ) + + result = adapter.detect_agent_annotation(span) + + assert result is not None + assert result.name == "ReActAgent" + + +def test_detect_agent_annotation_detects_weave_agent() -> None: + """Should detect agent from Weave spans.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + "weave.call", + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + "type": "agent", + "agentlightning.operation.input.name": "WeaveAgent", + }, + ) + + result = adapter.detect_agent_annotation(span) + + assert result is not None + assert result.name == "WeaveAgent" + + +def test_detect_agent_annotation_detects_langchain_weave_agent() -> None: + """Should detect agent from LangChain + Weave spans.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + "langchain.Chain.MyChain", + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + "lc_name": "LangChainAgent", + }, + ) + + result = adapter.detect_agent_annotation(span) + + assert result is not None + assert result.name == "LangChainAgent" + + +def test_detect_agent_annotation_detects_agent_framework() -> None: + """Should detect agent from agent-framework spans.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + "executor.run", + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + "executor.id": "ExecutorAgent", + }, + ) + + result = adapter.detect_agent_annotation(span) + + assert result is not None + assert result.name == "ExecutorAgent" + + +def test_detect_agent_annotation_returns_none_for_non_agent_span() -> None: + """Should return None for spans without agent indicators.""" + adapter = IdentifyAnnotations() + span = make_span("s1", "some.operation", parent_id=None, start_time=0.0, end_time=1.0) + + result = adapter.detect_agent_annotation(span) + + assert result is None + + +# Tests for IdentifyAnnotations.adapt_one + + +def test_adapt_one_identifies_annotation_span() -> None: + """Should identify AGL_ANNOTATION spans.""" + adapter = IdentifyAnnotations() + source = make_adapting_span( + "s1", + AGL_ANNOTATION, + attributes={ + "custom": "value", + f"{LightningSpanAttributes.TAG.value}.0": "test-tag", + }, + ) + + result = adapter.adapt_one(source) + + assert isinstance(result.data, GeneralAnnotation) + + +def test_adapt_one_identifies_message_span() -> None: + """Should identify AGL_MESSAGE spans.""" + adapter = IdentifyAnnotations() + source = make_adapting_span( + "s1", + AGL_MESSAGE, + attributes={LightningSpanAttributes.MESSAGE_BODY.value: "Hello"}, + ) + + result = adapter.adapt_one(source) + + assert isinstance(result.data, MessageAnnotation) + + +def test_adapt_one_identifies_exception_span() -> None: + """Should identify AGL_EXCEPTION spans.""" + adapter = IdentifyAnnotations() + source = make_adapting_span( + "s1", + AGL_EXCEPTION, + attributes={"exception.type": "Error"}, + ) + + result = adapter.adapt_one(source) + + assert isinstance(result.data, ExceptionAnnotation) + + +def test_adapt_one_identifies_operation_span() -> None: + """Should identify AGL_OPERATION spans.""" + adapter = IdentifyAnnotations() + source = make_adapting_span( + "s1", + AGL_OPERATION, + attributes={LightningSpanAttributes.OPERATION_NAME.value: "op"}, + ) + + result = adapter.adapt_one(source) + + assert isinstance(result.data, OperationAnnotation) + + +def test_adapt_one_falls_back_to_agent_detection() -> None: + """Should fall back to agent detection for unknown span names.""" + adapter = IdentifyAnnotations() + source = make_adapting_span( + "s1", + "agent.task", + attributes={"agent.name": "MyAgent"}, + ) + + result = adapter.adapt_one(source) + + assert isinstance(result.data, AgentAnnotation) + + +def test_adapt_one_returns_unchanged_for_unrecognized_span() -> None: + """Should return unchanged span when no annotation is detected.""" + adapter = IdentifyAnnotations() + source = make_adapting_span("s1", "some.operation") + + result = adapter.adapt_one(source) + + assert result.data is None + + +# Tests for SelectByAnnotation + + +def test_select_by_annotation_include_mode_selects_annotated_spans() -> None: + """Include mode should select only annotated spans and their links.""" + adapter = SelectByAnnotation(mode="include") + + # Create spans with one annotation + annotation_span = make_adapting_span( + "annotation", + AGL_ANNOTATION, + data=GeneralAnnotation( + annotation_type="general", + links=[LinkPydanticModel(key_match="span_id", value_match="linked")], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + linked_span = make_adapting_span("linked", "some.operation") + unrelated_span = make_adapting_span("unrelated", "other.operation") + + source = AdaptingSequence([annotation_span, linked_span, unrelated_span]) + result = adapter.adapt(source) + + span_ids = [s.span_id for s in result] + assert "annotation" in span_ids + assert "linked" in span_ids + assert "unrelated" not in span_ids + + +def test_select_by_annotation_exclude_mode_removes_annotated_spans() -> None: + """Exclude mode should remove annotated spans and their links.""" + adapter = SelectByAnnotation(mode="exclude") + + annotation_span = make_adapting_span( + "annotation", + AGL_ANNOTATION, + data=GeneralAnnotation( + annotation_type="general", + links=[LinkPydanticModel(key_match="span_id", value_match="linked")], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + linked_span = make_adapting_span("linked", "some.operation") + unrelated_span = make_adapting_span("unrelated", "other.operation") + + source = AdaptingSequence([annotation_span, linked_span, unrelated_span]) + result = adapter.adapt(source) + + span_ids = [s.span_id for s in result] + assert "annotation" not in span_ids + assert "linked" not in span_ids + assert "unrelated" in span_ids + + +def test_select_by_annotation_include_mode_with_no_annotations() -> None: + """Include mode with no annotations should return empty result.""" + adapter = SelectByAnnotation(mode="include") + + span1 = make_adapting_span("s1", "operation") + span2 = make_adapting_span("s2", "operation") + + source = AdaptingSequence([span1, span2]) + result = adapter.adapt(source) + + assert len(list(result)) == 0 + + +def test_select_by_annotation_include_mode_with_empty_links() -> None: + """Include mode with annotation having empty links should only include the annotation itself. + + An annotation with links=[] should apply only to itself, not to all spans. + This tests that check_linked_span returning True for empty links doesn't + cause all spans to be incorrectly selected. + """ + adapter = SelectByAnnotation(mode="include") + + # Annotation with no links - should only include itself + annotation_span = make_adapting_span( + "annotation", + AGL_ANNOTATION, + data=GeneralAnnotation( + annotation_type="general", + links=[], # Empty links! + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + unrelated_span1 = make_adapting_span("unrelated1", "some.operation") + unrelated_span2 = make_adapting_span("unrelated2", "other.operation") + + source = AdaptingSequence([annotation_span, unrelated_span1, unrelated_span2]) + result = adapter.adapt(source) + + span_ids = [s.span_id for s in result] + # Only the annotation span should be selected, not the unrelated spans + assert span_ids == ["annotation"] + + +def test_select_by_annotation_exclude_mode_with_empty_links() -> None: + """Exclude mode with annotation having empty links should only exclude the annotation itself. + + An annotation with links=[] should apply only to itself, so only the annotation + span should be excluded, not all spans. + """ + adapter = SelectByAnnotation(mode="exclude") + + annotation_span = make_adapting_span( + "annotation", + AGL_ANNOTATION, + data=GeneralAnnotation( + annotation_type="general", + links=[], # Empty links! + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + unrelated_span1 = make_adapting_span("unrelated1", "some.operation") + unrelated_span2 = make_adapting_span("unrelated2", "other.operation") + + source = AdaptingSequence([annotation_span, unrelated_span1, unrelated_span2]) + result = adapter.adapt(source) + + span_ids = [s.span_id for s in result] + # Only the annotation span should be excluded, unrelated spans should remain + assert "annotation" not in span_ids + assert "unrelated1" in span_ids + assert "unrelated2" in span_ids + + +def test_select_by_annotation_exclude_mode_with_no_annotations() -> None: + """Exclude mode with no annotations should return all spans.""" + adapter = SelectByAnnotation(mode="exclude") + + span1 = make_adapting_span("s1", "operation") + span2 = make_adapting_span("s2", "operation") + + source = AdaptingSequence([span1, span2]) + result = adapter.adapt(source) + + assert len(list(result)) == 2 + + +# Tests for RepairMissingLinks + + +def test_repair_missing_links_backward() -> None: + """Should repair missing links by searching backward.""" + adapter = RepairMissingLinks(scan_direction="backward") + + # Create sequence: candidate -> annotation (without link) + candidate = make_adapting_span("candidate", "some.operation", start_time=0.0, end_time=1.0) + annotation = make_adapting_span( + "annotation", + AGL_ANNOTATION, + start_time=1.0, + end_time=2.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], # No links + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + + source = AdaptingSequence([candidate, annotation]) + result = adapter.adapt(source) + + # Find the annotation span in result + annotation_result = next(s for s in result if s.span_id == "annotation") + assert isinstance(annotation_result.data, GeneralAnnotation) + assert len(annotation_result.data.links) == 1 + assert annotation_result.data.links[0].value_match == "candidate" + + +def test_repair_missing_links_forward() -> None: + """Should repair missing links by searching forward.""" + adapter = RepairMissingLinks(scan_direction="forward") + + # Create sequence: annotation (without link) -> candidate + annotation = make_adapting_span( + "annotation", + AGL_ANNOTATION, + start_time=0.0, + end_time=1.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + candidate = make_adapting_span("candidate", "some.operation", start_time=1.0, end_time=2.0) + + source = AdaptingSequence([annotation, candidate]) + result = adapter.adapt(source) + + annotation_result = next(s for s in result if s.span_id == "annotation") + assert isinstance(annotation_result.data, GeneralAnnotation) + assert len(annotation_result.data.links) == 1 + assert annotation_result.data.links[0].value_match == "candidate" + + +def test_repair_missing_links_respects_candidate_predicate() -> None: + """Should only consider candidates matching the predicate.""" + adapter = RepairMissingLinks( + scan_direction="backward", + candidate_predicate=lambda span: span.name == "valid.candidate", + ) + + invalid_candidate = make_adapting_span("invalid", "invalid.operation", start_time=0.0, end_time=1.0) + valid_candidate = make_adapting_span("valid", "valid.candidate", start_time=1.0, end_time=2.0) + annotation = make_adapting_span( + "annotation", + AGL_ANNOTATION, + start_time=2.0, + end_time=3.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + + source = AdaptingSequence([invalid_candidate, valid_candidate, annotation]) + result = adapter.adapt(source) + + annotation_result = next(s for s in result if s.span_id == "annotation") + assert annotation_result.data.links[0].value_match == "valid" + + +def test_repair_missing_links_does_not_reuse_linked_spans_by_default() -> None: + """By default, linked spans should not be reused for other annotations.""" + adapter = RepairMissingLinks(scan_direction="backward", allow_reuse_linked_spans=False) + + candidate = make_adapting_span("candidate", "operation", start_time=0.0, end_time=1.0) + annotation1 = make_adapting_span( + "ann1", + AGL_ANNOTATION, + start_time=1.0, + end_time=2.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + annotation2 = make_adapting_span( + "ann2", + AGL_ANNOTATION, + start_time=2.0, + end_time=3.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + + source = AdaptingSequence([candidate, annotation1, annotation2]) + result = adapter.adapt(source) + + # Only one annotation should get the link + linked_annotations = [s for s in result if isinstance(s.data, GeneralAnnotation) and len(s.data.links) > 0] + assert len(linked_annotations) == 1 + + +def test_repair_missing_links_allows_reuse_linked_spans_when_enabled() -> None: + """When enabled, linked spans can be reused for multiple annotations.""" + adapter = RepairMissingLinks(scan_direction="backward", allow_reuse_linked_spans=True) + + candidate = make_adapting_span("candidate", "operation", start_time=0.0, end_time=1.0) + annotation1 = make_adapting_span( + "ann1", + AGL_ANNOTATION, + start_time=1.0, + end_time=2.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + annotation2 = make_adapting_span( + "ann2", + AGL_ANNOTATION, + start_time=2.0, + end_time=3.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + + source = AdaptingSequence([candidate, annotation1, annotation2]) + result = adapter.adapt(source) + + # Both annotations should get links to the same candidate + linked_annotations = [s for s in result if isinstance(s.data, GeneralAnnotation) and len(s.data.links) > 0] + assert len(linked_annotations) == 2 + + +def test_repair_missing_links_preserves_existing_links() -> None: + """Annotations with existing links should not be modified.""" + adapter = RepairMissingLinks(scan_direction="backward") + + candidate = make_adapting_span("candidate", "operation", start_time=0.0, end_time=1.0) + annotation_with_link = make_adapting_span( + "annotation", + AGL_ANNOTATION, + start_time=1.0, + end_time=2.0, + data=GeneralAnnotation( + annotation_type="general", + links=[LinkPydanticModel(key_match="span_id", value_match="existing-target")], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + + source = AdaptingSequence([candidate, annotation_with_link]) + result = adapter.adapt(source) + + annotation_result = next(s for s in result if s.span_id == "annotation") + # Original link should be preserved + assert annotation_result.data.links[0].value_match == "existing-target" + + +def test_repair_missing_links_siblings_scope_requires_tree() -> None: + """Siblings scope should raise error for non-tree sequences.""" + adapter = RepairMissingLinks(candidate_scope="siblings") + + span = make_adapting_span("s1", "operation") + source = AdaptingSequence([span]) + + with pytest.raises(ValueError, match="siblings.*only applicable to tree"): + adapter.adapt(source) + + +def test_repair_missing_links_siblings_scope_with_tree() -> None: + """Siblings scope should work correctly with tree sequences.""" + adapter = RepairMissingLinks(candidate_scope="siblings", scan_direction="backward") + + # Build a simple tree: root -> [child1, child2] + root = make_adapting_span("root", "root", start_time=0.0, end_time=10.0) + child1 = make_adapting_span("child1", "operation", start_time=1.0, end_time=4.0) + child2 = make_adapting_span( + "child2", + AGL_ANNOTATION, + start_time=5.0, + end_time=9.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + + # Create tree structure + child1_tree = Tree(child1, []) + child2_tree = Tree(child2, []) + tree = Tree(root, [child1_tree, child2_tree]) + + result = adapter.adapt(tree) + + # child2 annotation should link to child1 (its sibling) + child2_result = next(s for s in result.traverse() if s.span_id == "child2") + assert isinstance(child2_result.data, GeneralAnnotation) + assert len(child2_result.data.links) == 1 + assert child2_result.data.links[0].value_match == "child1" + + +# Edge case tests + + +def test_identify_annotations_empty_sequence() -> None: + """Should handle empty sequence.""" + adapter = IdentifyAnnotations() + result = adapter.adapt(AdaptingSequence([])) + assert list(result) == [] + + +def test_select_by_annotation_empty_sequence() -> None: + """Should handle empty sequence.""" + adapter = SelectByAnnotation(mode="include") + result = adapter.adapt(AdaptingSequence([])) + assert len(list(result)) == 0 + + +def test_repair_missing_links_empty_sequence() -> None: + """Should handle empty sequence.""" + adapter = RepairMissingLinks() + result = adapter.adapt(AdaptingSequence([])) + assert len(list(result)) == 0 + + +# Additional corner case tests + + +def test_identify_object_returns_annotation_with_none_when_no_object_attributes() -> None: + """Should return ObjectAnnotation with None object when no JSON or literal present.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + AGL_OBJECT, + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={}, + ) + + result = adapter.identify_object(span) + + assert result is not None + assert result.annotation_type == "object" + assert result.object is None + + +def test_extract_agent_name_returns_none_for_agentops_without_operation_name() -> None: + """AgentOps span with kind=agent but no operation.name should return None.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + "some.operation", + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + "agentops.span.kind": "agent", + # Missing operation.name + }, + ) + + result = adapter.extract_agent_name(span) + + assert result is None + + +def test_extract_agent_name_returns_none_for_weave_without_input_name() -> None: + """Weave span with type=agent but no input.name should return None.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + "weave.call", + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + "type": "agent", + # Missing agentlightning.operation.input.name + }, + ) + + result = adapter.extract_agent_name(span) + + assert result is None + + +def test_extract_agent_name_returns_none_for_langchain_chain_without_lc_name() -> None: + """LangChain chain span without lc_name should return None.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + "langchain.Chain.MyChain", + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + # Missing lc_name + }, + ) + + result = adapter.extract_agent_name(span) + + assert result is None + + +def test_identify_annotations_adapt_processes_multiple_spans() -> None: + """Should process multiple spans in a sequence.""" + adapter = IdentifyAnnotations() + + span1 = make_adapting_span( + "s1", + AGL_ANNOTATION, + attributes={f"{LightningSpanAttributes.TAG.value}.0": "tag1"}, + ) + span2 = make_adapting_span( + "s2", + AGL_MESSAGE, + attributes={LightningSpanAttributes.MESSAGE_BODY.value: "Hello"}, + ) + span3 = make_adapting_span("s3", "regular.span") + + source = AdaptingSequence([span1, span2, span3]) + result = adapter.adapt(source) + + result_list = list(result) + assert len(result_list) == 3 + assert isinstance(result_list[0].data, GeneralAnnotation) + assert isinstance(result_list[1].data, MessageAnnotation) + assert result_list[2].data is None + + +def test_select_by_annotation_links_use_and_logic() -> None: + """Multiple links in an annotation use AND logic - span must match ALL links.""" + adapter = SelectByAnnotation(mode="include") + + # Links specify span_id AND a custom attribute must both match + annotation_span = make_adapting_span( + "annotation", + AGL_ANNOTATION, + data=GeneralAnnotation( + annotation_type="general", + links=[ + LinkPydanticModel(key_match="span_id", value_match="target"), + LinkPydanticModel(key_match="custom.tag", value_match="important"), + ], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + # This span matches both constraints + target = make_adapting_span("target", "target.operation", attributes={"custom.tag": "important"}) + # This span only matches span_id but not the custom attribute + partial_match = make_adapting_span("partial", "other.operation", attributes={"custom.tag": "important"}) + unrelated = make_adapting_span("unrelated", "operation3") + + source = AdaptingSequence([annotation_span, target, partial_match, unrelated]) + result = adapter.adapt(source) + + span_ids = [s.span_id for s in result] + assert "annotation" in span_ids + assert "target" in span_ids + # partial_match doesn't match span_id constraint + assert "partial" not in span_ids + assert "unrelated" not in span_ids + + +def test_select_by_annotation_multiple_annotations_link_different_spans() -> None: + """Multiple annotations can each link to different spans.""" + adapter = SelectByAnnotation(mode="include") + + ann1 = make_adapting_span( + "ann1", + AGL_ANNOTATION, + data=GeneralAnnotation( + annotation_type="general", + links=[LinkPydanticModel(key_match="span_id", value_match="linked1")], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + ann2 = make_adapting_span( + "ann2", + AGL_ANNOTATION, + data=GeneralAnnotation( + annotation_type="general", + links=[LinkPydanticModel(key_match="span_id", value_match="linked2")], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + linked1 = make_adapting_span("linked1", "operation1") + linked2 = make_adapting_span("linked2", "operation2") + unrelated = make_adapting_span("unrelated", "operation3") + + source = AdaptingSequence([ann1, ann2, linked1, linked2, unrelated]) + result = adapter.adapt(source) + + span_ids = [s.span_id for s in result] + assert "ann1" in span_ids + assert "ann2" in span_ids + assert "linked1" in span_ids + assert "linked2" in span_ids + assert "unrelated" not in span_ids + + +def test_repair_missing_links_multiple_annotations_single_candidate_no_reuse() -> None: + """Without reuse, only one annotation gets linked to a candidate.""" + adapter = RepairMissingLinks(scan_direction="backward", allow_reuse_linked_spans=False) + + candidate = make_adapting_span("candidate", "operation", start_time=0.0, end_time=1.0) + ann1 = make_adapting_span( + "ann1", + AGL_ANNOTATION, + start_time=1.0, + end_time=2.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + ann2 = make_adapting_span( + "ann2", + AGL_ANNOTATION, + start_time=2.0, + end_time=3.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + ann3 = make_adapting_span( + "ann3", + AGL_ANNOTATION, + start_time=3.0, + end_time=4.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + + source = AdaptingSequence([candidate, ann1, ann2, ann3]) + result = adapter.adapt(source) + + # Only one annotation should get linked + linked = [s for s in result if isinstance(s.data, GeneralAnnotation) and len(s.data.links) > 0] + assert len(linked) == 1 + + +def test_repair_missing_links_backward_links_nearest_candidate() -> None: + """Backward scan should link each annotation to its nearest preceding candidate. + + When multiple annotations appear before candidates (in chronological order), + the annotation closest to the candidate should get that candidate. + + Chronological order: [C1, C2, A1, A2] + Expected: A2->C2 (nearest), A1->C1 (remaining) + """ + adapter = RepairMissingLinks(scan_direction="backward", allow_reuse_linked_spans=False) + + c1 = make_adapting_span("c1", "operation", start_time=0.0, end_time=1.0) + c2 = make_adapting_span("c2", "operation", start_time=1.0, end_time=2.0) + a1 = make_adapting_span( + "a1", + AGL_ANNOTATION, + start_time=2.0, + end_time=3.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + a2 = make_adapting_span( + "a2", + AGL_ANNOTATION, + start_time=3.0, + end_time=4.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + + source = AdaptingSequence([c1, c2, a1, a2]) + result = adapter.adapt(source) + + # A2 (latest annotation) should link to C2 (nearest candidate before it) + a2_result = next(s for s in result if s.span_id == "a2") + assert isinstance(a2_result.data, GeneralAnnotation) + assert len(a2_result.data.links) == 1 + assert a2_result.data.links[0].value_match == "c2" + + # A1 should link to C1 (remaining candidate) + a1_result = next(s for s in result if s.span_id == "a1") + assert isinstance(a1_result.data, GeneralAnnotation) + assert len(a1_result.data.links) == 1 + assert a1_result.data.links[0].value_match == "c1" + + +def test_repair_missing_links_forward_links_nearest_candidate() -> None: + """Forward scan should link each annotation to its nearest following candidate. + + When multiple annotations appear before candidates (in chronological order), + the annotation closest to the candidate should get that candidate. + + Chronological order: [A1, A2, C1, C2] + Expected: A1->C1 (nearest), A2->C2 (remaining) + """ + adapter = RepairMissingLinks(scan_direction="forward", allow_reuse_linked_spans=False) + + a1 = make_adapting_span( + "a1", + AGL_ANNOTATION, + start_time=0.0, + end_time=1.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + a2 = make_adapting_span( + "a2", + AGL_ANNOTATION, + start_time=1.0, + end_time=2.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + c1 = make_adapting_span("c1", "operation", start_time=2.0, end_time=3.0) + c2 = make_adapting_span("c2", "operation", start_time=3.0, end_time=4.0) + + source = AdaptingSequence([a1, a2, c1, c2]) + result = adapter.adapt(source) + + # A1 (earliest annotation) should link to C1 (nearest candidate after it) + a1_result = next(s for s in result if s.span_id == "a1") + assert isinstance(a1_result.data, GeneralAnnotation) + assert len(a1_result.data.links) == 1 + assert a1_result.data.links[0].value_match == "c1" + + # A2 should link to C2 (remaining candidate) + a2_result = next(s for s in result if s.span_id == "a2") + assert isinstance(a2_result.data, GeneralAnnotation) + assert len(a2_result.data.links) == 1 + assert a2_result.data.links[0].value_match == "c2" + + +def test_repair_missing_links_backward_interleaved() -> None: + """Backward scan with interleaved annotations and candidates. + + Chronological order: [C1, A1, C2, A2] + Expected: A1->C1 (nearest before A1), A2->C2 (nearest before A2) + """ + adapter = RepairMissingLinks(scan_direction="backward", allow_reuse_linked_spans=False) + + c1 = make_adapting_span("c1", "operation", start_time=0.0, end_time=1.0) + a1 = make_adapting_span( + "a1", + AGL_ANNOTATION, + start_time=1.0, + end_time=2.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + c2 = make_adapting_span("c2", "operation", start_time=2.0, end_time=3.0) + a2 = make_adapting_span( + "a2", + AGL_ANNOTATION, + start_time=3.0, + end_time=4.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + + source = AdaptingSequence([c1, a1, c2, a2]) + result = adapter.adapt(source) + + # A2 should link to C2 (nearest candidate before A2) + a2_result = next(s for s in result if s.span_id == "a2") + assert isinstance(a2_result.data, GeneralAnnotation) + assert len(a2_result.data.links) == 1 + assert a2_result.data.links[0].value_match == "c2" + + # A1 should link to C1 (nearest candidate before A1) + a1_result = next(s for s in result if s.span_id == "a1") + assert isinstance(a1_result.data, GeneralAnnotation) + assert len(a1_result.data.links) == 1 + assert a1_result.data.links[0].value_match == "c1" + + +def test_repair_missing_links_forward_interleaved() -> None: + """Forward scan with interleaved annotations and candidates. + + Chronological order: [A1, C1, A2, C2] + Expected: A1->C1 (nearest after A1), A2->C2 (nearest after A2) + """ + adapter = RepairMissingLinks(scan_direction="forward", allow_reuse_linked_spans=False) + + a1 = make_adapting_span( + "a1", + AGL_ANNOTATION, + start_time=0.0, + end_time=1.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + c1 = make_adapting_span("c1", "operation", start_time=1.0, end_time=2.0) + a2 = make_adapting_span( + "a2", + AGL_ANNOTATION, + start_time=2.0, + end_time=3.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + c2 = make_adapting_span("c2", "operation", start_time=3.0, end_time=4.0) + + source = AdaptingSequence([a1, c1, a2, c2]) + result = adapter.adapt(source) + + # A1 should link to C1 (nearest candidate after A1) + a1_result = next(s for s in result if s.span_id == "a1") + assert isinstance(a1_result.data, GeneralAnnotation) + assert len(a1_result.data.links) == 1 + assert a1_result.data.links[0].value_match == "c1" + + # A2 should link to C2 (nearest candidate after A2) + a2_result = next(s for s in result if s.span_id == "a2") + assert isinstance(a2_result.data, GeneralAnnotation) + assert len(a2_result.data.links) == 1 + assert a2_result.data.links[0].value_match == "c2" + + +def test_repair_missing_links_skips_annotation_spans_as_candidates() -> None: + """Annotation spans should not be considered as link candidates.""" + adapter = RepairMissingLinks(scan_direction="backward") + + # Two annotations in sequence - second should not link to first + ann1 = make_adapting_span( + "ann1", + AGL_ANNOTATION, + start_time=0.0, + end_time=1.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + ann2 = make_adapting_span( + "ann2", + AGL_ANNOTATION, + start_time=1.0, + end_time=2.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + + source = AdaptingSequence([ann1, ann2]) + result = adapter.adapt(source) + + # Neither annotation should get links since there are no candidates + for span in result: + if isinstance(span.data, GeneralAnnotation): + assert len(span.data.links) == 0 + + +def test_identify_operation_with_no_input_output() -> None: + """Should handle operation spans with no input or output.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + AGL_OPERATION, + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + LightningSpanAttributes.OPERATION_NAME.value: "simple_op", + }, + ) + + result = adapter.identify_operation(span) + + assert result is not None + assert result.name == "simple_op" + assert result.input == {} + assert result.output == {} + + +def test_identify_general_with_multiple_rewards() -> None: + """Should extract multiple rewards and use first as primary.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + AGL_ANNOTATION, + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + f"{LightningSpanAttributes.REWARD.value}.0.name": "quality", + f"{LightningSpanAttributes.REWARD.value}.0.value": 0.9, + f"{LightningSpanAttributes.REWARD.value}.1.name": "relevance", + f"{LightningSpanAttributes.REWARD.value}.1.value": 0.7, + f"{LightningSpanAttributes.TAG.value}.0": "test", + }, + ) + + result = adapter.identify_general(span) + + assert result is not None + assert len(result.rewards) == 2 + assert result.primary_reward == 0.9 + + +def test_adapt_one_with_object_span() -> None: + """Should identify AGL_OBJECT spans via adapt_one.""" + adapter = IdentifyAnnotations() + source = make_adapting_span( + "s1", + AGL_OBJECT, + attributes={LightningSpanAttributes.OBJECT_JSON.value: '{"key": "value"}'}, + ) + + result = adapter.adapt_one(source) + + assert isinstance(result.data, ObjectAnnotation) + assert result.data.object == {"key": "value"} + + +def test_select_by_annotation_with_link_using_attribute_match() -> None: + """Should support links that match by custom attribute.""" + adapter = SelectByAnnotation(mode="include") + + annotation_span = make_adapting_span( + "annotation", + AGL_ANNOTATION, + data=GeneralAnnotation( + annotation_type="general", + links=[LinkPydanticModel(key_match="custom.tag", value_match="important")], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + target = make_adapting_span("target", "target.operation", attributes={"custom.tag": "important"}) + other = make_adapting_span("other", "other.operation", attributes={"custom.tag": "other"}) + + source = AdaptingSequence([annotation_span, target, other]) + result = adapter.adapt(source) + + span_ids = [s.span_id for s in result] + assert "annotation" in span_ids + assert "target" in span_ids + assert "other" not in span_ids diff --git a/tests/adapter/test_call.py b/tests/adapter/test_call.py new file mode 100644 index 000000000..fe1ca0636 --- /dev/null +++ b/tests/adapter/test_call.py @@ -0,0 +1,284 @@ +# Copyright (c) Microsoft. All rights reserved. + +from typing import List + +from agentlightning.adapter.base import Chain +from agentlightning.adapter.call import IdentifyChatCompletionCalls +from agentlightning.adapter.preprocess import RepairMalformedSpans, ToTree +from agentlightning.types.tracer import Span + + +def test_openai_calls(): + spans: List[Span] = [ + Span.from_attributes( + attributes={ + "tool.name": "get_rooms_and_availability", + "tool.parameters": '{"date":"2025-10-13","time":"16:30","duration_min":60}', + "tool.call.id": "call_owd6...l9Gv", + "tool.call.type": "function", + }, + name="tool_call.get_rooms_and_availability", + span_id="18edfd18ea659820", + parent_id="e858708413368c22", + sequence_id=0, + ), + Span.from_attributes( + attributes={ + "tool.name": "get_rooms_and_availability", + "tool.parameters": '{"date":"2025-10-13","time":"16:30","duration_min":60}', + "tool.call.id": "call_VKsy...5ULX", + "tool.call.type": "function", + }, + name="tool_call.get_rooms_and_availability", + span_id="6c9ba649e512a7f8", + parent_id="e858708413368c22", + sequence_id=1, + ), + Span.from_attributes( + attributes={ + "gen_ai.request.type": "chat", + "gen_ai.system": "OpenAI", + "gen_ai.request.model": "gpt-4.1-nano", + "gen_ai.request.temperature": 0.0, + "gen_ai.request.streaming": False, + "gen_ai.prompt.0.role": "system", + "gen_ai.prompt.0.content": "You are a scheduling assistant.", + "gen_ai.prompt.1.role": "user", + "gen_ai.prompt.1.content": "Find a room 2025-10-13 16:30 for 60m, 12 attendees; needs confphone+whiteboard; accessible.", + "gen_ai.request.functions.0.name": "get_rooms_and_availability", + "gen_ai.request.functions.0.description": "Return rooms with capacity/equipment/accessibility/bookings.", + "gen_ai.request.functions.0.parameters": '{"type":"object","properties":{"date":{"type":"string"},"time":{"type":"string"},"duration_min":{"type":"integer"}},"required":["date","time","duration_min"]}', + "gen_ai.response.id": "chatcmpl-...1jC4", + "gen_ai.response.model": "gpt-4.1-nano-2025-04-14", + "gen_ai.openai.system_fingerprint": "fp_03e44fcc34", + "gen_ai.usage.total_tokens": 211, + "gen_ai.usage.prompt_tokens": 128, + "gen_ai.usage.completion_tokens": 83, + "gen_ai.completion.0.finish_reason": "tool_calls", + "gen_ai.completion.0.role": "assistant", + }, + name="openai.chat.completion", + span_id="e858708413368c22", + parent_id="3db86425087d211f", + sequence_id=2, + ), + Span.from_attributes( + attributes={ + "gen_ai.request.type": "chat", + "gen_ai.system": "OpenAI", + "gen_ai.request.model": "gpt-4.1-nano", + "gen_ai.request.temperature": 0.0, + "gen_ai.request.streaming": False, + "gen_ai.prompt.0.role": "system", + "gen_ai.prompt.0.content": "You are a scheduling assistant.", + "gen_ai.prompt.1.role": "user", + "gen_ai.prompt.1.content": "Find a room 2025-10-13 16:30 for 60m, 12 attendees; needs confphone+whiteboard; accessible.", + "gen_ai.prompt.2.role": "assistant", + "gen_ai.prompt.2.tool_calls.0.id": "call_owd6...l9Gv", + "gen_ai.prompt.2.tool_calls.0.name": "get_rooms_and_availability", + "gen_ai.prompt.2.tool_calls.0.arguments": '{"date":"2025-10-13","time":"16:30","duration_min":60}', + "gen_ai.prompt.2.tool_calls.1.id": "call_VKsy...5ULX", + "gen_ai.prompt.2.tool_calls.1.name": "get_rooms_and_availability", + "gen_ai.prompt.2.tool_calls.1.arguments": '{"date":"2025-10-13","time":"16:30","duration_min":60}', + "gen_ai.prompt.3.role": "tool", + "gen_ai.prompt.3.content": '{"rooms":[{"id":"Nova","capacity":12,"equipment":["whiteboard","confphone"],"accessible":true,"distance_m":45,"booked":[],"free":true},{"id":"Pulse","capacity":8,"equipment":["whiteboard","confphone"],"accessible":true,"booked":[["2025-10-13","16:30",30]],"free":false}]}', + "gen_ai.prompt.3.tool_call_id": "call_owd6...l9Gv", + "gen_ai.prompt.4.role": "tool", + "gen_ai.prompt.4.content": '{"rooms":[{"id":"Nova","capacity":12,"equipment":["whiteboard","confphone"],"accessible":true,"free":true},{"id":"Pulse","capacity":8,"equipment":["whiteboard","confphone"],"accessible":true,"free":false}]}', + "gen_ai.prompt.4.tool_call_id": "call_VKsy...5ULX", + "gen_ai.response.id": "chatcmpl-...Syso", + "gen_ai.response.model": "gpt-4.1-nano-2025-04-14", + "gen_ai.openai.system_fingerprint": "fp_03e44fcc34", + "gen_ai.usage.total_tokens": 1189, + "gen_ai.usage.prompt_tokens": 1082, + "gen_ai.usage.completion_tokens": 107, + "gen_ai.completion.0.finish_reason": "stop", + "gen_ai.completion.0.role": "assistant", + "gen_ai.completion.0.content": "Available rooms: Nova (cap 12, whiteboard+confphone, accessible); Lyra (cap 10...).", + }, + name="openai.chat.completion", + span_id="9a44818e0901d0a1", + parent_id="3db86425087d211f", + sequence_id=3, + ), + Span.from_attributes( + attributes={ + "agentops.span.kind": "session", + "operation.name": "ro-90201d0a24cb", + }, + name="ro-90201d0a24cb.session", + span_id="3db86425087d211f", + parent_id=None, + sequence_id=4, + ), + Span.from_attributes( + attributes={ + "agentlightning.reward.0.name": "primary", + "agentlightning.reward.0.value": 1.0, + }, + name="agentlightning.annotation", + span_id="dc5e3c27f4378b6e", + parent_id=None, + sequence_id=5, + ), + ] + + # r1 = RepairMalformedSpans()(spans) + # r2 = ToTree()(r1) + # r2.visualize(filename="test_openai_calls", item_to_str=lambda span: span.name) + + adapter = Chain( + RepairMalformedSpans(), + ToTree(), + IdentifyChatCompletionCalls(), + ) + + adapted_spans = adapter(spans) + + for span in adapted_spans: + print(span) + + +def test_litellm_call(): + """Test LiteLLM proxy spans with chat completion calls.""" + spans: List[Span] = [ + # proxy_pre_call span + Span.from_attributes( + attributes={ + "call_type": "add_litellm_data_to_request", + "service": "proxy_pre_call", + }, + name="proxy_pre_call", + span_id="a82547704417abf4", + parent_id="9e1058bdd104a886", + sequence_id=0, + ), + # router span (async_get_available_deployment) + Span.from_attributes( + attributes={ + "call_type": "async_get_available_deployment", + "service": "router", + }, + name="router", + span_id="48086befab5d70cd", + parent_id="9e1058bdd104a886", + sequence_id=1, + ), + # self span (make_openai_chat_completion_request) + Span.from_attributes( + attributes={ + "call_type": "make_openai_chat_completion_request <- track_llm_api_timing", + "service": "self", + }, + name="self", + span_id="7b2c9b9d544c2107", + parent_id="9e1058bdd104a886", + sequence_id=2, + ), + # router span (acompletion) + Span.from_attributes( + attributes={ + "call_type": "acompletion", + "service": "router", + }, + name="router", + span_id="44f9efc7cd957922", + parent_id="9e1058bdd104a886", + sequence_id=3, + ), + # litellm_request span - main LLM call span with gen_ai attributes + Span.from_attributes( + attributes={ + "metadata.user_api_key_hash": "", + "metadata.user_api_key_alias": "", + "metadata.user_api_key_spend": 0.0, + "metadata.user_api_key_max_budget": "", + "metadata.user_api_key_budget_reset_at": "", + "metadata.user_api_key_team_id": "", + "metadata.user_api_key_org_id": "", + "metadata.user_api_key_user_id": "", + "metadata.user_api_key_team_alias": "", + "metadata.user_api_key_user_email": "", + "metadata.user_api_key_end_user_id": "", + "metadata.user_api_key_request_route": "/v1/chat/completions", + "metadata.spend_logs_metadata": "", + "metadata.requester_ip_address": "", + "metadata.requester_metadata": "{}", + "metadata.prompt_management_metadata": "", + "metadata.applied_guardrails": "[]", + "metadata.mcp_tool_call_metadata": "", + "metadata.vector_store_request_metadata": "", + "metadata.usage_object": "{'completion_tokens': 48, 'prompt_tokens': 36, 'total_tokens': 84}", + "metadata.requester_custom_headers": "{'x-rollout-id': 'ro-0b4d59a7d478'}", + "metadata.cold_storage_object_key": "", + "metadata.user_api_key_auth_metadata": "{}", + "hidden_params": '{"model_id": "f6746e78...b010", "api_base": "http://localhost:45177/v1"}', + "gen_ai.cost.input_cost": 0, + "gen_ai.cost.output_cost": 0, + "gen_ai.cost.total_cost": 0.0, + "gen_ai.cost.tool_usage_cost": 0.0, + "gen_ai.cost.original_cost": 0.0, + "gen_ai.cost.discount_percent": 0.0, + "gen_ai.cost.discount_amount": 0.0, + "gen_ai.request.model": "Qwen/Qwen2.5-0.5B-Instruct", + "llm.request.type": "acompletion", + "gen_ai.system": "hosted_vllm", + "llm.is_streaming": "False", + "gen_ai.response.id": "chatcmpl-8b25...3e6e", + "gen_ai.response.model": "hosted_vllm/Qwen/Qwen2.5-0.5B-Instruct", + "llm.usage.total_tokens": 84, + "gen_ai.usage.completion_tokens": 48, + "gen_ai.usage.prompt_tokens": 36, + "gen_ai.prompt.0.role": "user", + "gen_ai.prompt.0.content": "Hello, what's your name?", + "gen_ai.completion.0.finish_reason": "stop", + "gen_ai.completion.0.role": "assistant", + "gen_ai.completion.0.content": "Hello! I am Qwen, an AI assistant created by Alibaba Cloud.", + }, + name="litellm_request", + span_id="c43d325d68344786", + parent_id="9e1058bdd104a886", + sequence_id=4, + ), + # raw_gen_ai_request span - sibling of litellm_request (same parent) + Span.from_attributes( + attributes={ + "llm.hosted_vllm.messages": '[{"role": "user", "content": "Hello, what\'s your name?"}]', + "llm.hosted_vllm.extra_body": "{'return_token_ids': True}", + "llm.hosted_vllm.id": "chatcmpl-8b25...3e6e", + "llm.hosted_vllm.choices": '[{"finish_reason": "stop", "index": 0, "message": {"content": "Hello! I am Qwen...", "role": "assistant"}}]', + "llm.hosted_vllm.created": 1767842789, + "llm.hosted_vllm.model": "Qwen/Qwen2.5-0.5B-Instruct", + "llm.hosted_vllm.object": "chat.completion", + "llm.hosted_vllm.service_tier": "", + "llm.hosted_vllm.system_fingerprint": "", + "llm.hosted_vllm.usage": "{'completion_tokens': 48, 'prompt_tokens': 36, 'total_tokens': 84}", + "llm.hosted_vllm.prompt_logprobs": "", + "llm.hosted_vllm.prompt_token_ids": "[151644, 8948, 198, ...]", + "llm.hosted_vllm.kv_transfer_params": "", + }, + name="raw_gen_ai_request", + span_id="bb5a15dd04c9e74b", + parent_id="9e1058bdd104a886", # Same parent as litellm_request - they are siblings + sequence_id=5, + ), + # Root span - Received Proxy Server Request + Span.from_attributes( + attributes={}, + name="Received Proxy Server Request", + span_id="9e1058bdd104a886", + parent_id=None, + sequence_id=6, + ), + ] + + adapter = Chain( + RepairMalformedSpans(), + ToTree(), + IdentifyChatCompletionCalls(), + ) + + adapted_spans = adapter(spans) + + for span in adapted_spans: + print(span) diff --git a/tests/adapter/test_postprocess.py b/tests/adapter/test_postprocess.py new file mode 100644 index 000000000..5543dadc9 --- /dev/null +++ b/tests/adapter/test_postprocess.py @@ -0,0 +1,1257 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for the postprocess module adapters.""" + +from typing import Any, Dict, List, Optional, Sequence + +import pytest +from openai.types.chat import ChatCompletion +from openai.types.chat.chat_completion import Choice, ChoiceLogprobs +from openai.types.chat.chat_completion_message import ChatCompletionMessage +from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob + +from agentlightning.adapter.postprocess import ( + PropagateRewards, + ToPromptCompletionAccumulations, + ToPromptCompletionTriplets, + ToTokensAccumulations, + ToTokensTriplets, + is_prefix, +) +from agentlightning.types.adapter import ( + AdaptingSequence, + AdaptingSpan, + AnnotatedChatCompletionCall, + ChatCompletionCall, + GeneralAnnotation, + PromptCompletionAccumulation, + PromptCompletionTriplet, + TokenInput, + TokenOutput, + TokensAccumulation, + TokensTriplet, +) +from agentlightning.types.tracer import Span + +# ============================================================================== +# Helper functions to create mock objects +# ============================================================================== + + +def make_chat_completion( + *, + content: str = "Hello", + role: str = "assistant", + finish_reason: str = "stop", + model: str = "gpt-4", + completion_id: str = "chatcmpl-123", + token_ids: Optional[Sequence[int]] = None, + logprobs: Optional[List[float]] = None, + provider_specific_fields: Optional[Dict[str, Any]] = None, +) -> ChatCompletion: + """Create a mock ChatCompletion object.""" + choice_logprobs = None + if logprobs is not None: + choice_logprobs = ChoiceLogprobs( + content=[ + ChatCompletionTokenLogprob(token=f"tok{i}", bytes=None, logprob=lp, top_logprobs=[]) + for i, lp in enumerate(logprobs) + ], + refusal=None, + ) + + message = ChatCompletionMessage(content=content, role=role, refusal=None) # type: ignore + choice = Choice( + finish_reason=finish_reason, # type: ignore + index=0, + message=message, + logprobs=choice_logprobs, + ) + + # Add token_ids as an attribute if provided + if token_ids is not None: + choice.token_ids = token_ids # type: ignore + if provider_specific_fields is not None: + choice.provider_specific_fields = provider_specific_fields # type: ignore + + return ChatCompletion( + id=completion_id, + choices=[choice], + created=1234567890, + model=model, + object="chat.completion", + ) + + +def make_completion_request( + *, + messages: Optional[List[Dict[str, Any]]] = None, + model: str = "gpt-4", + prompt_token_ids: Optional[Sequence[int]] = None, + tools: Optional[List[Dict[str, Any]]] = None, +) -> Dict[str, Any]: + """Create a mock CompletionCreateParams-like dict.""" + if messages is None: + messages = [{"role": "user", "content": "Hello"}] + + request: Dict[str, Any] = { + "messages": messages, + "model": model, + } + if prompt_token_ids is not None: + request["prompt_token_ids"] = prompt_token_ids + if tools is not None: + request["tools"] = tools + return request + + +def make_chat_completion_call( + *, + request: Optional[Dict[str, Any]] = None, + response: Optional[ChatCompletion] = None, + prompt_token_ids: Optional[Sequence[int]] = None, + completion_token_ids: Optional[Sequence[int]] = None, + logprobs: Optional[List[float]] = None, +) -> ChatCompletionCall: + """Create a mock ChatCompletionCall.""" + if request is None: + request = make_completion_request(prompt_token_ids=prompt_token_ids) + if response is None: + response = make_chat_completion(token_ids=completion_token_ids, logprobs=logprobs) + # Use model_construct to bypass Pydantic validation which strips unknown fields like prompt_token_ids + return ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + + +def make_annotated_call( + *, + request: Optional[Dict[str, Any]] = None, + response: Optional[ChatCompletion] = None, + reward: Optional[float] = None, + prompt_token_ids: Optional[Sequence[int]] = None, + completion_token_ids: Optional[Sequence[int]] = None, + logprobs: Optional[List[float]] = None, +) -> AnnotatedChatCompletionCall: + """Create a mock AnnotatedChatCompletionCall with optional reward.""" + if request is None: + request = make_completion_request(prompt_token_ids=prompt_token_ids) + if response is None: + response = make_chat_completion(token_ids=completion_token_ids, logprobs=logprobs) + + annotations: List[GeneralAnnotation] = [] + if reward is not None: + annotations.append(GeneralAnnotation(primary_reward=reward)) + + # Use model_construct to bypass Pydantic validation which strips unknown fields like prompt_token_ids + return AnnotatedChatCompletionCall.model_construct( + request=request, + response=response, + malformed_fields={}, + annotations=annotations, + ) + + +def make_adapting_span(data: Any, span_id: str = "span-1") -> AdaptingSpan: + """Create an AdaptingSpan with the given data.""" + span = Span.from_attributes( + rollout_id="rollout-1", + attempt_id="attempt-1", + sequence_id=0, + trace_id="trace-1", + span_id=span_id, + parent_id=None, + name="test-span", + attributes={}, + start_time=0.0, + end_time=1.0, + ) + return AdaptingSpan.from_span(span, data=data) + + +def make_adapting_sequence(calls: Sequence[Any]) -> AdaptingSequence[AdaptingSpan]: + """Create an AdaptingSequence from a list of calls.""" + spans = [make_adapting_span(call, span_id=f"span-{i}") for i, call in enumerate(calls)] + return AdaptingSequence(spans) + + +# ============================================================================== +# Tests for is_prefix function +# ============================================================================== + + +def test_is_prefix_empty_shorter(): + """Empty sequence is a prefix of any sequence.""" + assert is_prefix([], [1, 2, 3]) is True + assert is_prefix([], []) is True + + +def test_is_prefix_identical_sequences(): + """Identical sequences should be prefixes of each other.""" + assert is_prefix([1, 2, 3], [1, 2, 3]) is True + + +def test_is_prefix_true_case(): + """Shorter sequence is a prefix of longer sequence.""" + assert is_prefix([1, 2], [1, 2, 3, 4]) is True + assert is_prefix(["a"], ["a", "b", "c"]) is True + + +def test_is_prefix_false_mismatch(): + """Sequences with mismatched elements are not prefixes.""" + assert is_prefix([1, 2, 3], [1, 2, 4]) is False + assert is_prefix([1, 3], [1, 2, 3]) is False + + +def test_is_prefix_longer_than_sequence(): + """Longer sequence cannot be a prefix of shorter sequence.""" + assert is_prefix([1, 2, 3, 4], [1, 2, 3]) is False + + +def test_is_prefix_with_iterables(): + """is_prefix should work with any iterables.""" + assert is_prefix(iter([1, 2]), iter([1, 2, 3])) is True + assert is_prefix(range(3), [0, 1, 2, 3, 4]) is True + + +# ============================================================================== +# Tests for ToTokensTriplets +# ============================================================================== + + +def test_to_tokens_triplets_basic(): + """Basic conversion from chat completion call to tokens triplet.""" + call = make_chat_completion_call( + prompt_token_ids=[1, 2, 3], + completion_token_ids=[4, 5, 6], + ) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + triplets = adapter.adapt(source) + + assert len(triplets) == 1 + assert list(triplets[0].observation.token_ids) == [1, 2, 3] + assert list(triplets[0].action.token_ids) == [4, 5, 6] + assert triplets[0].done is True # Last triplet should be done + + +def test_to_tokens_triplets_with_reward(): + """Triplet should include reward from annotations.""" + call = make_annotated_call( + prompt_token_ids=[1, 2, 3], + completion_token_ids=[4, 5], + reward=0.75, + ) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + triplets = adapter.adapt(source) + + assert triplets[0].reward == 0.75 + + +def test_to_tokens_triplets_with_logprobs(): + """Triplet should include logprobs when available.""" + call = make_chat_completion_call( + prompt_token_ids=[1, 2, 3], + completion_token_ids=[4, 5], + logprobs=[-0.1, -0.2], + ) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + triplets = adapter.adapt(source) + + assert triplets[0].action.logprobs is not None + assert list(triplets[0].action.logprobs) == [-0.1, -0.2] + + +def test_to_tokens_triplets_with_image_urls(): + """Triplet should extract image URLs from messages.""" + request = make_completion_request( + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is this?"}, + {"type": "image_url", "image_url": {"url": "http://example.com/img1.png"}}, + {"type": "image_url", "image_url": {"url": "http://example.com/img2.png"}}, + ], + } + ], + prompt_token_ids=[1, 2, 3], + ) + response = make_chat_completion(token_ids=[4, 5]) + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + triplets = adapter.adapt(source) + + assert list(triplets[0].observation.image_urls) == [ + "http://example.com/img1.png", + "http://example.com/img2.png", + ] + + +def test_to_tokens_triplets_missing_prompt_token_ids(): + """Should raise error when prompt_token_ids is missing.""" + request = make_completion_request() # No prompt_token_ids + response = make_chat_completion(token_ids=[4, 5]) + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + + with pytest.raises(RuntimeError, match="failed to create any triplets"): + adapter.adapt(source) + + +def test_to_tokens_triplets_missing_completion_token_ids(): + """Should raise error when completion token_ids is missing.""" + request = make_completion_request(prompt_token_ids=[1, 2, 3]) + response = make_chat_completion() # No token_ids + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + + with pytest.raises(RuntimeError, match="failed to create any triplets"): + adapter.adapt(source) + + +def test_to_tokens_triplets_empty_prompt_token_ids(): + """Should raise error when prompt_token_ids is empty.""" + request = make_completion_request(prompt_token_ids=[]) + response = make_chat_completion(token_ids=[4, 5]) + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + + with pytest.raises(RuntimeError, match="failed to create any triplets"): + adapter.adapt(source) + + +def test_to_tokens_triplets_empty_completion_token_ids(): + """Should raise error when completion token_ids is empty.""" + request = make_completion_request(prompt_token_ids=[1, 2, 3]) + response = make_chat_completion(token_ids=[]) + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + + with pytest.raises(RuntimeError, match="failed to create any triplets"): + adapter.adapt(source) + + +def test_to_tokens_triplets_strict_mode(): + """In strict mode, should raise immediately on error.""" + request = make_completion_request() # Missing prompt_token_ids + response = make_chat_completion(token_ids=[4, 5]) + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets(strict=True) + + with pytest.raises(ValueError, match="Prompt token ids not found"): + adapter.adapt(source) + + +def test_to_tokens_triplets_multiple_calls(): + """Multiple calls should create multiple triplets.""" + call1 = make_chat_completion_call(prompt_token_ids=[1, 2], completion_token_ids=[3, 4]) + call2 = make_chat_completion_call(prompt_token_ids=[5, 6], completion_token_ids=[7, 8]) + source = make_adapting_sequence([call1, call2]) + + adapter = ToTokensTriplets() + triplets = adapter.adapt(source) + + assert len(triplets) == 2 + assert triplets[0].done is False + assert triplets[1].done is True + + +def test_to_tokens_triplets_provider_specific_token_ids(): + """Should extract token_ids from provider_specific_fields.""" + request = make_completion_request(prompt_token_ids=[1, 2, 3]) + response = make_chat_completion(provider_specific_fields={"token_ids": [4, 5, 6]}) + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + triplets = adapter.adapt(source) + + assert list(triplets[0].action.token_ids) == [4, 5, 6] + + +def test_to_tokens_triplets_skip_non_call_spans(): + """Should skip spans that don't contain chat completion calls.""" + call = make_chat_completion_call(prompt_token_ids=[1, 2], completion_token_ids=[3, 4]) + spans = [ + make_adapting_span("not a call", span_id="span-0"), + make_adapting_span(call, span_id="span-1"), + make_adapting_span(None, span_id="span-2"), + ] + source = AdaptingSequence(spans) + + adapter = ToTokensTriplets() + triplets = adapter.adapt(source) + + assert len(triplets) == 1 + + +# ============================================================================== +# Tests for ToTokensAccumulations +# ============================================================================== + + +def test_to_tokens_accumulations_single_triplet(): + """Single triplet should be converted to accumulation.""" + raw_call = make_chat_completion_call() + triplet = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2, 3], image_urls=[]), + action=TokenOutput(token_ids=[4, 5], logprobs=[-0.1, -0.2]), + reward=0.5, + done=True, + raw_call=raw_call, + ) + + adapter = ToTokensAccumulations() + accumulations = adapter.adapt([triplet]) + + assert len(accumulations) == 1 + assert list(accumulations[0].token_ids) == [1, 2, 3, 4, 5] + assert accumulations[0].response_mask == [0, 0, 0, 1, 1] + assert accumulations[0].final_reward == 0.5 + + +def test_to_tokens_accumulations_merge_sequential(): + """Sequential triplets with matching prefixes should be merged.""" + triplet1 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2], image_urls=[]), + action=TokenOutput(token_ids=[3, 4], logprobs=None), + reward=0.3, + done=False, + raw_call=make_chat_completion_call(), + ) + triplet2 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2, 3, 4, 5], image_urls=[]), + action=TokenOutput(token_ids=[6, 7], logprobs=None), + reward=0.4, + done=True, + raw_call=make_chat_completion_call(), + ) + + adapter = ToTokensAccumulations() + accumulations = adapter.adapt([triplet1, triplet2]) + + assert len(accumulations) == 1 + # Final tokens: [1, 2, 3, 4] + [5] + [6, 7] + assert list(accumulations[0].token_ids) == [1, 2, 3, 4, 5, 6, 7] + # Response mask: [0, 0, 1, 1] + [0] + [1, 1] + assert accumulations[0].response_mask == [0, 0, 1, 1, 0, 1, 1] + # Rewards should be summed + assert accumulations[0].final_reward == 0.7 + + +def test_to_tokens_accumulations_no_merge_mismatch(): + """Triplets with mismatched prefixes should not merge.""" + triplet1 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2], image_urls=[]), + action=TokenOutput(token_ids=[3, 4], logprobs=None), + reward=0.3, + done=False, + raw_call=make_chat_completion_call(), + ) + triplet2 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[5, 6, 7], image_urls=[]), # Doesn't start with [1,2,3,4] + action=TokenOutput(token_ids=[8, 9], logprobs=None), + reward=0.4, + done=True, + raw_call=make_chat_completion_call(), + ) + + adapter = ToTokensAccumulations() + accumulations = adapter.adapt([triplet1, triplet2]) + + assert len(accumulations) == 2 + + +def test_to_tokens_accumulations_image_url_mismatch(): + """Triplets with mismatched image URLs should not merge.""" + triplet1 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2], image_urls=["img1.png"]), + action=TokenOutput(token_ids=[3, 4], logprobs=None), + reward=None, + done=False, + raw_call=make_chat_completion_call(), + ) + triplet2 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2, 3, 4, 5], image_urls=["img2.png"]), # Different image + action=TokenOutput(token_ids=[6, 7], logprobs=None), + reward=None, + done=True, + raw_call=make_chat_completion_call(), + ) + + adapter = ToTokensAccumulations() + accumulations = adapter.adapt([triplet1, triplet2]) + + assert len(accumulations) == 2 + + +def test_to_tokens_accumulations_image_url_extension(): + """Image URLs can be extended when merging.""" + triplet1 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2], image_urls=["img1.png"]), + action=TokenOutput(token_ids=[3, 4], logprobs=None), + reward=None, + done=False, + raw_call=make_chat_completion_call(), + ) + triplet2 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2, 3, 4, 5], image_urls=["img1.png", "img2.png"]), + action=TokenOutput(token_ids=[6, 7], logprobs=None), + reward=None, + done=True, + raw_call=make_chat_completion_call(), + ) + + adapter = ToTokensAccumulations() + accumulations = adapter.adapt([triplet1, triplet2]) + + assert len(accumulations) == 1 + assert list(accumulations[0].image_urls) == ["img1.png", "img2.png"] + + +def test_to_tokens_accumulations_with_logprobs_merge(): + """Logprobs should be properly accumulated when merging.""" + triplet1 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2], image_urls=[]), + action=TokenOutput(token_ids=[3, 4], logprobs=[-0.1, -0.2]), + reward=None, + done=False, + raw_call=make_chat_completion_call(), + ) + triplet2 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2, 3, 4, 5], image_urls=[]), + action=TokenOutput(token_ids=[6, 7], logprobs=[-0.3, -0.4]), + reward=None, + done=True, + raw_call=make_chat_completion_call(), + ) + + adapter = ToTokensAccumulations() + accumulations = adapter.adapt([triplet1, triplet2]) + + assert len(accumulations) == 1 + # Final tokens: [1, 2, 3, 4] + [5] + [6, 7] = 7 tokens + assert list(accumulations[0].token_ids) == [1, 2, 3, 4, 5, 6, 7] + assert accumulations[0].logprobs is not None + # Logprobs should match token_ids length + # [0, 0, -0.1, -0.2] from first triplet + [0] for observation extension + [-0.3, -0.4] for action + expected_logprobs = [0.0, 0.0, -0.1, -0.2, 0.0, -0.3, -0.4] + assert list(accumulations[0].logprobs) == expected_logprobs + assert len(accumulations[0].logprobs) == len(accumulations[0].token_ids) + + +def test_to_tokens_accumulations_empty_input(): + """Empty input should return empty output.""" + adapter = ToTokensAccumulations() + accumulations = adapter.adapt([]) + + assert len(accumulations) == 0 + + +def test_to_tokens_accumulations_none_reward_handling(): + """None rewards should be handled correctly.""" + triplet1 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2], image_urls=[]), + action=TokenOutput(token_ids=[3, 4], logprobs=None), + reward=None, + done=False, + raw_call=make_chat_completion_call(), + ) + triplet2 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2, 3, 4], image_urls=[]), + action=TokenOutput(token_ids=[5, 6], logprobs=None), + reward=0.5, + done=True, + raw_call=make_chat_completion_call(), + ) + + adapter = ToTokensAccumulations() + accumulations = adapter.adapt([triplet1, triplet2]) + + assert len(accumulations) == 1 + assert accumulations[0].final_reward == 0.5 + + +def test_to_tokens_accumulations_raw_calls_aggregated(): + """Raw calls from merged triplets should be aggregated.""" + raw_call1 = make_chat_completion_call(completion_token_ids=[1]) + raw_call2 = make_chat_completion_call(completion_token_ids=[2]) + triplet1 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2], image_urls=[]), + action=TokenOutput(token_ids=[3, 4], logprobs=None), + reward=None, + done=False, + raw_call=raw_call1, + ) + triplet2 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2, 3, 4], image_urls=[]), + action=TokenOutput(token_ids=[5, 6], logprobs=None), + reward=None, + done=True, + raw_call=raw_call2, + ) + + adapter = ToTokensAccumulations() + accumulations = adapter.adapt([triplet1, triplet2]) + + assert len(accumulations[0].raw_calls) == 2 + # Verify they are the same objects (not copies) + assert accumulations[0].raw_calls[0].response.choices[0].token_ids == [1] # type: ignore + assert accumulations[0].raw_calls[1].response.choices[0].token_ids == [2] # type: ignore + + +# ============================================================================== +# Tests for ToPromptCompletionTriplets +# ============================================================================== + + +def test_to_prompt_completion_triplets_basic(): + """Basic conversion from chat completion call to prompt-completion triplet.""" + call = make_chat_completion_call() + source = make_adapting_sequence([call]) + + adapter = ToPromptCompletionTriplets() + triplets = adapter.adapt(source) + + assert len(triplets) == 1 + assert triplets[0].observation == call.request + assert triplets[0].action == call.response + assert triplets[0].done is True + + +def test_to_prompt_completion_triplets_with_reward(): + """Triplet should include reward from annotations.""" + call = make_annotated_call(reward=0.9) + source = make_adapting_sequence([call]) + + adapter = ToPromptCompletionTriplets() + triplets = adapter.adapt(source) + + assert triplets[0].reward == 0.9 + + +def test_to_prompt_completion_triplets_multiple(): + """Multiple calls should create multiple triplets with correct done flags.""" + call1 = make_chat_completion_call() + call2 = make_chat_completion_call() + call3 = make_chat_completion_call() + source = make_adapting_sequence([call1, call2, call3]) + + adapter = ToPromptCompletionTriplets() + triplets = adapter.adapt(source) + + assert len(triplets) == 3 + assert triplets[0].done is False + assert triplets[1].done is False + assert triplets[2].done is True + + +def test_to_prompt_completion_triplets_strict_mode(): + """In strict mode, errors should propagate immediately.""" + # Create a span with non-call data + span = make_adapting_span("not a call") + source = AdaptingSequence([span]) + + adapter = ToPromptCompletionTriplets() + + # Should raise because no valid triplets can be created + with pytest.raises(RuntimeError, match="failed to create any triplets"): + adapter.adapt(source) + + +# ============================================================================== +# Tests for ToPromptCompletionAccumulations +# ============================================================================== + + +def test_to_prompt_completion_accumulations_single(): + """Single triplet converts to single accumulation.""" + request = make_completion_request(messages=[{"role": "user", "content": "Hi"}]) + response = make_chat_completion(content="Hello!") + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + + triplet = PromptCompletionTriplet.model_construct( + observation=request, + action=response, + reward=0.5, + done=True, + raw_call=call, + ) + + adapter = ToPromptCompletionAccumulations() + accumulations = adapter.adapt([triplet]) + + assert len(accumulations) == 1 + # Should have user message + assistant response + assert len(accumulations[0].messages) == 2 + assert accumulations[0].messages[0]["role"] == "user" + assert accumulations[0].messages[1]["role"] == "assistant" + + +def test_to_prompt_completion_accumulations_merge(): + """Sequential triplets with matching messages should merge.""" + request1 = make_completion_request(messages=[{"role": "user", "content": "Hi"}]) + response1 = make_chat_completion(content="Hello!") + call1 = ChatCompletionCall.model_construct(request=request1, response=response1, malformed_fields={}) + + # Second request includes the previous conversation + request2 = make_completion_request( + messages=[ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + {"role": "user", "content": "How are you?"}, + ] + ) + response2 = make_chat_completion(content="I'm good!") + call2 = ChatCompletionCall.model_construct(request=request2, response=response2, malformed_fields={}) + + triplet1 = PromptCompletionTriplet.model_construct( + observation=request1, + action=response1, + reward=0.3, + done=False, + raw_call=call1, + ) + triplet2 = PromptCompletionTriplet.model_construct( + observation=request2, + action=response2, + reward=0.4, + done=True, + raw_call=call2, + ) + + adapter = ToPromptCompletionAccumulations() + accumulations = adapter.adapt([triplet1, triplet2]) + + assert len(accumulations) == 1 + # Should have all 4 messages + assert len(accumulations[0].messages) == 4 + assert accumulations[0].final_reward == 0.7 + + +def test_to_prompt_completion_accumulations_no_merge_different_tools(): + """Triplets with different tools should not merge.""" + request1 = make_completion_request(tools=[{"type": "function", "function": {"name": "tool1"}}]) + response1 = make_chat_completion() + call1 = ChatCompletionCall.model_construct(request=request1, response=response1, malformed_fields={}) + + request2 = make_completion_request(tools=[{"type": "function", "function": {"name": "tool2"}}]) + response2 = make_chat_completion() + call2 = ChatCompletionCall.model_construct(request=request2, response=response2, malformed_fields={}) + + triplet1 = PromptCompletionTriplet.model_construct( + observation=request1, + action=response1, + reward=None, + done=False, + raw_call=call1, + ) + triplet2 = PromptCompletionTriplet.model_construct( + observation=request2, + action=response2, + reward=None, + done=True, + raw_call=call2, + ) + + adapter = ToPromptCompletionAccumulations() + accumulations = adapter.adapt([triplet1, triplet2]) + + assert len(accumulations) == 2 + + +def test_to_prompt_completion_accumulations_no_merge_message_mismatch(): + """Triplets with mismatched message history should not merge.""" + request1 = make_completion_request(messages=[{"role": "user", "content": "Hi"}]) + response1 = make_chat_completion(content="Hello!") + call1 = ChatCompletionCall.model_construct(request=request1, response=response1, malformed_fields={}) + + # Second request has completely different history + request2 = make_completion_request(messages=[{"role": "user", "content": "Different question"}]) + response2 = make_chat_completion(content="Different answer") + call2 = ChatCompletionCall.model_construct(request=request2, response=response2, malformed_fields={}) + + triplet1 = PromptCompletionTriplet.model_construct( + observation=request1, + action=response1, + reward=None, + done=False, + raw_call=call1, + ) + triplet2 = PromptCompletionTriplet.model_construct( + observation=request2, + action=response2, + reward=None, + done=True, + raw_call=call2, + ) + + adapter = ToPromptCompletionAccumulations() + accumulations = adapter.adapt([triplet1, triplet2]) + + assert len(accumulations) == 2 + + +def test_to_prompt_completion_accumulations_preserves_tools(): + """Accumulation should preserve tool definitions.""" + tools: List[Dict[str, Any]] = [{"type": "function", "function": {"name": "my_tool", "parameters": {}}}] + request = make_completion_request(tools=tools) + response = make_chat_completion() + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + + triplet = PromptCompletionTriplet.model_construct( + observation=request, + action=response, + reward=None, + done=True, + raw_call=call, + ) + + adapter = ToPromptCompletionAccumulations() + accumulations = adapter.adapt([triplet]) + + assert accumulations[0].tools == tools + + +def test_to_prompt_completion_accumulations_empty(): + """Empty input should return empty output.""" + adapter = ToPromptCompletionAccumulations() + accumulations = adapter.adapt([]) + + assert len(accumulations) == 0 + + +# ============================================================================== +# Tests for PropagateRewards +# ============================================================================== + + +def test_propagate_rewards_forward_triplets(): + """Forward propagation fills None rewards with previous value.""" + triplets = [ + TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1], image_urls=[]), + action=TokenOutput(token_ids=[2], logprobs=None), + reward=0.5, + done=False, + raw_call=make_chat_completion_call(), + ), + TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2], image_urls=[]), + action=TokenOutput(token_ids=[3], logprobs=None), + reward=None, + done=False, + raw_call=make_chat_completion_call(), + ), + TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2, 3], image_urls=[]), + action=TokenOutput(token_ids=[4], logprobs=None), + reward=None, + done=True, + raw_call=make_chat_completion_call(), + ), + ] + + adapter = PropagateRewards(direction="forward") + result = adapter.adapt(triplets) + + assert result[0].reward == 0.5 + assert result[1].reward == 0.5 + assert result[2].reward == 0.5 + + +def test_propagate_rewards_backward_triplets(): + """Backward propagation fills None rewards with next value.""" + triplets = [ + TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1], image_urls=[]), + action=TokenOutput(token_ids=[2], logprobs=None), + reward=None, + done=False, + raw_call=make_chat_completion_call(), + ), + TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2], image_urls=[]), + action=TokenOutput(token_ids=[3], logprobs=None), + reward=None, + done=False, + raw_call=make_chat_completion_call(), + ), + TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2, 3], image_urls=[]), + action=TokenOutput(token_ids=[4], logprobs=None), + reward=0.8, + done=True, + raw_call=make_chat_completion_call(), + ), + ] + + adapter = PropagateRewards(direction="backward") + result = adapter.adapt(triplets) + + assert result[0].reward == 0.8 + assert result[1].reward == 0.8 + assert result[2].reward == 0.8 + + +def test_propagate_rewards_forward_accumulations(): + """Forward propagation works with accumulations (uses final_reward).""" + accumulations = [ + TokensAccumulation( + token_ids=[1, 2], + image_urls=[], + logprobs=None, + response_mask=[0, 1], + final_reward=0.3, + raw_calls=[], + diagnosis_info=None, + ), + TokensAccumulation( + token_ids=[3, 4], + image_urls=[], + logprobs=None, + response_mask=[0, 1], + final_reward=None, + raw_calls=[], + diagnosis_info=None, + ), + ] + + adapter = PropagateRewards(direction="forward") + result = adapter.adapt(accumulations) + + assert result[0].final_reward == 0.3 + assert result[1].final_reward == 0.3 + + +def test_propagate_rewards_backward_accumulations(): + """Backward propagation works with accumulations.""" + accumulations = [ + TokensAccumulation( + token_ids=[1, 2], + image_urls=[], + logprobs=None, + response_mask=[0, 1], + final_reward=None, + raw_calls=[], + diagnosis_info=None, + ), + TokensAccumulation( + token_ids=[3, 4], + image_urls=[], + logprobs=None, + response_mask=[0, 1], + final_reward=0.7, + raw_calls=[], + diagnosis_info=None, + ), + ] + + adapter = PropagateRewards(direction="backward") + result = adapter.adapt(accumulations) + + assert result[0].final_reward == 0.7 + assert result[1].final_reward == 0.7 + + +def test_propagate_rewards_preserves_existing(): + """Propagation should not overwrite existing rewards.""" + triplets = [ + TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1], image_urls=[]), + action=TokenOutput(token_ids=[2], logprobs=None), + reward=0.1, + done=False, + raw_call=make_chat_completion_call(), + ), + TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2], image_urls=[]), + action=TokenOutput(token_ids=[3], logprobs=None), + reward=0.5, + done=False, + raw_call=make_chat_completion_call(), + ), + TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2, 3], image_urls=[]), + action=TokenOutput(token_ids=[4], logprobs=None), + reward=None, + done=True, + raw_call=make_chat_completion_call(), + ), + ] + + adapter = PropagateRewards(direction="forward") + result = adapter.adapt(triplets) + + assert result[0].reward == 0.1 + assert result[1].reward == 0.5 + assert result[2].reward == 0.5 # Propagated from index 1 + + +def test_propagate_rewards_empty_input(): + """Empty input should return empty output.""" + adapter = PropagateRewards(direction="forward") + triplets: List[TokensTriplet] = [] + result = adapter.adapt(triplets) + + assert len(result) == 0 + + +def test_propagate_rewards_all_none(): + """All None rewards should remain None.""" + triplets = [ + TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1], image_urls=[]), + action=TokenOutput(token_ids=[2], logprobs=None), + reward=None, + done=False, + raw_call=make_chat_completion_call(), + ), + TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2], image_urls=[]), + action=TokenOutput(token_ids=[3], logprobs=None), + reward=None, + done=True, + raw_call=make_chat_completion_call(), + ), + ] + + adapter = PropagateRewards(direction="forward") + result = adapter.adapt(triplets) + + assert result[0].reward is None + assert result[1].reward is None + + +def test_propagate_rewards_prompt_completion_triplets(): + """Propagation works with PromptCompletionTriplet.""" + request = make_completion_request() + response = make_chat_completion() + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + + triplets = [ + PromptCompletionTriplet.model_construct( + observation=request, + action=response, + reward=0.6, + done=False, + raw_call=call, + ), + PromptCompletionTriplet.model_construct( + observation=request, + action=response, + reward=None, + done=True, + raw_call=call, + ), + ] + + adapter = PropagateRewards(direction="forward") + result = adapter.adapt(triplets) + + assert result[0].reward == 0.6 + assert result[1].reward == 0.6 + + +def test_propagate_rewards_prompt_completion_accumulations(): + """Propagation works with PromptCompletionAccumulation.""" + accumulations = [ + PromptCompletionAccumulation( + messages=[{"role": "user", "content": "Hi"}], + tools=None, + final_reward=None, + raw_calls=[], + ), + PromptCompletionAccumulation( + messages=[{"role": "user", "content": "Bye"}], + tools=None, + final_reward=0.9, + raw_calls=[], + ), + ] + + adapter = PropagateRewards(direction="backward") + result = adapter.adapt(accumulations) + + assert result[0].final_reward == 0.9 + assert result[1].final_reward == 0.9 + + +# ============================================================================== +# Tests for ToTokensAccumulations diagnosis feature +# ============================================================================== + + +def test_to_tokens_accumulations_diagnosis_disabled_by_default(): + """Diagnosis info should be None when diagnosis=False (default).""" + triplet1 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2], image_urls=[]), + action=TokenOutput(token_ids=[3, 4], logprobs=None), + reward=None, + done=False, + raw_call=make_chat_completion_call(), + ) + triplet2 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[5, 6, 7], image_urls=[]), # Mismatch + action=TokenOutput(token_ids=[8, 9], logprobs=None), + reward=None, + done=True, + raw_call=make_chat_completion_call(), + ) + + adapter = ToTokensAccumulations(diagnosis=False) + accumulations = adapter.adapt([triplet1, triplet2]) + + assert len(accumulations) == 2 + assert accumulations[0].diagnosis_info is None + assert accumulations[1].diagnosis_info is None + + +# ============================================================================== +# Edge case tests +# ============================================================================== + + +def test_to_tokens_triplets_invalid_prompt_token_ids_type(): + """Should raise error when prompt_token_ids is not a list of ints.""" + request = make_completion_request() + request["prompt_token_ids"] = "not a list" + response = make_chat_completion(token_ids=[4, 5]) + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + + with pytest.raises(RuntimeError, match="failed to create any triplets"): + adapter.adapt(source) + + +def test_to_tokens_triplets_invalid_completion_token_ids_type(): + """Should raise error when completion token_ids is not a list of ints.""" + request = make_completion_request(prompt_token_ids=[1, 2, 3]) + response = make_chat_completion(token_ids=["not", "ints"]) # type: ignore + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + + with pytest.raises(RuntimeError, match="failed to create any triplets"): + adapter.adapt(source) + + +def test_to_tokens_triplets_message_without_content(): + """Messages without content should be handled gracefully.""" + request = make_completion_request( + messages=[{"role": "assistant"}], # No content field + prompt_token_ids=[1, 2, 3], + ) + response = make_chat_completion(token_ids=[4, 5]) + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + triplets = adapter.adapt(source) + + # Should succeed with empty image_urls + assert len(triplets) == 1 + assert list(triplets[0].observation.image_urls) == [] + + +def test_to_tokens_triplets_message_with_none_content(): + """Messages with None content should be handled gracefully.""" + request = make_completion_request( + messages=[{"role": "assistant", "content": None}], + prompt_token_ids=[1, 2, 3], + ) + response = make_chat_completion(token_ids=[4, 5]) + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + triplets = adapter.adapt(source) + + assert len(triplets) == 1 + assert list(triplets[0].observation.image_urls) == [] + + +def test_to_tokens_triplets_mixed_content_types(): + """Messages with mixed content types should extract only image URLs.""" + request = make_completion_request( + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe these"}, + {"type": "image_url", "image_url": {"url": "http://img1.png"}}, + {"type": "text", "text": "Thanks"}, + ], + } + ], + prompt_token_ids=[1, 2, 3], + ) + response = make_chat_completion(token_ids=[4, 5]) + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + triplets = adapter.adapt(source) + + assert list(triplets[0].observation.image_urls) == ["http://img1.png"] + + +def test_get_reward_from_non_annotated_call(): + """get_reward should return None for non-annotated calls.""" + call = make_chat_completion_call() + adapter = ToTokensTriplets() + + reward = adapter.get_reward(call) + + assert reward is None + + +def test_get_reward_from_annotated_call_without_general_annotation(): + """get_reward should return None if no GeneralAnnotation exists.""" + call = make_annotated_call() # No reward specified + # Manually clear annotations + call = AnnotatedChatCompletionCall( + request=call.request, + response=call.response, + malformed_fields={}, + annotations=[], # Empty annotations + ) + + adapter = ToTokensTriplets() + reward = adapter.get_reward(call) + + assert reward is None + + +def test_to_triplets_sets_done_on_last(): + """to_triplets should set done=True only on the last triplet.""" + call1 = make_chat_completion_call(prompt_token_ids=[1], completion_token_ids=[2]) + call2 = make_chat_completion_call(prompt_token_ids=[3], completion_token_ids=[4]) + call3 = make_chat_completion_call(prompt_token_ids=[5], completion_token_ids=[6]) + source = make_adapting_sequence([call1, call2, call3]) + + adapter = ToTokensTriplets() + triplets = adapter.adapt(source) + + assert triplets[0].done is False + assert triplets[1].done is False + assert triplets[2].done is True diff --git a/tests/adapter/test_preprocess.py b/tests/adapter/test_preprocess.py new file mode 100644 index 000000000..4caea5932 --- /dev/null +++ b/tests/adapter/test_preprocess.py @@ -0,0 +1,928 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for the preprocess module adapters.""" + +import itertools +from typing import Any, Dict, List, Optional + +import pytest + +from agentlightning.adapter.preprocess import _TreeLikeGraph # pyright: ignore[reportPrivateUsage] +from agentlightning.adapter.preprocess import ( + RepairMalformedSpans, + ToAdaptingSpans, + ToSpans, + ToTree, + default_span_order, +) +from agentlightning.semconv import AGL_VIRTUAL +from agentlightning.types import Span +from agentlightning.types.adapter import AdaptingSpan + +_SEQ = itertools.count() + + +def make_span( + span_id: str, + name: str, + *, + parent_id: Optional[str], + start_time: float, + end_time: float, + attributes: Optional[Dict[str, Any]] = None, + rollout_id: str = "rollout-1", + attempt_id: str = "attempt-1", + sequence_id: Optional[int] = None, +) -> Span: + """Create a test span with sensible defaults.""" + return Span.from_attributes( + rollout_id=rollout_id, + attempt_id=attempt_id, + sequence_id=sequence_id if sequence_id is not None else next(_SEQ), + trace_id="trace-1", + span_id=span_id, + parent_id=parent_id, + name=name, + attributes=attributes or {}, + start_time=start_time, + end_time=end_time, + ) + + +# Tests for default_span_order + + +def test_default_span_order_by_sequence_id(): + span1 = make_span("s1", "span", parent_id=None, start_time=0.0, end_time=1.0, sequence_id=2) + span2 = make_span("s2", "span", parent_id=None, start_time=0.0, end_time=1.0, sequence_id=1) + spans = [span1, span2] + sorted_spans = sorted(spans, key=default_span_order) + assert [s.span_id for s in sorted_spans] == ["s2", "s1"] + + +def test_default_span_order_by_start_time(): + span1 = make_span("s1", "span", parent_id=None, start_time=2.0, end_time=3.0, sequence_id=0) + span2 = make_span("s2", "span", parent_id=None, start_time=1.0, end_time=3.0, sequence_id=0) + spans = [span1, span2] + sorted_spans = sorted(spans, key=default_span_order) + assert [s.span_id for s in sorted_spans] == ["s2", "s1"] + + +def test_default_span_order_by_end_time(): + span1 = make_span("s1", "span", parent_id=None, start_time=1.0, end_time=3.0, sequence_id=0) + span2 = make_span("s2", "span", parent_id=None, start_time=1.0, end_time=2.0, sequence_id=0) + spans = [span1, span2] + sorted_spans = sorted(spans, key=default_span_order) + assert [s.span_id for s in sorted_spans] == ["s2", "s1"] + + +# Tests for _TreeLikeGraph + + +def test_tree_like_graph_from_spans_creates_correct_graph(): + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=10.0) + child1 = make_span("child1", "child", parent_id="root", start_time=1.0, end_time=5.0) + child2 = make_span("child2", "child", parent_id="root", start_time=5.0, end_time=9.0) + grandchild = make_span("grandchild", "grandchild", parent_id="child1", start_time=2.0, end_time=4.0) + + graph = _TreeLikeGraph.from_spans([root, child1, child2, grandchild]) + + assert graph.root_ids == {"root"} + assert set(graph.forward_graph["root"]) == {"child1", "child2"} + assert graph.forward_graph["child1"] == ["grandchild"] + assert graph.parent_map["child1"] == "root" + assert graph.parent_map["child2"] == "root" + assert graph.parent_map["grandchild"] == "child1" + + +def test_tree_like_graph_from_spans_handles_invalid_parent(): + """Spans with invalid parent IDs should be treated as roots.""" + orphan = make_span("orphan", "orphan", parent_id="nonexistent", start_time=0.0, end_time=1.0) + + graph = _TreeLikeGraph.from_spans([orphan]) + + assert graph.root_ids == {"orphan"} + assert "orphan" not in graph.parent_map + + +def test_tree_like_graph_from_spans_multiple_roots(): + root1 = make_span("root1", "root", parent_id=None, start_time=0.0, end_time=5.0) + root2 = make_span("root2", "root", parent_id=None, start_time=5.0, end_time=10.0) + + graph = _TreeLikeGraph.from_spans([root1, root2]) + + assert graph.root_ids == {"root1", "root2"} + + +def test_tree_like_graph_compute_depths(): + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=10.0) + child = make_span("child", "child", parent_id="root", start_time=1.0, end_time=9.0) + grandchild = make_span("grandchild", "grandchild", parent_id="child", start_time=2.0, end_time=8.0) + + graph = _TreeLikeGraph.from_spans([root, child, grandchild]) + depths = graph.compute_depths() + + assert depths["root"] == 0 + assert depths["child"] == 1 + assert depths["grandchild"] == 2 + + +def test_tree_like_graph_compute_ancestors(): + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=10.0) + child = make_span("child", "child", parent_id="root", start_time=1.0, end_time=9.0) + grandchild = make_span("grandchild", "grandchild", parent_id="child", start_time=2.0, end_time=8.0) + + graph = _TreeLikeGraph.from_spans([root, child, grandchild]) + ancestors = graph.compute_ancestors() + + assert ancestors["root"] == set() + assert ancestors["child"] == {"root"} + assert ancestors["grandchild"] == {"root", "child"} + + +def test_tree_like_graph_move_subtree(): + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=10.0) + child1 = make_span("child1", "child", parent_id="root", start_time=1.0, end_time=5.0) + child2 = make_span("child2", "child", parent_id="root", start_time=5.0, end_time=9.0) + grandchild = make_span("grandchild", "grandchild", parent_id="child1", start_time=2.0, end_time=4.0) + + graph = _TreeLikeGraph.from_spans([root, child1, child2, grandchild]) + graph.move_subtree("grandchild", "child2") + + assert "grandchild" not in graph.forward_graph["child1"] + assert "grandchild" in graph.forward_graph["child2"] + assert graph.parent_map["grandchild"] == "child2" + + +def test_tree_like_graph_move_subtree_from_root(): + """Moving a root node should remove it from root_ids.""" + root1 = make_span("root1", "root", parent_id=None, start_time=0.0, end_time=10.0) + root2 = make_span("root2", "root", parent_id=None, start_time=0.0, end_time=10.0) + + graph = _TreeLikeGraph.from_spans([root1, root2]) + assert "root2" in graph.root_ids + + graph.move_subtree("root2", "root1") + + assert "root2" not in graph.root_ids + assert graph.parent_map["root2"] == "root1" + + +def test_tree_like_graph_to_tree_single_root(): + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=10.0) + child1 = make_span("child1", "child", parent_id="root", start_time=1.0, end_time=5.0) + child2 = make_span("child2", "child", parent_id="root", start_time=5.0, end_time=9.0) + + graph = _TreeLikeGraph.from_spans([root, child1, child2]) + tree = graph.to_tree() + + assert tree.item.span_id == "root" + assert len(tree.children) == 2 + child_ids = {child.item.span_id for child in tree.children} + assert child_ids == {"child1", "child2"} + + +def test_tree_like_graph_to_tree_multiple_roots_raises(): + root1 = make_span("root1", "root", parent_id=None, start_time=0.0, end_time=5.0) + root2 = make_span("root2", "root", parent_id=None, start_time=5.0, end_time=10.0) + + graph = _TreeLikeGraph.from_spans([root1, root2]) + + with pytest.raises(ValueError, match="multiple or no roots"): + graph.to_tree() + + +# Tests for ToSpans + + +def test_to_spans_pass_through_span(): + """Span objects should pass through unchanged.""" + span = make_span("s1", "span", parent_id=None, start_time=0.0, end_time=1.0) + adapter = ToSpans() + + result = adapter.adapt_one(span) + + assert result is span + + +def test_to_spans_default_values(): + """Adapter should use default values for rollout_id, attempt_id, and sequence_id.""" + adapter = ToSpans( + default_rollout_id="my-rollout", + default_attempt_id="my-attempt", + default_sequence_id=42, + ) + span = make_span("s1", "span", parent_id=None, start_time=0.0, end_time=1.0) + + result = adapter.adapt_one(span) + + # Span passes through unchanged, defaults only apply to OpenTelemetry spans + assert result is span + + +# Tests for ToTree + + +def test_to_tree_basic_creation(): + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=10.0) + child = make_span("child", "child", parent_id="root", start_time=1.0, end_time=9.0) + spans = [root, child] + + adapter = ToTree() + tree = adapter.adapt(spans) + + assert tree.item.span_id == "root" + assert len(tree.children) == 1 + assert tree.children[0].item.span_id == "child" + + +def test_to_tree_empty_spans_raises(): + adapter = ToTree() + + with pytest.raises(ValueError, match="No spans provided"): + adapter.adapt([]) + + +def test_to_tree_non_sequence_raises(): + adapter = ToTree() + + # String is technically a sequence but will fail when trying to access span attributes + with pytest.raises(AttributeError): + adapter.adapt("not a sequence") # type: ignore + + +def test_to_tree_repair_multiple_roots(): + root1 = make_span("root1", "root", parent_id=None, start_time=0.0, end_time=5.0) + root2 = make_span("root2", "root", parent_id=None, start_time=5.0, end_time=10.0) + spans = [root1, root2] + + adapter = ToTree(repair_multiple_roots=True) + tree = adapter.adapt(spans) + + assert tree.item.name == AGL_VIRTUAL + assert len(tree.children) == 2 + child_ids = {child.item.span_id for child in tree.children} + assert child_ids == {"root1", "root2"} + + +def test_to_tree_repair_multiple_roots_disabled(): + root1 = make_span("root1", "root", parent_id=None, start_time=0.0, end_time=5.0) + root2 = make_span("root2", "root", parent_id=None, start_time=5.0, end_time=10.0) + spans = [root1, root2] + + adapter = ToTree(repair_multiple_roots=False) + + with pytest.raises(ValueError, match="multiple or no roots"): + adapter.adapt(spans) + + +def test_to_tree_invalid_parent_raises_error(): + """Spans with invalid parent IDs should raise ValueError.""" + orphan = make_span("orphan", "orphan", parent_id="missing-parent", start_time=0.0, end_time=1.0) + spans = [orphan] + + adapter = ToTree() + + with pytest.raises(ValueError, match="non-existent parent IDs"): + adapter.adapt(spans) + + +def test_to_tree_with_repaired_invalid_parents(): + """Using RepairMalformedSpans before ToTree should handle invalid parent IDs.""" + orphan = make_span("orphan", "orphan", parent_id="missing-parent", start_time=0.0, end_time=1.0) + spans = [orphan] + + # First repair invalid parent IDs + repair_adapter = RepairMalformedSpans(ensure_valid_parent_ids=True) + repaired = repair_adapter.adapt(spans) + + # Now ToTree should work + tree_adapter = ToTree() + tree = tree_adapter.adapt(repaired) + + # Orphan becomes root since parent_id was set to None + assert tree.item.span_id == "orphan" + + +def test_to_tree_repair_bad_hierarchy_dangling(): + """Dangling spans should be re-attached based on time containment.""" + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=10.0) + container = make_span("container", "container", parent_id="root", start_time=1.0, end_time=9.0) + # Dangling span (no parent) that should fit inside container + dangling = make_span("dangling", "dangling", parent_id=None, start_time=2.0, end_time=8.0) + spans = [root, container, dangling] + + adapter = ToTree(repair_bad_hierarchy="dangling") + tree = adapter.adapt(spans) + + # Dangling should be moved under container (best fit by time) + container_node = next(c for c in tree.children if c.item.span_id == "container") + dangling_in_container = any(c.item.span_id == "dangling" for c in container_node.children) + assert dangling_in_container + + +def test_to_tree_repair_bad_hierarchy_none(): + """When repair_bad_hierarchy is 'none', hierarchy is not repaired.""" + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=10.0) + dangling = make_span("dangling", "dangling", parent_id=None, start_time=2.0, end_time=8.0) + spans = [root, dangling] + + adapter = ToTree(repair_bad_hierarchy="none", repair_multiple_roots=True) + tree = adapter.adapt(spans) + + # Both should be roots under virtual root + assert tree.item.name == AGL_VIRTUAL + child_ids = {c.item.span_id for c in tree.children} + assert child_ids == {"root", "dangling"} + + +def test_to_tree_children_sorted_by_time(): + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=10.0) + child_late = make_span("child-late", "child", parent_id="root", start_time=5.0, end_time=9.0, sequence_id=0) + child_early = make_span("child-early", "child", parent_id="root", start_time=1.0, end_time=4.0, sequence_id=0) + spans = [root, child_late, child_early] + + adapter = ToTree() + tree = adapter.adapt(spans) + + child_ids = [c.item.span_id for c in tree.children] + assert child_ids == ["child-early", "child-late"] + + +def test_to_tree_adapting_span_properties(): + """Tree should contain AdaptingSpan instances with proper container references.""" + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=10.0) + child = make_span("child", "child", parent_id="root", start_time=1.0, end_time=9.0) + spans = [root, child] + + adapter = ToTree() + tree = adapter.adapt(spans) + + assert isinstance(tree.item, AdaptingSpan) + assert isinstance(tree.children[0].item, AdaptingSpan) + + +# Tests for ToAdaptingSpans + + +def test_to_adapting_spans_sorts_by_default_order(): + span1 = make_span("s1", "span", parent_id=None, start_time=2.0, end_time=3.0, sequence_id=0) + span2 = make_span("s2", "span", parent_id=None, start_time=1.0, end_time=3.0, sequence_id=0) + spans = [span1, span2] + + adapter = ToAdaptingSpans() + result = adapter.adapt(spans) + + span_ids = [s.span_id for s in result] + assert span_ids == ["s2", "s1"] + + +def test_to_adapting_spans_returns_adapting_sequence(): + span = make_span("s1", "span", parent_id=None, start_time=0.0, end_time=1.0) + spans = [span] + + adapter = ToAdaptingSpans() + result = adapter.adapt(spans) + + assert len(result) == 1 + assert isinstance(result[0], AdaptingSpan) + + +# Tests for RepairMalformedSpans + + +def test_repair_malformed_spans_missing_start_time(): + span = Span.from_attributes( + rollout_id="r1", + attempt_id="a1", + sequence_id=0, + trace_id="t1", + span_id="s1", + parent_id=None, + name="span", + attributes={}, + start_time=None, + end_time=5.0, + ) + spans = [span] + + adapter = RepairMalformedSpans() + result = adapter.adapt(spans) + + # Start time should be set to max of all times (5.0) + assert result[0].start_time == 5.0 + + +def test_repair_malformed_spans_missing_end_time(): + span = Span.from_attributes( + rollout_id="r1", + attempt_id="a1", + sequence_id=0, + trace_id="t1", + span_id="s1", + parent_id=None, + name="span", + attributes={}, + start_time=1.0, + end_time=None, + ) + spans = [span] + + adapter = RepairMalformedSpans() + result = adapter.adapt(spans) + + # End time should be set to max of all times (1.0) + assert result[0].end_time == 1.0 + + +def test_repair_malformed_spans_both_missing_times(): + span = Span.from_attributes( + rollout_id="r1", + attempt_id="a1", + sequence_id=0, + trace_id="t1", + span_id="s1", + parent_id=None, + name="span", + attributes={}, + start_time=None, + end_time=None, + ) + spans = [span] + + adapter = RepairMalformedSpans() + result = adapter.adapt(spans) + + # Both should be set to current time (they should be equal and non-None) + assert result[0].start_time is not None + assert result[0].end_time is not None + assert result[0].start_time == result[0].end_time + + +def test_repair_malformed_spans_negative_duration(): + """When end_time < start_time, end_time should be set to start_time.""" + span = make_span("s1", "span", parent_id=None, start_time=5.0, end_time=3.0) + spans = [span] + + adapter = RepairMalformedSpans(ensure_positive_duration=True) + result = adapter.adapt(spans) + + assert result[0].end_time == 5.0 + + +def test_repair_malformed_spans_no_repair_negative_duration_when_disabled(): + span = make_span("s1", "span", parent_id=None, start_time=5.0, end_time=3.0) + spans = [span] + + adapter = RepairMalformedSpans(ensure_positive_duration=False) + result = adapter.adapt(spans) + + assert result[0].end_time == 3.0 + + +def test_repair_malformed_spans_invalid_parent_ids(): + span = make_span("s1", "span", parent_id="nonexistent", start_time=0.0, end_time=1.0) + spans = [span] + + adapter = RepairMalformedSpans(ensure_valid_parent_ids=True) + result = adapter.adapt(spans) + + assert result[0].parent_id is None + + +def test_repair_malformed_spans_no_repair_invalid_parent_ids_when_disabled(): + span = make_span("s1", "span", parent_id="nonexistent", start_time=0.0, end_time=1.0) + spans = [span] + + adapter = RepairMalformedSpans(ensure_valid_parent_ids=False) + result = adapter.adapt(spans) + + assert result[0].parent_id == "nonexistent" + + +def test_repair_malformed_spans_proper_nesting(): + """Parent span's time range should be expanded to contain all children.""" + parent = make_span("parent", "parent", parent_id=None, start_time=2.0, end_time=8.0) + child = make_span("child", "child", parent_id="parent", start_time=1.0, end_time=9.0) + spans = [parent, child] + + adapter = RepairMalformedSpans(ensure_proper_nesting=True) + result = adapter.adapt(spans) + + parent_result = next(s for s in result if s.span_id == "parent") + # Parent's time should be expanded to contain child + assert parent_result.start_time == 1.0 + assert parent_result.end_time == 9.0 + + +def test_repair_malformed_spans_no_repair_proper_nesting_when_disabled(): + parent = make_span("parent", "parent", parent_id=None, start_time=2.0, end_time=8.0) + child = make_span("child", "child", parent_id="parent", start_time=1.0, end_time=9.0) + spans = [parent, child] + + adapter = RepairMalformedSpans(ensure_proper_nesting=False) + result = adapter.adapt(spans) + + parent_result = next(s for s in result if s.span_id == "parent") + assert parent_result.start_time == 2.0 + assert parent_result.end_time == 8.0 + + +def test_repair_malformed_spans_unchanged_pass_through(): + """Spans that don't need repair should not be modified.""" + span = make_span("s1", "span", parent_id=None, start_time=0.0, end_time=1.0) + spans = [span] + + adapter = RepairMalformedSpans() + result = adapter.adapt(spans) + + # The span object should be the same (not copied) + assert result[0] is span + + +# Integration tests + + +def test_integration_complex_tree_with_repairs(): + """Test a complex scenario with multiple issues that need repair.""" + # Root span + root = make_span("root", "session", parent_id=None, start_time=0.0, end_time=100.0) + + # Agent span with misaligned timing + agent = make_span("agent", "agent.node", parent_id="root", start_time=5.0, end_time=90.0) + + # LLM span that's a sibling of agent but should be under it (dangling repair) + llm = make_span("llm", "openai.chat", parent_id="root", start_time=10.0, end_time=20.0) + + # Orphan span with missing parent + orphan = make_span("orphan", "tool.call", parent_id="missing", start_time=30.0, end_time=40.0) + + spans = [root, agent, llm, orphan] + + # First repair invalid parent IDs + repair_adapter = RepairMalformedSpans(ensure_valid_parent_ids=True) + repaired = repair_adapter.adapt(spans) + + # Apply hierarchy repairs and convert to tree + tree_adapter = ToTree( + repair_bad_hierarchy="dangling", + repair_multiple_roots=True, + ) + tree = tree_adapter.adapt(repaired) + + # The tree should be properly structured + assert tree.size() >= 4 # At least the original spans + + +def test_integration_adapting_spans_after_tree(): + """AdaptingSpans should work correctly on tree output.""" + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=10.0) + child = make_span("child", "child", parent_id="root", start_time=1.0, end_time=9.0) + spans = [root, child] + + tree_adapter = ToTree() + tree = tree_adapter.adapt(spans) + + # Traverse and verify all items are AdaptingSpans + for span in tree.traverse(): + assert isinstance(span, AdaptingSpan) + + +def test_integration_repair_then_tree(): + """RepairMalformedSpans followed by ToTree should work correctly.""" + parent = make_span("parent", "parent", parent_id=None, start_time=5.0, end_time=3.0) # Invalid: end < start + child = make_span("child", "child", parent_id="parent", start_time=1.0, end_time=2.0) + + # First repair the spans + repair_adapter = RepairMalformedSpans() + repaired = repair_adapter.adapt([parent, child]) + + # Then create tree + tree_adapter = ToTree() + tree = tree_adapter.adapt(repaired) + + # Parent should have repaired time and contain child + assert tree.item.start_time is not None + assert tree.item.end_time is not None + assert tree.item.end_time >= tree.item.start_time + + +# Edge case tests + + +def test_edge_case_single_span_tree(): + span = make_span("only", "only", parent_id=None, start_time=0.0, end_time=1.0) + + adapter = ToTree() + tree = adapter.adapt([span]) + + assert tree.item.span_id == "only" + assert len(tree.children) == 0 + + +def test_edge_case_deep_tree(): + """Test a deeply nested tree.""" + spans: List[Span] = [] + for i in range(10): + parent_id = f"span-{i-1}" if i > 0 else None + spans.append( + make_span( + f"span-{i}", + f"level-{i}", + parent_id=parent_id, + start_time=float(i), + end_time=float(20 - i), + ) + ) + + adapter = ToTree() + tree = adapter.adapt(spans) + + assert tree.size() == 10 + + # Verify depth + current = tree + depth = 0 + while current.children: + depth += 1 + current = current.children[0] + assert depth == 9 + + +def test_edge_case_wide_tree(): + """Test a tree with many siblings.""" + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=100.0) + children = [ + make_span(f"child-{i}", "child", parent_id="root", start_time=float(i), end_time=float(i + 1)) + for i in range(20) + ] + + adapter = ToTree() + tree = adapter.adapt([root] + children) + + assert tree.item.span_id == "root" + assert len(tree.children) == 20 + + +def test_edge_case_spans_with_same_times(): + """Test handling of spans with identical timestamps. + + When repair_bad_hierarchy is disabled, spans with same times become multiple roots. + """ + spans = [ + make_span(f"span-{i}", "span", parent_id=None, start_time=0.0, end_time=1.0, sequence_id=i) for i in range(5) + ] + + # With repair_bad_hierarchy="none", no hierarchy repair happens + adapter = ToTree(repair_bad_hierarchy="none", repair_multiple_roots=True) + tree = adapter.adapt(spans) + + # All should become children of virtual root + assert tree.item.name == AGL_VIRTUAL + assert len(tree.children) == 5 + + +def test_edge_case_repair_preserves_order(): + """Repaired spans should maintain their original order where possible.""" + spans = [ + make_span("s1", "span", parent_id=None, start_time=0.0, end_time=1.0, sequence_id=0), + make_span("s2", "span", parent_id=None, start_time=1.0, end_time=2.0, sequence_id=1), + make_span("s3", "span", parent_id=None, start_time=2.0, end_time=3.0, sequence_id=2), + ] + + adapter = RepairMalformedSpans() + result = adapter.adapt(spans) + + # Order should be preserved + assert [s.span_id for s in result] == ["s1", "s2", "s3"] + + +# Additional corner case tests + + +def test_tree_like_graph_full_cycle_no_roots(): + """Graph where all nodes form a cycle (no roots) should raise ValueError.""" + # Create spans with a full cycle: A -> B -> C -> A (no roots) + span_a = make_span("A", "span", parent_id="C", start_time=0.0, end_time=10.0) + span_b = make_span("B", "span", parent_id="A", start_time=1.0, end_time=9.0) + span_c = make_span("C", "span", parent_id="B", start_time=2.0, end_time=8.0) + + with pytest.raises(ValueError, match="not reachable from the roots"): + _TreeLikeGraph.from_spans([span_a, span_b, span_c]) + + +def test_tree_like_graph_cycle_in_subtree(): + """Graph with a cycle in a subtree should raise ValueError.""" + # Create a graph with cycle: root -> A -> B -> A + # This tests the "Cycle detected" error path in validate_no_cycles + graph = _TreeLikeGraph() + graph.root_ids.add("root") + graph.add_edge("root", "A") + graph.add_edge("A", "B") + graph.add_edge("B", "A") # Creates cycle: A -> B -> A + + with pytest.raises(ValueError, match="Cycle detected"): + graph.validate_no_cycles() + + +def test_tree_like_graph_add_edge_direct(): + """Test add_edge method directly.""" + graph = _TreeLikeGraph() + graph.root_ids.add("root") + graph.add_edge("root", "child1") + graph.add_edge("root", "child2") + graph.add_edge("child1", "grandchild") + + assert "child1" in graph.forward_graph["root"] + assert "child2" in graph.forward_graph["root"] + assert "grandchild" in graph.forward_graph["child1"] + + +def test_to_tree_repair_bad_hierarchy_all(): + """Test repair_bad_hierarchy='all' mode re-evaluates all span placements.""" + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=100.0) + # child1 is correctly placed under root + child1 = make_span("child1", "child", parent_id="root", start_time=10.0, end_time=90.0) + # child2 is incorrectly a sibling of child1 but should be inside child1 based on time + child2 = make_span("child2", "child", parent_id="root", start_time=20.0, end_time=80.0) + spans = [root, child1, child2] + + adapter = ToTree(repair_bad_hierarchy="all") + tree = adapter.adapt(spans) + + # With "all" mode, child2 should be moved under child1 (tighter fit) + child1_node = next(c for c in tree.children if c.item.span_id == "child1") + child2_in_child1 = any(c.item.span_id == "child2" for c in child1_node.children) + assert child2_in_child1 + + +def test_repair_malformed_spans_empty_sequence(): + """Empty sequence should pass through unchanged.""" + adapter = RepairMalformedSpans() + result = adapter.adapt([]) + assert result == [] + + +def test_to_tree_invalid_parent_error_message(): + """Test that invalid parent error message includes helpful guidance.""" + grandchild = make_span("grandchild", "gc", parent_id="missing-parent", start_time=2.0, end_time=3.0) + spans = [grandchild] + + adapter = ToTree() + + with pytest.raises(ValueError, match="RepairMalformedSpans"): + adapter.adapt(spans) + + +def test_default_span_order_with_tied_values(): + """Spans with identical ordering keys should maintain stable sort.""" + span1 = make_span("s1", "span", parent_id=None, start_time=1.0, end_time=2.0, sequence_id=0) + span2 = make_span("s2", "span", parent_id=None, start_time=1.0, end_time=2.0, sequence_id=0) + span3 = make_span("s3", "span", parent_id=None, start_time=1.0, end_time=2.0, sequence_id=0) + spans = [span1, span2, span3] + + # Default order should be deterministic + order1 = [default_span_order(s) for s in spans] + order2 = [default_span_order(s) for s in spans] + assert order1 == order2 + + +def test_to_tree_depth_based_parent_selection(): + """Deeper eligible parents should be preferred over shallower ones. + + This verifies the fix for the depth sorting in _find_eligible_parents. + """ + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=100.0) + # Two nested containers with identical durations + container1 = make_span("container1", "c1", parent_id="root", start_time=10.0, end_time=90.0) + container2 = make_span("container2", "c2", parent_id="container1", start_time=10.0, end_time=90.0) + # Dangling span that fits in both containers + dangling = make_span("dangling", "d", parent_id=None, start_time=20.0, end_time=80.0) + spans = [root, container1, container2, dangling] + + adapter = ToTree(repair_bad_hierarchy="dangling") + tree = adapter.adapt(spans) + + # Dangling should be placed under container2 (deeper) rather than container1 + container1_node = next(c for c in tree.children if c.item.span_id == "container1") + container2_node = next(c for c in container1_node.children if c.item.span_id == "container2") + dangling_in_container2 = any(c.item.span_id == "dangling" for c in container2_node.children) + assert dangling_in_container2, "Dangling span should be placed under the deeper container" + + +def test_to_tree_repair_bad_hierarchy_prefers_shorter_duration(): + """When multiple parents are eligible, prefer shorter duration (tighter fit).""" + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=100.0) + # Wide container (duration 80) + wide = make_span("wide", "wide", parent_id="root", start_time=10.0, end_time=90.0) + # Narrow container (duration 40) - should be preferred + narrow = make_span("narrow", "narrow", parent_id="root", start_time=25.0, end_time=65.0) + # Dangling span that fits in both + dangling = make_span("dangling", "d", parent_id=None, start_time=30.0, end_time=60.0) + spans = [root, wide, narrow, dangling] + + adapter = ToTree(repair_bad_hierarchy="dangling") + tree = adapter.adapt(spans) + + # Dangling should be placed under narrow (shorter duration = tighter fit) + narrow_node = next(c for c in tree.children if c.item.span_id == "narrow") + dangling_in_narrow = any(c.item.span_id == "dangling" for c in narrow_node.children) + assert dangling_in_narrow, "Dangling span should be placed under the tighter-fitting container" + + +def test_repair_nesting_deep_hierarchy(): + """Test nesting repair with deeply nested hierarchy.""" + # Create hierarchy where child times exceed parent times at multiple levels + root = make_span("root", "root", parent_id=None, start_time=5.0, end_time=15.0) + child = make_span("child", "child", parent_id="root", start_time=4.0, end_time=16.0) + grandchild = make_span("grandchild", "gc", parent_id="child", start_time=3.0, end_time=17.0) + great_grandchild = make_span("great-gc", "ggc", parent_id="grandchild", start_time=2.0, end_time=18.0) + spans = [root, child, grandchild, great_grandchild] + + adapter = RepairMalformedSpans(ensure_proper_nesting=True) + result = adapter.adapt(spans) + + # All ancestors should be expanded to contain descendants + root_result = next(s for s in result if s.span_id == "root") + assert root_result.start_time == 2.0 + assert root_result.end_time == 18.0 + + +def test_to_adapting_spans_empty_sequence(): + """Empty sequence should return empty AdaptingSequence.""" + adapter = ToAdaptingSpans() + result = adapter.adapt([]) + assert len(result) == 0 + + +def test_tree_like_graph_move_subtree_to_same_parent(): + """Moving subtree to the same parent should be a no-op.""" + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=10.0) + child = make_span("child", "child", parent_id="root", start_time=1.0, end_time=9.0) + + graph = _TreeLikeGraph.from_spans([root, child]) + + # Move child to root (its current parent) + graph.move_subtree("child", "root") + + # Should still have the same structure + assert graph.parent_map["child"] == "root" + assert "child" in graph.forward_graph["root"] + + +def test_repair_malformed_spans_uses_max_time_for_missing(): + """Missing times should be filled with max time from other spans.""" + span_with_times = make_span("s1", "span", parent_id=None, start_time=5.0, end_time=10.0) + span_missing_start = Span.from_attributes( + rollout_id="r1", + attempt_id="a1", + sequence_id=0, + trace_id="t1", + span_id="s2", + parent_id=None, + name="span", + attributes={}, + start_time=None, + end_time=7.0, + ) + span_missing_end = Span.from_attributes( + rollout_id="r1", + attempt_id="a1", + sequence_id=0, + trace_id="t1", + span_id="s3", + parent_id=None, + name="span", + attributes={}, + start_time=3.0, + end_time=None, + ) + spans = [span_with_times, span_missing_start, span_missing_end] + + adapter = RepairMalformedSpans() + result = adapter.adapt(spans) + + # Max time across all spans is 10.0 + s2_result = next(s for s in result if s.span_id == "s2") + s3_result = next(s for s in result if s.span_id == "s3") + assert s2_result.start_time == 10.0 # Filled with max + assert s3_result.end_time == 10.0 # Filled with max + + +def test_to_tree_no_repair_needed(): + """Test that properly structured spans don't get modified.""" + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=10.0) + child1 = make_span("child1", "child", parent_id="root", start_time=1.0, end_time=4.0) + child2 = make_span("child2", "child", parent_id="root", start_time=5.0, end_time=9.0) + grandchild = make_span("grandchild", "gc", parent_id="child1", start_time=2.0, end_time=3.0) + spans = [root, child1, child2, grandchild] + + adapter = ToTree() + tree = adapter.adapt(spans) + + # Verify structure is preserved + assert tree.item.span_id == "root" + assert len(tree.children) == 2 + child1_node = next(c for c in tree.children if c.item.span_id == "child1") + assert len(child1_node.children) == 1 + assert child1_node.children[0].item.span_id == "grandchild" diff --git a/tests/types/test_adapter.py b/tests/types/test_adapter.py new file mode 100644 index 000000000..d11a39d62 --- /dev/null +++ b/tests/types/test_adapter.py @@ -0,0 +1,681 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for Tree, AdaptingSequence, and AdaptingSpan data structures.""" + +import logging +from collections import UserString +from typing import Any, Dict, List, Optional + +import pytest + +from agentlightning.types import OtelResource, Span, TraceStatus +from agentlightning.types.adapter import AdaptingSequence, AdaptingSpan, Tree + + +class SequenceTestString(UserString): + """Simple string wrapper that implements with_container for BaseAdaptingSequenceItem.""" + + def with_container(self, container: Any) -> "SequenceTestString": + return type(self)(self.data) + + +class SequenceTestInt(int): + """Simple int wrapper that implements with_container for BaseAdaptingSequenceItem.""" + + def __new__(cls, value: int) -> "SequenceTestInt": + return int.__new__(cls, value) + + def with_container(self, container: Any) -> "SequenceTestInt": + return type(self)(int(self)) + + +def s(value: str) -> SequenceTestString: + """Helper to create SequenceTestString instances.""" + return SequenceTestString(value) + + +def i(value: int) -> SequenceTestInt: + """Helper to create SequenceTestInt instances.""" + return SequenceTestInt(value) + + +def strs(*values: str) -> list[SequenceTestString]: + """Helper to create lists of SequenceTestString instances.""" + return [s(value) for value in values] + + +def ints(*values: int) -> list[SequenceTestInt]: + """Helper to create lists of SequenceTestInt instances.""" + return [i(value) for value in values] + + +def make_span( + name: str, + attributes: Dict[str, Any], + sequence_id: int, + *, + parent_id: Optional[str] = None, +) -> Span: + """Create a test span with minimal required fields.""" + return Span( + rollout_id="rollout-id", + attempt_id="attempt-id", + sequence_id=sequence_id, + trace_id=f"trace-{sequence_id}", + span_id=f"span-{sequence_id}", + parent_id=parent_id, + name=name, + status=TraceStatus(status_code="OK"), + attributes=attributes, + events=[], + links=[], + start_time=None, + end_time=None, + context=None, + parent=None, + resource=OtelResource(attributes={}, schema_url=""), + ) + + +# ============================================================================ +# Tree Tests +# ============================================================================ + + +def test_tree_single_node(): + tree = Tree(s("root"), []) + assert tree.item == "root" + assert tree.children == [] + assert tree.parent is None + + +def test_tree_with_children(): + child1 = Tree(s("child1"), []) + child2 = Tree(s("child2"), []) + root = Tree(s("root"), [child1, child2]) + + assert root.item == "root" + assert len(root.children) == 2 + assert root.children[0].item == "child1" + assert root.children[1].item == "child2" + + +def test_tree_parent_reference(): + grandchild = Tree(s("grandchild"), []) + child = Tree(s("child"), [grandchild]) + root = Tree(s("root"), [child]) + + assert root.parent is None + assert child.parent is root + assert grandchild.parent is child + + +def test_tree_len(): + child1 = Tree(s("child1"), []) + child2 = Tree(s("child2"), []) + root = Tree(s("root"), [child1, child2]) + assert len(root) == 3 + + +def test_tree_len_deep(): + grandchild = Tree(s("grandchild"), []) + child = Tree(s("child"), [grandchild]) + root = Tree(s("root"), [child]) + assert len(root) == 3 + + +def test_tree_getitem_single(): + root = Tree(s("root"), []) + assert root[0] == "root" + + +def test_tree_getitem_with_children(): + child1 = Tree(s("child1"), []) + child2 = Tree(s("child2"), []) + root = Tree(s("root"), [child1, child2]) + # DFS order: root, child1, child2 + assert root[0] == "root" + assert root[1] == "child1" + assert root[2] == "child2" + + +def test_tree_getitem_slice(): + child1 = Tree(s("child1"), []) + child2 = Tree(s("child2"), []) + root = Tree(s("root"), [child1, child2]) + assert root[1:] == ["child1", "child2"] + + +def test_tree_iter(): + child1 = Tree(s("child1"), []) + child2 = Tree(s("child2"), []) + root = Tree(s("root"), [child1, child2]) + assert list(root) == ["root", "child1", "child2"] + + +def test_tree_traverse_single_node(): + tree = Tree(s("root"), []) + assert list(tree.traverse()) == ["root"] + + +def test_tree_traverse_dfs_order(): + # root + # / \ + # child1 child2 + # | + # grandchild + grandchild = Tree(s("grandchild"), []) + child1 = Tree(s("child1"), [grandchild]) + child2 = Tree(s("child2"), []) + root = Tree(s("root"), [child1, child2]) + + # DFS order: root, child1, grandchild, child2 + assert list(root.traverse()) == ["root", "child1", "grandchild", "child2"] + + +def test_tree_size_single_node(): + tree = Tree(s("root"), []) + assert tree.size() == 1 + + +def test_tree_size_with_children(): + grandchild = Tree(s("grandchild"), []) + child1 = Tree(s("child1"), [grandchild]) + child2 = Tree(s("child2"), []) + root = Tree(s("root"), [child1, child2]) + assert root.size() == 4 + + +def test_tree_map_single_node(): + tree = Tree(i(1), []) + mapped = tree.map(lambda x: type(x)(x * 2)) + assert mapped.item == 2 + assert mapped.children == [] + + +def test_tree_map_with_children(): + child1 = Tree(i(2), []) + child2 = Tree(i(3), []) + root = Tree(i(1), [child1, child2]) + + mapped = root.map(lambda x: type(x)(x * 10)) + assert mapped.item == 10 + assert mapped.children[0].item == 20 + assert mapped.children[1].item == 30 + + +def test_tree_map_preserves_structure(): + grandchild = Tree(s("gc"), []) + child = Tree(s("c"), [grandchild]) + root = Tree(s("r"), [child]) + + mapped = root.map(lambda x: type(x)(x.upper())) + assert mapped.item == "R" + assert mapped.children[0].item == "C" + assert mapped.children[0].children[0].item == "GC" + + +def test_tree_retain_keeps_root_always(): + tree = Tree(s("root"), []) + retained = tree.retain(lambda x: False) + assert retained.item == "root" + assert retained.children == [] + + +def test_tree_retain_keeps_matching_subtrees(): + # root + # / \ + # keep1 drop1 + # | + # drop2 + drop2 = Tree(s("drop2"), []) + keep1 = Tree(s("keep1"), [drop2]) + drop1 = Tree(s("drop1"), []) + root = Tree(s("root"), [keep1, drop1]) + + # Retain subtrees rooted at nodes containing "keep" + retained = root.retain(lambda x: "keep" in x) + + assert retained.item == "root" + assert len(retained.children) == 1 + # keep1 is retained along with its entire subtree + assert retained.children[0].item == "keep1" + + +def test_tree_retain_removes_branches_without_matches(): + # root + # / \ + # drop1 drop2 + # | + # keep + keep = Tree(s("keep"), []) + drop1 = Tree(s("drop1"), [keep]) + drop2 = Tree(s("drop2"), []) + root = Tree(s("root"), [drop1, drop2]) + + retained = root.retain(lambda x: x == "keep") + + assert retained.item == "root" + # drop1 branch is kept because it leads to "keep" + assert len(retained.children) == 1 + assert retained.children[0].item == "drop1" + + +def test_tree_retain_deep_tree(): + # root + # | + # a + # | + # b (keep) + # | + # c + c = Tree(s("c"), []) + b = Tree(s("b"), [c]) + a = Tree(s("a"), [b]) + root = Tree(s("root"), [a]) + + retained = root.retain(lambda x: x == "b") + # When b matches, the entire subtree rooted at b (including c) is retained + assert list(retained.traverse()) == ["root", "a", "b", "c"] + + +def test_tree_prune_does_not_remove_root(): + tree = Tree(s("root"), []) + pruned = tree.prune(lambda x: x == "root") + assert pruned.item == "root" + + +def test_tree_prune_removes_matching_children(): + child1 = Tree(s("remove_me"), []) + child2 = Tree(s("keep_me"), []) + root = Tree(s("root"), [child1, child2]) + + pruned = root.prune(lambda x: x == "remove_me") + + assert pruned.item == "root" + assert len(pruned.children) == 1 + assert pruned.children[0].item == "keep_me" + + +def test_tree_prune_removes_subtrees(): + # root + # / \ + # remove keep + # | + # child_of_remove + child_of_remove = Tree(s("child_of_remove"), []) + remove = Tree(s("remove"), [child_of_remove]) + keep = Tree(s("keep"), []) + root = Tree(s("root"), [remove, keep]) + + pruned = root.prune(lambda x: x == "remove") + + assert pruned.item == "root" + assert len(pruned.children) == 1 + assert pruned.children[0].item == "keep" + + +def test_tree_prune_recursive(): + # root + # | + # keep + # / \ + # remove keep2 + keep2 = Tree(s("keep2"), []) + remove = Tree(s("remove"), []) + keep = Tree(s("keep"), [remove, keep2]) + root = Tree(s("root"), [keep]) + + pruned = root.prune(lambda x: x == "remove") + + assert pruned.item == "root" + assert pruned.children[0].item == "keep" + assert len(pruned.children[0].children) == 1 + assert pruned.children[0].children[0].item == "keep2" + + +# ============================================================================ +# AdaptingSequence Tests +# ============================================================================ + + +def test_adapting_sequence_empty(): + seq = AdaptingSequence[Any]([]) + assert len(seq) == 0 + assert list(seq) == [] + + +def test_adapting_sequence_with_items(): + seq = AdaptingSequence(ints(1, 2, 3)) + assert len(seq) == 3 + assert list(seq) == [1, 2, 3] + + +def test_adapting_sequence_getitem_single(): + seq = AdaptingSequence(strs("a", "b", "c")) + assert seq[0] == "a" + assert seq[1] == "b" + assert seq[2] == "c" + + +def test_adapting_sequence_getitem_negative_index(): + seq = AdaptingSequence(strs("a", "b", "c")) + assert seq[-1] == "c" + + +def test_adapting_sequence_getitem_slice(): + seq = AdaptingSequence(strs("a", "b", "c", "d")) + assert seq[1:3] == ["b", "c"] + + +def test_adapting_sequence_iter(): + seq = AdaptingSequence(ints(1, 2, 3)) + result: List[Any] = [] + for item in seq: + result.append(item) + assert result == [1, 2, 3] + + +def test_adapting_sequence_traverse(): + seq = AdaptingSequence(ints(1, 2, 3)) + assert list(seq.traverse()) == [1, 2, 3] + + +def test_adapting_sequence_size(): + seq = AdaptingSequence(ints(1, 2, 3, 4)) + assert seq.size() == 4 + + +def test_adapting_sequence_get(): + seq = AdaptingSequence(strs("x", "y", "z")) + assert seq.get(0) == "x" + assert seq.get(1) == "y" + + +def test_adapting_sequence_map_empty(): + seq = AdaptingSequence[Any]([]) + mapped = seq.map(lambda x: type(x)(x * 2)) + assert list(mapped) == [] + + +def test_adapting_sequence_map_integers(): + seq = AdaptingSequence(ints(1, 2, 3)) + mapped = seq.map(lambda x: type(x)(x * 2)) + assert list(mapped) == [2, 4, 6] + + +def test_adapting_sequence_map_strings(): + seq = AdaptingSequence(strs("a", "b", "c")) + mapped = seq.map(lambda x: type(x)(x.upper())) + assert list(mapped) == ["A", "B", "C"] + + +def test_adapting_sequence_map_returns_adapting_sequence(): + seq = AdaptingSequence(ints(1, 2, 3)) + mapped = seq.map(lambda x: x) + assert isinstance(mapped, AdaptingSequence) + + +def test_adapting_sequence_retain_all(): + seq = AdaptingSequence(ints(1, 2, 3)) + retained = seq.retain(lambda x: True) + assert list(retained) == [1, 2, 3] + + +def test_adapting_sequence_retain_none(): + seq = AdaptingSequence(ints(1, 2, 3)) + retained = seq.retain(lambda x: False) + assert list(retained) == [] + + +def test_adapting_sequence_retain_some(): + seq = AdaptingSequence(ints(1, 2, 3, 4, 5)) + retained = seq.retain(lambda x: x % 2 == 0) + assert list(retained) == [2, 4] + + +def test_adapting_sequence_retain_returns_adapting_sequence(): + seq = AdaptingSequence(ints(1, 2, 3)) + retained = seq.retain(lambda x: True) + assert isinstance(retained, AdaptingSequence) + + +def test_adapting_sequence_prune_none(): + seq = AdaptingSequence(ints(1, 2, 3)) + pruned = seq.prune(lambda x: False) + assert list(pruned) == [1, 2, 3] + + +def test_adapting_sequence_prune_all(): + seq = AdaptingSequence(ints(1, 2, 3)) + pruned = seq.prune(lambda x: True) + assert list(pruned) == [] + + +def test_adapting_sequence_prune_some(): + seq = AdaptingSequence(ints(1, 2, 3, 4, 5)) + pruned = seq.prune(lambda x: x % 2 == 0) + assert list(pruned) == [1, 3, 5] + + +def test_adapting_sequence_prune_returns_adapting_sequence(): + seq = AdaptingSequence(ints(1, 2, 3)) + pruned = seq.prune(lambda x: False) + assert isinstance(pruned, AdaptingSequence) + + +# ============================================================================ +# AdaptingSpan Tests +# ============================================================================ + + +def test_adapting_span_from_span_creates_adapting_span(): + span = make_span("test-span", {"key": "value"}, 0) + adapting_span = AdaptingSpan.from_span(span, data="test-data") + + assert adapting_span.name == "test-span" + assert adapting_span.attributes == {"key": "value"} + assert adapting_span.data == "test-data" + + +def test_adapting_span_from_span_preserves_fields(): + span = make_span("my-span", {"attr": 123}, 5, parent_id="parent-span") + adapting_span = AdaptingSpan.from_span(span, data={"nested": "data"}) + + assert adapting_span.rollout_id == "rollout-id" + assert adapting_span.attempt_id == "attempt-id" + assert adapting_span.sequence_id == 5 + assert adapting_span.parent_id == "parent-span" + assert adapting_span.data == {"nested": "data"} + + +def test_adapting_span_from_adapting_span_updates_data(): + span = make_span("test-span", {}, 0) + adapting_span1 = AdaptingSpan.from_span(span, data="original") + adapting_span2 = AdaptingSpan.from_span(adapting_span1, data="updated") + + assert adapting_span2.data == "updated" + + +def test_adapting_span_with_data_creates_copy(): + span = make_span("test-span", {"key": "value"}, 0) + adapting_span = AdaptingSpan.from_span(span, data=None) + new_span = adapting_span.with_data("new-data", override="silent") + + assert new_span.data == "new-data" + assert new_span is not adapting_span + + +def test_adapting_span_with_data_preserves_other_fields(): + span = make_span("test-span", {"attr": 42}, 3) + adapting_span = AdaptingSpan.from_span(span, data=None) + new_span = adapting_span.with_data("new-data", override="silent") + + assert new_span.name == "test-span" + assert new_span.attributes == {"attr": 42} + assert new_span.sequence_id == 3 + + +def test_adapting_span_with_data_silent_override(): + span = make_span("test-span", {}, 0) + adapting_span = AdaptingSpan.from_span(span, data="original") + new_span = adapting_span.with_data("updated", override="silent") + + assert new_span.data == "updated" + + +def test_adapting_span_with_data_warning_override(caplog: pytest.LogCaptureFixture) -> None: + span = make_span("test-span", {}, 0) + adapting_span = AdaptingSpan.from_span(span, data="original") + + with caplog.at_level(logging.WARNING): + new_span = adapting_span.with_data("updated", override="warning") + + assert new_span.data == "updated" + assert "overwriting" in caplog.text.lower() + + +def test_adapting_span_with_data_forbidden_override(): + span = make_span("test-span", {}, 0) + adapting_span = AdaptingSpan.from_span(span, data="original") + + with pytest.raises(ValueError, match="forbidden"): + adapting_span.with_data("updated", override="forbidden") + + +def test_adapting_span_with_data_none_does_not_warn(caplog: pytest.LogCaptureFixture) -> None: + span = make_span("test-span", {}, 0) + adapting_span = AdaptingSpan.from_span(span, data=None) + + with caplog.at_level(logging.WARNING): + new_span = adapting_span.with_data("updated", override="warning") + + assert new_span.data == "updated" + assert "overwriting" not in caplog.text.lower() + + +def test_adapting_span_container_default_none(): + span = make_span("test", {}, 0) + adapting_span = AdaptingSpan.from_span(span, data=None) + assert adapting_span.container is None + + +def test_adapting_span_container_can_be_set(): + span = make_span("test", {}, 0) + adapting_span = AdaptingSpan.from_span(span, data=None) + seq = AdaptingSequence([adapting_span]) + adapting_span = adapting_span.model_copy(update={"container": seq}) + + assert adapting_span.container is seq + + +@pytest.fixture +def tree_of_adapting_spans(): + """Create a tree structure of AdaptingSpans for testing. + + Structure: + root + / \\ + child1 child2 + | + grandchild + """ + root_span = make_span("root", {}, 0) + child1_span = make_span("child1", {}, 1, parent_id="span-0") + child2_span = make_span("child2", {}, 2, parent_id="span-0") + grandchild_span = make_span("grandchild", {}, 3, parent_id="span-1") + + grandchild_tree: Tree[AdaptingSpan] = Tree( + AdaptingSpan.from_span(grandchild_span, data="gc-data"), + [], + ) + child1_tree: Tree[AdaptingSpan] = Tree( + AdaptingSpan.from_span(child1_span, data="c1-data"), + [grandchild_tree], + ) + child2_tree: Tree[AdaptingSpan] = Tree( + AdaptingSpan.from_span(child2_span, data="c2-data"), + [], + ) + root_tree: Tree[AdaptingSpan] = Tree( + AdaptingSpan.from_span(root_span, data="root-data"), + [child1_tree, child2_tree], + ) + + # Set container references + root_adapting = root_tree.item.model_copy(update={"container": root_tree}) + child1_adapting = child1_tree.item.model_copy(update={"container": child1_tree}) + child2_adapting = child2_tree.item.model_copy(update={"container": child2_tree}) + grandchild_adapting = grandchild_tree.item.model_copy(update={"container": grandchild_tree}) + + return root_adapting, child1_adapting, child2_adapting, grandchild_adapting + + +def test_adapting_span_children_returns_child_spans(tree_of_adapting_spans: Tree[AdaptingSpan]): + root, child1, child2, grandchild = tree_of_adapting_spans # type: ignore + children = root.children() + + assert len(children) == 2 + assert children[0].name == "child1" + assert children[1].name == "child2" + + +def test_adapting_span_children_leaf_node_empty(tree_of_adapting_spans: Tree[AdaptingSpan]): + root, child1, child2, grandchild = tree_of_adapting_spans # type: ignore + assert child2.children() == [] + + +def test_adapting_span_children_raises_without_tree_container(): + span = make_span("test", {}, 0) + adapting_span = AdaptingSpan.from_span(span, data=None) + + with pytest.raises(ValueError, match="container"): + adapting_span.children() + + +def test_adapting_span_children_raises_with_non_tree_container(): + span = make_span("test", {}, 0) + adapting_span = AdaptingSpan.from_span(span, data=None) + adapting_span = adapting_span.model_copy(update={"container": AdaptingSequence([])}) + + with pytest.raises(ValueError, match="Tree"): + adapting_span.children() + + +def test_adapting_span_parent_span_returns_parent(tree_of_adapting_spans: Tree[AdaptingSpan]): + root, child1, child2, grandchild = tree_of_adapting_spans # type: ignore + + parent = child1.parent_span() + assert parent is not None + assert parent.name == "root" + + +def test_adapting_span_parent_span_root_returns_none(tree_of_adapting_spans: Tree[AdaptingSpan]): + root, child1, child2, grandchild = tree_of_adapting_spans # type: ignore + assert root.parent_span() is None + + +def test_adapting_span_parent_span_raises_without_tree_container(): + span = make_span("test", {}, 0) + adapting_span = AdaptingSpan.from_span(span, data=None) + + with pytest.raises(ValueError, match="container"): + adapting_span.parent_span() + + +def test_adapting_span_siblings_returns_sibling_spans(tree_of_adapting_spans: Tree[AdaptingSpan]): + root, child1, child2, grandchild = tree_of_adapting_spans # type: ignore + + siblings = child1.siblings() + assert len(siblings) == 1 + assert siblings[0].name == "child2" + + +def test_adapting_span_siblings_only_child_returns_empty(tree_of_adapting_spans: Tree[AdaptingSpan]): + root, child1, child2, grandchild = tree_of_adapting_spans # type: ignore + assert grandchild.siblings() == [] + + +def test_adapting_span_siblings_root_returns_empty(tree_of_adapting_spans: Tree[AdaptingSpan]): + root, child1, child2, grandchild = tree_of_adapting_spans # type: ignore + assert root.siblings() == []