Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,14 @@ async def initialize_tool_instances(self, question: str) -> None:
question: The research question (passed to tool instances for context)
"""
self.tool_instances = {}

# Deduplication tracking - prevents redundant tool calls within a research session
self._seen_queries = set() # Track searched queries to prevent duplicate searches
self._seen_urls = set() # Track read URLs to prevent duplicate reads

# Tool call statistics - tracks success/failure rates for debugging and insights
self._tool_call_stats = {"total": 0, "success": 0, "failed": 0, "by_tool": {}}

for tool_name, tool in self.tools.items():
instance_id, _ = await tool.create(
create_kwargs={
Expand Down Expand Up @@ -455,12 +463,60 @@ async def call_tool(self, tool_call_str: str) -> ToolResponse:
tool_name = tool_call["name"]
tool_args = tool_call["arguments"]

# Deduplication for web_search - prevents redundant searches
if tool_name == "web_search" and "query_list" in tool_args:
original_queries = tool_args["query_list"]
tool_args["query_list"] = [q for q in original_queries if q not in self._seen_queries]

# If all queries are duplicates, return cached response without calling tool
if not tool_args["query_list"]:
logger.info(f"All {len(original_queries)} queries already seen, skipping search")
response_text = json.dumps([
{"query": q, "search_results": [], "note": "Already searched"}
for q in original_queries
])
# Truncate cached response if needed for consistency
if len(response_text) > self.max_tool_response_length:
response_text = response_text[:self.max_tool_response_length] + "...(truncated)"
return ToolResponse(text=response_text)

# Mark queries as seen before execution
self._seen_queries.update(tool_args["query_list"])

# Deduplication for web_read - prevents redundant URL reads
elif tool_name == "web_read" and "url_list" in tool_args:
original_urls = tool_args["url_list"]
tool_args["url_list"] = [u for u in original_urls if u not in self._seen_urls]

# If all URLs are duplicates, return cached response without calling tool
if not tool_args["url_list"]:
logger.info(f"All {len(original_urls)} URLs already seen, skipping read")
response_text = json.dumps([
{"url": u, "content": "", "note": "Already read"}
for u in original_urls
])
# Truncate cached response if needed for consistency
if len(response_text) > self.max_tool_response_length:
response_text = response_text[:self.max_tool_response_length] + "...(truncated)"
return ToolResponse(text=response_text)

# Mark URLs as seen before execution
self._seen_urls.update(tool_args["url_list"])

# Execute tool
tool = self.tools[tool_name]
instance_id = self.tool_instances[tool_name]

try:
tool_response, _, _ = await tool.execute(instance_id, tool_args)

# Track successful tool call
self._tool_call_stats["total"] += 1
self._tool_call_stats["success"] += 1
self._tool_call_stats["by_tool"][tool_name] = self._tool_call_stats["by_tool"].get(
tool_name, {"success": 0, "failed": 0}
)
self._tool_call_stats["by_tool"][tool_name]["success"] += 1

# Truncate response if too long
if (
Expand All @@ -474,6 +530,22 @@ async def call_tool(self, tool_call_str: str) -> ToolResponse:

return tool_response
except Exception as e:
# Track failed tool call with statistics
self._tool_call_stats["total"] += 1
self._tool_call_stats["failed"] += 1
self._tool_call_stats["by_tool"][tool_name] = self._tool_call_stats["by_tool"].get(
tool_name, {"success": 0, "failed": 0}
)
self._tool_call_stats["by_tool"][tool_name]["failed"] += 1
success_rate = (
(self._tool_call_stats["success"] / self._tool_call_stats["total"] * 100)
if self._tool_call_stats["total"] > 0
else 0
)
logger.warning(
f"Tool call failed: {tool_name} | Session success rate: {success_rate:.1f}% "
f"({self._tool_call_stats['success']}/{self._tool_call_stats['total']})"
)
return ToolResponse(
text=f"Error executing tool: {type(e).__name__}: {str(e)}"
)
Expand Down
16 changes: 15 additions & 1 deletion main_rts.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,22 @@ def main_task(config):
else:
assert False, "gemini_mbe_final_answer not found for data source: " + data_source

# NEW CODE: Enhanced metrics for data sources
for data_source, gemini_mbe_list in data_source_to_gemini_mbe.items():
print(f"Average gemini_mbe for {data_source}: {np.mean(gemini_mbe_list)} (n={len(gemini_mbe_list)})")
avg_score = np.mean(gemini_mbe_list)

# --- Metric 21 (Success Rate) ---
# Calculates % of scores greater than or equal to 0.8
success_rate = np.mean([s >= 0.8 for s in gemini_mbe_list])

print(f"Data Source: {data_source}")
print(f" - Avg Gemini MBE: {avg_score:.4f}")
print(f" - Success Rate (>=0.8): {success_rate:.2%}")

# --- Metric 25 (Coverage Rate) ---
# Checks output_lst for your specific failure string
coverage = np.mean(["I have performed research but I can not find the answer" not in ans for ans in output_lst])
print(f"\nGlobal Coverage Rate: {coverage:.2%}")

if __name__ == "__main__":
main()
115 changes: 115 additions & 0 deletions tool_server/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import json
import os
from typing import Any, Dict, List
import urllib.parse

import aiohttp
from dotenv import load_dotenv
Expand Down Expand Up @@ -144,6 +145,13 @@ async def serper_search(query: str, timeout: int = 30, top_k: int = 10) -> Searc
data = await response.json()
raw_response = await response.text()
url_items = _extract_organic_from_serper_response(data)
# Apply boost scoring and stable sort
url_items = sorted(
enumerate(url_items),
key=lambda i_item: (boost_score(i_item[1].url), -i_item[0]),
reverse=True
)
url_items = [item for _, item in url_items]

logger.info(
f"Search successful for '{query}', found {len(url_items)} results"
Expand Down Expand Up @@ -313,3 +321,110 @@ async def search(self, query: str) -> SearchResult:
},
error=f"WebSearchAgent error: {str(e)[:200]}",
)
ACADEMIC_WEIGHTS = {
"wikipedia.org": 100,
"semanticscholar.org": 98,
"doi.org": 97,
"arxiv.org": 95,
"ncbi.nlm.nih.gov": 95,
"jstor.org": 94,
"dblp.org": 93,
"scholar.google.com": 100,
}

GOV_WEIGHTS = {
"nih.gov": 95,
"cdc.gov": 94,
"nasa.gov": 94,
"who.int": 93,
"un.org": 92,
"europa.eu": 91,
}

TECH_WEIGHTS = {
"w3.org": 90,
"ietf.org": 89,
"readthedocs.io": 89,
"developer.mozilla.org": 88,
"docs.python.org": 87,
"learn.microsoft.com": 86,
"unicode.org": 85,
"postgresql.org": 85,
}

OSS_WEIGHTS = {
"github.com": 84,
"gitlab.com": 83,
"apache.org": 82,
"python.org": 81,
"rust-lang.org": 81,
"linuxfoundation.org": 80,
}

COMMUNITY_WEIGHTS = {
"stackoverflow.com": 84,
"pypi.org": 83,
"npmjs.com": 82,
"crates.io": 81,
}

ALL_WEIGHTS = {
**ACADEMIC_WEIGHTS,
**GOV_WEIGHTS,
**TECH_WEIGHTS,
**OSS_WEIGHTS,
**COMMUNITY_WEIGHTS,
}

NEWS_DOMAINS = (
"reuters.com",
"apnews.com",
"pewresearch.org",
"nytimes.com",
"theguardian.com",
"npr.org",
"pbs.org",
)

TECH_ORG_TOKENS = (
"kernel",
"linux",
"python",
"apache",
"mozilla",
"gnome",
"kde",
)

def boost_score(url: str) -> int:
try:
host = urllib.parse.urlparse(url).netloc.lower().split(":", 1)[0]
if not host:
return 0
except Exception:
return 0

for domain, weight in ALL_WEIGHTS.items():
if host == domain or host.endswith("." + domain):
return weight

if host.endswith(".edu") or host.endswith(".gov"):
return 88
if host.endswith(".mil") or host.endswith(".int"):
return 85
if host.endswith(".ac.uk") or host.endswith(".edu.au"):
return 87

for domain in NEWS_DOMAINS:
if host == domain or host.endswith("." + domain):
return 72

if host.endswith(".org") and host.count(".") >= 1:
if any(token in host for token in TECH_ORG_TOKENS):
return 65
return 50

if host.endswith(".io") and "github" not in host:
return 55

return 0