From 62e8dbcd9ae9ee019ef778d1a5c6af2a20e974db Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Thu, 25 Dec 2025 13:17:13 +0800 Subject: [PATCH 01/41] initiate adapter pipeline types --- agentlightning/types/adapter.py | 64 +++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 agentlightning/types/adapter.py diff --git a/agentlightning/types/adapter.py b/agentlightning/types/adapter.py new file mode 100644 index 000000000..5e672e143 --- /dev/null +++ b/agentlightning/types/adapter.py @@ -0,0 +1,64 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Data formats used by adapters, usually the target format converted from trace spans.""" + +from __future__ import annotations + +from typing import Callable, Dict, Generic, Iterable, Iterator, MutableSequence, TypeVar + +T = TypeVar("T") + + +class Tree(Generic[T]): + + def __init__(self, item: T, children: MutableSequence[Tree[T]]) -> None: + self.item = item + self.children = children + + def traverse(self) -> Iterable[T]: + yield self.item + for child in self.children: + yield from child.traverse() + + def count(self) -> int: + return 1 + sum(child.count() for child in self.children) + + def __iter__(self) -> Iterator[T]: + return iter(self.traverse()) + + def __len__(self) -> int: + return self.count() + + def add(self, child: Tree[T]) -> None: + self.children.append(child) + + def prune(self, predicate: Callable[[T], bool]) -> Tree[T]: + return Tree(self.item, [child.prune(predicate) for child in self.children if predicate(child.item)]) + + def visualize(self, filename: str, item_to_str: Callable[[T], str]) -> None: + """Render the tree with Graphviz for debugging purposes. + + Args: + filename: Base filename for the generated `.png` diagram (without extension). + + !!! note + + The method requires the optional `graphviz` dependency to be available in the runtime + environment. + """ + import graphviz + + dot = graphviz.Digraph(comment="Tree") + + def visit(node: Tree[T]): + dot.node(str(id(node)), item_to_str(node.item)) # type: ignore + for child in node.children: + visit(child) + dot.edge(str(id(node)), str(id(child))) # type: ignore + + visit(self) + dot.render(filename, format="png", cleanup=True) # type: ignore + + +class ChatCompletionCall(TypedDict): + pass From 1d7b7e1939c470ea5a8312f6517ebd9de6366e1e Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Fri, 26 Dec 2025 11:47:33 +0800 Subject: [PATCH 02/41] update adapter types --- agentlightning/types/adapter.py | 240 +++++++++++++++++++++++++++++++- 1 file changed, 237 insertions(+), 3 deletions(-) diff --git a/agentlightning/types/adapter.py b/agentlightning/types/adapter.py index 5e672e143..548d86750 100644 --- a/agentlightning/types/adapter.py +++ b/agentlightning/types/adapter.py @@ -4,12 +4,40 @@ from __future__ import annotations -from typing import Callable, Dict, Generic, Iterable, Iterator, MutableSequence, TypeVar +from typing import ( + Any, + Callable, + Dict, + Generic, + Iterable, + Iterator, + Literal, + MutableSequence, + Optional, + Sequence, + TypeVar, +) + +from openai.types.chat import ( + ChatCompletion, + ChatCompletionFunctionToolParam, + ChatCompletionMessageParam, + CompletionCreateParams, +) +from pydantic import BaseModel, Field + +from agentlightning.semconv import LinkPydanticModel + +from .tracer import Attributes T = TypeVar("T") +# General containers + + class Tree(Generic[T]): + """This is a generic tree data structure that can be used to represent the structure of a tree.""" def __init__(self, item: T, children: MutableSequence[Tree[T]]) -> None: self.item = item @@ -60,5 +88,211 @@ def visit(node: Tree[T]): dot.render(filename, format="png", cleanup=True) # type: ignore -class ChatCompletionCall(TypedDict): - pass +# Annotation-related types + + +class Annotation(BaseModel): + """An annotation is an approach to parse a span into some kind of structured attachments to another object. + + Note that a span can be parsed in multiple ways, and annotation is just one of them. + """ + + annotation_type: Literal["agent", "general", "message", "object", "exception", "operation"] + """Type of the annotation.""" + + span_id: str + """Span ID of the annotation span. Not necessarily an [AGL_ANNOTATION][agentlightning.semconv.AGL_ANNOTATION] span.""" + + links: Optional[Sequence[LinkPydanticModel]] = None + """Links to other spans or objects.""" + + +class AgentAnnotation(Annotation): + """Parsed from [OTel Agent Spans](https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-agent-spans/).""" + + annotation_type = "agent" + """Type of the annotation.""" + + id: Optional[str] = None + """The unique identifier of the GenAI agent.""" + + name: Optional[str] = None + """Human-readable name of the GenAI agent provided by the application.""" + + description: Optional[str] = None + """Free-form description of the GenAI agent provided by the application.""" + + +class GeneralAnnotation(Annotation): + """An annotation payload that is parsed from an [annotation][agentlightning.semconv.AGL_ANNOTATION] span.""" + + annotation_type = "general" + """Type of the annotation.""" + + reward: Dict[str, float] = Field(default_factory=dict) + """Reward dimensions and values.""" + + primary_reward: Optional[float] = None + """Primary reward value.""" + + tag: Sequence[str] = Field(default_factory=list) + """Tags for the annotation.""" + + custom_fields: Dict[str, Any] = Field(default_factory=dict) + """Raw payload from the annotation.""" + + +class MessageAnnotation(Annotation): + """A log message that is parsed from a [message][agentlightning.semconv.AGL_MESSAGE] span.""" + + annotation_type = "message" + """Type of the annotation.""" + + message: str + """Message text.""" + + +class ObjectAnnotation(Annotation): + """An artifact that is parsed from a [object][agentlightning.semconv.AGL_OBJECT] span.""" + + annotation_type = "object" + """Type of the annotation.""" + + object: Any + """The object payload.""" + + +class ExceptionAnnotation(Annotation): + """An exception that is parsed from an [exception][agentlightning.semconv.AGL_EXCEPTION] span.""" + + annotation_type = "exception" + """Type of the annotation.""" + + type: str + """Type of the exception.""" + + message: str + """Message of the exception.""" + + stacktrace: Optional[str] = None + """Stacktrace of the exception.""" + + +class OperationAnnotation(Annotation): + """An operation that is parsed from an [operation][agentlightning.semconv.AGL_OPERATION] span.""" + + annotation_type = "operation" + """Type of the annotation.""" + + name: str + """Name of the operation.""" + + input: Optional[Any] = None + """Input of the operation.""" + + output: Optional[Any] = None + """Output of the operation.""" + + +class ChatCompletionCall(BaseModel): + """Corresponding to exactly one chat completion call. + + OpenAI chat completion request and response are used as standards here. + Convert to other chat completion formats if needed. + """ + + request: CompletionCreateParams + """OpenAI chat completion request parameters.""" + + response: ChatCompletion + """OpenAI chat completion response payload.""" + + malformed_fields: Dict[str, Attributes] + """Fields that are not supported by the adapter. + + Mapping from span names to a dict of malformed fields. + """ + + span_ids: Sequence[str] + """Span IDs of the spans that contributed to this chat completion.""" + + +class AnnotatedChatCompletionCall(ChatCompletionCall): + """A chat completion call with annotations.""" + + annotations: Sequence[Annotation] + """Annotations for the chat completion call.""" + + +# Algorithm-specific requirements + + +class TokenInput(BaseModel): + """Token-based model input.""" + + token_ids: Sequence[int] + """Token IDs of the model input.""" + + image_urls: Any + """A list of image URLs. Could be pointers to local files or base64-encoded images.""" + + +class TokenOutput(BaseModel): + """Token-based model output.""" + + token_ids: Sequence[int] + """Token IDs of the model output.""" + + +class TokenIOTriplet(BaseModel): + """A triplet of token IDs for the input and output, useful for reinforcement learning. + + This is not a stable interface and the fields here highly depend on RL implementations. + """ + + observation: TokenInput + """Observation for the model input. Corresponding to prompt.""" + + action: TokenOutput + """Action, corresponding to completion result.""" + + reward: Optional[float] + """Reward of the model input.""" + + done: bool + """Whether it's the end of the trajectory.""" + + raw_call: AnnotatedChatCompletionCall + """Raw chat completion call.""" + + +class AccumulatedTokenSequence(TokenInput): + """A sequence of token IDs that are accumulated from multiple model calls. + + Output is implied in the token IDs. + """ + + response_mask: Sequence[int] + """Mask for the response tokens. Must a sequence of 0s and 1s, with 1s for the completion tokens and 0s for the prompt tokens.""" + + final_reward: Optional[float] + """Single reward value for the entire sequence.""" + + raw_calls: Sequence[AnnotatedChatCompletionCall] + """Raw chat completion calls. The order of the calls must be the same as the order of the token IDs.""" + + +class AccumulatedMessages(BaseModel): + """A conversation that is accumulated from multiple model calls.""" + + messages: Sequence[ChatCompletionMessageParam] + """Messages of the conversation.""" + + tools: Optional[Sequence[ChatCompletionFunctionToolParam]] + """Tools provided for the conversation.""" + + final_reward: Optional[float] + """Single reward value for the entire conversation.""" + + raw_calls: Sequence[AnnotatedChatCompletionCall] + """Raw chat completion calls. The order of the calls must be the same as the order of the messages.""" From b7ad2d33eeaf42e821d763a6173d47a3954c7c10 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Fri, 26 Dec 2025 12:10:37 +0800 Subject: [PATCH 03/41] . --- agentlightning/adapter/base.py | 13 ++++++ agentlightning/adapter/consolidation.py | 19 ++++++++ agentlightning/adapter/conversion.py | 62 +++++++++++++++++++++++++ agentlightning/types/adapter.py | 2 +- 4 files changed, 95 insertions(+), 1 deletion(-) create mode 100644 agentlightning/adapter/consolidation.py create mode 100644 agentlightning/adapter/conversion.py diff --git a/agentlightning/adapter/base.py b/agentlightning/adapter/base.py index 2a148f78d..e31a1764a 100644 --- a/agentlightning/adapter/base.py +++ b/agentlightning/adapter/base.py @@ -66,6 +66,19 @@ def adapt(self, source: T_from, /) -> T_to: raise NotImplementedError("Adapter.adapt() is not implemented") +class SequenceAdapter(Adapter[Sequence[T_from], Sequence[T_to]], Generic[T_from, T_to]): + """Base class for adapters that convert sequences of data from one format to another. + + This class specializes [`Adapter`][agentlightning.Adapter] for working with sequences of data. + """ + + def adapt(self, source: Sequence[T_from]) -> Sequence[T_to]: + return [self.adapt_one(item) for item in source] + + def adapt_one(self, source: T_from) -> T_to: + raise NotImplementedError("SequenceAdapter.adapt_one() is not implemented") + + class OtelTraceAdapter(Adapter[Sequence[ReadableSpan], T_to], Generic[T_to]): """Base class for adapters that convert OpenTelemetry trace spans into other formats. diff --git a/agentlightning/adapter/consolidation.py b/agentlightning/adapter/consolidation.py new file mode 100644 index 000000000..34d0b84dc --- /dev/null +++ b/agentlightning/adapter/consolidation.py @@ -0,0 +1,19 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Filter/aggregate multiple objects into a single object. + +Opinionated towards which objects to keep and how to aggregate them. +""" + +from __future__ import annotations + +from typing import Generic, Sequence, TypeVar + +from agentlightning.types.adapter import ( + AccumulatedTokenSequence, + AnnotatedChatCompletionCall, + Annotation, + ChatCompletionCall, + TokenInputOutputTriplet, + Tree, +) diff --git a/agentlightning/adapter/conversion.py b/agentlightning/adapter/conversion.py new file mode 100644 index 000000000..4b4e7cf62 --- /dev/null +++ b/agentlightning/adapter/conversion.py @@ -0,0 +1,62 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Non-opinionated conversion adapters for different data formats, without loss of information.""" + +from __future__ import annotations + +from typing import Generic, Sequence, TypeVar + +from agentlightning.types.adapter import ( + AccumulatedTokenSequence, + AnnotatedChatCompletionCall, + Annotation, + ChatCompletionCall, + TokenInputOutputTriplet, + Tree, +) +from agentlightning.types.tracer import Span, SpanLike + +from .base import Adapter, SequenceAdapter + +T_from = TypeVar("T_from") +T_to = TypeVar("T_to") + + +class ToSpans(SequenceAdapter[SpanLike, Span]): + + def __init__( + self, + default_rollout_id: str = "rollout-dummy", + default_attempt_id: str = "attempt-dummy", + default_sequence_id: int = 0, + ): + self.default_rollout_id = default_rollout_id + self.default_attempt_id = default_attempt_id + self.default_sequence_id = default_sequence_id + + def adapt_one(self, source: SpanLike) -> Span: + if isinstance(source, Span): + return source + return Span.from_opentelemetry( + source, + rollout_id=self.default_rollout_id, + attempt_id=self.default_attempt_id, + sequence_id=self.default_sequence_id, + ) + + +class ToTree(Adapter[Sequence[Span], Tree[Span]]): + + def __init__(self, repair_hierarchy: bool = True): + self.repair_hierarchy = repair_hierarchy + + def adapt(self, source: Sequence[Span]) -> Tree[Span]: ... + + +class ToChatCompletionCalls(Adapter[Sequence[Span], Sequence[ChatCompletionCall]]): ... + + +class ToAnnotations(Adapter[Sequence[Span], Sequence[Annotation]]): ... + + +class ToTokenInputOutputTriplet(Adapter[Sequence[AnnotatedChatCompletionCall], Sequence[TokenInputOutputTriplet]]): ... diff --git a/agentlightning/types/adapter.py b/agentlightning/types/adapter.py index 548d86750..80e836de5 100644 --- a/agentlightning/types/adapter.py +++ b/agentlightning/types/adapter.py @@ -244,7 +244,7 @@ class TokenOutput(BaseModel): """Token IDs of the model output.""" -class TokenIOTriplet(BaseModel): +class TokenInputOutputTriplet(BaseModel): """A triplet of token IDs for the input and output, useful for reinforcement learning. This is not a stable interface and the fields here highly depend on RL implementations. From 257ebeba24fef519ca293dbd893cdeb76a7306f4 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Fri, 26 Dec 2025 18:18:44 +0800 Subject: [PATCH 04/41] add more definitions --- agentlightning/adapter/consolidation.py | 82 ++++++++++++++++++++++++- agentlightning/adapter/conversion.py | 13 ++-- agentlightning/adapter/speculation.py | 70 +++++++++++++++++++++ agentlightning/types/adapter.py | 16 ++++- 4 files changed, 171 insertions(+), 10 deletions(-) create mode 100644 agentlightning/adapter/speculation.py diff --git a/agentlightning/adapter/consolidation.py b/agentlightning/adapter/consolidation.py index 34d0b84dc..678ff1421 100644 --- a/agentlightning/adapter/consolidation.py +++ b/agentlightning/adapter/consolidation.py @@ -7,9 +7,11 @@ from __future__ import annotations -from typing import Generic, Sequence, TypeVar +import re +from typing import Callable, Generic, Literal, Sequence, Tuple, TypeVar from agentlightning.types.adapter import ( + AccumulatedMessages, AccumulatedTokenSequence, AnnotatedChatCompletionCall, Annotation, @@ -17,3 +19,81 @@ TokenInputOutputTriplet, Tree, ) +from agentlightning.types.tracer import Span + +from .base import Adapter + +T_SpanSequence = TypeVar("T_SpanSequence", bound=Sequence[Span]) + +T_from = TypeVar("T_from") +T_to = TypeVar("T_to") + + +class CurateChatCompletionCalls(Adapter[Sequence[Span], Sequence[ChatCompletionCall]]): + """Curate the chat completion calls from the spans.""" + + def adapt(self, source: Sequence[Span]) -> Sequence[ChatCompletionCall]: ... + + +class CurateAnnotations(Adapter[Sequence[Span], Sequence[Annotation]]): + """Curate the annotations from the spans.""" + + def adapt(self, source: Sequence[Span]) -> Sequence[Annotation]: ... + + +class Filter(Adapter[Sequence[T_from], Sequence[T_to]]): + """Filter items of type T_from to items of type T_to based on a predicate.""" + + def __init__(self, predicate: Callable[[T_from], bool]) -> None: + self.predicate = predicate + + def adapt(self, source: Sequence[T_from]) -> Sequence[T_to]: ... + + +class SelectByAnnotation(Adapter[Tuple[T_SpanSequence, Sequence[Annotation]], T_SpanSequence]): + """Select the corresponding spans within the annotation sequence, as well as their linked spans + (and subtree spans if applicable). + + The effective radius of an annotation is as follows: + + - If the annotation has links, it applies to the linked spans only. + - If the annotation is on a tree node, it applies to all spans in the subtree. + - If the annotation has neither links nor tree nodes, it applies to only itself. + + Args: + mode: "include" to select spans within the annotations; "exclude" to exclude them. + """ + + def __init__(self, mode: Literal["include", "exclude"]) -> None: + self.mode = mode + + def adapt(self, source: Tuple[T_SpanSequence, Sequence[Annotation]]) -> T_SpanSequence: ... + + +class AnnotateChatCompletionCalls( + Adapter[Tuple[Sequence[ChatCompletionCall], Sequence[Annotation]], Sequence[AnnotatedChatCompletionCall]] +): + """Annotate chat completion calls with the given annotations. + + The intersection of "effective radius" of annotations and chat completion calls is used to determine + which annotations apply to which chat completion calls. + + If an annotation is not linked to any span, try to use `FillMissingLinks` first to link it to spans. + """ + + def adapt( + self, + source: Tuple[Sequence[ChatCompletionCall], Sequence[Annotation]], + ) -> Sequence[AnnotatedChatCompletionCall]: ... + + +class AccumulateTokenSequence(Adapter[Sequence[TokenInputOutputTriplet], Sequence[AccumulatedTokenSequence]]): + """Assemble multiple token input-output triplets into accumulated token sequences.""" + + def adapt(self, source: Sequence[TokenInputOutputTriplet]) -> Sequence[AccumulatedTokenSequence]: ... + + +class AccumulateMessages(Adapter[Sequence[AnnotatedChatCompletionCall], Sequence[AccumulatedMessages]]): + """Assemble multiple token input-output triplets into accumulated chat messages.""" + + def adapt(self, source: Sequence[AnnotatedChatCompletionCall]) -> Sequence[AccumulatedMessages]: ... diff --git a/agentlightning/adapter/conversion.py b/agentlightning/adapter/conversion.py index 4b4e7cf62..0c3c48d6d 100644 --- a/agentlightning/adapter/conversion.py +++ b/agentlightning/adapter/conversion.py @@ -47,16 +47,17 @@ def adapt_one(self, source: SpanLike) -> Span: class ToTree(Adapter[Sequence[Span], Tree[Span]]): - def __init__(self, repair_hierarchy: bool = True): - self.repair_hierarchy = repair_hierarchy - def adapt(self, source: Sequence[Span]) -> Tree[Span]: ... -class ToChatCompletionCalls(Adapter[Sequence[Span], Sequence[ChatCompletionCall]]): ... +class ToSortedSpans(Adapter[Sequence[Span], Sequence[Span]]): + """Sort the spans with sequence ID as the primary key and start time as the secondary key.""" + def adapt(self, source: Sequence[Span]) -> Sequence[Span]: + return sorted(source, key=lambda span: (span.sequence_id, span.start_time)) -class ToAnnotations(Adapter[Sequence[Span], Sequence[Annotation]]): ... +class ToTokenInputOutputTriplet(Adapter[Sequence[AnnotatedChatCompletionCall], Sequence[TokenInputOutputTriplet]]): + """Convert annotated chat completion calls to token input-output triplets.""" -class ToTokenInputOutputTriplet(Adapter[Sequence[AnnotatedChatCompletionCall], Sequence[TokenInputOutputTriplet]]): ... + def adapt(self, source: Sequence[AnnotatedChatCompletionCall]) -> Sequence[TokenInputOutputTriplet]: ... diff --git a/agentlightning/adapter/speculation.py b/agentlightning/adapter/speculation.py new file mode 100644 index 000000000..b0864d383 --- /dev/null +++ b/agentlightning/adapter/speculation.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Adapters that are making guesses based on heuristics to fill in missing information.""" + +from __future__ import annotations + +from typing import Literal, Sequence, Tuple + +from agentlightning.types.adapter import Annotation, Tree +from agentlightning.types.tracer import Span + +from .base import Adapter + + +class FillMissingLinks(Adapter[Tuple[Sequence[Span], Sequence[Annotation]], Sequence[Span]]): + """Populate missing annotation links by searching nearby spans. + + This adapter scans annotations and, for any annotation that has no linked spans, attempts + to infer and attach link targets using a configurable search strategy. + + Typical use case: upstream extraction produced annotations (e.g., entities, citations) + but failed to attach their target spans; this adapter backfills those links based on + proximity and eligibility rules. + + Args: + require_annotation_span_child: + If True, only attempt to fill links for annotations whose *own* span is present + as a child span in the candidate span set. If False, annotations are considered + regardless of whether their span is a child. + + candidate_scope: + Controls which spans are eligible as link targets: + + - "siblings": search only among sibling spans of the annotation span. + - "all": search among all spans provided to the adapter. + + scan_direction: + Determines both (a) which direction the adapter searches for candidate targets + relative to an annotation and (b) the order in which annotations are processed: + + - "backward": search earlier spans; process annotations from latest to earliest. + - "forward": search later spans; process annotations from earliest to latest. + + allow_reuse_linked_spans: + If False, spans already linked by *any* annotation are not eligible targets for + additional links (i.e., enforce a one-to-one-ish linking constraint). + If True, a span may be linked multiple times by different annotations. + """ + + def __init__( + self, + require_annotation_span_child: bool = True, + candidate_scope: Literal["siblings", "all"] = "all", + scan_direction: Literal["backward", "forward"] = "backward", + allow_reuse_linked_spans: bool = False, + ) -> None: + self.require_annotation_span_child = require_annotation_span_child + self.candidate_scope = candidate_scope + self.scan_direction = scan_direction + self.allow_reuse_linked_spans = allow_reuse_linked_spans + + def adapt(self, source: Tuple[Sequence[Span], Sequence[Annotation]]) -> Sequence[Span]: ... + + +class RepairTreeHierarchy(Adapter[Tree[Span], Tree[Span]]): + """Repair the tree hierarchy by ensuring that parent-child relationships are consistent + with span start and end times. Adding missing parent-child relationships as needed. + """ + + def adapt(self, source: Tree[Span]) -> Tree[Span]: ... diff --git a/agentlightning/types/adapter.py b/agentlightning/types/adapter.py index 80e836de5..23be24d60 100644 --- a/agentlightning/types/adapter.py +++ b/agentlightning/types/adapter.py @@ -48,14 +48,24 @@ def traverse(self) -> Iterable[T]: for child in self.children: yield from child.traverse() - def count(self) -> int: - return 1 + sum(child.count() for child in self.children) + def size(self) -> int: + return 1 + sum(child.size() for child in self.children) def __iter__(self) -> Iterator[T]: return iter(self.traverse()) + def __getitem__(self, index: int) -> T: + """Get the index-th item in the tree (O(n) time complexity). + + I think this is not efficient, but it's seldomly used. + """ + for i, item in enumerate(self.traverse()): + if i == index: + return item + raise IndexError(f"Tree index out of range: {index}") + def __len__(self) -> int: - return self.count() + return self.size() def add(self, child: Tree[T]) -> None: self.children.append(child) From 0b27bed23de73392c1cd36b0a88d40423f2c1360 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Mon, 29 Dec 2025 11:03:36 +0800 Subject: [PATCH 05/41] reorganize folder --- .../adapter/{speculation.py => annotation.py} | 42 +++++--- agentlightning/adapter/base.py | 22 ++++- agentlightning/adapter/call.py | 35 +++++++ agentlightning/adapter/consolidation.py | 99 ------------------- agentlightning/adapter/postprocess.py | 34 +++++++ .../adapter/{conversion.py => preprocess.py} | 29 +++--- 6 files changed, 133 insertions(+), 128 deletions(-) rename agentlightning/adapter/{speculation.py => annotation.py} (67%) create mode 100644 agentlightning/adapter/call.py delete mode 100644 agentlightning/adapter/consolidation.py create mode 100644 agentlightning/adapter/postprocess.py rename agentlightning/adapter/{conversion.py => preprocess.py} (56%) diff --git a/agentlightning/adapter/speculation.py b/agentlightning/adapter/annotation.py similarity index 67% rename from agentlightning/adapter/speculation.py rename to agentlightning/adapter/annotation.py index b0864d383..7cbf64c59 100644 --- a/agentlightning/adapter/speculation.py +++ b/agentlightning/adapter/annotation.py @@ -1,16 +1,44 @@ # Copyright (c) Microsoft. All rights reserved. -"""Adapters that are making guesses based on heuristics to fill in missing information.""" +"""Find and repair the annotations from spans.""" from __future__ import annotations -from typing import Literal, Sequence, Tuple +from typing import Literal, Sequence, Tuple, TypeVar -from agentlightning.types.adapter import Annotation, Tree +from agentlightning.types.adapter import Annotation from agentlightning.types.tracer import Span from .base import Adapter +T_SpanSequence = TypeVar("T_SpanSequence", bound=Sequence[Span]) + + +class CurateAnnotations(Adapter[Sequence[Span], Sequence[Annotation]]): + """Curate the annotations from the spans.""" + + def adapt(self, source: Sequence[Span]) -> Sequence[Annotation]: ... + + +class SelectByAnnotation(Adapter[Tuple[T_SpanSequence, Sequence[Annotation]], T_SpanSequence]): + """Select the corresponding spans within the annotation sequence, as well as their linked spans + (and subtree spans if applicable). + + The effective radius of an annotation is as follows: + + - If the annotation has links, it applies to the linked spans only. + - If the annotation is on a tree node, it applies to all spans in the subtree. + - If the annotation has neither links nor tree nodes, it applies to only itself. + + Args: + mode: "include" to select spans within the annotations; "exclude" to exclude them. + """ + + def __init__(self, mode: Literal["include", "exclude"]) -> None: + self.mode = mode + + def adapt(self, source: Tuple[T_SpanSequence, Sequence[Annotation]]) -> T_SpanSequence: ... + class FillMissingLinks(Adapter[Tuple[Sequence[Span], Sequence[Annotation]], Sequence[Span]]): """Populate missing annotation links by searching nearby spans. @@ -60,11 +88,3 @@ def __init__( self.allow_reuse_linked_spans = allow_reuse_linked_spans def adapt(self, source: Tuple[Sequence[Span], Sequence[Annotation]]) -> Sequence[Span]: ... - - -class RepairTreeHierarchy(Adapter[Tree[Span], Tree[Span]]): - """Repair the tree hierarchy by ensuring that parent-child relationships are consistent - with span start and end times. Adding missing parent-child relationships as needed. - """ - - def adapt(self, source: Tree[Span]) -> Tree[Span]: ... diff --git a/agentlightning/adapter/base.py b/agentlightning/adapter/base.py index e31a1764a..ef941ca72 100644 --- a/agentlightning/adapter/base.py +++ b/agentlightning/adapter/base.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. -from typing import Generic, Sequence, TypeVar +from typing import Any, Callable, Generic, Sequence, TypeVar from opentelemetry.sdk.trace import ReadableSpan @@ -79,6 +79,26 @@ def adapt_one(self, source: T_from) -> T_to: raise NotImplementedError("SequenceAdapter.adapt_one() is not implemented") +class Filter(Adapter[Sequence[T_from], Sequence[T_from]], Generic[T_from]): + """Filter items of type T to items of type T based on a predicate.""" + + def __init__(self, predicate: Callable[[T_from], bool]) -> None: + self.predicate = predicate + + def adapt(self, source: Sequence[T_from]) -> Sequence[T_from]: + return [item for item in source if self.predicate(item)] + + +class Sort(Adapter[Sequence[T_from], Sequence[T_from]], Generic[T_from]): + """Sort items of type T based on a key function.""" + + def __init__(self, key: Callable[[T_from], Any]) -> None: + self.key = key + + def adapt(self, source: Sequence[T_from]) -> Sequence[T_from]: + return sorted(source, key=self.key) + + class OtelTraceAdapter(Adapter[Sequence[ReadableSpan], T_to], Generic[T_to]): """Base class for adapters that convert OpenTelemetry trace spans into other formats. diff --git a/agentlightning/adapter/call.py b/agentlightning/adapter/call.py new file mode 100644 index 000000000..8330dd9ef --- /dev/null +++ b/agentlightning/adapter/call.py @@ -0,0 +1,35 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Handles chat completion / response calls. Extracts them from spans and annotates them with annotations.""" + +from __future__ import annotations + +from typing import Sequence, Tuple + +from agentlightning.types.adapter import AnnotatedChatCompletionCall, Annotation, ChatCompletionCall +from agentlightning.types.tracer import Span + +from .base import Adapter + + +class CurateChatCompletionCalls(Adapter[Sequence[Span], Sequence[ChatCompletionCall]]): + """Curate the chat completion calls from the spans.""" + + def adapt(self, source: Sequence[Span]) -> Sequence[ChatCompletionCall]: ... + + +class AnnotateChatCompletionCalls( + Adapter[Tuple[Sequence[ChatCompletionCall], Sequence[Annotation]], Sequence[AnnotatedChatCompletionCall]] +): + """Annotate chat completion calls with the given annotations. + + The intersection of "effective radius" of annotations and chat completion calls is used to determine + which annotations apply to which chat completion calls. + + If an annotation is not linked to any span, try to use `FillMissingLinks` first to link it to spans. + """ + + def adapt( + self, + source: Tuple[Sequence[ChatCompletionCall], Sequence[Annotation]], + ) -> Sequence[AnnotatedChatCompletionCall]: ... diff --git a/agentlightning/adapter/consolidation.py b/agentlightning/adapter/consolidation.py deleted file mode 100644 index 678ff1421..000000000 --- a/agentlightning/adapter/consolidation.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Filter/aggregate multiple objects into a single object. - -Opinionated towards which objects to keep and how to aggregate them. -""" - -from __future__ import annotations - -import re -from typing import Callable, Generic, Literal, Sequence, Tuple, TypeVar - -from agentlightning.types.adapter import ( - AccumulatedMessages, - AccumulatedTokenSequence, - AnnotatedChatCompletionCall, - Annotation, - ChatCompletionCall, - TokenInputOutputTriplet, - Tree, -) -from agentlightning.types.tracer import Span - -from .base import Adapter - -T_SpanSequence = TypeVar("T_SpanSequence", bound=Sequence[Span]) - -T_from = TypeVar("T_from") -T_to = TypeVar("T_to") - - -class CurateChatCompletionCalls(Adapter[Sequence[Span], Sequence[ChatCompletionCall]]): - """Curate the chat completion calls from the spans.""" - - def adapt(self, source: Sequence[Span]) -> Sequence[ChatCompletionCall]: ... - - -class CurateAnnotations(Adapter[Sequence[Span], Sequence[Annotation]]): - """Curate the annotations from the spans.""" - - def adapt(self, source: Sequence[Span]) -> Sequence[Annotation]: ... - - -class Filter(Adapter[Sequence[T_from], Sequence[T_to]]): - """Filter items of type T_from to items of type T_to based on a predicate.""" - - def __init__(self, predicate: Callable[[T_from], bool]) -> None: - self.predicate = predicate - - def adapt(self, source: Sequence[T_from]) -> Sequence[T_to]: ... - - -class SelectByAnnotation(Adapter[Tuple[T_SpanSequence, Sequence[Annotation]], T_SpanSequence]): - """Select the corresponding spans within the annotation sequence, as well as their linked spans - (and subtree spans if applicable). - - The effective radius of an annotation is as follows: - - - If the annotation has links, it applies to the linked spans only. - - If the annotation is on a tree node, it applies to all spans in the subtree. - - If the annotation has neither links nor tree nodes, it applies to only itself. - - Args: - mode: "include" to select spans within the annotations; "exclude" to exclude them. - """ - - def __init__(self, mode: Literal["include", "exclude"]) -> None: - self.mode = mode - - def adapt(self, source: Tuple[T_SpanSequence, Sequence[Annotation]]) -> T_SpanSequence: ... - - -class AnnotateChatCompletionCalls( - Adapter[Tuple[Sequence[ChatCompletionCall], Sequence[Annotation]], Sequence[AnnotatedChatCompletionCall]] -): - """Annotate chat completion calls with the given annotations. - - The intersection of "effective radius" of annotations and chat completion calls is used to determine - which annotations apply to which chat completion calls. - - If an annotation is not linked to any span, try to use `FillMissingLinks` first to link it to spans. - """ - - def adapt( - self, - source: Tuple[Sequence[ChatCompletionCall], Sequence[Annotation]], - ) -> Sequence[AnnotatedChatCompletionCall]: ... - - -class AccumulateTokenSequence(Adapter[Sequence[TokenInputOutputTriplet], Sequence[AccumulatedTokenSequence]]): - """Assemble multiple token input-output triplets into accumulated token sequences.""" - - def adapt(self, source: Sequence[TokenInputOutputTriplet]) -> Sequence[AccumulatedTokenSequence]: ... - - -class AccumulateMessages(Adapter[Sequence[AnnotatedChatCompletionCall], Sequence[AccumulatedMessages]]): - """Assemble multiple token input-output triplets into accumulated chat messages.""" - - def adapt(self, source: Sequence[AnnotatedChatCompletionCall]) -> Sequence[AccumulatedMessages]: ... diff --git a/agentlightning/adapter/postprocess.py b/agentlightning/adapter/postprocess.py new file mode 100644 index 000000000..0fdc6c461 --- /dev/null +++ b/agentlightning/adapter/postprocess.py @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Post-process the data to make it more suitable for training.""" + +from __future__ import annotations + +from typing import Sequence + +from agentlightning.types.adapter import ( + AccumulatedMessages, + AccumulatedTokenSequence, + AnnotatedChatCompletionCall, + TokenInputOutputTriplet, +) + +from .base import Adapter + + +class AccumulateTokenSequence(Adapter[Sequence[TokenInputOutputTriplet], Sequence[AccumulatedTokenSequence]]): + """Assemble multiple token input-output triplets into accumulated token sequences.""" + + def adapt(self, source: Sequence[TokenInputOutputTriplet]) -> Sequence[AccumulatedTokenSequence]: ... + + +class AccumulateMessages(Adapter[Sequence[AnnotatedChatCompletionCall], Sequence[AccumulatedMessages]]): + """Assemble multiple token input-output triplets into accumulated chat messages.""" + + def adapt(self, source: Sequence[AnnotatedChatCompletionCall]) -> Sequence[AccumulatedMessages]: ... + + +class ToTokenInputOutputTriplet(Adapter[Sequence[AnnotatedChatCompletionCall], Sequence[TokenInputOutputTriplet]]): + """Convert annotated chat completion calls to token input-output triplets.""" + + def adapt(self, source: Sequence[AnnotatedChatCompletionCall]) -> Sequence[TokenInputOutputTriplet]: ... diff --git a/agentlightning/adapter/conversion.py b/agentlightning/adapter/preprocess.py similarity index 56% rename from agentlightning/adapter/conversion.py rename to agentlightning/adapter/preprocess.py index 0c3c48d6d..1e4dd1af2 100644 --- a/agentlightning/adapter/conversion.py +++ b/agentlightning/adapter/preprocess.py @@ -1,22 +1,15 @@ # Copyright (c) Microsoft. All rights reserved. -"""Non-opinionated conversion adapters for different data formats, without loss of information.""" +"""Span re-organization adapters.""" from __future__ import annotations -from typing import Generic, Sequence, TypeVar +from typing import Sequence, TypeVar -from agentlightning.types.adapter import ( - AccumulatedTokenSequence, - AnnotatedChatCompletionCall, - Annotation, - ChatCompletionCall, - TokenInputOutputTriplet, - Tree, -) +from agentlightning.types.adapter import Tree from agentlightning.types.tracer import Span, SpanLike -from .base import Adapter, SequenceAdapter +from .base import Adapter, SequenceAdapter, Sort T_from = TypeVar("T_from") T_to = TypeVar("T_to") @@ -50,14 +43,16 @@ class ToTree(Adapter[Sequence[Span], Tree[Span]]): def adapt(self, source: Sequence[Span]) -> Tree[Span]: ... -class ToSortedSpans(Adapter[Sequence[Span], Sequence[Span]]): +class ToSortedSpans(Sort[Span]): """Sort the spans with sequence ID as the primary key and start time as the secondary key.""" - def adapt(self, source: Sequence[Span]) -> Sequence[Span]: - return sorted(source, key=lambda span: (span.sequence_id, span.start_time)) + def __init__(self) -> None: + super().__init__(key=lambda span: (span.sequence_id, span.start_time)) -class ToTokenInputOutputTriplet(Adapter[Sequence[AnnotatedChatCompletionCall], Sequence[TokenInputOutputTriplet]]): - """Convert annotated chat completion calls to token input-output triplets.""" +class RepairTreeHierarchy(Adapter[Tree[Span], Tree[Span]]): + """Repair the tree hierarchy by ensuring that parent-child relationships are consistent + with span start and end times. Adding missing parent-child relationships as needed. + """ - def adapt(self, source: Sequence[AnnotatedChatCompletionCall]) -> Sequence[TokenInputOutputTriplet]: ... + def adapt(self, source: Tree[Span]) -> Tree[Span]: ... From a38dacf0d27cb62e5b41e0116f57479eea876e56 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Mon, 29 Dec 2025 13:19:23 +0800 Subject: [PATCH 06/41] update repair hierarchy --- agentlightning/adapter/preprocess.py | 212 ++++++++++++++++++++++++++- agentlightning/types/tracer.py | 12 ++ 2 files changed, 222 insertions(+), 2 deletions(-) diff --git a/agentlightning/adapter/preprocess.py b/agentlightning/adapter/preprocess.py index 1e4dd1af2..e1ddf6db5 100644 --- a/agentlightning/adapter/preprocess.py +++ b/agentlightning/adapter/preprocess.py @@ -4,7 +4,10 @@ from __future__ import annotations -from typing import Sequence, TypeVar +import logging +import time +from collections import defaultdict +from typing import Dict, List, Literal, Sequence, Set, TypeVar from agentlightning.types.adapter import Tree from agentlightning.types.tracer import Span, SpanLike @@ -14,8 +17,11 @@ T_from = TypeVar("T_from") T_to = TypeVar("T_to") +logger = logging.getLogger(__name__) + class ToSpans(SequenceAdapter[SpanLike, Span]): + """Normalize the span-like objects (e.g., OpenTelemetry `ReadableSpan`) to [spans][agentlightning.Span].""" def __init__( self, @@ -40,7 +46,171 @@ def adapt_one(self, source: SpanLike) -> Span: class ToTree(Adapter[Sequence[Span], Tree[Span]]): - def adapt(self, source: Sequence[Span]) -> Tree[Span]: ... + def __init__( + self, + repair_bad_hierarchy: Literal["dangling", "all", "none"] = "dangling", + repair_missing_parents: bool = True, + ): + self.repair_bad_hierarchy = repair_bad_hierarchy + self.repair_missing_parents = repair_missing_parents + + def _validate_tree(self, graph: Dict[str, List[str]], root_ids: Set[str]) -> None: + visited = set[str]() + + def visit(node_id: str) -> None: + if node_id in visited: + raise ValueError(f"Cycle detected in the tree: {node_id}") + visited.add(node_id) + for child_id in graph[node_id]: + visit(child_id) + + for root_id in root_ids: + visit(root_id) + + if len(visited) != len(graph): + raise ValueError(f"Some spans are not reachable from the roots: {set(graph.keys()) - visited}") + + def _compute_depths(self, graph: Dict[str, List[str]], root_ids: Set[str]) -> Dict[str, int]: + depths = {root: 0 for root in root_ids} + + def visit(node_id: str) -> None: + for child_id in graph[node_id]: + depths[child_id] = depths[node_id] + 1 + visit(child_id) + + for root_id in root_ids: + visit(root_id) + + return depths + + def _compute_ancestors(self, graph: Dict[str, List[str]], root_ids: Set[str]) -> Dict[str, Set[str]]: + ancestors = {root: set[str]() for root in root_ids} + + def visit(node_id: str) -> None: + for child_id in graph[node_id]: + ancestors[child_id] = ancestors[node_id] | {node_id} + visit(child_id) + + for root_id in root_ids: + visit(root_id) + + return ancestors + + def _find_eligible_parents( + self, + all_spans: Sequence[Span], + current: Span, + forward_graph: Dict[str, List[str]], + root_ids: Set[str], + depths: Dict[str, int], + parent_ids: Dict[str, str], + ) -> List[Span]: + """We wish to find a good place to insert the span, which is ideally it's sibling or sibling's child. + + Filter the candidates by: (1) must not in current's ancestors; (2) must not be in current's subtree; + (3) must have an ancestor that is the parent of current span; (4) have start and end time covering the current span. + The third condition can be optional if current span has no parent. + + Then sort the candidates by: (1) shortest to longest, (2) deep to shallow. + """ + spans_to_consider: List[Span] = [] + # This needs to be re-computed every time. + # It will be too troublesome to maintain the dynamic cache of ancestors. + ancestors = self._compute_ancestors(forward_graph, root_ids) + + for candidate_parent in all_spans: + if candidate_parent.span_id == current.span_id: + continue + if ( + candidate_parent.ensure_start_time() > current.ensure_start_time() + or candidate_parent.ensure_end_time() < current.ensure_end_time() + ): + # If the span is not covering the current span, it cannot be a parent. + continue + if candidate_parent.span_id in ancestors[current.span_id]: + # If the span is in the current's ancestors, it cannot be a parent. + continue + if current.span_id in ancestors[candidate_parent.span_id]: + # If the span is in the current's subtree, it cannot be a parent. + continue + if current.span_id in parent_ids: + # If the current span has a parent, the eligible parent must live in the parent's subtree. + if parent_ids[current.span_id] not in ancestors[candidate_parent.span_id]: + continue + spans_to_consider.append(candidate_parent) + + # Sort the spans + return sorted( + spans_to_consider, + key=lambda span: (span.ensure_end_time() - span.ensure_start_time(), depths[span.span_id]), + ) + + def _repair_bad_hierarchy( + self, + source: Sequence[Span], + forward_graph: Dict[str, List[str]], + root_ids: Set[str], + parent_ids: Dict[str, str], + ) -> Sequence[Span]: + depths = self._compute_depths(forward_graph, root_ids) + + # Scan all the spans by: (1) shallow to deep, (2) longest to shortest, (3) earliest to latest. + scan_order = sorted( + source, + key=lambda span: ( + depths[span.span_id], + span.ensure_end_time() - span.ensure_start_time(), + span.ensure_start_time(), + ), + ) + for i, span in enumerate(scan_order): + # Check whether we should repair this span. + # It must be a dangling span, or the user wants to repair all the spans. + if ( + self.repair_bad_hierarchy == "dangling" and span.span_id not in parent_ids + ) or self.repair_bad_hierarchy == "all": + # We wish to find a good place to insert the span. + eligible_parents = self._find_eligible_parents( + source, span, forward_graph, root_ids, depths, parent_ids + ) + if eligible_parents: + original_parent_id = parent_ids.get(span.span_id, None) + new_parent_id = eligible_parents[0].span_id + scan_order[i] = span.model_copy(update={"parent_id": new_parent_id}) + # Maintain the cache + parent_ids[span.span_id] = new_parent_id + if original_parent_id is not None: + forward_graph[original_parent_id].remove(span.span_id) + forward_graph[new_parent_id].append(span.span_id) + + return scan_order + + def adapt(self, source: Sequence[Span]) -> Tree[Span]: + if not isinstance(source, Sequence): # pyright: ignore[reportUnnecessaryIsInstance] + raise TypeError(f"Expected a sequence of spans, but got {type(source)}") + if not source: + raise ValueError("No spans provided to create Tree") + + valid_span_ids = set(span.span_id for span in source) + root_ids = set(span.span_id for span in source if span.parent_id is None) + forward_graph: Dict[str, List[str]] = defaultdict(list) + parent_ids: Dict[str, str] = {} + for span in source: + if span.parent_id is not None: + if span.parent_id in valid_span_ids: + forward_graph[span.parent_id].append(span.span_id) + root_ids.discard(span.span_id) + parent_ids[span.span_id] = span.parent_id + else: + logger.debug( + f'Span {span.span_id} has an invalid parent ID "{span.parent_id}". The parent will be ignored.' + ) + + self._validate_tree(forward_graph, root_ids) + + source = self._repair_bad_hierarchy(source, forward_graph, root_ids, parent_ids) + + return Tree(source[0], [Tree(span, []) for span in source if span.parent_id is not None]) class ToSortedSpans(Sort[Span]): @@ -50,6 +220,44 @@ def __init__(self) -> None: super().__init__(key=lambda span: (span.sequence_id, span.start_time)) +class RepairTime(Adapter[Sequence[Span], Sequence[Span]]): + """Repair the end time of the spans by: + + 1. Ensuring the end time is greater than the start time. + 2. Fill the spans with no end time to be the maximum start/end time of all spans. + """ + + def adapt(self, source: Sequence[Span]) -> Sequence[Span]: + times_set = set[float]() + for span in source: + if span.start_time is not None: + times_set.add(span.start_time) + if span.end_time is not None: + times_set.add(span.end_time) + + if not times_set: + logger.debug("No times set in the spans. Setting all the time to current time.") + current_time = time.time() + else: + current_time = max(times_set) + + new_spans: List[Span] = [] + + for span in source: + update_fields: Dict[str, float] = {} + if span.start_time is None: + update_fields["start_time"] = current_time + if span.end_time is None: + update_fields["end_time"] = current_time + if span.start_time is not None and span.end_time is not None and span.end_time < span.start_time: + update_fields["end_time"] = span.start_time + if update_fields: + new_spans.append(span.model_copy(update=update_fields)) + else: + new_spans.append(span) + return new_spans + + class RepairTreeHierarchy(Adapter[Tree[Span], Tree[Span]]): """Repair the tree hierarchy by ensuring that parent-child relationships are consistent with span start and end times. Adding missing parent-child relationships as needed. diff --git a/agentlightning/types/tracer.py b/agentlightning/types/tracer.py index 2db9a2f48..e7825276e 100644 --- a/agentlightning/types/tracer.py +++ b/agentlightning/types/tracer.py @@ -477,6 +477,18 @@ def from_core_fields( status=core.status, ) + def ensure_start_time(self) -> float: + """Get the start time or raise an error if it's not set.""" + if self.start_time is None: + raise ValueError("Start time is not set. Try to use the `RepairTime` adapter to repair the span.") + return self.start_time + + def ensure_end_time(self) -> float: + """Get the end time or raise an error if it's not set.""" + if self.end_time is None: + raise ValueError("End time is not set. Try to use the `RepairTime` adapter to repair the span.") + return self.end_time + class SpanNames(str, Enum): """Enumerated span names recognised by Agent-lightning. Deprecated in favor of [semconv][agentlightning.semconv].""" From 62d5581a3683757d53b8b6ee171a01346cf1691e Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Mon, 29 Dec 2025 13:19:53 +0800 Subject: [PATCH 07/41] . --- agentlightning/adapter/preprocess.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agentlightning/adapter/preprocess.py b/agentlightning/adapter/preprocess.py index e1ddf6db5..a6d53a7a3 100644 --- a/agentlightning/adapter/preprocess.py +++ b/agentlightning/adapter/preprocess.py @@ -159,7 +159,7 @@ def _repair_bad_hierarchy( source, key=lambda span: ( depths[span.span_id], - span.ensure_end_time() - span.ensure_start_time(), + -(span.ensure_end_time() - span.ensure_start_time()), span.ensure_start_time(), ), ) From 70efeb5c470c78d545bb17dc372526e9cbf0b04a Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Mon, 29 Dec 2025 18:26:58 +0800 Subject: [PATCH 08/41] fix tree --- agentlightning/adapter/preprocess.py | 322 +++++++++++++++++++-------- 1 file changed, 226 insertions(+), 96 deletions(-) diff --git a/agentlightning/adapter/preprocess.py b/agentlightning/adapter/preprocess.py index a6d53a7a3..8b73f1ebf 100644 --- a/agentlightning/adapter/preprocess.py +++ b/agentlightning/adapter/preprocess.py @@ -7,10 +7,12 @@ import logging import time from collections import defaultdict -from typing import Dict, List, Literal, Sequence, Set, TypeVar +from typing import Dict, List, Literal, Sequence, Set, Tuple, TypeVar +from agentlightning.semconv import AGL_VIRTUAL from agentlightning.types.adapter import Tree from agentlightning.types.tracer import Span, SpanLike +from agentlightning.utils.id import generate_id from .base import Adapter, SequenceAdapter, Sort @@ -20,6 +22,107 @@ logger = logging.getLogger(__name__) +class _TreeLikeGraph: + """A simple directed graph implementation for span hierarchy. + + In preparation before forming a tree. + """ + + def __init__(self) -> None: + self.forward_graph: Dict[str, List[str]] = defaultdict(list) + self.parent_map: Dict[str, str] = {} + self.root_ids: Set[str] = set() + + def add_edge(self, from_node: str, to_node: str) -> None: + self.forward_graph[from_node].append(to_node) + + def move_subtree(self, node_id: str, new_parent_id: str) -> None: + old_parent_id = self.parent_map.get(node_id, None) + if old_parent_id is not None: + self.forward_graph[old_parent_id].remove(node_id) + self.add_edge(new_parent_id, node_id) + self.parent_map[node_id] = new_parent_id + if node_id in self.root_ids: + self.root_ids.remove(node_id) + + def validate_no_cycles(self) -> None: + visited = set[str]() + + def visit(node_id: str) -> None: + if node_id in visited: + raise ValueError(f"Cycle detected in the graph at node {node_id}") + visited.add(node_id) + for child_id in self.forward_graph[node_id]: + visit(child_id) + + for root_id in self.root_ids: + visit(root_id) + + if len(visited) != len(self.forward_graph): + raise ValueError("Some nodes are not reachable from the roots") + + def compute_depths(self) -> Dict[str, int]: + depths = {root: 0 for root in self.root_ids} + + def visit(node_id: str) -> None: + for child_id in self.forward_graph[node_id]: + depths[child_id] = depths[node_id] + 1 + visit(child_id) + + for root_id in self.root_ids: + visit(root_id) + + return depths + + def compute_ancestors(self) -> Dict[str, Set[str]]: + ancestors = {root: set[str]() for root in self.root_ids} + + def visit(node_id: str) -> None: + for child_id in self.forward_graph[node_id]: + ancestors[child_id] = ancestors[node_id] | {node_id} + visit(child_id) + + for root_id in self.root_ids: + visit(root_id) + + return ancestors + + def to_tree(self, spans: Sequence[Span]) -> Tree[Span]: + spans_dict = {span.span_id: span for span in spans} + + def build_subtree(node_id: str) -> Tree[Span]: + children = [build_subtree(child_id) for child_id in self.forward_graph.get(node_id, [])] + return Tree(spans_dict[node_id], children) + + if len(self.root_ids) != 1: + raise ValueError( + "Cannot convert to tree: multiple or no roots found; enable repair options of ToTree to fix this." + ) + root_id = next(iter(self.root_ids)) + return build_subtree(root_id) + + @staticmethod + def from_spans(spans: Sequence[Span], logs_invalid_parent: bool = True) -> _TreeLikeGraph: + graph = _TreeLikeGraph() + + valid_span_ids = set(span.span_id for span in spans) + graph.root_ids = set(span.span_id for span in spans if span.parent_id is None) + for span in spans: + if span.parent_id is not None: + if span.parent_id in valid_span_ids: + graph.forward_graph[span.parent_id].append(span.span_id) + graph.root_ids.discard(span.span_id) + graph.parent_map[span.span_id] = span.parent_id + elif logs_invalid_parent: + logger.debug( + f'Span {span.span_id} has an invalid parent ID "{span.parent_id}". The parent will be ignored.' + ) + + graph.validate_no_cycles() + + return graph + + class ToSpans(SequenceAdapter[SpanLike, Span]): """Normalize the span-like objects (e.g., OpenTelemetry `ReadableSpan`) to [spans][agentlightning.Span].""" @@ -50,60 +153,18 @@ def __init__( self, repair_bad_hierarchy: Literal["dangling", "all", "none"] = "dangling", repair_missing_parents: bool = True, + repair_multiple_roots: bool = True, ): self.repair_bad_hierarchy = repair_bad_hierarchy self.repair_missing_parents = repair_missing_parents - - def _validate_tree(self, graph: Dict[str, List[str]], root_ids: Set[str]) -> None: - visited = set[str]() - - def visit(node_id: str) -> None: - if node_id in visited: - raise ValueError(f"Cycle detected in the tree: {node_id}") - visited.add(node_id) - for child_id in graph[node_id]: - visit(child_id) - - for root_id in root_ids: - visit(root_id) - - if len(visited) != len(graph): - raise ValueError(f"Some spans are not reachable from the roots: {set(graph.keys()) - visited}") - - def _compute_depths(self, graph: Dict[str, List[str]], root_ids: Set[str]) -> Dict[str, int]: - depths = {root: 0 for root in root_ids} - - def visit(node_id: str) -> None: - for child_id in graph[node_id]: - depths[child_id] = depths[node_id] + 1 - visit(child_id) - - for root_id in root_ids: - visit(root_id) - - return depths - - def _compute_ancestors(self, graph: Dict[str, List[str]], root_ids: Set[str]) -> Dict[str, Set[str]]: - ancestors = {root: set[str]() for root in root_ids} - - def visit(node_id: str) -> None: - for child_id in graph[node_id]: - ancestors[child_id] = ancestors[node_id] | {node_id} - visit(child_id) - - for root_id in root_ids: - visit(root_id) - - return ancestors + self.repair_multiple_roots = repair_multiple_roots def _find_eligible_parents( self, all_spans: Sequence[Span], current: Span, - forward_graph: Dict[str, List[str]], - root_ids: Set[str], - depths: Dict[str, int], - parent_ids: Dict[str, str], + graph: _TreeLikeGraph, + cache_depths: Dict[str, int], ) -> List[Span]: """We wish to find a good place to insert the span, which is ideally it's sibling or sibling's child. @@ -116,7 +177,7 @@ def _find_eligible_parents( spans_to_consider: List[Span] = [] # This needs to be re-computed every time. # It will be too troublesome to maintain the dynamic cache of ancestors. - ancestors = self._compute_ancestors(forward_graph, root_ids) + ancestors = graph.compute_ancestors() for candidate_parent in all_spans: if candidate_parent.span_id == current.span_id: @@ -133,26 +194,25 @@ def _find_eligible_parents( if current.span_id in ancestors[candidate_parent.span_id]: # If the span is in the current's subtree, it cannot be a parent. continue - if current.span_id in parent_ids: + if current.span_id in graph.parent_map: # If the current span has a parent, the eligible parent must live in the parent's subtree. - if parent_ids[current.span_id] not in ancestors[candidate_parent.span_id]: + if graph.parent_map[current.span_id] not in ancestors[candidate_parent.span_id]: continue spans_to_consider.append(candidate_parent) # Sort the spans return sorted( spans_to_consider, - key=lambda span: (span.ensure_end_time() - span.ensure_start_time(), depths[span.span_id]), + key=lambda span: (span.ensure_end_time() - span.ensure_start_time(), cache_depths[span.span_id]), ) - def _repair_bad_hierarchy( - self, - source: Sequence[Span], - forward_graph: Dict[str, List[str]], - root_ids: Set[str], - parent_ids: Dict[str, str], - ) -> Sequence[Span]: - depths = self._compute_depths(forward_graph, root_ids) + def _repair_bad_hierarchy(self, source: Sequence[Span]) -> Sequence[Span]: + """Repair bad hierarchy by re-attaching dangling spans or all spans. + + This is based on the chronological relationships between start time and end time of spans. + """ + graph = _TreeLikeGraph.from_spans(source) + depths = graph.compute_depths() # Scan all the spans by: (1) shallow to deep, (2) longest to shortest, (3) earliest to latest. scan_order = sorted( @@ -167,50 +227,91 @@ def _repair_bad_hierarchy( # Check whether we should repair this span. # It must be a dangling span, or the user wants to repair all the spans. if ( - self.repair_bad_hierarchy == "dangling" and span.span_id not in parent_ids + self.repair_bad_hierarchy == "dangling" and span.span_id not in graph.parent_map ) or self.repair_bad_hierarchy == "all": # We wish to find a good place to insert the span. - eligible_parents = self._find_eligible_parents( - source, span, forward_graph, root_ids, depths, parent_ids - ) + eligible_parents = self._find_eligible_parents(source, span, graph, depths) if eligible_parents: - original_parent_id = parent_ids.get(span.span_id, None) new_parent_id = eligible_parents[0].span_id scan_order[i] = span.model_copy(update={"parent_id": new_parent_id}) - # Maintain the cache - parent_ids[span.span_id] = new_parent_id - if original_parent_id is not None: - forward_graph[original_parent_id].remove(span.span_id) - forward_graph[new_parent_id].append(span.span_id) + + # Maintain/update the cache + graph.move_subtree(span.span_id, new_parent_id) return scan_order + def _repair_missing_parents(self, source: Sequence[Span]) -> Sequence[Span]: + valid_span_ids: Set[str] = set(span.span_id for span in source) + parent_to_children: Dict[str, List[Span]] = defaultdict(list) + + for span in source: + if span.parent_id is not None and span.parent_id not in valid_span_ids: + parent_to_children[span.parent_id].append(span) + + created_spans: List[Span] = [] + for parent_id, children in parent_to_children.items(): + child = children[0] + # Create a virtual span for the missing parent + created_spans.append( + Span.from_attributes( + rollout_id=child.rollout_id, + attempt_id=child.attempt_id, + sequence_id=child.sequence_id, + trace_id=child.trace_id, + span_id=parent_id, + parent_id=None, + name=AGL_VIRTUAL, + attributes={}, + start_time=min(child.ensure_start_time() for child in children), + end_time=max(child.ensure_end_time() for child in children), + ) + ) + + return [*source, *created_spans] + + def _repair_multiple_roots(self, source: Sequence[Span]) -> Sequence[Span]: + root_spans = [span for span in source if span.parent_id is None] + + if len(root_spans) <= 1: + return source + + # Create a new root span + new_root_span = Span.from_attributes( + rollout_id=root_spans[0].rollout_id, + attempt_id=root_spans[0].attempt_id, + sequence_id=root_spans[0].sequence_id, + trace_id=root_spans[0].trace_id, + span_id="span-" + generate_id(12), + parent_id=None, + name=AGL_VIRTUAL, + attributes={}, + start_time=min(span.ensure_start_time() for span in root_spans), + end_time=max(span.ensure_end_time() for span in root_spans), + ) + + updated_spans = [ + span.model_copy(update={"parent_id": new_root_span.span_id}) if span in root_spans else span + for span in source + ] + return [new_root_span, *updated_spans] + def adapt(self, source: Sequence[Span]) -> Tree[Span]: if not isinstance(source, Sequence): # pyright: ignore[reportUnnecessaryIsInstance] raise TypeError(f"Expected a sequence of spans, but got {type(source)}") if not source: raise ValueError("No spans provided to create Tree") - valid_span_ids = set(span.span_id for span in source) - root_ids = set(span.span_id for span in source if span.parent_id is None) - forward_graph: Dict[str, List[str]] = defaultdict(list) - parent_ids: Dict[str, str] = {} - for span in source: - if span.parent_id is not None: - if span.parent_id in valid_span_ids: - forward_graph[span.parent_id].append(span.span_id) - root_ids.discard(span.span_id) - parent_ids[span.span_id] = span.parent_id - else: - logger.debug( - f'Span {span.span_id} has an invalid parent ID "{span.parent_id}". The parent will be ignored.' - ) - - self._validate_tree(forward_graph, root_ids) + source = self._repair_bad_hierarchy(source) - source = self._repair_bad_hierarchy(source, forward_graph, root_ids, parent_ids) + # The other repairing steps should be done *after* repairing the bad hierarchy because + # some problems might be fixed during the bad hierarchy repairing. + if self.repair_missing_parents: + source = self._repair_missing_parents(source) + # repair missing parents might have introduced new roots. + if self.repair_multiple_roots: + source = self._repair_multiple_roots(source) - return Tree(source[0], [Tree(span, []) for span in source if span.parent_id is not None]) + return _TreeLikeGraph.from_spans(source).to_tree(source) class ToSortedSpans(Sort[Span]): @@ -227,6 +328,34 @@ class RepairTime(Adapter[Sequence[Span], Sequence[Span]]): 2. Fill the spans with no end time to be the maximum start/end time of all spans. """ + def __init__(self, ensure_positive_duration: bool = True, ensure_proper_nesting: bool = True) -> None: + self.ensure_positive_duration = ensure_positive_duration + self.ensure_proper_nesting = ensure_proper_nesting + + def _repair_nesting(self, source: Sequence[Span]) -> Sequence[Span]: + graph = _TreeLikeGraph.from_spans(source, logs_invalid_parent=False) + spans = {span.span_id: span for span in source} + + def visit(node_id: str) -> Tuple[float, float]: + child_start_end_times: List[Tuple[float, float]] = [] + cur_start_time = spans[node_id].ensure_start_time() + cur_end_time = spans[node_id].ensure_end_time() + if graph.forward_graph.get(node_id): + for child_id in graph.forward_graph[node_id]: + child_start_end_times.append(visit(child_id)) + start_times, end_times = zip(*child_start_end_times) + start_time = min(cur_start_time, *start_times) + end_time = max(cur_end_time, *end_times) + if start_time != cur_start_time or end_time != cur_end_time: + spans[node_id] = spans[node_id].model_copy(update={"start_time": start_time, "end_time": end_time}) + + return spans[node_id].ensure_start_time(), spans[node_id].ensure_end_time() + + for root_id in graph.root_ids: + visit(root_id) + + return [spans[span.span_id] for span in source] + def adapt(self, source: Sequence[Span]) -> Sequence[Span]: times_set = set[float]() for span in source: @@ -249,18 +378,19 @@ def adapt(self, source: Sequence[Span]) -> Sequence[Span]: update_fields["start_time"] = current_time if span.end_time is None: update_fields["end_time"] = current_time - if span.start_time is not None and span.end_time is not None and span.end_time < span.start_time: + if ( + self.ensure_positive_duration + and span.start_time is not None + and span.end_time is not None + and span.end_time < span.start_time + ): update_fields["end_time"] = span.start_time if update_fields: new_spans.append(span.model_copy(update=update_fields)) else: new_spans.append(span) - return new_spans - -class RepairTreeHierarchy(Adapter[Tree[Span], Tree[Span]]): - """Repair the tree hierarchy by ensuring that parent-child relationships are consistent - with span start and end times. Adding missing parent-child relationships as needed. - """ - - def adapt(self, source: Tree[Span]) -> Tree[Span]: ... + if self.ensure_proper_nesting: + return self._repair_nesting(new_spans) + else: + return new_spans From 8b2963e63ed7a4a00add2eeb652d674d32520aca Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Tue, 30 Dec 2025 12:10:15 +0800 Subject: [PATCH 09/41] implement CurateAnnotations --- agentlightning/adapter/annotation.py | 221 ++++++++++++++++++++++++++- agentlightning/types/adapter.py | 8 +- 2 files changed, 222 insertions(+), 7 deletions(-) diff --git a/agentlightning/adapter/annotation.py b/agentlightning/adapter/annotation.py index 7cbf64c59..aa6a847da 100644 --- a/agentlightning/adapter/annotation.py +++ b/agentlightning/adapter/annotation.py @@ -4,20 +4,235 @@ from __future__ import annotations -from typing import Literal, Sequence, Tuple, TypeVar +import logging +from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, TypeVar, cast -from agentlightning.types.adapter import Annotation +from opentelemetry.semconv.attributes import exception_attributes + +from agentlightning.emitter.message import get_message_value +from agentlightning.emitter.object import get_object_value +from agentlightning.emitter.reward import get_rewards_from_span +from agentlightning.semconv import ( + AGL_ANNOTATION, + AGL_EXCEPTION, + AGL_MESSAGE, + AGL_OBJECT, + AGL_OPERATION, + LightningSpanAttributes, + LinkPydanticModel, +) +from agentlightning.types.adapter import ( + AgentAnnotation, + Annotation, + ExceptionAnnotation, + GeneralAnnotation, + MessageAnnotation, + ObjectAnnotation, + OperationAnnotation, +) from agentlightning.types.tracer import Span +from agentlightning.utils.otel import ( + extract_links_from_attributes, + extract_tags_from_attributes, + filter_and_unflatten_attributes, +) from .base import Adapter T_SpanSequence = TypeVar("T_SpanSequence", bound=Sequence[Span]) +logger = logging.getLogger(__name__) + class CurateAnnotations(Adapter[Sequence[Span], Sequence[Annotation]]): """Curate the annotations from the spans.""" - def adapt(self, source: Sequence[Span]) -> Sequence[Annotation]: ... + def _filter_custom_attributes(self, attributes: Dict[str, Any]) -> Dict[str, Any]: + reserved_fields = [attr.value for attr in LightningSpanAttributes if attr.value in attributes] + return { + key: value + for key, value in attributes.items() + if not any( + # Filter out those that are reserved fields or start with reserved fields (plus ".") + key == reserved_field or key.startswith(reserved_field + ".") + for reserved_field in reserved_fields + ) + } + + def extract_links(self, span: Span) -> Sequence[LinkPydanticModel]: + try: + return extract_links_from_attributes(span.attributes) + except Exception as exc: + logger.error(f"Link is malformed for span {span.span_id}: {exc}") + return [] + + def adapt_general(self, span: Span) -> Optional[GeneralAnnotation]: + rewards = get_rewards_from_span(span) + primary_reward = rewards[0].value if rewards else None + return GeneralAnnotation( + annotation_type="general", + span_id=span.span_id, + links=self.extract_links(span), + rewards=rewards, + primary_reward=primary_reward, + tags=extract_tags_from_attributes(span.attributes), + custom_fields=self._filter_custom_attributes(span.attributes), + ) + + def adapt_message(self, span: Span) -> Optional[MessageAnnotation]: + msg_body = get_message_value(span) + if msg_body is None: + logger.warning(f"Message body is missing for message span {span.span_id}") + return None + + return MessageAnnotation( + annotation_type="message", + span_id=span.span_id, + links=self.extract_links(span), + message=msg_body, + ) + + def adapt_object(self, span: Span) -> Optional[ObjectAnnotation]: + try: + obj_value = get_object_value(span) + except Exception as exc: + logger.error(f"Fail to deserialize object for object span {span.span_id}: {exc}") + return None + + return ObjectAnnotation( + annotation_type="object", + span_id=span.span_id, + links=self.extract_links(span), + object=obj_value, + ) + + def adapt_exception(self, span: Span) -> Optional[ExceptionAnnotation]: + exception_type = span.attributes.get(exception_attributes.EXCEPTION_TYPE, "UnknownException") + exception_message = span.attributes.get(exception_attributes.EXCEPTION_MESSAGE, "") + exception_stacktrace = span.attributes.get(exception_attributes.EXCEPTION_STACKTRACE, "") + + return ExceptionAnnotation( + annotation_type="exception", + span_id=span.span_id, + links=self.extract_links(span), + type=str(exception_type), + message=str(exception_message), + stacktrace=str(exception_stacktrace), + ) + + def adapt_operation(self, span: Span) -> Optional[OperationAnnotation]: + try: + operation_name = span.attributes.get(LightningSpanAttributes.OPERATION_NAME.value, "UnknownOperation") + if LightningSpanAttributes.OPERATION_INPUT.value in span.attributes: + operation_input = span.attributes[LightningSpanAttributes.OPERATION_INPUT.value] + else: + operation_input = filter_and_unflatten_attributes( + span.attributes, LightningSpanAttributes.OPERATION_INPUT.value + ) + if LightningSpanAttributes.OPERATION_OUTPUT.value in span.attributes: + operation_output = span.attributes[LightningSpanAttributes.OPERATION_OUTPUT.value] + else: + operation_output = filter_and_unflatten_attributes( + span.attributes, LightningSpanAttributes.OPERATION_OUTPUT.value + ) + except Exception as exc: + logger.error(f"Fail to unpack operation context for operation span {span.span_id}: {exc}") + return None + + return OperationAnnotation( + annotation_type="operation", + span_id=span.span_id, + links=self.extract_links(span), + name=str(operation_name), + input=operation_input, + output=operation_output, + ) + + def extract_agent_id(self, span: Span) -> Optional[str]: + # TODO: Support agent id in other formats + return cast(Optional[str], span.attributes.get("agent.id")) + + def extract_agent_description(self, span: Span) -> Optional[str]: + # TODO: Support agent description in other formats + return cast(Optional[str], span.attributes.get("agent.description")) + + def extract_agent_name(self, span: Span) -> Optional[str]: + # Case 1: OpenTelemetry Agent Spans + agent_name = cast(Optional[str], span.attributes.get("agent.name")) + if agent_name is not None: + return agent_name + + # Case 2: Agentops decorator @agent + is_agent = span.attributes.get("agentops.span.kind") == "agent" + if is_agent: + agent_name = cast(Optional[str], span.attributes.get("operation.name")) + if agent_name is not None: + return agent_name + + # Case 3: Autogen team + agent_name = cast(Optional[str], span.attributes.get("recipient_agent_type")) + if agent_name is not None: + return agent_name + + # Case 4: LangGraph + agent_name = cast(Optional[str], span.attributes.get("langchain.chain.type")) + if agent_name is not None: + return agent_name + + # Case 5: agent-framework + agent_name = cast(Optional[str], span.attributes.get("executor.id")) + if agent_name is not None: + return agent_name + + # Case 6: Weave + is_agent_type = span.attributes.get("type") == "agent" + if is_agent_type: + agent_name = cast(Optional[str], span.attributes.get("agentlightning.operation.input.name")) + if agent_name is not None: + return agent_name + + # Case 7: Weave + LangChain + if span.name.startswith("langchain.Chain."): + attributes_lc_name = cast(Optional[str], span.attributes.get("lc_name")) + if attributes_lc_name is not None: + return attributes_lc_name + + def detect_agent_annotation(self, span: Span) -> Optional[AgentAnnotation]: + agent_id = self.extract_agent_id(span) + agent_name = self.extract_agent_name(span) + agent_description = self.extract_agent_description(span) + + if agent_name is not None: + return AgentAnnotation( + annotation_type="agent", + span_id=span.span_id, + links=self.extract_links(span), + id=agent_id, + name=agent_name, + description=agent_description, + ) + return None + + def adapt(self, source: Sequence[Span]) -> Sequence[Annotation]: + annotations: List[Annotation] = [] + for span in source: + annotation: Optional[Annotation] = None + if span.name == AGL_ANNOTATION: + annotation = self.adapt_general(span) + elif span.name == AGL_MESSAGE: + annotation = self.adapt_message(span) + elif span.name == AGL_OBJECT: + annotation = self.adapt_object(span) + elif span.name == AGL_EXCEPTION: + annotation = self.adapt_exception(span) + elif span.name == AGL_OPERATION: + annotation = self.adapt_operation(span) + else: + # Fallback to agent annotation detection + annotation = self.detect_agent_annotation(span) + if annotation is not None: + annotations.append(annotation) + return annotations class SelectByAnnotation(Adapter[Tuple[T_SpanSequence, Sequence[Annotation]], T_SpanSequence]): diff --git a/agentlightning/types/adapter.py b/agentlightning/types/adapter.py index 23be24d60..1b4125370 100644 --- a/agentlightning/types/adapter.py +++ b/agentlightning/types/adapter.py @@ -26,7 +26,7 @@ ) from pydantic import BaseModel, Field -from agentlightning.semconv import LinkPydanticModel +from agentlightning.semconv import LinkPydanticModel, RewardPydanticModel from .tracer import Attributes @@ -113,7 +113,7 @@ class Annotation(BaseModel): span_id: str """Span ID of the annotation span. Not necessarily an [AGL_ANNOTATION][agentlightning.semconv.AGL_ANNOTATION] span.""" - links: Optional[Sequence[LinkPydanticModel]] = None + links: Sequence[LinkPydanticModel] = Field(default_factory=list[LinkPydanticModel]) """Links to other spans or objects.""" @@ -139,13 +139,13 @@ class GeneralAnnotation(Annotation): annotation_type = "general" """Type of the annotation.""" - reward: Dict[str, float] = Field(default_factory=dict) + rewards: Sequence[RewardPydanticModel] = Field(default_factory=list[RewardPydanticModel]) """Reward dimensions and values.""" primary_reward: Optional[float] = None """Primary reward value.""" - tag: Sequence[str] = Field(default_factory=list) + tags: Sequence[str] = Field(default_factory=list) """Tags for the annotation.""" custom_fields: Dict[str, Any] = Field(default_factory=dict) From d81c6dcb9b22efd17df51b260a25f60fd9ede82a Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Tue, 30 Dec 2025 15:05:01 +0800 Subject: [PATCH 10/41] . --- agentlightning/adapter/annotation.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/agentlightning/adapter/annotation.py b/agentlightning/adapter/annotation.py index aa6a847da..68265bce3 100644 --- a/agentlightning/adapter/annotation.py +++ b/agentlightning/adapter/annotation.py @@ -242,11 +242,18 @@ class SelectByAnnotation(Adapter[Tuple[T_SpanSequence, Sequence[Annotation]], T_ The effective radius of an annotation is as follows: - If the annotation has links, it applies to the linked spans only. - - If the annotation is on a tree node, it applies to all spans in the subtree. + - If the annotation is on a tree node, it applies to all spans in its subtree. - If the annotation has neither links nor tree nodes, it applies to only itself. + The adapter either selects the union of the effective radius of all annotations, + or excludes the union of effective radius. + + When the source is a tree, to avoid the tree nodes from becoming fragmented, + the adapter will also include the ancestors of the tree nodes in "include" mode. + Args: mode: "include" to select spans within the annotations; "exclude" to exclude them. + """ def __init__(self, mode: Literal["include", "exclude"]) -> None: From 3747315df2bce60df9cd9cd5b55437d5757bfba0 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Sun, 4 Jan 2026 14:16:48 +0800 Subject: [PATCH 11/41] . --- agentlightning/adapter/annotation.py | 41 ++++++++++++--- agentlightning/adapter/call.py | 10 ++-- agentlightning/types/adapter.py | 42 +++++++++++++--- agentlightning/utils/otel.py | 75 ++++++++++++++-------------- 4 files changed, 113 insertions(+), 55 deletions(-) diff --git a/agentlightning/adapter/annotation.py b/agentlightning/adapter/annotation.py index 68265bce3..fd49d2c8f 100644 --- a/agentlightning/adapter/annotation.py +++ b/agentlightning/adapter/annotation.py @@ -5,7 +5,7 @@ from __future__ import annotations import logging -from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, TypeVar, cast +from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, TypeVar, cast from opentelemetry.semconv.attributes import exception_attributes @@ -29,9 +29,11 @@ MessageAnnotation, ObjectAnnotation, OperationAnnotation, + Tree, ) from agentlightning.types.tracer import Span from agentlightning.utils.otel import ( + check_linked_span, extract_links_from_attributes, extract_tags_from_attributes, filter_and_unflatten_attributes, @@ -259,10 +261,31 @@ class SelectByAnnotation(Adapter[Tuple[T_SpanSequence, Sequence[Annotation]], T_ def __init__(self, mode: Literal["include", "exclude"]) -> None: self.mode = mode - def adapt(self, source: Tuple[T_SpanSequence, Sequence[Annotation]]) -> T_SpanSequence: ... + def _filter_linked_spans(self, source: Sequence[Span], annotation: Sequence[Annotation]) -> Iterable[Span]: + annotation_span_ids = set(annotation.span_id for annotation in annotation) + for span in source: + if span.span_id in annotation_span_ids: + yield span + elif any(check_linked_span(span, annotation.links) for annotation in annotation): + yield span + # ignore the current span for now + + def adapt(self, source: Tuple[T_SpanSequence, Sequence[Annotation]]) -> T_SpanSequence: + spans, annotations = source + linked_spans = self._filter_linked_spans(spans, annotations) + if isinstance(spans, Tree): + if self.mode == "include": + return cast(T_SpanSequence, spans.retain(lambda span: span in linked_spans)) + else: + return cast(T_SpanSequence, spans.prune(lambda span: span not in linked_spans)) + else: + if self.mode == "include": + return cast(T_SpanSequence, list(linked_spans)) + else: + return cast(T_SpanSequence, [span for span in spans if span not in linked_spans]) -class FillMissingLinks(Adapter[Tuple[Sequence[Span], Sequence[Annotation]], Sequence[Span]]): +class RepairMissingLinks(Adapter[Tuple[Sequence[Span], Sequence[Annotation]], Sequence[Span]]): """Populate missing annotation links by searching nearby spans. This adapter scans annotations and, for any annotation that has no linked spans, attempts @@ -272,16 +295,18 @@ class FillMissingLinks(Adapter[Tuple[Sequence[Span], Sequence[Annotation]], Sequ but failed to attach their target spans; this adapter backfills those links based on proximity and eligibility rules. + The annotation spans do not necessarily have to appear in the input span sequence. + Args: - require_annotation_span_child: - If True, only attempt to fill links for annotations whose *own* span is present - as a child span in the candidate span set. If False, annotations are considered - regardless of whether their span is a child. + annotation_span_required: + If True, only attempt to fill links for annotations whose *span_id* is present + in the candidate span set. If False, annotations are considered + regardless of whether their span is in the input span sequence. candidate_scope: Controls which spans are eligible as link targets: - - "siblings": search only among sibling spans of the annotation span. + - "siblings": search only among sibling spans of the annotation span. Only applicable when input span sequence is a tree. - "all": search among all spans provided to the adapter. scan_direction: diff --git a/agentlightning/adapter/call.py b/agentlightning/adapter/call.py index 8330dd9ef..ca30a4fe8 100644 --- a/agentlightning/adapter/call.py +++ b/agentlightning/adapter/call.py @@ -4,9 +4,9 @@ from __future__ import annotations -from typing import Sequence, Tuple +from typing import Sequence, Tuple, Union -from agentlightning.types.adapter import AnnotatedChatCompletionCall, Annotation, ChatCompletionCall +from agentlightning.types.adapter import AnnotatedChatCompletionCall, Annotation, ChatCompletionCall, Tree from agentlightning.types.tracer import Span from .base import Adapter @@ -15,6 +15,10 @@ class CurateChatCompletionCalls(Adapter[Sequence[Span], Sequence[ChatCompletionCall]]): """Curate the chat completion calls from the spans.""" + def _parse_openai_chat_completion_create(self, span: Union[Span, Tree[Span]]) -> ChatCompletionCall: ... + + def _parse_litellm_request(self, span: Union[Span, Tree[Span]]) -> ChatCompletionCall: ... + def adapt(self, source: Sequence[Span]) -> Sequence[ChatCompletionCall]: ... @@ -26,7 +30,7 @@ class AnnotateChatCompletionCalls( The intersection of "effective radius" of annotations and chat completion calls is used to determine which annotations apply to which chat completion calls. - If an annotation is not linked to any span, try to use `FillMissingLinks` first to link it to spans. + If an annotation is not linked to any span, try to use `RepairMissingLinks` first to link it to spans. """ def adapt( diff --git a/agentlightning/types/adapter.py b/agentlightning/types/adapter.py index 1b4125370..e41e37870 100644 --- a/agentlightning/types/adapter.py +++ b/agentlightning/types/adapter.py @@ -16,6 +16,8 @@ Optional, Sequence, TypeVar, + Union, + overload, ) from openai.types.chat import ( @@ -36,7 +38,7 @@ # General containers -class Tree(Generic[T]): +class Tree(Sequence[T], Generic[T]): """This is a generic tree data structure that can be used to represent the structure of a tree.""" def __init__(self, item: T, children: MutableSequence[Tree[T]]) -> None: @@ -54,15 +56,18 @@ def size(self) -> int: def __iter__(self) -> Iterator[T]: return iter(self.traverse()) - def __getitem__(self, index: int) -> T: + @overload + def __getitem__(self, index: int) -> T: ... + + @overload + def __getitem__(self, index: slice) -> Sequence[T]: ... + + def __getitem__(self, index: Union[int, slice]) -> Union[T, Sequence[T]]: """Get the index-th item in the tree (O(n) time complexity). I think this is not efficient, but it's seldomly used. """ - for i, item in enumerate(self.traverse()): - if i == index: - return item - raise IndexError(f"Tree index out of range: {index}") + return list(self.traverse())[index] def __len__(self) -> int: return self.size() @@ -70,8 +75,31 @@ def __len__(self) -> int: def add(self, child: Tree[T]) -> None: self.children.append(child) + def _retain_subtree(self, predicate: Callable[[T], bool]) -> Optional[Tree[T]]: + if predicate(self.item): + # If the current node satisfies the predicate, retain the subtree + return self + + subtrees = [child._retain_subtree(predicate) for child in self.children] + if all(subtree is None for subtree in subtrees): + # no subtrees satisfy the predicate, remove the current node + return None + + return Tree(self.item, [subtree for subtree in subtrees if subtree is not None]) + + def retain(self, predicate: Callable[[T], bool]) -> Tree[T]: + """Prune the tree by retaining subtrees with root nodes that satisfy the predicate. + + The root node is always retained. + """ + return self._retain_subtree(predicate) or Tree(self.item, []) + def prune(self, predicate: Callable[[T], bool]) -> Tree[T]: - return Tree(self.item, [child.prune(predicate) for child in self.children if predicate(child.item)]) + """Prune the tree by removing nodes that satisfy the predicate. + + The root node is always retained. + """ + return Tree(self.item, [child.prune(predicate) for child in self.children if not predicate(child.item)]) def visualize(self, filename: str, item_to_str: Callable[[T], str]) -> None: """Render the tree with Graphviz for debugging purposes. diff --git a/agentlightning/utils/otel.py b/agentlightning/utils/otel.py index 59909d2ff..2bf8c17eb 100644 --- a/agentlightning/utils/otel.py +++ b/agentlightning/utils/otel.py @@ -32,6 +32,7 @@ "make_tag_attributes", "extract_tags_from_attributes", "make_link_attributes", + "check_linked_span", "query_linked_spans", "extract_links_from_attributes", "filter_attributes", @@ -229,7 +230,42 @@ def make_link_attributes(links: Dict[str, str]) -> Dict[str, Any]: return flatten_attributes({LightningSpanAttributes.LINK.value: link_list}, expand_leaf_lists=True) -def query_linked_spans(spans: Sequence[T_SpanLike], links: List[LinkPydanticModel]) -> List[T_SpanLike]: +def check_linked_span(span: SpanLike, links: Sequence[LinkPydanticModel]) -> bool: + """Check if a span matches a link attribute. + + Args: + span: A span to check. + links: A list of link attributes to match. + """ + span_attributes = span.attributes or {} + for link in links: + # trace_id and span_id must be full match. + if link.key_match == "trace_id": + if isinstance(span, ReadableSpan): + trace_id = trace_api.format_trace_id(span.context.trace_id) if span.context else None + else: + trace_id = span.trace_id + if trace_id != link.value_match: + return False + + elif link.key_match == "span_id": + if isinstance(span, ReadableSpan): + span_id = trace_api.format_span_id(span.context.span_id) if span.context else None + else: + span_id = span.span_id + if span_id != link.value_match: + return False + + else: + attribute = span_attributes.get(link.key_match) + # attributes must also be a full match currently. + if attribute != link.value_match: + return False + + return True + + +def query_linked_spans(spans: Sequence[T_SpanLike], links: Sequence[LinkPydanticModel]) -> List[T_SpanLike]: """Query spans that are linked by the given link attributes. Args: @@ -239,42 +275,7 @@ def query_linked_spans(spans: Sequence[T_SpanLike], links: List[LinkPydanticMode Returns: A list of spans that match the given link attributes. """ - matched_spans: List[T_SpanLike] = [] - - for span in spans: - span_attributes = span.attributes or {} - is_match = True - for link in links: - # trace_id and span_id must be full match. - if link.key_match == "trace_id": - if isinstance(span, ReadableSpan): - trace_id = trace_api.format_trace_id(span.context.trace_id) if span.context else None - else: - trace_id = span.trace_id - if trace_id != link.value_match: - is_match = False - break - - elif link.key_match == "span_id": - if isinstance(span, ReadableSpan): - span_id = trace_api.format_span_id(span.context.span_id) if span.context else None - else: - span_id = span.span_id - if span_id != link.value_match: - is_match = False - break - - else: - attribute = span_attributes.get(link.key_match) - # attributes must also be a full match currently. - if attribute != link.value_match: - is_match = False - break - - if is_match: - matched_spans.append(span) - - return matched_spans + return [span for span in spans if check_linked_span(span, links)] def extract_links_from_attributes(attributes: Dict[str, Any]) -> List[LinkPydanticModel]: From 932752bcb8b32233e865ed31fcca159790b266eb Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Mon, 5 Jan 2026 16:10:48 +0800 Subject: [PATCH 12/41] . --- agentlightning/adapter/annotation.py | 59 +++++----- agentlightning/adapter/base.py | 14 ++- agentlightning/types/adapter.py | 157 +++++++++++++++++++++------ 3 files changed, 162 insertions(+), 68 deletions(-) diff --git a/agentlightning/adapter/annotation.py b/agentlightning/adapter/annotation.py index fd49d2c8f..672dafe3e 100644 --- a/agentlightning/adapter/annotation.py +++ b/agentlightning/adapter/annotation.py @@ -22,9 +22,12 @@ LinkPydanticModel, ) from agentlightning.types.adapter import ( + AdaptingSequence, + AdaptingSpan, AgentAnnotation, Annotation, ExceptionAnnotation, + FromSpan, GeneralAnnotation, MessageAnnotation, ObjectAnnotation, @@ -39,14 +42,14 @@ filter_and_unflatten_attributes, ) -from .base import Adapter +from .base import Adapter, AdaptingSequenceAdapter T_SpanSequence = TypeVar("T_SpanSequence", bound=Sequence[Span]) logger = logging.getLogger(__name__) -class CurateAnnotations(Adapter[Sequence[Span], Sequence[Annotation]]): +class CurateAnnotations(AdaptingSequenceAdapter[AdaptingSpan, AdaptingSpan]): """Curate the annotations from the spans.""" def _filter_custom_attributes(self, attributes: Dict[str, Any]) -> Dict[str, Any]: @@ -73,7 +76,6 @@ def adapt_general(self, span: Span) -> Optional[GeneralAnnotation]: primary_reward = rewards[0].value if rewards else None return GeneralAnnotation( annotation_type="general", - span_id=span.span_id, links=self.extract_links(span), rewards=rewards, primary_reward=primary_reward, @@ -89,7 +91,6 @@ def adapt_message(self, span: Span) -> Optional[MessageAnnotation]: return MessageAnnotation( annotation_type="message", - span_id=span.span_id, links=self.extract_links(span), message=msg_body, ) @@ -103,7 +104,6 @@ def adapt_object(self, span: Span) -> Optional[ObjectAnnotation]: return ObjectAnnotation( annotation_type="object", - span_id=span.span_id, links=self.extract_links(span), object=obj_value, ) @@ -115,7 +115,6 @@ def adapt_exception(self, span: Span) -> Optional[ExceptionAnnotation]: return ExceptionAnnotation( annotation_type="exception", - span_id=span.span_id, links=self.extract_links(span), type=str(exception_type), message=str(exception_message), @@ -143,7 +142,6 @@ def adapt_operation(self, span: Span) -> Optional[OperationAnnotation]: return OperationAnnotation( annotation_type="operation", - span_id=span.span_id, links=self.extract_links(span), name=str(operation_name), input=operation_input, @@ -207,7 +205,6 @@ def detect_agent_annotation(self, span: Span) -> Optional[AgentAnnotation]: if agent_name is not None: return AgentAnnotation( annotation_type="agent", - span_id=span.span_id, links=self.extract_links(span), id=agent_id, name=agent_name, @@ -215,26 +212,30 @@ def detect_agent_annotation(self, span: Span) -> Optional[AgentAnnotation]: ) return None - def adapt(self, source: Sequence[Span]) -> Sequence[Annotation]: - annotations: List[Annotation] = [] - for span in source: - annotation: Optional[Annotation] = None - if span.name == AGL_ANNOTATION: - annotation = self.adapt_general(span) - elif span.name == AGL_MESSAGE: - annotation = self.adapt_message(span) - elif span.name == AGL_OBJECT: - annotation = self.adapt_object(span) - elif span.name == AGL_EXCEPTION: - annotation = self.adapt_exception(span) - elif span.name == AGL_OPERATION: - annotation = self.adapt_operation(span) - else: - # Fallback to agent annotation detection - annotation = self.detect_agent_annotation(span) - if annotation is not None: - annotations.append(annotation) - return annotations + def adapt_one(self, source: AdaptingSpan) -> AdaptingSpan: + annotation: Optional[Annotation] = None + if source.name == AGL_ANNOTATION: + annotation = self.adapt_general(source) + elif source.name == AGL_MESSAGE: + annotation = self.adapt_message(source) + elif source.name == AGL_OBJECT: + annotation = self.adapt_object(source) + elif source.name == AGL_EXCEPTION: + annotation = self.adapt_exception(source) + elif source.name == AGL_OPERATION: + annotation = self.adapt_operation(source) + else: + # Fallback to agent annotation detection + annotation = self.detect_agent_annotation(source) + if annotation is not None: + if source.data is not None: + logger.warning( + "Found annotation on an adapting span with existing data; overwriting the data. " + f"Current data: {source.data}, New data: {annotation}" + ) + return AdaptingSpan.from_span(source, data=annotation) + else: + return source class SelectByAnnotation(Adapter[Tuple[T_SpanSequence, Sequence[Annotation]], T_SpanSequence]): @@ -285,7 +286,7 @@ def adapt(self, source: Tuple[T_SpanSequence, Sequence[Annotation]]) -> T_SpanSe return cast(T_SpanSequence, [span for span in spans if span not in linked_spans]) -class RepairMissingLinks(Adapter[Tuple[Sequence[Span], Sequence[Annotation]], Sequence[Span]]): +class RepairMissingLinks(Adapter[Tuple[Sequence[Annotation], Sequence[FromSpan], Sequence[Span]], Sequence[Span]]): """Populate missing annotation links by searching nearby spans. This adapter scans annotations and, for any annotation that has no linked spans, attempts diff --git a/agentlightning/adapter/base.py b/agentlightning/adapter/base.py index ef941ca72..e33f4de18 100644 --- a/agentlightning/adapter/base.py +++ b/agentlightning/adapter/base.py @@ -5,6 +5,7 @@ from opentelemetry.sdk.trace import ReadableSpan from agentlightning.types import Span +from agentlightning.types.adapter import AdaptingSequence T_from = TypeVar("T_from") T_to = TypeVar("T_to") @@ -66,17 +67,18 @@ def adapt(self, source: T_from, /) -> T_to: raise NotImplementedError("Adapter.adapt() is not implemented") -class SequenceAdapter(Adapter[Sequence[T_from], Sequence[T_to]], Generic[T_from, T_to]): - """Base class for adapters that convert sequences of data from one format to another. +class SequenceAdapter(Adapter[AdaptingSequence[T_from], AdaptingSequence[T_to]], Generic[T_from, T_to]): + """Base class for adapters that convert adapting sequences of data from one format to another. - This class specializes [`Adapter`][agentlightning.Adapter] for working with sequences of data. + This class specializes [`Adapter`][agentlightning.Adapter] for working with + [`AdaptingSequence`][agentlightning.AdaptingSequence] instances. """ - def adapt(self, source: Sequence[T_from]) -> Sequence[T_to]: - return [self.adapt_one(item) for item in source] + def adapt(self, source: AdaptingSequence[T_from]) -> AdaptingSequence[T_to]: + return source.map(self.adapt_one) def adapt_one(self, source: T_from) -> T_to: - raise NotImplementedError("SequenceAdapter.adapt_one() is not implemented") + raise NotImplementedError("AdaptingSequenceAdapter.adapt_one() is not implemented") class Filter(Adapter[Sequence[T_from], Sequence[T_from]], Generic[T_from]): diff --git a/agentlightning/types/adapter.py b/agentlightning/types/adapter.py index e41e37870..cb544f120 100644 --- a/agentlightning/types/adapter.py +++ b/agentlightning/types/adapter.py @@ -30,76 +30,118 @@ from agentlightning.semconv import LinkPydanticModel, RewardPydanticModel -from .tracer import Attributes +from .tracer import Attributes, Span T = TypeVar("T") +V = TypeVar("V") + + +class AdaptingSequence(Sequence[T], Generic[T]): + """Interface that makes adapter easier to work with sequences.""" + + @overload + def __getitem__(self, index: int) -> T: ... + + @overload + def __getitem__(self, index: slice) -> Sequence[T]: ... + + def __getitem__(self, index: Union[int, slice]) -> Union[T, Sequence[T]]: + return self.get(index) + + def __iter__(self) -> Iterator[T]: + return iter(self.traverse()) + + def __len__(self) -> int: + return self.size() + + def get(self, index: Union[int, slice]) -> Union[T, Sequence[T]]: + """Get the index-th item in the sequence.""" + raise NotImplementedError() + + def map(self, func: Callable[[T], V]) -> AdaptingSequence[V]: + """Map a function over all items in the sequence.""" + raise NotImplementedError() + + def retain(self, predicate: Callable[[T], bool]) -> AdaptingSequence[T]: + """Filter items in the sequence by a predicate (true for items to be kept). + + Depending on the implementation, the returned sequence may contain more or less items than a standard filter. + """ + raise NotImplementedError() + + def prune(self, predicate: Callable[[T], bool]) -> AdaptingSequence[T]: + """Prune items in the sequence by a predicate (true for items to be pruned). + + Depending on the implementation, the returned sequence may contain more or less items than a standard prune. + """ + raise NotImplementedError() + + def size(self) -> int: + """Get the size of the sequence.""" + raise NotImplementedError() + + def traverse(self) -> Iterable[T]: + """Traverse all items in the sequence.""" + raise NotImplementedError() # General containers -class Tree(Sequence[T], Generic[T]): +class Tree(AdaptingSequence[T], Generic[T]): """This is a generic tree data structure that can be used to represent the structure of a tree.""" def __init__(self, item: T, children: MutableSequence[Tree[T]]) -> None: - self.item = item - self.children = children + self._item = item + self._children = children def traverse(self) -> Iterable[T]: - yield self.item - for child in self.children: + yield self._item + for child in self._children: yield from child.traverse() def size(self) -> int: - return 1 + sum(child.size() for child in self.children) - - def __iter__(self) -> Iterator[T]: - return iter(self.traverse()) - - @overload - def __getitem__(self, index: int) -> T: ... + return 1 + sum(child.size() for child in self._children) - @overload - def __getitem__(self, index: slice) -> Sequence[T]: ... - - def __getitem__(self, index: Union[int, slice]) -> Union[T, Sequence[T]]: + def get(self, index: Union[int, slice]) -> Union[T, Sequence[T]]: """Get the index-th item in the tree (O(n) time complexity). I think this is not efficient, but it's seldomly used. """ return list(self.traverse())[index] - def __len__(self) -> int: - return self.size() - def add(self, child: Tree[T]) -> None: - self.children.append(child) + self._children.append(child) + + def map(self, func: Callable[[T], V]) -> Tree[V]: + """Map a function over all items in the tree.""" + return Tree(func(self._item), [child.map(func) for child in self._children]) def _retain_subtree(self, predicate: Callable[[T], bool]) -> Optional[Tree[T]]: - if predicate(self.item): + if predicate(self._item): # If the current node satisfies the predicate, retain the subtree return self - subtrees = [child._retain_subtree(predicate) for child in self.children] + subtrees = [child._retain_subtree(predicate) for child in self._children] if all(subtree is None for subtree in subtrees): # no subtrees satisfy the predicate, remove the current node return None - return Tree(self.item, [subtree for subtree in subtrees if subtree is not None]) + return Tree(self._item, [subtree for subtree in subtrees if subtree is not None]) def retain(self, predicate: Callable[[T], bool]) -> Tree[T]: """Prune the tree by retaining subtrees with root nodes that satisfy the predicate. The root node is always retained. """ - return self._retain_subtree(predicate) or Tree(self.item, []) + return self._retain_subtree(predicate) or Tree(self._item, []) def prune(self, predicate: Callable[[T], bool]) -> Tree[T]: """Prune the tree by removing nodes that satisfy the predicate. The root node is always retained. """ - return Tree(self.item, [child.prune(predicate) for child in self.children if not predicate(child.item)]) + return Tree(self._item, [child.prune(predicate) for child in self._children if not predicate(child._item)]) def visualize(self, filename: str, item_to_str: Callable[[T], str]) -> None: """Render the tree with Graphviz for debugging purposes. @@ -118,7 +160,7 @@ def visualize(self, filename: str, item_to_str: Callable[[T], str]) -> None: def visit(node: Tree[T]): dot.node(str(id(node)), item_to_str(node.item)) # type: ignore - for child in node.children: + for child in node._children: visit(child) dot.edge(str(id(node)), str(id(child))) # type: ignore @@ -126,21 +168,73 @@ def visit(node: Tree[T]): dot.render(filename, format="png", cleanup=True) # type: ignore +class SimpleList(AdaptingSequence[T], Generic[T]): + """A simple list implementation of AdaptingSequence.""" + + def __init__(self, items: Sequence[T]) -> None: + self._items = list(items) + + def get(self, index: Union[int, slice]) -> Union[T, Sequence[T]]: + return self._items[index] + + def traverse(self) -> Iterable[T]: + return iter(self._items) + + def size(self) -> int: + return len(self._items) + + def map(self, func: Callable[[T], V]) -> SimpleList[V]: + return SimpleList([func(item) for item in self._items]) + + def retain(self, predicate: Callable[[T], bool]) -> SimpleList[T]: + return SimpleList([item for item in self._items if predicate(item)]) + + def prune(self, predicate: Callable[[T], bool]) -> SimpleList[T]: + return SimpleList([item for item in self._items if not predicate(item)]) + + +class AdaptingSpan(Span): + """A span that has been adapted to a different format. + + This class extends the base [`Span`][agentlightning.Span] class to represent spans that have + been converted to a different format by an adapter. + """ + + data: Any + """The data in the adapted format. Could be annotations, calls, or other structured data.""" + + @classmethod + def from_span(cls, span: Span, data: Any) -> AdaptingSpan: + """Create an [`AdaptingSpan`][agentlightning.AdaptingSpan] from a base [`Span`][agentlightning.Span]. + + Args: + span: The base span to convert. + data: The data in the adapted format. + + Returns: + An instance of [`AdaptingSpan`][agentlightning.AdaptingSpan] with the same properties as + the input span and the provided adapted data. + """ + if isinstance(span, AdaptingSpan): + return span.model_copy(update={"data": data}) + else: + return AdaptingSpan.model_validate(span.model_dump()).model_copy(update={"data": data}) + + # Annotation-related types class Annotation(BaseModel): """An annotation is an approach to parse a span into some kind of structured attachments to another object. + Not necessarily an [AGL_ANNOTATION][agentlightning.semconv.AGL_ANNOTATION] span. + Note that a span can be parsed in multiple ways, and annotation is just one of them. """ annotation_type: Literal["agent", "general", "message", "object", "exception", "operation"] """Type of the annotation.""" - span_id: str - """Span ID of the annotation span. Not necessarily an [AGL_ANNOTATION][agentlightning.semconv.AGL_ANNOTATION] span.""" - links: Sequence[LinkPydanticModel] = Field(default_factory=list[LinkPydanticModel]) """Links to other spans or objects.""" @@ -251,9 +345,6 @@ class ChatCompletionCall(BaseModel): Mapping from span names to a dict of malformed fields. """ - span_ids: Sequence[str] - """Span IDs of the spans that contributed to this chat completion.""" - class AnnotatedChatCompletionCall(ChatCompletionCall): """A chat completion call with annotations.""" From bea08333a62955751722c2e1753e61a820052f20 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Wed, 7 Jan 2026 13:43:19 +0800 Subject: [PATCH 13/41] . --- agentlightning/adapter/annotation.py | 97 +++++++++++++++++----------- agentlightning/adapter/base.py | 8 +-- agentlightning/adapter/preprocess.py | 6 +- agentlightning/types/adapter.py | 32 +++++---- 4 files changed, 87 insertions(+), 56 deletions(-) diff --git a/agentlightning/adapter/annotation.py b/agentlightning/adapter/annotation.py index 672dafe3e..7b5e4e8da 100644 --- a/agentlightning/adapter/annotation.py +++ b/agentlightning/adapter/annotation.py @@ -5,10 +5,11 @@ from __future__ import annotations import logging -from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, TypeVar, cast +from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, TypeVar, cast from opentelemetry.semconv.attributes import exception_attributes +from agentlightning.adapter.preprocess import default_span_order from agentlightning.emitter.message import get_message_value from agentlightning.emitter.object import get_object_value from agentlightning.emitter.reward import get_rewards_from_span @@ -26,6 +27,7 @@ AdaptingSpan, AgentAnnotation, Annotation, + BaseAdaptingSequence, ExceptionAnnotation, FromSpan, GeneralAnnotation, @@ -42,14 +44,14 @@ filter_and_unflatten_attributes, ) -from .base import Adapter, AdaptingSequenceAdapter +from .base import Adapter, SequenceAdapter T_SpanSequence = TypeVar("T_SpanSequence", bound=Sequence[Span]) logger = logging.getLogger(__name__) -class CurateAnnotations(AdaptingSequenceAdapter[AdaptingSpan, AdaptingSpan]): +class IdentifyAnnotations(SequenceAdapter[AdaptingSpan, AdaptingSpan]): """Curate the annotations from the spans.""" def _filter_custom_attributes(self, attributes: Dict[str, Any]) -> Dict[str, Any]: @@ -71,7 +73,7 @@ def extract_links(self, span: Span) -> Sequence[LinkPydanticModel]: logger.error(f"Link is malformed for span {span.span_id}: {exc}") return [] - def adapt_general(self, span: Span) -> Optional[GeneralAnnotation]: + def identify_general(self, span: Span) -> Optional[GeneralAnnotation]: rewards = get_rewards_from_span(span) primary_reward = rewards[0].value if rewards else None return GeneralAnnotation( @@ -83,7 +85,7 @@ def adapt_general(self, span: Span) -> Optional[GeneralAnnotation]: custom_fields=self._filter_custom_attributes(span.attributes), ) - def adapt_message(self, span: Span) -> Optional[MessageAnnotation]: + def identify_message(self, span: Span) -> Optional[MessageAnnotation]: msg_body = get_message_value(span) if msg_body is None: logger.warning(f"Message body is missing for message span {span.span_id}") @@ -95,7 +97,7 @@ def adapt_message(self, span: Span) -> Optional[MessageAnnotation]: message=msg_body, ) - def adapt_object(self, span: Span) -> Optional[ObjectAnnotation]: + def identify_object(self, span: Span) -> Optional[ObjectAnnotation]: try: obj_value = get_object_value(span) except Exception as exc: @@ -108,7 +110,7 @@ def adapt_object(self, span: Span) -> Optional[ObjectAnnotation]: object=obj_value, ) - def adapt_exception(self, span: Span) -> Optional[ExceptionAnnotation]: + def identify_exception(self, span: Span) -> Optional[ExceptionAnnotation]: exception_type = span.attributes.get(exception_attributes.EXCEPTION_TYPE, "UnknownException") exception_message = span.attributes.get(exception_attributes.EXCEPTION_MESSAGE, "") exception_stacktrace = span.attributes.get(exception_attributes.EXCEPTION_STACKTRACE, "") @@ -121,7 +123,7 @@ def adapt_exception(self, span: Span) -> Optional[ExceptionAnnotation]: stacktrace=str(exception_stacktrace), ) - def adapt_operation(self, span: Span) -> Optional[OperationAnnotation]: + def identify_operation(self, span: Span) -> Optional[OperationAnnotation]: try: operation_name = span.attributes.get(LightningSpanAttributes.OPERATION_NAME.value, "UnknownOperation") if LightningSpanAttributes.OPERATION_INPUT.value in span.attributes: @@ -215,15 +217,15 @@ def detect_agent_annotation(self, span: Span) -> Optional[AgentAnnotation]: def adapt_one(self, source: AdaptingSpan) -> AdaptingSpan: annotation: Optional[Annotation] = None if source.name == AGL_ANNOTATION: - annotation = self.adapt_general(source) + annotation = self.identify_general(source) elif source.name == AGL_MESSAGE: - annotation = self.adapt_message(source) + annotation = self.identify_message(source) elif source.name == AGL_OBJECT: - annotation = self.adapt_object(source) + annotation = self.identify_object(source) elif source.name == AGL_EXCEPTION: - annotation = self.adapt_exception(source) + annotation = self.identify_exception(source) elif source.name == AGL_OPERATION: - annotation = self.adapt_operation(source) + annotation = self.identify_operation(source) else: # Fallback to agent annotation detection annotation = self.detect_agent_annotation(source) @@ -238,7 +240,7 @@ def adapt_one(self, source: AdaptingSpan) -> AdaptingSpan: return source -class SelectByAnnotation(Adapter[Tuple[T_SpanSequence, Sequence[Annotation]], T_SpanSequence]): +class SelectByAnnotation(SequenceAdapter[AdaptingSpan, AdaptingSpan]): """Select the corresponding spans within the annotation sequence, as well as their linked spans (and subtree spans if applicable). @@ -262,31 +264,26 @@ class SelectByAnnotation(Adapter[Tuple[T_SpanSequence, Sequence[Annotation]], T_ def __init__(self, mode: Literal["include", "exclude"]) -> None: self.mode = mode - def _filter_linked_spans(self, source: Sequence[Span], annotation: Sequence[Annotation]) -> Iterable[Span]: - annotation_span_ids = set(annotation.span_id for annotation in annotation) + def _filter_linked_spans(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Iterable[AdaptingSpan]: + annotation_spans = [span for span in source if isinstance(span.data, Annotation)] + annotation_span_ids = set(annotation_span.span_id for annotation_span in annotation_spans) + annotation_links = [cast(Annotation, span.data).links for span in annotation_spans] for span in source: if span.span_id in annotation_span_ids: yield span - elif any(check_linked_span(span, annotation.links) for annotation in annotation): + elif any(check_linked_span(span, links) for links in annotation_links): yield span # ignore the current span for now - def adapt(self, source: Tuple[T_SpanSequence, Sequence[Annotation]]) -> T_SpanSequence: - spans, annotations = source - linked_spans = self._filter_linked_spans(spans, annotations) - if isinstance(spans, Tree): - if self.mode == "include": - return cast(T_SpanSequence, spans.retain(lambda span: span in linked_spans)) - else: - return cast(T_SpanSequence, spans.prune(lambda span: span not in linked_spans)) + def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> BaseAdaptingSequence[AdaptingSpan]: + linked_spans = list(self._filter_linked_spans(source)) + if self.mode == "include": + return source.retain(lambda span: span in linked_spans) else: - if self.mode == "include": - return cast(T_SpanSequence, list(linked_spans)) - else: - return cast(T_SpanSequence, [span for span in spans if span not in linked_spans]) + return source.prune(lambda span: span not in linked_spans) -class RepairMissingLinks(Adapter[Tuple[Sequence[Annotation], Sequence[FromSpan], Sequence[Span]], Sequence[Span]]): +class RepairMissingLinks(SequenceAdapter[AdaptingSpan, AdaptingSpan]): """Populate missing annotation links by searching nearby spans. This adapter scans annotations and, for any annotation that has no linked spans, attempts @@ -296,13 +293,9 @@ class RepairMissingLinks(Adapter[Tuple[Sequence[Annotation], Sequence[FromSpan], but failed to attach their target spans; this adapter backfills those links based on proximity and eligibility rules. - The annotation spans do not necessarily have to appear in the input span sequence. - Args: - annotation_span_required: - If True, only attempt to fill links for annotations whose *span_id* is present - in the candidate span set. If False, annotations are considered - regardless of whether their span is in the input span sequence. + candidate_predicate: + A predicate to filter the candidate spans. If None, all spans within the candidate scope are considered. candidate_scope: Controls which spans are eligible as link targets: @@ -310,6 +303,8 @@ class RepairMissingLinks(Adapter[Tuple[Sequence[Annotation], Sequence[FromSpan], - "siblings": search only among sibling spans of the annotation span. Only applicable when input span sequence is a tree. - "all": search among all spans provided to the adapter. + The intersection of the candidate scope and predicate forms the candidate span set. + scan_direction: Determines both (a) which direction the adapter searches for candidate targets relative to an annotation and (b) the order in which annotations are processed: @@ -325,14 +320,38 @@ class RepairMissingLinks(Adapter[Tuple[Sequence[Annotation], Sequence[FromSpan], def __init__( self, - require_annotation_span_child: bool = True, + candidate_predicate: Optional[Callable[[AdaptingSpan], bool]] = None, candidate_scope: Literal["siblings", "all"] = "all", scan_direction: Literal["backward", "forward"] = "backward", allow_reuse_linked_spans: bool = False, ) -> None: - self.require_annotation_span_child = require_annotation_span_child + if candidate_predicate is not None: + self.candidate_predicate = candidate_predicate + else: + self.candidate_predicate = lambda _: True # type: ignore self.candidate_scope = candidate_scope self.scan_direction = scan_direction self.allow_reuse_linked_spans = allow_reuse_linked_spans - def adapt(self, source: Tuple[Sequence[Span], Sequence[Annotation]]) -> Sequence[Span]: ... + def _search_groups(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Iterable[Sequence[AdaptingSpan]]: + if self.candidate_scope == "siblings": + if not isinstance(source, Tree): + raise ValueError("Candidate scope 'siblings' is only applicable to tree sequences") + + def visit(node: Tree[AdaptingSpan]) -> Iterable[Sequence[AdaptingSpan]]: + # Each group must be siblings + yield [child.item for child in node.children] # yield siblings first + for child in node.children: # then yield children recursively + yield from visit(child) + + yield [source.item] # yield root first + yield from visit(source) + + elif self.candidate_scope == "all": + return sorted(list(source), key=lambda span: default_span_order(span)) + + else: + raise ValueError(f"Invalid candidate scope: {self.candidate_scope}") + + def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> BaseAdaptingSequence[AdaptingSpan]: + groups = list(self._search_groups(source)) diff --git a/agentlightning/adapter/base.py b/agentlightning/adapter/base.py index e33f4de18..1a88471e7 100644 --- a/agentlightning/adapter/base.py +++ b/agentlightning/adapter/base.py @@ -5,7 +5,7 @@ from opentelemetry.sdk.trace import ReadableSpan from agentlightning.types import Span -from agentlightning.types.adapter import AdaptingSequence +from agentlightning.types.adapter import AdaptingSequence, BaseAdaptingSequence T_from = TypeVar("T_from") T_to = TypeVar("T_to") @@ -67,18 +67,18 @@ def adapt(self, source: T_from, /) -> T_to: raise NotImplementedError("Adapter.adapt() is not implemented") -class SequenceAdapter(Adapter[AdaptingSequence[T_from], AdaptingSequence[T_to]], Generic[T_from, T_to]): +class SequenceAdapter(Adapter[BaseAdaptingSequence[T_from], BaseAdaptingSequence[T_to]], Generic[T_from, T_to]): """Base class for adapters that convert adapting sequences of data from one format to another. This class specializes [`Adapter`][agentlightning.Adapter] for working with [`AdaptingSequence`][agentlightning.AdaptingSequence] instances. """ - def adapt(self, source: AdaptingSequence[T_from]) -> AdaptingSequence[T_to]: + def adapt(self, source: BaseAdaptingSequence[T_from]) -> BaseAdaptingSequence[T_to]: return source.map(self.adapt_one) def adapt_one(self, source: T_from) -> T_to: - raise NotImplementedError("AdaptingSequenceAdapter.adapt_one() is not implemented") + raise NotImplementedError(f"{self.__class__.__name__}.adapt_one() is not implemented") class Filter(Adapter[Sequence[T_from], Sequence[T_from]], Generic[T_from]): diff --git a/agentlightning/adapter/preprocess.py b/agentlightning/adapter/preprocess.py index 8b73f1ebf..3431675e7 100644 --- a/agentlightning/adapter/preprocess.py +++ b/agentlightning/adapter/preprocess.py @@ -22,6 +22,10 @@ logger = logging.getLogger(__name__) +def default_span_order(span: Span) -> tuple[int, float, float]: + return (span.sequence_id, span.ensure_start_time(), span.ensure_end_time()) + + class _TreeLikeGraph: """A simple directed graph implementation for span hierarchy. @@ -92,7 +96,7 @@ def to_tree(self, spans: Sequence[Span]) -> Tree[Span]: def build_subtree(node_id: str) -> Tree[Span]: children = [build_subtree(child_id) for child_id in self.forward_graph.get(node_id, [])] - return Tree(spans_dict[node_id], children) + return Tree(spans_dict[node_id], sorted(children, key=lambda child: default_span_order(child.item))) if len(self.root_ids) != 1: raise ValueError( diff --git a/agentlightning/types/adapter.py b/agentlightning/types/adapter.py index cb544f120..d1d5c275f 100644 --- a/agentlightning/types/adapter.py +++ b/agentlightning/types/adapter.py @@ -36,7 +36,7 @@ V = TypeVar("V") -class AdaptingSequence(Sequence[T], Generic[T]): +class BaseAdaptingSequence(Sequence[T], Generic[T]): """Interface that makes adapter easier to work with sequences.""" @overload @@ -58,18 +58,18 @@ def get(self, index: Union[int, slice]) -> Union[T, Sequence[T]]: """Get the index-th item in the sequence.""" raise NotImplementedError() - def map(self, func: Callable[[T], V]) -> AdaptingSequence[V]: + def map(self, func: Callable[[T], V]) -> BaseAdaptingSequence[V]: """Map a function over all items in the sequence.""" raise NotImplementedError() - def retain(self, predicate: Callable[[T], bool]) -> AdaptingSequence[T]: + def retain(self, predicate: Callable[[T], bool]) -> BaseAdaptingSequence[T]: """Filter items in the sequence by a predicate (true for items to be kept). Depending on the implementation, the returned sequence may contain more or less items than a standard filter. """ raise NotImplementedError() - def prune(self, predicate: Callable[[T], bool]) -> AdaptingSequence[T]: + def prune(self, predicate: Callable[[T], bool]) -> BaseAdaptingSequence[T]: """Prune items in the sequence by a predicate (true for items to be pruned). Depending on the implementation, the returned sequence may contain more or less items than a standard prune. @@ -88,13 +88,21 @@ def traverse(self) -> Iterable[T]: # General containers -class Tree(AdaptingSequence[T], Generic[T]): +class Tree(BaseAdaptingSequence[T], Generic[T]): """This is a generic tree data structure that can be used to represent the structure of a tree.""" def __init__(self, item: T, children: MutableSequence[Tree[T]]) -> None: self._item = item self._children = children + @property + def item(self) -> T: + return self._item + + @property + def children(self) -> Sequence[Tree[T]]: + return self._children + def traverse(self) -> Iterable[T]: yield self._item for child in self._children: @@ -168,7 +176,7 @@ def visit(node: Tree[T]): dot.render(filename, format="png", cleanup=True) # type: ignore -class SimpleList(AdaptingSequence[T], Generic[T]): +class AdaptingSequence(BaseAdaptingSequence[T], Generic[T]): """A simple list implementation of AdaptingSequence.""" def __init__(self, items: Sequence[T]) -> None: @@ -183,14 +191,14 @@ def traverse(self) -> Iterable[T]: def size(self) -> int: return len(self._items) - def map(self, func: Callable[[T], V]) -> SimpleList[V]: - return SimpleList([func(item) for item in self._items]) + def map(self, func: Callable[[T], V]) -> AdaptingSequence[V]: + return AdaptingSequence([func(item) for item in self._items]) - def retain(self, predicate: Callable[[T], bool]) -> SimpleList[T]: - return SimpleList([item for item in self._items if predicate(item)]) + def retain(self, predicate: Callable[[T], bool]) -> AdaptingSequence[T]: + return AdaptingSequence([item for item in self._items if predicate(item)]) - def prune(self, predicate: Callable[[T], bool]) -> SimpleList[T]: - return SimpleList([item for item in self._items if not predicate(item)]) + def prune(self, predicate: Callable[[T], bool]) -> AdaptingSequence[T]: + return AdaptingSequence([item for item in self._items if not predicate(item)]) class AdaptingSpan(Span): From cf343fc7ada7ec271a2b6785e3f7339cc019e100 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Wed, 7 Jan 2026 16:53:41 +0800 Subject: [PATCH 14/41] . --- agentlightning/adapter/annotation.py | 45 ++++++++++++++++++++++++---- 1 file changed, 40 insertions(+), 5 deletions(-) diff --git a/agentlightning/adapter/annotation.py b/agentlightning/adapter/annotation.py index 7b5e4e8da..f5b76633e 100644 --- a/agentlightning/adapter/annotation.py +++ b/agentlightning/adapter/annotation.py @@ -5,7 +5,7 @@ from __future__ import annotations import logging -from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, TypeVar, cast +from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, TypeVar, cast from opentelemetry.semconv.attributes import exception_attributes @@ -23,13 +23,11 @@ LinkPydanticModel, ) from agentlightning.types.adapter import ( - AdaptingSequence, AdaptingSpan, AgentAnnotation, Annotation, BaseAdaptingSequence, ExceptionAnnotation, - FromSpan, GeneralAnnotation, MessageAnnotation, ObjectAnnotation, @@ -44,7 +42,7 @@ filter_and_unflatten_attributes, ) -from .base import Adapter, SequenceAdapter +from .base import SequenceAdapter T_SpanSequence = TypeVar("T_SpanSequence", bound=Sequence[Span]) @@ -328,7 +326,7 @@ def __init__( if candidate_predicate is not None: self.candidate_predicate = candidate_predicate else: - self.candidate_predicate = lambda _: True # type: ignore + self.candidate_predicate: Callable[[AdaptingSpan], bool] = lambda _: True self.candidate_scope = candidate_scope self.scan_direction = scan_direction self.allow_reuse_linked_spans = allow_reuse_linked_spans @@ -355,3 +353,40 @@ def visit(node: Tree[AdaptingSpan]) -> Iterable[Sequence[AdaptingSpan]]: def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> BaseAdaptingSequence[AdaptingSpan]: groups = list(self._search_groups(source)) + span_id_to_link: Dict[str, LinkPydanticModel] = {} + for group in groups: + if self.scan_direction == "backward": + group_to_scan = reversed(group) + else: + group_to_scan = group + + annotations_to_fill: List[AdaptingSpan] = [] + for span in group_to_scan: + if isinstance(span.data, Annotation): + if not span.data.links: + annotations_to_fill.append(span) + # The span is an annotation, skip it from being a candidate + else: + # The span is a candidate + if self.candidate_predicate(span): + while len(annotations_to_fill) > 0: + # Fill the link + annotation_span = annotations_to_fill.pop(-1) + span_id_to_link[annotation_span.span_id] = LinkPydanticModel( + key_match="span_id", value_match=span.span_id + ) + + if not self.allow_reuse_linked_spans: + # Once used, the candidate span cannot be reused + break + # If no annotations to fill, the candidate is wasted + # Otherwise, the span is not a candidate, skip it + + def _update_links(span: AdaptingSpan) -> AdaptingSpan: + if span.span_id in span_id_to_link and isinstance(span.data, Annotation): + new_annotation = span.data.model_copy(update={"links": [span_id_to_link[span.span_id]]}) + return span.model_copy(update={"data": new_annotation}) + else: + return span + + return source.map(_update_links) From 3559256d9749b8de1038092eb60431d60522dfd5 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Wed, 7 Jan 2026 17:07:01 +0800 Subject: [PATCH 15/41] . --- agentlightning/adapter/base.py | 2 +- agentlightning/adapter/call.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/agentlightning/adapter/base.py b/agentlightning/adapter/base.py index 1a88471e7..6f9d1ad99 100644 --- a/agentlightning/adapter/base.py +++ b/agentlightning/adapter/base.py @@ -5,7 +5,7 @@ from opentelemetry.sdk.trace import ReadableSpan from agentlightning.types import Span -from agentlightning.types.adapter import AdaptingSequence, BaseAdaptingSequence +from agentlightning.types.adapter import BaseAdaptingSequence T_from = TypeVar("T_from") T_to = TypeVar("T_to") diff --git a/agentlightning/adapter/call.py b/agentlightning/adapter/call.py index ca30a4fe8..6a6b5a136 100644 --- a/agentlightning/adapter/call.py +++ b/agentlightning/adapter/call.py @@ -12,7 +12,7 @@ from .base import Adapter -class CurateChatCompletionCalls(Adapter[Sequence[Span], Sequence[ChatCompletionCall]]): +class IdentifyChatCompletionCalls(Adapter[Sequence[Span], Sequence[ChatCompletionCall]]): """Curate the chat completion calls from the spans.""" def _parse_openai_chat_completion_create(self, span: Union[Span, Tree[Span]]) -> ChatCompletionCall: ... From 16741ab0a31e3eb1d69afd7f0503e2c5884fc5c2 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Wed, 7 Jan 2026 17:13:08 +0800 Subject: [PATCH 16/41] fix adapter --- agentlightning/types/adapter.py | 16 +++++++++------- examples/apo/room_selector.py | 5 +++++ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/agentlightning/types/adapter.py b/agentlightning/types/adapter.py index d1d5c275f..2d57f6c2e 100644 --- a/agentlightning/types/adapter.py +++ b/agentlightning/types/adapter.py @@ -231,6 +231,8 @@ def from_span(cls, span: Span, data: Any) -> AdaptingSpan: # Annotation-related types +AnnotationType = Literal["agent", "general", "message", "object", "exception", "operation"] + class Annotation(BaseModel): """An annotation is an approach to parse a span into some kind of structured attachments to another object. @@ -240,7 +242,7 @@ class Annotation(BaseModel): Note that a span can be parsed in multiple ways, and annotation is just one of them. """ - annotation_type: Literal["agent", "general", "message", "object", "exception", "operation"] + annotation_type: AnnotationType """Type of the annotation.""" links: Sequence[LinkPydanticModel] = Field(default_factory=list[LinkPydanticModel]) @@ -250,7 +252,7 @@ class Annotation(BaseModel): class AgentAnnotation(Annotation): """Parsed from [OTel Agent Spans](https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-agent-spans/).""" - annotation_type = "agent" + annotation_type: AnnotationType = "agent" """Type of the annotation.""" id: Optional[str] = None @@ -266,7 +268,7 @@ class AgentAnnotation(Annotation): class GeneralAnnotation(Annotation): """An annotation payload that is parsed from an [annotation][agentlightning.semconv.AGL_ANNOTATION] span.""" - annotation_type = "general" + annotation_type: AnnotationType = "general" """Type of the annotation.""" rewards: Sequence[RewardPydanticModel] = Field(default_factory=list[RewardPydanticModel]) @@ -285,7 +287,7 @@ class GeneralAnnotation(Annotation): class MessageAnnotation(Annotation): """A log message that is parsed from a [message][agentlightning.semconv.AGL_MESSAGE] span.""" - annotation_type = "message" + annotation_type: AnnotationType = "message" """Type of the annotation.""" message: str @@ -295,7 +297,7 @@ class MessageAnnotation(Annotation): class ObjectAnnotation(Annotation): """An artifact that is parsed from a [object][agentlightning.semconv.AGL_OBJECT] span.""" - annotation_type = "object" + annotation_type: AnnotationType = "object" """Type of the annotation.""" object: Any @@ -305,7 +307,7 @@ class ObjectAnnotation(Annotation): class ExceptionAnnotation(Annotation): """An exception that is parsed from an [exception][agentlightning.semconv.AGL_EXCEPTION] span.""" - annotation_type = "exception" + annotation_type: AnnotationType = "exception" """Type of the annotation.""" type: str @@ -321,7 +323,7 @@ class ExceptionAnnotation(Annotation): class OperationAnnotation(Annotation): """An operation that is parsed from an [operation][agentlightning.semconv.AGL_OPERATION] span.""" - annotation_type = "operation" + annotation_type: AnnotationType = "operation" """Type of the annotation.""" name: str diff --git a/examples/apo/room_selector.py b/examples/apo/room_selector.py index 289351619..b566f734d 100644 --- a/examples/apo/room_selector.py +++ b/examples/apo/room_selector.py @@ -349,6 +349,11 @@ async def debug_room_selector(limit: int = 1): # Get the spans and convert them to messages # Useful for debugging and analysis spans = await store.query_spans(rollout.rollout_id) + for span in spans: + console.print( + f"[bold blue]=== Span {span.span_id} ({span.name}, parent {span.parent_id}) ===[/bold blue]" + ) + console.print(span.attributes) adapter = TraceToMessages() messages = adapter.adapt(spans) for message_idx, message in enumerate(messages): From fc6ae3819c47c92a7007dfcbefb621e763cb880a Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Wed, 7 Jan 2026 17:56:31 +0800 Subject: [PATCH 17/41] . --- agentlightning/adapter/call.py | 96 +++++++++++++++++++++++++++++++++- 1 file changed, 94 insertions(+), 2 deletions(-) diff --git a/agentlightning/adapter/call.py b/agentlightning/adapter/call.py index 6a6b5a136..b7b9b1f59 100644 --- a/agentlightning/adapter/call.py +++ b/agentlightning/adapter/call.py @@ -4,10 +4,20 @@ from __future__ import annotations -from typing import Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeGuard, Union, cast + +from openai.types.chat import ( + ChatCompletion, + ChatCompletionFunctionToolParam, + ChatCompletionMessageFunctionToolCall, + ChatCompletionMessageParam, + CompletionCreateParams, +) +from pydantic.type_adapter import TypeAdapter from agentlightning.types.adapter import AnnotatedChatCompletionCall, Annotation, ChatCompletionCall, Tree from agentlightning.types.tracer import Span +from agentlightning.utils.otel import filter_and_unflatten_attributes from .base import Adapter @@ -15,7 +25,89 @@ class IdentifyChatCompletionCalls(Adapter[Sequence[Span], Sequence[ChatCompletionCall]]): """Curate the chat completion calls from the spans.""" - def _parse_openai_chat_completion_create(self, span: Union[Span, Tree[Span]]) -> ChatCompletionCall: ... + def _validate_metadata(self, metadata: Any) -> TypeGuard[Dict[str, Any]]: + if not isinstance(metadata, dict) or not all(isinstance(key, str) for key in metadata.keys()): # type: ignore + return False + return True + + def _validate_completion(self, completion: Any) -> TypeGuard[List[Dict[str, Any]]]: + if not isinstance(completion, list): + return False + for choice in cast(List[Any], completion): + if not isinstance(choice, dict): + return False + if "message" not in choice or not isinstance(choice["message"], dict): + return False + return True + + def _parse_agentops_tool_calls(self, span: Span) -> Optional[ChatCompletionMessageFunctionToolCall]: + if span.name.startswith("tool_call."): + tool_call_data = filter_and_unflatten_attributes(span.attributes, "tool.") + if isinstance(tool_call_data, dict) and "call" in tool_call_data: + # Example tool_call_data: + # {'tool.name': 'get_rooms', 'tool.parameters': '{"date": ...}', + # 'tool.call.id': 'call_owd6', 'tool.call.type': 'function'} + tool_call_data = { + **tool_call_data["call"], + "function": {k: v for k, v in tool_call_data.items() if k != "call"}, + } + return ChatCompletionMessageFunctionToolCall.model_validate(tool_call_data) + return None + + def _parse_openai_chat_completion_create(self, span: Union[Span, Tree[Span]]) -> ChatCompletionCall: + core_content = span.attributes if isinstance(span, Span) else span.item.attributes + prompt_messages = filter_and_unflatten_attributes(core_content, "gen_ai.prompt") + request_metadata = filter_and_unflatten_attributes(core_content, "gen_ai.request") + completion_choices = filter_and_unflatten_attributes(core_content, "gen_ai.completion") + usages = filter_and_unflatten_attributes(core_content, "gen_ai.usage") + response_metadata = filter_and_unflatten_attributes(core_content, "gen_ai.response") + + if not self._validate_metadata(request_metadata): + raise ValueError(f"Invalid request metadata format in span attributes: {request_metadata}") + if not self._validate_metadata(response_metadata): + raise ValueError(f"Invalid response metadata format in span attributes: {response_metadata}") + if not self._validate_completion(completion_choices): + raise ValueError( + "Invalid completion choices format in span attributes. Must be a list of dict, " + f"each containing a 'message' dict: {completion_choices}" + ) + + request_body = cast( + CompletionCreateParams, + TypeAdapter(CompletionCreateParams).validate_python({"messages": prompt_messages, **request_metadata}), + ) + + if isinstance(span, Tree): + # Get additional tool calls from child spans + additional_tool_calls: List[ChatCompletionMessageFunctionToolCall] = [] + for child in span.children: + tool_call = self._parse_agentops_tool_calls(child.item) + if tool_call is not None: + additional_tool_calls.append(tool_call) + + for choice in completion_choices: + tool_calls = choice["message"].get("tool_calls", []) + if isinstance(tool_calls, list): + cast(List[Any], tool_calls).extend(additional_tool_calls) + else: + raise ValueError( + f"Invalid tool_calls format in completion choice message. Must be a list: {completion_choices}" + ) + choice["message"]["tool_calls"] = tool_calls + # Only assign to the first choice. + break + + return ChatCompletionCall( + request=request_body, + response=ChatCompletion.model_validate( + { + **response_metadata, + "choices": completion_choices, + "usage": usages, + } + ), + malformed_fields={}, # TODO: malformed fields + ) def _parse_litellm_request(self, span: Union[Span, Tree[Span]]) -> ChatCompletionCall: ... From c2f694045de5079cda498d3d6a75622b38051523 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Wed, 7 Jan 2026 18:29:53 +0800 Subject: [PATCH 18/41] . --- agentlightning/adapter/annotation.py | 7 +- agentlightning/adapter/call.py | 96 ++++++++++++++++++++-------- agentlightning/types/adapter.py | 67 +++++++++++++++++-- 3 files changed, 133 insertions(+), 37 deletions(-) diff --git a/agentlightning/adapter/annotation.py b/agentlightning/adapter/annotation.py index f5b76633e..94c6a6611 100644 --- a/agentlightning/adapter/annotation.py +++ b/agentlightning/adapter/annotation.py @@ -228,12 +228,7 @@ def adapt_one(self, source: AdaptingSpan) -> AdaptingSpan: # Fallback to agent annotation detection annotation = self.detect_agent_annotation(source) if annotation is not None: - if source.data is not None: - logger.warning( - "Found annotation on an adapting span with existing data; overwriting the data. " - f"Current data: {source.data}, New data: {annotation}" - ) - return AdaptingSpan.from_span(source, data=annotation) + return source.with_data(annotation) else: return source diff --git a/agentlightning/adapter/call.py b/agentlightning/adapter/call.py index b7b9b1f59..cf143ee0a 100644 --- a/agentlightning/adapter/call.py +++ b/agentlightning/adapter/call.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeGuard, Union, cast +from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, TypeGuard, Union, cast from openai.types.chat import ( ChatCompletion, @@ -13,16 +13,23 @@ ChatCompletionMessageParam, CompletionCreateParams, ) -from pydantic.type_adapter import TypeAdapter - -from agentlightning.types.adapter import AnnotatedChatCompletionCall, Annotation, ChatCompletionCall, Tree +from pydantic import TypeAdapter + +from agentlightning.types.adapter import ( + AdaptingSpan, + AnnotatedChatCompletionCall, + Annotation, + BaseAdaptingSequence, + ChatCompletionCall, + Tree, +) from agentlightning.types.tracer import Span -from agentlightning.utils.otel import filter_and_unflatten_attributes +from agentlightning.utils.otel import filter_and_unflatten_attributes, query_linked_spans -from .base import Adapter +from .base import SequenceAdapter -class IdentifyChatCompletionCalls(Adapter[Sequence[Span], Sequence[ChatCompletionCall]]): +class IdentifyChatCompletionCalls(SequenceAdapter[AdaptingSpan, AdaptingSpan]): """Curate the chat completion calls from the spans.""" def _validate_metadata(self, metadata: Any) -> TypeGuard[Dict[str, Any]]: @@ -54,13 +61,12 @@ def _parse_agentops_tool_calls(self, span: Span) -> Optional[ChatCompletionMessa return ChatCompletionMessageFunctionToolCall.model_validate(tool_call_data) return None - def _parse_openai_chat_completion_create(self, span: Union[Span, Tree[Span]]) -> ChatCompletionCall: - core_content = span.attributes if isinstance(span, Span) else span.item.attributes - prompt_messages = filter_and_unflatten_attributes(core_content, "gen_ai.prompt") - request_metadata = filter_and_unflatten_attributes(core_content, "gen_ai.request") - completion_choices = filter_and_unflatten_attributes(core_content, "gen_ai.completion") - usages = filter_and_unflatten_attributes(core_content, "gen_ai.usage") - response_metadata = filter_and_unflatten_attributes(core_content, "gen_ai.response") + def _parse_openai_chat_completion_create(self, span: AdaptingSpan) -> ChatCompletionCall: + prompt_messages = filter_and_unflatten_attributes(span.attributes, "gen_ai.prompt") + request_metadata = filter_and_unflatten_attributes(span.attributes, "gen_ai.request") + completion_choices = filter_and_unflatten_attributes(span.attributes, "gen_ai.completion") + usages = filter_and_unflatten_attributes(span.attributes, "gen_ai.usage") + response_metadata = filter_and_unflatten_attributes(span.attributes, "gen_ai.response") if not self._validate_metadata(request_metadata): raise ValueError(f"Invalid request metadata format in span attributes: {request_metadata}") @@ -77,11 +83,11 @@ def _parse_openai_chat_completion_create(self, span: Union[Span, Tree[Span]]) -> TypeAdapter(CompletionCreateParams).validate_python({"messages": prompt_messages, **request_metadata}), ) - if isinstance(span, Tree): + if isinstance(span.container, Tree): # Get additional tool calls from child spans additional_tool_calls: List[ChatCompletionMessageFunctionToolCall] = [] - for child in span.children: - tool_call = self._parse_agentops_tool_calls(child.item) + for child in span.children(): + tool_call = self._parse_agentops_tool_calls(child) if tool_call is not None: additional_tool_calls.append(tool_call) @@ -111,12 +117,20 @@ def _parse_openai_chat_completion_create(self, span: Union[Span, Tree[Span]]) -> def _parse_litellm_request(self, span: Union[Span, Tree[Span]]) -> ChatCompletionCall: ... - def adapt(self, source: Sequence[Span]) -> Sequence[ChatCompletionCall]: ... + def adapt_one(self, source: AdaptingSpan) -> AdaptingSpan: + if source.name == "openai.chat.completion.create" or source.name == "openai.chat.completion": + chat_completion_call = self._parse_openai_chat_completion_create(source) + return source.with_data(chat_completion_call) + elif source.name == "raw_gen_ai_request": + # Litellm request span + chat_completion_call = self._parse_litellm_request(source) + return source.with_data(chat_completion_call) + else: + # Not a chat completion call span. Do nothing + return source -class AnnotateChatCompletionCalls( - Adapter[Tuple[Sequence[ChatCompletionCall], Sequence[Annotation]], Sequence[AnnotatedChatCompletionCall]] -): +class AnnotateChatCompletionCalls(SequenceAdapter[AdaptingSpan, AdaptingSpan]): """Annotate chat completion calls with the given annotations. The intersection of "effective radius" of annotations and chat completion calls is used to determine @@ -125,7 +139,39 @@ class AnnotateChatCompletionCalls( If an annotation is not linked to any span, try to use `RepairMissingLinks` first to link it to spans. """ - def adapt( - self, - source: Tuple[Sequence[ChatCompletionCall], Sequence[Annotation]], - ) -> Sequence[AnnotatedChatCompletionCall]: ... + def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> BaseAdaptingSequence[AdaptingSpan]: + annotation_spans = [span for span in source if isinstance(span.data, Annotation)] + span_id_to_updated_annotation: Dict[str, AnnotatedChatCompletionCall] = {} + for annotation_span in annotation_spans: + annotation = cast(Annotation, annotation_span.data) + for linked_span in query_linked_spans(source, annotation.links): + if isinstance(linked_span.container, Tree): + linked_spans = list(linked_span.container.traverse()) + else: + linked_spans = [linked_span] + + for linked_span in linked_spans: + if isinstance(linked_span.data, ChatCompletionCall): + existing_annotations: List[Annotation] = ( + list(linked_span.data.annotations) + if isinstance(linked_span.data, AnnotatedChatCompletionCall) + else [] + ) + # Annotate the chat completion call + annotated_call = AnnotatedChatCompletionCall( + request=linked_span.data.request, + response=linked_span.data.response, + malformed_fields=linked_span.data.malformed_fields, + annotations=existing_annotations + [annotation], + ) + # Update the linked span with the annotated call + span_id_to_updated_annotation[linked_span.span_id] = annotated_call + + def _update_span(span: AdaptingSpan) -> AdaptingSpan: + if span.span_id in span_id_to_updated_annotation: + annotated_call = span_id_to_updated_annotation[span.span_id] + return span.with_data(annotated_call, override="silent") # override is expected here + else: + return span + + return source.map(_update_span) diff --git a/agentlightning/types/adapter.py b/agentlightning/types/adapter.py index 2d57f6c2e..207cc64d1 100644 --- a/agentlightning/types/adapter.py +++ b/agentlightning/types/adapter.py @@ -4,6 +4,8 @@ from __future__ import annotations +import logging +import weakref from typing import ( Any, Callable, @@ -12,7 +14,6 @@ Iterable, Iterator, Literal, - MutableSequence, Optional, Sequence, TypeVar, @@ -35,6 +36,8 @@ T = TypeVar("T") V = TypeVar("V") +logger = logging.getLogger(__name__) + class BaseAdaptingSequence(Sequence[T], Generic[T]): """Interface that makes adapter easier to work with sequences.""" @@ -89,11 +92,17 @@ def traverse(self) -> Iterable[T]: class Tree(BaseAdaptingSequence[T], Generic[T]): - """This is a generic tree data structure that can be used to represent the structure of a tree.""" + """This is a generic tree data structure that can be used to represent the structure of a tree. + + This data structure is immutable. + """ - def __init__(self, item: T, children: MutableSequence[Tree[T]]) -> None: + def __init__(self, item: T, children: Sequence[Tree[T]]) -> None: self._item = item self._children = children + self._parent: Optional[weakref.ReferenceType[Tree[T]]] = None + for child in self._children: + child._parent = weakref.ref(self) # type: ignore @property def item(self) -> T: @@ -103,6 +112,10 @@ def item(self) -> T: def children(self) -> Sequence[Tree[T]]: return self._children + @property + def parent(self) -> Optional[Tree[T]]: + return self._parent() if self._parent is not None else None + def traverse(self) -> Iterable[T]: yield self._item for child in self._children: @@ -118,9 +131,6 @@ def get(self, index: Union[int, slice]) -> Union[T, Sequence[T]]: """ return list(self.traverse())[index] - def add(self, child: Tree[T]) -> None: - self._children.append(child) - def map(self, func: Callable[[T], V]) -> Tree[V]: """Map a function over all items in the tree.""" return Tree(func(self._item), [child.map(func) for child in self._children]) @@ -211,6 +221,32 @@ class AdaptingSpan(Span): data: Any """The data in the adapted format. Could be annotations, calls, or other structured data.""" + container: Optional[BaseAdaptingSequence[AdaptingSpan]] = None + """An optional container that holds multiple adapted data items.""" + + def with_data(self, data: Any, override: Literal["silent", "warning", "forbidden"] = "warning") -> AdaptingSpan: + """Create a new [`AdaptingSpan`][agentlightning.AdaptingSpan] with the same properties but different adapted data. + + Args: + data: The new adapted data. + + Returns: + An instance of [`AdaptingSpan`][agentlightning.AdaptingSpan] with the same properties as + the current span but with the provided adapted data. + """ + if self.data is not None: + if override == "forbidden": + raise ValueError( + "Overwriting existing data on AdaptingSpan is forbidden. " + f"Current data: {self.data}, New data: {data}" + ) + elif override == "warning": + logger.warning( + "Found annotation on an adapting span with existing data; overwriting the data. " + f"Current data: {self.data}, New data: {data}" + ) + return self.model_copy(update={"data": data}) + @classmethod def from_span(cls, span: Span, data: Any) -> AdaptingSpan: """Create an [`AdaptingSpan`][agentlightning.AdaptingSpan] from a base [`Span`][agentlightning.Span]. @@ -228,6 +264,25 @@ def from_span(cls, span: Span, data: Any) -> AdaptingSpan: else: return AdaptingSpan.model_validate(span.model_dump()).model_copy(update={"data": data}) + def children(self) -> Sequence[AdaptingSpan]: + """Get the child spans as [`AdaptingSpan`][agentlightning.AdaptingSpan] instances. + + Only applicable when the container is a [`Tree`][agentlightning.Tree]. + """ + if self.container is None or not isinstance(self.container, Tree): + raise ValueError("AdaptingSpan.children() is only applicable when container is non-empty and a Tree.") + return [child.item for child in self.container.children] + + def parent_span(self) -> Optional[AdaptingSpan]: + """Get the parent span if available. + + Only applicable when the container is a [`Tree`][agentlightning.Tree]. + """ + if self.container is None or not isinstance(self.container, Tree): + raise ValueError("AdaptingSpan.parent() is only applicable when container is non-empty and a Tree.") + parent_tree = self.container.parent + return parent_tree.item if parent_tree is not None else None + # Annotation-related types From e27c2d7d743dd25ba9d7a46f5d0276511e9eafbc Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Thu, 8 Jan 2026 12:40:12 +0800 Subject: [PATCH 19/41] fix call support --- agentlightning/adapter/call.py | 125 ++++++++++++++++++++++++-------- agentlightning/types/adapter.py | 13 ++++ 2 files changed, 107 insertions(+), 31 deletions(-) diff --git a/agentlightning/adapter/call.py b/agentlightning/adapter/call.py index cf143ee0a..92df3df0c 100644 --- a/agentlightning/adapter/call.py +++ b/agentlightning/adapter/call.py @@ -4,13 +4,12 @@ from __future__ import annotations -from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, TypeGuard, Union, cast +import ast +from typing import Any, Dict, List, Optional, cast from openai.types.chat import ( ChatCompletion, - ChatCompletionFunctionToolParam, ChatCompletionMessageFunctionToolCall, - ChatCompletionMessageParam, CompletionCreateParams, ) from pydantic import TypeAdapter @@ -32,20 +31,44 @@ class IdentifyChatCompletionCalls(SequenceAdapter[AdaptingSpan, AdaptingSpan]): """Curate the chat completion calls from the spans.""" - def _validate_metadata(self, metadata: Any) -> TypeGuard[Dict[str, Any]]: - if not isinstance(metadata, dict) or not all(isinstance(key, str) for key in metadata.keys()): # type: ignore - return False - return True + def _parse_request(self, span: AdaptingSpan) -> Dict[str, Any]: + request = filter_and_unflatten_attributes(span.attributes, "gen_ai.request") + if not isinstance(request, dict): + raise ValueError(f"Invalid request format in span attributes: {request}") + return request - def _validate_completion(self, completion: Any) -> TypeGuard[List[Dict[str, Any]]]: - if not isinstance(completion, list): - return False - for choice in cast(List[Any], completion): + def _parse_response(self, span: AdaptingSpan) -> Dict[str, Any]: + response = filter_and_unflatten_attributes(span.attributes, "gen_ai.response") + if not isinstance(response, dict): + raise ValueError(f"Invalid response format in span attributes: {response}") + return response + + def _parse_completion_choices(self, span: AdaptingSpan) -> List[Dict[str, Any]]: + completion_choices = filter_and_unflatten_attributes(span.attributes, "gen_ai.completion") + if not isinstance(completion_choices, list): + raise ValueError(f"Invalid completion choices format in span attributes: {completion_choices}") + for choice in completion_choices: if not isinstance(choice, dict): - return False + raise ValueError( + f"Invalid completion choice format in span attributes. Choice must be a dict: {choice}" + ) if "message" not in choice or not isinstance(choice["message"], dict): - return False - return True + raise ValueError( + f"Invalid completion choice format in span attributes. Choice must contain a 'message' dict: {choice}" + ) + return completion_choices + + def _parse_prompt_messages(self, span: AdaptingSpan) -> List[Dict[str, Any]]: + prompt_messages = filter_and_unflatten_attributes(span.attributes, "gen_ai.prompt") + if not isinstance(prompt_messages, list) or not all(isinstance(msg, dict) for msg in prompt_messages): + raise ValueError(f"Invalid prompt messages format in span attributes: {prompt_messages}") + return prompt_messages + + def _parse_usages(self, span: AdaptingSpan) -> Dict[str, Any]: + usages = filter_and_unflatten_attributes(span.attributes, "gen_ai.usage") + if not isinstance(usages, dict): + raise ValueError(f"Invalid usages format in span attributes: {usages}") + return usages def _parse_agentops_tool_calls(self, span: Span) -> Optional[ChatCompletionMessageFunctionToolCall]: if span.name.startswith("tool_call."): @@ -62,21 +85,11 @@ def _parse_agentops_tool_calls(self, span: Span) -> Optional[ChatCompletionMessa return None def _parse_openai_chat_completion_create(self, span: AdaptingSpan) -> ChatCompletionCall: - prompt_messages = filter_and_unflatten_attributes(span.attributes, "gen_ai.prompt") - request_metadata = filter_and_unflatten_attributes(span.attributes, "gen_ai.request") - completion_choices = filter_and_unflatten_attributes(span.attributes, "gen_ai.completion") - usages = filter_and_unflatten_attributes(span.attributes, "gen_ai.usage") - response_metadata = filter_and_unflatten_attributes(span.attributes, "gen_ai.response") - - if not self._validate_metadata(request_metadata): - raise ValueError(f"Invalid request metadata format in span attributes: {request_metadata}") - if not self._validate_metadata(response_metadata): - raise ValueError(f"Invalid response metadata format in span attributes: {response_metadata}") - if not self._validate_completion(completion_choices): - raise ValueError( - "Invalid completion choices format in span attributes. Must be a list of dict, " - f"each containing a 'message' dict: {completion_choices}" - ) + prompt_messages = self._parse_prompt_messages(span) + request_metadata = self._parse_request(span) + completion_choices = self._parse_completion_choices(span) + usages = self._parse_usages(span) + response_metadata = self._parse_response(span) request_body = cast( CompletionCreateParams, @@ -115,13 +128,63 @@ def _parse_openai_chat_completion_create(self, span: AdaptingSpan) -> ChatComple malformed_fields={}, # TODO: malformed fields ) - def _parse_litellm_request(self, span: Union[Span, Tree[Span]]) -> ChatCompletionCall: ... + def _augment_litellm_raw_gen_ai_request( + self, span: AdaptingSpan, request: Dict[str, Any], response: Dict[str, Any] + ) -> None: + """Augment the request/response with more rich info from the sibling raw_gen_ai_request span. + + The request and response are modified in place. + """ + hosted_vllm = filter_and_unflatten_attributes(span.attributes, "llm.hosted_vllm") + if not hosted_vllm: + return + + if not isinstance(hosted_vllm, dict): + raise ValueError(f"Invalid hosted_vllm format in span attributes: {hosted_vllm}") + + if "choices" in hosted_vllm: + choices = ast.literal_eval(hosted_vllm["choices"]) + if not isinstance(choices, list) or not choices or not isinstance(choices[0], dict): + raise ValueError(f"Invalid choices format in hosted_vllm: {choices}") + if "token_ids" in choices[0]: + response["choices"][0]["token_ids"] = choices[0]["token_ids"] + + if "prompt_token_ids" in hosted_vllm: + request["prompt_token_ids"] = ast.literal_eval(hosted_vllm["prompt_token_ids"]) + + def _parse_litellm_request(self, span: AdaptingSpan) -> ChatCompletionCall: + prompt_messages = self._parse_prompt_messages(span) + completion_choices = self._parse_completion_choices(span) + usages = self._parse_usages(span) + request_metadata = self._parse_request(span) + response_metadata = self._parse_response(span) + + request_body = {"messages": prompt_messages, **request_metadata} + response_body = { + **response_metadata, + "choices": completion_choices, + "usage": usages, + } + + # If the underlying backend is vllm, we have more rich info in sibling span. + for sibling in span.siblings(): + if sibling.name == "raw_gen_ai_request": + self._augment_litellm_raw_gen_ai_request(sibling, request_body, response_body) + + return ChatCompletionCall( + request=cast( + CompletionCreateParams, + TypeAdapter(CompletionCreateParams).validate_python(request_body), + ), + response=ChatCompletion.model_validate(response_body), + malformed_fields={}, # TODO: malformed fields + ) def adapt_one(self, source: AdaptingSpan) -> AdaptingSpan: if source.name == "openai.chat.completion.create" or source.name == "openai.chat.completion": chat_completion_call = self._parse_openai_chat_completion_create(source) return source.with_data(chat_completion_call) - elif source.name == "raw_gen_ai_request": + elif source.name == "litellm_request": # Litellm request span chat_completion_call = self._parse_litellm_request(source) return source.with_data(chat_completion_call) diff --git a/agentlightning/types/adapter.py b/agentlightning/types/adapter.py index 207cc64d1..abe839855 100644 --- a/agentlightning/types/adapter.py +++ b/agentlightning/types/adapter.py @@ -218,6 +218,9 @@ class AdaptingSpan(Span): been converted to a different format by an adapter. """ + class Config: + arbitrary_types_allowed = True + data: Any """The data in the adapted format. Could be annotations, calls, or other structured data.""" @@ -273,6 +276,16 @@ def children(self) -> Sequence[AdaptingSpan]: raise ValueError("AdaptingSpan.children() is only applicable when container is non-empty and a Tree.") return [child.item for child in self.container.children] + def siblings(self) -> Sequence[AdaptingSpan]: + """Get the sibling spans as [`AdaptingSpan`][agentlightning.AdaptingSpan] instances. + + Only applicable when the container is a [`Tree`][agentlightning.Tree]. + """ + parent = self.parent_span() + if parent is None: + return [] + return [child for child in parent.children() if child != self] + def parent_span(self) -> Optional[AdaptingSpan]: """Get the parent span if available. From ad194387cac24556a9e2000c8636c696b414137a Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Thu, 8 Jan 2026 17:50:36 +0800 Subject: [PATCH 20/41] runnable --- agentlightning/adapter/base.py | 160 +++++++++++++++++++++++++-- agentlightning/adapter/preprocess.py | 109 +++++++++++------- agentlightning/types/adapter.py | 21 ++-- tests/adapter/test_call.py | 132 ++++++++++++++++++++++ 4 files changed, 364 insertions(+), 58 deletions(-) create mode 100644 tests/adapter/test_call.py diff --git a/agentlightning/adapter/base.py b/agentlightning/adapter/base.py index 6f9d1ad99..aa946f9ff 100644 --- a/agentlightning/adapter/base.py +++ b/agentlightning/adapter/base.py @@ -1,14 +1,15 @@ # Copyright (c) Microsoft. All rights reserved. -from typing import Any, Callable, Generic, Sequence, TypeVar +from typing import Any, Callable, Generic, Sequence, TypeVar, overload from opentelemetry.sdk.trace import ReadableSpan from agentlightning.types import Span from agentlightning.types.adapter import BaseAdaptingSequence -T_from = TypeVar("T_from") -T_to = TypeVar("T_to") +T_inv = TypeVar("T_inv") +T_from = TypeVar("T_from", contravariant=True) +T_to = TypeVar("T_to", covariant=True) class Adapter(Generic[T_from, T_to]): @@ -81,26 +82,167 @@ def adapt_one(self, source: T_from) -> T_to: raise NotImplementedError(f"{self.__class__.__name__}.adapt_one() is not implemented") -class Filter(Adapter[Sequence[T_from], Sequence[T_from]], Generic[T_from]): +class Filter(Adapter[Sequence[T_inv], Sequence[T_inv]], Generic[T_inv]): """Filter items of type T to items of type T based on a predicate.""" - def __init__(self, predicate: Callable[[T_from], bool]) -> None: + def __init__(self, predicate: Callable[[T_inv], bool]) -> None: self.predicate = predicate - def adapt(self, source: Sequence[T_from]) -> Sequence[T_from]: + def adapt(self, source: Sequence[T_inv]) -> Sequence[T_inv]: return [item for item in source if self.predicate(item)] -class Sort(Adapter[Sequence[T_from], Sequence[T_from]], Generic[T_from]): +class Sort(Adapter[Sequence[T_inv], Sequence[T_inv]], Generic[T_inv]): """Sort items of type T based on a key function.""" - def __init__(self, key: Callable[[T_from], Any]) -> None: + def __init__(self, key: Callable[[T_inv], Any]) -> None: self.key = key - def adapt(self, source: Sequence[T_from]) -> Sequence[T_from]: + def adapt(self, source: Sequence[T_inv]) -> Sequence[T_inv]: return sorted(source, key=self.key) +T_chain1 = TypeVar("T_chain1") +T_chain2 = TypeVar("T_chain2") +T_chain3 = TypeVar("T_chain3") +T_chain4 = TypeVar("T_chain4") +T_chain5 = TypeVar("T_chain5") +T_chain6 = TypeVar("T_chain6") +T_chain7 = TypeVar("T_chain7") +T_chain8 = TypeVar("T_chain8") +T_chain9 = TypeVar("T_chain9") + + +class Chain(Adapter[T_from, T_to]): + """Chain multiple adapters together to form a single adapter. + + The output of each adapter is passed as input to the next adapter in the chain. + """ + + @overload + def __init__( + self, + adapter1: Adapter[T_from, T_chain1], + adapter2: Adapter[T_chain1, T_to], + /, + ) -> None: ... + + @overload + def __init__( + self, + adapter1: Adapter[T_from, T_chain1], + adapter2: Adapter[T_chain1, T_chain2], + adapter3: Adapter[T_chain2, T_to], + /, + ) -> None: ... + + @overload + def __init__( + self, + adapter1: Adapter[T_from, T_chain1], + adapter2: Adapter[T_chain1, T_chain2], + adapter3: Adapter[T_chain2, T_chain3], + adapter4: Adapter[T_chain3, T_to], + /, + ) -> None: ... + + @overload + def __init__( + self, + adapter1: Adapter[T_from, T_chain1], + adapter2: Adapter[T_chain1, T_chain2], + adapter3: Adapter[T_chain2, T_chain3], + adapter4: Adapter[T_chain3, T_chain4], + adapter5: Adapter[T_chain4, T_to], + /, + ) -> None: ... + + @overload + def __init__( + self, + adapter1: Adapter[T_from, T_chain1], + adapter2: Adapter[T_chain1, T_chain2], + adapter3: Adapter[T_chain2, T_chain3], + adapter4: Adapter[T_chain3, T_chain4], + adapter5: Adapter[T_chain4, T_chain5], + adapter6: Adapter[T_chain5, T_to], + /, + ) -> None: ... + + @overload + def __init__( + self, + adapter1: Adapter[T_from, T_chain1], + adapter2: Adapter[T_chain1, T_chain2], + adapter3: Adapter[T_chain2, T_chain3], + adapter4: Adapter[T_chain3, T_chain4], + adapter5: Adapter[T_chain4, T_chain5], + adapter6: Adapter[T_chain5, T_chain6], + adapter7: Adapter[T_chain6, T_to], + /, + ) -> None: ... + + @overload + def __init__( + self, + adapter1: Adapter[T_from, T_chain1], + adapter2: Adapter[T_chain1, T_chain2], + adapter3: Adapter[T_chain2, T_chain3], + adapter4: Adapter[T_chain3, T_chain4], + adapter5: Adapter[T_chain4, T_chain5], + adapter6: Adapter[T_chain5, T_chain6], + adapter7: Adapter[T_chain6, T_chain7], + adapter8: Adapter[T_chain7, T_to], + /, + ) -> None: ... + + @overload + def __init__( + self, + adapter1: Adapter[T_from, T_chain1], + adapter2: Adapter[T_chain1, T_chain2], + adapter3: Adapter[T_chain2, T_chain3], + adapter4: Adapter[T_chain3, T_chain4], + adapter5: Adapter[T_chain4, T_chain5], + adapter6: Adapter[T_chain5, T_chain6], + adapter7: Adapter[T_chain6, T_chain7], + adapter8: Adapter[T_chain7, T_chain8], + adapter9: Adapter[T_chain8, T_to], + /, + ) -> None: ... + + @overload + def __init__( + self, + adapter1: Adapter[T_from, T_chain1], + adapter2: Adapter[T_chain1, T_chain2], + adapter3: Adapter[T_chain2, T_chain3], + adapter4: Adapter[T_chain3, T_chain4], + adapter5: Adapter[T_chain4, T_chain5], + adapter6: Adapter[T_chain5, T_chain6], + adapter7: Adapter[T_chain6, T_chain7], + adapter8: Adapter[T_chain7, T_chain8], + adapter9: Adapter[T_chain8, T_chain9], + adapter10: Adapter[T_chain9, T_to], + /, + ) -> None: ... + + def __init__(self, adapter1: Adapter[Any, Any], *adapters: Adapter[Any, Any]) -> None: + # Enforce that a Chain always has at least one adapter. + self.adapters: tuple[Adapter[Any, Any], ...] = (adapter1, *adapters) + + def adapt(self, source: T_from) -> T_to: + result: Any = source + for idx, adapter in enumerate(self.adapters): + try: + result = adapter.adapt(result) + except Exception as exc: + raise RuntimeError( + f"Adapter chain failed at adapter index {idx} ({adapter.__class__.__name__}). See inner exception for details." + ) from exc + return result + + class OtelTraceAdapter(Adapter[Sequence[ReadableSpan], T_to], Generic[T_to]): """Base class for adapters that convert OpenTelemetry trace spans into other formats. diff --git a/agentlightning/adapter/preprocess.py b/agentlightning/adapter/preprocess.py index 3431675e7..0d279337c 100644 --- a/agentlightning/adapter/preprocess.py +++ b/agentlightning/adapter/preprocess.py @@ -10,11 +10,11 @@ from typing import Dict, List, Literal, Sequence, Set, Tuple, TypeVar from agentlightning.semconv import AGL_VIRTUAL -from agentlightning.types.adapter import Tree +from agentlightning.types.adapter import AdaptingSequence, AdaptingSpan, Tree from agentlightning.types.tracer import Span, SpanLike from agentlightning.utils.id import generate_id -from .base import Adapter, SequenceAdapter, Sort +from .base import Adapter, SequenceAdapter T_from = TypeVar("T_from") T_to = TypeVar("T_to") @@ -318,49 +318,38 @@ def adapt(self, source: Sequence[Span]) -> Tree[Span]: return _TreeLikeGraph.from_spans(source).to_tree(source) -class ToSortedSpans(Sort[Span]): - """Sort the spans with sequence ID as the primary key and start time as the secondary key.""" +class ToAdaptingSpans(Adapter[Sequence[Span], AdaptingSequence[AdaptingSpan]]): + """Sort the spans with sequence ID as the primary key and start time as the secondary key + and end time as the tertiary key.""" + + def adapt(self, source: Sequence[Span]) -> AdaptingSequence[AdaptingSpan]: + sorted_spans = sorted(source, key=default_span_order) + return AdaptingSequence([AdaptingSpan.from_span(span, None) for span in sorted_spans]) - def __init__(self) -> None: - super().__init__(key=lambda span: (span.sequence_id, span.start_time)) +class RepairMalformedSpans(Adapter[Sequence[Span], Sequence[Span]]): + """The adapter repairs multiple common issues with spans. -class RepairTime(Adapter[Sequence[Span], Sequence[Span]]): - """Repair the end time of the spans by: + 1. Repair the end time of the spans by: - 1. Ensuring the end time is greater than the start time. - 2. Fill the spans with no end time to be the maximum start/end time of all spans. + * Ensuring the end time is greater than the start time. + * Fill the spans with no end time to be the maximum start/end time of all spans. + + 2. Repair the parent ID. If a span has a parent ID that does not exist in the set of spans, + the parent ID will be set to None. """ - def __init__(self, ensure_positive_duration: bool = True, ensure_proper_nesting: bool = True) -> None: + def __init__( + self, + ensure_positive_duration: bool = True, + ensure_proper_nesting: bool = True, + ensure_valid_parent_ids: bool = True, + ) -> None: self.ensure_positive_duration = ensure_positive_duration self.ensure_proper_nesting = ensure_proper_nesting + self.ensure_valid_parent_ids = ensure_valid_parent_ids - def _repair_nesting(self, source: Sequence[Span]) -> Sequence[Span]: - graph = _TreeLikeGraph.from_spans(source, logs_invalid_parent=False) - spans = {span.span_id: span for span in source} - - def visit(node_id: str) -> Tuple[float, float]: - child_start_end_times: List[Tuple[float, float]] = [] - cur_start_time = spans[node_id].ensure_start_time() - cur_end_time = spans[node_id].ensure_end_time() - if graph.forward_graph.get(node_id): - for child_id in graph.forward_graph[node_id]: - child_start_end_times.append(visit(child_id)) - start_times, end_times = zip(*child_start_end_times) - start_time = min(cur_start_time, *start_times) - end_time = max(cur_end_time, *end_times) - if start_time != cur_start_time or end_time != cur_end_time: - spans[node_id] = spans[node_id].model_copy(update={"start_time": start_time, "end_time": end_time}) - - return spans[node_id].ensure_start_time(), spans[node_id].ensure_end_time() - - for root_id in graph.root_ids: - visit(root_id) - - return [spans[span.span_id] for span in source] - - def adapt(self, source: Sequence[Span]) -> Sequence[Span]: + def _repair_start_end_time(self, source: Sequence[Span]) -> List[Span]: times_set = set[float]() for span in source: if span.start_time is not None: @@ -394,7 +383,49 @@ def adapt(self, source: Sequence[Span]) -> Sequence[Span]: else: new_spans.append(span) + return new_spans + + def _repair_nesting(self, source: Sequence[Span]) -> List[Span]: + graph = _TreeLikeGraph.from_spans(source, logs_invalid_parent=False) + spans = {span.span_id: span for span in source} + + def visit(node_id: str) -> Tuple[float, float]: + child_start_end_times: List[Tuple[float, float]] = [] + cur_start_time = spans[node_id].ensure_start_time() + cur_end_time = spans[node_id].ensure_end_time() + if graph.forward_graph.get(node_id): + for child_id in graph.forward_graph[node_id]: + child_start_end_times.append(visit(child_id)) + start_times, end_times = zip(*child_start_end_times) + start_time = min(cur_start_time, *start_times) + end_time = max(cur_end_time, *end_times) + if start_time != cur_start_time or end_time != cur_end_time: + spans[node_id] = spans[node_id].model_copy(update={"start_time": start_time, "end_time": end_time}) + + return spans[node_id].ensure_start_time(), spans[node_id].ensure_end_time() + + for root_id in graph.root_ids: + visit(root_id) + + return [spans[span.span_id] for span in source] + + def _repair_invalid_parent_ids(self, source: Sequence[Span]) -> List[Span]: + valid_span_ids: Set[str] = set(span.span_id for span in source) + new_spans: List[Span] = [] + + for span in source: + if span.parent_id is not None and span.parent_id not in valid_span_ids: + new_spans.append(span.model_copy(update={"parent_id": None, "parent": None})) + else: + new_spans.append(span) + + return new_spans + + def adapt(self, source: Sequence[Span]) -> Sequence[Span]: + # This step is always performed first no matter whether the flags are set. + new_spans = self._repair_start_end_time(source) if self.ensure_proper_nesting: - return self._repair_nesting(new_spans) - else: - return new_spans + new_spans = self._repair_nesting(new_spans) + if self.ensure_valid_parent_ids: + new_spans = self._repair_invalid_parent_ids(new_spans) + return new_spans diff --git a/agentlightning/types/adapter.py b/agentlightning/types/adapter.py index abe839855..1324d4ab8 100644 --- a/agentlightning/types/adapter.py +++ b/agentlightning/types/adapter.py @@ -33,46 +33,47 @@ from .tracer import Attributes, Span +T_co = TypeVar("T_co", covariant=True) T = TypeVar("T") V = TypeVar("V") logger = logging.getLogger(__name__) -class BaseAdaptingSequence(Sequence[T], Generic[T]): +class BaseAdaptingSequence(Sequence[T_co], Generic[T_co]): """Interface that makes adapter easier to work with sequences.""" @overload - def __getitem__(self, index: int) -> T: ... + def __getitem__(self, index: int) -> T_co: ... @overload - def __getitem__(self, index: slice) -> Sequence[T]: ... + def __getitem__(self, index: slice) -> Sequence[T_co]: ... - def __getitem__(self, index: Union[int, slice]) -> Union[T, Sequence[T]]: + def __getitem__(self, index: Union[int, slice]) -> Union[T_co, Sequence[T_co]]: return self.get(index) - def __iter__(self) -> Iterator[T]: + def __iter__(self) -> Iterator[T_co]: return iter(self.traverse()) def __len__(self) -> int: return self.size() - def get(self, index: Union[int, slice]) -> Union[T, Sequence[T]]: + def get(self, index: Union[int, slice]) -> Union[T_co, Sequence[T_co]]: """Get the index-th item in the sequence.""" raise NotImplementedError() - def map(self, func: Callable[[T], V]) -> BaseAdaptingSequence[V]: + def map(self, func: Callable[[T_co], V]) -> BaseAdaptingSequence[V]: """Map a function over all items in the sequence.""" raise NotImplementedError() - def retain(self, predicate: Callable[[T], bool]) -> BaseAdaptingSequence[T]: + def retain(self, predicate: Callable[[T_co], bool]) -> BaseAdaptingSequence[T_co]: """Filter items in the sequence by a predicate (true for items to be kept). Depending on the implementation, the returned sequence may contain more or less items than a standard filter. """ raise NotImplementedError() - def prune(self, predicate: Callable[[T], bool]) -> BaseAdaptingSequence[T]: + def prune(self, predicate: Callable[[T_co], bool]) -> BaseAdaptingSequence[T_co]: """Prune items in the sequence by a predicate (true for items to be pruned). Depending on the implementation, the returned sequence may contain more or less items than a standard prune. @@ -83,7 +84,7 @@ def size(self) -> int: """Get the size of the sequence.""" raise NotImplementedError() - def traverse(self) -> Iterable[T]: + def traverse(self) -> Iterable[T_co]: """Traverse all items in the sequence.""" raise NotImplementedError() diff --git a/tests/adapter/test_call.py b/tests/adapter/test_call.py new file mode 100644 index 000000000..9996e46c0 --- /dev/null +++ b/tests/adapter/test_call.py @@ -0,0 +1,132 @@ +# Copyright (c) Microsoft. All rights reserved. + +from typing import List + +from agentlightning.adapter.base import Chain +from agentlightning.adapter.call import IdentifyChatCompletionCalls +from agentlightning.adapter.preprocess import RepairMalformedSpans, ToAdaptingSpans +from agentlightning.types.tracer import Span + + +def test_openai_calls(): + spans: List[Span] = [ + Span.from_attributes( + attributes={ + "tool.name": "get_rooms_and_availability", + "tool.parameters": '{"date":"2025-10-13","time":"16:30","duration_min":60}', + "tool.call.id": "call_owd6...l9Gv", + "tool.call.type": "function", + }, + span_id="18edfd18ea659820", + parent_id="e858708413368c22", + sequence_id=0, + ), + Span.from_attributes( + attributes={ + "tool.name": "get_rooms_and_availability", + "tool.parameters": '{"date":"2025-10-13","time":"16:30","duration_min":60}', + "tool.call.id": "call_VKsy...5ULX", + "tool.call.type": "function", + }, + span_id="6c9ba649e512a7f8", + parent_id="e858708413368c22", + sequence_id=1, + ), + Span.from_attributes( + attributes={ + "gen_ai.request.type": "chat", + "gen_ai.system": "OpenAI", + "gen_ai.request.model": "gpt-4.1-nano", + "gen_ai.request.temperature": 0.0, + "gen_ai.request.streaming": False, + "gen_ai.prompt.0.role": "system", + "gen_ai.prompt.0.content": "You are a scheduling assistant.", + "gen_ai.prompt.1.role": "user", + "gen_ai.prompt.1.content": "Find a room 2025-10-13 16:30 for 60m, 12 attendees; needs confphone+whiteboard; accessible.", + "gen_ai.request.functions.0.name": "get_rooms_and_availability", + "gen_ai.request.functions.0.description": "Return rooms with capacity/equipment/accessibility/bookings.", + "gen_ai.request.functions.0.parameters": '{"type":"object","properties":{"date":{"type":"string"},"time":{"type":"string"},"duration_min":{"type":"integer"}},"required":["date","time","duration_min"]}', + "gen_ai.response.id": "chatcmpl-...1jC4", + "gen_ai.response.model": "gpt-4.1-nano-2025-04-14", + "gen_ai.openai.system_fingerprint": "fp_03e44fcc34", + "gen_ai.usage.total_tokens": 211, + "gen_ai.usage.prompt_tokens": 128, + "gen_ai.usage.completion_tokens": 83, + "gen_ai.completion.0.finish_reason": "tool_calls", + "gen_ai.completion.0.role": "assistant", + }, + span_id="e858708413368c22", + parent_id="3db86425087d211f", + sequence_id=2, + ), + Span.from_attributes( + attributes={ + "gen_ai.request.type": "chat", + "gen_ai.system": "OpenAI", + "gen_ai.request.model": "gpt-4.1-nano", + "gen_ai.request.temperature": 0.0, + "gen_ai.request.streaming": False, + "gen_ai.prompt.0.role": "system", + "gen_ai.prompt.0.content": "You are a scheduling assistant.", + "gen_ai.prompt.1.role": "user", + "gen_ai.prompt.1.content": "Find a room 2025-10-13 16:30 for 60m, 12 attendees; needs confphone+whiteboard; accessible.", + "gen_ai.prompt.2.role": "assistant", + "gen_ai.prompt.2.tool_calls.0.id": "call_owd6...l9Gv", + "gen_ai.prompt.2.tool_calls.0.name": "get_rooms_and_availability", + "gen_ai.prompt.2.tool_calls.0.arguments": '{"date":"2025-10-13","time":"16:30","duration_min":60}', + "gen_ai.prompt.2.tool_calls.1.id": "call_VKsy...5ULX", + "gen_ai.prompt.2.tool_calls.1.name": "get_rooms_and_availability", + "gen_ai.prompt.2.tool_calls.1.arguments": '{"date":"2025-10-13","time":"16:30","duration_min":60}', + "gen_ai.prompt.3.role": "tool", + "gen_ai.prompt.3.content": '{"rooms":[{"id":"Nova","capacity":12,"equipment":["whiteboard","confphone"],"accessible":true,"distance_m":45,"booked":[],"free":true},{"id":"Pulse","capacity":8,"equipment":["whiteboard","confphone"],"accessible":true,"booked":[["2025-10-13","16:30",30]],"free":false}]}', + "gen_ai.prompt.3.tool_call_id": "call_owd6...l9Gv", + "gen_ai.prompt.4.role": "tool", + "gen_ai.prompt.4.content": '{"rooms":[{"id":"Nova","capacity":12,"equipment":["whiteboard","confphone"],"accessible":true,"free":true},{"id":"Pulse","capacity":8,"equipment":["whiteboard","confphone"],"accessible":true,"free":false}]}', + "gen_ai.prompt.4.tool_call_id": "call_VKsy...5ULX", + "gen_ai.response.id": "chatcmpl-...Syso", + "gen_ai.response.model": "gpt-4.1-nano-2025-04-14", + "gen_ai.openai.system_fingerprint": "fp_03e44fcc34", + "gen_ai.usage.total_tokens": 1189, + "gen_ai.usage.prompt_tokens": 1082, + "gen_ai.usage.completion_tokens": 107, + "gen_ai.completion.0.finish_reason": "stop", + "gen_ai.completion.0.role": "assistant", + "gen_ai.completion.0.content": "Available rooms: Nova (cap 12, whiteboard+confphone, accessible); Lyra (cap 10...).", + }, + span_id="9a44818e0901d0a1", + parent_id="3db86425087d211f", + sequence_id=3, + ), + Span.from_attributes( + attributes={ + "agentops.span.kind": "session", + "operation.name": "ro-90201d0a24cb", + }, + span_id="3db86425087d211f", + parent_id=None, + sequence_id=4, + ), + Span.from_attributes( + attributes={ + "agentlightning.reward.0.name": "primary", + "agentlightning.reward.0.value": 1.0, + }, + span_id="dc5e3c27f4378b6e", + parent_id=None, + sequence_id=5, + ), + ] + + adapter = Chain( + RepairMalformedSpans(), + ToAdaptingSpans(), + IdentifyChatCompletionCalls(), + ) + + adapted_spans = adapter(spans) + + for span in adapted_spans: + print(span) + + +test_openai_calls() From 7150c320c4a253c8f6f23a9a5c9c26ff09998a36 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Thu, 8 Jan 2026 17:59:38 +0800 Subject: [PATCH 21/41] . --- agentlightning/types/adapter.py | 2 +- tests/adapter/test_call.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/agentlightning/types/adapter.py b/agentlightning/types/adapter.py index 1324d4ab8..4919b9085 100644 --- a/agentlightning/types/adapter.py +++ b/agentlightning/types/adapter.py @@ -266,7 +266,7 @@ def from_span(cls, span: Span, data: Any) -> AdaptingSpan: if isinstance(span, AdaptingSpan): return span.model_copy(update={"data": data}) else: - return AdaptingSpan.model_validate(span.model_dump()).model_copy(update={"data": data}) + return AdaptingSpan.model_validate({**span.model_dump(), "data": data}) def children(self) -> Sequence[AdaptingSpan]: """Get the child spans as [`AdaptingSpan`][agentlightning.AdaptingSpan] instances. diff --git a/tests/adapter/test_call.py b/tests/adapter/test_call.py index 9996e46c0..8bafaaa82 100644 --- a/tests/adapter/test_call.py +++ b/tests/adapter/test_call.py @@ -17,6 +17,7 @@ def test_openai_calls(): "tool.call.id": "call_owd6...l9Gv", "tool.call.type": "function", }, + name="tool_call.get_rooms_and_availability", span_id="18edfd18ea659820", parent_id="e858708413368c22", sequence_id=0, @@ -28,6 +29,7 @@ def test_openai_calls(): "tool.call.id": "call_VKsy...5ULX", "tool.call.type": "function", }, + name="tool_call.get_rooms_and_availability", span_id="6c9ba649e512a7f8", parent_id="e858708413368c22", sequence_id=1, @@ -55,6 +57,7 @@ def test_openai_calls(): "gen_ai.completion.0.finish_reason": "tool_calls", "gen_ai.completion.0.role": "assistant", }, + name="openai.chat.completion", span_id="e858708413368c22", parent_id="3db86425087d211f", sequence_id=2, @@ -93,6 +96,7 @@ def test_openai_calls(): "gen_ai.completion.0.role": "assistant", "gen_ai.completion.0.content": "Available rooms: Nova (cap 12, whiteboard+confphone, accessible); Lyra (cap 10...).", }, + name="openai.chat.completion", span_id="9a44818e0901d0a1", parent_id="3db86425087d211f", sequence_id=3, @@ -102,6 +106,7 @@ def test_openai_calls(): "agentops.span.kind": "session", "operation.name": "ro-90201d0a24cb", }, + name="ro-90201d0a24cb.session", span_id="3db86425087d211f", parent_id=None, sequence_id=4, @@ -111,6 +116,7 @@ def test_openai_calls(): "agentlightning.reward.0.name": "primary", "agentlightning.reward.0.value": 1.0, }, + name="agentlightning.annotation", span_id="dc5e3c27f4378b6e", parent_id=None, sequence_id=5, From e36d013d155cd9e73e53d6449d75020ed6b21b3c Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Thu, 8 Jan 2026 19:12:49 +0800 Subject: [PATCH 22/41] . --- agentlightning/adapter/call.py | 60 ++++++++++++++++------------------ 1 file changed, 29 insertions(+), 31 deletions(-) diff --git a/agentlightning/adapter/call.py b/agentlightning/adapter/call.py index 92df3df0c..3f263a66d 100644 --- a/agentlightning/adapter/call.py +++ b/agentlightning/adapter/call.py @@ -27,6 +27,8 @@ from .base import SequenceAdapter +CompletionCreateParamsType: TypeAdapter[CompletionCreateParams] = TypeAdapter(CompletionCreateParams) + class IdentifyChatCompletionCalls(SequenceAdapter[AdaptingSpan, AdaptingSpan]): """Curate the chat completion calls from the spans.""" @@ -47,15 +49,28 @@ def _parse_completion_choices(self, span: AdaptingSpan) -> List[Dict[str, Any]]: completion_choices = filter_and_unflatten_attributes(span.attributes, "gen_ai.completion") if not isinstance(completion_choices, list): raise ValueError(f"Invalid completion choices format in span attributes: {completion_choices}") - for choice in completion_choices: + for index, choice in enumerate(completion_choices): if not isinstance(choice, dict): raise ValueError( f"Invalid completion choice format in span attributes. Choice must be a dict: {choice}" ) - if "message" not in choice or not isinstance(choice["message"], dict): - raise ValueError( - f"Invalid completion choice format in span attributes. Choice must contain a 'message' dict: {choice}" - ) + + choice["index"] = index + + # Uncover the message from the choice + message: Dict[str, Any] = { + "role": "assistant", + } + if "content" in choice: + message["content"] = cast(Dict[str, Any], choice).pop("content") + if isinstance(span.container, Tree): + # Get additional fields from child spans if any + for child in span.children(): + tool_call = self._parse_agentops_tool_calls(child) + if tool_call is not None: + message.setdefault("tool_calls", []).append(tool_call) + choice["message"] = message + return completion_choices def _parse_prompt_messages(self, span: AdaptingSpan) -> List[Dict[str, Any]]: @@ -70,7 +85,7 @@ def _parse_usages(self, span: AdaptingSpan) -> Dict[str, Any]: raise ValueError(f"Invalid usages format in span attributes: {usages}") return usages - def _parse_agentops_tool_calls(self, span: Span) -> Optional[ChatCompletionMessageFunctionToolCall]: + def _parse_agentops_tool_calls(self, span: Span) -> Optional[Dict[str, Any]]: if span.name.startswith("tool_call."): tool_call_data = filter_and_unflatten_attributes(span.attributes, "tool.") if isinstance(tool_call_data, dict) and "call" in tool_call_data: @@ -81,7 +96,7 @@ def _parse_agentops_tool_calls(self, span: Span) -> Optional[ChatCompletionMessa **tool_call_data["call"], "function": {k: v for k, v in tool_call_data.items() if k != "call"}, } - return ChatCompletionMessageFunctionToolCall.model_validate(tool_call_data) + return cast(Dict[str, Any], tool_call_data) return None def _parse_openai_chat_completion_create(self, span: AdaptingSpan) -> ChatCompletionCall: @@ -91,36 +106,19 @@ def _parse_openai_chat_completion_create(self, span: AdaptingSpan) -> ChatComple usages = self._parse_usages(span) response_metadata = self._parse_response(span) - request_body = cast( - CompletionCreateParams, - TypeAdapter(CompletionCreateParams).validate_python({"messages": prompt_messages, **request_metadata}), + validated_request_body = CompletionCreateParamsType.validate_python( + {"messages": prompt_messages, **request_metadata} ) - - if isinstance(span.container, Tree): - # Get additional tool calls from child spans - additional_tool_calls: List[ChatCompletionMessageFunctionToolCall] = [] - for child in span.children(): - tool_call = self._parse_agentops_tool_calls(child) - if tool_call is not None: - additional_tool_calls.append(tool_call) - - for choice in completion_choices: - tool_calls = choice["message"].get("tool_calls", []) - if isinstance(tool_calls, list): - cast(List[Any], tool_calls).extend(additional_tool_calls) - else: - raise ValueError( - f"Invalid tool_calls format in completion choice message. Must be a list: {completion_choices}" - ) - choice["message"]["tool_calls"] = tool_calls - # Only assign to the first choice. - break + normalized_request_body = CompletionCreateParamsType.dump_python(validated_request_body, mode="json") + print(normalized_request_body) return ChatCompletionCall( - request=request_body, + request=normalized_request_body, response=ChatCompletion.model_validate( { **response_metadata, + "object": "chat.completion", + "created": int(span.ensure_end_time()), "choices": completion_choices, "usage": usages, } From 907d178da064100517e229edd5e021bba12e3476 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Fri, 9 Jan 2026 12:40:58 +0800 Subject: [PATCH 23/41] fix call parsing --- agentlightning/adapter/call.py | 80 ++++++++++++++++++++++------ agentlightning/adapter/preprocess.py | 13 +++-- agentlightning/types/adapter.py | 28 +++++----- tests/adapter/test_call.py | 8 ++- 4 files changed, 91 insertions(+), 38 deletions(-) diff --git a/agentlightning/adapter/call.py b/agentlightning/adapter/call.py index 3f263a66d..60244f345 100644 --- a/agentlightning/adapter/call.py +++ b/agentlightning/adapter/call.py @@ -5,14 +5,14 @@ from __future__ import annotations import ast +import json from typing import Any, Dict, List, Optional, cast from openai.types.chat import ( ChatCompletion, - ChatCompletionMessageFunctionToolCall, CompletionCreateParams, ) -from pydantic import TypeAdapter +from pydantic import TypeAdapter, ValidationError from agentlightning.types.adapter import ( AdaptingSpan, @@ -37,6 +37,14 @@ def _parse_request(self, span: AdaptingSpan) -> Dict[str, Any]: request = filter_and_unflatten_attributes(span.attributes, "gen_ai.request") if not isinstance(request, dict): raise ValueError(f"Invalid request format in span attributes: {request}") + if "functions" in request: + if not isinstance(request["functions"], list): + raise ValueError(f"Invalid functions format in request. Must be a list: {request['functions']}") + for function in cast(List[Any], request["functions"]): + if not isinstance(function, dict): + raise ValueError(f"Invalid function format in request. Must be a dict: {function}") + if "parameters" in function and isinstance(function["parameters"], str): + function["parameters"] = json.loads(function["parameters"]) return request def _parse_response(self, span: AdaptingSpan) -> Dict[str, Any]: @@ -77,6 +85,24 @@ def _parse_prompt_messages(self, span: AdaptingSpan) -> List[Dict[str, Any]]: prompt_messages = filter_and_unflatten_attributes(span.attributes, "gen_ai.prompt") if not isinstance(prompt_messages, list) or not all(isinstance(msg, dict) for msg in prompt_messages): raise ValueError(f"Invalid prompt messages format in span attributes: {prompt_messages}") + + # Fix discrepency between OpenAI API and AGL span attributes. + for message in prompt_messages: + if not isinstance(message, dict): + raise ValueError(f"Invalid message format in prompt messages. Must be a dict: {message}") + if "tool_calls" in message: + if not isinstance(message["tool_calls"], list): + raise ValueError(f"Invalid tool calls format in message. Must be a list: {message['tool_calls']}") + for tool_call in cast(List[Any], message["tool_calls"]): + if not isinstance(tool_call, dict): + raise ValueError(f"Invalid tool call format in message. Must be a dict: {tool_call}") + if "type" not in tool_call: + tool_call["type"] = "function" + if "function" not in tool_call and "name" in tool_call and "arguments" in tool_call: + tool_call["function"] = { + "name": cast(Dict[str, Any], tool_call).pop("name"), + "arguments": cast(Dict[str, Any], tool_call).pop("arguments"), + } return prompt_messages def _parse_usages(self, span: AdaptingSpan) -> Dict[str, Any]: @@ -99,6 +125,29 @@ def _parse_agentops_tool_calls(self, span: Span) -> Optional[Dict[str, Any]]: return cast(Dict[str, Any], tool_call_data) return None + def _normalize_request(self, request_body: Dict[str, Any]) -> CompletionCreateParams: + validated_request = CompletionCreateParamsType.validate_python(request_body) + + def _to_plain_object(object: Any, path: List[str]) -> Any: + if type(object).__name__ == "ValidatorIterator": + try: + return [_to_plain_object(item, path + [str(i)]) for i, item in enumerate(object)] + except ValidationError as exc: + raise ValueError( + "Failed to convert ValidatorIterator to list.\n" + "ValidatorIterator path (see below for subpath): " + ".".join(path) + "\nError: " + str(exc) + ) from exc + elif isinstance(object, dict): + return {k: _to_plain_object(v, path + [k]) for k, v in object.items()} # type: ignore + elif isinstance(object, list): + return [_to_plain_object(item, path + [str(i)]) for i, item in enumerate(object)] # type: ignore + elif isinstance(object, tuple): + return tuple(_to_plain_object(item, path + [str(i)]) for i, item in enumerate(object)) # type: ignore + else: + return object + + return _to_plain_object(validated_request, []) + def _parse_openai_chat_completion_create(self, span: AdaptingSpan) -> ChatCompletionCall: prompt_messages = self._parse_prompt_messages(span) request_metadata = self._parse_request(span) @@ -106,23 +155,20 @@ def _parse_openai_chat_completion_create(self, span: AdaptingSpan) -> ChatComple usages = self._parse_usages(span) response_metadata = self._parse_response(span) - validated_request_body = CompletionCreateParamsType.validate_python( - {"messages": prompt_messages, **request_metadata} + request = self._normalize_request({"messages": prompt_messages, **request_metadata}) + response = ChatCompletion.model_validate( + { + **response_metadata, + "object": "chat.completion", + "created": int(span.ensure_end_time()), + "choices": completion_choices, + "usage": usages, + } ) - normalized_request_body = CompletionCreateParamsType.dump_python(validated_request_body, mode="json") - print(normalized_request_body) - return ChatCompletionCall( - request=normalized_request_body, - response=ChatCompletion.model_validate( - { - **response_metadata, - "object": "chat.completion", - "created": int(span.ensure_end_time()), - "choices": completion_choices, - "usage": usages, - } - ), + return ChatCompletionCall.model_construct( + request=request, + response=response, malformed_fields={}, # TODO: malformed fields ) diff --git a/agentlightning/adapter/preprocess.py b/agentlightning/adapter/preprocess.py index 0d279337c..5a30a502b 100644 --- a/agentlightning/adapter/preprocess.py +++ b/agentlightning/adapter/preprocess.py @@ -18,6 +18,7 @@ T_from = TypeVar("T_from") T_to = TypeVar("T_to") +T_span = TypeVar("T_span", bound=Span, covariant=True) logger = logging.getLogger(__name__) @@ -91,10 +92,10 @@ def visit(node_id: str) -> None: return ancestors - def to_tree(self, spans: Sequence[Span]) -> Tree[Span]: + def to_tree(self, spans: Sequence[T_span]) -> Tree[T_span]: spans_dict = {span.span_id: span for span in spans} - def build_subtree(node_id: str) -> Tree[Span]: + def build_subtree(node_id: str) -> Tree[T_span]: children = [build_subtree(child_id) for child_id in self.forward_graph.get(node_id, [])] return Tree(spans_dict[node_id], sorted(children, key=lambda child: default_span_order(child.item))) @@ -151,7 +152,7 @@ def adapt_one(self, source: SpanLike) -> Span: ) -class ToTree(Adapter[Sequence[Span], Tree[Span]]): +class ToTree(Adapter[Sequence[Span], Tree[AdaptingSpan]]): def __init__( self, @@ -299,7 +300,7 @@ def _repair_multiple_roots(self, source: Sequence[Span]) -> Sequence[Span]: ] return [new_root_span, *updated_spans] - def adapt(self, source: Sequence[Span]) -> Tree[Span]: + def adapt(self, source: Sequence[Span]) -> Tree[AdaptingSpan]: if not isinstance(source, Sequence): # pyright: ignore[reportUnnecessaryIsInstance] raise TypeError(f"Expected a sequence of spans, but got {type(source)}") if not source: @@ -315,7 +316,9 @@ def adapt(self, source: Sequence[Span]) -> Tree[Span]: if self.repair_multiple_roots: source = self._repair_multiple_roots(source) - return _TreeLikeGraph.from_spans(source).to_tree(source) + graph = _TreeLikeGraph.from_spans(source) + adapting_spans = [AdaptingSpan.from_span(span, None) for span in source] + return graph.to_tree(adapting_spans) class ToAdaptingSpans(Adapter[Sequence[Span], AdaptingSequence[AdaptingSpan]]): diff --git a/agentlightning/types/adapter.py b/agentlightning/types/adapter.py index 4919b9085..e26560015 100644 --- a/agentlightning/types/adapter.py +++ b/agentlightning/types/adapter.py @@ -92,32 +92,32 @@ def traverse(self) -> Iterable[T_co]: # General containers -class Tree(BaseAdaptingSequence[T], Generic[T]): +class Tree(BaseAdaptingSequence[T_co], Generic[T_co]): """This is a generic tree data structure that can be used to represent the structure of a tree. This data structure is immutable. """ - def __init__(self, item: T, children: Sequence[Tree[T]]) -> None: + def __init__(self, item: T_co, children: Sequence[Tree[T_co]]) -> None: self._item = item self._children = children - self._parent: Optional[weakref.ReferenceType[Tree[T]]] = None + self._parent: Optional[weakref.ReferenceType[Tree[T_co]]] = None for child in self._children: child._parent = weakref.ref(self) # type: ignore @property - def item(self) -> T: + def item(self) -> T_co: return self._item @property - def children(self) -> Sequence[Tree[T]]: + def children(self) -> Sequence[Tree[T_co]]: return self._children @property - def parent(self) -> Optional[Tree[T]]: + def parent(self) -> Optional[Tree[T_co]]: return self._parent() if self._parent is not None else None - def traverse(self) -> Iterable[T]: + def traverse(self) -> Iterable[T_co]: yield self._item for child in self._children: yield from child.traverse() @@ -125,18 +125,18 @@ def traverse(self) -> Iterable[T]: def size(self) -> int: return 1 + sum(child.size() for child in self._children) - def get(self, index: Union[int, slice]) -> Union[T, Sequence[T]]: + def get(self, index: Union[int, slice]) -> Union[T_co, Sequence[T_co]]: """Get the index-th item in the tree (O(n) time complexity). I think this is not efficient, but it's seldomly used. """ return list(self.traverse())[index] - def map(self, func: Callable[[T], V]) -> Tree[V]: + def map(self, func: Callable[[T_co], V]) -> Tree[V]: """Map a function over all items in the tree.""" return Tree(func(self._item), [child.map(func) for child in self._children]) - def _retain_subtree(self, predicate: Callable[[T], bool]) -> Optional[Tree[T]]: + def _retain_subtree(self, predicate: Callable[[T_co], bool]) -> Optional[Tree[T_co]]: if predicate(self._item): # If the current node satisfies the predicate, retain the subtree return self @@ -148,21 +148,21 @@ def _retain_subtree(self, predicate: Callable[[T], bool]) -> Optional[Tree[T]]: return Tree(self._item, [subtree for subtree in subtrees if subtree is not None]) - def retain(self, predicate: Callable[[T], bool]) -> Tree[T]: + def retain(self, predicate: Callable[[T_co], bool]) -> Tree[T_co]: """Prune the tree by retaining subtrees with root nodes that satisfy the predicate. The root node is always retained. """ return self._retain_subtree(predicate) or Tree(self._item, []) - def prune(self, predicate: Callable[[T], bool]) -> Tree[T]: + def prune(self, predicate: Callable[[T_co], bool]) -> Tree[T_co]: """Prune the tree by removing nodes that satisfy the predicate. The root node is always retained. """ return Tree(self._item, [child.prune(predicate) for child in self._children if not predicate(child._item)]) - def visualize(self, filename: str, item_to_str: Callable[[T], str]) -> None: + def visualize(self, filename: str, item_to_str: Callable[[T_co], str]) -> None: """Render the tree with Graphviz for debugging purposes. Args: @@ -177,7 +177,7 @@ def visualize(self, filename: str, item_to_str: Callable[[T], str]) -> None: dot = graphviz.Digraph(comment="Tree") - def visit(node: Tree[T]): + def visit(node: Tree[T_co]): dot.node(str(id(node)), item_to_str(node.item)) # type: ignore for child in node._children: visit(child) diff --git a/tests/adapter/test_call.py b/tests/adapter/test_call.py index 8bafaaa82..fec5876b4 100644 --- a/tests/adapter/test_call.py +++ b/tests/adapter/test_call.py @@ -4,7 +4,7 @@ from agentlightning.adapter.base import Chain from agentlightning.adapter.call import IdentifyChatCompletionCalls -from agentlightning.adapter.preprocess import RepairMalformedSpans, ToAdaptingSpans +from agentlightning.adapter.preprocess import RepairMalformedSpans, ToAdaptingSpans, ToTree from agentlightning.types.tracer import Span @@ -123,9 +123,13 @@ def test_openai_calls(): ), ] + # r1 = RepairMalformedSpans()(spans) + # r2 = ToTree()(r1) + # r2.visualize(filename="test_openai_calls", item_to_str=lambda span: span.name) + adapter = Chain( RepairMalformedSpans(), - ToAdaptingSpans(), + ToTree(), IdentifyChatCompletionCalls(), ) From 7beafd75110adaead9036f3dea801a061a27436a Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Fri, 9 Jan 2026 14:49:30 +0800 Subject: [PATCH 24/41] add tests for preprocess --- agentlightning/adapter/preprocess.py | 175 +++-- tests/adapter/test_preprocess.py | 928 +++++++++++++++++++++++++++ 2 files changed, 1058 insertions(+), 45 deletions(-) create mode 100644 tests/adapter/test_preprocess.py diff --git a/agentlightning/adapter/preprocess.py b/agentlightning/adapter/preprocess.py index 5a30a502b..731bda89b 100644 --- a/agentlightning/adapter/preprocess.py +++ b/agentlightning/adapter/preprocess.py @@ -24,6 +24,17 @@ def default_span_order(span: Span) -> tuple[int, float, float]: + """Return a tuple for sorting spans by sequence ID, then start time, then end time. + + Args: + span: The span to extract ordering keys from. + + Returns: + A tuple of (sequence_id, start_time, end_time) for use as a sort key. + + Raises: + ValueError: If the span has no start or end time set. + """ return (span.sequence_id, span.ensure_start_time(), span.ensure_end_time()) @@ -118,10 +129,14 @@ def from_spans(spans: Sequence[Span], logs_invalid_parent: bool = True) -> _Tree graph.forward_graph[span.parent_id].append(span.span_id) graph.root_ids.discard(span.span_id) graph.parent_map[span.span_id] = span.parent_id - elif logs_invalid_parent: - logger.debug( - f'Span {span.span_id} has an invalid parent ID "{span.parent_id}". The parent will be ignored.' - ) + else: + # Span has invalid parent, treat as root + graph.root_ids.add(span.span_id) + if logs_invalid_parent: + logger.debug( + f'Span {span.span_id} has an invalid parent ID "{span.parent_id}". ' + "The parent will be ignored and the span will be treated as a root." + ) graph.validate_no_cycles() @@ -129,7 +144,17 @@ def from_spans(spans: Sequence[Span], logs_invalid_parent: bool = True) -> _Tree class ToSpans(SequenceAdapter[SpanLike, Span]): - """Normalize the span-like objects (e.g., OpenTelemetry `ReadableSpan`) to [spans][agentlightning.Span].""" + """Normalize span-like objects (e.g., OpenTelemetry `ReadableSpan`) to [Span][agentlightning.Span]. + + This adapter handles conversion from various span formats to the internal Span type. + Native Span objects pass through unchanged, while OpenTelemetry spans are converted + using the provided default values for rollout, attempt, and sequence identifiers. + + Args: + default_rollout_id: Default rollout ID for converted OpenTelemetry spans. + default_attempt_id: Default attempt ID for converted OpenTelemetry spans. + default_sequence_id: Default sequence ID for converted OpenTelemetry spans. + """ def __init__( self, @@ -142,6 +167,14 @@ def __init__( self.default_sequence_id = default_sequence_id def adapt_one(self, source: SpanLike) -> Span: + """Convert a single span-like object to a Span. + + Args: + source: A Span or OpenTelemetry ReadableSpan to convert. + + Returns: + The converted Span object. Native Spans pass through unchanged. + """ if isinstance(source, Span): return source return Span.from_opentelemetry( @@ -153,15 +186,38 @@ def adapt_one(self, source: SpanLike) -> Span: class ToTree(Adapter[Sequence[Span], Tree[AdaptingSpan]]): + """Convert a sequence of spans into a tree structure. + + This adapter organizes flat span sequences into a hierarchical tree based on parent-child + relationships. It can repair various structural issues in the span data: + + - **Bad hierarchy**: Spans that are incorrectly positioned (e.g., dangling spans without + proper parents despite being contained within other spans' time ranges). + - **Multiple roots**: Cases where more than one span has no parent. + + Note: + Spans with invalid parent IDs (referencing non-existent spans) will cause a ValueError. + Use [RepairMalformedSpans][agentlightning.adapter.preprocess.RepairMalformedSpans] with + `ensure_valid_parent_ids=True` before calling this adapter if you need to handle + invalid parent references. + + Args: + repair_bad_hierarchy: Controls hierarchy repair. `"dangling"` repairs only orphaned spans, + `"all"` re-evaluates all span placements, `"none"` skips hierarchy repair. + repair_multiple_roots: If True, creates a virtual root when multiple root spans exist. + + Raises: + TypeError: If the input is not a sequence. + ValueError: If no spans are provided, if any span has an invalid parent ID, + or if the tree cannot be constructed. + """ def __init__( self, repair_bad_hierarchy: Literal["dangling", "all", "none"] = "dangling", - repair_missing_parents: bool = True, repair_multiple_roots: bool = True, ): self.repair_bad_hierarchy = repair_bad_hierarchy - self.repair_missing_parents = repair_missing_parents self.repair_multiple_roots = repair_multiple_roots def _find_eligible_parents( @@ -205,10 +261,10 @@ def _find_eligible_parents( continue spans_to_consider.append(candidate_parent) - # Sort the spans + # Sort the spans: (1) shortest to longest duration, (2) deeper to shallower (prefer more specific ancestors) return sorted( spans_to_consider, - key=lambda span: (span.ensure_end_time() - span.ensure_start_time(), cache_depths[span.span_id]), + key=lambda span: (span.ensure_end_time() - span.ensure_start_time(), -cache_depths[span.span_id]), ) def _repair_bad_hierarchy(self, source: Sequence[Span]) -> Sequence[Span]: @@ -216,6 +272,9 @@ def _repair_bad_hierarchy(self, source: Sequence[Span]) -> Sequence[Span]: This is based on the chronological relationships between start time and end time of spans. """ + if self.repair_bad_hierarchy == "none": + return source + graph = _TreeLikeGraph.from_spans(source) depths = graph.compute_depths() @@ -245,34 +304,24 @@ def _repair_bad_hierarchy(self, source: Sequence[Span]) -> Sequence[Span]: return scan_order - def _repair_missing_parents(self, source: Sequence[Span]) -> Sequence[Span]: + def _validate_parent_ids(self, source: Sequence[Span]) -> None: + """Validate that all parent IDs reference existing spans. + + Raises: + ValueError: If any span references a non-existent parent. + """ valid_span_ids: Set[str] = set(span.span_id for span in source) - parent_to_children: Dict[str, List[Span]] = defaultdict(list) + invalid_refs: List[str] = [] for span in source: if span.parent_id is not None and span.parent_id not in valid_span_ids: - parent_to_children[span.parent_id].append(span) - - created_spans: List[Span] = [] - for parent_id, children in parent_to_children.items(): - child = children[0] - # Create a virtual span for the missing parent - created_spans.append( - Span.from_attributes( - rollout_id=child.rollout_id, - attempt_id=child.attempt_id, - sequence_id=child.sequence_id, - trace_id=child.trace_id, - span_id=parent_id, - parent_id=None, - name=AGL_VIRTUAL, - attributes={}, - start_time=min(child.ensure_start_time() for child in children), - end_time=max(child.ensure_end_time() for child in children), - ) - ) + invalid_refs.append(f"{span.span_id} -> {span.parent_id}") - return [*source, *created_spans] + if invalid_refs: + raise ValueError( + f"Spans reference non-existent parent IDs: {', '.join(invalid_refs)}. " + "Use RepairMalformedSpans with ensure_valid_parent_ids=True to fix this." + ) def _repair_multiple_roots(self, source: Sequence[Span]) -> Sequence[Span]: root_spans = [span for span in source if span.parent_id is None] @@ -301,18 +350,28 @@ def _repair_multiple_roots(self, source: Sequence[Span]) -> Sequence[Span]: return [new_root_span, *updated_spans] def adapt(self, source: Sequence[Span]) -> Tree[AdaptingSpan]: + """Convert a sequence of spans into a tree of AdaptingSpan objects. + + Args: + source: A sequence of Span objects to organize into a tree. + + Returns: + A Tree with AdaptingSpan items representing the hierarchical structure. + + Raises: + TypeError: If source is not a sequence. + ValueError: If source is empty, has invalid parent IDs, or cannot form a valid tree. + """ if not isinstance(source, Sequence): # pyright: ignore[reportUnnecessaryIsInstance] raise TypeError(f"Expected a sequence of spans, but got {type(source)}") if not source: raise ValueError("No spans provided to create Tree") + # Validate parent IDs before any processing + self._validate_parent_ids(source) + source = self._repair_bad_hierarchy(source) - # The other repairing steps should be done *after* repairing the bad hierarchy because - # some problems might be fixed during the bad hierarchy repairing. - if self.repair_missing_parents: - source = self._repair_missing_parents(source) - # repair missing parents might have introduced new roots. if self.repair_multiple_roots: source = self._repair_multiple_roots(source) @@ -322,24 +381,41 @@ def adapt(self, source: Sequence[Span]) -> Tree[AdaptingSpan]: class ToAdaptingSpans(Adapter[Sequence[Span], AdaptingSequence[AdaptingSpan]]): - """Sort the spans with sequence ID as the primary key and start time as the secondary key - and end time as the tertiary key.""" + """Convert spans to a sorted AdaptingSequence. + + Sorts spans by sequence ID (primary), start time (secondary), and end time (tertiary), + then wraps each in an AdaptingSpan for use in adaptation pipelines. + """ def adapt(self, source: Sequence[Span]) -> AdaptingSequence[AdaptingSpan]: + """Sort spans and convert to AdaptingSequence. + + Args: + source: A sequence of Span objects to sort and convert. + + Returns: + An AdaptingSequence containing sorted AdaptingSpan objects. + """ sorted_spans = sorted(source, key=default_span_order) return AdaptingSequence([AdaptingSpan.from_span(span, None) for span in sorted_spans]) class RepairMalformedSpans(Adapter[Sequence[Span], Sequence[Span]]): - """The adapter repairs multiple common issues with spans. + """Repair common structural issues in span data. - 1. Repair the end time of the spans by: + This adapter fixes several types of malformed span data: - * Ensuring the end time is greater than the start time. - * Fill the spans with no end time to be the maximum start/end time of all spans. + - **Missing times**: Fills in missing start/end times using the maximum known time. + - **Negative duration**: Adjusts end times that are earlier than start times. + - **Improper nesting**: Expands parent time ranges to contain all children. + - **Invalid parent IDs**: Removes references to non-existent parent spans. - 2. Repair the parent ID. If a span has a parent ID that does not exist in the set of spans, - the parent ID will be set to None. + Spans that don't require repair pass through unchanged (same object reference). + + Args: + ensure_positive_duration: If True, sets end_time = start_time when end < start. + ensure_proper_nesting: If True, expands parent spans to contain children's time ranges. + ensure_valid_parent_ids: If True, sets parent_id to None for orphaned spans. """ def __init__( @@ -425,6 +501,15 @@ def _repair_invalid_parent_ids(self, source: Sequence[Span]) -> List[Span]: return new_spans def adapt(self, source: Sequence[Span]) -> Sequence[Span]: + """Repair malformed spans according to the configured repair options. + + Args: + source: A sequence of Span objects to repair. + + Returns: + A sequence of repaired Span objects. Unmodified spans retain their original + object reference. + """ # This step is always performed first no matter whether the flags are set. new_spans = self._repair_start_end_time(source) if self.ensure_proper_nesting: diff --git a/tests/adapter/test_preprocess.py b/tests/adapter/test_preprocess.py new file mode 100644 index 000000000..a0c548e8c --- /dev/null +++ b/tests/adapter/test_preprocess.py @@ -0,0 +1,928 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for the preprocess module adapters.""" + +import itertools +from typing import Any, Dict, List, Optional + +import pytest + +from agentlightning.adapter.preprocess import _TreeLikeGraph # pyright: ignore[reportPrivateUsage] +from agentlightning.adapter.preprocess import ( + RepairMalformedSpans, + ToAdaptingSpans, + ToSpans, + ToTree, + default_span_order, +) +from agentlightning.semconv import AGL_VIRTUAL +from agentlightning.types import Span +from agentlightning.types.adapter import AdaptingSpan + +_SEQ = itertools.count() + + +def make_span( + span_id: str, + name: str, + *, + parent_id: Optional[str], + start_time: float, + end_time: float, + attributes: Optional[Dict[str, Any]] = None, + rollout_id: str = "rollout-1", + attempt_id: str = "attempt-1", + sequence_id: Optional[int] = None, +) -> Span: + """Create a test span with sensible defaults.""" + return Span.from_attributes( + rollout_id=rollout_id, + attempt_id=attempt_id, + sequence_id=sequence_id if sequence_id is not None else next(_SEQ), + trace_id="trace-1", + span_id=span_id, + parent_id=parent_id, + name=name, + attributes=attributes or {}, + start_time=start_time, + end_time=end_time, + ) + + +# Tests for default_span_order + + +def test_default_span_order_by_sequence_id(): + span1 = make_span("s1", "span", parent_id=None, start_time=0.0, end_time=1.0, sequence_id=2) + span2 = make_span("s2", "span", parent_id=None, start_time=0.0, end_time=1.0, sequence_id=1) + spans = [span1, span2] + sorted_spans = sorted(spans, key=default_span_order) + assert [s.span_id for s in sorted_spans] == ["s2", "s1"] + + +def test_default_span_order_by_start_time(): + span1 = make_span("s1", "span", parent_id=None, start_time=2.0, end_time=3.0, sequence_id=0) + span2 = make_span("s2", "span", parent_id=None, start_time=1.0, end_time=3.0, sequence_id=0) + spans = [span1, span2] + sorted_spans = sorted(spans, key=default_span_order) + assert [s.span_id for s in sorted_spans] == ["s2", "s1"] + + +def test_default_span_order_by_end_time(): + span1 = make_span("s1", "span", parent_id=None, start_time=1.0, end_time=3.0, sequence_id=0) + span2 = make_span("s2", "span", parent_id=None, start_time=1.0, end_time=2.0, sequence_id=0) + spans = [span1, span2] + sorted_spans = sorted(spans, key=default_span_order) + assert [s.span_id for s in sorted_spans] == ["s2", "s1"] + + +# Tests for _TreeLikeGraph + + +def test_tree_like_graph_from_spans_creates_correct_graph(): + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=10.0) + child1 = make_span("child1", "child", parent_id="root", start_time=1.0, end_time=5.0) + child2 = make_span("child2", "child", parent_id="root", start_time=5.0, end_time=9.0) + grandchild = make_span("grandchild", "grandchild", parent_id="child1", start_time=2.0, end_time=4.0) + + graph = _TreeLikeGraph.from_spans([root, child1, child2, grandchild]) + + assert graph.root_ids == {"root"} + assert set(graph.forward_graph["root"]) == {"child1", "child2"} + assert graph.forward_graph["child1"] == ["grandchild"] + assert graph.parent_map["child1"] == "root" + assert graph.parent_map["child2"] == "root" + assert graph.parent_map["grandchild"] == "child1" + + +def test_tree_like_graph_from_spans_handles_invalid_parent(): + """Spans with invalid parent IDs should be treated as roots.""" + orphan = make_span("orphan", "orphan", parent_id="nonexistent", start_time=0.0, end_time=1.0) + + graph = _TreeLikeGraph.from_spans([orphan]) + + assert graph.root_ids == {"orphan"} + assert "orphan" not in graph.parent_map + + +def test_tree_like_graph_from_spans_multiple_roots(): + root1 = make_span("root1", "root", parent_id=None, start_time=0.0, end_time=5.0) + root2 = make_span("root2", "root", parent_id=None, start_time=5.0, end_time=10.0) + + graph = _TreeLikeGraph.from_spans([root1, root2]) + + assert graph.root_ids == {"root1", "root2"} + + +def test_tree_like_graph_compute_depths(): + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=10.0) + child = make_span("child", "child", parent_id="root", start_time=1.0, end_time=9.0) + grandchild = make_span("grandchild", "grandchild", parent_id="child", start_time=2.0, end_time=8.0) + + graph = _TreeLikeGraph.from_spans([root, child, grandchild]) + depths = graph.compute_depths() + + assert depths["root"] == 0 + assert depths["child"] == 1 + assert depths["grandchild"] == 2 + + +def test_tree_like_graph_compute_ancestors(): + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=10.0) + child = make_span("child", "child", parent_id="root", start_time=1.0, end_time=9.0) + grandchild = make_span("grandchild", "grandchild", parent_id="child", start_time=2.0, end_time=8.0) + + graph = _TreeLikeGraph.from_spans([root, child, grandchild]) + ancestors = graph.compute_ancestors() + + assert ancestors["root"] == set() + assert ancestors["child"] == {"root"} + assert ancestors["grandchild"] == {"root", "child"} + + +def test_tree_like_graph_move_subtree(): + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=10.0) + child1 = make_span("child1", "child", parent_id="root", start_time=1.0, end_time=5.0) + child2 = make_span("child2", "child", parent_id="root", start_time=5.0, end_time=9.0) + grandchild = make_span("grandchild", "grandchild", parent_id="child1", start_time=2.0, end_time=4.0) + + graph = _TreeLikeGraph.from_spans([root, child1, child2, grandchild]) + graph.move_subtree("grandchild", "child2") + + assert "grandchild" not in graph.forward_graph["child1"] + assert "grandchild" in graph.forward_graph["child2"] + assert graph.parent_map["grandchild"] == "child2" + + +def test_tree_like_graph_move_subtree_from_root(): + """Moving a root node should remove it from root_ids.""" + root1 = make_span("root1", "root", parent_id=None, start_time=0.0, end_time=10.0) + root2 = make_span("root2", "root", parent_id=None, start_time=0.0, end_time=10.0) + + graph = _TreeLikeGraph.from_spans([root1, root2]) + assert "root2" in graph.root_ids + + graph.move_subtree("root2", "root1") + + assert "root2" not in graph.root_ids + assert graph.parent_map["root2"] == "root1" + + +def test_tree_like_graph_to_tree_single_root(): + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=10.0) + child1 = make_span("child1", "child", parent_id="root", start_time=1.0, end_time=5.0) + child2 = make_span("child2", "child", parent_id="root", start_time=5.0, end_time=9.0) + + graph = _TreeLikeGraph.from_spans([root, child1, child2]) + tree = graph.to_tree([root, child1, child2]) + + assert tree.item.span_id == "root" + assert len(tree.children) == 2 + child_ids = {child.item.span_id for child in tree.children} + assert child_ids == {"child1", "child2"} + + +def test_tree_like_graph_to_tree_multiple_roots_raises(): + root1 = make_span("root1", "root", parent_id=None, start_time=0.0, end_time=5.0) + root2 = make_span("root2", "root", parent_id=None, start_time=5.0, end_time=10.0) + + graph = _TreeLikeGraph.from_spans([root1, root2]) + + with pytest.raises(ValueError, match="multiple or no roots"): + graph.to_tree([root1, root2]) + + +# Tests for ToSpans + + +def test_to_spans_pass_through_span(): + """Span objects should pass through unchanged.""" + span = make_span("s1", "span", parent_id=None, start_time=0.0, end_time=1.0) + adapter = ToSpans() + + result = adapter.adapt_one(span) + + assert result is span + + +def test_to_spans_default_values(): + """Adapter should use default values for rollout_id, attempt_id, and sequence_id.""" + adapter = ToSpans( + default_rollout_id="my-rollout", + default_attempt_id="my-attempt", + default_sequence_id=42, + ) + span = make_span("s1", "span", parent_id=None, start_time=0.0, end_time=1.0) + + result = adapter.adapt_one(span) + + # Span passes through unchanged, defaults only apply to OpenTelemetry spans + assert result is span + + +# Tests for ToTree + + +def test_to_tree_basic_creation(): + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=10.0) + child = make_span("child", "child", parent_id="root", start_time=1.0, end_time=9.0) + spans = [root, child] + + adapter = ToTree() + tree = adapter.adapt(spans) + + assert tree.item.span_id == "root" + assert len(tree.children) == 1 + assert tree.children[0].item.span_id == "child" + + +def test_to_tree_empty_spans_raises(): + adapter = ToTree() + + with pytest.raises(ValueError, match="No spans provided"): + adapter.adapt([]) + + +def test_to_tree_non_sequence_raises(): + adapter = ToTree() + + # String is technically a sequence but will fail when trying to access span attributes + with pytest.raises(AttributeError): + adapter.adapt("not a sequence") # type: ignore + + +def test_to_tree_repair_multiple_roots(): + root1 = make_span("root1", "root", parent_id=None, start_time=0.0, end_time=5.0) + root2 = make_span("root2", "root", parent_id=None, start_time=5.0, end_time=10.0) + spans = [root1, root2] + + adapter = ToTree(repair_multiple_roots=True) + tree = adapter.adapt(spans) + + assert tree.item.name == AGL_VIRTUAL + assert len(tree.children) == 2 + child_ids = {child.item.span_id for child in tree.children} + assert child_ids == {"root1", "root2"} + + +def test_to_tree_repair_multiple_roots_disabled(): + root1 = make_span("root1", "root", parent_id=None, start_time=0.0, end_time=5.0) + root2 = make_span("root2", "root", parent_id=None, start_time=5.0, end_time=10.0) + spans = [root1, root2] + + adapter = ToTree(repair_multiple_roots=False) + + with pytest.raises(ValueError, match="multiple or no roots"): + adapter.adapt(spans) + + +def test_to_tree_invalid_parent_raises_error(): + """Spans with invalid parent IDs should raise ValueError.""" + orphan = make_span("orphan", "orphan", parent_id="missing-parent", start_time=0.0, end_time=1.0) + spans = [orphan] + + adapter = ToTree() + + with pytest.raises(ValueError, match="non-existent parent IDs"): + adapter.adapt(spans) + + +def test_to_tree_with_repaired_invalid_parents(): + """Using RepairMalformedSpans before ToTree should handle invalid parent IDs.""" + orphan = make_span("orphan", "orphan", parent_id="missing-parent", start_time=0.0, end_time=1.0) + spans = [orphan] + + # First repair invalid parent IDs + repair_adapter = RepairMalformedSpans(ensure_valid_parent_ids=True) + repaired = repair_adapter.adapt(spans) + + # Now ToTree should work + tree_adapter = ToTree() + tree = tree_adapter.adapt(repaired) + + # Orphan becomes root since parent_id was set to None + assert tree.item.span_id == "orphan" + + +def test_to_tree_repair_bad_hierarchy_dangling(): + """Dangling spans should be re-attached based on time containment.""" + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=10.0) + container = make_span("container", "container", parent_id="root", start_time=1.0, end_time=9.0) + # Dangling span (no parent) that should fit inside container + dangling = make_span("dangling", "dangling", parent_id=None, start_time=2.0, end_time=8.0) + spans = [root, container, dangling] + + adapter = ToTree(repair_bad_hierarchy="dangling") + tree = adapter.adapt(spans) + + # Dangling should be moved under container (best fit by time) + container_node = next(c for c in tree.children if c.item.span_id == "container") + dangling_in_container = any(c.item.span_id == "dangling" for c in container_node.children) + assert dangling_in_container + + +def test_to_tree_repair_bad_hierarchy_none(): + """When repair_bad_hierarchy is 'none', hierarchy is not repaired.""" + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=10.0) + dangling = make_span("dangling", "dangling", parent_id=None, start_time=2.0, end_time=8.0) + spans = [root, dangling] + + adapter = ToTree(repair_bad_hierarchy="none", repair_multiple_roots=True) + tree = adapter.adapt(spans) + + # Both should be roots under virtual root + assert tree.item.name == AGL_VIRTUAL + child_ids = {c.item.span_id for c in tree.children} + assert child_ids == {"root", "dangling"} + + +def test_to_tree_children_sorted_by_time(): + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=10.0) + child_late = make_span("child-late", "child", parent_id="root", start_time=5.0, end_time=9.0, sequence_id=0) + child_early = make_span("child-early", "child", parent_id="root", start_time=1.0, end_time=4.0, sequence_id=0) + spans = [root, child_late, child_early] + + adapter = ToTree() + tree = adapter.adapt(spans) + + child_ids = [c.item.span_id for c in tree.children] + assert child_ids == ["child-early", "child-late"] + + +def test_to_tree_adapting_span_properties(): + """Tree should contain AdaptingSpan instances with proper container references.""" + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=10.0) + child = make_span("child", "child", parent_id="root", start_time=1.0, end_time=9.0) + spans = [root, child] + + adapter = ToTree() + tree = adapter.adapt(spans) + + assert isinstance(tree.item, AdaptingSpan) + assert isinstance(tree.children[0].item, AdaptingSpan) + + +# Tests for ToAdaptingSpans + + +def test_to_adapting_spans_sorts_by_default_order(): + span1 = make_span("s1", "span", parent_id=None, start_time=2.0, end_time=3.0, sequence_id=0) + span2 = make_span("s2", "span", parent_id=None, start_time=1.0, end_time=3.0, sequence_id=0) + spans = [span1, span2] + + adapter = ToAdaptingSpans() + result = adapter.adapt(spans) + + span_ids = [s.span_id for s in result] + assert span_ids == ["s2", "s1"] + + +def test_to_adapting_spans_returns_adapting_sequence(): + span = make_span("s1", "span", parent_id=None, start_time=0.0, end_time=1.0) + spans = [span] + + adapter = ToAdaptingSpans() + result = adapter.adapt(spans) + + assert len(result) == 1 + assert isinstance(result[0], AdaptingSpan) + + +# Tests for RepairMalformedSpans + + +def test_repair_malformed_spans_missing_start_time(): + span = Span.from_attributes( + rollout_id="r1", + attempt_id="a1", + sequence_id=0, + trace_id="t1", + span_id="s1", + parent_id=None, + name="span", + attributes={}, + start_time=None, + end_time=5.0, + ) + spans = [span] + + adapter = RepairMalformedSpans() + result = adapter.adapt(spans) + + # Start time should be set to max of all times (5.0) + assert result[0].start_time == 5.0 + + +def test_repair_malformed_spans_missing_end_time(): + span = Span.from_attributes( + rollout_id="r1", + attempt_id="a1", + sequence_id=0, + trace_id="t1", + span_id="s1", + parent_id=None, + name="span", + attributes={}, + start_time=1.0, + end_time=None, + ) + spans = [span] + + adapter = RepairMalformedSpans() + result = adapter.adapt(spans) + + # End time should be set to max of all times (1.0) + assert result[0].end_time == 1.0 + + +def test_repair_malformed_spans_both_missing_times(): + span = Span.from_attributes( + rollout_id="r1", + attempt_id="a1", + sequence_id=0, + trace_id="t1", + span_id="s1", + parent_id=None, + name="span", + attributes={}, + start_time=None, + end_time=None, + ) + spans = [span] + + adapter = RepairMalformedSpans() + result = adapter.adapt(spans) + + # Both should be set to current time (they should be equal and non-None) + assert result[0].start_time is not None + assert result[0].end_time is not None + assert result[0].start_time == result[0].end_time + + +def test_repair_malformed_spans_negative_duration(): + """When end_time < start_time, end_time should be set to start_time.""" + span = make_span("s1", "span", parent_id=None, start_time=5.0, end_time=3.0) + spans = [span] + + adapter = RepairMalformedSpans(ensure_positive_duration=True) + result = adapter.adapt(spans) + + assert result[0].end_time == 5.0 + + +def test_repair_malformed_spans_no_repair_negative_duration_when_disabled(): + span = make_span("s1", "span", parent_id=None, start_time=5.0, end_time=3.0) + spans = [span] + + adapter = RepairMalformedSpans(ensure_positive_duration=False) + result = adapter.adapt(spans) + + assert result[0].end_time == 3.0 + + +def test_repair_malformed_spans_invalid_parent_ids(): + span = make_span("s1", "span", parent_id="nonexistent", start_time=0.0, end_time=1.0) + spans = [span] + + adapter = RepairMalformedSpans(ensure_valid_parent_ids=True) + result = adapter.adapt(spans) + + assert result[0].parent_id is None + + +def test_repair_malformed_spans_no_repair_invalid_parent_ids_when_disabled(): + span = make_span("s1", "span", parent_id="nonexistent", start_time=0.0, end_time=1.0) + spans = [span] + + adapter = RepairMalformedSpans(ensure_valid_parent_ids=False) + result = adapter.adapt(spans) + + assert result[0].parent_id == "nonexistent" + + +def test_repair_malformed_spans_proper_nesting(): + """Parent span's time range should be expanded to contain all children.""" + parent = make_span("parent", "parent", parent_id=None, start_time=2.0, end_time=8.0) + child = make_span("child", "child", parent_id="parent", start_time=1.0, end_time=9.0) + spans = [parent, child] + + adapter = RepairMalformedSpans(ensure_proper_nesting=True) + result = adapter.adapt(spans) + + parent_result = next(s for s in result if s.span_id == "parent") + # Parent's time should be expanded to contain child + assert parent_result.start_time == 1.0 + assert parent_result.end_time == 9.0 + + +def test_repair_malformed_spans_no_repair_proper_nesting_when_disabled(): + parent = make_span("parent", "parent", parent_id=None, start_time=2.0, end_time=8.0) + child = make_span("child", "child", parent_id="parent", start_time=1.0, end_time=9.0) + spans = [parent, child] + + adapter = RepairMalformedSpans(ensure_proper_nesting=False) + result = adapter.adapt(spans) + + parent_result = next(s for s in result if s.span_id == "parent") + assert parent_result.start_time == 2.0 + assert parent_result.end_time == 8.0 + + +def test_repair_malformed_spans_unchanged_pass_through(): + """Spans that don't need repair should not be modified.""" + span = make_span("s1", "span", parent_id=None, start_time=0.0, end_time=1.0) + spans = [span] + + adapter = RepairMalformedSpans() + result = adapter.adapt(spans) + + # The span object should be the same (not copied) + assert result[0] is span + + +# Integration tests + + +def test_integration_complex_tree_with_repairs(): + """Test a complex scenario with multiple issues that need repair.""" + # Root span + root = make_span("root", "session", parent_id=None, start_time=0.0, end_time=100.0) + + # Agent span with misaligned timing + agent = make_span("agent", "agent.node", parent_id="root", start_time=5.0, end_time=90.0) + + # LLM span that's a sibling of agent but should be under it (dangling repair) + llm = make_span("llm", "openai.chat", parent_id="root", start_time=10.0, end_time=20.0) + + # Orphan span with missing parent + orphan = make_span("orphan", "tool.call", parent_id="missing", start_time=30.0, end_time=40.0) + + spans = [root, agent, llm, orphan] + + # First repair invalid parent IDs + repair_adapter = RepairMalformedSpans(ensure_valid_parent_ids=True) + repaired = repair_adapter.adapt(spans) + + # Apply hierarchy repairs and convert to tree + tree_adapter = ToTree( + repair_bad_hierarchy="dangling", + repair_multiple_roots=True, + ) + tree = tree_adapter.adapt(repaired) + + # The tree should be properly structured + assert tree.size() >= 4 # At least the original spans + + +def test_integration_adapting_spans_after_tree(): + """AdaptingSpans should work correctly on tree output.""" + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=10.0) + child = make_span("child", "child", parent_id="root", start_time=1.0, end_time=9.0) + spans = [root, child] + + tree_adapter = ToTree() + tree = tree_adapter.adapt(spans) + + # Traverse and verify all items are AdaptingSpans + for span in tree.traverse(): + assert isinstance(span, AdaptingSpan) + + +def test_integration_repair_then_tree(): + """RepairMalformedSpans followed by ToTree should work correctly.""" + parent = make_span("parent", "parent", parent_id=None, start_time=5.0, end_time=3.0) # Invalid: end < start + child = make_span("child", "child", parent_id="parent", start_time=1.0, end_time=2.0) + + # First repair the spans + repair_adapter = RepairMalformedSpans() + repaired = repair_adapter.adapt([parent, child]) + + # Then create tree + tree_adapter = ToTree() + tree = tree_adapter.adapt(repaired) + + # Parent should have repaired time and contain child + assert tree.item.start_time is not None + assert tree.item.end_time is not None + assert tree.item.end_time >= tree.item.start_time + + +# Edge case tests + + +def test_edge_case_single_span_tree(): + span = make_span("only", "only", parent_id=None, start_time=0.0, end_time=1.0) + + adapter = ToTree() + tree = adapter.adapt([span]) + + assert tree.item.span_id == "only" + assert len(tree.children) == 0 + + +def test_edge_case_deep_tree(): + """Test a deeply nested tree.""" + spans: List[Span] = [] + for i in range(10): + parent_id = f"span-{i-1}" if i > 0 else None + spans.append( + make_span( + f"span-{i}", + f"level-{i}", + parent_id=parent_id, + start_time=float(i), + end_time=float(20 - i), + ) + ) + + adapter = ToTree() + tree = adapter.adapt(spans) + + assert tree.size() == 10 + + # Verify depth + current = tree + depth = 0 + while current.children: + depth += 1 + current = current.children[0] + assert depth == 9 + + +def test_edge_case_wide_tree(): + """Test a tree with many siblings.""" + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=100.0) + children = [ + make_span(f"child-{i}", "child", parent_id="root", start_time=float(i), end_time=float(i + 1)) + for i in range(20) + ] + + adapter = ToTree() + tree = adapter.adapt([root] + children) + + assert tree.item.span_id == "root" + assert len(tree.children) == 20 + + +def test_edge_case_spans_with_same_times(): + """Test handling of spans with identical timestamps. + + When repair_bad_hierarchy is disabled, spans with same times become multiple roots. + """ + spans = [ + make_span(f"span-{i}", "span", parent_id=None, start_time=0.0, end_time=1.0, sequence_id=i) for i in range(5) + ] + + # With repair_bad_hierarchy="none", no hierarchy repair happens + adapter = ToTree(repair_bad_hierarchy="none", repair_multiple_roots=True) + tree = adapter.adapt(spans) + + # All should become children of virtual root + assert tree.item.name == AGL_VIRTUAL + assert len(tree.children) == 5 + + +def test_edge_case_repair_preserves_order(): + """Repaired spans should maintain their original order where possible.""" + spans = [ + make_span("s1", "span", parent_id=None, start_time=0.0, end_time=1.0, sequence_id=0), + make_span("s2", "span", parent_id=None, start_time=1.0, end_time=2.0, sequence_id=1), + make_span("s3", "span", parent_id=None, start_time=2.0, end_time=3.0, sequence_id=2), + ] + + adapter = RepairMalformedSpans() + result = adapter.adapt(spans) + + # Order should be preserved + assert [s.span_id for s in result] == ["s1", "s2", "s3"] + + +# Additional corner case tests + + +def test_tree_like_graph_full_cycle_no_roots(): + """Graph where all nodes form a cycle (no roots) should raise ValueError.""" + # Create spans with a full cycle: A -> B -> C -> A (no roots) + span_a = make_span("A", "span", parent_id="C", start_time=0.0, end_time=10.0) + span_b = make_span("B", "span", parent_id="A", start_time=1.0, end_time=9.0) + span_c = make_span("C", "span", parent_id="B", start_time=2.0, end_time=8.0) + + with pytest.raises(ValueError, match="not reachable from the roots"): + _TreeLikeGraph.from_spans([span_a, span_b, span_c]) + + +def test_tree_like_graph_cycle_in_subtree(): + """Graph with a cycle in a subtree should raise ValueError.""" + # Create a graph with cycle: root -> A -> B -> A + # This tests the "Cycle detected" error path in validate_no_cycles + graph = _TreeLikeGraph() + graph.root_ids.add("root") + graph.add_edge("root", "A") + graph.add_edge("A", "B") + graph.add_edge("B", "A") # Creates cycle: A -> B -> A + + with pytest.raises(ValueError, match="Cycle detected"): + graph.validate_no_cycles() + + +def test_tree_like_graph_add_edge_direct(): + """Test add_edge method directly.""" + graph = _TreeLikeGraph() + graph.root_ids.add("root") + graph.add_edge("root", "child1") + graph.add_edge("root", "child2") + graph.add_edge("child1", "grandchild") + + assert "child1" in graph.forward_graph["root"] + assert "child2" in graph.forward_graph["root"] + assert "grandchild" in graph.forward_graph["child1"] + + +def test_to_tree_repair_bad_hierarchy_all(): + """Test repair_bad_hierarchy='all' mode re-evaluates all span placements.""" + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=100.0) + # child1 is correctly placed under root + child1 = make_span("child1", "child", parent_id="root", start_time=10.0, end_time=90.0) + # child2 is incorrectly a sibling of child1 but should be inside child1 based on time + child2 = make_span("child2", "child", parent_id="root", start_time=20.0, end_time=80.0) + spans = [root, child1, child2] + + adapter = ToTree(repair_bad_hierarchy="all") + tree = adapter.adapt(spans) + + # With "all" mode, child2 should be moved under child1 (tighter fit) + child1_node = next(c for c in tree.children if c.item.span_id == "child1") + child2_in_child1 = any(c.item.span_id == "child2" for c in child1_node.children) + assert child2_in_child1 + + +def test_repair_malformed_spans_empty_sequence(): + """Empty sequence should pass through unchanged.""" + adapter = RepairMalformedSpans() + result = adapter.adapt([]) + assert result == [] + + +def test_to_tree_invalid_parent_error_message(): + """Test that invalid parent error message includes helpful guidance.""" + grandchild = make_span("grandchild", "gc", parent_id="missing-parent", start_time=2.0, end_time=3.0) + spans = [grandchild] + + adapter = ToTree() + + with pytest.raises(ValueError, match="RepairMalformedSpans"): + adapter.adapt(spans) + + +def test_default_span_order_with_tied_values(): + """Spans with identical ordering keys should maintain stable sort.""" + span1 = make_span("s1", "span", parent_id=None, start_time=1.0, end_time=2.0, sequence_id=0) + span2 = make_span("s2", "span", parent_id=None, start_time=1.0, end_time=2.0, sequence_id=0) + span3 = make_span("s3", "span", parent_id=None, start_time=1.0, end_time=2.0, sequence_id=0) + spans = [span1, span2, span3] + + # Default order should be deterministic + order1 = [default_span_order(s) for s in spans] + order2 = [default_span_order(s) for s in spans] + assert order1 == order2 + + +def test_to_tree_depth_based_parent_selection(): + """Deeper eligible parents should be preferred over shallower ones. + + This verifies the fix for the depth sorting in _find_eligible_parents. + """ + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=100.0) + # Two nested containers with identical durations + container1 = make_span("container1", "c1", parent_id="root", start_time=10.0, end_time=90.0) + container2 = make_span("container2", "c2", parent_id="container1", start_time=10.0, end_time=90.0) + # Dangling span that fits in both containers + dangling = make_span("dangling", "d", parent_id=None, start_time=20.0, end_time=80.0) + spans = [root, container1, container2, dangling] + + adapter = ToTree(repair_bad_hierarchy="dangling") + tree = adapter.adapt(spans) + + # Dangling should be placed under container2 (deeper) rather than container1 + container1_node = next(c for c in tree.children if c.item.span_id == "container1") + container2_node = next(c for c in container1_node.children if c.item.span_id == "container2") + dangling_in_container2 = any(c.item.span_id == "dangling" for c in container2_node.children) + assert dangling_in_container2, "Dangling span should be placed under the deeper container" + + +def test_to_tree_repair_bad_hierarchy_prefers_shorter_duration(): + """When multiple parents are eligible, prefer shorter duration (tighter fit).""" + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=100.0) + # Wide container (duration 80) + wide = make_span("wide", "wide", parent_id="root", start_time=10.0, end_time=90.0) + # Narrow container (duration 40) - should be preferred + narrow = make_span("narrow", "narrow", parent_id="root", start_time=25.0, end_time=65.0) + # Dangling span that fits in both + dangling = make_span("dangling", "d", parent_id=None, start_time=30.0, end_time=60.0) + spans = [root, wide, narrow, dangling] + + adapter = ToTree(repair_bad_hierarchy="dangling") + tree = adapter.adapt(spans) + + # Dangling should be placed under narrow (shorter duration = tighter fit) + narrow_node = next(c for c in tree.children if c.item.span_id == "narrow") + dangling_in_narrow = any(c.item.span_id == "dangling" for c in narrow_node.children) + assert dangling_in_narrow, "Dangling span should be placed under the tighter-fitting container" + + +def test_repair_nesting_deep_hierarchy(): + """Test nesting repair with deeply nested hierarchy.""" + # Create hierarchy where child times exceed parent times at multiple levels + root = make_span("root", "root", parent_id=None, start_time=5.0, end_time=15.0) + child = make_span("child", "child", parent_id="root", start_time=4.0, end_time=16.0) + grandchild = make_span("grandchild", "gc", parent_id="child", start_time=3.0, end_time=17.0) + great_grandchild = make_span("great-gc", "ggc", parent_id="grandchild", start_time=2.0, end_time=18.0) + spans = [root, child, grandchild, great_grandchild] + + adapter = RepairMalformedSpans(ensure_proper_nesting=True) + result = adapter.adapt(spans) + + # All ancestors should be expanded to contain descendants + root_result = next(s for s in result if s.span_id == "root") + assert root_result.start_time == 2.0 + assert root_result.end_time == 18.0 + + +def test_to_adapting_spans_empty_sequence(): + """Empty sequence should return empty AdaptingSequence.""" + adapter = ToAdaptingSpans() + result = adapter.adapt([]) + assert len(result) == 0 + + +def test_tree_like_graph_move_subtree_to_same_parent(): + """Moving subtree to the same parent should be a no-op.""" + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=10.0) + child = make_span("child", "child", parent_id="root", start_time=1.0, end_time=9.0) + + graph = _TreeLikeGraph.from_spans([root, child]) + + # Move child to root (its current parent) + graph.move_subtree("child", "root") + + # Should still have the same structure + assert graph.parent_map["child"] == "root" + assert "child" in graph.forward_graph["root"] + + +def test_repair_malformed_spans_uses_max_time_for_missing(): + """Missing times should be filled with max time from other spans.""" + span_with_times = make_span("s1", "span", parent_id=None, start_time=5.0, end_time=10.0) + span_missing_start = Span.from_attributes( + rollout_id="r1", + attempt_id="a1", + sequence_id=0, + trace_id="t1", + span_id="s2", + parent_id=None, + name="span", + attributes={}, + start_time=None, + end_time=7.0, + ) + span_missing_end = Span.from_attributes( + rollout_id="r1", + attempt_id="a1", + sequence_id=0, + trace_id="t1", + span_id="s3", + parent_id=None, + name="span", + attributes={}, + start_time=3.0, + end_time=None, + ) + spans = [span_with_times, span_missing_start, span_missing_end] + + adapter = RepairMalformedSpans() + result = adapter.adapt(spans) + + # Max time across all spans is 10.0 + s2_result = next(s for s in result if s.span_id == "s2") + s3_result = next(s for s in result if s.span_id == "s3") + assert s2_result.start_time == 10.0 # Filled with max + assert s3_result.end_time == 10.0 # Filled with max + + +def test_to_tree_no_repair_needed(): + """Test that properly structured spans don't get modified.""" + root = make_span("root", "root", parent_id=None, start_time=0.0, end_time=10.0) + child1 = make_span("child1", "child", parent_id="root", start_time=1.0, end_time=4.0) + child2 = make_span("child2", "child", parent_id="root", start_time=5.0, end_time=9.0) + grandchild = make_span("grandchild", "gc", parent_id="child1", start_time=2.0, end_time=3.0) + spans = [root, child1, child2, grandchild] + + adapter = ToTree() + tree = adapter.adapt(spans) + + # Verify structure is preserved + assert tree.item.span_id == "root" + assert len(tree.children) == 2 + child1_node = next(c for c in tree.children if c.item.span_id == "child1") + assert len(child1_node.children) == 1 + assert child1_node.children[0].item.span_id == "grandchild" From 3767c6d2b129b41414d70c95d50e151f42323b44 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Fri, 9 Jan 2026 15:27:45 +0800 Subject: [PATCH 25/41] add annotation tests --- agentlightning/adapter/annotation.py | 6 +- tests/adapter/test_annotation.py | 1078 ++++++++++++++++++++++++++ 2 files changed, 1082 insertions(+), 2 deletions(-) create mode 100644 tests/adapter/test_annotation.py diff --git a/agentlightning/adapter/annotation.py b/agentlightning/adapter/annotation.py index 94c6a6611..f82611703 100644 --- a/agentlightning/adapter/annotation.py +++ b/agentlightning/adapter/annotation.py @@ -273,7 +273,8 @@ def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> BaseAdaptingSeque if self.mode == "include": return source.retain(lambda span: span in linked_spans) else: - return source.prune(lambda span: span not in linked_spans) + # prune removes items where predicate is True, so we remove linked spans + return source.prune(lambda span: span in linked_spans) class RepairMissingLinks(SequenceAdapter[AdaptingSpan, AdaptingSpan]): @@ -341,7 +342,8 @@ def visit(node: Tree[AdaptingSpan]) -> Iterable[Sequence[AdaptingSpan]]: yield from visit(source) elif self.candidate_scope == "all": - return sorted(list(source), key=lambda span: default_span_order(span)) + # Return as a single group containing all spans sorted by default order + yield sorted(list(source), key=lambda span: default_span_order(span)) else: raise ValueError(f"Invalid candidate scope: {self.candidate_scope}") diff --git a/tests/adapter/test_annotation.py b/tests/adapter/test_annotation.py new file mode 100644 index 000000000..534bbe624 --- /dev/null +++ b/tests/adapter/test_annotation.py @@ -0,0 +1,1078 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for the annotation module adapters.""" + +import itertools +from typing import Any, Dict, Optional + +import pytest + +from agentlightning.adapter.annotation import ( + IdentifyAnnotations, + RepairMissingLinks, + SelectByAnnotation, +) +from agentlightning.semconv import ( + AGL_ANNOTATION, + AGL_EXCEPTION, + AGL_MESSAGE, + AGL_OBJECT, + AGL_OPERATION, + LightningSpanAttributes, + LinkPydanticModel, +) +from agentlightning.types import Span +from agentlightning.types.adapter import ( + AdaptingSequence, + AdaptingSpan, + AgentAnnotation, + ExceptionAnnotation, + GeneralAnnotation, + MessageAnnotation, + OperationAnnotation, + Tree, +) + +_SEQ = itertools.count() + + +def make_span( + span_id: str, + name: str, + *, + parent_id: Optional[str], + start_time: float, + end_time: float, + attributes: Optional[Dict[str, Any]] = None, + rollout_id: str = "rollout-1", + attempt_id: str = "attempt-1", + sequence_id: Optional[int] = None, +) -> Span: + """Create a test span with sensible defaults.""" + return Span.from_attributes( + rollout_id=rollout_id, + attempt_id=attempt_id, + sequence_id=sequence_id if sequence_id is not None else next(_SEQ), + trace_id="trace-1", + span_id=span_id, + parent_id=parent_id, + name=name, + attributes=attributes or {}, + start_time=start_time, + end_time=end_time, + ) + + +def make_adapting_span( + span_id: str, + name: str, + *, + parent_id: Optional[str] = None, + start_time: float = 0.0, + end_time: float = 1.0, + attributes: Optional[Dict[str, Any]] = None, + data: Any = None, +) -> AdaptingSpan: + """Create a test AdaptingSpan with sensible defaults.""" + span = make_span( + span_id=span_id, + name=name, + parent_id=parent_id, + start_time=start_time, + end_time=end_time, + attributes=attributes, + ) + return AdaptingSpan.from_span(span, data) + + +# Tests for IdentifyAnnotations._filter_custom_attributes + + +def test_filter_custom_attributes_filters_reserved_fields() -> None: + """Reserved fields should be filtered out.""" + adapter = IdentifyAnnotations() + attributes = { + LightningSpanAttributes.REWARD.value: 1.0, + LightningSpanAttributes.TAG.value: ["tag1"], + "custom.field": "value", + } + + result = adapter._filter_custom_attributes(attributes) # pyright: ignore[reportPrivateUsage] + + assert "custom.field" in result + assert LightningSpanAttributes.REWARD.value not in result + assert LightningSpanAttributes.TAG.value not in result + + +def test_filter_custom_attributes_filters_nested_reserved_fields() -> None: + """Nested reserved fields are only filtered when base field is present.""" + adapter = IdentifyAnnotations() + # When the base field IS present, nested fields are also filtered + attributes = { + LightningSpanAttributes.REWARD.value: "base", + f"{LightningSpanAttributes.REWARD.value}.0.value": 1.0, + f"{LightningSpanAttributes.REWARD.value}.0.name": "quality", + "custom.field": "value", + } + + result = adapter._filter_custom_attributes(attributes) # pyright: ignore[reportPrivateUsage] + + assert "custom.field" in result + assert len(result) == 1 + + +def test_filter_custom_attributes_nested_fields_not_filtered_without_base() -> None: + """Nested reserved fields pass through if base field is not present. + + Note: This is the current behavior - nested fields like 'agentlightning.reward.0.value' + are NOT filtered unless the base field 'agentlightning.reward' is also present. + """ + adapter = IdentifyAnnotations() + attributes = { + f"{LightningSpanAttributes.REWARD.value}.0.value": 1.0, + "custom.field": "value", + } + + result = adapter._filter_custom_attributes(attributes) # pyright: ignore[reportPrivateUsage] + + # Nested fields pass through since base field is not present + assert len(result) == 2 + + +def test_filter_custom_attributes_preserves_custom_fields() -> None: + """Custom fields not matching reserved prefixes should be preserved.""" + adapter = IdentifyAnnotations() + attributes = { + "my.custom.attribute": "value1", + "another_attribute": "value2", + } + + result = adapter._filter_custom_attributes(attributes) # pyright: ignore[reportPrivateUsage] + + assert result == attributes + + +# Tests for IdentifyAnnotations.extract_links + + +def test_extract_links_extracts_valid_links() -> None: + """Valid links should be extracted from span attributes.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + "span", + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + f"{LightningSpanAttributes.LINK.value}.0.key_match": "span_id", + f"{LightningSpanAttributes.LINK.value}.0.value_match": "target-span", + }, + ) + + links = adapter.extract_links(span) + + assert len(links) == 1 + assert links[0].key_match == "span_id" + assert links[0].value_match == "target-span" + + +def test_extract_links_returns_empty_for_no_links() -> None: + """Should return empty list when no links are present.""" + adapter = IdentifyAnnotations() + span = make_span("s1", "span", parent_id=None, start_time=0.0, end_time=1.0) + + links = adapter.extract_links(span) + + assert links == [] + + +def test_extract_links_handles_malformed_links() -> None: + """Malformed links should return empty list without raising.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + "span", + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + f"{LightningSpanAttributes.LINK.value}.0.invalid_field": "value", + }, + ) + + # Should not raise, returns empty list + links = adapter.extract_links(span) + assert links == [] + + +# Tests for IdentifyAnnotations.identify_general + + +def test_identify_general_creates_general_annotation() -> None: + """Should create GeneralAnnotation from annotation span.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + AGL_ANNOTATION, + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + "custom.field": "value", + # Need to provide empty list for tags since extract_tags_from_attributes + # doesn't handle missing tags gracefully + f"{LightningSpanAttributes.TAG.value}.0": "test-tag", + }, + ) + + result = adapter.identify_general(span) + + assert result is not None + assert result.annotation_type == "general" + # custom.field is preserved, tag is filtered out + assert "custom.field" in result.custom_fields + + +def test_identify_general_extracts_rewards() -> None: + """Should extract rewards from annotation span.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + AGL_ANNOTATION, + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + f"{LightningSpanAttributes.REWARD.value}.0.name": "quality", + f"{LightningSpanAttributes.REWARD.value}.0.value": 0.9, + f"{LightningSpanAttributes.TAG.value}.0": "test-tag", + }, + ) + + result = adapter.identify_general(span) + + assert result is not None + assert len(result.rewards) == 1 + assert result.rewards[0].name == "quality" + assert result.rewards[0].value == 0.9 + assert result.primary_reward == 0.9 + + +def test_identify_general_extracts_tags() -> None: + """Should extract tags from annotation span.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + AGL_ANNOTATION, + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + f"{LightningSpanAttributes.TAG.value}.0": "important", + f"{LightningSpanAttributes.TAG.value}.1": "reviewed", + }, + ) + + result = adapter.identify_general(span) + + assert result is not None + assert "important" in result.tags + assert "reviewed" in result.tags + + +# Tests for IdentifyAnnotations.identify_message + + +def test_identify_message_creates_message_annotation() -> None: + """Should create MessageAnnotation from message span.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + AGL_MESSAGE, + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + LightningSpanAttributes.MESSAGE_BODY.value: "Hello, world!", + }, + ) + + result = adapter.identify_message(span) + + assert result is not None + assert result.annotation_type == "message" + assert result.message == "Hello, world!" + + +def test_identify_message_returns_none_for_missing_body() -> None: + """Should return None when message body is missing.""" + adapter = IdentifyAnnotations() + span = make_span("s1", AGL_MESSAGE, parent_id=None, start_time=0.0, end_time=1.0) + + result = adapter.identify_message(span) + + assert result is None + + +# Tests for IdentifyAnnotations.identify_object + + +def test_identify_object_creates_object_annotation_from_json() -> None: + """Should create ObjectAnnotation from JSON object.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + AGL_OBJECT, + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + LightningSpanAttributes.OBJECT_JSON.value: '{"key": "value"}', + }, + ) + + result = adapter.identify_object(span) + + assert result is not None + assert result.annotation_type == "object" + assert result.object == {"key": "value"} + + +def test_identify_object_creates_object_annotation_from_literal() -> None: + """Should create ObjectAnnotation from literal value.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + AGL_OBJECT, + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + LightningSpanAttributes.OBJECT_LITERAL.value: "simple string", + }, + ) + + result = adapter.identify_object(span) + + assert result is not None + assert result.annotation_type == "object" + + +def test_identify_object_handles_invalid_json() -> None: + """Should handle invalid JSON gracefully.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + AGL_OBJECT, + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + LightningSpanAttributes.OBJECT_JSON.value: "invalid json {", + }, + ) + + # Should not raise, returns None due to error handling + result = adapter.identify_object(span) + assert result is None + + +# Tests for IdentifyAnnotations.identify_exception + + +def test_identify_exception_creates_exception_annotation() -> None: + """Should create ExceptionAnnotation from exception span.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + AGL_EXCEPTION, + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + "exception.type": "ValueError", + "exception.message": "Invalid input", + "exception.stacktrace": "Traceback...", + }, + ) + + result = adapter.identify_exception(span) + + assert result is not None + assert result.annotation_type == "exception" + assert result.type == "ValueError" + assert result.message == "Invalid input" + assert result.stacktrace == "Traceback..." + + +def test_identify_exception_uses_defaults_for_missing_fields() -> None: + """Should use default values when exception fields are missing.""" + adapter = IdentifyAnnotations() + span = make_span("s1", AGL_EXCEPTION, parent_id=None, start_time=0.0, end_time=1.0) + + result = adapter.identify_exception(span) + + assert result is not None + assert result.type == "UnknownException" + assert result.message == "" + + +# Tests for IdentifyAnnotations.identify_operation + + +def test_identify_operation_creates_operation_annotation() -> None: + """Should create OperationAnnotation from operation span.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + AGL_OPERATION, + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + LightningSpanAttributes.OPERATION_NAME.value: "process_data", + LightningSpanAttributes.OPERATION_INPUT.value: "input_value", + LightningSpanAttributes.OPERATION_OUTPUT.value: "output_value", + }, + ) + + result = adapter.identify_operation(span) + + assert result is not None + assert result.annotation_type == "operation" + assert result.name == "process_data" + assert result.input == "input_value" + assert result.output == "output_value" + + +def test_identify_operation_extracts_nested_input_output() -> None: + """Should extract nested input/output from flattened attributes.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + AGL_OPERATION, + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + LightningSpanAttributes.OPERATION_NAME.value: "process_data", + f"{LightningSpanAttributes.OPERATION_INPUT.value}.arg1": "value1", + f"{LightningSpanAttributes.OPERATION_INPUT.value}.arg2": "value2", + }, + ) + + result = adapter.identify_operation(span) + + assert result is not None + assert result.input == {"arg1": "value1", "arg2": "value2"} + + +# Tests for IdentifyAnnotations.detect_agent_annotation + + +def test_detect_agent_annotation_detects_otel_agent_span() -> None: + """Should detect agent from OpenTelemetry agent spans.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + "agent.run", + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + "agent.name": "MyAgent", + "agent.id": "agent-123", + "agent.description": "A helpful agent", + }, + ) + + result = adapter.detect_agent_annotation(span) + + assert result is not None + assert result.annotation_type == "agent" + assert result.name == "MyAgent" + assert result.id == "agent-123" + assert result.description == "A helpful agent" + + +def test_detect_agent_annotation_detects_agentops_agent() -> None: + """Should detect agent from AgentOps spans.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + "some.operation", + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + "agentops.span.kind": "agent", + "operation.name": "AgentOpsAgent", + }, + ) + + result = adapter.detect_agent_annotation(span) + + assert result is not None + assert result.name == "AgentOpsAgent" + + +def test_detect_agent_annotation_detects_autogen_agent() -> None: + """Should detect agent from Autogen spans.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + "autogen.task", + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + "recipient_agent_type": "AssistantAgent", + }, + ) + + result = adapter.detect_agent_annotation(span) + + assert result is not None + assert result.name == "AssistantAgent" + + +def test_detect_agent_annotation_detects_langgraph_agent() -> None: + """Should detect agent from LangGraph spans.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + "langgraph.node", + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + "langchain.chain.type": "ReActAgent", + }, + ) + + result = adapter.detect_agent_annotation(span) + + assert result is not None + assert result.name == "ReActAgent" + + +def test_detect_agent_annotation_detects_weave_agent() -> None: + """Should detect agent from Weave spans.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + "weave.call", + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + "type": "agent", + "agentlightning.operation.input.name": "WeaveAgent", + }, + ) + + result = adapter.detect_agent_annotation(span) + + assert result is not None + assert result.name == "WeaveAgent" + + +def test_detect_agent_annotation_detects_langchain_weave_agent() -> None: + """Should detect agent from LangChain + Weave spans.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + "langchain.Chain.MyChain", + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + "lc_name": "LangChainAgent", + }, + ) + + result = adapter.detect_agent_annotation(span) + + assert result is not None + assert result.name == "LangChainAgent" + + +def test_detect_agent_annotation_detects_agent_framework() -> None: + """Should detect agent from agent-framework spans.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + "executor.run", + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + "executor.id": "ExecutorAgent", + }, + ) + + result = adapter.detect_agent_annotation(span) + + assert result is not None + assert result.name == "ExecutorAgent" + + +def test_detect_agent_annotation_returns_none_for_non_agent_span() -> None: + """Should return None for spans without agent indicators.""" + adapter = IdentifyAnnotations() + span = make_span("s1", "some.operation", parent_id=None, start_time=0.0, end_time=1.0) + + result = adapter.detect_agent_annotation(span) + + assert result is None + + +# Tests for IdentifyAnnotations.adapt_one + + +def test_adapt_one_identifies_annotation_span() -> None: + """Should identify AGL_ANNOTATION spans.""" + adapter = IdentifyAnnotations() + source = make_adapting_span( + "s1", + AGL_ANNOTATION, + attributes={ + "custom": "value", + f"{LightningSpanAttributes.TAG.value}.0": "test-tag", + }, + ) + + result = adapter.adapt_one(source) + + assert isinstance(result.data, GeneralAnnotation) + + +def test_adapt_one_identifies_message_span() -> None: + """Should identify AGL_MESSAGE spans.""" + adapter = IdentifyAnnotations() + source = make_adapting_span( + "s1", + AGL_MESSAGE, + attributes={LightningSpanAttributes.MESSAGE_BODY.value: "Hello"}, + ) + + result = adapter.adapt_one(source) + + assert isinstance(result.data, MessageAnnotation) + + +def test_adapt_one_identifies_exception_span() -> None: + """Should identify AGL_EXCEPTION spans.""" + adapter = IdentifyAnnotations() + source = make_adapting_span( + "s1", + AGL_EXCEPTION, + attributes={"exception.type": "Error"}, + ) + + result = adapter.adapt_one(source) + + assert isinstance(result.data, ExceptionAnnotation) + + +def test_adapt_one_identifies_operation_span() -> None: + """Should identify AGL_OPERATION spans.""" + adapter = IdentifyAnnotations() + source = make_adapting_span( + "s1", + AGL_OPERATION, + attributes={LightningSpanAttributes.OPERATION_NAME.value: "op"}, + ) + + result = adapter.adapt_one(source) + + assert isinstance(result.data, OperationAnnotation) + + +def test_adapt_one_falls_back_to_agent_detection() -> None: + """Should fall back to agent detection for unknown span names.""" + adapter = IdentifyAnnotations() + source = make_adapting_span( + "s1", + "agent.task", + attributes={"agent.name": "MyAgent"}, + ) + + result = adapter.adapt_one(source) + + assert isinstance(result.data, AgentAnnotation) + + +def test_adapt_one_returns_unchanged_for_unrecognized_span() -> None: + """Should return unchanged span when no annotation is detected.""" + adapter = IdentifyAnnotations() + source = make_adapting_span("s1", "some.operation") + + result = adapter.adapt_one(source) + + assert result.data is None + + +# Tests for SelectByAnnotation + + +def test_select_by_annotation_include_mode_selects_annotated_spans() -> None: + """Include mode should select only annotated spans and their links.""" + adapter = SelectByAnnotation(mode="include") + + # Create spans with one annotation + annotation_span = make_adapting_span( + "annotation", + AGL_ANNOTATION, + data=GeneralAnnotation( + annotation_type="general", + links=[LinkPydanticModel(key_match="span_id", value_match="linked")], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + linked_span = make_adapting_span("linked", "some.operation") + unrelated_span = make_adapting_span("unrelated", "other.operation") + + source = AdaptingSequence([annotation_span, linked_span, unrelated_span]) + result = adapter.adapt(source) + + span_ids = [s.span_id for s in result] + assert "annotation" in span_ids + assert "linked" in span_ids + assert "unrelated" not in span_ids + + +def test_select_by_annotation_exclude_mode_removes_annotated_spans() -> None: + """Exclude mode should remove annotated spans and their links.""" + adapter = SelectByAnnotation(mode="exclude") + + annotation_span = make_adapting_span( + "annotation", + AGL_ANNOTATION, + data=GeneralAnnotation( + annotation_type="general", + links=[LinkPydanticModel(key_match="span_id", value_match="linked")], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + linked_span = make_adapting_span("linked", "some.operation") + unrelated_span = make_adapting_span("unrelated", "other.operation") + + source = AdaptingSequence([annotation_span, linked_span, unrelated_span]) + result = adapter.adapt(source) + + span_ids = [s.span_id for s in result] + assert "annotation" not in span_ids + assert "linked" not in span_ids + assert "unrelated" in span_ids + + +def test_select_by_annotation_include_mode_with_no_annotations() -> None: + """Include mode with no annotations should return empty result.""" + adapter = SelectByAnnotation(mode="include") + + span1 = make_adapting_span("s1", "operation") + span2 = make_adapting_span("s2", "operation") + + source = AdaptingSequence([span1, span2]) + result = adapter.adapt(source) + + assert len(list(result)) == 0 + + +def test_select_by_annotation_exclude_mode_with_no_annotations() -> None: + """Exclude mode with no annotations should return all spans.""" + adapter = SelectByAnnotation(mode="exclude") + + span1 = make_adapting_span("s1", "operation") + span2 = make_adapting_span("s2", "operation") + + source = AdaptingSequence([span1, span2]) + result = adapter.adapt(source) + + assert len(list(result)) == 2 + + +# Tests for RepairMissingLinks + + +def test_repair_missing_links_backward() -> None: + """Should repair missing links by searching backward.""" + adapter = RepairMissingLinks(scan_direction="backward") + + # Create sequence: candidate -> annotation (without link) + candidate = make_adapting_span("candidate", "some.operation", start_time=0.0, end_time=1.0) + annotation = make_adapting_span( + "annotation", + AGL_ANNOTATION, + start_time=1.0, + end_time=2.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], # No links + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + + source = AdaptingSequence([candidate, annotation]) + result = adapter.adapt(source) + + # Find the annotation span in result + annotation_result = next(s for s in result if s.span_id == "annotation") + assert isinstance(annotation_result.data, GeneralAnnotation) + assert len(annotation_result.data.links) == 1 + assert annotation_result.data.links[0].value_match == "candidate" + + +def test_repair_missing_links_forward() -> None: + """Should repair missing links by searching forward.""" + adapter = RepairMissingLinks(scan_direction="forward") + + # Create sequence: annotation (without link) -> candidate + annotation = make_adapting_span( + "annotation", + AGL_ANNOTATION, + start_time=0.0, + end_time=1.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + candidate = make_adapting_span("candidate", "some.operation", start_time=1.0, end_time=2.0) + + source = AdaptingSequence([annotation, candidate]) + result = adapter.adapt(source) + + annotation_result = next(s for s in result if s.span_id == "annotation") + assert isinstance(annotation_result.data, GeneralAnnotation) + assert len(annotation_result.data.links) == 1 + assert annotation_result.data.links[0].value_match == "candidate" + + +def test_repair_missing_links_respects_candidate_predicate() -> None: + """Should only consider candidates matching the predicate.""" + adapter = RepairMissingLinks( + scan_direction="backward", + candidate_predicate=lambda span: span.name == "valid.candidate", + ) + + invalid_candidate = make_adapting_span("invalid", "invalid.operation", start_time=0.0, end_time=1.0) + valid_candidate = make_adapting_span("valid", "valid.candidate", start_time=1.0, end_time=2.0) + annotation = make_adapting_span( + "annotation", + AGL_ANNOTATION, + start_time=2.0, + end_time=3.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + + source = AdaptingSequence([invalid_candidate, valid_candidate, annotation]) + result = adapter.adapt(source) + + annotation_result = next(s for s in result if s.span_id == "annotation") + assert annotation_result.data.links[0].value_match == "valid" + + +def test_repair_missing_links_does_not_reuse_linked_spans_by_default() -> None: + """By default, linked spans should not be reused for other annotations.""" + adapter = RepairMissingLinks(scan_direction="backward", allow_reuse_linked_spans=False) + + candidate = make_adapting_span("candidate", "operation", start_time=0.0, end_time=1.0) + annotation1 = make_adapting_span( + "ann1", + AGL_ANNOTATION, + start_time=1.0, + end_time=2.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + annotation2 = make_adapting_span( + "ann2", + AGL_ANNOTATION, + start_time=2.0, + end_time=3.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + + source = AdaptingSequence([candidate, annotation1, annotation2]) + result = adapter.adapt(source) + + # Only one annotation should get the link + linked_annotations = [s for s in result if isinstance(s.data, GeneralAnnotation) and len(s.data.links) > 0] + assert len(linked_annotations) == 1 + + +def test_repair_missing_links_allows_reuse_linked_spans_when_enabled() -> None: + """When enabled, linked spans can be reused for multiple annotations.""" + adapter = RepairMissingLinks(scan_direction="backward", allow_reuse_linked_spans=True) + + candidate = make_adapting_span("candidate", "operation", start_time=0.0, end_time=1.0) + annotation1 = make_adapting_span( + "ann1", + AGL_ANNOTATION, + start_time=1.0, + end_time=2.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + annotation2 = make_adapting_span( + "ann2", + AGL_ANNOTATION, + start_time=2.0, + end_time=3.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + + source = AdaptingSequence([candidate, annotation1, annotation2]) + result = adapter.adapt(source) + + # Both annotations should get links to the same candidate + linked_annotations = [s for s in result if isinstance(s.data, GeneralAnnotation) and len(s.data.links) > 0] + assert len(linked_annotations) == 2 + + +def test_repair_missing_links_preserves_existing_links() -> None: + """Annotations with existing links should not be modified.""" + adapter = RepairMissingLinks(scan_direction="backward") + + candidate = make_adapting_span("candidate", "operation", start_time=0.0, end_time=1.0) + annotation_with_link = make_adapting_span( + "annotation", + AGL_ANNOTATION, + start_time=1.0, + end_time=2.0, + data=GeneralAnnotation( + annotation_type="general", + links=[LinkPydanticModel(key_match="span_id", value_match="existing-target")], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + + source = AdaptingSequence([candidate, annotation_with_link]) + result = adapter.adapt(source) + + annotation_result = next(s for s in result if s.span_id == "annotation") + # Original link should be preserved + assert annotation_result.data.links[0].value_match == "existing-target" + + +def test_repair_missing_links_siblings_scope_requires_tree() -> None: + """Siblings scope should raise error for non-tree sequences.""" + adapter = RepairMissingLinks(candidate_scope="siblings") + + span = make_adapting_span("s1", "operation") + source = AdaptingSequence([span]) + + with pytest.raises(ValueError, match="siblings.*only applicable to tree"): + adapter.adapt(source) + + +def test_repair_missing_links_siblings_scope_with_tree() -> None: + """Siblings scope should work correctly with tree sequences.""" + adapter = RepairMissingLinks(candidate_scope="siblings", scan_direction="backward") + + # Build a simple tree: root -> [child1, child2] + root = make_adapting_span("root", "root", start_time=0.0, end_time=10.0) + child1 = make_adapting_span("child1", "operation", start_time=1.0, end_time=4.0) + child2 = make_adapting_span( + "child2", + AGL_ANNOTATION, + start_time=5.0, + end_time=9.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + + # Create tree structure + child1_tree = Tree(child1, []) + child2_tree = Tree(child2, []) + tree = Tree(root, [child1_tree, child2_tree]) + + result = adapter.adapt(tree) + + # child2 annotation should link to child1 (its sibling) + child2_result = next(s for s in result.traverse() if s.span_id == "child2") + assert isinstance(child2_result.data, GeneralAnnotation) + assert len(child2_result.data.links) == 1 + assert child2_result.data.links[0].value_match == "child1" + + +# Edge case tests + + +def test_identify_annotations_empty_sequence() -> None: + """Should handle empty sequence.""" + adapter = IdentifyAnnotations() + result = adapter.adapt(AdaptingSequence([])) + assert list(result) == [] + + +def test_select_by_annotation_empty_sequence() -> None: + """Should handle empty sequence.""" + adapter = SelectByAnnotation(mode="include") + result = adapter.adapt(AdaptingSequence([])) + assert len(list(result)) == 0 + + +def test_repair_missing_links_empty_sequence() -> None: + """Should handle empty sequence.""" + adapter = RepairMissingLinks() + result = adapter.adapt(AdaptingSequence([])) + assert len(list(result)) == 0 From 1db39a59bc6da9dca4ac96d9227af102298eb4b8 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Fri, 9 Jan 2026 15:59:47 +0800 Subject: [PATCH 26/41] fix annotation tests --- agentlightning/adapter/annotation.py | 161 +++++++++++- tests/adapter/test_annotation.py | 377 +++++++++++++++++++++++++++ 2 files changed, 537 insertions(+), 1 deletion(-) diff --git a/agentlightning/adapter/annotation.py b/agentlightning/adapter/annotation.py index f82611703..941be4cac 100644 --- a/agentlightning/adapter/annotation.py +++ b/agentlightning/adapter/annotation.py @@ -50,7 +50,22 @@ class IdentifyAnnotations(SequenceAdapter[AdaptingSpan, AdaptingSpan]): - """Curate the annotations from the spans.""" + """Identify and parse annotation data from spans based on span name conventions. + + This adapter inspects each span's name to determine if it represents a known + annotation type (general, message, object, exception, operation) or an agent span. + When identified, the span's data field is populated with the corresponding annotation + model containing extracted attributes. + + Supported annotation types: + + - `AGL_ANNOTATION`: General annotations with rewards, tags, and custom fields. + - `AGL_MESSAGE`: Message annotations containing a message body. + - `AGL_OBJECT`: Object annotations containing serialized JSON or literal values. + - `AGL_EXCEPTION`: Exception annotations with type, message, and stacktrace. + - `AGL_OPERATION`: Operation annotations with name, input, and output. + - Agent spans: Detected via heuristics for various agent frameworks. + """ def _filter_custom_attributes(self, attributes: Dict[str, Any]) -> Dict[str, Any]: reserved_fields = [attr.value for attr in LightningSpanAttributes if attr.value in attributes] @@ -65,6 +80,15 @@ def _filter_custom_attributes(self, attributes: Dict[str, Any]) -> Dict[str, Any } def extract_links(self, span: Span) -> Sequence[LinkPydanticModel]: + """Extract link specifications from span attributes. + + Args: + span: The span to extract links from. + + Returns: + A sequence of link models. Returns an empty list if no links are found + or if the link attributes are malformed. + """ try: return extract_links_from_attributes(span.attributes) except Exception as exc: @@ -72,6 +96,16 @@ def extract_links(self, span: Span) -> Sequence[LinkPydanticModel]: return [] def identify_general(self, span: Span) -> Optional[GeneralAnnotation]: + """Parse a general annotation span into a `GeneralAnnotation` model. + + Extracts rewards, tags, links, and custom fields from the span attributes. + + Args: + span: A span with name `AGL_ANNOTATION`. + + Returns: + A `GeneralAnnotation` with extracted data, or None if parsing fails. + """ rewards = get_rewards_from_span(span) primary_reward = rewards[0].value if rewards else None return GeneralAnnotation( @@ -84,6 +118,15 @@ def identify_general(self, span: Span) -> Optional[GeneralAnnotation]: ) def identify_message(self, span: Span) -> Optional[MessageAnnotation]: + """Parse a message span into a `MessageAnnotation` model. + + Args: + span: A span with name `AGL_MESSAGE`. + + Returns: + A `MessageAnnotation` containing the message body, or None if the + message body attribute is missing. + """ msg_body = get_message_value(span) if msg_body is None: logger.warning(f"Message body is missing for message span {span.span_id}") @@ -96,6 +139,17 @@ def identify_message(self, span: Span) -> Optional[MessageAnnotation]: ) def identify_object(self, span: Span) -> Optional[ObjectAnnotation]: + """Parse an object span into an `ObjectAnnotation` model. + + Supports both JSON-serialized objects and literal values. + + Args: + span: A span with name `AGL_OBJECT`. + + Returns: + An `ObjectAnnotation` containing the deserialized object, or None if + deserialization fails. + """ try: obj_value = get_object_value(span) except Exception as exc: @@ -109,6 +163,17 @@ def identify_object(self, span: Span) -> Optional[ObjectAnnotation]: ) def identify_exception(self, span: Span) -> Optional[ExceptionAnnotation]: + """Parse an exception span into an `ExceptionAnnotation` model. + + Uses OpenTelemetry semantic conventions for exception attributes. + + Args: + span: A span with name `AGL_EXCEPTION`. + + Returns: + An `ExceptionAnnotation` containing exception type, message, and stacktrace. + Missing fields default to "UnknownException" for type and empty string for others. + """ exception_type = span.attributes.get(exception_attributes.EXCEPTION_TYPE, "UnknownException") exception_message = span.attributes.get(exception_attributes.EXCEPTION_MESSAGE, "") exception_stacktrace = span.attributes.get(exception_attributes.EXCEPTION_STACKTRACE, "") @@ -122,6 +187,18 @@ def identify_exception(self, span: Span) -> Optional[ExceptionAnnotation]: ) def identify_operation(self, span: Span) -> Optional[OperationAnnotation]: + """Parse an operation span into an `OperationAnnotation` model. + + Extracts operation name, input, and output. Input/output can be either + direct values or nested structures reconstructed from flattened attributes. + + Args: + span: A span with name `AGL_OPERATION`. + + Returns: + An `OperationAnnotation` containing operation details, or None if + attribute unpacking fails. + """ try: operation_name = span.attributes.get(LightningSpanAttributes.OPERATION_NAME.value, "UnknownOperation") if LightningSpanAttributes.OPERATION_INPUT.value in span.attributes: @@ -149,14 +226,47 @@ def identify_operation(self, span: Span) -> Optional[OperationAnnotation]: ) def extract_agent_id(self, span: Span) -> Optional[str]: + """Extract agent ID from span attributes. + + Args: + span: The span to extract the agent ID from. + + Returns: + The agent ID if found, None otherwise. + """ # TODO: Support agent id in other formats return cast(Optional[str], span.attributes.get("agent.id")) def extract_agent_description(self, span: Span) -> Optional[str]: + """Extract agent description from span attributes. + + Args: + span: The span to extract the agent description from. + + Returns: + The agent description if found, None otherwise. + """ # TODO: Support agent description in other formats return cast(Optional[str], span.attributes.get("agent.description")) def extract_agent_name(self, span: Span) -> Optional[str]: + """Extract agent name from span attributes using framework-specific heuristics. + + Supports multiple agent frameworks by checking various attribute patterns: + 1. OpenTelemetry agent spans (`agent.name`) + 2. AgentOps decorated agents (`agentops.span.kind` + `operation.name`) + 3. Autogen teams (`recipient_agent_type`) + 4. LangGraph (`langchain.chain.type`) + 5. agent-framework (`executor.id`) + 6. Weave (`type` == "agent" + `agentlightning.operation.input.name`) + 7. Weave + LangChain (`langchain.Chain.*` span names + `lc_name`) + + Args: + span: The span to extract the agent name from. + + Returns: + The agent name if detected via any supported pattern, None otherwise. + """ # Case 1: OpenTelemetry Agent Spans agent_name = cast(Optional[str], span.attributes.get("agent.name")) if agent_name is not None: @@ -197,7 +307,20 @@ def extract_agent_name(self, span: Span) -> Optional[str]: if attributes_lc_name is not None: return attributes_lc_name + return None + def detect_agent_annotation(self, span: Span) -> Optional[AgentAnnotation]: + """Detect and create an agent annotation from span attributes. + + Uses heuristics to identify spans representing agent executions from + various frameworks (OpenTelemetry, AgentOps, Autogen, LangGraph, etc.). + + Args: + span: The span to check for agent indicators. + + Returns: + An `AgentAnnotation` if an agent is detected, None otherwise. + """ agent_id = self.extract_agent_id(span) agent_name = self.extract_agent_name(span) agent_description = self.extract_agent_description(span) @@ -213,6 +336,19 @@ def detect_agent_annotation(self, span: Span) -> Optional[AgentAnnotation]: return None def adapt_one(self, source: AdaptingSpan) -> AdaptingSpan: + """Process a single span to identify and attach annotation data. + + Checks the span name against known annotation types and parses the + corresponding annotation model. Falls back to agent detection for + unrecognized span names. + + Args: + source: The span to process. + + Returns: + The span with annotation data attached if identified, otherwise + the original span unchanged. + """ annotation: Optional[Annotation] = None if source.name == AGL_ANNOTATION: annotation = self.identify_general(source) @@ -269,6 +405,15 @@ def _filter_linked_spans(self, source: BaseAdaptingSequence[AdaptingSpan]) -> It # ignore the current span for now def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> BaseAdaptingSequence[AdaptingSpan]: + """Filter spans based on annotation membership and links. + + Args: + source: The span sequence to filter. + + Returns: + A filtered sequence containing only annotated spans and their linked + spans (include mode), or all spans except those (exclude mode). + """ linked_spans = list(self._filter_linked_spans(source)) if self.mode == "include": return source.retain(lambda span: span in linked_spans) @@ -349,6 +494,20 @@ def visit(node: Tree[AdaptingSpan]) -> Iterable[Sequence[AdaptingSpan]]: raise ValueError(f"Invalid candidate scope: {self.candidate_scope}") def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> BaseAdaptingSequence[AdaptingSpan]: + """Repair annotations that have no links by inferring targets from nearby spans. + + Scans the span sequence according to the configured direction and scope, + linking annotations without targets to the nearest eligible candidate spans. + + Args: + source: The span sequence containing annotations to repair. + + Returns: + A new sequence with repaired annotations containing inferred links. + + Raises: + ValueError: If `candidate_scope` is "siblings" but source is not a tree. + """ groups = list(self._search_groups(source)) span_id_to_link: Dict[str, LinkPydanticModel] = {} for group in groups: diff --git a/tests/adapter/test_annotation.py b/tests/adapter/test_annotation.py index 534bbe624..3779efea5 100644 --- a/tests/adapter/test_annotation.py +++ b/tests/adapter/test_annotation.py @@ -29,6 +29,7 @@ ExceptionAnnotation, GeneralAnnotation, MessageAnnotation, + ObjectAnnotation, OperationAnnotation, Tree, ) @@ -1076,3 +1077,379 @@ def test_repair_missing_links_empty_sequence() -> None: adapter = RepairMissingLinks() result = adapter.adapt(AdaptingSequence([])) assert len(list(result)) == 0 + + +# Additional corner case tests + + +def test_identify_object_returns_annotation_with_none_when_no_object_attributes() -> None: + """Should return ObjectAnnotation with None object when no JSON or literal present.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + AGL_OBJECT, + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={}, + ) + + result = adapter.identify_object(span) + + assert result is not None + assert result.annotation_type == "object" + assert result.object is None + + +def test_extract_agent_name_returns_none_for_agentops_without_operation_name() -> None: + """AgentOps span with kind=agent but no operation.name should return None.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + "some.operation", + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + "agentops.span.kind": "agent", + # Missing operation.name + }, + ) + + result = adapter.extract_agent_name(span) + + assert result is None + + +def test_extract_agent_name_returns_none_for_weave_without_input_name() -> None: + """Weave span with type=agent but no input.name should return None.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + "weave.call", + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + "type": "agent", + # Missing agentlightning.operation.input.name + }, + ) + + result = adapter.extract_agent_name(span) + + assert result is None + + +def test_extract_agent_name_returns_none_for_langchain_chain_without_lc_name() -> None: + """LangChain chain span without lc_name should return None.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + "langchain.Chain.MyChain", + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + # Missing lc_name + }, + ) + + result = adapter.extract_agent_name(span) + + assert result is None + + +def test_identify_annotations_adapt_processes_multiple_spans() -> None: + """Should process multiple spans in a sequence.""" + adapter = IdentifyAnnotations() + + span1 = make_adapting_span( + "s1", + AGL_ANNOTATION, + attributes={f"{LightningSpanAttributes.TAG.value}.0": "tag1"}, + ) + span2 = make_adapting_span( + "s2", + AGL_MESSAGE, + attributes={LightningSpanAttributes.MESSAGE_BODY.value: "Hello"}, + ) + span3 = make_adapting_span("s3", "regular.span") + + source = AdaptingSequence([span1, span2, span3]) + result = adapter.adapt(source) + + result_list = list(result) + assert len(result_list) == 3 + assert isinstance(result_list[0].data, GeneralAnnotation) + assert isinstance(result_list[1].data, MessageAnnotation) + assert result_list[2].data is None + + +def test_select_by_annotation_links_use_and_logic() -> None: + """Multiple links in an annotation use AND logic - span must match ALL links.""" + adapter = SelectByAnnotation(mode="include") + + # Links specify span_id AND a custom attribute must both match + annotation_span = make_adapting_span( + "annotation", + AGL_ANNOTATION, + data=GeneralAnnotation( + annotation_type="general", + links=[ + LinkPydanticModel(key_match="span_id", value_match="target"), + LinkPydanticModel(key_match="custom.tag", value_match="important"), + ], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + # This span matches both constraints + target = make_adapting_span("target", "target.operation", attributes={"custom.tag": "important"}) + # This span only matches span_id but not the custom attribute + partial_match = make_adapting_span("partial", "other.operation", attributes={"custom.tag": "important"}) + unrelated = make_adapting_span("unrelated", "operation3") + + source = AdaptingSequence([annotation_span, target, partial_match, unrelated]) + result = adapter.adapt(source) + + span_ids = [s.span_id for s in result] + assert "annotation" in span_ids + assert "target" in span_ids + # partial_match doesn't match span_id constraint + assert "partial" not in span_ids + assert "unrelated" not in span_ids + + +def test_select_by_annotation_multiple_annotations_link_different_spans() -> None: + """Multiple annotations can each link to different spans.""" + adapter = SelectByAnnotation(mode="include") + + ann1 = make_adapting_span( + "ann1", + AGL_ANNOTATION, + data=GeneralAnnotation( + annotation_type="general", + links=[LinkPydanticModel(key_match="span_id", value_match="linked1")], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + ann2 = make_adapting_span( + "ann2", + AGL_ANNOTATION, + data=GeneralAnnotation( + annotation_type="general", + links=[LinkPydanticModel(key_match="span_id", value_match="linked2")], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + linked1 = make_adapting_span("linked1", "operation1") + linked2 = make_adapting_span("linked2", "operation2") + unrelated = make_adapting_span("unrelated", "operation3") + + source = AdaptingSequence([ann1, ann2, linked1, linked2, unrelated]) + result = adapter.adapt(source) + + span_ids = [s.span_id for s in result] + assert "ann1" in span_ids + assert "ann2" in span_ids + assert "linked1" in span_ids + assert "linked2" in span_ids + assert "unrelated" not in span_ids + + +def test_repair_missing_links_multiple_annotations_single_candidate_no_reuse() -> None: + """Without reuse, only one annotation gets linked to a candidate.""" + adapter = RepairMissingLinks(scan_direction="backward", allow_reuse_linked_spans=False) + + candidate = make_adapting_span("candidate", "operation", start_time=0.0, end_time=1.0) + ann1 = make_adapting_span( + "ann1", + AGL_ANNOTATION, + start_time=1.0, + end_time=2.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + ann2 = make_adapting_span( + "ann2", + AGL_ANNOTATION, + start_time=2.0, + end_time=3.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + ann3 = make_adapting_span( + "ann3", + AGL_ANNOTATION, + start_time=3.0, + end_time=4.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + + source = AdaptingSequence([candidate, ann1, ann2, ann3]) + result = adapter.adapt(source) + + # Only one annotation should get linked + linked = [s for s in result if isinstance(s.data, GeneralAnnotation) and len(s.data.links) > 0] + assert len(linked) == 1 + + +def test_repair_missing_links_skips_annotation_spans_as_candidates() -> None: + """Annotation spans should not be considered as link candidates.""" + adapter = RepairMissingLinks(scan_direction="backward") + + # Two annotations in sequence - second should not link to first + ann1 = make_adapting_span( + "ann1", + AGL_ANNOTATION, + start_time=0.0, + end_time=1.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + ann2 = make_adapting_span( + "ann2", + AGL_ANNOTATION, + start_time=1.0, + end_time=2.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + + source = AdaptingSequence([ann1, ann2]) + result = adapter.adapt(source) + + # Neither annotation should get links since there are no candidates + for span in result: + if isinstance(span.data, GeneralAnnotation): + assert len(span.data.links) == 0 + + +def test_identify_operation_with_no_input_output() -> None: + """Should handle operation spans with no input or output.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + AGL_OPERATION, + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + LightningSpanAttributes.OPERATION_NAME.value: "simple_op", + }, + ) + + result = adapter.identify_operation(span) + + assert result is not None + assert result.name == "simple_op" + assert result.input == {} + assert result.output == {} + + +def test_identify_general_with_multiple_rewards() -> None: + """Should extract multiple rewards and use first as primary.""" + adapter = IdentifyAnnotations() + span = make_span( + "s1", + AGL_ANNOTATION, + parent_id=None, + start_time=0.0, + end_time=1.0, + attributes={ + f"{LightningSpanAttributes.REWARD.value}.0.name": "quality", + f"{LightningSpanAttributes.REWARD.value}.0.value": 0.9, + f"{LightningSpanAttributes.REWARD.value}.1.name": "relevance", + f"{LightningSpanAttributes.REWARD.value}.1.value": 0.7, + f"{LightningSpanAttributes.TAG.value}.0": "test", + }, + ) + + result = adapter.identify_general(span) + + assert result is not None + assert len(result.rewards) == 2 + assert result.primary_reward == 0.9 + + +def test_adapt_one_with_object_span() -> None: + """Should identify AGL_OBJECT spans via adapt_one.""" + adapter = IdentifyAnnotations() + source = make_adapting_span( + "s1", + AGL_OBJECT, + attributes={LightningSpanAttributes.OBJECT_JSON.value: '{"key": "value"}'}, + ) + + result = adapter.adapt_one(source) + + assert isinstance(result.data, ObjectAnnotation) + assert result.data.object == {"key": "value"} + + +def test_select_by_annotation_with_link_using_attribute_match() -> None: + """Should support links that match by custom attribute.""" + adapter = SelectByAnnotation(mode="include") + + annotation_span = make_adapting_span( + "annotation", + AGL_ANNOTATION, + data=GeneralAnnotation( + annotation_type="general", + links=[LinkPydanticModel(key_match="custom.tag", value_match="important")], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + target = make_adapting_span("target", "target.operation", attributes={"custom.tag": "important"}) + other = make_adapting_span("other", "other.operation", attributes={"custom.tag": "other"}) + + source = AdaptingSequence([annotation_span, target, other]) + result = adapter.adapt(source) + + span_ids = [s.span_id for s in result] + assert "annotation" in span_ids + assert "target" in span_ids + assert "other" not in span_ids From dda4308875d85f0939c7fd32808edca4f59d4ac7 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Mon, 12 Jan 2026 13:05:51 +0800 Subject: [PATCH 27/41] find and fix two bugs --- agentlightning/adapter/annotation.py | 25 ++- tests/adapter/test_annotation.py | 292 +++++++++++++++++++++++++++ 2 files changed, 306 insertions(+), 11 deletions(-) diff --git a/agentlightning/adapter/annotation.py b/agentlightning/adapter/annotation.py index 941be4cac..b65e9cdbc 100644 --- a/agentlightning/adapter/annotation.py +++ b/agentlightning/adapter/annotation.py @@ -253,13 +253,14 @@ def extract_agent_name(self, span: Span) -> Optional[str]: """Extract agent name from span attributes using framework-specific heuristics. Supports multiple agent frameworks by checking various attribute patterns: - 1. OpenTelemetry agent spans (`agent.name`) - 2. AgentOps decorated agents (`agentops.span.kind` + `operation.name`) - 3. Autogen teams (`recipient_agent_type`) - 4. LangGraph (`langchain.chain.type`) - 5. agent-framework (`executor.id`) - 6. Weave (`type` == "agent" + `agentlightning.operation.input.name`) - 7. Weave + LangChain (`langchain.Chain.*` span names + `lc_name`) + + 1. OpenTelemetry agent spans (`agent.name`) + 2. AgentOps decorated agents (`agentops.span.kind` + `operation.name`) + 3. Autogen teams (`recipient_agent_type`) + 4. LangGraph (`langchain.chain.type`) + 5. agent-framework (`executor.id`) + 6. Weave (`type` == "agent" + `agentlightning.operation.input.name`) + 7. Weave + LangChain (`langchain.Chain.*` span names + `lc_name`) Args: span: The span to extract the agent name from. @@ -387,7 +388,6 @@ class SelectByAnnotation(SequenceAdapter[AdaptingSpan, AdaptingSpan]): Args: mode: "include" to select spans within the annotations; "exclude" to exclude them. - """ def __init__(self, mode: Literal["include", "exclude"]) -> None: @@ -400,7 +400,9 @@ def _filter_linked_spans(self, source: BaseAdaptingSequence[AdaptingSpan]) -> It for span in source: if span.span_id in annotation_span_ids: yield span - elif any(check_linked_span(span, links) for links in annotation_links): + # Only check non-empty link lists; empty links means the annotation applies only to itself + # (check_linked_span returns True for empty links, which would incorrectly match all spans) + elif any(links and check_linked_span(span, links) for links in annotation_links): yield span # ignore the current span for now @@ -526,8 +528,9 @@ def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> BaseAdaptingSeque # The span is a candidate if self.candidate_predicate(span): while len(annotations_to_fill) > 0: - # Fill the link - annotation_span = annotations_to_fill.pop(-1) + # Fill the link with the earliest-encountered annotation first (FIFO order) + # This ensures each annotation links to its nearest candidate in the scan direction + annotation_span = annotations_to_fill.pop(0) span_id_to_link[annotation_span.span_id] = LinkPydanticModel( key_match="span_id", value_match=span.span_id ) diff --git a/tests/adapter/test_annotation.py b/tests/adapter/test_annotation.py index 3779efea5..b3b347f0b 100644 --- a/tests/adapter/test_annotation.py +++ b/tests/adapter/test_annotation.py @@ -788,6 +788,72 @@ def test_select_by_annotation_include_mode_with_no_annotations() -> None: assert len(list(result)) == 0 +def test_select_by_annotation_include_mode_with_empty_links() -> None: + """Include mode with annotation having empty links should only include the annotation itself. + + An annotation with links=[] should apply only to itself, not to all spans. + This tests that check_linked_span returning True for empty links doesn't + cause all spans to be incorrectly selected. + """ + adapter = SelectByAnnotation(mode="include") + + # Annotation with no links - should only include itself + annotation_span = make_adapting_span( + "annotation", + AGL_ANNOTATION, + data=GeneralAnnotation( + annotation_type="general", + links=[], # Empty links! + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + unrelated_span1 = make_adapting_span("unrelated1", "some.operation") + unrelated_span2 = make_adapting_span("unrelated2", "other.operation") + + source = AdaptingSequence([annotation_span, unrelated_span1, unrelated_span2]) + result = adapter.adapt(source) + + span_ids = [s.span_id for s in result] + # Only the annotation span should be selected, not the unrelated spans + assert span_ids == ["annotation"] + + +def test_select_by_annotation_exclude_mode_with_empty_links() -> None: + """Exclude mode with annotation having empty links should only exclude the annotation itself. + + An annotation with links=[] should apply only to itself, so only the annotation + span should be excluded, not all spans. + """ + adapter = SelectByAnnotation(mode="exclude") + + annotation_span = make_adapting_span( + "annotation", + AGL_ANNOTATION, + data=GeneralAnnotation( + annotation_type="general", + links=[], # Empty links! + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + unrelated_span1 = make_adapting_span("unrelated1", "some.operation") + unrelated_span2 = make_adapting_span("unrelated2", "other.operation") + + source = AdaptingSequence([annotation_span, unrelated_span1, unrelated_span2]) + result = adapter.adapt(source) + + span_ids = [s.span_id for s in result] + # Only the annotation span should be excluded, unrelated spans should remain + assert "annotation" not in span_ids + assert "unrelated1" in span_ids + assert "unrelated2" in span_ids + + def test_select_by_annotation_exclude_mode_with_no_annotations() -> None: """Exclude mode with no annotations should return all spans.""" adapter = SelectByAnnotation(mode="exclude") @@ -1322,6 +1388,232 @@ def test_repair_missing_links_multiple_annotations_single_candidate_no_reuse() - assert len(linked) == 1 +def test_repair_missing_links_backward_links_nearest_candidate() -> None: + """Backward scan should link each annotation to its nearest preceding candidate. + + When multiple annotations appear before candidates (in chronological order), + the annotation closest to the candidate should get that candidate. + + Chronological order: [C1, C2, A1, A2] + Expected: A2->C2 (nearest), A1->C1 (remaining) + """ + adapter = RepairMissingLinks(scan_direction="backward", allow_reuse_linked_spans=False) + + c1 = make_adapting_span("c1", "operation", start_time=0.0, end_time=1.0) + c2 = make_adapting_span("c2", "operation", start_time=1.0, end_time=2.0) + a1 = make_adapting_span( + "a1", + AGL_ANNOTATION, + start_time=2.0, + end_time=3.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + a2 = make_adapting_span( + "a2", + AGL_ANNOTATION, + start_time=3.0, + end_time=4.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + + source = AdaptingSequence([c1, c2, a1, a2]) + result = adapter.adapt(source) + + # A2 (latest annotation) should link to C2 (nearest candidate before it) + a2_result = next(s for s in result if s.span_id == "a2") + assert isinstance(a2_result.data, GeneralAnnotation) + assert len(a2_result.data.links) == 1 + assert a2_result.data.links[0].value_match == "c2" + + # A1 should link to C1 (remaining candidate) + a1_result = next(s for s in result if s.span_id == "a1") + assert isinstance(a1_result.data, GeneralAnnotation) + assert len(a1_result.data.links) == 1 + assert a1_result.data.links[0].value_match == "c1" + + +def test_repair_missing_links_forward_links_nearest_candidate() -> None: + """Forward scan should link each annotation to its nearest following candidate. + + When multiple annotations appear before candidates (in chronological order), + the annotation closest to the candidate should get that candidate. + + Chronological order: [A1, A2, C1, C2] + Expected: A1->C1 (nearest), A2->C2 (remaining) + """ + adapter = RepairMissingLinks(scan_direction="forward", allow_reuse_linked_spans=False) + + a1 = make_adapting_span( + "a1", + AGL_ANNOTATION, + start_time=0.0, + end_time=1.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + a2 = make_adapting_span( + "a2", + AGL_ANNOTATION, + start_time=1.0, + end_time=2.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + c1 = make_adapting_span("c1", "operation", start_time=2.0, end_time=3.0) + c2 = make_adapting_span("c2", "operation", start_time=3.0, end_time=4.0) + + source = AdaptingSequence([a1, a2, c1, c2]) + result = adapter.adapt(source) + + # A1 (earliest annotation) should link to C1 (nearest candidate after it) + a1_result = next(s for s in result if s.span_id == "a1") + assert isinstance(a1_result.data, GeneralAnnotation) + assert len(a1_result.data.links) == 1 + assert a1_result.data.links[0].value_match == "c1" + + # A2 should link to C2 (remaining candidate) + a2_result = next(s for s in result if s.span_id == "a2") + assert isinstance(a2_result.data, GeneralAnnotation) + assert len(a2_result.data.links) == 1 + assert a2_result.data.links[0].value_match == "c2" + + +def test_repair_missing_links_backward_interleaved() -> None: + """Backward scan with interleaved annotations and candidates. + + Chronological order: [C1, A1, C2, A2] + Expected: A1->C1 (nearest before A1), A2->C2 (nearest before A2) + """ + adapter = RepairMissingLinks(scan_direction="backward", allow_reuse_linked_spans=False) + + c1 = make_adapting_span("c1", "operation", start_time=0.0, end_time=1.0) + a1 = make_adapting_span( + "a1", + AGL_ANNOTATION, + start_time=1.0, + end_time=2.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + c2 = make_adapting_span("c2", "operation", start_time=2.0, end_time=3.0) + a2 = make_adapting_span( + "a2", + AGL_ANNOTATION, + start_time=3.0, + end_time=4.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + + source = AdaptingSequence([c1, a1, c2, a2]) + result = adapter.adapt(source) + + # A2 should link to C2 (nearest candidate before A2) + a2_result = next(s for s in result if s.span_id == "a2") + assert isinstance(a2_result.data, GeneralAnnotation) + assert len(a2_result.data.links) == 1 + assert a2_result.data.links[0].value_match == "c2" + + # A1 should link to C1 (nearest candidate before A1) + a1_result = next(s for s in result if s.span_id == "a1") + assert isinstance(a1_result.data, GeneralAnnotation) + assert len(a1_result.data.links) == 1 + assert a1_result.data.links[0].value_match == "c1" + + +def test_repair_missing_links_forward_interleaved() -> None: + """Forward scan with interleaved annotations and candidates. + + Chronological order: [A1, C1, A2, C2] + Expected: A1->C1 (nearest after A1), A2->C2 (nearest after A2) + """ + adapter = RepairMissingLinks(scan_direction="forward", allow_reuse_linked_spans=False) + + a1 = make_adapting_span( + "a1", + AGL_ANNOTATION, + start_time=0.0, + end_time=1.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + c1 = make_adapting_span("c1", "operation", start_time=1.0, end_time=2.0) + a2 = make_adapting_span( + "a2", + AGL_ANNOTATION, + start_time=2.0, + end_time=3.0, + data=GeneralAnnotation( + annotation_type="general", + links=[], + rewards=[], + primary_reward=None, + tags=[], + custom_fields={}, + ), + ) + c2 = make_adapting_span("c2", "operation", start_time=3.0, end_time=4.0) + + source = AdaptingSequence([a1, c1, a2, c2]) + result = adapter.adapt(source) + + # A1 should link to C1 (nearest candidate after A1) + a1_result = next(s for s in result if s.span_id == "a1") + assert isinstance(a1_result.data, GeneralAnnotation) + assert len(a1_result.data.links) == 1 + assert a1_result.data.links[0].value_match == "c1" + + # A2 should link to C2 (nearest candidate after A2) + a2_result = next(s for s in result if s.span_id == "a2") + assert isinstance(a2_result.data, GeneralAnnotation) + assert len(a2_result.data.links) == 1 + assert a2_result.data.links[0].value_match == "c2" + + def test_repair_missing_links_skips_annotation_spans_as_candidates() -> None: """Annotation spans should not be considered as link candidates.""" adapter = RepairMissingLinks(scan_direction="backward") From 3dea061eadf125805d4e5a145d065990bc5ea1a4 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Mon, 12 Jan 2026 13:25:36 +0800 Subject: [PATCH 28/41] update test_call --- agentlightning/adapter/call.py | 12 ++- agentlightning/types/adapter.py | 3 + tests/adapter/test_call.py | 146 +++++++++++++++++++++++++++++++- 3 files changed, 157 insertions(+), 4 deletions(-) diff --git a/agentlightning/adapter/call.py b/agentlightning/adapter/call.py index 60244f345..6600a7212 100644 --- a/agentlightning/adapter/call.py +++ b/agentlightning/adapter/call.py @@ -51,6 +51,10 @@ def _parse_response(self, span: AdaptingSpan) -> Dict[str, Any]: response = filter_and_unflatten_attributes(span.attributes, "gen_ai.response") if not isinstance(response, dict): raise ValueError(f"Invalid response format in span attributes: {response}") + if "created" not in response: + response["created"] = int(span.ensure_end_time()) + if "object" not in response: + response["object"] = "chat.completion" return response def _parse_completion_choices(self, span: AdaptingSpan) -> List[Dict[str, Any]]: @@ -109,6 +113,12 @@ def _parse_usages(self, span: AdaptingSpan) -> Dict[str, Any]: usages = filter_and_unflatten_attributes(span.attributes, "gen_ai.usage") if not isinstance(usages, dict): raise ValueError(f"Invalid usages format in span attributes: {usages}") + if "prompt_tokens" not in usages: + usages["prompt_tokens"] = 0 + if "completion_tokens" not in usages: + usages["completion_tokens"] = 0 + if "total_tokens" not in usages: + usages["total_tokens"] = usages["prompt_tokens"] + usages["completion_tokens"] return usages def _parse_agentops_tool_calls(self, span: Span) -> Optional[Dict[str, Any]]: @@ -159,8 +169,6 @@ def _parse_openai_chat_completion_create(self, span: AdaptingSpan) -> ChatComple response = ChatCompletion.model_validate( { **response_metadata, - "object": "chat.completion", - "created": int(span.ensure_end_time()), "choices": completion_choices, "usage": usages, } diff --git a/agentlightning/types/adapter.py b/agentlightning/types/adapter.py index e26560015..395ddec2d 100644 --- a/agentlightning/types/adapter.py +++ b/agentlightning/types/adapter.py @@ -104,6 +104,9 @@ def __init__(self, item: T_co, children: Sequence[Tree[T_co]]) -> None: self._parent: Optional[weakref.ReferenceType[Tree[T_co]]] = None for child in self._children: child._parent = weakref.ref(self) # type: ignore + # Set container reference on the item if it supports it (e.g., AdaptingSpan) + if hasattr(item, "container"): + object.__setattr__(item, "container", self) @property def item(self) -> T_co: diff --git a/tests/adapter/test_call.py b/tests/adapter/test_call.py index fec5876b4..fe1ca0636 100644 --- a/tests/adapter/test_call.py +++ b/tests/adapter/test_call.py @@ -4,7 +4,7 @@ from agentlightning.adapter.base import Chain from agentlightning.adapter.call import IdentifyChatCompletionCalls -from agentlightning.adapter.preprocess import RepairMalformedSpans, ToAdaptingSpans, ToTree +from agentlightning.adapter.preprocess import RepairMalformedSpans, ToTree from agentlightning.types.tracer import Span @@ -139,4 +139,146 @@ def test_openai_calls(): print(span) -test_openai_calls() +def test_litellm_call(): + """Test LiteLLM proxy spans with chat completion calls.""" + spans: List[Span] = [ + # proxy_pre_call span + Span.from_attributes( + attributes={ + "call_type": "add_litellm_data_to_request", + "service": "proxy_pre_call", + }, + name="proxy_pre_call", + span_id="a82547704417abf4", + parent_id="9e1058bdd104a886", + sequence_id=0, + ), + # router span (async_get_available_deployment) + Span.from_attributes( + attributes={ + "call_type": "async_get_available_deployment", + "service": "router", + }, + name="router", + span_id="48086befab5d70cd", + parent_id="9e1058bdd104a886", + sequence_id=1, + ), + # self span (make_openai_chat_completion_request) + Span.from_attributes( + attributes={ + "call_type": "make_openai_chat_completion_request <- track_llm_api_timing", + "service": "self", + }, + name="self", + span_id="7b2c9b9d544c2107", + parent_id="9e1058bdd104a886", + sequence_id=2, + ), + # router span (acompletion) + Span.from_attributes( + attributes={ + "call_type": "acompletion", + "service": "router", + }, + name="router", + span_id="44f9efc7cd957922", + parent_id="9e1058bdd104a886", + sequence_id=3, + ), + # litellm_request span - main LLM call span with gen_ai attributes + Span.from_attributes( + attributes={ + "metadata.user_api_key_hash": "", + "metadata.user_api_key_alias": "", + "metadata.user_api_key_spend": 0.0, + "metadata.user_api_key_max_budget": "", + "metadata.user_api_key_budget_reset_at": "", + "metadata.user_api_key_team_id": "", + "metadata.user_api_key_org_id": "", + "metadata.user_api_key_user_id": "", + "metadata.user_api_key_team_alias": "", + "metadata.user_api_key_user_email": "", + "metadata.user_api_key_end_user_id": "", + "metadata.user_api_key_request_route": "/v1/chat/completions", + "metadata.spend_logs_metadata": "", + "metadata.requester_ip_address": "", + "metadata.requester_metadata": "{}", + "metadata.prompt_management_metadata": "", + "metadata.applied_guardrails": "[]", + "metadata.mcp_tool_call_metadata": "", + "metadata.vector_store_request_metadata": "", + "metadata.usage_object": "{'completion_tokens': 48, 'prompt_tokens': 36, 'total_tokens': 84}", + "metadata.requester_custom_headers": "{'x-rollout-id': 'ro-0b4d59a7d478'}", + "metadata.cold_storage_object_key": "", + "metadata.user_api_key_auth_metadata": "{}", + "hidden_params": '{"model_id": "f6746e78...b010", "api_base": "http://localhost:45177/v1"}', + "gen_ai.cost.input_cost": 0, + "gen_ai.cost.output_cost": 0, + "gen_ai.cost.total_cost": 0.0, + "gen_ai.cost.tool_usage_cost": 0.0, + "gen_ai.cost.original_cost": 0.0, + "gen_ai.cost.discount_percent": 0.0, + "gen_ai.cost.discount_amount": 0.0, + "gen_ai.request.model": "Qwen/Qwen2.5-0.5B-Instruct", + "llm.request.type": "acompletion", + "gen_ai.system": "hosted_vllm", + "llm.is_streaming": "False", + "gen_ai.response.id": "chatcmpl-8b25...3e6e", + "gen_ai.response.model": "hosted_vllm/Qwen/Qwen2.5-0.5B-Instruct", + "llm.usage.total_tokens": 84, + "gen_ai.usage.completion_tokens": 48, + "gen_ai.usage.prompt_tokens": 36, + "gen_ai.prompt.0.role": "user", + "gen_ai.prompt.0.content": "Hello, what's your name?", + "gen_ai.completion.0.finish_reason": "stop", + "gen_ai.completion.0.role": "assistant", + "gen_ai.completion.0.content": "Hello! I am Qwen, an AI assistant created by Alibaba Cloud.", + }, + name="litellm_request", + span_id="c43d325d68344786", + parent_id="9e1058bdd104a886", + sequence_id=4, + ), + # raw_gen_ai_request span - sibling of litellm_request (same parent) + Span.from_attributes( + attributes={ + "llm.hosted_vllm.messages": '[{"role": "user", "content": "Hello, what\'s your name?"}]', + "llm.hosted_vllm.extra_body": "{'return_token_ids': True}", + "llm.hosted_vllm.id": "chatcmpl-8b25...3e6e", + "llm.hosted_vllm.choices": '[{"finish_reason": "stop", "index": 0, "message": {"content": "Hello! I am Qwen...", "role": "assistant"}}]', + "llm.hosted_vllm.created": 1767842789, + "llm.hosted_vllm.model": "Qwen/Qwen2.5-0.5B-Instruct", + "llm.hosted_vllm.object": "chat.completion", + "llm.hosted_vllm.service_tier": "", + "llm.hosted_vllm.system_fingerprint": "", + "llm.hosted_vllm.usage": "{'completion_tokens': 48, 'prompt_tokens': 36, 'total_tokens': 84}", + "llm.hosted_vllm.prompt_logprobs": "", + "llm.hosted_vllm.prompt_token_ids": "[151644, 8948, 198, ...]", + "llm.hosted_vllm.kv_transfer_params": "", + }, + name="raw_gen_ai_request", + span_id="bb5a15dd04c9e74b", + parent_id="9e1058bdd104a886", # Same parent as litellm_request - they are siblings + sequence_id=5, + ), + # Root span - Received Proxy Server Request + Span.from_attributes( + attributes={}, + name="Received Proxy Server Request", + span_id="9e1058bdd104a886", + parent_id=None, + sequence_id=6, + ), + ] + + adapter = Chain( + RepairMalformedSpans(), + ToTree(), + IdentifyChatCompletionCalls(), + ) + + adapted_spans = adapter(spans) + + for span in adapted_spans: + print(span) From c8b96251d717476324c7d62319dd1cab83f9bb07 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Mon, 12 Jan 2026 13:51:49 +0800 Subject: [PATCH 29/41] fix adapter types --- agentlightning/types/adapter.py | 30 +- tests/types/test_adapter.py | 643 ++++++++++++++++++++++++++++++++ 2 files changed, 663 insertions(+), 10 deletions(-) create mode 100644 tests/types/test_adapter.py diff --git a/agentlightning/types/adapter.py b/agentlightning/types/adapter.py index 395ddec2d..29dbe7988 100644 --- a/agentlightning/types/adapter.py +++ b/agentlightning/types/adapter.py @@ -99,14 +99,15 @@ class Tree(BaseAdaptingSequence[T_co], Generic[T_co]): """ def __init__(self, item: T_co, children: Sequence[Tree[T_co]]) -> None: - self._item = item self._children = children self._parent: Optional[weakref.ReferenceType[Tree[T_co]]] = None for child in self._children: child._parent = weakref.ref(self) # type: ignore - # Set container reference on the item if it supports it (e.g., AdaptingSpan) - if hasattr(item, "container"): - object.__setattr__(item, "container", self) + # Set container on item if it's an AdaptingSpan + if isinstance(item, AdaptingSpan): + self._item: T_co = item.model_copy(update={"container": self}) # type: ignore + else: + self._item = item @property def item(self) -> T_co: @@ -194,7 +195,14 @@ class AdaptingSequence(BaseAdaptingSequence[T], Generic[T]): """A simple list implementation of AdaptingSequence.""" def __init__(self, items: Sequence[T]) -> None: - self._items = list(items) + # Set container on items if they are AdaptingSpan instances + processed: list[T] = [] + for item in items: + if isinstance(item, AdaptingSpan): + processed.append(item.model_copy(update={"container": self})) # type: ignore + else: + processed.append(item) + self._items = processed def get(self, index: Union[int, slice]) -> Union[T, Sequence[T]]: return self._items[index] @@ -285,10 +293,12 @@ def siblings(self) -> Sequence[AdaptingSpan]: Only applicable when the container is a [`Tree`][agentlightning.Tree]. """ - parent = self.parent_span() - if parent is None: + if self.container is None or not isinstance(self.container, Tree): + raise ValueError("AdaptingSpan.siblings() is only applicable when container is non-empty and a Tree.") + parent_tree = self.container.parent + if parent_tree is None: return [] - return [child for child in parent.children() if child != self] + return [child.item for child in parent_tree.children if child is not self.container] def parent_span(self) -> Optional[AdaptingSpan]: """Get the parent span if available. @@ -317,7 +327,7 @@ class Annotation(BaseModel): annotation_type: AnnotationType """Type of the annotation.""" - links: Sequence[LinkPydanticModel] = Field(default_factory=list[LinkPydanticModel]) + links: Sequence[LinkPydanticModel] = Field(default_factory=list) # type: ignore """Links to other spans or objects.""" @@ -343,7 +353,7 @@ class GeneralAnnotation(Annotation): annotation_type: AnnotationType = "general" """Type of the annotation.""" - rewards: Sequence[RewardPydanticModel] = Field(default_factory=list[RewardPydanticModel]) + rewards: Sequence[RewardPydanticModel] = Field(default_factory=list) # type: ignore """Reward dimensions and values.""" primary_reward: Optional[float] = None diff --git a/tests/types/test_adapter.py b/tests/types/test_adapter.py new file mode 100644 index 000000000..0f9659819 --- /dev/null +++ b/tests/types/test_adapter.py @@ -0,0 +1,643 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for Tree, AdaptingSequence, and AdaptingSpan data structures.""" + +import logging +from typing import Any, Dict, Optional + +import pytest + +from agentlightning.types import OtelResource, Span, TraceStatus +from agentlightning.types.adapter import AdaptingSequence, AdaptingSpan, Tree + + +def make_span( + name: str, + attributes: Dict[str, Any], + sequence_id: int, + *, + parent_id: Optional[str] = None, +) -> Span: + """Create a test span with minimal required fields.""" + return Span( + rollout_id="rollout-id", + attempt_id="attempt-id", + sequence_id=sequence_id, + trace_id=f"trace-{sequence_id}", + span_id=f"span-{sequence_id}", + parent_id=parent_id, + name=name, + status=TraceStatus(status_code="OK"), + attributes=attributes, + events=[], + links=[], + start_time=None, + end_time=None, + context=None, + parent=None, + resource=OtelResource(attributes={}, schema_url=""), + ) + + +# ============================================================================ +# Tree Tests +# ============================================================================ + + +def test_tree_single_node(): + tree = Tree("root", []) + assert tree.item == "root" + assert tree.children == [] + assert tree.parent is None + + +def test_tree_with_children(): + child1 = Tree("child1", []) + child2 = Tree("child2", []) + root = Tree("root", [child1, child2]) + + assert root.item == "root" + assert len(root.children) == 2 + assert root.children[0].item == "child1" + assert root.children[1].item == "child2" + + +def test_tree_parent_reference(): + grandchild = Tree("grandchild", []) + child = Tree("child", [grandchild]) + root = Tree("root", [child]) + + assert root.parent is None + assert child.parent is root + assert grandchild.parent is child + + +def test_tree_len(): + child1 = Tree("child1", []) + child2 = Tree("child2", []) + root = Tree("root", [child1, child2]) + assert len(root) == 3 + + +def test_tree_len_deep(): + grandchild = Tree("grandchild", []) + child = Tree("child", [grandchild]) + root = Tree("root", [child]) + assert len(root) == 3 + + +def test_tree_getitem_single(): + root = Tree("root", []) + assert root[0] == "root" + + +def test_tree_getitem_with_children(): + child1 = Tree("child1", []) + child2 = Tree("child2", []) + root = Tree("root", [child1, child2]) + # DFS order: root, child1, child2 + assert root[0] == "root" + assert root[1] == "child1" + assert root[2] == "child2" + + +def test_tree_getitem_slice(): + child1 = Tree("child1", []) + child2 = Tree("child2", []) + root = Tree("root", [child1, child2]) + assert root[1:] == ["child1", "child2"] + + +def test_tree_iter(): + child1 = Tree("child1", []) + child2 = Tree("child2", []) + root = Tree("root", [child1, child2]) + assert list(root) == ["root", "child1", "child2"] + + +def test_tree_traverse_single_node(): + tree = Tree("root", []) + assert list(tree.traverse()) == ["root"] + + +def test_tree_traverse_dfs_order(): + # root + # / \ + # child1 child2 + # | + # grandchild + grandchild = Tree("grandchild", []) + child1 = Tree("child1", [grandchild]) + child2 = Tree("child2", []) + root = Tree("root", [child1, child2]) + + # DFS order: root, child1, grandchild, child2 + assert list(root.traverse()) == ["root", "child1", "grandchild", "child2"] + + +def test_tree_size_single_node(): + tree = Tree("root", []) + assert tree.size() == 1 + + +def test_tree_size_with_children(): + grandchild = Tree("grandchild", []) + child1 = Tree("child1", [grandchild]) + child2 = Tree("child2", []) + root = Tree("root", [child1, child2]) + assert root.size() == 4 + + +def test_tree_map_single_node(): + tree = Tree(1, []) + mapped = tree.map(lambda x: x * 2) + assert mapped.item == 2 + assert mapped.children == [] + + +def test_tree_map_with_children(): + child1 = Tree(2, []) + child2 = Tree(3, []) + root = Tree(1, [child1, child2]) + + mapped = root.map(lambda x: x * 10) + assert mapped.item == 10 + assert mapped.children[0].item == 20 + assert mapped.children[1].item == 30 + + +def test_tree_map_preserves_structure(): + grandchild = Tree("gc", []) + child = Tree("c", [grandchild]) + root = Tree("r", [child]) + + mapped = root.map(str.upper) + assert mapped.item == "R" + assert mapped.children[0].item == "C" + assert mapped.children[0].children[0].item == "GC" + + +def test_tree_retain_keeps_root_always(): + tree = Tree("root", []) + retained = tree.retain(lambda x: False) + assert retained.item == "root" + assert retained.children == [] + + +def test_tree_retain_keeps_matching_subtrees(): + # root + # / \ + # keep1 drop1 + # | + # drop2 + drop2 = Tree("drop2", []) + keep1 = Tree("keep1", [drop2]) + drop1 = Tree("drop1", []) + root = Tree("root", [keep1, drop1]) + + # Retain subtrees rooted at nodes containing "keep" + retained = root.retain(lambda x: "keep" in x) + + assert retained.item == "root" + assert len(retained.children) == 1 + # keep1 is retained along with its entire subtree + assert retained.children[0].item == "keep1" + + +def test_tree_retain_removes_branches_without_matches(): + # root + # / \ + # drop1 drop2 + # | + # keep + keep = Tree("keep", []) + drop1 = Tree("drop1", [keep]) + drop2 = Tree("drop2", []) + root = Tree("root", [drop1, drop2]) + + retained = root.retain(lambda x: x == "keep") + + assert retained.item == "root" + # drop1 branch is kept because it leads to "keep" + assert len(retained.children) == 1 + assert retained.children[0].item == "drop1" + + +def test_tree_retain_deep_tree(): + # root + # | + # a + # | + # b (keep) + # | + # c + c = Tree("c", []) + b = Tree("b", [c]) + a = Tree("a", [b]) + root = Tree("root", [a]) + + retained = root.retain(lambda x: x == "b") + # When b matches, the entire subtree rooted at b (including c) is retained + assert list(retained.traverse()) == ["root", "a", "b", "c"] + + +def test_tree_prune_does_not_remove_root(): + tree = Tree("root", []) + pruned = tree.prune(lambda x: x == "root") + assert pruned.item == "root" + + +def test_tree_prune_removes_matching_children(): + child1 = Tree("remove_me", []) + child2 = Tree("keep_me", []) + root = Tree("root", [child1, child2]) + + pruned = root.prune(lambda x: x == "remove_me") + + assert pruned.item == "root" + assert len(pruned.children) == 1 + assert pruned.children[0].item == "keep_me" + + +def test_tree_prune_removes_subtrees(): + # root + # / \ + # remove keep + # | + # child_of_remove + child_of_remove = Tree("child_of_remove", []) + remove = Tree("remove", [child_of_remove]) + keep = Tree("keep", []) + root = Tree("root", [remove, keep]) + + pruned = root.prune(lambda x: x == "remove") + + assert pruned.item == "root" + assert len(pruned.children) == 1 + assert pruned.children[0].item == "keep" + + +def test_tree_prune_recursive(): + # root + # | + # keep + # / \ + # remove keep2 + keep2 = Tree("keep2", []) + remove = Tree("remove", []) + keep = Tree("keep", [remove, keep2]) + root = Tree("root", [keep]) + + pruned = root.prune(lambda x: x == "remove") + + assert pruned.item == "root" + assert pruned.children[0].item == "keep" + assert len(pruned.children[0].children) == 1 + assert pruned.children[0].children[0].item == "keep2" + + +# ============================================================================ +# AdaptingSequence Tests +# ============================================================================ + + +def test_adapting_sequence_empty(): + seq = AdaptingSequence([]) + assert len(seq) == 0 + assert list(seq) == [] + + +def test_adapting_sequence_with_items(): + seq = AdaptingSequence([1, 2, 3]) + assert len(seq) == 3 + assert list(seq) == [1, 2, 3] + + +def test_adapting_sequence_getitem_single(): + seq = AdaptingSequence(["a", "b", "c"]) + assert seq[0] == "a" + assert seq[1] == "b" + assert seq[2] == "c" + + +def test_adapting_sequence_getitem_negative_index(): + seq = AdaptingSequence(["a", "b", "c"]) + assert seq[-1] == "c" + + +def test_adapting_sequence_getitem_slice(): + seq = AdaptingSequence(["a", "b", "c", "d"]) + assert seq[1:3] == ["b", "c"] + + +def test_adapting_sequence_iter(): + seq = AdaptingSequence([1, 2, 3]) + result = [] + for item in seq: + result.append(item) + assert result == [1, 2, 3] + + +def test_adapting_sequence_traverse(): + seq = AdaptingSequence([1, 2, 3]) + assert list(seq.traverse()) == [1, 2, 3] + + +def test_adapting_sequence_size(): + seq = AdaptingSequence([1, 2, 3, 4]) + assert seq.size() == 4 + + +def test_adapting_sequence_get(): + seq = AdaptingSequence(["x", "y", "z"]) + assert seq.get(0) == "x" + assert seq.get(1) == "y" + + +def test_adapting_sequence_map_empty(): + seq = AdaptingSequence([]) + mapped = seq.map(lambda x: x * 2) + assert list(mapped) == [] + + +def test_adapting_sequence_map_integers(): + seq = AdaptingSequence([1, 2, 3]) + mapped = seq.map(lambda x: x * 2) + assert list(mapped) == [2, 4, 6] + + +def test_adapting_sequence_map_strings(): + seq = AdaptingSequence(["a", "b", "c"]) + mapped = seq.map(str.upper) + assert list(mapped) == ["A", "B", "C"] + + +def test_adapting_sequence_map_returns_adapting_sequence(): + seq = AdaptingSequence([1, 2, 3]) + mapped = seq.map(lambda x: x) + assert isinstance(mapped, AdaptingSequence) + + +def test_adapting_sequence_retain_all(): + seq = AdaptingSequence([1, 2, 3]) + retained = seq.retain(lambda x: True) + assert list(retained) == [1, 2, 3] + + +def test_adapting_sequence_retain_none(): + seq = AdaptingSequence([1, 2, 3]) + retained = seq.retain(lambda x: False) + assert list(retained) == [] + + +def test_adapting_sequence_retain_some(): + seq = AdaptingSequence([1, 2, 3, 4, 5]) + retained = seq.retain(lambda x: x % 2 == 0) + assert list(retained) == [2, 4] + + +def test_adapting_sequence_retain_returns_adapting_sequence(): + seq = AdaptingSequence([1, 2, 3]) + retained = seq.retain(lambda x: True) + assert isinstance(retained, AdaptingSequence) + + +def test_adapting_sequence_prune_none(): + seq = AdaptingSequence([1, 2, 3]) + pruned = seq.prune(lambda x: False) + assert list(pruned) == [1, 2, 3] + + +def test_adapting_sequence_prune_all(): + seq = AdaptingSequence([1, 2, 3]) + pruned = seq.prune(lambda x: True) + assert list(pruned) == [] + + +def test_adapting_sequence_prune_some(): + seq = AdaptingSequence([1, 2, 3, 4, 5]) + pruned = seq.prune(lambda x: x % 2 == 0) + assert list(pruned) == [1, 3, 5] + + +def test_adapting_sequence_prune_returns_adapting_sequence(): + seq = AdaptingSequence([1, 2, 3]) + pruned = seq.prune(lambda x: False) + assert isinstance(pruned, AdaptingSequence) + + +# ============================================================================ +# AdaptingSpan Tests +# ============================================================================ + + +def test_adapting_span_from_span_creates_adapting_span(): + span = make_span("test-span", {"key": "value"}, 0) + adapting_span = AdaptingSpan.from_span(span, data="test-data") + + assert adapting_span.name == "test-span" + assert adapting_span.attributes == {"key": "value"} + assert adapting_span.data == "test-data" + + +def test_adapting_span_from_span_preserves_fields(): + span = make_span("my-span", {"attr": 123}, 5, parent_id="parent-span") + adapting_span = AdaptingSpan.from_span(span, data={"nested": "data"}) + + assert adapting_span.rollout_id == "rollout-id" + assert adapting_span.attempt_id == "attempt-id" + assert adapting_span.sequence_id == 5 + assert adapting_span.parent_id == "parent-span" + assert adapting_span.data == {"nested": "data"} + + +def test_adapting_span_from_adapting_span_updates_data(): + span = make_span("test-span", {}, 0) + adapting_span1 = AdaptingSpan.from_span(span, data="original") + adapting_span2 = AdaptingSpan.from_span(adapting_span1, data="updated") + + assert adapting_span2.data == "updated" + + +def test_adapting_span_with_data_creates_copy(): + span = make_span("test-span", {"key": "value"}, 0) + adapting_span = AdaptingSpan.from_span(span, data=None) + new_span = adapting_span.with_data("new-data", override="silent") + + assert new_span.data == "new-data" + assert new_span is not adapting_span + + +def test_adapting_span_with_data_preserves_other_fields(): + span = make_span("test-span", {"attr": 42}, 3) + adapting_span = AdaptingSpan.from_span(span, data=None) + new_span = adapting_span.with_data("new-data", override="silent") + + assert new_span.name == "test-span" + assert new_span.attributes == {"attr": 42} + assert new_span.sequence_id == 3 + + +def test_adapting_span_with_data_silent_override(): + span = make_span("test-span", {}, 0) + adapting_span = AdaptingSpan.from_span(span, data="original") + new_span = adapting_span.with_data("updated", override="silent") + + assert new_span.data == "updated" + + +def test_adapting_span_with_data_warning_override(caplog: pytest.LogCaptureFixture) -> None: + span = make_span("test-span", {}, 0) + adapting_span = AdaptingSpan.from_span(span, data="original") + + with caplog.at_level(logging.WARNING): + new_span = adapting_span.with_data("updated", override="warning") + + assert new_span.data == "updated" + assert "overwriting" in caplog.text.lower() + + +def test_adapting_span_with_data_forbidden_override(): + span = make_span("test-span", {}, 0) + adapting_span = AdaptingSpan.from_span(span, data="original") + + with pytest.raises(ValueError, match="forbidden"): + adapting_span.with_data("updated", override="forbidden") + + +def test_adapting_span_with_data_none_does_not_warn(caplog: pytest.LogCaptureFixture) -> None: + span = make_span("test-span", {}, 0) + adapting_span = AdaptingSpan.from_span(span, data=None) + + with caplog.at_level(logging.WARNING): + new_span = adapting_span.with_data("updated", override="warning") + + assert new_span.data == "updated" + assert "overwriting" not in caplog.text.lower() + + +def test_adapting_span_container_default_none(): + span = make_span("test", {}, 0) + adapting_span = AdaptingSpan.from_span(span, data=None) + assert adapting_span.container is None + + +def test_adapting_span_container_can_be_set(): + span = make_span("test", {}, 0) + adapting_span = AdaptingSpan.from_span(span, data=None) + seq = AdaptingSequence([adapting_span]) + adapting_span = adapting_span.model_copy(update={"container": seq}) + + assert adapting_span.container is seq + + +@pytest.fixture +def tree_of_adapting_spans(): + """Create a tree structure of AdaptingSpans for testing. + + Structure: + root + / \\ + child1 child2 + | + grandchild + """ + root_span = make_span("root", {}, 0) + child1_span = make_span("child1", {}, 1, parent_id="span-0") + child2_span = make_span("child2", {}, 2, parent_id="span-0") + grandchild_span = make_span("grandchild", {}, 3, parent_id="span-1") + + grandchild_tree: Tree[AdaptingSpan] = Tree( + AdaptingSpan.from_span(grandchild_span, data="gc-data"), + [], + ) + child1_tree: Tree[AdaptingSpan] = Tree( + AdaptingSpan.from_span(child1_span, data="c1-data"), + [grandchild_tree], + ) + child2_tree: Tree[AdaptingSpan] = Tree( + AdaptingSpan.from_span(child2_span, data="c2-data"), + [], + ) + root_tree: Tree[AdaptingSpan] = Tree( + AdaptingSpan.from_span(root_span, data="root-data"), + [child1_tree, child2_tree], + ) + + # Set container references + root_adapting = root_tree.item.model_copy(update={"container": root_tree}) + child1_adapting = child1_tree.item.model_copy(update={"container": child1_tree}) + child2_adapting = child2_tree.item.model_copy(update={"container": child2_tree}) + grandchild_adapting = grandchild_tree.item.model_copy(update={"container": grandchild_tree}) + + return root_adapting, child1_adapting, child2_adapting, grandchild_adapting + + +def test_adapting_span_children_returns_child_spans(tree_of_adapting_spans: Tree[AdaptingSpan]): + root, child1, child2, grandchild = tree_of_adapting_spans # type: ignore + children = root.children() + + assert len(children) == 2 + assert children[0].name == "child1" + assert children[1].name == "child2" + + +def test_adapting_span_children_leaf_node_empty(tree_of_adapting_spans: Tree[AdaptingSpan]): + root, child1, child2, grandchild = tree_of_adapting_spans # type: ignore + assert child2.children() == [] + + +def test_adapting_span_children_raises_without_tree_container(): + span = make_span("test", {}, 0) + adapting_span = AdaptingSpan.from_span(span, data=None) + + with pytest.raises(ValueError, match="container"): + adapting_span.children() + + +def test_adapting_span_children_raises_with_non_tree_container(): + span = make_span("test", {}, 0) + adapting_span = AdaptingSpan.from_span(span, data=None) + adapting_span = adapting_span.model_copy(update={"container": AdaptingSequence([])}) + + with pytest.raises(ValueError, match="Tree"): + adapting_span.children() + + +def test_adapting_span_parent_span_returns_parent(tree_of_adapting_spans: Tree[AdaptingSpan]): + root, child1, child2, grandchild = tree_of_adapting_spans # type: ignore + + parent = child1.parent_span() + assert parent is not None + assert parent.name == "root" + + +def test_adapting_span_parent_span_root_returns_none(tree_of_adapting_spans: Tree[AdaptingSpan]): + root, child1, child2, grandchild = tree_of_adapting_spans # type: ignore + assert root.parent_span() is None + + +def test_adapting_span_parent_span_raises_without_tree_container(): + span = make_span("test", {}, 0) + adapting_span = AdaptingSpan.from_span(span, data=None) + + with pytest.raises(ValueError, match="container"): + adapting_span.parent_span() + + +def test_adapting_span_siblings_returns_sibling_spans(tree_of_adapting_spans: Tree[AdaptingSpan]): + root, child1, child2, grandchild = tree_of_adapting_spans # type: ignore + + siblings = child1.siblings() + assert len(siblings) == 1 + assert siblings[0].name == "child2" + + +def test_adapting_span_siblings_only_child_returns_empty(tree_of_adapting_spans: Tree[AdaptingSpan]): + root, child1, child2, grandchild = tree_of_adapting_spans # type: ignore + assert grandchild.siblings() == [] + + +def test_adapting_span_siblings_root_returns_empty(tree_of_adapting_spans: Tree[AdaptingSpan]): + root, child1, child2, grandchild = tree_of_adapting_spans # type: ignore + assert root.siblings() == [] From 45c6a66ef8034f08274606dde1c9b72aa85236f5 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Mon, 12 Jan 2026 15:08:48 +0800 Subject: [PATCH 30/41] update adapter impl --- agentlightning/types/adapter.py | 55 ++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/agentlightning/types/adapter.py b/agentlightning/types/adapter.py index 29dbe7988..f129b58ab 100644 --- a/agentlightning/types/adapter.py +++ b/agentlightning/types/adapter.py @@ -15,6 +15,7 @@ Iterator, Literal, Optional, + Protocol, Sequence, TypeVar, Union, @@ -28,18 +29,28 @@ CompletionCreateParams, ) from pydantic import BaseModel, Field +from typing_extensions import Self from agentlightning.semconv import LinkPydanticModel, RewardPydanticModel from .tracer import Attributes, Span -T_co = TypeVar("T_co", covariant=True) T = TypeVar("T") -V = TypeVar("V") logger = logging.getLogger(__name__) +class BaseAdaptingSequenceItem(Protocol): + + def with_container(self, container: BaseAdaptingSequence[Self]) -> Self: + """Create a new item with the same properties but different container.""" + raise NotImplementedError() + + +T_co = TypeVar("T_co", covariant=True, bound=BaseAdaptingSequenceItem) +V_co = TypeVar("V_co", covariant=True, bound=BaseAdaptingSequenceItem) + + class BaseAdaptingSequence(Sequence[T_co], Generic[T_co]): """Interface that makes adapter easier to work with sequences.""" @@ -62,7 +73,7 @@ def get(self, index: Union[int, slice]) -> Union[T_co, Sequence[T_co]]: """Get the index-th item in the sequence.""" raise NotImplementedError() - def map(self, func: Callable[[T_co], V]) -> BaseAdaptingSequence[V]: + def map(self, func: Callable[[T_co], V_co]) -> BaseAdaptingSequence[V_co]: """Map a function over all items in the sequence.""" raise NotImplementedError() @@ -103,11 +114,7 @@ def __init__(self, item: T_co, children: Sequence[Tree[T_co]]) -> None: self._parent: Optional[weakref.ReferenceType[Tree[T_co]]] = None for child in self._children: child._parent = weakref.ref(self) # type: ignore - # Set container on item if it's an AdaptingSpan - if isinstance(item, AdaptingSpan): - self._item: T_co = item.model_copy(update={"container": self}) # type: ignore - else: - self._item = item + self._item = item.with_container(self) @property def item(self) -> T_co: @@ -136,7 +143,7 @@ def get(self, index: Union[int, slice]) -> Union[T_co, Sequence[T_co]]: """ return list(self.traverse())[index] - def map(self, func: Callable[[T_co], V]) -> Tree[V]: + def map(self, func: Callable[[T_co], V_co]) -> Tree[V_co]: """Map a function over all items in the tree.""" return Tree(func(self._item), [child.map(func) for child in self._children]) @@ -191,39 +198,33 @@ def visit(node: Tree[T_co]): dot.render(filename, format="png", cleanup=True) # type: ignore -class AdaptingSequence(BaseAdaptingSequence[T], Generic[T]): +class AdaptingSequence(BaseAdaptingSequence[T_co], Generic[T_co]): """A simple list implementation of AdaptingSequence.""" - def __init__(self, items: Sequence[T]) -> None: + def __init__(self, items: Sequence[T_co]) -> None: # Set container on items if they are AdaptingSpan instances - processed: list[T] = [] - for item in items: - if isinstance(item, AdaptingSpan): - processed.append(item.model_copy(update={"container": self})) # type: ignore - else: - processed.append(item) - self._items = processed - - def get(self, index: Union[int, slice]) -> Union[T, Sequence[T]]: + self._items = [item.with_container(self) for item in items] + + def get(self, index: Union[int, slice]) -> Union[T_co, Sequence[T_co]]: return self._items[index] - def traverse(self) -> Iterable[T]: + def traverse(self) -> Iterable[T_co]: return iter(self._items) def size(self) -> int: return len(self._items) - def map(self, func: Callable[[T], V]) -> AdaptingSequence[V]: + def map(self, func: Callable[[T_co], V_co]) -> AdaptingSequence[V_co]: return AdaptingSequence([func(item) for item in self._items]) - def retain(self, predicate: Callable[[T], bool]) -> AdaptingSequence[T]: + def retain(self, predicate: Callable[[T_co], bool]) -> AdaptingSequence[T_co]: return AdaptingSequence([item for item in self._items if predicate(item)]) - def prune(self, predicate: Callable[[T], bool]) -> AdaptingSequence[T]: + def prune(self, predicate: Callable[[T_co], bool]) -> AdaptingSequence[T_co]: return AdaptingSequence([item for item in self._items if not predicate(item)]) -class AdaptingSpan(Span): +class AdaptingSpan(BaseAdaptingSequenceItem, Span): """A span that has been adapted to a different format. This class extends the base [`Span`][agentlightning.Span] class to represent spans that have @@ -262,6 +263,10 @@ def with_data(self, data: Any, override: Literal["silent", "warning", "forbidden ) return self.model_copy(update={"data": data}) + def with_container(self, container: BaseAdaptingSequence[AdaptingSpan]) -> AdaptingSpan: + """Create a new [`AdaptingSpan`][agentlightning.AdaptingSpan] with the same properties but different container.""" + return self.model_copy(update={"container": container}) + @classmethod def from_span(cls, span: Span, data: Any) -> AdaptingSpan: """Create an [`AdaptingSpan`][agentlightning.AdaptingSpan] from a base [`Span`][agentlightning.Span]. From 24697accf2b82844667b5658f5926b84d8667327 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Mon, 12 Jan 2026 15:54:53 +0800 Subject: [PATCH 31/41] fix type hints --- agentlightning/adapter/preprocess.py | 16 +++++++++------- agentlightning/types/adapter.py | 2 +- tests/adapter/test_preprocess.py | 4 ++-- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/agentlightning/adapter/preprocess.py b/agentlightning/adapter/preprocess.py index 731bda89b..39bed647b 100644 --- a/agentlightning/adapter/preprocess.py +++ b/agentlightning/adapter/preprocess.py @@ -48,6 +48,7 @@ def __init__(self) -> None: self.forward_graph: Dict[str, List[str]] = defaultdict(list) self.parent_map: Dict[str, str] = {} self.root_ids: Set[str] = set() + self.spans_dict: Dict[str, Span] = {} def add_edge(self, from_node: str, to_node: str) -> None: self.forward_graph[from_node].append(to_node) @@ -103,12 +104,11 @@ def visit(node_id: str) -> None: return ancestors - def to_tree(self, spans: Sequence[T_span]) -> Tree[T_span]: - spans_dict = {span.span_id: span for span in spans} - - def build_subtree(node_id: str) -> Tree[T_span]: + def to_tree(self) -> Tree[AdaptingSpan]: + def build_subtree(node_id: str) -> Tree[AdaptingSpan]: children = [build_subtree(child_id) for child_id in self.forward_graph.get(node_id, [])] - return Tree(spans_dict[node_id], sorted(children, key=lambda child: default_span_order(child.item))) + item = AdaptingSpan.from_span(self.spans_dict[node_id], None) + return Tree(item, sorted(children, key=lambda child: default_span_order(child.item))) if len(self.root_ids) != 1: raise ValueError( @@ -123,12 +123,15 @@ def from_spans(spans: Sequence[Span], logs_invalid_parent: bool = True) -> _Tree valid_span_ids = set(span.span_id for span in spans) graph.root_ids = set(span.span_id for span in spans if span.parent_id is None) + graph.spans_dict = {span.span_id: span for span in spans} for span in spans: if span.parent_id is not None: if span.parent_id in valid_span_ids: graph.forward_graph[span.parent_id].append(span.span_id) graph.root_ids.discard(span.span_id) graph.parent_map[span.span_id] = span.parent_id + # We don't care about the wrong parent id in the spans dict + # This fix here is only for graph construction else: # Span has invalid parent, treat as root graph.root_ids.add(span.span_id) @@ -376,8 +379,7 @@ def adapt(self, source: Sequence[Span]) -> Tree[AdaptingSpan]: source = self._repair_multiple_roots(source) graph = _TreeLikeGraph.from_spans(source) - adapting_spans = [AdaptingSpan.from_span(span, None) for span in source] - return graph.to_tree(adapting_spans) + return graph.to_tree() class ToAdaptingSpans(Adapter[Sequence[Span], AdaptingSequence[AdaptingSpan]]): diff --git a/agentlightning/types/adapter.py b/agentlightning/types/adapter.py index f129b58ab..b093c7b7d 100644 --- a/agentlightning/types/adapter.py +++ b/agentlightning/types/adapter.py @@ -224,7 +224,7 @@ def prune(self, predicate: Callable[[T_co], bool]) -> AdaptingSequence[T_co]: return AdaptingSequence([item for item in self._items if not predicate(item)]) -class AdaptingSpan(BaseAdaptingSequenceItem, Span): +class AdaptingSpan(Span): """A span that has been adapted to a different format. This class extends the base [`Span`][agentlightning.Span] class to represent spans that have diff --git a/tests/adapter/test_preprocess.py b/tests/adapter/test_preprocess.py index a0c548e8c..4caea5932 100644 --- a/tests/adapter/test_preprocess.py +++ b/tests/adapter/test_preprocess.py @@ -174,7 +174,7 @@ def test_tree_like_graph_to_tree_single_root(): child2 = make_span("child2", "child", parent_id="root", start_time=5.0, end_time=9.0) graph = _TreeLikeGraph.from_spans([root, child1, child2]) - tree = graph.to_tree([root, child1, child2]) + tree = graph.to_tree() assert tree.item.span_id == "root" assert len(tree.children) == 2 @@ -189,7 +189,7 @@ def test_tree_like_graph_to_tree_multiple_roots_raises(): graph = _TreeLikeGraph.from_spans([root1, root2]) with pytest.raises(ValueError, match="multiple or no roots"): - graph.to_tree([root1, root2]) + graph.to_tree() # Tests for ToSpans From 70fdf7df2955030416ab8efcbf3c8c06e5c950bb Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Mon, 12 Jan 2026 16:07:06 +0800 Subject: [PATCH 32/41] update postprocess --- agentlightning/adapter/postprocess.py | 26 ++- tests/types/test_adapter.py | 226 +++++++++++++++----------- 2 files changed, 150 insertions(+), 102 deletions(-) diff --git a/agentlightning/adapter/postprocess.py b/agentlightning/adapter/postprocess.py index 0fdc6c461..090b1c0af 100644 --- a/agentlightning/adapter/postprocess.py +++ b/agentlightning/adapter/postprocess.py @@ -9,26 +9,36 @@ from agentlightning.types.adapter import ( AccumulatedMessages, AccumulatedTokenSequence, + AdaptingSpan, AnnotatedChatCompletionCall, + BaseAdaptingSequence, TokenInputOutputTriplet, ) from .base import Adapter -class AccumulateTokenSequence(Adapter[Sequence[TokenInputOutputTriplet], Sequence[AccumulatedTokenSequence]]): +class ToTokensTriplets(Adapter[BaseAdaptingSequence[AdaptingSpan], Sequence[TokenTriplet]]): + """Convert adapting spans to token input-output triplets.""" + + def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[TokenInputOutputTriplet]: ... + + +class ToTokensAccumulations(Adapter[BaseAdaptingSequence[AdaptingSpan], Sequence[TokensAccumulation]]): """Assemble multiple token input-output triplets into accumulated token sequences.""" - def adapt(self, source: Sequence[TokenInputOutputTriplet]) -> Sequence[AccumulatedTokenSequence]: ... + def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[AccumulatedTokenSequence]: ... -class AccumulateMessages(Adapter[Sequence[AnnotatedChatCompletionCall], Sequence[AccumulatedMessages]]): - """Assemble multiple token input-output triplets into accumulated chat messages.""" +class ToPromptCompletionTriplets(Adapter[BaseAdaptingSequence[AdaptingSpan], Sequence[PromptCompletionTriplet]]): + """Convert annotated chat completion calls to prompt-completion triplets.""" - def adapt(self, source: Sequence[AnnotatedChatCompletionCall]) -> Sequence[AccumulatedMessages]: ... + def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[PromptCompletionTriplet]: ... -class ToTokenInputOutputTriplet(Adapter[Sequence[AnnotatedChatCompletionCall], Sequence[TokenInputOutputTriplet]]): - """Convert annotated chat completion calls to token input-output triplets.""" +class ToPromptCompletionAccumulations( + Adapter[BaseAdaptingSequence[AdaptingSpan], Sequence[PromptCompletionAccumulation]] +): + """Assemble multiple prompt-completion triplets into accumulated prompt-completion pairs.""" - def adapt(self, source: Sequence[AnnotatedChatCompletionCall]) -> Sequence[TokenInputOutputTriplet]: ... + def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[PromptCompletionAccumulation]: ... diff --git a/tests/types/test_adapter.py b/tests/types/test_adapter.py index 0f9659819..d11a39d62 100644 --- a/tests/types/test_adapter.py +++ b/tests/types/test_adapter.py @@ -3,7 +3,8 @@ """Tests for Tree, AdaptingSequence, and AdaptingSpan data structures.""" import logging -from typing import Any, Dict, Optional +from collections import UserString +from typing import Any, Dict, List, Optional import pytest @@ -11,6 +12,43 @@ from agentlightning.types.adapter import AdaptingSequence, AdaptingSpan, Tree +class SequenceTestString(UserString): + """Simple string wrapper that implements with_container for BaseAdaptingSequenceItem.""" + + def with_container(self, container: Any) -> "SequenceTestString": + return type(self)(self.data) + + +class SequenceTestInt(int): + """Simple int wrapper that implements with_container for BaseAdaptingSequenceItem.""" + + def __new__(cls, value: int) -> "SequenceTestInt": + return int.__new__(cls, value) + + def with_container(self, container: Any) -> "SequenceTestInt": + return type(self)(int(self)) + + +def s(value: str) -> SequenceTestString: + """Helper to create SequenceTestString instances.""" + return SequenceTestString(value) + + +def i(value: int) -> SequenceTestInt: + """Helper to create SequenceTestInt instances.""" + return SequenceTestInt(value) + + +def strs(*values: str) -> list[SequenceTestString]: + """Helper to create lists of SequenceTestString instances.""" + return [s(value) for value in values] + + +def ints(*values: int) -> list[SequenceTestInt]: + """Helper to create lists of SequenceTestInt instances.""" + return [i(value) for value in values] + + def make_span( name: str, attributes: Dict[str, Any], @@ -45,16 +83,16 @@ def make_span( def test_tree_single_node(): - tree = Tree("root", []) + tree = Tree(s("root"), []) assert tree.item == "root" assert tree.children == [] assert tree.parent is None def test_tree_with_children(): - child1 = Tree("child1", []) - child2 = Tree("child2", []) - root = Tree("root", [child1, child2]) + child1 = Tree(s("child1"), []) + child2 = Tree(s("child2"), []) + root = Tree(s("root"), [child1, child2]) assert root.item == "root" assert len(root.children) == 2 @@ -63,9 +101,9 @@ def test_tree_with_children(): def test_tree_parent_reference(): - grandchild = Tree("grandchild", []) - child = Tree("child", [grandchild]) - root = Tree("root", [child]) + grandchild = Tree(s("grandchild"), []) + child = Tree(s("child"), [grandchild]) + root = Tree(s("root"), [child]) assert root.parent is None assert child.parent is root @@ -73,28 +111,28 @@ def test_tree_parent_reference(): def test_tree_len(): - child1 = Tree("child1", []) - child2 = Tree("child2", []) - root = Tree("root", [child1, child2]) + child1 = Tree(s("child1"), []) + child2 = Tree(s("child2"), []) + root = Tree(s("root"), [child1, child2]) assert len(root) == 3 def test_tree_len_deep(): - grandchild = Tree("grandchild", []) - child = Tree("child", [grandchild]) - root = Tree("root", [child]) + grandchild = Tree(s("grandchild"), []) + child = Tree(s("child"), [grandchild]) + root = Tree(s("root"), [child]) assert len(root) == 3 def test_tree_getitem_single(): - root = Tree("root", []) + root = Tree(s("root"), []) assert root[0] == "root" def test_tree_getitem_with_children(): - child1 = Tree("child1", []) - child2 = Tree("child2", []) - root = Tree("root", [child1, child2]) + child1 = Tree(s("child1"), []) + child2 = Tree(s("child2"), []) + root = Tree(s("root"), [child1, child2]) # DFS order: root, child1, child2 assert root[0] == "root" assert root[1] == "child1" @@ -102,21 +140,21 @@ def test_tree_getitem_with_children(): def test_tree_getitem_slice(): - child1 = Tree("child1", []) - child2 = Tree("child2", []) - root = Tree("root", [child1, child2]) + child1 = Tree(s("child1"), []) + child2 = Tree(s("child2"), []) + root = Tree(s("root"), [child1, child2]) assert root[1:] == ["child1", "child2"] def test_tree_iter(): - child1 = Tree("child1", []) - child2 = Tree("child2", []) - root = Tree("root", [child1, child2]) + child1 = Tree(s("child1"), []) + child2 = Tree(s("child2"), []) + root = Tree(s("root"), [child1, child2]) assert list(root) == ["root", "child1", "child2"] def test_tree_traverse_single_node(): - tree = Tree("root", []) + tree = Tree(s("root"), []) assert list(tree.traverse()) == ["root"] @@ -126,59 +164,59 @@ def test_tree_traverse_dfs_order(): # child1 child2 # | # grandchild - grandchild = Tree("grandchild", []) - child1 = Tree("child1", [grandchild]) - child2 = Tree("child2", []) - root = Tree("root", [child1, child2]) + grandchild = Tree(s("grandchild"), []) + child1 = Tree(s("child1"), [grandchild]) + child2 = Tree(s("child2"), []) + root = Tree(s("root"), [child1, child2]) # DFS order: root, child1, grandchild, child2 assert list(root.traverse()) == ["root", "child1", "grandchild", "child2"] def test_tree_size_single_node(): - tree = Tree("root", []) + tree = Tree(s("root"), []) assert tree.size() == 1 def test_tree_size_with_children(): - grandchild = Tree("grandchild", []) - child1 = Tree("child1", [grandchild]) - child2 = Tree("child2", []) - root = Tree("root", [child1, child2]) + grandchild = Tree(s("grandchild"), []) + child1 = Tree(s("child1"), [grandchild]) + child2 = Tree(s("child2"), []) + root = Tree(s("root"), [child1, child2]) assert root.size() == 4 def test_tree_map_single_node(): - tree = Tree(1, []) - mapped = tree.map(lambda x: x * 2) + tree = Tree(i(1), []) + mapped = tree.map(lambda x: type(x)(x * 2)) assert mapped.item == 2 assert mapped.children == [] def test_tree_map_with_children(): - child1 = Tree(2, []) - child2 = Tree(3, []) - root = Tree(1, [child1, child2]) + child1 = Tree(i(2), []) + child2 = Tree(i(3), []) + root = Tree(i(1), [child1, child2]) - mapped = root.map(lambda x: x * 10) + mapped = root.map(lambda x: type(x)(x * 10)) assert mapped.item == 10 assert mapped.children[0].item == 20 assert mapped.children[1].item == 30 def test_tree_map_preserves_structure(): - grandchild = Tree("gc", []) - child = Tree("c", [grandchild]) - root = Tree("r", [child]) + grandchild = Tree(s("gc"), []) + child = Tree(s("c"), [grandchild]) + root = Tree(s("r"), [child]) - mapped = root.map(str.upper) + mapped = root.map(lambda x: type(x)(x.upper())) assert mapped.item == "R" assert mapped.children[0].item == "C" assert mapped.children[0].children[0].item == "GC" def test_tree_retain_keeps_root_always(): - tree = Tree("root", []) + tree = Tree(s("root"), []) retained = tree.retain(lambda x: False) assert retained.item == "root" assert retained.children == [] @@ -190,10 +228,10 @@ def test_tree_retain_keeps_matching_subtrees(): # keep1 drop1 # | # drop2 - drop2 = Tree("drop2", []) - keep1 = Tree("keep1", [drop2]) - drop1 = Tree("drop1", []) - root = Tree("root", [keep1, drop1]) + drop2 = Tree(s("drop2"), []) + keep1 = Tree(s("keep1"), [drop2]) + drop1 = Tree(s("drop1"), []) + root = Tree(s("root"), [keep1, drop1]) # Retain subtrees rooted at nodes containing "keep" retained = root.retain(lambda x: "keep" in x) @@ -210,10 +248,10 @@ def test_tree_retain_removes_branches_without_matches(): # drop1 drop2 # | # keep - keep = Tree("keep", []) - drop1 = Tree("drop1", [keep]) - drop2 = Tree("drop2", []) - root = Tree("root", [drop1, drop2]) + keep = Tree(s("keep"), []) + drop1 = Tree(s("drop1"), [keep]) + drop2 = Tree(s("drop2"), []) + root = Tree(s("root"), [drop1, drop2]) retained = root.retain(lambda x: x == "keep") @@ -231,10 +269,10 @@ def test_tree_retain_deep_tree(): # b (keep) # | # c - c = Tree("c", []) - b = Tree("b", [c]) - a = Tree("a", [b]) - root = Tree("root", [a]) + c = Tree(s("c"), []) + b = Tree(s("b"), [c]) + a = Tree(s("a"), [b]) + root = Tree(s("root"), [a]) retained = root.retain(lambda x: x == "b") # When b matches, the entire subtree rooted at b (including c) is retained @@ -242,15 +280,15 @@ def test_tree_retain_deep_tree(): def test_tree_prune_does_not_remove_root(): - tree = Tree("root", []) + tree = Tree(s("root"), []) pruned = tree.prune(lambda x: x == "root") assert pruned.item == "root" def test_tree_prune_removes_matching_children(): - child1 = Tree("remove_me", []) - child2 = Tree("keep_me", []) - root = Tree("root", [child1, child2]) + child1 = Tree(s("remove_me"), []) + child2 = Tree(s("keep_me"), []) + root = Tree(s("root"), [child1, child2]) pruned = root.prune(lambda x: x == "remove_me") @@ -265,10 +303,10 @@ def test_tree_prune_removes_subtrees(): # remove keep # | # child_of_remove - child_of_remove = Tree("child_of_remove", []) - remove = Tree("remove", [child_of_remove]) - keep = Tree("keep", []) - root = Tree("root", [remove, keep]) + child_of_remove = Tree(s("child_of_remove"), []) + remove = Tree(s("remove"), [child_of_remove]) + keep = Tree(s("keep"), []) + root = Tree(s("root"), [remove, keep]) pruned = root.prune(lambda x: x == "remove") @@ -283,10 +321,10 @@ def test_tree_prune_recursive(): # keep # / \ # remove keep2 - keep2 = Tree("keep2", []) - remove = Tree("remove", []) - keep = Tree("keep", [remove, keep2]) - root = Tree("root", [keep]) + keep2 = Tree(s("keep2"), []) + remove = Tree(s("remove"), []) + keep = Tree(s("keep"), [remove, keep2]) + root = Tree(s("root"), [keep]) pruned = root.prune(lambda x: x == "remove") @@ -302,126 +340,126 @@ def test_tree_prune_recursive(): def test_adapting_sequence_empty(): - seq = AdaptingSequence([]) + seq = AdaptingSequence[Any]([]) assert len(seq) == 0 assert list(seq) == [] def test_adapting_sequence_with_items(): - seq = AdaptingSequence([1, 2, 3]) + seq = AdaptingSequence(ints(1, 2, 3)) assert len(seq) == 3 assert list(seq) == [1, 2, 3] def test_adapting_sequence_getitem_single(): - seq = AdaptingSequence(["a", "b", "c"]) + seq = AdaptingSequence(strs("a", "b", "c")) assert seq[0] == "a" assert seq[1] == "b" assert seq[2] == "c" def test_adapting_sequence_getitem_negative_index(): - seq = AdaptingSequence(["a", "b", "c"]) + seq = AdaptingSequence(strs("a", "b", "c")) assert seq[-1] == "c" def test_adapting_sequence_getitem_slice(): - seq = AdaptingSequence(["a", "b", "c", "d"]) + seq = AdaptingSequence(strs("a", "b", "c", "d")) assert seq[1:3] == ["b", "c"] def test_adapting_sequence_iter(): - seq = AdaptingSequence([1, 2, 3]) - result = [] + seq = AdaptingSequence(ints(1, 2, 3)) + result: List[Any] = [] for item in seq: result.append(item) assert result == [1, 2, 3] def test_adapting_sequence_traverse(): - seq = AdaptingSequence([1, 2, 3]) + seq = AdaptingSequence(ints(1, 2, 3)) assert list(seq.traverse()) == [1, 2, 3] def test_adapting_sequence_size(): - seq = AdaptingSequence([1, 2, 3, 4]) + seq = AdaptingSequence(ints(1, 2, 3, 4)) assert seq.size() == 4 def test_adapting_sequence_get(): - seq = AdaptingSequence(["x", "y", "z"]) + seq = AdaptingSequence(strs("x", "y", "z")) assert seq.get(0) == "x" assert seq.get(1) == "y" def test_adapting_sequence_map_empty(): - seq = AdaptingSequence([]) - mapped = seq.map(lambda x: x * 2) + seq = AdaptingSequence[Any]([]) + mapped = seq.map(lambda x: type(x)(x * 2)) assert list(mapped) == [] def test_adapting_sequence_map_integers(): - seq = AdaptingSequence([1, 2, 3]) - mapped = seq.map(lambda x: x * 2) + seq = AdaptingSequence(ints(1, 2, 3)) + mapped = seq.map(lambda x: type(x)(x * 2)) assert list(mapped) == [2, 4, 6] def test_adapting_sequence_map_strings(): - seq = AdaptingSequence(["a", "b", "c"]) - mapped = seq.map(str.upper) + seq = AdaptingSequence(strs("a", "b", "c")) + mapped = seq.map(lambda x: type(x)(x.upper())) assert list(mapped) == ["A", "B", "C"] def test_adapting_sequence_map_returns_adapting_sequence(): - seq = AdaptingSequence([1, 2, 3]) + seq = AdaptingSequence(ints(1, 2, 3)) mapped = seq.map(lambda x: x) assert isinstance(mapped, AdaptingSequence) def test_adapting_sequence_retain_all(): - seq = AdaptingSequence([1, 2, 3]) + seq = AdaptingSequence(ints(1, 2, 3)) retained = seq.retain(lambda x: True) assert list(retained) == [1, 2, 3] def test_adapting_sequence_retain_none(): - seq = AdaptingSequence([1, 2, 3]) + seq = AdaptingSequence(ints(1, 2, 3)) retained = seq.retain(lambda x: False) assert list(retained) == [] def test_adapting_sequence_retain_some(): - seq = AdaptingSequence([1, 2, 3, 4, 5]) + seq = AdaptingSequence(ints(1, 2, 3, 4, 5)) retained = seq.retain(lambda x: x % 2 == 0) assert list(retained) == [2, 4] def test_adapting_sequence_retain_returns_adapting_sequence(): - seq = AdaptingSequence([1, 2, 3]) + seq = AdaptingSequence(ints(1, 2, 3)) retained = seq.retain(lambda x: True) assert isinstance(retained, AdaptingSequence) def test_adapting_sequence_prune_none(): - seq = AdaptingSequence([1, 2, 3]) + seq = AdaptingSequence(ints(1, 2, 3)) pruned = seq.prune(lambda x: False) assert list(pruned) == [1, 2, 3] def test_adapting_sequence_prune_all(): - seq = AdaptingSequence([1, 2, 3]) + seq = AdaptingSequence(ints(1, 2, 3)) pruned = seq.prune(lambda x: True) assert list(pruned) == [] def test_adapting_sequence_prune_some(): - seq = AdaptingSequence([1, 2, 3, 4, 5]) + seq = AdaptingSequence(ints(1, 2, 3, 4, 5)) pruned = seq.prune(lambda x: x % 2 == 0) assert list(pruned) == [1, 3, 5] def test_adapting_sequence_prune_returns_adapting_sequence(): - seq = AdaptingSequence([1, 2, 3]) + seq = AdaptingSequence(ints(1, 2, 3)) pruned = seq.prune(lambda x: False) assert isinstance(pruned, AdaptingSequence) From fe2fb305a400749c4ad056e1016f969b32021aab Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Tue, 13 Jan 2026 11:08:44 +0800 Subject: [PATCH 33/41] update tokens/messages accumulation --- agentlightning/adapter/postprocess.py | 30 +++++++--- agentlightning/types/adapter.py | 82 +++++++++++++++++---------- 2 files changed, 75 insertions(+), 37 deletions(-) diff --git a/agentlightning/adapter/postprocess.py b/agentlightning/adapter/postprocess.py index 090b1c0af..46cb2876c 100644 --- a/agentlightning/adapter/postprocess.py +++ b/agentlightning/adapter/postprocess.py @@ -4,30 +4,35 @@ from __future__ import annotations -from typing import Sequence +from typing import Literal, Sequence, TypeVar, Union from agentlightning.types.adapter import ( - AccumulatedMessages, - AccumulatedTokenSequence, AdaptingSpan, - AnnotatedChatCompletionCall, BaseAdaptingSequence, - TokenInputOutputTriplet, + PromptCompletionAccumulation, + PromptCompletionTriplet, + TokensAccumulation, + TokensTriplet, ) from .base import Adapter +T_triplet_or_accumulation = TypeVar( + "T_triplet_or_accumulation", + bound=Union[TokensTriplet, TokensAccumulation, PromptCompletionTriplet, PromptCompletionAccumulation], +) + -class ToTokensTriplets(Adapter[BaseAdaptingSequence[AdaptingSpan], Sequence[TokenTriplet]]): +class ToTokensTriplets(Adapter[BaseAdaptingSequence[AdaptingSpan], Sequence[TokensTriplet]]): """Convert adapting spans to token input-output triplets.""" - def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[TokenInputOutputTriplet]: ... + def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[TokensTriplet]: ... class ToTokensAccumulations(Adapter[BaseAdaptingSequence[AdaptingSpan], Sequence[TokensAccumulation]]): """Assemble multiple token input-output triplets into accumulated token sequences.""" - def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[AccumulatedTokenSequence]: ... + def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[TokensAccumulation]: ... class ToPromptCompletionTriplets(Adapter[BaseAdaptingSequence[AdaptingSpan], Sequence[PromptCompletionTriplet]]): @@ -42,3 +47,12 @@ class ToPromptCompletionAccumulations( """Assemble multiple prompt-completion triplets into accumulated prompt-completion pairs.""" def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[PromptCompletionAccumulation]: ... + + +class PropagateRewards(Adapter[Sequence[T_triplet_or_accumulation], Sequence[T_triplet_or_accumulation]]): + """Propagate rewards forward or backward from one triplet or accumulation to the next.""" + + def __init__(self, direction: Literal["forward", "backward"]) -> None: + self.direction = direction + + def adapt(self, source: Sequence[T_triplet_or_accumulation]) -> Sequence[T_triplet_or_accumulation]: ... diff --git a/agentlightning/types/adapter.py b/agentlightning/types/adapter.py index b093c7b7d..a825cdedb 100644 --- a/agentlightning/types/adapter.py +++ b/agentlightning/types/adapter.py @@ -452,6 +452,41 @@ class AnnotatedChatCompletionCall(ChatCompletionCall): # Algorithm-specific requirements +T_observation = TypeVar("T_observation") +T_action = TypeVar("T_action") + + +class Triplet(Generic[T_observation, T_action], BaseModel): + """A triplet of observation, action and reward.""" + + observation: T_observation + """Observation for the model input.""" + + completion: T_action + """Action from the model output.""" + + reward: Optional[float] + """Reward of the model input.""" + + done: bool + """Whether it's the end of the trajectory.""" + + raw_call: Union[AnnotatedChatCompletionCall, ChatCompletionCall] + """Raw chat completion call.""" + + +class Accumulation(BaseModel): + """Accumulation from multiple triplets.""" + + final_reward: Optional[float] + """Single reward value for the entire sequence. + + An accumulation can only have one single final reward value. + """ + + raw_calls: Sequence[Union[AnnotatedChatCompletionCall, ChatCompletionCall]] + """Raw chat completion calls. The order of the calls must be the same as the order of the token IDs.""" + class TokenInput(BaseModel): """Token-based model input.""" @@ -459,7 +494,7 @@ class TokenInput(BaseModel): token_ids: Sequence[int] """Token IDs of the model input.""" - image_urls: Any + image_urls: Sequence[str] """A list of image URLs. Could be pointers to local files or base64-encoded images.""" @@ -469,46 +504,41 @@ class TokenOutput(BaseModel): token_ids: Sequence[int] """Token IDs of the model output.""" + logprobs: Optional[Sequence[float]] + """Log probabilities of the model output.""" -class TokenInputOutputTriplet(BaseModel): + +class TokensTriplet(Triplet[TokenInput, TokenOutput]): """A triplet of token IDs for the input and output, useful for reinforcement learning. This is not a stable interface and the fields here highly depend on RL implementations. """ - observation: TokenInput - """Observation for the model input. Corresponding to prompt.""" - - action: TokenOutput - """Action, corresponding to completion result.""" - - reward: Optional[float] - """Reward of the model input.""" - - done: bool - """Whether it's the end of the trajectory.""" - - raw_call: AnnotatedChatCompletionCall - """Raw chat completion call.""" - -class AccumulatedTokenSequence(TokenInput): +class TokensAccumulation(Accumulation): """A sequence of token IDs that are accumulated from multiple model calls. Output is implied in the token IDs. """ + token_ids: Sequence[int] + """Token IDs of the model input and output.""" + + image_urls: Sequence[str] + """Image URLs of the model input and output.""" + + logprobs: Optional[Sequence[float]] + """Log probabilities of the model output.""" + response_mask: Sequence[int] """Mask for the response tokens. Must a sequence of 0s and 1s, with 1s for the completion tokens and 0s for the prompt tokens.""" - final_reward: Optional[float] - """Single reward value for the entire sequence.""" - raw_calls: Sequence[AnnotatedChatCompletionCall] - """Raw chat completion calls. The order of the calls must be the same as the order of the token IDs.""" +class PromptCompletionTriplet(Triplet[Sequence[ChatCompletionMessageParam], ChatCompletionCall]): + """A triplet of prompt and completion.""" -class AccumulatedMessages(BaseModel): +class PromptCompletionAccumulation(Accumulation): """A conversation that is accumulated from multiple model calls.""" messages: Sequence[ChatCompletionMessageParam] @@ -516,9 +546,3 @@ class AccumulatedMessages(BaseModel): tools: Optional[Sequence[ChatCompletionFunctionToolParam]] """Tools provided for the conversation.""" - - final_reward: Optional[float] - """Single reward value for the entire conversation.""" - - raw_calls: Sequence[AnnotatedChatCompletionCall] - """Raw chat completion calls. The order of the calls must be the same as the order of the messages.""" From fba55d043146f045b5f6341031b812bb2363089b Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Tue, 13 Jan 2026 12:08:00 +0800 Subject: [PATCH 34/41] implement postprocess --- agentlightning/adapter/call.py | 28 +++--- agentlightning/adapter/postprocess.py | 120 +++++++++++++++++++++++++- agentlightning/types/adapter.py | 11 ++- 3 files changed, 138 insertions(+), 21 deletions(-) diff --git a/agentlightning/adapter/call.py b/agentlightning/adapter/call.py index 6600a7212..067a247ac 100644 --- a/agentlightning/adapter/call.py +++ b/agentlightning/adapter/call.py @@ -165,19 +165,13 @@ def _parse_openai_chat_completion_create(self, span: AdaptingSpan) -> ChatComple usages = self._parse_usages(span) response_metadata = self._parse_response(span) - request = self._normalize_request({"messages": prompt_messages, **request_metadata}) - response = ChatCompletion.model_validate( + return self._construct_chat_completion_call( + {"messages": prompt_messages, **request_metadata}, { **response_metadata, "choices": completion_choices, "usage": usages, - } - ) - - return ChatCompletionCall.model_construct( - request=request, - response=response, - malformed_fields={}, # TODO: malformed fields + }, ) def _augment_litellm_raw_gen_ai_request( @@ -223,12 +217,16 @@ def _parse_litellm_request(self, span: AdaptingSpan) -> ChatCompletionCall: if sibling.name == "raw_gen_ai_request": self._augment_litellm_raw_gen_ai_request(sibling, request_body, response_body) - return ChatCompletionCall( - request=cast( - CompletionCreateParams, - TypeAdapter(CompletionCreateParams).validate_python(request_body), - ), - response=ChatCompletion.model_validate(response_body), + return self._construct_chat_completion_call(request_body, response_body) + + def _construct_chat_completion_call( + self, request_body: Dict[str, Any], response_body: Dict[str, Any] + ) -> ChatCompletionCall: + request = self._normalize_request(request_body) + response = ChatCompletion.model_validate(response_body, extra="allow") + return ChatCompletionCall.model_construct( + request=request, + response=response, malformed_fields={}, # TODO: malformed fields ) diff --git a/agentlightning/adapter/postprocess.py b/agentlightning/adapter/postprocess.py index 46cb2876c..eba1b0136 100644 --- a/agentlightning/adapter/postprocess.py +++ b/agentlightning/adapter/postprocess.py @@ -4,13 +4,18 @@ from __future__ import annotations -from typing import Literal, Sequence, TypeVar, Union +from typing import Any, List, Literal, Optional, Sequence, TypeVar, Union, cast from agentlightning.types.adapter import ( AdaptingSpan, + AnnotatedChatCompletionCall, BaseAdaptingSequence, + ChatCompletionCall, + GeneralAnnotation, PromptCompletionAccumulation, PromptCompletionTriplet, + TokenInput, + TokenOutput, TokensAccumulation, TokensTriplet, ) @@ -24,9 +29,116 @@ class ToTokensTriplets(Adapter[BaseAdaptingSequence[AdaptingSpan], Sequence[TokensTriplet]]): - """Convert adapting spans to token input-output triplets.""" - - def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[TokensTriplet]: ... + """Convert adapting spans to token input-output triplets. + + Args: + strict: Whether to raise an exception if the triplet cannot be created. + If False, the exception will be added to the list of exceptions and the triplet will be skipped. + The exceptions will also be raised when the resulting sequence is empty. + If True, the exception will be raised. + Default is False. + """ + + def __init__(self, strict: bool = False): + self.strict = strict + + def _get_prompt_token_ids(self, call: Union[AnnotatedChatCompletionCall, ChatCompletionCall]) -> Sequence[int]: + prompt = call.request + if "prompt_token_ids" in prompt: + prompt_token_ids = prompt["prompt_token_ids"] + else: + raise ValueError(f"Prompt token ids not found in call: {call}") + # Validate the prompt token ids + if not isinstance(prompt_token_ids, list) or not all(isinstance(x, int) for x in prompt_token_ids): # type: ignore + raise ValueError(f"Invalid prompt token ids. Must be a list of ints. Got: {prompt_token_ids}") + if len(prompt_token_ids) == 0: + raise ValueError("Prompt token ids is empty.") + return prompt_token_ids + + def _get_image_urls(self, call: Union[AnnotatedChatCompletionCall, ChatCompletionCall]) -> Sequence[str]: + image_urls: List[str] = [] + for message in call.request["messages"]: + if "content" not in message: + continue + content = message["content"] + if content is None: + continue + if isinstance(content, list): + for part in content: + if part["type"] == "image_url": + image_urls.append(part["image_url"]["url"]) + return image_urls + + def _get_completion_token_ids(self, call: Union[AnnotatedChatCompletionCall, ChatCompletionCall]) -> Sequence[int]: + completion_choice = call.response.choices[0] + if hasattr(completion_choice, "token_ids") and completion_choice.token_ids is not None: # type: ignore + response_token_ids = cast(Any, completion_choice.token_ids) # type: ignore + elif hasattr(completion_choice, "provider_specific_fields") and "token_ids" in completion_choice.provider_specific_fields: # type: ignore + response_token_ids = cast(Any, completion_choice.provider_specific_fields["token_ids"]) # type: ignore + else: + raise ValueError(f"Completion token ids not found in call: {call}") + if not isinstance(response_token_ids, list) or not all(isinstance(x, int) for x in response_token_ids): # type: ignore + raise ValueError(f"Invalid completion token ids. Must be a list of ints. Got: {response_token_ids}") + response_token_ids = cast(Sequence[int], response_token_ids) + if len(response_token_ids) == 0: + raise ValueError("Completion token ids is empty.") + return response_token_ids + + def _get_logprobs(self, call: Union[AnnotatedChatCompletionCall, ChatCompletionCall]) -> Optional[Sequence[float]]: + logprobs = call.response.choices[0].logprobs + if logprobs is not None: + content_logprobs = logprobs.content + if content_logprobs is not None: + return [logprob.logprob for logprob in content_logprobs] + return None + + def _get_reward(self, call: Union[AnnotatedChatCompletionCall, ChatCompletionCall]) -> Optional[float]: + if isinstance(call, AnnotatedChatCompletionCall): + for annotation in call.annotations: + # The first general annotation + if isinstance(annotation, GeneralAnnotation): + # The primary reward + return annotation.primary_reward + return None + + def _to_tokens_triplet( + self, call: Union[AnnotatedChatCompletionCall, ChatCompletionCall] + ) -> Union[TokensTriplet, BaseException]: + try: + return TokensTriplet( + observation=TokenInput( + token_ids=self._get_prompt_token_ids(call), image_urls=self._get_image_urls(call) + ), + completion=TokenOutput( + token_ids=self._get_completion_token_ids(call), logprobs=self._get_logprobs(call) + ), + reward=self._get_reward(call), + done=False, # False by now + raw_call=call, + ) + except Exception as exc: + if self.strict: + raise exc + return exc + + def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[TokensTriplet]: + exceptions: List[BaseException] = [] + triplets: List[TokensTriplet] = [] + for span in source: + if isinstance(span.data, (AnnotatedChatCompletionCall, ChatCompletionCall)): + triplet = self._to_tokens_triplet(span.data) + if isinstance(triplet, BaseException): + exceptions.append(triplet) + else: + triplets.append(triplet) + if len(triplets) == 0: + error_msg = ( + f"{self.__class__.__name__} failed to create any triplets. " + "The adapter has raised {len(exceptions)} exceptions when processing the spans:\n" + + "\n".join([f" - {exc}" for exc in exceptions]) + ) + raise RuntimeError(error_msg) + return triplets class ToTokensAccumulations(Adapter[BaseAdaptingSequence[AdaptingSpan], Sequence[TokensAccumulation]]): diff --git a/agentlightning/types/adapter.py b/agentlightning/types/adapter.py index a825cdedb..aae2b9522 100644 --- a/agentlightning/types/adapter.py +++ b/agentlightning/types/adapter.py @@ -28,7 +28,7 @@ ChatCompletionMessageParam, CompletionCreateParams, ) -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from typing_extensions import Self from agentlightning.semconv import LinkPydanticModel, RewardPydanticModel @@ -507,6 +507,13 @@ class TokenOutput(BaseModel): logprobs: Optional[Sequence[float]] """Log probabilities of the model output.""" + @model_validator(mode="after") + def validate_logprobs(self) -> Self: + if self.logprobs is not None: + if len(self.logprobs) != len(self.token_ids): + raise ValueError("Log probabilities must be the same length as the token IDs.") + return self + class TokensTriplet(Triplet[TokenInput, TokenOutput]): """A triplet of token IDs for the input and output, useful for reinforcement learning. @@ -534,7 +541,7 @@ class TokensAccumulation(Accumulation): """Mask for the response tokens. Must a sequence of 0s and 1s, with 1s for the completion tokens and 0s for the prompt tokens.""" -class PromptCompletionTriplet(Triplet[Sequence[ChatCompletionMessageParam], ChatCompletionCall]): +class PromptCompletionTriplet(Triplet[Sequence[ChatCompletionMessageParam], ChatCompletion]): """A triplet of prompt and completion.""" From bbfb4a3d61619b1605c70f08ed386d1dff89821d Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Tue, 13 Jan 2026 15:11:25 +0800 Subject: [PATCH 35/41] postprocess --- agentlightning/adapter/postprocess.py | 184 +++++++++++++++++++++++--- agentlightning/types/adapter.py | 33 +++++ 2 files changed, 199 insertions(+), 18 deletions(-) diff --git a/agentlightning/adapter/postprocess.py b/agentlightning/adapter/postprocess.py index eba1b0136..a220eeeec 100644 --- a/agentlightning/adapter/postprocess.py +++ b/agentlightning/adapter/postprocess.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import Any, List, Literal, Optional, Sequence, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Iterable, List, Literal, Optional, Sequence, TypeVar, Union, cast from agentlightning.types.adapter import ( AdaptingSpan, @@ -17,9 +17,13 @@ TokenInput, TokenOutput, TokensAccumulation, + TokensAccumulationDiagnosis, TokensTriplet, ) +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + from .base import Adapter T_triplet_or_accumulation = TypeVar( @@ -28,16 +32,7 @@ ) -class ToTokensTriplets(Adapter[BaseAdaptingSequence[AdaptingSpan], Sequence[TokensTriplet]]): - """Convert adapting spans to token input-output triplets. - - Args: - strict: Whether to raise an exception if the triplet cannot be created. - If False, the exception will be added to the list of exceptions and the triplet will be skipped. - The exceptions will also be raised when the resulting sequence is empty. - If True, the exception will be raised. - Default is False. - """ +class TokensTripletMixin: def __init__(self, strict: bool = False): self.strict = strict @@ -101,7 +96,7 @@ def _get_reward(self, call: Union[AnnotatedChatCompletionCall, ChatCompletionCal return annotation.primary_reward return None - def _to_tokens_triplet( + def to_triplet( self, call: Union[AnnotatedChatCompletionCall, ChatCompletionCall] ) -> Union[TokensTriplet, BaseException]: try: @@ -121,12 +116,12 @@ def _to_tokens_triplet( raise exc return exc - def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[TokensTriplet]: + def to_triplets(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[TokensTriplet]: exceptions: List[BaseException] = [] triplets: List[TokensTriplet] = [] for span in source: if isinstance(span.data, (AnnotatedChatCompletionCall, ChatCompletionCall)): - triplet = self._to_tokens_triplet(span.data) + triplet = self.to_triplet(span.data) if isinstance(triplet, BaseException): exceptions.append(triplet) else: @@ -134,17 +129,170 @@ def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[TokensTr if len(triplets) == 0: error_msg = ( f"{self.__class__.__name__} failed to create any triplets. " - "The adapter has raised {len(exceptions)} exceptions when processing the spans:\n" + f"The adapter has raised {len(exceptions)} exceptions when processing the spans:\n" + "\n".join([f" - {exc}" for exc in exceptions]) ) raise RuntimeError(error_msg) + triplets[-1] = triplets[-1].model_copy(update={"done": True}) return triplets -class ToTokensAccumulations(Adapter[BaseAdaptingSequence[AdaptingSpan], Sequence[TokensAccumulation]]): - """Assemble multiple token input-output triplets into accumulated token sequences.""" +class ToTokensTriplets(TokensTripletMixin, Adapter[BaseAdaptingSequence[AdaptingSpan], Sequence[TokensTriplet]]): + """Convert adapting spans to token input-output triplets. + + Args: + strict: Whether to raise an exception if the triplet cannot be created. + If False, the exception will be added to the list of exceptions and the triplet will be skipped. + The exceptions will also be raised when the resulting sequence is empty. + If True, the exception will be raised. + Default is False. + """ + + def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[TokensTriplet]: + return self.to_triplets(source) + + +class ToTokensAccumulations( + TokensTripletMixin, Adapter[BaseAdaptingSequence[AdaptingSpan], Sequence[TokensAccumulation]] +): + """Assemble multiple token input-output triplets into accumulated token sequences. + + Args: + diagnosis: Whether to include diagnosis information in the resulting TokensAccumulation. + strict: Whether to raise an exception if the triplet cannot be created. + If False, the exception will be added to the list of exceptions and the triplet will be skipped. + The exceptions will also be raised when the resulting sequence is empty. + If True, the exception will be raised. + Default is False. + tokenizer: An optional tokenizer to decode token IDs to text for diagnosis. + """ + + def __init__(self, strict: bool = False, diagnosis: bool = False, tokenizer: Optional[PreTrainedTokenizer] = None): + super().__init__(strict=strict) + self.diagnosis = diagnosis + self.tokenizer = tokenizer + + def _triplet_to_accumulation( + self, triplet: TokensTriplet, diagnosis_info: Optional[TokensAccumulationDiagnosis] + ) -> TokensAccumulation: + if triplet.completion.logprobs is not None: + logprobs = [0.0] * len(triplet.observation.token_ids) + list(triplet.completion.logprobs) + else: + logprobs = None + + return TokensAccumulation( + token_ids=[*triplet.observation.token_ids, *triplet.completion.token_ids], + image_urls=triplet.observation.image_urls, + logprobs=logprobs, + response_mask=[0] * len(triplet.observation.token_ids) + [1] * len(triplet.completion.token_ids), + final_reward=triplet.reward, + raw_calls=[triplet.raw_call], + diagnosis_info=diagnosis_info, + ) + + def _special_token_sequence(self, ids: Sequence[int]) -> List[int]: + assert self.tokenizer is not None, "Tokenizer must be provided for special token sequence extraction." + return [id for id in ids if id in self.tokenizer.all_special_ids] + + def _non_special_token_sequence(self, ids: Sequence[int]) -> List[int]: + assert self.tokenizer is not None, "Tokenizer must be provided for non-special token sequence extraction." + return [id for id in ids if id not in self.tokenizer.all_special_ids] + + def _diagnose_mismatch( + self, prev: TokensAccumulation, next: TokensTriplet + ) -> Optional[TokensAccumulationDiagnosis]: + if not self.diagnosis: + return None + + if self.tokenizer is None: + raise ValueError("Tokenizer must be provided for diagnosis.") + + image_urls_match = self.is_prefix(prev.image_urls, next.observation.image_urls) + + # Check whether the special tokens match + next_special_ids = self._special_token_sequence(next.observation.token_ids) + prev_special_ids = self._special_token_sequence(prev.token_ids) + special_tokens_match = self.is_prefix(prev_special_ids, next_special_ids) + + # Check whether the non-special tokens match + next_non_special_ids = self._non_special_token_sequence(next.observation.token_ids) + prev_non_special_ids = self._non_special_token_sequence(prev.token_ids) + non_special_tokens_match = self.is_prefix(prev_non_special_ids, next_non_special_ids) + + # Check whether the detokenized text matches + next_string = self.tokenizer.decode(next.observation.token_ids, skip_special_tokens=True) # type: ignore + prev_string = self.tokenizer.decode(prev.token_ids, skip_special_tokens=True) # type: ignore + detokenized_text_match = next_string.startswith(prev_string) + + return TokensAccumulationDiagnosis( + special_tokens_mismatch=not special_tokens_match, + non_special_tokens_mismatch=not non_special_tokens_match, + detokenized_text_mismatch=not detokenized_text_match, + image_urls_mismatch=not image_urls_match, + accumulation_prev=prev, + special_tokens_prev=prev_special_ids, + special_tokens_next=next_special_ids, + detokenized_text_prev=prev_string, + detokenized_text_next=next_string, + ) + + @staticmethod + def is_prefix(shorter: Iterable[Any], longer: Iterable[Any]) -> bool: + """Check if the shorter sequence is a prefix of the longer sequence.""" + longer_iter = iter(longer) + for item in shorter: + try: + expected_item = next(longer_iter) + if item != expected_item: + return False + except StopIteration: + return False + return True + + def _attempt_to_merge(self, prev: TokensAccumulation, next: TokensTriplet) -> List[TokensAccumulation]: + # Check if we can merge the next triplet into the previous accumulation + if not self.is_prefix(prev.image_urls, next.observation.image_urls): + return [prev, self._triplet_to_accumulation(next, self._diagnose_mismatch(prev, next))] + # Merge token ids + if not self.is_prefix(prev.token_ids, next.observation.token_ids): + return [prev, self._triplet_to_accumulation(next, self._diagnose_mismatch(prev, next))] + tokens_to_add = [*next.observation.token_ids[len(prev.token_ids) :], *next.completion.token_ids] + if prev.logprobs is not None and next.completion.logprobs is not None: + new_logprobs = list(prev.logprobs) + [0.0] * len(tokens_to_add) + list(next.completion.logprobs) + else: + new_logprobs = None + response_mask_to_add = [0] * (len(next.observation.token_ids) - len(prev.token_ids)) + [1] * len( + next.completion.token_ids + ) + new_reward = ( + (prev.final_reward or 0.0) + (next.reward or 0.0) + if next.reward is not None or prev.final_reward is not None + else None + ) + + return [ + TokensAccumulation( + token_ids=[*prev.token_ids, *tokens_to_add], + image_urls=next.observation.image_urls, + logprobs=new_logprobs, + response_mask=[*prev.response_mask, *response_mask_to_add], + final_reward=new_reward, + raw_calls=[*prev.raw_calls, next.raw_call], + diagnosis_info=None, + ) + ] + + def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[TokensAccumulation]: + triplets = self.to_triplets(source) - def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[TokensAccumulation]: ... + accumulations: List[TokensAccumulation] = [] + for triplet in triplets: + if not accumulations: + accumulations.append(self._triplet_to_accumulation(triplet, None)) + else: + last_accumulation = accumulations[-1] + accumulations = accumulations[:-1] + self._attempt_to_merge(last_accumulation, triplet) + return accumulations class ToPromptCompletionTriplets(Adapter[BaseAdaptingSequence[AdaptingSpan], Sequence[PromptCompletionTriplet]]): diff --git a/agentlightning/types/adapter.py b/agentlightning/types/adapter.py index aae2b9522..990c55d69 100644 --- a/agentlightning/types/adapter.py +++ b/agentlightning/types/adapter.py @@ -522,6 +522,36 @@ class TokensTriplet(Triplet[TokenInput, TokenOutput]): """ +class TokensAccumulationDiagnosis(BaseModel): + + special_tokens_mismatch: bool + """Whether there is a mismatch in special tokens.""" + + non_special_tokens_mismatch: bool + """Whether there is a mismatch in non-special tokens.""" + + detokenized_text_mismatch: bool + """Whether there is a mismatch in detokenized text.""" + + image_urls_mismatch: bool + """Whether there is a mismatch in image URLs.""" + + accumulation_prev: TokensAccumulation + """The previous accumulation which this triplet fails to match.""" + + special_tokens_prev: Sequence[int] + """Special tokens in the previous accumulation.""" + + special_tokens_next: Sequence[int] + """Special tokens in the next sequence.""" + + detokenized_text_prev: str + """Detokenized text in the previous accumulation.""" + + detokenized_text_next: str + """Detokenized text in the next sequence.""" + + class TokensAccumulation(Accumulation): """A sequence of token IDs that are accumulated from multiple model calls. @@ -540,6 +570,9 @@ class TokensAccumulation(Accumulation): response_mask: Sequence[int] """Mask for the response tokens. Must a sequence of 0s and 1s, with 1s for the completion tokens and 0s for the prompt tokens.""" + diagnosis_info: Optional[TokensAccumulationDiagnosis] + """Diagnosis information for token accumulation mismatches.""" + class PromptCompletionTriplet(Triplet[Sequence[ChatCompletionMessageParam], ChatCompletion]): """A triplet of prompt and completion.""" From 0f278fa068f03d4ceea19c1704f094c1750f9fbc Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Tue, 13 Jan 2026 15:53:43 +0800 Subject: [PATCH 36/41] edit postprocess --- agentlightning/adapter/call.py | 24 +- agentlightning/adapter/postprocess.py | 316 +++++++++++++++++--------- agentlightning/types/adapter.py | 8 +- agentlightning/utils/pydantic.py | 26 +++ 4 files changed, 240 insertions(+), 134 deletions(-) create mode 100644 agentlightning/utils/pydantic.py diff --git a/agentlightning/adapter/call.py b/agentlightning/adapter/call.py index 067a247ac..156fc7bcb 100644 --- a/agentlightning/adapter/call.py +++ b/agentlightning/adapter/call.py @@ -12,7 +12,7 @@ ChatCompletion, CompletionCreateParams, ) -from pydantic import TypeAdapter, ValidationError +from pydantic import TypeAdapter from agentlightning.types.adapter import ( AdaptingSpan, @@ -24,6 +24,7 @@ ) from agentlightning.types.tracer import Span from agentlightning.utils.otel import filter_and_unflatten_attributes, query_linked_spans +from agentlightning.utils.pydantic import to_plain_object from .base import SequenceAdapter @@ -137,26 +138,7 @@ def _parse_agentops_tool_calls(self, span: Span) -> Optional[Dict[str, Any]]: def _normalize_request(self, request_body: Dict[str, Any]) -> CompletionCreateParams: validated_request = CompletionCreateParamsType.validate_python(request_body) - - def _to_plain_object(object: Any, path: List[str]) -> Any: - if type(object).__name__ == "ValidatorIterator": - try: - return [_to_plain_object(item, path + [str(i)]) for i, item in enumerate(object)] - except ValidationError as exc: - raise ValueError( - "Failed to convert ValidatorIterator to list.\n" - "ValidatorIterator path (see below for subpath): " + ".".join(path) + "\nError: " + str(exc) - ) from exc - elif isinstance(object, dict): - return {k: _to_plain_object(v, path + [k]) for k, v in object.items()} # type: ignore - elif isinstance(object, list): - return [_to_plain_object(item, path + [str(i)]) for i, item in enumerate(object)] # type: ignore - elif isinstance(object, tuple): - return tuple(_to_plain_object(item, path + [str(i)]) for i, item in enumerate(object)) # type: ignore - else: - return object - - return _to_plain_object(validated_request, []) + return to_plain_object(validated_request, []) def _parse_openai_chat_completion_create(self, span: AdaptingSpan) -> ChatCompletionCall: prompt_messages = self._parse_prompt_messages(span) diff --git a/agentlightning/adapter/postprocess.py b/agentlightning/adapter/postprocess.py index a220eeeec..de2820ab7 100644 --- a/agentlightning/adapter/postprocess.py +++ b/agentlightning/adapter/postprocess.py @@ -4,7 +4,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterable, List, Literal, Optional, Sequence, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Generic, Iterable, List, Literal, Optional, Sequence, TypeVar, Union, cast + +from openai.types.chat import ChatCompletion, ChatCompletionAssistantMessageParam +from pydantic import TypeAdapter from agentlightning.types.adapter import ( AdaptingSpan, @@ -20,6 +23,7 @@ TokensAccumulationDiagnosis, TokensTriplet, ) +from agentlightning.utils.pydantic import to_plain_object if TYPE_CHECKING: from transformers import PreTrainedTokenizer @@ -30,13 +34,94 @@ "T_triplet_or_accumulation", bound=Union[TokensTriplet, TokensAccumulation, PromptCompletionTriplet, PromptCompletionAccumulation], ) +T_triplet = TypeVar("T_triplet", bound=Union[TokensTriplet, PromptCompletionTriplet]) + + +def is_prefix(shorter: Iterable[Any], longer: Iterable[Any]) -> bool: + """Check if the shorter sequence is a prefix of the longer sequence.""" + longer_iter = iter(longer) + for item in shorter: + try: + expected_item = next(longer_iter) + if item != expected_item: + return False + except StopIteration: + return False + return True -class TokensTripletMixin: +class ToTripletMixin(Generic[T_triplet]): + """Mixin for adapters that convert chat completion calls to triplets.""" def __init__(self, strict: bool = False): self.strict = strict + def get_reward(self, call: Union[AnnotatedChatCompletionCall, ChatCompletionCall]) -> Optional[float]: + if isinstance(call, AnnotatedChatCompletionCall): + for annotation in call.annotations: + # The first general annotation + if isinstance(annotation, GeneralAnnotation): + # The primary reward + return annotation.primary_reward + return None + + def to_triplet( + self, call: Union[AnnotatedChatCompletionCall, ChatCompletionCall] + ) -> Union[T_triplet, BaseException]: + raise NotImplementedError() + + def to_triplets(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[T_triplet]: + exceptions: List[BaseException] = [] + triplets: List[T_triplet] = [] + for span in source: + if isinstance(span.data, (AnnotatedChatCompletionCall, ChatCompletionCall)): + triplet = self.to_triplet(span.data) + if isinstance(triplet, BaseException): + exceptions.append(triplet) + else: + triplets.append(triplet) + if len(triplets) == 0: + error_msg = ( + f"{self.__class__.__name__} failed to create any triplets. " + f"The adapter has raised {len(exceptions)} exceptions when processing the spans:\n" + + "\n".join([f" - {exc}" for exc in exceptions]) + ) + raise RuntimeError(error_msg) + triplets[-1] = triplets[-1].model_copy(update={"done": True}) + return triplets + + +class ToTokensTriplets( + ToTripletMixin[TokensTriplet], Adapter[BaseAdaptingSequence[AdaptingSpan], Sequence[TokensTriplet]] +): + """Convert adapting spans to token input-output triplets. + + Args: + strict: Whether to raise an exception if the triplet cannot be created. + If False, the exception will be added to the list of exceptions and the triplet will be skipped. + The exceptions will also be raised when the resulting sequence is empty. + If True, the exception will be raised. + Default is False. + """ + + def to_triplet( + self, call: Union[AnnotatedChatCompletionCall, ChatCompletionCall] + ) -> Union[TokensTriplet, BaseException]: + try: + return TokensTriplet.model_construct( + observation=TokenInput( + token_ids=self._get_prompt_token_ids(call), image_urls=self._get_image_urls(call) + ), + action=TokenOutput(token_ids=self._get_completion_token_ids(call), logprobs=self._get_logprobs(call)), + reward=self.get_reward(call), + done=False, # False by now + raw_call=call, + ) + except Exception as exc: + if self.strict: + raise exc + return exc + def _get_prompt_token_ids(self, call: Union[AnnotatedChatCompletionCall, ChatCompletionCall]) -> Sequence[int]: prompt = call.request if "prompt_token_ids" in prompt: @@ -87,104 +172,35 @@ def _get_logprobs(self, call: Union[AnnotatedChatCompletionCall, ChatCompletionC return [logprob.logprob for logprob in content_logprobs] return None - def _get_reward(self, call: Union[AnnotatedChatCompletionCall, ChatCompletionCall]) -> Optional[float]: - if isinstance(call, AnnotatedChatCompletionCall): - for annotation in call.annotations: - # The first general annotation - if isinstance(annotation, GeneralAnnotation): - # The primary reward - return annotation.primary_reward - return None - - def to_triplet( - self, call: Union[AnnotatedChatCompletionCall, ChatCompletionCall] - ) -> Union[TokensTriplet, BaseException]: - try: - return TokensTriplet( - observation=TokenInput( - token_ids=self._get_prompt_token_ids(call), image_urls=self._get_image_urls(call) - ), - completion=TokenOutput( - token_ids=self._get_completion_token_ids(call), logprobs=self._get_logprobs(call) - ), - reward=self._get_reward(call), - done=False, # False by now - raw_call=call, - ) - except Exception as exc: - if self.strict: - raise exc - return exc - - def to_triplets(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[TokensTriplet]: - exceptions: List[BaseException] = [] - triplets: List[TokensTriplet] = [] - for span in source: - if isinstance(span.data, (AnnotatedChatCompletionCall, ChatCompletionCall)): - triplet = self.to_triplet(span.data) - if isinstance(triplet, BaseException): - exceptions.append(triplet) - else: - triplets.append(triplet) - if len(triplets) == 0: - error_msg = ( - f"{self.__class__.__name__} failed to create any triplets. " - f"The adapter has raised {len(exceptions)} exceptions when processing the spans:\n" - + "\n".join([f" - {exc}" for exc in exceptions]) - ) - raise RuntimeError(error_msg) - triplets[-1] = triplets[-1].model_copy(update={"done": True}) - return triplets - - -class ToTokensTriplets(TokensTripletMixin, Adapter[BaseAdaptingSequence[AdaptingSpan], Sequence[TokensTriplet]]): - """Convert adapting spans to token input-output triplets. - - Args: - strict: Whether to raise an exception if the triplet cannot be created. - If False, the exception will be added to the list of exceptions and the triplet will be skipped. - The exceptions will also be raised when the resulting sequence is empty. - If True, the exception will be raised. - Default is False. - """ - def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[TokensTriplet]: return self.to_triplets(source) -class ToTokensAccumulations( - TokensTripletMixin, Adapter[BaseAdaptingSequence[AdaptingSpan], Sequence[TokensAccumulation]] -): +class ToTokensAccumulations(Adapter[Sequence[TokensTriplet], Sequence[TokensAccumulation]]): """Assemble multiple token input-output triplets into accumulated token sequences. Args: diagnosis: Whether to include diagnosis information in the resulting TokensAccumulation. - strict: Whether to raise an exception if the triplet cannot be created. - If False, the exception will be added to the list of exceptions and the triplet will be skipped. - The exceptions will also be raised when the resulting sequence is empty. - If True, the exception will be raised. - Default is False. tokenizer: An optional tokenizer to decode token IDs to text for diagnosis. """ - def __init__(self, strict: bool = False, diagnosis: bool = False, tokenizer: Optional[PreTrainedTokenizer] = None): - super().__init__(strict=strict) + def __init__(self, diagnosis: bool = False, tokenizer: Optional[PreTrainedTokenizer] = None): self.diagnosis = diagnosis self.tokenizer = tokenizer def _triplet_to_accumulation( self, triplet: TokensTriplet, diagnosis_info: Optional[TokensAccumulationDiagnosis] ) -> TokensAccumulation: - if triplet.completion.logprobs is not None: - logprobs = [0.0] * len(triplet.observation.token_ids) + list(triplet.completion.logprobs) + if triplet.action.logprobs is not None: + logprobs = [0.0] * len(triplet.observation.token_ids) + list(triplet.action.logprobs) else: logprobs = None return TokensAccumulation( - token_ids=[*triplet.observation.token_ids, *triplet.completion.token_ids], + token_ids=[*triplet.observation.token_ids, *triplet.action.token_ids], image_urls=triplet.observation.image_urls, logprobs=logprobs, - response_mask=[0] * len(triplet.observation.token_ids) + [1] * len(triplet.completion.token_ids), + response_mask=[0] * len(triplet.observation.token_ids) + [1] * len(triplet.action.token_ids), final_reward=triplet.reward, raw_calls=[triplet.raw_call], diagnosis_info=diagnosis_info, @@ -207,17 +223,17 @@ def _diagnose_mismatch( if self.tokenizer is None: raise ValueError("Tokenizer must be provided for diagnosis.") - image_urls_match = self.is_prefix(prev.image_urls, next.observation.image_urls) + image_urls_match = is_prefix(prev.image_urls, next.observation.image_urls) # Check whether the special tokens match next_special_ids = self._special_token_sequence(next.observation.token_ids) prev_special_ids = self._special_token_sequence(prev.token_ids) - special_tokens_match = self.is_prefix(prev_special_ids, next_special_ids) + special_tokens_match = is_prefix(prev_special_ids, next_special_ids) # Check whether the non-special tokens match next_non_special_ids = self._non_special_token_sequence(next.observation.token_ids) prev_non_special_ids = self._non_special_token_sequence(prev.token_ids) - non_special_tokens_match = self.is_prefix(prev_non_special_ids, next_non_special_ids) + non_special_tokens_match = is_prefix(prev_non_special_ids, next_non_special_ids) # Check whether the detokenized text matches next_string = self.tokenizer.decode(next.observation.token_ids, skip_special_tokens=True) # type: ignore @@ -236,33 +252,20 @@ def _diagnose_mismatch( detokenized_text_next=next_string, ) - @staticmethod - def is_prefix(shorter: Iterable[Any], longer: Iterable[Any]) -> bool: - """Check if the shorter sequence is a prefix of the longer sequence.""" - longer_iter = iter(longer) - for item in shorter: - try: - expected_item = next(longer_iter) - if item != expected_item: - return False - except StopIteration: - return False - return True - def _attempt_to_merge(self, prev: TokensAccumulation, next: TokensTriplet) -> List[TokensAccumulation]: # Check if we can merge the next triplet into the previous accumulation - if not self.is_prefix(prev.image_urls, next.observation.image_urls): + if not is_prefix(prev.image_urls, next.observation.image_urls): return [prev, self._triplet_to_accumulation(next, self._diagnose_mismatch(prev, next))] # Merge token ids - if not self.is_prefix(prev.token_ids, next.observation.token_ids): + if not is_prefix(prev.token_ids, next.observation.token_ids): return [prev, self._triplet_to_accumulation(next, self._diagnose_mismatch(prev, next))] - tokens_to_add = [*next.observation.token_ids[len(prev.token_ids) :], *next.completion.token_ids] - if prev.logprobs is not None and next.completion.logprobs is not None: - new_logprobs = list(prev.logprobs) + [0.0] * len(tokens_to_add) + list(next.completion.logprobs) + tokens_to_add = [*next.observation.token_ids[len(prev.token_ids) :], *next.action.token_ids] + if prev.logprobs is not None and next.action.logprobs is not None: + new_logprobs = list(prev.logprobs) + [0.0] * len(tokens_to_add) + list(next.action.logprobs) else: new_logprobs = None response_mask_to_add = [0] * (len(next.observation.token_ids) - len(prev.token_ids)) + [1] * len( - next.completion.token_ids + next.action.token_ids ) new_reward = ( (prev.final_reward or 0.0) + (next.reward or 0.0) @@ -282,11 +285,9 @@ def _attempt_to_merge(self, prev: TokensAccumulation, next: TokensTriplet) -> Li ) ] - def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[TokensAccumulation]: - triplets = self.to_triplets(source) - + def adapt(self, source: Sequence[TokensTriplet]) -> Sequence[TokensAccumulation]: accumulations: List[TokensAccumulation] = [] - for triplet in triplets: + for triplet in source: if not accumulations: accumulations.append(self._triplet_to_accumulation(triplet, None)) else: @@ -295,18 +296,95 @@ def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[TokensAc return accumulations -class ToPromptCompletionTriplets(Adapter[BaseAdaptingSequence[AdaptingSpan], Sequence[PromptCompletionTriplet]]): - """Convert annotated chat completion calls to prompt-completion triplets.""" +class ToPromptCompletionTriplets( + ToTripletMixin[PromptCompletionTriplet], + Adapter[BaseAdaptingSequence[AdaptingSpan], Sequence[PromptCompletionTriplet]], +): + """Convert annotated chat completion calls to prompt-completion triplets. - def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[PromptCompletionTriplet]: ... + Args: + strict: Whether to raise an exception if the triplet cannot be created. + If False, the exception will be added to the list of exceptions and the triplet will be skipped. + The exceptions will also be raised when the resulting sequence is empty. + If True, the exception will be raised. + Default is False. + """ + + def to_triplet( + self, call: Union[AnnotatedChatCompletionCall, ChatCompletionCall] + ) -> Union[PromptCompletionTriplet, BaseException]: + try: + return PromptCompletionTriplet.model_construct( + observation=call.request, + action=call.response, + reward=self.get_reward(call), + done=False, # False by now + raw_call=call, + ) + except Exception as exc: + if self.strict: + raise exc + return exc + + def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[PromptCompletionTriplet]: + return self.to_triplets(source) class ToPromptCompletionAccumulations( - Adapter[BaseAdaptingSequence[AdaptingSpan], Sequence[PromptCompletionAccumulation]] + Adapter[Sequence[PromptCompletionTriplet], Sequence[PromptCompletionAccumulation]] ): """Assemble multiple prompt-completion triplets into accumulated prompt-completion pairs.""" - def adapt(self, source: BaseAdaptingSequence[AdaptingSpan]) -> Sequence[PromptCompletionAccumulation]: ... + def _to_messages(self, completion: ChatCompletion) -> List[ChatCompletionAssistantMessageParam]: + ChatCompletionAssistantMessageParamType = TypeAdapter(ChatCompletionAssistantMessageParam) + validated_message = ChatCompletionAssistantMessageParamType.validate_python(completion.choices[0].message) + return [to_plain_object(validated_message, [])] + + def _to_accumulation(self, triplet: PromptCompletionTriplet) -> PromptCompletionAccumulation: + tools = [*triplet.observation["tools"]] if "tools" in triplet.observation else None + return PromptCompletionAccumulation.model_construct( + messages=[*triplet.observation["messages"], *self._to_messages(triplet.action)], + tools=tools, + final_reward=triplet.reward, + raw_calls=[triplet.raw_call], + ) + + def _attempt_to_merge( + self, prev: PromptCompletionAccumulation, next: PromptCompletionTriplet + ) -> List[PromptCompletionAccumulation]: + next_acc = self._to_accumulation(next) + # Check if we can merge the next triplet into the previous accumulation + if prev.tools != next_acc.tools: + # Cannot merge because tools are different + return [prev, next_acc] + if not is_prefix(prev.messages, next_acc.messages): + # Cannot merge because messages do not match + return [prev, next_acc] + # Merge messages + merged_messages = [*prev.messages, *next_acc.messages[len(prev.messages) :]] + new_reward = ( + (prev.final_reward or 0.0) + (next.reward or 0.0) + if next.reward is not None or prev.final_reward is not None + else None + ) + return [ + PromptCompletionAccumulation( + messages=merged_messages, + tools=prev.tools, + final_reward=new_reward, + raw_calls=[*prev.raw_calls, next.raw_call], + ) + ] + + def adapt(self, source: Sequence[PromptCompletionTriplet]) -> Sequence[PromptCompletionAccumulation]: + accumulations: List[PromptCompletionAccumulation] = [] + for triplet in source: + if not accumulations: + accumulations.append(self._to_accumulation(triplet)) + else: + last_accumulation = accumulations[-1] + accumulations = accumulations[:-1] + self._attempt_to_merge(last_accumulation, triplet) + return accumulations class PropagateRewards(Adapter[Sequence[T_triplet_or_accumulation], Sequence[T_triplet_or_accumulation]]): @@ -315,4 +393,24 @@ class PropagateRewards(Adapter[Sequence[T_triplet_or_accumulation], Sequence[T_t def __init__(self, direction: Literal["forward", "backward"]) -> None: self.direction = direction - def adapt(self, source: Sequence[T_triplet_or_accumulation]) -> Sequence[T_triplet_or_accumulation]: ... + def adapt(self, source: Sequence[T_triplet_or_accumulation]) -> Sequence[T_triplet_or_accumulation]: + prev_reward: Optional[float] = None + if self.direction == "forward": + scan_order = list(range(len(source))) + else: + scan_order = list(reversed(range(len(source)))) + transformed_source = [*source] + for idx in scan_order: + item = source[idx] + if isinstance(item, PromptCompletionTriplet) or isinstance(item, TokensTriplet): + if item.reward is not None: + prev_reward = item.reward + else: + transformed_source[idx] = item.model_copy(update={"reward": prev_reward}) + else: + if item.final_reward is not None: + prev_reward = item.final_reward + else: + transformed_source[idx] = item.model_copy(update={"final_reward": prev_reward}) + + return transformed_source diff --git a/agentlightning/types/adapter.py b/agentlightning/types/adapter.py index 990c55d69..4355117da 100644 --- a/agentlightning/types/adapter.py +++ b/agentlightning/types/adapter.py @@ -24,8 +24,8 @@ from openai.types.chat import ( ChatCompletion, - ChatCompletionFunctionToolParam, ChatCompletionMessageParam, + ChatCompletionToolUnionParam, CompletionCreateParams, ) from pydantic import BaseModel, Field, model_validator @@ -462,7 +462,7 @@ class Triplet(Generic[T_observation, T_action], BaseModel): observation: T_observation """Observation for the model input.""" - completion: T_action + action: T_action """Action from the model output.""" reward: Optional[float] @@ -574,7 +574,7 @@ class TokensAccumulation(Accumulation): """Diagnosis information for token accumulation mismatches.""" -class PromptCompletionTriplet(Triplet[Sequence[ChatCompletionMessageParam], ChatCompletion]): +class PromptCompletionTriplet(Triplet[CompletionCreateParams, ChatCompletion]): """A triplet of prompt and completion.""" @@ -584,5 +584,5 @@ class PromptCompletionAccumulation(Accumulation): messages: Sequence[ChatCompletionMessageParam] """Messages of the conversation.""" - tools: Optional[Sequence[ChatCompletionFunctionToolParam]] + tools: Optional[Sequence[ChatCompletionToolUnionParam]] """Tools provided for the conversation.""" diff --git a/agentlightning/utils/pydantic.py b/agentlightning/utils/pydantic.py new file mode 100644 index 000000000..b230d50af --- /dev/null +++ b/agentlightning/utils/pydantic.py @@ -0,0 +1,26 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Helper functions to make pydantic more useful.""" + +from typing import Any, List + +from pydantic import ValidationError + + +def to_plain_object(object: Any, path: List[str]) -> Any: + if type(object).__name__ == "ValidatorIterator": + try: + return [to_plain_object(item, path + [str(i)]) for i, item in enumerate(object)] + except ValidationError as exc: + raise ValueError( + "Failed to convert ValidatorIterator to list.\n" + "ValidatorIterator path (see below for subpath): " + ".".join(path) + "\nError: " + str(exc) + ) from exc + elif isinstance(object, dict): + return {k: to_plain_object(v, path + [k]) for k, v in object.items()} # type: ignore + elif isinstance(object, list): + return [to_plain_object(item, path + [str(i)]) for i, item in enumerate(object)] # type: ignore + elif isinstance(object, tuple): + return tuple(to_plain_object(item, path + [str(i)]) for i, item in enumerate(object)) # type: ignore + else: + return object From 4351a916fac6f1eb2fa61305b75ab1f1e190324e Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Wed, 14 Jan 2026 12:20:14 +0800 Subject: [PATCH 37/41] update postprocess tests --- agentlightning/adapter/postprocess.py | 8 +- tests/adapter/test_postprocess.py | 1257 +++++++++++++++++++++++++ 2 files changed, 1263 insertions(+), 2 deletions(-) create mode 100644 tests/adapter/test_postprocess.py diff --git a/agentlightning/adapter/postprocess.py b/agentlightning/adapter/postprocess.py index de2820ab7..799a5b059 100644 --- a/agentlightning/adapter/postprocess.py +++ b/agentlightning/adapter/postprocess.py @@ -261,7 +261,9 @@ def _attempt_to_merge(self, prev: TokensAccumulation, next: TokensTriplet) -> Li return [prev, self._triplet_to_accumulation(next, self._diagnose_mismatch(prev, next))] tokens_to_add = [*next.observation.token_ids[len(prev.token_ids) :], *next.action.token_ids] if prev.logprobs is not None and next.action.logprobs is not None: - new_logprobs = list(prev.logprobs) + [0.0] * len(tokens_to_add) + list(next.action.logprobs) + # Add zeros only for observation extension tokens, not for action tokens + observation_extension_len = len(next.observation.token_ids) - len(prev.token_ids) + new_logprobs = list(prev.logprobs) + [0.0] * observation_extension_len + list(next.action.logprobs) else: new_logprobs = None response_mask_to_add = [0] * (len(next.observation.token_ids) - len(prev.token_ids)) + [1] * len( @@ -337,7 +339,9 @@ class ToPromptCompletionAccumulations( def _to_messages(self, completion: ChatCompletion) -> List[ChatCompletionAssistantMessageParam]: ChatCompletionAssistantMessageParamType = TypeAdapter(ChatCompletionAssistantMessageParam) - validated_message = ChatCompletionAssistantMessageParamType.validate_python(completion.choices[0].message) + # Convert message to dict first since TypeAdapter expects dict for TypedDict validation + message_dict = completion.choices[0].message.model_dump() + validated_message = ChatCompletionAssistantMessageParamType.validate_python(message_dict) return [to_plain_object(validated_message, [])] def _to_accumulation(self, triplet: PromptCompletionTriplet) -> PromptCompletionAccumulation: diff --git a/tests/adapter/test_postprocess.py b/tests/adapter/test_postprocess.py new file mode 100644 index 000000000..c8993bc26 --- /dev/null +++ b/tests/adapter/test_postprocess.py @@ -0,0 +1,1257 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for the postprocess module adapters.""" + +from typing import Any, Dict, List, Optional, Sequence +from unittest.mock import MagicMock + +import pytest +from openai.types.chat import ChatCompletion +from openai.types.chat.chat_completion import Choice, ChoiceLogprobs +from openai.types.chat.chat_completion_message import ChatCompletionMessage +from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob + +from agentlightning.adapter.postprocess import ( + PropagateRewards, + ToPromptCompletionAccumulations, + ToPromptCompletionTriplets, + ToTokensAccumulations, + ToTokensTriplets, + is_prefix, +) +from agentlightning.types.adapter import ( + AdaptingSequence, + AdaptingSpan, + AnnotatedChatCompletionCall, + ChatCompletionCall, + GeneralAnnotation, + PromptCompletionAccumulation, + PromptCompletionTriplet, + TokenInput, + TokenOutput, + TokensAccumulation, + TokensTriplet, +) +from agentlightning.types.tracer import Span + +# ============================================================================== +# Helper functions to create mock objects +# ============================================================================== + + +def make_chat_completion( + *, + content: str = "Hello", + role: str = "assistant", + finish_reason: str = "stop", + model: str = "gpt-4", + completion_id: str = "chatcmpl-123", + token_ids: Optional[Sequence[int]] = None, + logprobs: Optional[List[float]] = None, + provider_specific_fields: Optional[Dict[str, Any]] = None, +) -> ChatCompletion: + """Create a mock ChatCompletion object.""" + choice_logprobs = None + if logprobs is not None: + choice_logprobs = ChoiceLogprobs( + content=[ + ChatCompletionTokenLogprob(token=f"tok{i}", bytes=None, logprob=lp, top_logprobs=[]) + for i, lp in enumerate(logprobs) + ], + refusal=None, + ) + + message = ChatCompletionMessage(content=content, role=role, refusal=None) # type: ignore + choice = Choice( + finish_reason=finish_reason, # type: ignore + index=0, + message=message, + logprobs=choice_logprobs, + ) + + # Add token_ids as an attribute if provided + if token_ids is not None: + choice.token_ids = token_ids # type: ignore + if provider_specific_fields is not None: + choice.provider_specific_fields = provider_specific_fields # type: ignore + + return ChatCompletion( + id=completion_id, + choices=[choice], + created=1234567890, + model=model, + object="chat.completion", + ) + + +def make_completion_request( + *, + messages: Optional[List[Dict[str, Any]]] = None, + model: str = "gpt-4", + prompt_token_ids: Optional[Sequence[int]] = None, + tools: Optional[List[Dict[str, Any]]] = None, +) -> Dict[str, Any]: + """Create a mock CompletionCreateParams-like dict.""" + if messages is None: + messages = [{"role": "user", "content": "Hello"}] + + request: Dict[str, Any] = { + "messages": messages, + "model": model, + } + if prompt_token_ids is not None: + request["prompt_token_ids"] = prompt_token_ids + if tools is not None: + request["tools"] = tools + return request + + +def make_chat_completion_call( + *, + request: Optional[Dict[str, Any]] = None, + response: Optional[ChatCompletion] = None, + prompt_token_ids: Optional[Sequence[int]] = None, + completion_token_ids: Optional[Sequence[int]] = None, + logprobs: Optional[List[float]] = None, +) -> ChatCompletionCall: + """Create a mock ChatCompletionCall.""" + if request is None: + request = make_completion_request(prompt_token_ids=prompt_token_ids) + if response is None: + response = make_chat_completion(token_ids=completion_token_ids, logprobs=logprobs) + # Use model_construct to bypass Pydantic validation which strips unknown fields like prompt_token_ids + return ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + + +def make_annotated_call( + *, + request: Optional[Dict[str, Any]] = None, + response: Optional[ChatCompletion] = None, + reward: Optional[float] = None, + prompt_token_ids: Optional[Sequence[int]] = None, + completion_token_ids: Optional[Sequence[int]] = None, + logprobs: Optional[List[float]] = None, +) -> AnnotatedChatCompletionCall: + """Create a mock AnnotatedChatCompletionCall with optional reward.""" + if request is None: + request = make_completion_request(prompt_token_ids=prompt_token_ids) + if response is None: + response = make_chat_completion(token_ids=completion_token_ids, logprobs=logprobs) + + annotations = [] + if reward is not None: + annotations.append(GeneralAnnotation(primary_reward=reward)) + + # Use model_construct to bypass Pydantic validation which strips unknown fields like prompt_token_ids + return AnnotatedChatCompletionCall.model_construct( + request=request, + response=response, + malformed_fields={}, + annotations=annotations, + ) + + +def make_adapting_span(data: Any, span_id: str = "span-1") -> AdaptingSpan: + """Create an AdaptingSpan with the given data.""" + span = Span.from_attributes( + rollout_id="rollout-1", + attempt_id="attempt-1", + sequence_id=0, + trace_id="trace-1", + span_id=span_id, + parent_id=None, + name="test-span", + attributes={}, + start_time=0.0, + end_time=1.0, + ) + return AdaptingSpan.from_span(span, data=data) + + +def make_adapting_sequence(calls: Sequence[Any]) -> AdaptingSequence[AdaptingSpan]: + """Create an AdaptingSequence from a list of calls.""" + spans = [make_adapting_span(call, span_id=f"span-{i}") for i, call in enumerate(calls)] + return AdaptingSequence(spans) + + +# ============================================================================== +# Tests for is_prefix function +# ============================================================================== + + +def test_is_prefix_empty_shorter(): + """Empty sequence is a prefix of any sequence.""" + assert is_prefix([], [1, 2, 3]) is True + assert is_prefix([], []) is True + + +def test_is_prefix_identical_sequences(): + """Identical sequences should be prefixes of each other.""" + assert is_prefix([1, 2, 3], [1, 2, 3]) is True + + +def test_is_prefix_true_case(): + """Shorter sequence is a prefix of longer sequence.""" + assert is_prefix([1, 2], [1, 2, 3, 4]) is True + assert is_prefix(["a"], ["a", "b", "c"]) is True + + +def test_is_prefix_false_mismatch(): + """Sequences with mismatched elements are not prefixes.""" + assert is_prefix([1, 2, 3], [1, 2, 4]) is False + assert is_prefix([1, 3], [1, 2, 3]) is False + + +def test_is_prefix_longer_than_sequence(): + """Longer sequence cannot be a prefix of shorter sequence.""" + assert is_prefix([1, 2, 3, 4], [1, 2, 3]) is False + + +def test_is_prefix_with_iterables(): + """is_prefix should work with any iterables.""" + assert is_prefix(iter([1, 2]), iter([1, 2, 3])) is True + assert is_prefix(range(3), [0, 1, 2, 3, 4]) is True + + +# ============================================================================== +# Tests for ToTokensTriplets +# ============================================================================== + + +def test_to_tokens_triplets_basic(): + """Basic conversion from chat completion call to tokens triplet.""" + call = make_chat_completion_call( + prompt_token_ids=[1, 2, 3], + completion_token_ids=[4, 5, 6], + ) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + triplets = adapter.adapt(source) + + assert len(triplets) == 1 + assert list(triplets[0].observation.token_ids) == [1, 2, 3] + assert list(triplets[0].action.token_ids) == [4, 5, 6] + assert triplets[0].done is True # Last triplet should be done + + +def test_to_tokens_triplets_with_reward(): + """Triplet should include reward from annotations.""" + call = make_annotated_call( + prompt_token_ids=[1, 2, 3], + completion_token_ids=[4, 5], + reward=0.75, + ) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + triplets = adapter.adapt(source) + + assert triplets[0].reward == 0.75 + + +def test_to_tokens_triplets_with_logprobs(): + """Triplet should include logprobs when available.""" + call = make_chat_completion_call( + prompt_token_ids=[1, 2, 3], + completion_token_ids=[4, 5], + logprobs=[-0.1, -0.2], + ) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + triplets = adapter.adapt(source) + + assert triplets[0].action.logprobs is not None + assert list(triplets[0].action.logprobs) == [-0.1, -0.2] + + +def test_to_tokens_triplets_with_image_urls(): + """Triplet should extract image URLs from messages.""" + request = make_completion_request( + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is this?"}, + {"type": "image_url", "image_url": {"url": "http://example.com/img1.png"}}, + {"type": "image_url", "image_url": {"url": "http://example.com/img2.png"}}, + ], + } + ], + prompt_token_ids=[1, 2, 3], + ) + response = make_chat_completion(token_ids=[4, 5]) + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + triplets = adapter.adapt(source) + + assert list(triplets[0].observation.image_urls) == [ + "http://example.com/img1.png", + "http://example.com/img2.png", + ] + + +def test_to_tokens_triplets_missing_prompt_token_ids(): + """Should raise error when prompt_token_ids is missing.""" + request = make_completion_request() # No prompt_token_ids + response = make_chat_completion(token_ids=[4, 5]) + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + + with pytest.raises(RuntimeError, match="failed to create any triplets"): + adapter.adapt(source) + + +def test_to_tokens_triplets_missing_completion_token_ids(): + """Should raise error when completion token_ids is missing.""" + request = make_completion_request(prompt_token_ids=[1, 2, 3]) + response = make_chat_completion() # No token_ids + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + + with pytest.raises(RuntimeError, match="failed to create any triplets"): + adapter.adapt(source) + + +def test_to_tokens_triplets_empty_prompt_token_ids(): + """Should raise error when prompt_token_ids is empty.""" + request = make_completion_request(prompt_token_ids=[]) + response = make_chat_completion(token_ids=[4, 5]) + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + + with pytest.raises(RuntimeError, match="failed to create any triplets"): + adapter.adapt(source) + + +def test_to_tokens_triplets_empty_completion_token_ids(): + """Should raise error when completion token_ids is empty.""" + request = make_completion_request(prompt_token_ids=[1, 2, 3]) + response = make_chat_completion(token_ids=[]) + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + + with pytest.raises(RuntimeError, match="failed to create any triplets"): + adapter.adapt(source) + + +def test_to_tokens_triplets_strict_mode(): + """In strict mode, should raise immediately on error.""" + request = make_completion_request() # Missing prompt_token_ids + response = make_chat_completion(token_ids=[4, 5]) + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets(strict=True) + + with pytest.raises(ValueError, match="Prompt token ids not found"): + adapter.adapt(source) + + +def test_to_tokens_triplets_multiple_calls(): + """Multiple calls should create multiple triplets.""" + call1 = make_chat_completion_call(prompt_token_ids=[1, 2], completion_token_ids=[3, 4]) + call2 = make_chat_completion_call(prompt_token_ids=[5, 6], completion_token_ids=[7, 8]) + source = make_adapting_sequence([call1, call2]) + + adapter = ToTokensTriplets() + triplets = adapter.adapt(source) + + assert len(triplets) == 2 + assert triplets[0].done is False + assert triplets[1].done is True + + +def test_to_tokens_triplets_provider_specific_token_ids(): + """Should extract token_ids from provider_specific_fields.""" + request = make_completion_request(prompt_token_ids=[1, 2, 3]) + response = make_chat_completion(provider_specific_fields={"token_ids": [4, 5, 6]}) + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + triplets = adapter.adapt(source) + + assert list(triplets[0].action.token_ids) == [4, 5, 6] + + +def test_to_tokens_triplets_skip_non_call_spans(): + """Should skip spans that don't contain chat completion calls.""" + call = make_chat_completion_call(prompt_token_ids=[1, 2], completion_token_ids=[3, 4]) + spans = [ + make_adapting_span("not a call", span_id="span-0"), + make_adapting_span(call, span_id="span-1"), + make_adapting_span(None, span_id="span-2"), + ] + source = AdaptingSequence(spans) + + adapter = ToTokensTriplets() + triplets = adapter.adapt(source) + + assert len(triplets) == 1 + + +# ============================================================================== +# Tests for ToTokensAccumulations +# ============================================================================== + + +def test_to_tokens_accumulations_single_triplet(): + """Single triplet should be converted to accumulation.""" + raw_call = make_chat_completion_call() + triplet = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2, 3], image_urls=[]), + action=TokenOutput(token_ids=[4, 5], logprobs=[-0.1, -0.2]), + reward=0.5, + done=True, + raw_call=raw_call, + ) + + adapter = ToTokensAccumulations() + accumulations = adapter.adapt([triplet]) + + assert len(accumulations) == 1 + assert list(accumulations[0].token_ids) == [1, 2, 3, 4, 5] + assert accumulations[0].response_mask == [0, 0, 0, 1, 1] + assert accumulations[0].final_reward == 0.5 + + +def test_to_tokens_accumulations_merge_sequential(): + """Sequential triplets with matching prefixes should be merged.""" + triplet1 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2], image_urls=[]), + action=TokenOutput(token_ids=[3, 4], logprobs=None), + reward=0.3, + done=False, + raw_call=make_chat_completion_call(), + ) + triplet2 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2, 3, 4, 5], image_urls=[]), + action=TokenOutput(token_ids=[6, 7], logprobs=None), + reward=0.4, + done=True, + raw_call=make_chat_completion_call(), + ) + + adapter = ToTokensAccumulations() + accumulations = adapter.adapt([triplet1, triplet2]) + + assert len(accumulations) == 1 + # Final tokens: [1, 2, 3, 4] + [5] + [6, 7] + assert list(accumulations[0].token_ids) == [1, 2, 3, 4, 5, 6, 7] + # Response mask: [0, 0, 1, 1] + [0] + [1, 1] + assert accumulations[0].response_mask == [0, 0, 1, 1, 0, 1, 1] + # Rewards should be summed + assert accumulations[0].final_reward == 0.7 + + +def test_to_tokens_accumulations_no_merge_mismatch(): + """Triplets with mismatched prefixes should not merge.""" + triplet1 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2], image_urls=[]), + action=TokenOutput(token_ids=[3, 4], logprobs=None), + reward=0.3, + done=False, + raw_call=make_chat_completion_call(), + ) + triplet2 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[5, 6, 7], image_urls=[]), # Doesn't start with [1,2,3,4] + action=TokenOutput(token_ids=[8, 9], logprobs=None), + reward=0.4, + done=True, + raw_call=make_chat_completion_call(), + ) + + adapter = ToTokensAccumulations() + accumulations = adapter.adapt([triplet1, triplet2]) + + assert len(accumulations) == 2 + + +def test_to_tokens_accumulations_image_url_mismatch(): + """Triplets with mismatched image URLs should not merge.""" + triplet1 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2], image_urls=["img1.png"]), + action=TokenOutput(token_ids=[3, 4], logprobs=None), + reward=None, + done=False, + raw_call=make_chat_completion_call(), + ) + triplet2 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2, 3, 4, 5], image_urls=["img2.png"]), # Different image + action=TokenOutput(token_ids=[6, 7], logprobs=None), + reward=None, + done=True, + raw_call=make_chat_completion_call(), + ) + + adapter = ToTokensAccumulations() + accumulations = adapter.adapt([triplet1, triplet2]) + + assert len(accumulations) == 2 + + +def test_to_tokens_accumulations_image_url_extension(): + """Image URLs can be extended when merging.""" + triplet1 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2], image_urls=["img1.png"]), + action=TokenOutput(token_ids=[3, 4], logprobs=None), + reward=None, + done=False, + raw_call=make_chat_completion_call(), + ) + triplet2 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2, 3, 4, 5], image_urls=["img1.png", "img2.png"]), + action=TokenOutput(token_ids=[6, 7], logprobs=None), + reward=None, + done=True, + raw_call=make_chat_completion_call(), + ) + + adapter = ToTokensAccumulations() + accumulations = adapter.adapt([triplet1, triplet2]) + + assert len(accumulations) == 1 + assert list(accumulations[0].image_urls) == ["img1.png", "img2.png"] + + +def test_to_tokens_accumulations_with_logprobs_merge(): + """Logprobs should be properly accumulated when merging.""" + triplet1 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2], image_urls=[]), + action=TokenOutput(token_ids=[3, 4], logprobs=[-0.1, -0.2]), + reward=None, + done=False, + raw_call=make_chat_completion_call(), + ) + triplet2 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2, 3, 4, 5], image_urls=[]), + action=TokenOutput(token_ids=[6, 7], logprobs=[-0.3, -0.4]), + reward=None, + done=True, + raw_call=make_chat_completion_call(), + ) + + adapter = ToTokensAccumulations() + accumulations = adapter.adapt([triplet1, triplet2]) + + assert len(accumulations) == 1 + # Final tokens: [1, 2, 3, 4] + [5] + [6, 7] = 7 tokens + assert list(accumulations[0].token_ids) == [1, 2, 3, 4, 5, 6, 7] + assert accumulations[0].logprobs is not None + # Logprobs should match token_ids length + # [0, 0, -0.1, -0.2] from first triplet + [0] for observation extension + [-0.3, -0.4] for action + expected_logprobs = [0.0, 0.0, -0.1, -0.2, 0.0, -0.3, -0.4] + assert list(accumulations[0].logprobs) == expected_logprobs + assert len(accumulations[0].logprobs) == len(accumulations[0].token_ids) + + +def test_to_tokens_accumulations_empty_input(): + """Empty input should return empty output.""" + adapter = ToTokensAccumulations() + accumulations = adapter.adapt([]) + + assert len(accumulations) == 0 + + +def test_to_tokens_accumulations_none_reward_handling(): + """None rewards should be handled correctly.""" + triplet1 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2], image_urls=[]), + action=TokenOutput(token_ids=[3, 4], logprobs=None), + reward=None, + done=False, + raw_call=make_chat_completion_call(), + ) + triplet2 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2, 3, 4], image_urls=[]), + action=TokenOutput(token_ids=[5, 6], logprobs=None), + reward=0.5, + done=True, + raw_call=make_chat_completion_call(), + ) + + adapter = ToTokensAccumulations() + accumulations = adapter.adapt([triplet1, triplet2]) + + assert len(accumulations) == 1 + assert accumulations[0].final_reward == 0.5 + + +def test_to_tokens_accumulations_raw_calls_aggregated(): + """Raw calls from merged triplets should be aggregated.""" + raw_call1 = make_chat_completion_call(completion_token_ids=[1]) + raw_call2 = make_chat_completion_call(completion_token_ids=[2]) + triplet1 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2], image_urls=[]), + action=TokenOutput(token_ids=[3, 4], logprobs=None), + reward=None, + done=False, + raw_call=raw_call1, + ) + triplet2 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2, 3, 4], image_urls=[]), + action=TokenOutput(token_ids=[5, 6], logprobs=None), + reward=None, + done=True, + raw_call=raw_call2, + ) + + adapter = ToTokensAccumulations() + accumulations = adapter.adapt([triplet1, triplet2]) + + assert len(accumulations[0].raw_calls) == 2 + # Verify they are the same objects (not copies) + assert accumulations[0].raw_calls[0].response.choices[0].token_ids == [1] # type: ignore + assert accumulations[0].raw_calls[1].response.choices[0].token_ids == [2] # type: ignore + + +# ============================================================================== +# Tests for ToPromptCompletionTriplets +# ============================================================================== + + +def test_to_prompt_completion_triplets_basic(): + """Basic conversion from chat completion call to prompt-completion triplet.""" + call = make_chat_completion_call() + source = make_adapting_sequence([call]) + + adapter = ToPromptCompletionTriplets() + triplets = adapter.adapt(source) + + assert len(triplets) == 1 + assert triplets[0].observation == call.request + assert triplets[0].action == call.response + assert triplets[0].done is True + + +def test_to_prompt_completion_triplets_with_reward(): + """Triplet should include reward from annotations.""" + call = make_annotated_call(reward=0.9) + source = make_adapting_sequence([call]) + + adapter = ToPromptCompletionTriplets() + triplets = adapter.adapt(source) + + assert triplets[0].reward == 0.9 + + +def test_to_prompt_completion_triplets_multiple(): + """Multiple calls should create multiple triplets with correct done flags.""" + call1 = make_chat_completion_call() + call2 = make_chat_completion_call() + call3 = make_chat_completion_call() + source = make_adapting_sequence([call1, call2, call3]) + + adapter = ToPromptCompletionTriplets() + triplets = adapter.adapt(source) + + assert len(triplets) == 3 + assert triplets[0].done is False + assert triplets[1].done is False + assert triplets[2].done is True + + +def test_to_prompt_completion_triplets_strict_mode(): + """In strict mode, errors should propagate immediately.""" + # Create a span with non-call data + span = make_adapting_span("not a call") + source = AdaptingSequence([span]) + + adapter = ToPromptCompletionTriplets() + + # Should raise because no valid triplets can be created + with pytest.raises(RuntimeError, match="failed to create any triplets"): + adapter.adapt(source) + + +# ============================================================================== +# Tests for ToPromptCompletionAccumulations +# ============================================================================== + + +def test_to_prompt_completion_accumulations_single(): + """Single triplet converts to single accumulation.""" + request = make_completion_request(messages=[{"role": "user", "content": "Hi"}]) + response = make_chat_completion(content="Hello!") + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + + triplet = PromptCompletionTriplet.model_construct( + observation=request, + action=response, + reward=0.5, + done=True, + raw_call=call, + ) + + adapter = ToPromptCompletionAccumulations() + accumulations = adapter.adapt([triplet]) + + assert len(accumulations) == 1 + # Should have user message + assistant response + assert len(accumulations[0].messages) == 2 + assert accumulations[0].messages[0]["role"] == "user" + assert accumulations[0].messages[1]["role"] == "assistant" + + +def test_to_prompt_completion_accumulations_merge(): + """Sequential triplets with matching messages should merge.""" + request1 = make_completion_request(messages=[{"role": "user", "content": "Hi"}]) + response1 = make_chat_completion(content="Hello!") + call1 = ChatCompletionCall.model_construct(request=request1, response=response1, malformed_fields={}) + + # Second request includes the previous conversation + request2 = make_completion_request( + messages=[ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + {"role": "user", "content": "How are you?"}, + ] + ) + response2 = make_chat_completion(content="I'm good!") + call2 = ChatCompletionCall.model_construct(request=request2, response=response2, malformed_fields={}) + + triplet1 = PromptCompletionTriplet.model_construct( + observation=request1, + action=response1, + reward=0.3, + done=False, + raw_call=call1, + ) + triplet2 = PromptCompletionTriplet.model_construct( + observation=request2, + action=response2, + reward=0.4, + done=True, + raw_call=call2, + ) + + adapter = ToPromptCompletionAccumulations() + accumulations = adapter.adapt([triplet1, triplet2]) + + assert len(accumulations) == 1 + # Should have all 4 messages + assert len(accumulations[0].messages) == 4 + assert accumulations[0].final_reward == 0.7 + + +def test_to_prompt_completion_accumulations_no_merge_different_tools(): + """Triplets with different tools should not merge.""" + request1 = make_completion_request(tools=[{"type": "function", "function": {"name": "tool1"}}]) + response1 = make_chat_completion() + call1 = ChatCompletionCall.model_construct(request=request1, response=response1, malformed_fields={}) + + request2 = make_completion_request(tools=[{"type": "function", "function": {"name": "tool2"}}]) + response2 = make_chat_completion() + call2 = ChatCompletionCall.model_construct(request=request2, response=response2, malformed_fields={}) + + triplet1 = PromptCompletionTriplet.model_construct( + observation=request1, + action=response1, + reward=None, + done=False, + raw_call=call1, + ) + triplet2 = PromptCompletionTriplet.model_construct( + observation=request2, + action=response2, + reward=None, + done=True, + raw_call=call2, + ) + + adapter = ToPromptCompletionAccumulations() + accumulations = adapter.adapt([triplet1, triplet2]) + + assert len(accumulations) == 2 + + +def test_to_prompt_completion_accumulations_no_merge_message_mismatch(): + """Triplets with mismatched message history should not merge.""" + request1 = make_completion_request(messages=[{"role": "user", "content": "Hi"}]) + response1 = make_chat_completion(content="Hello!") + call1 = ChatCompletionCall.model_construct(request=request1, response=response1, malformed_fields={}) + + # Second request has completely different history + request2 = make_completion_request(messages=[{"role": "user", "content": "Different question"}]) + response2 = make_chat_completion(content="Different answer") + call2 = ChatCompletionCall.model_construct(request=request2, response=response2, malformed_fields={}) + + triplet1 = PromptCompletionTriplet.model_construct( + observation=request1, + action=response1, + reward=None, + done=False, + raw_call=call1, + ) + triplet2 = PromptCompletionTriplet.model_construct( + observation=request2, + action=response2, + reward=None, + done=True, + raw_call=call2, + ) + + adapter = ToPromptCompletionAccumulations() + accumulations = adapter.adapt([triplet1, triplet2]) + + assert len(accumulations) == 2 + + +def test_to_prompt_completion_accumulations_preserves_tools(): + """Accumulation should preserve tool definitions.""" + tools = [{"type": "function", "function": {"name": "my_tool", "parameters": {}}}] + request = make_completion_request(tools=tools) + response = make_chat_completion() + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + + triplet = PromptCompletionTriplet.model_construct( + observation=request, + action=response, + reward=None, + done=True, + raw_call=call, + ) + + adapter = ToPromptCompletionAccumulations() + accumulations = adapter.adapt([triplet]) + + assert accumulations[0].tools == tools + + +def test_to_prompt_completion_accumulations_empty(): + """Empty input should return empty output.""" + adapter = ToPromptCompletionAccumulations() + accumulations = adapter.adapt([]) + + assert len(accumulations) == 0 + + +# ============================================================================== +# Tests for PropagateRewards +# ============================================================================== + + +def test_propagate_rewards_forward_triplets(): + """Forward propagation fills None rewards with previous value.""" + triplets = [ + TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1], image_urls=[]), + action=TokenOutput(token_ids=[2], logprobs=None), + reward=0.5, + done=False, + raw_call=make_chat_completion_call(), + ), + TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2], image_urls=[]), + action=TokenOutput(token_ids=[3], logprobs=None), + reward=None, + done=False, + raw_call=make_chat_completion_call(), + ), + TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2, 3], image_urls=[]), + action=TokenOutput(token_ids=[4], logprobs=None), + reward=None, + done=True, + raw_call=make_chat_completion_call(), + ), + ] + + adapter = PropagateRewards(direction="forward") + result = adapter.adapt(triplets) + + assert result[0].reward == 0.5 + assert result[1].reward == 0.5 + assert result[2].reward == 0.5 + + +def test_propagate_rewards_backward_triplets(): + """Backward propagation fills None rewards with next value.""" + triplets = [ + TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1], image_urls=[]), + action=TokenOutput(token_ids=[2], logprobs=None), + reward=None, + done=False, + raw_call=make_chat_completion_call(), + ), + TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2], image_urls=[]), + action=TokenOutput(token_ids=[3], logprobs=None), + reward=None, + done=False, + raw_call=make_chat_completion_call(), + ), + TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2, 3], image_urls=[]), + action=TokenOutput(token_ids=[4], logprobs=None), + reward=0.8, + done=True, + raw_call=make_chat_completion_call(), + ), + ] + + adapter = PropagateRewards(direction="backward") + result = adapter.adapt(triplets) + + assert result[0].reward == 0.8 + assert result[1].reward == 0.8 + assert result[2].reward == 0.8 + + +def test_propagate_rewards_forward_accumulations(): + """Forward propagation works with accumulations (uses final_reward).""" + accumulations = [ + TokensAccumulation( + token_ids=[1, 2], + image_urls=[], + logprobs=None, + response_mask=[0, 1], + final_reward=0.3, + raw_calls=[], + diagnosis_info=None, + ), + TokensAccumulation( + token_ids=[3, 4], + image_urls=[], + logprobs=None, + response_mask=[0, 1], + final_reward=None, + raw_calls=[], + diagnosis_info=None, + ), + ] + + adapter = PropagateRewards(direction="forward") + result = adapter.adapt(accumulations) + + assert result[0].final_reward == 0.3 + assert result[1].final_reward == 0.3 + + +def test_propagate_rewards_backward_accumulations(): + """Backward propagation works with accumulations.""" + accumulations = [ + TokensAccumulation( + token_ids=[1, 2], + image_urls=[], + logprobs=None, + response_mask=[0, 1], + final_reward=None, + raw_calls=[], + diagnosis_info=None, + ), + TokensAccumulation( + token_ids=[3, 4], + image_urls=[], + logprobs=None, + response_mask=[0, 1], + final_reward=0.7, + raw_calls=[], + diagnosis_info=None, + ), + ] + + adapter = PropagateRewards(direction="backward") + result = adapter.adapt(accumulations) + + assert result[0].final_reward == 0.7 + assert result[1].final_reward == 0.7 + + +def test_propagate_rewards_preserves_existing(): + """Propagation should not overwrite existing rewards.""" + triplets = [ + TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1], image_urls=[]), + action=TokenOutput(token_ids=[2], logprobs=None), + reward=0.1, + done=False, + raw_call=make_chat_completion_call(), + ), + TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2], image_urls=[]), + action=TokenOutput(token_ids=[3], logprobs=None), + reward=0.5, + done=False, + raw_call=make_chat_completion_call(), + ), + TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2, 3], image_urls=[]), + action=TokenOutput(token_ids=[4], logprobs=None), + reward=None, + done=True, + raw_call=make_chat_completion_call(), + ), + ] + + adapter = PropagateRewards(direction="forward") + result = adapter.adapt(triplets) + + assert result[0].reward == 0.1 + assert result[1].reward == 0.5 + assert result[2].reward == 0.5 # Propagated from index 1 + + +def test_propagate_rewards_empty_input(): + """Empty input should return empty output.""" + adapter = PropagateRewards(direction="forward") + result = adapter.adapt([]) + + assert len(result) == 0 + + +def test_propagate_rewards_all_none(): + """All None rewards should remain None.""" + triplets = [ + TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1], image_urls=[]), + action=TokenOutput(token_ids=[2], logprobs=None), + reward=None, + done=False, + raw_call=make_chat_completion_call(), + ), + TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2], image_urls=[]), + action=TokenOutput(token_ids=[3], logprobs=None), + reward=None, + done=True, + raw_call=make_chat_completion_call(), + ), + ] + + adapter = PropagateRewards(direction="forward") + result = adapter.adapt(triplets) + + assert result[0].reward is None + assert result[1].reward is None + + +def test_propagate_rewards_prompt_completion_triplets(): + """Propagation works with PromptCompletionTriplet.""" + request = make_completion_request() + response = make_chat_completion() + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + + triplets = [ + PromptCompletionTriplet.model_construct( + observation=request, + action=response, + reward=0.6, + done=False, + raw_call=call, + ), + PromptCompletionTriplet.model_construct( + observation=request, + action=response, + reward=None, + done=True, + raw_call=call, + ), + ] + + adapter = PropagateRewards(direction="forward") + result = adapter.adapt(triplets) + + assert result[0].reward == 0.6 + assert result[1].reward == 0.6 + + +def test_propagate_rewards_prompt_completion_accumulations(): + """Propagation works with PromptCompletionAccumulation.""" + accumulations = [ + PromptCompletionAccumulation( + messages=[{"role": "user", "content": "Hi"}], + tools=None, + final_reward=None, + raw_calls=[], + ), + PromptCompletionAccumulation( + messages=[{"role": "user", "content": "Bye"}], + tools=None, + final_reward=0.9, + raw_calls=[], + ), + ] + + adapter = PropagateRewards(direction="backward") + result = adapter.adapt(accumulations) + + assert result[0].final_reward == 0.9 + assert result[1].final_reward == 0.9 + + +# ============================================================================== +# Tests for ToTokensAccumulations diagnosis feature +# ============================================================================== + + +def test_to_tokens_accumulations_diagnosis_disabled_by_default(): + """Diagnosis info should be None when diagnosis=False (default).""" + triplet1 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[1, 2], image_urls=[]), + action=TokenOutput(token_ids=[3, 4], logprobs=None), + reward=None, + done=False, + raw_call=make_chat_completion_call(), + ) + triplet2 = TokensTriplet.model_construct( + observation=TokenInput(token_ids=[5, 6, 7], image_urls=[]), # Mismatch + action=TokenOutput(token_ids=[8, 9], logprobs=None), + reward=None, + done=True, + raw_call=make_chat_completion_call(), + ) + + adapter = ToTokensAccumulations(diagnosis=False) + accumulations = adapter.adapt([triplet1, triplet2]) + + assert len(accumulations) == 2 + assert accumulations[0].diagnosis_info is None + assert accumulations[1].diagnosis_info is None + + +# ============================================================================== +# Edge case tests +# ============================================================================== + + +def test_to_tokens_triplets_invalid_prompt_token_ids_type(): + """Should raise error when prompt_token_ids is not a list of ints.""" + request = make_completion_request() + request["prompt_token_ids"] = "not a list" + response = make_chat_completion(token_ids=[4, 5]) + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + + with pytest.raises(RuntimeError, match="failed to create any triplets"): + adapter.adapt(source) + + +def test_to_tokens_triplets_invalid_completion_token_ids_type(): + """Should raise error when completion token_ids is not a list of ints.""" + request = make_completion_request(prompt_token_ids=[1, 2, 3]) + response = make_chat_completion(token_ids=["not", "ints"]) # type: ignore + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + + with pytest.raises(RuntimeError, match="failed to create any triplets"): + adapter.adapt(source) + + +def test_to_tokens_triplets_message_without_content(): + """Messages without content should be handled gracefully.""" + request = make_completion_request( + messages=[{"role": "assistant"}], # No content field + prompt_token_ids=[1, 2, 3], + ) + response = make_chat_completion(token_ids=[4, 5]) + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + triplets = adapter.adapt(source) + + # Should succeed with empty image_urls + assert len(triplets) == 1 + assert list(triplets[0].observation.image_urls) == [] + + +def test_to_tokens_triplets_message_with_none_content(): + """Messages with None content should be handled gracefully.""" + request = make_completion_request( + messages=[{"role": "assistant", "content": None}], + prompt_token_ids=[1, 2, 3], + ) + response = make_chat_completion(token_ids=[4, 5]) + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + triplets = adapter.adapt(source) + + assert len(triplets) == 1 + assert list(triplets[0].observation.image_urls) == [] + + +def test_to_tokens_triplets_mixed_content_types(): + """Messages with mixed content types should extract only image URLs.""" + request = make_completion_request( + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe these"}, + {"type": "image_url", "image_url": {"url": "http://img1.png"}}, + {"type": "text", "text": "Thanks"}, + ], + } + ], + prompt_token_ids=[1, 2, 3], + ) + response = make_chat_completion(token_ids=[4, 5]) + call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) + source = make_adapting_sequence([call]) + + adapter = ToTokensTriplets() + triplets = adapter.adapt(source) + + assert list(triplets[0].observation.image_urls) == ["http://img1.png"] + + +def test_get_reward_from_non_annotated_call(): + """get_reward should return None for non-annotated calls.""" + call = make_chat_completion_call() + adapter = ToTokensTriplets() + + reward = adapter.get_reward(call) + + assert reward is None + + +def test_get_reward_from_annotated_call_without_general_annotation(): + """get_reward should return None if no GeneralAnnotation exists.""" + call = make_annotated_call() # No reward specified + # Manually clear annotations + call = AnnotatedChatCompletionCall( + request=call.request, + response=call.response, + malformed_fields={}, + annotations=[], # Empty annotations + ) + + adapter = ToTokensTriplets() + reward = adapter.get_reward(call) + + assert reward is None + + +def test_to_triplets_sets_done_on_last(): + """to_triplets should set done=True only on the last triplet.""" + call1 = make_chat_completion_call(prompt_token_ids=[1], completion_token_ids=[2]) + call2 = make_chat_completion_call(prompt_token_ids=[3], completion_token_ids=[4]) + call3 = make_chat_completion_call(prompt_token_ids=[5], completion_token_ids=[6]) + source = make_adapting_sequence([call1, call2, call3]) + + adapter = ToTokensTriplets() + triplets = adapter.adapt(source) + + assert triplets[0].done is False + assert triplets[1].done is False + assert triplets[2].done is True From 8c8300cadb03b452a8c69d9f8a8b4d4e312b2dca Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Wed, 14 Jan 2026 18:05:23 +0800 Subject: [PATCH 38/41] fix postprocess tests --- agentlightning/adapter/postprocess.py | 8 +++++--- agentlightning/types/adapter.py | 7 +++---- tests/adapter/test_postprocess.py | 8 ++++---- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/agentlightning/adapter/postprocess.py b/agentlightning/adapter/postprocess.py index 799a5b059..c9b416c96 100644 --- a/agentlightning/adapter/postprocess.py +++ b/agentlightning/adapter/postprocess.py @@ -30,9 +30,10 @@ from .base import Adapter +TripletOrAccumulation = Union[TokensTriplet, TokensAccumulation, PromptCompletionTriplet, PromptCompletionAccumulation] T_triplet_or_accumulation = TypeVar( "T_triplet_or_accumulation", - bound=Union[TokensTriplet, TokensAccumulation, PromptCompletionTriplet, PromptCompletionAccumulation], + bound=TripletOrAccumulation, ) T_triplet = TypeVar("T_triplet", bound=Union[TokensTriplet, PromptCompletionTriplet]) @@ -340,7 +341,8 @@ class ToPromptCompletionAccumulations( def _to_messages(self, completion: ChatCompletion) -> List[ChatCompletionAssistantMessageParam]: ChatCompletionAssistantMessageParamType = TypeAdapter(ChatCompletionAssistantMessageParam) # Convert message to dict first since TypeAdapter expects dict for TypedDict validation - message_dict = completion.choices[0].message.model_dump() + # Exclude none to avoid issues like tools=None + message_dict = completion.choices[0].message.model_dump(exclude_none=True) validated_message = ChatCompletionAssistantMessageParamType.validate_python(message_dict) return [to_plain_object(validated_message, [])] @@ -391,7 +393,7 @@ def adapt(self, source: Sequence[PromptCompletionTriplet]) -> Sequence[PromptCom return accumulations -class PropagateRewards(Adapter[Sequence[T_triplet_or_accumulation], Sequence[T_triplet_or_accumulation]]): +class PropagateRewards(Adapter[Sequence[TripletOrAccumulation], Sequence[TripletOrAccumulation]]): """Propagate rewards forward or backward from one triplet or accumulation to the next.""" def __init__(self, direction: Literal["forward", "backward"]) -> None: diff --git a/agentlightning/types/adapter.py b/agentlightning/types/adapter.py index 4355117da..71f1d6664 100644 --- a/agentlightning/types/adapter.py +++ b/agentlightning/types/adapter.py @@ -28,7 +28,7 @@ ChatCompletionToolUnionParam, CompletionCreateParams, ) -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import Self from agentlightning.semconv import LinkPydanticModel, RewardPydanticModel @@ -231,8 +231,7 @@ class AdaptingSpan(Span): been converted to a different format by an adapter. """ - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) data: Any """The data in the adapted format. Could be annotations, calls, or other structured data.""" @@ -456,7 +455,7 @@ class AnnotatedChatCompletionCall(ChatCompletionCall): T_action = TypeVar("T_action") -class Triplet(Generic[T_observation, T_action], BaseModel): +class Triplet(BaseModel, Generic[T_observation, T_action]): """A triplet of observation, action and reward.""" observation: T_observation diff --git a/tests/adapter/test_postprocess.py b/tests/adapter/test_postprocess.py index c8993bc26..5543dadc9 100644 --- a/tests/adapter/test_postprocess.py +++ b/tests/adapter/test_postprocess.py @@ -3,7 +3,6 @@ """Tests for the postprocess module adapters.""" from typing import Any, Dict, List, Optional, Sequence -from unittest.mock import MagicMock import pytest from openai.types.chat import ChatCompletion @@ -138,7 +137,7 @@ def make_annotated_call( if response is None: response = make_chat_completion(token_ids=completion_token_ids, logprobs=logprobs) - annotations = [] + annotations: List[GeneralAnnotation] = [] if reward is not None: annotations.append(GeneralAnnotation(primary_reward=reward)) @@ -811,7 +810,7 @@ def test_to_prompt_completion_accumulations_no_merge_message_mismatch(): def test_to_prompt_completion_accumulations_preserves_tools(): """Accumulation should preserve tool definitions.""" - tools = [{"type": "function", "function": {"name": "my_tool", "parameters": {}}}] + tools: List[Dict[str, Any]] = [{"type": "function", "function": {"name": "my_tool", "parameters": {}}}] request = make_completion_request(tools=tools) response = make_chat_completion() call = ChatCompletionCall.model_construct(request=request, response=response, malformed_fields={}) @@ -1008,7 +1007,8 @@ def test_propagate_rewards_preserves_existing(): def test_propagate_rewards_empty_input(): """Empty input should return empty output.""" adapter = PropagateRewards(direction="forward") - result = adapter.adapt([]) + triplets: List[TokensTriplet] = [] + result = adapter.adapt(triplets) assert len(result) == 0 From e2d511824af816aff6b1ba370ea02031aff940fa Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Wed, 14 Jan 2026 18:26:11 +0800 Subject: [PATCH 39/41] fix call converter tests --- agentlightning/adapter/call.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/agentlightning/adapter/call.py b/agentlightning/adapter/call.py index 156fc7bcb..e5ead5333 100644 --- a/agentlightning/adapter/call.py +++ b/agentlightning/adapter/call.py @@ -124,15 +124,23 @@ def _parse_usages(self, span: AdaptingSpan) -> Dict[str, Any]: def _parse_agentops_tool_calls(self, span: Span) -> Optional[Dict[str, Any]]: if span.name.startswith("tool_call."): - tool_call_data = filter_and_unflatten_attributes(span.attributes, "tool.") - if isinstance(tool_call_data, dict) and "call" in tool_call_data: - # Example tool_call_data: - # {'tool.name': 'get_rooms', 'tool.parameters': '{"date": ...}', - # 'tool.call.id': 'call_owd6', 'tool.call.type': 'function'} - tool_call_data = { - **tool_call_data["call"], - "function": {k: v for k, v in tool_call_data.items() if k != "call"}, - } + tool_call_data = filter_and_unflatten_attributes(span.attributes, "tool") + if isinstance(tool_call_data, dict): + if "call" in tool_call_data: + # Example tool_call_data: + # {'tool.name': 'get_rooms', 'tool.parameters': '{"date": ...}', + # 'tool.call.id': 'call_owd6', 'tool.call.type': 'function'} + tool_call_data = { + **tool_call_data["call"], + "function": {k: v for k, v in tool_call_data.items() if k != "call"}, + } + if ( + "function" in tool_call_data + and isinstance(tool_call_data["function"], dict) + and "parameters" in tool_call_data["function"] + ): + tool_call_data["function"]["arguments"] = tool_call_data["function"].pop("parameters") # type: ignore + return cast(Dict[str, Any], tool_call_data) return None From 081aebe689b35765a543a124ed1c514f9271a8a9 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Fri, 16 Jan 2026 17:30:56 +0800 Subject: [PATCH 40/41] fix type hints --- agentlightning/adapter/base.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/agentlightning/adapter/base.py b/agentlightning/adapter/base.py index aa946f9ff..0f2e78bfe 100644 --- a/agentlightning/adapter/base.py +++ b/agentlightning/adapter/base.py @@ -5,11 +5,13 @@ from opentelemetry.sdk.trace import ReadableSpan from agentlightning.types import Span -from agentlightning.types.adapter import BaseAdaptingSequence +from agentlightning.types.adapter import BaseAdaptingSequence, BaseAdaptingSequenceItem T_inv = TypeVar("T_inv") T_from = TypeVar("T_from", contravariant=True) T_to = TypeVar("T_to", covariant=True) +T_seq_from = TypeVar("T_seq_from", contravariant=True, bound=BaseAdaptingSequenceItem) +T_seq_to = TypeVar("T_seq_to", covariant=True, bound=BaseAdaptingSequenceItem) class Adapter(Generic[T_from, T_to]): @@ -68,17 +70,20 @@ def adapt(self, source: T_from, /) -> T_to: raise NotImplementedError("Adapter.adapt() is not implemented") -class SequenceAdapter(Adapter[BaseAdaptingSequence[T_from], BaseAdaptingSequence[T_to]], Generic[T_from, T_to]): +class SequenceAdapter( + Adapter[BaseAdaptingSequence[T_seq_from], BaseAdaptingSequence[T_seq_to]], + Generic[T_seq_from, T_seq_to], +): """Base class for adapters that convert adapting sequences of data from one format to another. This class specializes [`Adapter`][agentlightning.Adapter] for working with [`AdaptingSequence`][agentlightning.AdaptingSequence] instances. """ - def adapt(self, source: BaseAdaptingSequence[T_from]) -> BaseAdaptingSequence[T_to]: + def adapt(self, source: BaseAdaptingSequence[T_seq_from]) -> BaseAdaptingSequence[T_seq_to]: return source.map(self.adapt_one) - def adapt_one(self, source: T_from) -> T_to: + def adapt_one(self, source: T_seq_from) -> T_seq_to: raise NotImplementedError(f"{self.__class__.__name__}.adapt_one() is not implemented") From 3fc4603492da00e8fac55f3fdf3883cc32ac13d7 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Fri, 16 Jan 2026 17:52:15 +0800 Subject: [PATCH 41/41] fix ToSpans type --- agentlightning/adapter/preprocess.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/agentlightning/adapter/preprocess.py b/agentlightning/adapter/preprocess.py index 39bed647b..01da8c796 100644 --- a/agentlightning/adapter/preprocess.py +++ b/agentlightning/adapter/preprocess.py @@ -14,7 +14,7 @@ from agentlightning.types.tracer import Span, SpanLike from agentlightning.utils.id import generate_id -from .base import Adapter, SequenceAdapter +from .base import Adapter T_from = TypeVar("T_from") T_to = TypeVar("T_to") @@ -146,7 +146,7 @@ def from_spans(spans: Sequence[Span], logs_invalid_parent: bool = True) -> _Tree return graph -class ToSpans(SequenceAdapter[SpanLike, Span]): +class ToSpans(Adapter[Sequence[SpanLike], Sequence[Span]]): """Normalize span-like objects (e.g., OpenTelemetry `ReadableSpan`) to [Span][agentlightning.Span]. This adapter handles conversion from various span formats to the internal Span type. @@ -187,6 +187,14 @@ def adapt_one(self, source: SpanLike) -> Span: sequence_id=self.default_sequence_id, ) + def adapt(self, source: Sequence[SpanLike]) -> Sequence[Span]: + """Convert a sequence of span-like objects to Spans. + + Args: + source: A sequence of Span or OpenTelemetry ReadableSpan objects. + """ + return [self.adapt_one(span) for span in source] + class ToTree(Adapter[Sequence[Span], Tree[AdaptingSpan]]): """Convert a sequence of spans into a tree structure.