diff --git a/agentlightning/verl/config.yaml b/agentlightning/verl/config.yaml index 0d0f68593..f7dc2ff9b 100644 --- a/agentlightning/verl/config.yaml +++ b/agentlightning/verl/config.yaml @@ -14,6 +14,14 @@ agentlightning: trajectory_max_response_length: 8192 # supported in trajectory level aggregation, suggest to set as maximum length for the cumulative agent responses in the full trajectory, i.e., n_turns * (max_response_length + max_prompt_length) debug: False # supported in trajectory level aggregation, enable to diagnose trace merging failures mismatch_log_dir: ./mismatch_cases # supported in trajectory level aggregation with debug=True, directory to store logs of mismatch cases + # ========================================================================= + # Tool Call Filtering (Youtu-Agent style) + # When enabled, filters out "unexpected tool call" turns where the model + # continues generating after a tool call instead of stopping properly. + # This helps prevent entropy explosion during RL training. + # Reference: contrib/youtu-agent-lightning branch + # ========================================================================= + filter_unexpected_tool_calls: False # set to True to enable filtering data: filter_overlong_prompts: false diff --git a/agentlightning/verl/daemon.py b/agentlightning/verl/daemon.py index 98c58f330..daa741729 100644 --- a/agentlightning/verl/daemon.py +++ b/agentlightning/verl/daemon.py @@ -2,6 +2,7 @@ import asyncio import json +import logging import os import random import socket @@ -19,12 +20,27 @@ from tensordict import TensorDict from verl import DataProto +# ============================================================================= +# Tool Call Filtering Support (for filtering unexpected tool call turns) +# Reference: Youtu-Agent implementation in contrib/youtu-agent-lightning branch +# The ToolParser extracts tool calls from response tokens to detect cases where +# the model continues generating after a tool call (hallucinated tool responses) +# instead of properly stopping with <|im_end|> +# ============================================================================= +try: + from verl.experimental.agent_loop.tool_parser import ToolParser + TOOL_PARSER_AVAILABLE = True +except ImportError: + TOOL_PARSER_AVAILABLE = False + from agentlightning import LLM, AgentLightningServer, NamedResources, RolloutLegacy from agentlightning.adapter.triplet import TracerTraceToTriplet, TraceToTripletBase from agentlightning.llm_proxy import LLMProxy, ModelConfig from agentlightning.store.base import LightningStore from agentlightning.types import EnqueueRolloutRequest, Rollout, RolloutConfig, Task +logger = logging.getLogger(__name__) + __all__ = [ "AgentModeDaemon", "get_left_padded_ids_and_attention_mask", @@ -283,6 +299,11 @@ def __init__( self._proxy_thread: Optional[threading.Thread] = None self.is_train = True + # Tool Call Filtering Setup (config key: trace_aggregator["filter_unexpected_tool_calls"]) + self.tool_parser = None + self.toolcall_candidate_token_last2_list = [] + self._setup_tool_call_filter(train_information, tokenizer) + def _internal_loop_runner(self): """Run the internal loop.""" loop = asyncio.new_event_loop() @@ -291,6 +312,112 @@ def _internal_loop_runner(self): loop.run_forever() loop.close() + # ========================================================================= + # Tool Call Filtering Methods + # Reference: Youtu-Agent implementation (contrib/youtu-agent-lightning) + # Purpose: Filter out "unexpected tool call turns" where the model continues + # generating text after a tool call instead of stopping properly. + # ========================================================================= + + def _setup_tool_call_filter(self, train_information: Dict[str, Any], tokenizer: Any) -> None: + """Initialize tool parser and valid ending token patterns for filtering. + + Uses apply_chat_template to auto-detect the correct tool call ending tokens + rather than hardcoding token IDs. Also builds variants with eos/pad tokens + to allow various ending conditions and prevent over-filtering. + + Args: + train_information: Training config containing 'format' for toolcall format + tokenizer: The tokenizer used for encoding/decoding + """ + if not TOOL_PARSER_AVAILABLE: + print("Warning: ToolParser not available, tool call filtering disabled.") + self.tool_parser = None + return + + toolcall_format = train_information.get("format", "hermes") + self.tool_parser = ToolParser.get_tool_parser(toolcall_format, tokenizer) + + # Use chat template to detect the actual tool call ending token sequence + # Example uses calculator tool to match calc-x example for consistency + tools_examples = [{ + "type": "function", + "name": "calculate", + "description": "Evaluate a mathematical expression", + "parameters": { + "type": "object", + "properties": { + "expression": {"type": "string", "description": "Math expression, e.g., '2 + 3 * 4'"}, + }, + "required": ["expression"], + }, + }] + toolcall_message_examples = [ + {"role": "user", "content": "What is 15 + 27?"}, + {"role": "assistant", "content": "", "tool_calls": [{ + "id": "call_001", + "type": "function", + "function": {"name": "calculate", "arguments": '{"expression":"15 + 27"}'}, + }]}, + ] + toolcall_example_chat_template = tokenizer.apply_chat_template( + toolcall_message_examples, tools=tools_examples, + add_generation_prompt=False, tokenize=False, + ) + # Extract the last 2 tokens from the chat template output (e.g., <|im_end|>) + toolcall_example_token_last2 = tokenizer.encode(toolcall_example_chat_template.strip())[-2:] + + eos_token_id = tokenizer.eos_token_id + pad_token_id = tokenizer.pad_token_id + + # Build candidate list: the detected ending + variants with eos/pad + # This allows various tool-call ending conditions to prevent over-filtering + toolcall_candidate_token_last2_list = [toolcall_example_token_last2] + if toolcall_example_token_last2[-1] != eos_token_id: + toolcall_candidate_token_last2_list.append([toolcall_example_token_last2[0], eos_token_id]) + if toolcall_example_token_last2[-1] != pad_token_id: + toolcall_candidate_token_last2_list.append([toolcall_example_token_last2[0], pad_token_id]) + + self.toolcall_candidate_token_last2_list = toolcall_candidate_token_last2_list + logger.info( + f"Tool call filter initialized: {eos_token_id=}, {pad_token_id=}, " + f"candidates={self.toolcall_candidate_token_last2_list}" + ) + + + def _is_valid_tool_call_response(self, response_ids: List[int]) -> Tuple[bool, bool]: + """Check if a response with tool calls ends with valid ending tokens. + + Uses strict last-2-token check (same as youtu branch): the response must end + with one of the candidate token pairs (e.g., <|im_end|> or + <|endoftext|>). + + Args: + response_ids: List of token IDs from the model's response + + Returns: + Tuple of (has_tool_calls, has_valid_ending): + - has_tool_calls: True if the response contains tool calls + - has_valid_ending: True if no tool calls, or tool calls with proper ending + """ + if self.tool_parser is None: + return False, True + + _, tool_calls = asyncio.run(self.tool_parser.extract_tool_calls(response_ids)) + + if not tool_calls: + return False, True + + if len(response_ids) < 2: + return True, False + + # Strict last-2 check against all valid ending candidates + for candidate in self.toolcall_candidate_token_last2_list: + if response_ids[-2] == candidate[0] and response_ids[-1] == candidate[1]: + return True, True + + return True, False + # Multimodal utilities for M-RoPE position embeddings def _is_mrope_model(self) -> bool: @@ -821,6 +948,12 @@ def get_train_data_batch( finished_id_to_sample_info: Dict[str, Dict[str, Any]] = {} finished_id_to_final_reward: Dict[str, float] = {} sample_with_reward_count = 0 + + # Tool call filtering metrics + n_total_turns_before_filter = 0 + n_unexpected_tool_calls = 0 + n_skipped_rollouts_by_filter = 0 + for rollout_id, rollout in self._completed_rollouts_v0.items(): original_sample = self._task_id_to_original_sample[rollout_id] sample_with_reward_count += int(rollout.final_reward is not None) @@ -842,6 +975,41 @@ def get_train_data_batch( } for t in rollout.triplets ] + + # Filter void/unexpected tool call turns (Youtu-Agent style) + # When config is OFF: only count for metrics, no filtering + # When config is ON: apply both void and unexpected tool call filtering + if self.tool_parser is not None: + n_total_turns_before_filter += len(trace_list) + + # Count unexpected tool calls (always, for metrics) + for trace in trace_list: + if len(trace["prompt_ids"]) and len(trace["response_ids"]): + has_tool_calls, has_valid_ending = self._is_valid_tool_call_response(trace["response_ids"]) + if has_tool_calls and not has_valid_ending: + n_unexpected_tool_calls += 1 + + # Apply filtering only when config is enabled + if self.trace_aggregator.get("filter_unexpected_tool_calls", False): + # 1. Filter void turns (empty prompt or response) + trace_list = [ + t for t in trace_list + if len(t["prompt_ids"]) and len(t["response_ids"]) + ] + # 2. Filter unexpected tool call turns + trace_list_filtered = [] + for trace in trace_list: + has_tool_calls, has_valid_ending = self._is_valid_tool_call_response(trace["response_ids"]) + if has_tool_calls and not has_valid_ending: + continue # Skip invalid turns + trace_list_filtered.append(trace) + # 3. Skip rollout if only 1 or fewer valid turns remain + if len(trace_list_filtered) <= 1: + n_skipped_rollouts_by_filter += 1 + finished_id_to_final_reward[rollout_id] = final_reward + continue + trace_list = trace_list_filtered + info = { "reward": final_reward, "trace_list": trace_list, @@ -1123,6 +1291,14 @@ def get_train_data_batch( and self.trace_aggregator.get("debug", False) else {} ), + "training/n_unexpected_tool_calls": n_unexpected_tool_calls, + "training/n_total_turns_before_filter": n_total_turns_before_filter, + "training/unexpected_tool_call_ratio": ( + n_unexpected_tool_calls / n_total_turns_before_filter if n_total_turns_before_filter > 0 else 0.0 + ), + "training/n_skipped_rollouts_by_filter": n_skipped_rollouts_by_filter, + "training/filter_enabled": float(self.trace_aggregator.get("filter_unexpected_tool_calls", False)), + "training/reward_std": np.std(list(finished_id_to_final_reward.values())), } # Add non-tensor data for advantage calculation and logging diff --git a/agentlightning/verl/trainer.py b/agentlightning/verl/trainer.py index 94b3276f8..62e762159 100644 --- a/agentlightning/verl/trainer.py +++ b/agentlightning/verl/trainer.py @@ -421,6 +421,7 @@ def _train_step(self, batch_dict: dict) -> dict: self._dump_generations( inputs=inputs, outputs=outputs, + gts=[""] * len(inputs), scores=scores, reward_extra_infos_dict=reward_extra_infos_dict, dump_path=rollout_data_dir, diff --git a/examples/calc_x/train_calc_agent.py b/examples/calc_x/train_calc_agent.py index 5d73edeac..61048bad5 100644 --- a/examples/calc_x/train_calc_agent.py +++ b/examples/calc_x/train_calc_agent.py @@ -40,6 +40,12 @@ import agentlightning as agl from agentlightning.env_var import LightningEnvVar, resolve_bool_env_var, resolve_str_env_var +# Ensure venv bin is in PATH (needed for uvx/mcp-server-calculator in Ray workers) +_script_dir = os.path.dirname(os.path.abspath(__file__)) +_venv_bin = os.path.join(_script_dir, "..", "..", ".venv", "bin") +if os.path.isdir(_venv_bin): + os.environ["PATH"] = os.path.abspath(_venv_bin) + ":" + os.environ.get("PATH", "") + def verl_default_config() -> Dict[str, Any]: config = { @@ -123,6 +129,11 @@ def train( trajectory_level: bool = False, weave: bool, mongo_uri: Optional[str], + filter_unexpected_tool_calls: bool = False, + experiment_name: Optional[str] = None, + n_gpus: int = 1, + checkpoint_dir: str = "/home/jovyan/msra/experiments/checkpoints", + resume: bool = False, ): """The training entrypoint function for Calc-X agent with VERL algorithm. @@ -141,6 +152,7 @@ def train( trajectory_level: Whether to enable trajectory level in trace aggregator. weave: Whether to enable Weave tracing. mongo_uri: MongoDB URI to use for the store. + experiment_name: Custom experiment name for W&B logging. """ # Load datasets (respect CLI file paths) train_dataset = cast(agl.Dataset[MathProblem], HuggingFaceDataset.from_parquet(train_file).to_list()) # type: ignore @@ -156,6 +168,26 @@ def train( if model: config["actor_rollout_ref"]["model"]["path"] = model + # Override experiment name if provided (for W&B logging) + if experiment_name: + config["trainer"]["experiment_name"] = experiment_name + print(f"Using custom experiment name: {experiment_name}") + + # Override n_gpus_per_node for multi-GPU training + if n_gpus > 1: + config["trainer"]["n_gpus_per_node"] = n_gpus + print(f"Multi-GPU training enabled: n_gpus_per_node={n_gpus}") + + # Set checkpoint directory and conversation dump directory + config["trainer"]["default_local_dir"] = checkpoint_dir + config["trainer"]["resume_mode"] = "auto" if resume else "disable" + conversations_dir = checkpoint_dir.replace("checkpoints", "conversations") + config["trainer"]["rollout_data_dir"] = conversations_dir + os.makedirs(conversations_dir, exist_ok=True) + print(f"Checkpoint directory: {checkpoint_dir}") + print(f"Conversations directory: {conversations_dir}") + print(f"Resume mode: {config['trainer']['resume_mode']}") + # Enable LoRA configuration if requested if lora: config["actor_rollout_ref"]["model"]["lora_rank"] = lora_rank @@ -175,6 +207,19 @@ def train( } print("Trajectory level enabled in trace aggregator.") + # ========================================================================= + # Tool Call Filtering (Youtu-Agent style) + # Filters out turns where the model generates unexpected content after + # a tool call (hallucinated tool responses). Helps prevent entropy explosion. + # ========================================================================= + if filter_unexpected_tool_calls: + if "agentlightning" not in config: + config["agentlightning"] = {"trace_aggregator": {}} + if "trace_aggregator" not in config["agentlightning"]: + config["agentlightning"]["trace_aggregator"] = {} + config["agentlightning"]["trace_aggregator"]["filter_unexpected_tool_calls"] = True + print("Tool call filtering enabled (Youtu-Agent style).") + # CI toggle keeps everything else the same but you can tweak the lightweight bits here if desired if ci or ci_fast: # Config the experiment name and project name so that they are available to CI @@ -290,6 +335,35 @@ def main(): default=None, help="MongoDB URI to use for the store.", ) + parser.add_argument( + "--filter-unexpected-tool-calls", + action="store_true", + help="Enable Youtu-Agent style tool call filtering. " + "Filters out turns where the model generates unexpected content after a tool call.", + ) + parser.add_argument( + "--experiment-name", + type=str, + default=None, + help="Custom experiment name for W&B logging (default: calc_x or auto-generated for CI)", + ) + parser.add_argument( + "--n-gpus", + type=int, + default=1, + help="Number of GPUs per node for distributed training (default: 1)", + ) + parser.add_argument( + "--checkpoint-dir", + type=str, + default="/home/jovyan/msra/experiments/checkpoints", + help="Directory to save checkpoints (default: /home/jovyan/msra/experiments/checkpoints)", + ) + parser.add_argument( + "--resume", + action="store_true", + help="Resume training from the latest checkpoint in checkpoint-dir", + ) args = parser.parse_args() @@ -321,6 +395,11 @@ def main(): trajectory_level=args.trajectory_level, weave=args.weave, mongo_uri=args.mongo_uri, + filter_unexpected_tool_calls=args.filter_unexpected_tool_calls, + experiment_name=args.experiment_name, + n_gpus=args.n_gpus, + checkpoint_dir=args.checkpoint_dir, + resume=args.resume, )