diff --git a/agent/base_agent.py b/agent/base_agent.py index 5cd7b8d..f164dd9 100644 --- a/agent/base_agent.py +++ b/agent/base_agent.py @@ -417,6 +417,14 @@ async def initialize_tool_instances(self, question: str) -> None: question: The research question (passed to tool instances for context) """ self.tool_instances = {} + + # Deduplication tracking - prevents redundant tool calls within a research session + self._seen_queries = set() # Track searched queries to prevent duplicate searches + self._seen_urls = set() # Track read URLs to prevent duplicate reads + + # Tool call statistics - tracks success/failure rates for debugging and insights + self._tool_call_stats = {"total": 0, "success": 0, "failed": 0, "by_tool": {}} + for tool_name, tool in self.tools.items(): instance_id, _ = await tool.create( create_kwargs={ @@ -455,12 +463,60 @@ async def call_tool(self, tool_call_str: str) -> ToolResponse: tool_name = tool_call["name"] tool_args = tool_call["arguments"] + # Deduplication for web_search - prevents redundant searches + if tool_name == "web_search" and "query_list" in tool_args: + original_queries = tool_args["query_list"] + tool_args["query_list"] = [q for q in original_queries if q not in self._seen_queries] + + # If all queries are duplicates, return cached response without calling tool + if not tool_args["query_list"]: + logger.info(f"All {len(original_queries)} queries already seen, skipping search") + response_text = json.dumps([ + {"query": q, "search_results": [], "note": "Already searched"} + for q in original_queries + ]) + # Truncate cached response if needed for consistency + if len(response_text) > self.max_tool_response_length: + response_text = response_text[:self.max_tool_response_length] + "...(truncated)" + return ToolResponse(text=response_text) + + # Mark queries as seen before execution + self._seen_queries.update(tool_args["query_list"]) + + # Deduplication for web_read - prevents redundant URL reads + elif tool_name == "web_read" and "url_list" in tool_args: + original_urls = tool_args["url_list"] + tool_args["url_list"] = [u for u in original_urls if u not in self._seen_urls] + + # If all URLs are duplicates, return cached response without calling tool + if not tool_args["url_list"]: + logger.info(f"All {len(original_urls)} URLs already seen, skipping read") + response_text = json.dumps([ + {"url": u, "content": "", "note": "Already read"} + for u in original_urls + ]) + # Truncate cached response if needed for consistency + if len(response_text) > self.max_tool_response_length: + response_text = response_text[:self.max_tool_response_length] + "...(truncated)" + return ToolResponse(text=response_text) + + # Mark URLs as seen before execution + self._seen_urls.update(tool_args["url_list"]) + # Execute tool tool = self.tools[tool_name] instance_id = self.tool_instances[tool_name] try: tool_response, _, _ = await tool.execute(instance_id, tool_args) + + # Track successful tool call + self._tool_call_stats["total"] += 1 + self._tool_call_stats["success"] += 1 + self._tool_call_stats["by_tool"][tool_name] = self._tool_call_stats["by_tool"].get( + tool_name, {"success": 0, "failed": 0} + ) + self._tool_call_stats["by_tool"][tool_name]["success"] += 1 # Truncate response if too long if ( @@ -474,6 +530,22 @@ async def call_tool(self, tool_call_str: str) -> ToolResponse: return tool_response except Exception as e: + # Track failed tool call with statistics + self._tool_call_stats["total"] += 1 + self._tool_call_stats["failed"] += 1 + self._tool_call_stats["by_tool"][tool_name] = self._tool_call_stats["by_tool"].get( + tool_name, {"success": 0, "failed": 0} + ) + self._tool_call_stats["by_tool"][tool_name]["failed"] += 1 + success_rate = ( + (self._tool_call_stats["success"] / self._tool_call_stats["total"] * 100) + if self._tool_call_stats["total"] > 0 + else 0 + ) + logger.warning( + f"Tool call failed: {tool_name} | Session success rate: {success_rate:.1f}% " + f"({self._tool_call_stats['success']}/{self._tool_call_stats['total']})" + ) return ToolResponse( text=f"Error executing tool: {type(e).__name__}: {str(e)}" ) diff --git a/main_rts.py b/main_rts.py index 66322bf..d9e9c24 100644 --- a/main_rts.py +++ b/main_rts.py @@ -349,8 +349,22 @@ def main_task(config): else: assert False, "gemini_mbe_final_answer not found for data source: " + data_source + # NEW CODE: Enhanced metrics for data sources for data_source, gemini_mbe_list in data_source_to_gemini_mbe.items(): - print(f"Average gemini_mbe for {data_source}: {np.mean(gemini_mbe_list)} (n={len(gemini_mbe_list)})") + avg_score = np.mean(gemini_mbe_list) + + # --- Metric 21 (Success Rate) --- + # Calculates % of scores greater than or equal to 0.8 + success_rate = np.mean([s >= 0.8 for s in gemini_mbe_list]) + + print(f"Data Source: {data_source}") + print(f" - Avg Gemini MBE: {avg_score:.4f}") + print(f" - Success Rate (>=0.8): {success_rate:.2%}") + + # --- Metric 25 (Coverage Rate) --- + # Checks output_lst for your specific failure string + coverage = np.mean(["I have performed research but I can not find the answer" not in ans for ans in output_lst]) + print(f"\nGlobal Coverage Rate: {coverage:.2%}") if __name__ == "__main__": main() diff --git a/tool_server/search.py b/tool_server/search.py index 302668f..58d84a5 100644 --- a/tool_server/search.py +++ b/tool_server/search.py @@ -17,6 +17,7 @@ import json import os from typing import Any, Dict, List +import urllib.parse import aiohttp from dotenv import load_dotenv @@ -144,6 +145,13 @@ async def serper_search(query: str, timeout: int = 30, top_k: int = 10) -> Searc data = await response.json() raw_response = await response.text() url_items = _extract_organic_from_serper_response(data) + # Apply boost scoring and stable sort + url_items = sorted( + enumerate(url_items), + key=lambda i_item: (boost_score(i_item[1].url), -i_item[0]), + reverse=True + ) + url_items = [item for _, item in url_items] logger.info( f"Search successful for '{query}', found {len(url_items)} results" @@ -313,3 +321,110 @@ async def search(self, query: str) -> SearchResult: }, error=f"WebSearchAgent error: {str(e)[:200]}", ) +ACADEMIC_WEIGHTS = { + "wikipedia.org": 100, + "semanticscholar.org": 98, + "doi.org": 97, + "arxiv.org": 95, + "ncbi.nlm.nih.gov": 95, + "jstor.org": 94, + "dblp.org": 93, + "scholar.google.com": 100, +} + +GOV_WEIGHTS = { + "nih.gov": 95, + "cdc.gov": 94, + "nasa.gov": 94, + "who.int": 93, + "un.org": 92, + "europa.eu": 91, +} + +TECH_WEIGHTS = { + "w3.org": 90, + "ietf.org": 89, + "readthedocs.io": 89, + "developer.mozilla.org": 88, + "docs.python.org": 87, + "learn.microsoft.com": 86, + "unicode.org": 85, + "postgresql.org": 85, +} + +OSS_WEIGHTS = { + "github.com": 84, + "gitlab.com": 83, + "apache.org": 82, + "python.org": 81, + "rust-lang.org": 81, + "linuxfoundation.org": 80, +} + +COMMUNITY_WEIGHTS = { + "stackoverflow.com": 84, + "pypi.org": 83, + "npmjs.com": 82, + "crates.io": 81, +} + +ALL_WEIGHTS = { + **ACADEMIC_WEIGHTS, + **GOV_WEIGHTS, + **TECH_WEIGHTS, + **OSS_WEIGHTS, + **COMMUNITY_WEIGHTS, +} + +NEWS_DOMAINS = ( + "reuters.com", + "apnews.com", + "pewresearch.org", + "nytimes.com", + "theguardian.com", + "npr.org", + "pbs.org", +) + +TECH_ORG_TOKENS = ( + "kernel", + "linux", + "python", + "apache", + "mozilla", + "gnome", + "kde", +) + +def boost_score(url: str) -> int: + try: + host = urllib.parse.urlparse(url).netloc.lower().split(":", 1)[0] + if not host: + return 0 + except Exception: + return 0 + + for domain, weight in ALL_WEIGHTS.items(): + if host == domain or host.endswith("." + domain): + return weight + + if host.endswith(".edu") or host.endswith(".gov"): + return 88 + if host.endswith(".mil") or host.endswith(".int"): + return 85 + if host.endswith(".ac.uk") or host.endswith(".edu.au"): + return 87 + + for domain in NEWS_DOMAINS: + if host == domain or host.endswith("." + domain): + return 72 + + if host.endswith(".org") and host.count(".") >= 1: + if any(token in host for token in TECH_ORG_TOKENS): + return 65 + return 50 + + if host.endswith(".io") and "github" not in host: + return 55 + + return 0