Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@
},
"pnpm": {
"onlyBuiltDependencies": [
"esbuild"
"@parcel/watcher",
"esbuild",
"msgpackr-extract"
]
}
}
63 changes: 54 additions & 9 deletions src/lib/formatting.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ export function formatConversationAsXml(messages: Message[]): string {
}

/** Escape special characters for XML */
function escapeXml(str: string): string {
export function escapeXml(str: string): string {
return str
.replace(/&/g, "&")
.replace(/</g, "&lt;")
Expand Down Expand Up @@ -59,8 +59,8 @@ export function formatNodesForPrompt(
.map(
(node) =>
`<node id="${escapeXml(node.tempId)}" type="${escapeXml(node.type)}" timestamp="${node.timestamp}">
<label>${node.label ?? ""}</label>
<description>${node.description || ""}</description>
<label>${escapeXml(node.label ?? "")}</label>
<description>${escapeXml(node.description || "")}</description>
</node>`,
)
.join("\n");
Expand All @@ -82,16 +82,15 @@ export function formatLabelDescList(

const xmlItems = items
.map(
(item) => `<item label="${escapeXml(item.label ?? "Unnamed")}"
>${item.description ?? ""}</item>`,
(item) =>
`<item label="${escapeXml(item.label ?? "Unnamed")}">${escapeXml(item.description ?? "")}</item>`,
)
.join("\n");
return `<items>
${xmlItems}
</items>`;
}


// Group definitions for reranked search results
type SearchGroups = {
similarNodes: NodeSearchResult;
Expand All @@ -107,8 +106,8 @@ export type SearchResults = RerankResult<SearchGroups>;
// Helpers for formatting individual result items
function formatSearchNode(node: NodeSearchResult): string {
return `<node type="${escapeXml(node.type)}" timestamp="${formatISO(node.timestamp)}">
<label>${node.label ?? ""}</label>
<description>${node.description ?? ""}</description>
<label>${escapeXml(node.label ?? "")}</label>
<description>${escapeXml(node.description ?? "")}</description>
</node>`;
}

Expand All @@ -124,10 +123,14 @@ function formatSearchConnection(conn: OneHopNode): string {
return `<edge from="${escapeXml(conn.sourceLabel ?? "")}" to="${escapeXml(
conn.targetLabel ?? "",
)}" type="${escapeXml(conn.edgeType)}" timestamp="${formatISO(conn.timestamp)}">
<description>${conn.description ?? ""}</description>
<description>${escapeXml(conn.description ?? "")}</description>
</edge>`;
}

function assertNever(value: never, message: string): never {
throw new Error(message);
}

/**
* Formats reranked search results as an XML-like structure for LLM prompts.
* Items are ordered by descending relevance and tagged by their group.
Expand All @@ -143,9 +146,51 @@ export function formatSearchResultsAsXml(results: SearchResults): string {
return formatSearchEdge(r.item);
case "connections":
return formatSearchConnection(r.item);
default:
return assertNever(
r.group,
`[formatSearchResultsAsXml] Unhandled search result group: ${String(
r.group,
)}`,
);
}
})
.join("\n")
: "";
return body;
}

export type SearchResultWithId = SearchResults[number] & { tempId: string };

/**
* Format search results with temporary IDs so the LLM can reference them.
*/
export function formatSearchResultsWithIds(
results: SearchResultWithId[],
): string {
const body = results.length
? results
.map((r) => {
const inner = (() => {
switch (r.group) {
case "similarNodes":
return formatSearchNode(r.item);
case "similarEdges":
return formatSearchEdge(r.item);
case "connections":
return formatSearchConnection(r.item);
default:
return assertNever(
r.group,
`[formatSearchResultsWithIds] Unhandled search result group: ${String(
r.group,
)}`,
);
}
})();
return `<result id="${escapeXml(r.tempId)}">${inner}</result>`;
})
.join("\n")
: "";
return body;
}
154 changes: 147 additions & 7 deletions src/lib/jobs/deep-research.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import { performStructuredAnalysis } from "../ai";
import { storeDeepResearchResult } from "../cache/deep-research-cache";
import { generateEmbeddings } from "../embeddings";
import {
escapeXml,
formatSearchResultsWithIds,
type SearchResultWithId,
type SearchResults,
} from "../formatting";
import {
findOneHopNodes,
findSimilarEdges,
Expand All @@ -14,6 +20,7 @@ import {
DeepResearchJobInput,
DeepResearchResult,
} from "../schemas/deep-research";
import { TemporaryIdMapper } from "../temporary-id-mapper";
import { z } from "zod";
import { DrizzleDB } from "~/db";
import { useDatabase } from "~/utils/db";
Expand All @@ -28,6 +35,8 @@ type SearchGroups = {

// Default TTL for deep research results (24 hours)
const DEFAULT_TTL_SECONDS = 24 * 60 * 60;
// Maximum number of refinement loops
const MAX_SEARCH_LOOPS = 4;

/**
* Main job handler for deep research
Expand All @@ -42,8 +51,7 @@ export async function performDeepResearch(
console.log(`Starting deep research for conversation ${conversationId}`);

try {
// Get search queries based on recent conversation turns
// Filter to only include user and assistant messages
// Prepare initial queries based on recent conversation turns
const recentMessages = messages
.slice(-lastNMessages)
.filter((m) => m.role === "user" || m.role === "assistant");
Expand All @@ -54,11 +62,16 @@ export async function performDeepResearch(
return;
}

// Execute search queries and aggregate results
const searchResults = await executeDeepSearchQueries(db, userId, queries);
// Run iterative search/refine loop
const searchResults = await runIterativeSearch(
db,
userId,
recentMessages,
queries,
);

// Process results and cache them
await cacheDeepResearchResults(userId, conversationId, searchResults);
// Cache the combined results
await cacheDeepResearchResults(userId, conversationId, [searchResults]);

console.log(`Deep research completed for conversation ${conversationId}`);
} catch (error) {
Expand All @@ -82,7 +95,7 @@ async function generateSearchQueries(

// Format messages for context
const messageContext = messages
.map((m) => `<message role="${m.role}">${m.content}</message>`)
.map((m) => `<message role="${m.role}">${escapeXml(m.content)}</message>`)
.join("\n");

// Use structured analysis to generate tangential search queries
Expand Down Expand Up @@ -110,6 +123,133 @@ Come up with 1-5 search queries that explore adjacent or less obvious connection
}
}

/**
* Run iterative search with LLM refinement.
*/
async function runIterativeSearch(
db: DrizzleDB,
userId: string,
messages: DeepResearchJobInput["messages"],
initialQueries: string[],
): Promise<RerankResult<SearchGroups>> {
const queue = [...initialQueries];
const history: string[] = [];
let results: SearchResultWithId[] = [];
let tempIdCounter = 0;
const mapper = new TemporaryIdMapper<SearchResults[number], string>(
() => `r${++tempIdCounter}`,
);
Comment on lines 139 to 141
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current ID generation strategy for TemporaryIdMapper, (_item, idx) => r${idx + 1}``, will cause issues. The idx is relative to the `dedup` array passed to `mapper.mapItems(dedup)` in each iteration of the `while` loop (line 169). If `mapItems` is called multiple times with new items (e.g., `dedup` has 2 items in loop 1, then 2 different items in loop 2), it will attempt to generate IDs like 'r1', 'r2' again for the new items. The `TemporaryIdMapper` will then correctly throw a "Duplicate temporary ID generated" error because these IDs would have already been mapped to items from previous iterations. This will crash the deep research process.

To fix this, the ID generation needs to ensure uniqueness across all items added to this specific mapper instance throughout its lifecycle. A simple counter scoped to the runIterativeSearch function and incremented for each new ID would work.

Suggested change
const mapper = new TemporaryIdMapper<SearchResults[number], string>(
(_item, idx) => `r${idx + 1}`,
);
let results: SearchResultWithId[] = [];
let tempIdCounter = 0; // Initialize a counter for unique IDs
const mapper = new TemporaryIdMapper<SearchResults[number], string>(
(_item, _idx) => `r${++tempIdCounter}`, // Use the incrementing counter
);

const seen = new Set<string>();
let loops = 0;

while (loops < MAX_SEARCH_LOOPS && queue.length > 0) {
const query = queue.shift()!;
history.push(query);

const embResp = await generateEmbeddings({
model: "jina-embeddings-v3",
task: "retrieval.query",
input: [query],
truncate: true,
});
const embedding = embResp.data[0]?.embedding;
if (embedding) {
const res = await executeSearchWithEmbedding(
db,
userId,
query,
embedding,
20,
);
if (res) {
const dedup = res.filter((r) => {
const key = `${r.group}:${r.item.id}`;
if (seen.has(key)) return false;
seen.add(key);
return true;
});
results.push(...mapper.mapItems(dedup));
}
}

loops++;
if (loops >= MAX_SEARCH_LOOPS) break;

const refinement = await refineSearchResults(
userId,
messages,
history,
results,
);
if (refinement.dropIds.length) {
const drop = new Set(refinement.dropIds);
results = results.filter((r) => !drop.has(r.tempId));
}
if (refinement.done) break;
if (refinement.nextQuery) queue.push(refinement.nextQuery);
}

return results.map(({ tempId, ...rest }) => rest);
}

interface RefinementResult {
dropIds: string[];
done: boolean;
nextQuery?: string;
}

/**
* Ask the LLM to refine search results.
*/
async function refineSearchResults(
userId: string,
messages: DeepResearchJobInput["messages"],
queries: string[],
results: SearchResultWithId[],
): Promise<RefinementResult> {
const schema = z
.object({
dropIds: z.array(z.string()).default([]),
done: z.boolean(),
nextQuery: z.string().optional(),
})
.describe("DeepResearchRefinement");

const messageContext = messages
.map((m) => `<message role="${m.role}">${escapeXml(m.content)}</message>`)
.join("\n");
const queriesXml = queries
.map((q) => `<query>${escapeXml(q)}</query>`)
.join("\n");
const resultsXml = formatSearchResultsWithIds(results);

try {
return await performStructuredAnalysis({
userId,
systemPrompt: "You refine background search results.",
prompt: `<conversation>
${messageContext}
</conversation>

<queries>
${queriesXml}
</queries>

<results>
${resultsXml}
</results>

<system:instruction>
Remove irrelevant results by listing their ids in dropIds. If more searching is needed, set done=false and provide nextQuery. If satisfied, set done=true.
</system:instruction>`,
schema,
});
} catch (error) {
console.error("Failed to refine deep search results:", error);
return { dropIds: [], done: true };
}
}

/**
* Execute multiple search queries in parallel with higher limits
* and return combined results
Expand Down
Loading