diff --git a/.github/workflows/examples-chartqa.yml b/.github/workflows/examples-chartqa.yml index 98be679b8..9c330655d 100644 --- a/.github/workflows/examples-chartqa.yml +++ b/.github/workflows/examples-chartqa.yml @@ -99,7 +99,7 @@ jobs: set -euo pipefail source .venv/bin/activate cd examples/chartqa - uv run --no-sync vllm serve Qwen/Qwen2-VL-2B-Instruct \ + uv run --no-sync vllm serve Qwen/Qwen2.5-VL-3B-Instruct \ --gpu-memory-utilization 0.9 \ --max-model-len 4096 \ --allowed-local-media-path "$(pwd)/data" \ @@ -131,7 +131,7 @@ jobs: env: USE_LLM_PROXY: "1" OPENAI_API_BASE: http://localhost:8088/v1 - OPENAI_MODEL: Qwen/Qwen2-VL-2B-Instruct + OPENAI_MODEL: Qwen/Qwen2.5-VL-3B-Instruct - name: Stop vLLM Server run: | diff --git a/examples/chartqa/README.md b/examples/chartqa/README.md index 5bd905d1b..d49bae945 100644 --- a/examples/chartqa/README.md +++ b/examples/chartqa/README.md @@ -2,11 +2,14 @@ [![chartqa workflow status](https://github.com/microsoft/agent-lightning/actions/workflows/badge-chartqa.yml/badge.svg)](https://github.com/microsoft/agent-lightning/actions/workflows/examples-chartqa.yml) -This example demonstrates training a visual reasoning agent on the ChartQA dataset using Agent-Lightning with the VERL algorithm and LangGraph framework. The agent answers questions about charts through a multi-step workflow with self-refinement. +This example demonstrates training a visual reasoning agent on the ChartQA dataset using Agent-Lightning with the VERL algorithm. The agent uses a two-step pipeline to answer questions about charts: + +1. **Extract** [model + image]: Read all relevant values from the chart +2. **Compute** [model text-only]: Calculate the final answer from extracted data ## Requirements -This example requires a single node with at least one 40GB GPU. Install dependencies with: +This example requires a single node with at least 2x 40GB GPUs. Install dependencies with: ```bash uv sync --frozen \ @@ -41,11 +44,11 @@ This downloads the ChartQA dataset from HuggingFace (`HuggingFaceM4/ChartQA`), s | File/Directory | Description | |----------------|-------------| -| `chartqa_agent.py` | Chart reasoning agent using LangGraph with multi-step workflow (observe → extract → calculate → check → refine) | +| `chartqa_agent.py` | Two-step chart reasoning agent (extract → compute) | | `train_chartqa_agent.py` | Training script using VERL algorithm with configurable hyperparameters (debug, qwen) | -| `debug_chartqa_agent.py` | Debugging script to test the agent with cloud APIs or a local vLLM proxy | +| `debug_chartqa_agent.py` | Debugging script to test the agent with cloud APIs or a local vLLM server | | `prepare_data.py` | Script to download ChartQA dataset from HuggingFace and prepare parquet files | -| `prompts.py` | Prompt templates for the agent workflow | +| `prompts.py` | Prompt templates for extract and compute steps | | `multimodal_utils.py` | Utility functions for encoding images to base64 | | `env_var.py` | Environment variables and configurations | | `data/` | Directory containing images and parquet files after download | @@ -69,43 +72,44 @@ export OPENAI_MODEL=gpt-4o python debug_chartqa_agent.py ``` -### Debugging with Local Model (LLMProxy) +### Debugging with Local Model -To test the agent with a local vLLM server and LLMProxy: +To test the agent with a local vLLM server: ```bash -# Start a vLLM server (specify image path for VLM) +# Start a vLLM server (specify image path for vision model) export CHARTQA_DATA_DIR= -vllm serve Qwen/Qwen2-VL-2B-Instruct \ +vllm serve Qwen/Qwen2.5-VL-3B-Instruct \ --gpu-memory-utilization 0.6 \ --max-model-len 4096 \ --allowed-local-media-path $CHARTQA_DATA_DIR \ --enable-prefix-caching \ --port 8088 -# Run the agent with LLMProxy -USE_LLM_PROXY=1 \ - OPENAI_API_BASE=http://localhost:8088/v1 \ - OPENAI_MODEL=Qwen/Qwen2-VL-2B-Instruct \ +# Run the agent +OPENAI_API_BASE=http://localhost:8088/v1 \ + OPENAI_MODEL=Qwen/Qwen2.5-VL-3B-Instruct \ python debug_chartqa_agent.py ``` ### Training with Local Model -```bash -python train_chartqa_agent.py debug --n-runners 2 -``` +You can use an external store server (recommended for distributed setups) -You can also use an external store server (recommended for distributed setups), first start the store: ```bash agl store --port 4747 ``` +Debug script with the external store address with 10 training steps: + +```bash +AGL_MANAGED_STORE=0 python train_chartqa_agent.py debug --n-runners 32 --external-store-address http://localhost:4747 +``` -Then run the training script with the external store address: +Run full training with: ```bash -AGL_MANAGED_STORE=0 python train_chartqa_agent.py qwen --external-store-address http://localhost:4747 +AGL_MANAGED_STORE=0 python train_chartqa_agent.py qwen --n-runners 32 --external-store-address http://localhost:4747 ``` If you want to track experiments with Weights & Biases, set the `WANDB_API_KEY` environment variable before training. diff --git a/examples/chartqa/chartqa_agent.py b/examples/chartqa/chartqa_agent.py index 3b14381f5..4955c7c58 100644 --- a/examples/chartqa/chartqa_agent.py +++ b/examples/chartqa/chartqa_agent.py @@ -1,17 +1,15 @@ # Copyright (c) Microsoft. All rights reserved. -"""ChartQA agent demonstrating LangGraph-based visual reasoning with refinement. +"""ChartQA agent with two-step pipeline. -This module defines `ChartQAAgent` plus the supporting prompt utilities used by -`debug_chartqa_agent.py` and `train_chartqa_agent.py`. +This module implements a pipeline with minimal sequential steps +to reduce error propagation (exposure bias). -1. `analyze_chart` observes and summarizes the chart. -2. `extract_data` calls a text-only LLM to extract the requested values. -3. `calculate_answer` runs calculations grounded in prior steps. -4. `check_answer` verifies reasoning quality. -5. `refine_answer` conditionally patches mistakes before responding. +Two-step mode: +`extract_data` [model + image]: Read all relevant values from chart +`compute_answer` [model text-only]: Calculate final answer from extracted data -Example usage can be found in `debug_chartqa_agent.py` and `train_chartqa_agent.py`. +Model endpoint and parameters are configured automatically from resources during training. """ from __future__ import annotations @@ -19,21 +17,18 @@ import logging import os import re -from typing import Any, Dict, Literal, cast +from typing import Any, Dict, cast import env_var as chartqa_env_var import termcolor from langchain.chat_models import BaseChatModel, init_chat_model -from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage +from langchain_core.messages import AnyMessage, HumanMessage from langgraph.graph import END, START, MessagesState, StateGraph from langgraph.graph.state import CompiledStateGraph from multimodal_utils import encode_image_to_base64 from prompts import ( - ANALYZE_CHART_PROMPT, - CALCULATE_ANSWER_PROMPT, - CHECK_ANSWER_PROMPT, + COMPUTE_ANSWER_PROMPT, EXTRACT_DATA_PROMPT, - REFINE_ANSWER_PROMPT, ) import agentlightning as agl @@ -41,108 +36,140 @@ logger = logging.getLogger("chartqa_agent") -class ChartState(MessagesState): +def evaluate_answer(predicted: str, ground_truth: str, raise_on_error: bool = False) -> float: + """Evaluate answer accuracy. + + Returns: + 1.0 for exact match or numeric match within 2% + 0.5 for substring match + 0.0 otherwise + """ + try: + pred = predicted.lower().strip() + gt = ground_truth.lower().strip() + + # Exact match + if pred == gt: + return 1.0 + + # Try numeric comparison + try: + pred_num = float(pred.replace(",", "")) + gt_num = float(gt.replace(",", "")) + if abs(pred_num - gt_num) / max(abs(gt_num), 1e-9) < 0.02: + return 1.0 + except (ValueError, AttributeError): + pass + + # Partial credit for substring match + if pred in gt or gt in pred: + return 0.5 + + return 0.0 + except Exception as e: + if raise_on_error: + raise + logger.exception(f"Error evaluating answer: {e}") + return 0.0 + + +class ChartQAState(MessagesState): + """State for the ChartQA agent.""" + question: str image_path: str - observation: str + # Extraction step extracted_data: str + # Computation step calculation: str answer: str - feedback: str + # Tracking num_turns: int + num_model_calls: int messages: list[AnyMessage] class ChartQAAgent(agl.LitAgent[Dict[str, Any]]): - """LangGraph-powered ChartQA agent with multi-step reasoning and refinement. + """ChartQA agent with reduced exposure bias. - The implementation shares the same [`agl.LitAgent`][agentlightning.LitAgent] interface as - the Calc-X sample agent but augments it with image handling and LangGraph state tracking. + Uses a simplified 2-step pipeline to minimize error propagation. """ def __init__( self, - model_name: str | None = None, - max_turns: int = 3, + model_name: str = "Qwen/Qwen2.5-VL-3B-Instruct", + endpoint: str = "http://localhost:8090/v1", + max_turns: int = 1, debug: bool = False, - endpoint: str | None = None, temperature: float = 0.0, use_base64_images: bool = False, ): - self.debug = debug - self.max_turns = max_turns - self.use_base64_images = use_base64_images + """Initialize the ChartQA agent. + + Args: + model_name: Model name for the vision-language model. + endpoint: API endpoint for the model. + max_turns: Max tool call iterations (default 1). + debug: Enable debug output. + temperature: Sampling temperature. + use_base64_images: Whether to encode images as base64. + """ self.model_name = model_name self.endpoint = endpoint + self.max_turns = max_turns + self.debug = debug self.temperature = temperature + self.use_base64_images = use_base64_images - self._llm: BaseChatModel | None = None - self._graph: CompiledStateGraph[ChartState] | None = None + self._model: BaseChatModel | None = None + self._graph: CompiledStateGraph[ChartQAState] | None = None - def _create_llm(self) -> BaseChatModel: - if self.model_name is None: - raise ValueError("model_name is required for creating LLM") + def _create_model(self) -> BaseChatModel: + """Create the model instance.""" return init_chat_model( self.model_name, model_provider="openai", openai_api_base=self.endpoint, openai_api_key=chartqa_env_var.OPENAI_API_KEY, temperature=self.temperature, - max_retries=2, - max_tokens=1024, - timeout=300, + max_retries=5, + max_tokens=2048, + timeout=1200, ) - def update_llm_config(self, model_name: str, endpoint: str | None, temperature: float | None) -> None: - """Update the LLM configuration. Re-create the LLM if the configuration is changed.""" - updated: bool = False - if model_name != self.model_name: + def _ensure_model(self) -> BaseChatModel: + """Ensure the model is created and cached.""" + if self._model is None: + self._model = self._create_model() + return self._model + + def update_llm_config( + self, + model_name: str | None = None, + endpoint: str | None = None, + temperature: float | None = None, + ) -> None: + """Update model configurations.""" + updated = False + + if model_name is not None and model_name != self.model_name: self.model_name = model_name updated = True - if endpoint != self.endpoint: + if endpoint is not None and endpoint != self.endpoint: self.endpoint = endpoint updated = True - if temperature != self.temperature: + if temperature is not None and temperature != self.temperature: self.temperature = temperature updated = True - if updated: - self._llm = self._create_llm() - - def _ensure_llm(self) -> BaseChatModel: - """Ensure the LLM is created and cached.""" - if self._llm is None: - self._llm = self._create_llm() - return self._llm - - def invoke_prompt(self, prompt: Any) -> AnyMessage: - """Invoke LLM with prompt.""" - if self.debug: - for message in prompt.messages: - termcolor.cprint(message.pretty_repr(), "blue") - - try: - result = self._ensure_llm().invoke(prompt) - except Exception as e: - logger.error(f"Failed to invoke prompt: {e}") - result = self._ensure_llm().invoke([HumanMessage(content="Please provide a reasonable answer.")]) - - if self.debug: - termcolor.cprint(result.pretty_repr(), "green") - - return result # type: ignore - def invoke_prompt_with_image(self, prompt_text: str, image_path: str) -> str: - """Invoke vision-language model with image. + if updated: + self._model = self._create_model() - Handles both local vLLM (file:// URLs) and cloud APIs (base64 encoding). - Cloud APIs (OpenAI, Anthropic, Google, Azure, etc.) require base64 encoding. - """ - # Determine image URL format based on endpoint + def invoke_with_image(self, prompt_text: str, image_path: str) -> str: + """Invoke the model with an image.""" if self.use_base64_images: - # Cloud APIs require base64 encoding for local files image_url = encode_image_to_base64(image_path) else: - # Local vLLM supports file:// URLs if not image_path.startswith("file://"): image_path = f"file://{os.path.realpath(image_path)}" image_url = image_path @@ -158,193 +185,190 @@ def invoke_prompt_with_image(self, prompt_text: str, image_path: str) -> str: ] if self.debug: - termcolor.cprint(f"[VLM Call] {prompt_text[:100]}...", "blue") + termcolor.cprint(f"[Model Input] {prompt_text[:200]}...", "cyan") try: - result = self._ensure_llm().invoke(messages) - response = result.content if hasattr(result, "content") else str(result) # type: ignore + result = self._ensure_model().invoke(messages) + response = result.content if hasattr(result, "content") else str(result) except Exception as e: - logger.error(f"Failed to invoke VLM: {e}") - response = "Unable to analyze chart" + logger.error(f"Failed to invoke model: {e}") + response = "error" if self.debug: - termcolor.cprint(f"[VLM Response] {response[:200]}...", "green") + termcolor.cprint(f"[Model Output] {response[:300]}...", "magenta") return response # type: ignore - def extract_content(self, text: str, tag: str) -> str: - """Extract content between XML-style tags.""" - match = re.search(rf"<{tag}>(.*?)", text, re.DOTALL) - return match.group(1).strip() if match else "" + def invoke_text_only(self, prompt_text: str) -> str: + """Invoke the model with text only.""" + messages = [{"role": "user", "content": prompt_text}] - def analyze_chart(self, state: ChartState) -> ChartState: - """Step 1: Observe and describe the chart.""" - prompt: Any = ANALYZE_CHART_PROMPT.invoke({"question": state["question"]}) # type: ignore - prompt_text = prompt.messages[1].content + if self.debug: + termcolor.cprint(f"[Model Text Input] {prompt_text[:200]}...", "blue") - result_text = self.invoke_prompt_with_image(prompt_text, state["image_path"]) + try: + result = self._ensure_model().invoke(messages) + response = result.content if hasattr(result, "content") else str(result) + except Exception as e: + logger.error(f"Failed to invoke model: {e}") + response = "error" - observation = self.extract_content(result_text, "observe") - if not observation: - observation = result_text + if self.debug: + termcolor.cprint(f"[Model Text Output] {response[:300]}...", "green") - return { # type: ignore - **state, - "observation": observation, - "num_turns": 1, - "messages": [HumanMessage(content=result_text)], - } + return response # type: ignore - def extract_data(self, state: ChartState) -> ChartState: - """Step 2: Extract specific data values.""" - prompt: Any = EXTRACT_DATA_PROMPT.invoke( # type: ignore - { - "observation": state["observation"], - "question": state["question"], - } - ) - result = self.invoke_prompt(prompt) + def extract_content(self, text: str, tag: str) -> str: + """Extract content between XML-style tags.""" + matches = re.findall(rf"<{tag}>(.*?)", text, re.DOTALL) + if not matches: + return "" + return "\n".join(match.strip() for match in matches) + + def normalize_answer(self, answer: str) -> str: + """Normalize answer to match evaluation format. + + Cleans up common issues: + - Remove units (%, $, million, etc.) + - Remove extra punctuation + - Clean whitespace + - Format numbers consistently + """ + if not answer: + return "" - extracted_data = self.extract_content(result.content, "extract") # type: ignore - if not extracted_data: - extracted_data = result.content # type: ignore + ans = answer.strip() - return { # type: ignore - **state, - "extracted_data": extracted_data, # type: ignore - "messages": [*state.get("messages", []), result], - } + # Remove common prefixes/suffixes + prefixes = ["the answer is", "answer:", "approximately", "about", "around"] + for prefix in prefixes: + if ans.lower().startswith(prefix): + ans = ans[len(prefix) :].strip() - def calculate_answer(self, state: ChartState) -> ChartState: - """Step 3: Calculate and provide answer.""" - prompt: Any = CALCULATE_ANSWER_PROMPT.invoke( # type: ignore - { - "extracted_data": state["extracted_data"], - "question": state["question"], - } + # Remove trailing punctuation + ans = ans.rstrip(".,;:") + + # Remove units and symbols (but keep the number) + # Pattern: number followed by unit + unit_pattern = ( + r"^([-+]?\d*\.?\d+)\s*(%|percent|million|billion|thousand|USD|\$|€|£|dollars?|years?|km|miles?|kg|lbs?)\.?$" ) - result = self.invoke_prompt(prompt) + match = re.match(unit_pattern, ans, re.IGNORECASE) + if match: + ans = match.group(1) - calculation = self.extract_content(result.content, "calculate") # type: ignore - answer = self.extract_content(result.content, "answer") # type: ignore - if not answer: - answer = cast(str, result.content) # type: ignore + # Remove leading $ or currency + if ans.startswith("$"): + ans = ans[1:].strip() - return { # type: ignore - **state, - "calculation": calculation, - "answer": answer, - "messages": [*state.get("messages", []), result], - } + # Remove trailing % if present + if ans.endswith("%"): + ans = ans[:-1].strip() - def check_answer(self, state: ChartState) -> ChartState: - """Step 4: Verify answer quality.""" - prompt: Any = CHECK_ANSWER_PROMPT.invoke( # type: ignore - { - "observation": state["observation"], - "extracted_data": state["extracted_data"], - "question": state["question"], - "answer": state["answer"], - "calculation": state.get("calculation", "No calculation shown"), - } - ) - result = self.invoke_prompt(prompt) + # Clean up number formatting + # Remove commas from numbers + if re.match(r"^[-+]?\d{1,3}(,\d{3})*(\.\d+)?$", ans): + ans = ans.replace(",", "") + + # Try to parse as number and format consistently + try: + num = float(ans) + # If it's a whole number, format without decimals + if num == int(num): + ans = str(int(num)) + else: + # Round to 2 decimal places, remove trailing zeros + ans = f"{num:.2f}".rstrip("0").rstrip(".") + except ValueError: + # Not a number, keep as text + # Clean up quotes if present + ans = ans.strip("'\"") + + return ans + + # ========================================================================= + # Two-Step Pipeline Nodes + # ========================================================================= + + def extract_data(self, state: ChartQAState) -> ChartQAState: + """Step 1 [model + image]: Extract all relevant data from chart.""" + prompt: Any = EXTRACT_DATA_PROMPT.invoke({"question": state["question"]}) + prompt_text = "\n".join(msg.content for msg in prompt.messages) # type: ignore + + result_text = self.invoke_with_image(prompt_text, state["image_path"]) + + # Extract data JSON + extracted_data = self.extract_content(result_text, "data") + if not extracted_data: + extracted_data = result_text if self.debug: - termcolor.cprint(f"[Check] {result.content}", "yellow") # type: ignore + termcolor.cprint(f"[Extract] Data: {extracted_data[:300]}", "yellow") return { # type: ignore **state, - "feedback": result.content, # type: ignore - "messages": [*state.get("messages", []), *prompt.messages, result], + "extracted_data": extracted_data, + "num_model_calls": state.get("num_model_calls", 0) + 1, + "messages": [HumanMessage(content=result_text)], } - def refine_answer(self, state: ChartState) -> ChartState: - """Step 5: Refine answer based on feedback.""" - prompt: Any = REFINE_ANSWER_PROMPT.invoke( # type: ignore - { - "observation": state["observation"], - "extracted_data": state["extracted_data"], - "question": state["question"], - "answer": state["answer"], - "calculation": state.get("calculation", ""), - "feedback": state["feedback"], - } + def compute_answer(self, state: ChartQAState) -> ChartQAState: + """Step 2 [model text-only]: Compute answer from extracted data.""" + prompt: Any = COMPUTE_ANSWER_PROMPT.invoke( + {"question": state["question"], "extracted_data": state["extracted_data"]} ) - result = self.invoke_prompt(prompt) - content: str = result.content # type: ignore + prompt_text = "\n".join(msg.content for msg in prompt.messages) # type: ignore - new_extracted = self.extract_content(content, "extract") - extracted_data = new_extracted if new_extracted else state["extracted_data"] + result_text = self.invoke_text_only(prompt_text) - new_calculation = self.extract_content(content, "calculate") + # Extract answer and calculation + raw_answer = self.extract_content(result_text, "answer") + calculation = self.extract_content(result_text, "think") - new_answer = self.extract_content(content, "answer") - if not new_answer: - new_answer = content + if not raw_answer: + # Fallback: try to find any number or short text + raw_answer = result_text.strip().split("\n")[-1].strip() + + # Normalize answer for better evaluation matching + answer = self.normalize_answer(raw_answer) + + if self.debug: + termcolor.cprint(f"[Compute] Raw: {raw_answer} -> Normalized: {answer}", "yellow") return { # type: ignore **state, - "extracted_data": extracted_data, - "calculation": new_calculation, - "answer": new_answer, - "num_turns": state.get("num_turns", 0) + 1, - "messages": [*prompt.messages, result], + "calculation": calculation, + "answer": answer, + "num_turns": 1, + "num_model_calls": state.get("num_model_calls", 0) + 1, + "messages": [*state.get("messages", []), HumanMessage(content=result_text)], } - def should_continue(self, state: ChartState) -> Literal[END, "refine_answer"]: # type: ignore - """Determine if refinement is needed.""" - if state["messages"] and isinstance( - state["messages"][-1], BaseMessage - ): # pyright: ignore[reportUnnecessaryIsInstance] - last_message = state["messages"][-1] - if "THE ANSWER IS CORRECT" in last_message.content: # type: ignore - if "THE ANSWER IS INCORRECT" in last_message.content: # type: ignore - correct_index = last_message.content.rfind("THE ANSWER IS CORRECT") # type: ignore - incorrect_index = last_message.content.rfind("THE ANSWER IS INCORRECT") # type: ignore - if correct_index > incorrect_index: - return END - else: - return END - - if state.get("num_turns", 0) >= self.max_turns: - return END - - return "refine_answer" - - def graph(self) -> CompiledStateGraph[ChartState]: - """Build the workflow graph with refinement loop.""" - # Check if the graph is already built + def graph(self) -> CompiledStateGraph[ChartQAState]: + """Build the workflow graph. + + Two-step mode: + START -> extract_data -> compute_answer -> END + """ if self._graph is not None: return self._graph - builder = StateGraph(ChartState) - builder.add_node(self.analyze_chart) # type: ignore - builder.add_node(self.extract_data) # type: ignore - builder.add_node(self.calculate_answer) # type: ignore - builder.add_node(self.check_answer) # type: ignore - builder.add_node(self.refine_answer) # type: ignore - - builder.add_edge(START, "analyze_chart") - builder.add_edge("analyze_chart", "extract_data") - builder.add_edge("extract_data", "calculate_answer") - builder.add_edge("calculate_answer", "check_answer") - builder.add_conditional_edges( - "check_answer", - self.should_continue, # type: ignore - ) - builder.add_edge("refine_answer", "extract_data") + builder = StateGraph(ChartQAState) + + # Two-step: extract then compute + builder.add_node("extract_data", self.extract_data) + builder.add_node("compute_answer", self.compute_answer) + builder.add_edge(START, "extract_data") + builder.add_edge("extract_data", "compute_answer") + builder.add_edge("compute_answer", END) self._graph = builder.compile() # type: ignore return self._graph def rollout(self, task: Dict[str, Any], resources: agl.NamedResources, rollout: agl.Rollout) -> float | None: - """AgentLightning wrapper for ChartQA agent.""" - + """AgentLightning wrapper for the ChartQA agent.""" question = task["question"] - rollout = cast(agl.AttemptedRollout, rollout) - llm = cast(agl.LLM, resources["main_llm"]) image_path = os.path.join(chartqa_env_var.CHARTQA_DATA_DIR, task["image_path"]) ground_truth = task["answer"] @@ -353,58 +377,28 @@ def rollout(self, task: Dict[str, Any], resources: agl.NamedResources, rollout: logger.error(f"Image {image_path} does not exist. Skipping.") return None - # The new rollout could have a different endpoint or temperature. - # Update the LLM if necessary. - self.update_llm_config( - model_name=llm.model, - endpoint=llm.get_base_url(rollout.rollout_id, rollout.attempt.attempt_id), - temperature=llm.sampling_parameters.get("temperature", 0.0), - ) + # Update model configuration from resources + if "main_llm" in resources: + llm_resource = cast(agl.LLM, resources["main_llm"]) + llm_endpoint = llm_resource.get_base_url(rollout.rollout_id, rollout.attempt.attempt_id) + llm_temperature = llm_resource.sampling_parameters.get("temperature", 0.0) + + if llm_endpoint != self.endpoint or llm_temperature != self.temperature: + self.endpoint = llm_endpoint + self.temperature = llm_temperature + self._model = self._create_model() try: handler = self.tracer.get_langchain_handler() - result = self.graph().invoke( # type: ignore - {"question": question, "image_path": image_path}, # type: ignore + result = self.graph().invoke( + {"question": question, "image_path": image_path}, {"callbacks": [handler] if handler else [], "recursion_limit": 100}, ) except Exception as e: - error_msg = f"[Rollout {rollout.rollout_id}] Error during agent invocation: {e}" - logger.error(error_msg, exc_info=True) - # Return 0.0 as reward to indicate failure + logger.error(f"[Rollout {rollout.rollout_id}] Error: {e}", exc_info=True) return 0.0 predicted_answer = result["answer"] reward = evaluate_answer(predicted_answer, ground_truth, raise_on_error=False) return reward - - -def evaluate_answer(predicted: str, ground_truth: str, raise_on_error: bool = False) -> float: - """Evaluate answer accuracy.""" - try: - pred = predicted.lower().strip() - gt = ground_truth.lower().strip() - - # Exact match - if pred == gt: - return 1.0 - - # Try numeric comparison - try: - pred_num = float(pred.replace(",", "")) - gt_num = float(gt.replace(",", "")) - if abs(pred_num - gt_num) / max(abs(gt_num), 1e-9) < 0.02: - return 1.0 - except (ValueError, AttributeError): - pass - - # Partial credit for substring match - if pred in gt or gt in pred: - return 0.5 - - return 0.0 - except Exception as e: - if raise_on_error: - raise - logger.exception(f"Error evaluating answer: {e}") - return 0.0 diff --git a/examples/chartqa/debug_chartqa_agent.py b/examples/chartqa/debug_chartqa_agent.py index b8b2fb2b0..6d8c1ecf7 100644 --- a/examples/chartqa/debug_chartqa_agent.py +++ b/examples/chartqa/debug_chartqa_agent.py @@ -11,13 +11,13 @@ Example usage for self-hosted model. ``` -vllm serve Qwen/Qwen2-VL-2B-Instruct \ +vllm serve Qwen/Qwen2.5-VL-3B-Instruct \ --gpu-memory-utilization 0.6 \ --max-model-len 4096 \ --allowed-local-media-path $CHARTQA_DATA_DIR \ --enable-prefix-caching \ --port 8088 -USE_LLM_PROXY=1 OPENAI_API_BASE=http://localhost:8088/v1 OPENAI_MODEL=Qwen/Qwen2-VL-2B-Instruct python debug_chartqa_agent.py +OPENAI_API_BASE=http://localhost:8088/v1 OPENAI_MODEL=Qwen/Qwen2.5-VL-3B-Instruct python debug_chartqa_agent.py ``` Ensure `CHARTQA_DATA_DIR` points to a directory with the prepared parquet file by running `python prepare_data.py` beforehand. @@ -38,42 +38,16 @@ logger = logging.getLogger("chartqa_agent") -def create_llm_proxy_for_chartqa(vllm_endpoint: str, port: int = 8081) -> agl.LLMProxy: - """Create an LLMProxy configured for ChartQA with token ID capture. +def is_local_endpoint(endpoint: str) -> bool: + """Check if the endpoint is a local vLLM server.""" + return "localhost" in endpoint or "127.0.0.1" in endpoint - Args: - vllm_endpoint: Base URL for the hosted vLLM server. - port: Local port where the proxy should listen. - - Returns: - An [`LLMProxy`][agentlightning.LLMProxy] instance launched in a thread. - """ - store = agl.LightningStoreThreaded(agl.InMemoryLightningStore()) - - llm_proxy = agl.LLMProxy( - port=port, - store=store, - model_list=[ - { - "model_name": "Qwen/Qwen2-VL-2B-Instruct", - "litellm_params": { - "model": "hosted_vllm/Qwen/Qwen2-VL-2B-Instruct", - "api_base": vllm_endpoint, - }, - } - ], - callbacks=["return_token_ids"], - launch_mode="thread", - ) - - return llm_proxy +def debug_chartqa_agent() -> None: + """Debug the ChartQA agent against cloud APIs or a local vLLM server. -def debug_chartqa_agent(use_llm_proxy: bool = False) -> None: - """Debug the ChartQA agent against cloud APIs or a local vLLM proxy. - - Args: - use_llm_proxy: When `True`, spin up an LLMProxy that points to a local vLLM endpoint. + Automatically detects local vs cloud based on the OPENAI_API_BASE endpoint. + For local vLLM, uses file:// paths. For cloud APIs, uses base64 encoding. Raises: FileNotFoundError: If the prepared ChartQA parquet file is missing. @@ -88,37 +62,30 @@ def debug_chartqa_agent(use_llm_proxy: bool = False) -> None: model = chartqa_env_var.OPENAI_MODEL endpoint = chartqa_env_var.OPENAI_API_BASE + api_key = chartqa_env_var.OPENAI_API_KEY + use_local = is_local_endpoint(endpoint) + logger.info( - "Debug data: %s samples, model: %s, endpoint: %s, llm_proxy=%s", + "Debug data: %s samples, model: %s, endpoint: %s, local=%s", len(test_data), model, endpoint, - use_llm_proxy, + use_local, ) - llm_endpoint = endpoint - trainer_kwargs: Dict[str, Any] = {} - - if use_llm_proxy: - proxy_port = 8089 - llm_proxy = create_llm_proxy_for_chartqa(endpoint, port=proxy_port) - trainer_kwargs["llm_proxy"] = llm_proxy - trainer_kwargs["n_workers"] = 2 - llm_endpoint = f"http://localhost:{proxy_port}/v1" - agent = ChartQAAgent() - else: - trainer_kwargs["n_workers"] = 1 - agent = ChartQAAgent(use_base64_images=True) + # For local vLLM, use file:// paths; for cloud APIs, use base64 encoding + agent = ChartQAAgent(use_base64_images=not use_local) trainer = agl.Trainer( initial_resources={ "main_llm": agl.LLM( - endpoint=llm_endpoint, + endpoint=endpoint, model=model, + api_key=api_key, sampling_parameters={"temperature": 0.0}, ) }, - **trainer_kwargs, + n_workers=1, ) trainer.dev(agent, test_data) @@ -126,4 +93,4 @@ def debug_chartqa_agent(use_llm_proxy: bool = False) -> None: if __name__ == "__main__": agl.setup_logging(apply_to=["chartqa_agent"]) - debug_chartqa_agent(use_llm_proxy=chartqa_env_var.USE_LLM_PROXY) + debug_chartqa_agent() diff --git a/examples/chartqa/env_var.py b/examples/chartqa/env_var.py index a8c880984..6bb1a8920 100644 --- a/examples/chartqa/env_var.py +++ b/examples/chartqa/env_var.py @@ -7,7 +7,6 @@ "CHARTQA_DATA_DIR", "CHARTQA_IMAGES_DIR", "USE_BASE64_IMAGES", - "USE_LLM_PROXY", "OPENAI_API_BASE", "OPENAI_API_KEY", "OPENAI_MODEL", @@ -21,10 +20,8 @@ USE_BASE64_IMAGES = os.getenv("USE_BASE64_IMAGES", "false").lower() in ("1", "true", "yes") -USE_LLM_PROXY = os.getenv("USE_LLM_PROXY", "false").lower() in ("1", "true", "yes") - OPENAI_API_BASE = os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1") OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "token-abc123") -OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4.1-mini") +OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini") diff --git a/examples/chartqa/prompts.py b/examples/chartqa/prompts.py index 7059bd9e1..04cef3b5f 100644 --- a/examples/chartqa/prompts.py +++ b/examples/chartqa/prompts.py @@ -4,195 +4,70 @@ from langchain_core.prompts import ChatPromptTemplate -ANALYZE_CHART_PROMPT = ChatPromptTemplate( - [ - ( - "system", - """ -You are a visual reasoning expert analyzing charts and graphs. -Given a chart image and a question, first carefully observe and describe the chart. - -Instructions: -- Identify the chart type (bar chart, line chart, pie chart, scatter plot, etc.) -- Note the axes labels and units (if applicable) -- Describe the data series or categories shown -- Observe key patterns, trends, or noteworthy values -- Pay attention to legends, titles, and annotations - -## Output Format ## - -Provide your observation inside and tags. - -Example: - -Bar chart showing GDP of 5 countries. X-axis shows country names, Y-axis shows GDP in trillions of USD. -Data values: USA appears highest at around 25, China second at around 20, followed by India, UK, and France. - -""".strip(), - ), - ("user", "Question: {question}"), - ] -) - EXTRACT_DATA_PROMPT = ChatPromptTemplate( [ - ( - "system", - """ -Based on your observation of the chart, extract the specific data values needed to answer the question. - -Instructions: -- Extract only the data relevant to the question -- Be precise with values (read carefully from the chart) -- Include labels/categories with each value -- Use appropriate units - -## Output Format ## - -Provide extracted data inside and tags. -Format: Label1: Value1, Label2: Value2, ... - -Example: - -USA: 25, China: 20, India: 15, UK: 10, France: 8 - -""".strip(), - ), ( "user", - """Observation: {observation} + """Analyze this chart to answer the question. Extract all relevant data. Question: {question} -Please extract the relevant data values.""", - ), - ] -) - -CALCULATE_ANSWER_PROMPT = ChatPromptTemplate( - [ - ( - "system", - """ -Using the extracted data, perform any necessary calculations to answer the question. - Instructions: -- Show your calculation steps clearly -- Use correct mathematical operations -- Pay attention to the question (average, sum, difference, maximum, etc.) -- Provide a precise numerical answer if applicable -- Keep the answer concise (typically 1-10 words) - -## Output Format ## - -Show calculation inside and tags (if needed). -Provide final answer inside and tags. - -Example: - -Average = (25 + 20 + 15 + 10 + 8) / 5 = 78 / 5 = 15.6 - - -15.6 - -""".strip(), - ), - ( - "user", - """Extracted Data: {extracted_data} - -Question: {question} - -Please calculate and provide the answer.""", +1. Identify what data is needed to answer the question +2. Read the values carefully from the chart +3. Output the extracted data as JSON + +Output format: + +{{ + "chart_type": "bar/line/pie/...", + "relevant_values": {{ + "label1": value1, + "label2": value2, + ... + }}, + "notes": "any observations about the data" +}} + + +Rules: +- Extract ONLY values needed to answer the question +- Use exact labels/categories from the chart +- For numbers, give your best estimate if not labeled +- Include units in notes if relevant (but not in values)""", ), ] ) -CHECK_ANSWER_PROMPT = ChatPromptTemplate( - [ - ( - "system", - """ -You are a chart analysis expert with strong attention to detail. -Review the answer for potential mistakes. - -Common mistakes to check: -- Incorrect data extraction from chart (misread values) -- Arithmetic errors in calculations -- Misunderstanding the question type (average vs. sum vs. difference) -- Wrong number of data points counted -- Incorrect units or scale interpretation -- Off-by-one errors - -## Chart Information ## - -Observation: {observation} -Extracted Data: {extracted_data} - -## Output Format ## -If any mistakes are found, list each error clearly. -After listing mistakes (if any), conclude with **ONE** of the following exact phrases in all caps: -- If mistakes are found: `THE ANSWER IS INCORRECT.` -- If no mistakes are found: `THE ANSWER IS CORRECT.` - -DO NOT write the corrected answer in this response. You only need to report mistakes. -""".strip(), - ), - ( - "user", - """Question: {question} - -Current Answer: {answer} - -Calculation shown: -{calculation} - -Please review this answer for correctness.""", - ), - ] -) - -REFINE_ANSWER_PROMPT = ChatPromptTemplate( +COMPUTE_ANSWER_PROMPT = ChatPromptTemplate( [ - ( - "system", - """ -You are a chart analysis agent. -The previous answer had errors. Based on the feedback, provide a corrected answer. - -Instructions: -- Re-examine the chart observation carefully -- Correct any data extraction errors by re-extracting if needed -- Fix calculation mistakes -- Address all points mentioned in the feedback - -## Chart Observation ## - -{observation} - -## Output Format ## - -If you need to re-extract data, provide it inside and tags. -Show corrected calculation inside and tags. -Provide corrected answer inside and tags. -""".strip(), - ), ( "user", - """Question: {question} - -## Previous Attempt ## + """Given the extracted chart data, answer the question. -Extracted Data: {extracted_data} -Calculation: {calculation} -Answer: {answer} - -## Feedback ## +Question: {question} -{feedback} +Extracted data: +{extracted_data} -Please provide the corrected answer.""", +Instructions: +1. Determine what calculation is needed (if any) +2. Perform the calculation +3. Give the final answer + +Output format: + +[Your reasoning - what operation is needed?] + + +[final answer - number or short phrase, no units] + +Rules: +- Answer must be concise: a number OR max 3 words +- No units, %, or currency symbols in the answer +- If the question asks "which/what", return the label/name +- If the question asks "how many/much", return a number""", ), ] ) diff --git a/examples/chartqa/train_chartqa_agent.py b/examples/chartqa/train_chartqa_agent.py index 88921f25f..c1226a14d 100644 --- a/examples/chartqa/train_chartqa_agent.py +++ b/examples/chartqa/train_chartqa_agent.py @@ -5,13 +5,13 @@ Example usage: ```bash -python train_chartqa_agent.py debug --n-runners 2 +python train_chartqa_agent.py debug --n-runners 32 ``` or: ```bash -AGL_MANAGED_STORE=0 python train_chartqa_agent.py qwen --external-store-address http://localhost:9999 +AGL_MANAGED_STORE=0 python train_chartqa_agent.py qwen --external-store-address http://localhost:4747 ``` Make sure to run `python prepare_data.py` so the parquet files referenced here exist. @@ -37,24 +37,24 @@ "algorithm": {"adv_estimator": "grpo", "use_kl_in_reward": False}, "data": { "image_base_dir": chartqa_env_var.CHARTQA_IMAGES_DIR, - "train_batch_size": 32, + "train_batch_size": 8, "max_prompt_length": 4096, "max_response_length": 1024, "truncation": "error", }, "actor_rollout_ref": { "rollout": { - "tensor_model_parallel_size": 1, - "n": 4, + "tensor_model_parallel_size": 2, + "n": 8, "log_prob_micro_batch_size_per_gpu": 1, "name": "vllm", - "gpu_memory_utilization": 0.8, + "gpu_memory_utilization": 0.4, "enable_prefix_caching": True, "engine_kwargs": {"vllm": {"allowed_local_media_path": chartqa_env_var.CHARTQA_IMAGES_DIR}}, }, "actor": { - "ppo_mini_batch_size": 32, - "ppo_micro_batch_size_per_gpu": 4, + "ppo_mini_batch_size": 8, + "ppo_micro_batch_size_per_gpu": 1, "optim": {"lr": 1e-6}, "use_kl_loss": False, "kl_loss_coef": 0.0, @@ -65,13 +65,13 @@ }, "ref": {"log_prob_micro_batch_size_per_gpu": 1, "fsdp_config": {"param_offload": True}}, "model": { - "path": "Qwen/Qwen2-VL-2B-Instruct", + "path": "Qwen/Qwen2.5-VL-3B-Instruct", "use_remove_padding": True, "enable_gradient_checkpointing": True, }, }, "trainer": { - "n_gpus_per_node": 1, + "n_gpus_per_node": 2, "val_before_train": False, "critic_warmup": 0, "logger": ["console", "wandb"], @@ -83,7 +83,7 @@ def config_ci() -> Dict[str, Any]: - """Return a CI-friendly RL config for ChartQA.""" + """Return a CI-friendly RL config for ChartQA agent.""" # For CI testing, we need to set the experiment name and project name so that # they are available to subsequent steps. timestamp = datetime.now().strftime("%Y%m%d%H%M%S") @@ -97,10 +97,7 @@ def config_ci() -> Dict[str, Any]: f.write(f"run_name={EXPERIMENT_NAME}\n") config = deepcopy(RL_CONFIG) - config["data"]["train_batch_size"] = 16 - config["trainer"]["n_gpus_per_node"] = 1 config["trainer"]["total_training_steps"] = 4 - config["trainer"]["val_before_train"] = True config["trainer"]["test_freq"] = 2 config["trainer"]["experiment_name"] = EXPERIMENT_NAME config["trainer"]["project_name"] = PROJECT_NAME @@ -110,7 +107,6 @@ def config_ci() -> Dict[str, Any]: def config_debug() -> Dict[str, Any]: """Return a short debugging config for smoke testing ChartQA training.""" config = deepcopy(RL_CONFIG) - config["actor_rollout_ref"]["rollout"]["gpu_memory_utilization"] = 0.5 config["trainer"]["total_training_steps"] = 10 config["trainer"]["test_freq"] = 2 return config @@ -119,7 +115,6 @@ def config_debug() -> Dict[str, Any]: def config_qwen() -> Dict[str, Any]: """Return a Qwen-focused config with validation before each epoch.""" config = deepcopy(RL_CONFIG) - config["trainer"]["val_before_train"] = True config["trainer"]["n_gpus_per_node"] = 2 config["trainer"]["total_epochs"] = 2 config["trainer"]["test_freq"] = 32 @@ -130,7 +125,7 @@ def train( config: Dict[str, Any], train_data: agl.Dataset[Any], val_data: agl.Dataset[Any], - external_store_address: str, + external_store_address: Optional[str], n_runners: int, debug: bool, ) -> None: