diff --git a/package.json b/package.json index 047a0ae..473e736 100644 --- a/package.json +++ b/package.json @@ -70,7 +70,9 @@ }, "pnpm": { "onlyBuiltDependencies": [ - "esbuild" + "@parcel/watcher", + "esbuild", + "msgpackr-extract" ] } } diff --git a/src/lib/formatting.ts b/src/lib/formatting.ts index f57dca3..f562ece 100644 --- a/src/lib/formatting.ts +++ b/src/lib/formatting.ts @@ -29,7 +29,7 @@ export function formatConversationAsXml(messages: Message[]): string { } /** Escape special characters for XML */ -function escapeXml(str: string): string { +export function escapeXml(str: string): string { return str .replace(/&/g, "&") .replace(/ ` - - ${node.description || ""} + + ${escapeXml(node.description || "")} `, ) .join("\n"); @@ -82,8 +82,8 @@ export function formatLabelDescList( const xmlItems = items .map( - (item) => `${item.description ?? ""}`, + (item) => + `${escapeXml(item.description ?? "")}`, ) .join("\n"); return ` @@ -91,7 +91,6 @@ ${xmlItems} `; } - // Group definitions for reranked search results type SearchGroups = { similarNodes: NodeSearchResult; @@ -107,8 +106,8 @@ export type SearchResults = RerankResult; // Helpers for formatting individual result items function formatSearchNode(node: NodeSearchResult): string { return ` - - ${node.description ?? ""} + + ${escapeXml(node.description ?? "")} `; } @@ -124,10 +123,14 @@ function formatSearchConnection(conn: OneHopNode): string { return ` - ${conn.description ?? ""} + ${escapeXml(conn.description ?? "")} `; } +function assertNever(value: never, message: string): never { + throw new Error(message); +} + /** * Formats reranked search results as an XML-like structure for LLM prompts. * Items are ordered by descending relevance and tagged by their group. @@ -143,9 +146,51 @@ export function formatSearchResultsAsXml(results: SearchResults): string { return formatSearchEdge(r.item); case "connections": return formatSearchConnection(r.item); + default: + return assertNever( + r.group, + `[formatSearchResultsAsXml] Unhandled search result group: ${String( + r.group, + )}`, + ); } }) .join("\n") : ""; return body; } + +export type SearchResultWithId = SearchResults[number] & { tempId: string }; + +/** + * Format search results with temporary IDs so the LLM can reference them. + */ +export function formatSearchResultsWithIds( + results: SearchResultWithId[], +): string { + const body = results.length + ? results + .map((r) => { + const inner = (() => { + switch (r.group) { + case "similarNodes": + return formatSearchNode(r.item); + case "similarEdges": + return formatSearchEdge(r.item); + case "connections": + return formatSearchConnection(r.item); + default: + return assertNever( + r.group, + `[formatSearchResultsWithIds] Unhandled search result group: ${String( + r.group, + )}`, + ); + } + })(); + return `${inner}`; + }) + .join("\n") + : ""; + return body; +} diff --git a/src/lib/jobs/deep-research.ts b/src/lib/jobs/deep-research.ts index 78a82e3..106d181 100644 --- a/src/lib/jobs/deep-research.ts +++ b/src/lib/jobs/deep-research.ts @@ -1,6 +1,12 @@ import { performStructuredAnalysis } from "../ai"; import { storeDeepResearchResult } from "../cache/deep-research-cache"; import { generateEmbeddings } from "../embeddings"; +import { + escapeXml, + formatSearchResultsWithIds, + type SearchResultWithId, + type SearchResults, +} from "../formatting"; import { findOneHopNodes, findSimilarEdges, @@ -14,6 +20,7 @@ import { DeepResearchJobInput, DeepResearchResult, } from "../schemas/deep-research"; +import { TemporaryIdMapper } from "../temporary-id-mapper"; import { z } from "zod"; import { DrizzleDB } from "~/db"; import { useDatabase } from "~/utils/db"; @@ -28,6 +35,8 @@ type SearchGroups = { // Default TTL for deep research results (24 hours) const DEFAULT_TTL_SECONDS = 24 * 60 * 60; +// Maximum number of refinement loops +const MAX_SEARCH_LOOPS = 4; /** * Main job handler for deep research @@ -42,8 +51,7 @@ export async function performDeepResearch( console.log(`Starting deep research for conversation ${conversationId}`); try { - // Get search queries based on recent conversation turns - // Filter to only include user and assistant messages + // Prepare initial queries based on recent conversation turns const recentMessages = messages .slice(-lastNMessages) .filter((m) => m.role === "user" || m.role === "assistant"); @@ -54,11 +62,16 @@ export async function performDeepResearch( return; } - // Execute search queries and aggregate results - const searchResults = await executeDeepSearchQueries(db, userId, queries); + // Run iterative search/refine loop + const searchResults = await runIterativeSearch( + db, + userId, + recentMessages, + queries, + ); - // Process results and cache them - await cacheDeepResearchResults(userId, conversationId, searchResults); + // Cache the combined results + await cacheDeepResearchResults(userId, conversationId, [searchResults]); console.log(`Deep research completed for conversation ${conversationId}`); } catch (error) { @@ -82,7 +95,7 @@ async function generateSearchQueries( // Format messages for context const messageContext = messages - .map((m) => `${m.content}`) + .map((m) => `${escapeXml(m.content)}`) .join("\n"); // Use structured analysis to generate tangential search queries @@ -110,6 +123,133 @@ Come up with 1-5 search queries that explore adjacent or less obvious connection } } +/** + * Run iterative search with LLM refinement. + */ +async function runIterativeSearch( + db: DrizzleDB, + userId: string, + messages: DeepResearchJobInput["messages"], + initialQueries: string[], +): Promise> { + const queue = [...initialQueries]; + const history: string[] = []; + let results: SearchResultWithId[] = []; + let tempIdCounter = 0; + const mapper = new TemporaryIdMapper( + () => `r${++tempIdCounter}`, + ); + const seen = new Set(); + let loops = 0; + + while (loops < MAX_SEARCH_LOOPS && queue.length > 0) { + const query = queue.shift()!; + history.push(query); + + const embResp = await generateEmbeddings({ + model: "jina-embeddings-v3", + task: "retrieval.query", + input: [query], + truncate: true, + }); + const embedding = embResp.data[0]?.embedding; + if (embedding) { + const res = await executeSearchWithEmbedding( + db, + userId, + query, + embedding, + 20, + ); + if (res) { + const dedup = res.filter((r) => { + const key = `${r.group}:${r.item.id}`; + if (seen.has(key)) return false; + seen.add(key); + return true; + }); + results.push(...mapper.mapItems(dedup)); + } + } + + loops++; + if (loops >= MAX_SEARCH_LOOPS) break; + + const refinement = await refineSearchResults( + userId, + messages, + history, + results, + ); + if (refinement.dropIds.length) { + const drop = new Set(refinement.dropIds); + results = results.filter((r) => !drop.has(r.tempId)); + } + if (refinement.done) break; + if (refinement.nextQuery) queue.push(refinement.nextQuery); + } + + return results.map(({ tempId, ...rest }) => rest); +} + +interface RefinementResult { + dropIds: string[]; + done: boolean; + nextQuery?: string; +} + +/** + * Ask the LLM to refine search results. + */ +async function refineSearchResults( + userId: string, + messages: DeepResearchJobInput["messages"], + queries: string[], + results: SearchResultWithId[], +): Promise { + const schema = z + .object({ + dropIds: z.array(z.string()).default([]), + done: z.boolean(), + nextQuery: z.string().optional(), + }) + .describe("DeepResearchRefinement"); + + const messageContext = messages + .map((m) => `${escapeXml(m.content)}`) + .join("\n"); + const queriesXml = queries + .map((q) => `${escapeXml(q)}`) + .join("\n"); + const resultsXml = formatSearchResultsWithIds(results); + + try { + return await performStructuredAnalysis({ + userId, + systemPrompt: "You refine background search results.", + prompt: ` +${messageContext} + + + +${queriesXml} + + + +${resultsXml} + + + +Remove irrelevant results by listing their ids in dropIds. If more searching is needed, set done=false and provide nextQuery. If satisfied, set done=true. +`, + schema, + }); + } catch (error) { + console.error("Failed to refine deep search results:", error); + return { dropIds: [], done: true }; + } +} + /** * Execute multiple search queries in parallel with higher limits * and return combined results