diff --git a/scripts/dispatch_library_links_disambiguation_tasks.py b/scripts/dispatch_library_links_disambiguation_tasks.py index bb4b10431b..071199633a 100644 --- a/scripts/dispatch_library_links_disambiguation_tasks.py +++ b/scripts/dispatch_library_links_disambiguation_tasks.py @@ -6,12 +6,18 @@ 2. Non-segment-level resolutions Set DEBUG_MODE = True at the top of the script to limit to 10 random docs for debug. + +Examples: + python dispatch_library_links_disambiguation_tasks.py --ambiguous-start 565440 --non-segment-start 0 + python dispatch_library_links_disambiguation_tasks.py --ambiguous-start skip --non-segment-start 0 """ import django django.setup() from collections import defaultdict +import argparse +from tqdm import tqdm from sefaria.model import Ref from sefaria.system.exceptions import InputError from sefaria.system.database import db @@ -19,13 +25,23 @@ from sefaria.celery_setup.app import app from dataclasses import asdict from sefaria.helper.linker.disambiguator import AmbiguousResolutionPayload, NonSegmentResolutionPayload +from sefaria.helper.linker.tasks import _is_non_segment_or_perek_ref # Global flag for debug mode -DEBUG_MODE = True # True = sample a small random subset; False = process all matching LinkerOutput docs -DEBUG_LIMIT = 500 # Number of random examples to fetch in debug mode +DEBUG_MODE = False # True = sample a small random subset; False = process all matching LinkerOutput docs +DEBUG_LIMIT = 10 # Number of random examples to fetch in debug mode DEBUG_SEED = 6133 # Seed for reproducible random sampling + +def _parse_start_arg(value: str): + if value is None: + return 0 + if value.lower() == "skip": + return "skip" + return int(value) + + def is_segment_level_ref(ref_str): """Check if a reference string is segment-level""" try: @@ -147,8 +163,11 @@ def find_non_segment_level_resolutions(): "$elemMatch": { "type": "citation", "failed": {"$ne": True}, - "ambiguous": {"$ne": True}, - "ref": {"$exists": True} + "ref": {"$exists": True}, + "$or": [ + {"ambiguous": {"$ne": True}}, + {"llm_ambiguous_option_valid": True}, + ], } } } @@ -170,16 +189,21 @@ def find_non_segment_level_resolutions(): for span in raw_linker_output.get('spans', []): # Only look at successful citation resolutions if (span.get('type') != 'citation' or - span.get('failed', False) or - span.get('ambiguous', False)): + span.get('failed', False)): + continue + if span.get('ambiguous', False) and not span.get('llm_ambiguous_option_valid'): continue ref_str = span.get('ref') + if span.get('ambiguous', False) and span.get('llm_ambiguous_option_valid'): + amb_resolved_ref = span.get('llm_resolved_ref_ambiguous') + if amb_resolved_ref: + ref_str = amb_resolved_ref if not ref_str: continue - # Check if it's NOT segment level - if not is_segment_level_ref(ref_str): + # Check if it's NOT segment level (including perek/parasha treated as non-segment) + if _is_non_segment_or_perek_ref(ref_str): try: oref = Ref(ref_str) ref_level = 'unknown' @@ -219,6 +243,15 @@ def enqueue_bulk_disambiguation(payload: dict): def main(): """Main execution function - find and dispatch tasks""" + parser = argparse.ArgumentParser() + parser.add_argument("--ambiguous-start", default="0", + help="Number to skip for ambiguous resolutions, or 'skip'") + parser.add_argument("--non-segment-start", default="0", + help="Number to skip for non-segment resolutions, or 'skip'") + args = parser.parse_args() + ambiguous_start_from = _parse_start_arg(args.ambiguous_start) + non_segment_start_from = _parse_start_arg(args.non_segment_start) + print("Starting Library Links Disambiguation Tasks Dispatcher") if DEBUG_MODE: print(f"DEBUG MODE: Limited to {DEBUG_LIMIT} documents") @@ -237,17 +270,39 @@ def main(): return # Find ambiguous resolutions - ambiguous_resolutions = find_ambiguous_resolutions() + ambiguous_resolutions = [] if ambiguous_start_from == "skip" else find_ambiguous_resolutions() - # Find non-segment-level resolutions - non_segment_resolutions = find_non_segment_level_resolutions() - - # Dispatch bulk disambiguation tasks (single payload each) - print(f"Dispatching {len(ambiguous_resolutions) + len(non_segment_resolutions)} bulk disambiguation tasks...") + # Dispatch ambiguous first + print(f"Dispatching {len(ambiguous_resolutions)} ambiguous disambiguation tasks...") try: - for resolution in ambiguous_resolutions: + ambiguous_iter = ( + ambiguous_resolutions[ambiguous_start_from:] + if isinstance(ambiguous_start_from, int) and ambiguous_start_from + else ambiguous_resolutions + ) + for resolution in tqdm( + ambiguous_iter, + desc="Ambiguous resolutions", + initial=ambiguous_start_from if isinstance(ambiguous_start_from, int) else 0, + total=len(ambiguous_resolutions), + ): enqueue_bulk_disambiguation(asdict(resolution)) - for resolution in non_segment_resolutions: + + # Find non-segment-level resolutions AFTER ambiguous dispatch + non_segment_resolutions = [] if non_segment_start_from == "skip" else find_non_segment_level_resolutions() + print(f"Dispatching {len(non_segment_resolutions)} non-segment disambiguation tasks...") + + non_segment_iter = ( + non_segment_resolutions[non_segment_start_from:] + if isinstance(non_segment_start_from, int) and non_segment_start_from + else non_segment_resolutions + ) + for resolution in tqdm( + non_segment_iter, + desc="Non-segment resolutions", + initial=non_segment_start_from if isinstance(non_segment_start_from, int) else 0, + total=len(non_segment_resolutions), + ): enqueue_bulk_disambiguation(asdict(resolution)) print("Dispatched bulk disambiguation tasks") except Exception as e: @@ -259,4 +314,5 @@ def main(): if __name__ == "__main__": - main() + # main() + print(len(find_non_segment_level_resolutions())) diff --git a/sefaria/helper/linker/disambiguator.py b/sefaria/helper/linker/disambiguator.py index d9834ba20b..e76604a041 100644 --- a/sefaria/helper/linker/disambiguator.py +++ b/sefaria/helper/linker/disambiguator.py @@ -7,7 +7,7 @@ import os import re import requests -from dataclasses import dataclass +from dataclasses import dataclass, asdict from typing import Dict, Any, Optional, List, Tuple from html import unescape @@ -19,17 +19,23 @@ from sefaria.settings import SEARCH_URL - from langchain_anthropic import ChatAnthropic from langchain_openai import ChatOpenAI from langchain_core.prompts import ChatPromptTemplate from langsmith import traceable from sefaria.model.text import Ref +from sefaria.utils.hebrew import strip_cantillation from sefaria.model.schema import AddressType logger = structlog.get_logger(__name__) +class DictaAPIError(RuntimeError): + def __init__(self, info: Dict[str, Any]): + super().__init__("Dicta API returned non-200") + self.info = info + + @dataclass(frozen=True) class AmbiguousResolutionPayload: ref: str @@ -52,28 +58,42 @@ class NonSegmentResolutionPayload: @dataclass(frozen=True) class AmbiguousResolutionResult: - resolved_ref: str - matched_segment: Optional[str] - method: str + resolved_ref: Optional[str] = None + matched_segment: Optional[str] = None + method: Optional[str] = None + llm_resolved_phrase: Optional[str] = None @dataclass(frozen=True) class NonSegmentResolutionResult: - resolved_ref: str - method: str + resolved_ref: Optional[str] = None + method: Optional[str] = None + llm_resolved_phrase: Optional[str] = None + # Configuration DICTA_URL = os.getenv("DICTA_PARALLELS_URL", "https://parallels-3-0a.loadbalancer.dicta.org.il/parallels/api/findincorpus") -SEFARIA_SEARCH_URL = f"{SEARCH_URL}/api/search/text/_search" +SEFARIA_SEARCH_URL = f"{SEARCH_URL}/text/_search" MIN_THRESHOLD = 1.0 MAX_DISTANCE = 10.0 REQUEST_TIMEOUT = 30 WINDOW_WORDS = 120 + def _get_llm(): """Get configured primary LLM instance.""" - model = os.getenv("ANTHROPIC_MODEL", "claude-3-5-haiku-20241022") + model = os.getenv("ANTHROPIC_MODEL", "claude-sonnet-4-5-20250929") + api_key = os.getenv("ANTHROPIC_API_KEY") + if not api_key: + raise RuntimeError("ANTHROPIC_API_KEY environment variable is required") + + return ChatAnthropic(model=model, temperature=0, max_tokens=1024, api_key=api_key) + + +def _get_confirmation_llm(): + """Get LLM instance used for prior formation and candidate confirmation.""" + model = os.getenv("ANTHROPIC_CONFIRM_MODEL", "claude-sonnet-4-5-20250929") api_key = os.getenv("ANTHROPIC_API_KEY") if not api_key: raise RuntimeError("ANTHROPIC_API_KEY environment variable is required") @@ -101,6 +121,13 @@ def _escape_template_braces(text: str) -> str: return text.replace('{', '{{').replace('}', '}}') +def _strip_nikud(text: Optional[str]) -> Optional[str]: + """Remove cantillation and vowels (nikud) from Hebrew text.""" + if not text: + return text + return strip_cantillation(text, strip_vowels=True) + + def _get_ref_text(ref_str: str, lang: str = None, vtitle: str = None) -> Optional[str]: """Get text for a reference.""" try: @@ -193,7 +220,10 @@ def _mark_citation(text: str, span: dict) -> str: @traceable(run_type="tool", name="query_dicta") -def _query_dicta(query_text: str, target_ref: str) -> List[Dict[str, Any]]: +def _query_dicta( + query_text: str, + target_ref: str, +) -> List[Dict[str, Any]]: """Query Dicta parallels API for matching segments.""" params = { 'minthreshold': int(MIN_THRESHOLD), @@ -219,7 +249,16 @@ def _query_dicta(query_text: str, target_ref: str) -> List[Dict[str, Any]]: headers=headers, timeout=REQUEST_TIMEOUT ) - resp.raise_for_status() + if resp.status_code != 200: + raise DictaAPIError({ + "status_code": resp.status_code, + "url": resp.url, + "query_text": query_text, + "target_ref": target_ref, + "response_text": resp.text, + }) + logger.warning(f"Dicta API request failed: {resp.status_code} for {resp.url}") + return [] # Handle UTF-8 BOM by decoding with utf-8-sig text = resp.content.decode('utf-8-sig') @@ -278,14 +317,14 @@ def _normalize_dicta_url_to_ref(url: str) -> Optional[str]: @traceable(run_type="tool", name="query_sefaria_search") -def _query_sefaria_search(query_text: str, target_ref: str, slop: int = 10) -> Optional[Dict[str, Any]]: +def _query_sefaria_search(query_text: str, target_ref: str, slop: int = 20) -> List[Dict[str, Any]]: """Query Sefaria search API for matching segments.""" try: target_oref = Ref(target_ref) path_regex = _path_regex_for_ref(target_ref) except Exception: logger.warning(f"Could not create Ref for target: {target_ref}") - return None + return [] bool_query = { 'must': {'match_phrase': {'naive_lemmatizer': {'query': query_text, 'slop': slop}}} @@ -323,10 +362,11 @@ def _query_sefaria_search(query_text: str, target_ref: str, slop: int = 10) -> O data = resp.json() except Exception as e: logger.warning(f"Sefaria search API request failed: {e}") - return None + return [] hits = (data.get('hits') or {}).get('hits', []) + matches: List[Dict[str, Any]] = [] for entry in hits: normalized = _extract_ref_from_search_hit(entry) if not normalized: @@ -337,16 +377,17 @@ def _query_sefaria_search(query_text: str, target_ref: str, slop: int = 10) -> O if not cand_oref.is_segment_level(): continue if target_oref.contains(cand_oref): - return { + matches.append({ 'resolved_ref': normalized, 'source': 'sefaria_search', 'query': query_text, + 'queries': [query_text], 'raw': entry - } + }) except Exception: continue - return None + return matches def _extract_ref_from_search_hit(hit: Dict[str, Any]) -> Optional[str]: @@ -386,36 +427,45 @@ def _path_regex_for_ref(ref_str: str) -> Optional[str]: @traceable(run_type="llm", name="llm_form_search_query") def _llm_form_search_query(marked_text: str, base_ref: str = None, base_text: str = None) -> List[str]: """Use LLM to generate search queries from marked citing text.""" - llm = _get_keyword_llm() + llm = _get_confirmation_llm() # Create context with citation redacted context_redacted = re.sub(r'.*?', '[REDACTED]', marked_text, flags=re.DOTALL) base_block = "" if base_ref and base_text: - base_block = f"Base text being commented on ({base_ref}):\n{base_text[:1000]}\n\n" + base_block = f"Base text being commented on ({base_ref}):\n{_strip_nikud(base_text)}\n\n" + + prior = _llm_form_prior(marked_text, base_ref=base_ref, base_text=base_text) prompt = ChatPromptTemplate.from_messages([ - ("system", "You are extracting a concise citation phrase to search for parallels."), + ("system", "You extract concise search phrases that are likely to appear verbatim in the target text."), ("human", "Citing passage (citation wrapped in ):\n{citing}\n\n" "Context with citation redacted:\n{context}\n\n" "{base_block}" + "Prior expectations about the target (formed without seeing it):\n{prior}\n\n" "Return 5-6 short lexical search queries (<=6 words each), taken from surrounding context " "outside the citation span.\n" + "- Prefer phrases that you expect to appear verbatim in the target text.\n" "- If base text is provided, prefer keywords that appear verbatim in the base text.\n" + "- If the context contains distinctive Hebrew content words (especially nouns), prefer them verbatim.\n" + "- Do NOT translate Hebrew into English. Avoid paraphrases.\n" + "- Prefer specific/rare tokens over generic ones.\n" + "- Include at least one single-word query (preferably a distinctive Hebrew noun).\n" "- Include at least one 2-3 word query.\n" "- Do NOT copy words that appear inside ....\n" "Strict output: one per line, numbered 1) ... through 6) ... or a single line 'NONE'." - ) + ) ]) chain = prompt | llm try: response = chain.invoke({ - "citing": _escape_template_braces(marked_text[:2000]), - "context": _escape_template_braces(context_redacted[:2000]), - "base_block": _escape_template_braces(base_block) + "citing": _escape_template_braces(_strip_nikud(marked_text)), + "context": _escape_template_braces(_strip_nikud(context_redacted)), + "base_block": _escape_template_braces(base_block), + "prior": _escape_template_braces(prior), }) content = getattr(response, 'content', '') @@ -446,13 +496,15 @@ def _llm_form_search_query(marked_text: str, base_ref: str = None, base_text: st @traceable(run_type="llm", name="llm_confirm_candidate") def _llm_confirm_candidate(marked_text: str, candidate_ref: str, candidate_text: str, base_ref: str = None, base_text: str = None) -> Tuple[bool, str]: - """Use LLM to confirm if a candidate is the correct resolution.""" + """Use LLM to confirm if a candidate is the correct resolution, using a prior.""" - llm = _get_llm() + llm = _get_confirmation_llm() + + prior = _llm_form_prior(marked_text, base_ref=base_ref, base_text=base_text) base_block = "" if base_ref and base_text: - base_block = f"Base text ({base_ref}):\n{_escape_template_braces(base_text[:1000])}\n\n" + base_block = f"Base text ({base_ref}):\n{_escape_template_braces(_strip_nikud(base_text))}\n\n" prompt = ChatPromptTemplate.from_messages([ ( @@ -465,6 +517,7 @@ def _llm_confirm_candidate(marked_text: str, candidate_ref: str, candidate_text: "Citing passage (the citation span is wrapped in ):\n" "{citing}\n\n" "{base_block}" + "Prior expectations (formed without seeing the candidate):\n{prior}\n\n" "Candidate segment ref (retrieved by similarity):\n{candidate_ref}\n\n" "Candidate segment text:\n{candidate_text}\n\n" "Determine whether the citing passage is actually referring to this candidate segment.\n" @@ -478,10 +531,11 @@ def _llm_confirm_candidate(marked_text: str, candidate_ref: str, candidate_text: chain = prompt | llm try: response = chain.invoke({ - "citing": _escape_template_braces(marked_text[:2000]), + "citing": _escape_template_braces(_strip_nikud(marked_text)), "base_block": base_block, + "prior": _escape_template_braces(prior), "candidate_ref": candidate_ref, - "candidate_text": _escape_template_braces(candidate_text[:500]) + "candidate_text": _escape_template_braces(_strip_nikud(candidate_text)) }) content = getattr(response, 'content', '') verdict = "YES" if re.search(r'\bYES\b', content, re.IGNORECASE) else "NO" @@ -491,6 +545,93 @@ def _llm_confirm_candidate(marked_text: str, candidate_ref: str, candidate_text: return False, str(e) +@traceable(run_type="llm", name="llm_choose_base_vs_commentary") +def _llm_choose_base_vs_commentary( + marked_text: str, + base_ref: str, + base_text: str, + commentary_ref: str, + commentary_text: str, +) -> Optional[str]: + """Choose whether the citation refers to the base text or the commentary.""" + llm = _get_llm() + + prompt = ChatPromptTemplate.from_messages([ + ( + "system", + "You decide whether a citation is referring to the base text itself or to a commentary on that base text. " + "Be strict and choose the most likely target." + ), + ( + "human", + "Citing passage (the citation span is wrapped in ):\n" + "{citing}\n\n" + "Option A (Base text): {base_ref}\n{base_text}\n\n" + "Option B (Commentary): {commentary_ref}\n{commentary_text}\n\n" + "Which is more likely being referred to? Answer in exactly two lines:\n" + "Reason: \n" + "Choice: BASE or COMMENTARY", + ), + ]) + + chain = prompt | llm + try: + response = chain.invoke({ + "citing": _escape_template_braces(_strip_nikud(marked_text)), + "base_ref": base_ref, + "base_text": _escape_template_braces(_strip_nikud(base_text)), + "commentary_ref": commentary_ref, + "commentary_text": _escape_template_braces(_strip_nikud(commentary_text)), + }) + content = getattr(response, 'content', '') + if re.search(r"\bBASE\b", content, re.IGNORECASE): + return "BASE" + if re.search(r"\bCOMMENTARY\b", content, re.IGNORECASE): + return "COMMENTARY" + return None + except Exception as e: + logger.warning(f"LLM base vs commentary choice failed: {e}") + return None + + +@traceable(run_type="llm", name="llm_form_prior") +def _llm_form_prior(marked_text: str, base_ref: str = None, base_text: str = None) -> str: + """Use LLM to form a prior about what the target segment should contain.""" + llm = _get_confirmation_llm() + + base_block = "" + if base_ref and base_text: + base_block = f"Base text ({base_ref}):\n{_escape_template_braces(_strip_nikud(base_text))}\n\n" + + prompt = ChatPromptTemplate.from_messages([ + ( + "system", + "You form a prior expectation about what the target text likely contains, " + "based only on the citing passage and any base text. Do NOT guess a specific ref." + ), + ( + "human", + "Citing passage (the citation span is wrapped in ):\n" + "{citing}\n\n" + "{base_block}" + "Describe what the target segment should be about, key themes or phrases to expect, " + "and any constraints implied by the citation. Keep it concise and concrete.\n" + "Return 3-6 bullet points." + ), + ]) + + chain = prompt | llm + try: + response = chain.invoke({ + "citing": _escape_template_braces(_strip_nikud(marked_text)), + "base_block": base_block, + }) + content = getattr(response, 'content', '') + return content.strip() + except Exception as e: + logger.warning(f"LLM prior formation failed: {e}") + return "" + @traceable(run_type="llm", name="llm_choose_best_candidate") def _llm_choose_best_candidate( marked_text: str, @@ -535,18 +676,17 @@ def _llm_choose_best_candidate( for i, (ref, cand) in enumerate(unique.items(), 1): txt = _get_ref_text(ref, lang=lang) - preview = (txt or "").strip()[:400] - if txt and len(txt) > 400: - preview += "..." + preview = (txt or "").strip() + if preview: + preview = strip_cantillation(preview, strip_vowels=True) - score_str = f"(score={cand.get('score')})" if cand.get('score') is not None else "" - numbered.append(f"{i}) {ref} {score_str}\n{preview}") + numbered.append(f"{i}) {ref}\n{preview}") payloads.append((i, cand)) # Build base text block if available base_block = "" if base_ref and base_text: - base_block = f"Base text of commentary target ({base_ref}):\n{_escape_template_braces(base_text[:2000])}\n\n" + base_block = f"Base text of commentary target ({base_ref}):\n{_escape_template_braces(_strip_nikud(base_text))}\n\n" # Create LLM prompt llm = _get_llm() @@ -572,8 +712,8 @@ def _llm_choose_best_candidate( chain = prompt | llm try: resp = chain.invoke({ - "citing": _escape_template_braces(marked_text[:6000]), - "candidates": _escape_template_braces("\n\n".join(numbered)) + "citing": _escape_template_braces(_strip_nikud(marked_text)), + "candidates": _escape_template_braces("\n\n".join(numbered)), }) content = getattr(resp, "content", "") except Exception as exc: @@ -633,10 +773,44 @@ def _dedupe_candidates_by_ref(candidates: List[Dict[str, Any]]) -> List[Dict[str new_score = cand.get('score', 0) if new_score > old_score: seen[ref] = cand + # Merge queries from duplicate hits + prev_queries = seen[ref].get("queries") + new_query = cand.get("query") + new_queries = cand.get("queries") + merged = [] + if isinstance(prev_queries, list): + merged.extend(prev_queries) + if isinstance(new_queries, list): + merged.extend(new_queries) + if new_query: + merged.append(new_query) + if merged: + seen[ref]["queries"] = sorted({q for q in merged if q}) return list(seen.values()) +def _resolution_phrase_from_candidate(candidate: Optional[Dict[str, Any]]) -> Optional[str]: + """Extract a key phrase used to resolve a candidate from Dicta/Search data.""" + if not candidate: + return None + queries = candidate.get("queries") + if isinstance(queries, list) and queries: + unique = [q for q in dict.fromkeys([q for q in queries if q])] + return "; ".join(unique) + query = candidate.get("query") + if query: + return query + raw = candidate.get("raw", {}) + if isinstance(raw, dict) and "raw" in raw and isinstance(raw.get("raw"), dict): + raw = raw.get("raw") + if isinstance(raw, dict): + base_matched = raw.get("baseMatchedText") + if base_matched: + return base_matched + return None + + def _fallback_search_pipeline( marked_citing_text: str, citing_text: str, @@ -671,20 +845,20 @@ def run_queries(queries: List[str], label: str) -> None: searched.add(q) logger.info(f"Trying {label} query: '{q}'") - hit = _query_sefaria_search(q, non_segment_ref) + hits = _query_sefaria_search(q, non_segment_ref) - if hit: - logger.info(f"Sefaria search {label} succeeded: '{q}' -> {hit.get('resolved_ref')}") - candidates.append(hit) + if hits: + logger.info(f"Sefaria search {label} succeeded: '{q}' -> {len(hits)} hits") + candidates.extend(hits) continue # One retry for failed queries logger.info(f"Sefaria search {label} failed: '{q}', retrying once...") - retry = _query_sefaria_search(q, non_segment_ref) + retry_hits = _query_sefaria_search(q, non_segment_ref) - if retry: - logger.info(f"Sefaria search {label} retry succeeded: '{q}' -> {retry.get('resolved_ref')}") - candidates.append(retry) + if retry_hits: + logger.info(f"Sefaria search {label} retry succeeded: '{q}' -> {len(retry_hits)} hits") + candidates.extend(retry_hits) # A) Normal window queries (text-only) logger.info("Stage A: Normal window text-only queries") @@ -744,7 +918,9 @@ def run_queries(queries: List[str], label: str) -> None: @traceable(run_type="chain", name="disambiguate_non_segment_ref") -def disambiguate_non_segment_ref(resolution_data: NonSegmentResolutionPayload) -> Optional[NonSegmentResolutionResult]: +def disambiguate_non_segment_ref( + resolution_data: NonSegmentResolutionPayload, +) -> Optional[NonSegmentResolutionResult]: """ Disambiguate a non-segment-level reference to a specific segment. @@ -772,6 +948,7 @@ def disambiguate_non_segment_ref(resolution_data: NonSegmentResolutionPayload) - """ try: + logger.info("Non-segment payload", payload=asdict(resolution_data)) citing_ref = resolution_data.ref citing_text_snippet = resolution_data.text citing_lang = resolution_data.language @@ -801,6 +978,7 @@ def disambiguate_non_segment_ref(resolution_data: NonSegmentResolutionPayload) - return NonSegmentResolutionResult( resolved_ref=resolved_ref, method='auto_single_segment', + llm_resolved_phrase=None, ) # Case 2: 2-3 segments - use LLM to pick directly @@ -809,7 +987,7 @@ def disambiguate_non_segment_ref(resolution_data: NonSegmentResolutionPayload) - for i, seg_ref in enumerate(segment_refs, 1): seg_text = _get_ref_text(seg_ref.normal(), lang="he") or _get_ref_text(seg_ref.normal(), lang="en") if seg_text: - preview = seg_text[:300] + ("..." if len(seg_text) > 300 else "") + preview = _strip_nikud(seg_text) candidates.append({ 'index': i, 'resolved_ref': seg_ref.normal(), @@ -856,6 +1034,7 @@ def disambiguate_non_segment_ref(resolution_data: NonSegmentResolutionPayload) - return NonSegmentResolutionResult( resolved_ref=cand['resolved_ref'], method='llm_small_range', + llm_resolved_phrase=None, ) logger.warning(f"Could not parse LLM response: {content}") @@ -898,6 +1077,7 @@ def disambiguate_non_segment_ref(resolution_data: NonSegmentResolutionPayload) - return NonSegmentResolutionResult( resolved_ref=resolved_ref, method='dicta_auto_approved', + llm_resolved_phrase=_resolution_phrase_from_candidate(candidate), ) candidate_text = _get_ref_text(resolved_ref, citing_lang) @@ -909,6 +1089,7 @@ def disambiguate_non_segment_ref(resolution_data: NonSegmentResolutionPayload) - return NonSegmentResolutionResult( resolved_ref=resolved_ref, method='dicta_llm_confirmed', + llm_resolved_phrase=_resolution_phrase_from_candidate(candidate), ) else: logger.info(f"Dicta candidate {resolved_ref} rejected by LLM: {reason}") @@ -939,6 +1120,7 @@ def disambiguate_non_segment_ref(resolution_data: NonSegmentResolutionPayload) - return NonSegmentResolutionResult( resolved_ref=resolved_ref, method='search_auto_approved', + llm_resolved_phrase=_resolution_phrase_from_candidate(search_result), ) candidate_text = _get_ref_text(resolved_ref, citing_lang) @@ -950,6 +1132,7 @@ def disambiguate_non_segment_ref(resolution_data: NonSegmentResolutionPayload) - return NonSegmentResolutionResult( resolved_ref=resolved_ref, method='search_llm_confirmed', + llm_resolved_phrase=_resolution_phrase_from_candidate(search_result), ) else: logger.info(f"Search candidate {resolved_ref} rejected by LLM: {reason}") @@ -957,13 +1140,17 @@ def disambiguate_non_segment_ref(resolution_data: NonSegmentResolutionPayload) - logger.info("No resolution found via Dicta or Search") return None + except DictaAPIError: + raise except Exception as e: logger.error(f"Error in disambiguate_non_segment_ref: {e}", exc_info=True) return None @traceable(run_type="chain", name="disambiguate_ambiguous_ref") -def disambiguate_ambiguous_ref(resolution_data: AmbiguousResolutionPayload) -> Optional[AmbiguousResolutionResult]: +def disambiguate_ambiguous_ref( + resolution_data: AmbiguousResolutionPayload, +) -> Optional[AmbiguousResolutionResult]: """ Disambiguate between multiple possible reference resolutions. @@ -989,6 +1176,7 @@ def disambiguate_ambiguous_ref(resolution_data: AmbiguousResolutionPayload) -> O """ try: + logger.info("Ambiguous payload", payload=asdict(resolution_data)) citing_ref = resolution_data.ref citing_text_snippet = resolution_data.text citing_lang = resolution_data.language @@ -1039,6 +1227,62 @@ def disambiguate_ambiguous_ref(resolution_data: AmbiguousResolutionPayload) -> O # Get base context if commentary base_ref, base_text = _get_commentary_base_context(citing_ref) + # Special case: two options, base text vs commentary on base text, citing ref is that commentary + if _is_base_vs_commentary_ambiguous(citing_ref, base_ref, valid_candidates): + logger.info( + "Detected ambiguous base-text vs commentary case", + citing_ref=citing_ref, + base_ref=base_ref, + options=[c["ref"] for c in valid_candidates], + ) + + try: + base_index = Ref(base_ref).index.title + except Exception: + base_index = None + try: + citing_index = Ref(citing_ref).index.title + except Exception: + citing_index = None + + base_cand = None + comm_cand = None + for cand in valid_candidates: + try: + idx_title = Ref(cand["ref"]).index.title + except Exception: + continue + if base_index and idx_title == base_index: + base_cand = cand + if citing_index and idx_title == citing_index: + comm_cand = cand + + if base_cand and comm_cand: + base_text_full = _get_ref_text(base_cand["ref"], citing_lang) + comm_text_full = _get_ref_text(comm_cand["ref"], citing_lang) + if base_text_full and comm_text_full: + choice = _llm_choose_base_vs_commentary( + marked_text, + base_cand["ref"], + base_text_full, + comm_cand["ref"], + comm_text_full, + ) + if choice == "BASE": + return AmbiguousResolutionResult( + resolved_ref=base_cand["ref"], + matched_segment=None, + method="llm_base_vs_commentary", + llm_resolved_phrase=None, + ) + if choice == "COMMENTARY": + return AmbiguousResolutionResult( + resolved_ref=comm_cand["ref"], + matched_segment=None, + method="llm_base_vs_commentary", + llm_resolved_phrase=None, + ) + # Step 1: Try Dicta to find match among candidates logger.info("Trying Dicta to find match among ambiguous candidates...") dicta_match = _try_dicta_for_candidates( @@ -1061,8 +1305,9 @@ def disambiguate_ambiguous_ref(resolution_data: AmbiguousResolutionPayload) -> O logger.info(f"LLM confirmed Dicta match: {match_ref}") return AmbiguousResolutionResult( resolved_ref=dicta_match['ref'], - matched_segment=match_ref if match_ref != dicta_match['ref'] else None, + matched_segment=match_ref, method='dicta_llm_confirmed', + llm_resolved_phrase=_resolution_phrase_from_candidate(dicta_match), ) else: logger.info(f"LLM rejected Dicta match: {reason}") @@ -1082,8 +1327,9 @@ def disambiguate_ambiguous_ref(resolution_data: AmbiguousResolutionPayload) -> O logger.info(f"LLM confirmed search match: {match_ref}") return AmbiguousResolutionResult( resolved_ref=search_match['ref'], - matched_segment=match_ref if match_ref != search_match['ref'] else None, + matched_segment=match_ref, method='search_llm_confirmed', + llm_resolved_phrase=_resolution_phrase_from_candidate(search_match), ) else: logger.info(f"LLM rejected search match: {reason}") @@ -1091,6 +1337,8 @@ def disambiguate_ambiguous_ref(resolution_data: AmbiguousResolutionPayload) -> O logger.info("Could not find valid match among ambiguous candidates") return None + except DictaAPIError: + raise except Exception as e: logger.error(f"Error in disambiguate_ambiguous_ref: {e}", exc_info=True) return None @@ -1120,6 +1368,36 @@ def _get_commentary_base_context(citing_ref: Optional[str]) -> Tuple[Optional[st return None, None +def _is_base_vs_commentary_ambiguous( + citing_ref: str, + base_ref: Optional[str], + valid_candidates: List[Dict[str, Any]], +) -> bool: + """Detect base-text vs commentary ambiguity when citing ref is the commentary.""" + if not base_ref or len(valid_candidates) != 2: + return False + try: + base_index = Ref(base_ref).index.title + except Exception: + base_index = None + try: + citing_index = Ref(citing_ref).index.title + except Exception: + citing_index = None + + if not base_index or not citing_index: + return False + + cand_indexes = [] + for cand in valid_candidates: + try: + cand_indexes.append(Ref(cand["ref"]).index.title) + except Exception: + cand_indexes.append(None) + + return base_index in cand_indexes and citing_index in cand_indexes + + def _try_dicta_for_candidates( query_text: str, candidates: List[Dict[str, Any]], @@ -1193,7 +1471,9 @@ def _try_dicta_for_candidates( @traceable(run_type="tool", name="query_dicta_raw") -def _query_dicta_raw(query_text: str) -> List[Dict[str, Any]]: +def _query_dicta_raw( + query_text: str, +) -> List[Dict[str, Any]]: """Query Dicta and return all results (not filtered by target ref).""" params = { 'minthreshold': int(MIN_THRESHOLD), @@ -1213,7 +1493,16 @@ def _query_dicta_raw(query_text: str) -> List[Dict[str, Any]]: headers=headers, timeout=REQUEST_TIMEOUT ) - resp.raise_for_status() + if resp.status_code != 200: + raise DictaAPIError({ + "status_code": resp.status_code, + "url": resp.url, + "query_text": query_text, + "target_ref": None, + "response_text": resp.text, + }) + logger.warning(f"Dicta API request failed: {resp.status_code} for {resp.url}") + return [] # Handle UTF-8 BOM by decoding with utf-8-sig text = resp.content.decode('utf-8-sig') @@ -1276,40 +1565,42 @@ def _try_search_for_candidates(marked_text: str, candidates: List[Dict[str, Any] for query in queries: # Query search filtered by candidate books - result = _query_sefaria_search_with_books(query, list(candidate_books) if candidate_books else None) - if not result: + results = _query_sefaria_search_with_books(query, list(candidate_books) if candidate_books else None) + if not results: continue - search_ref = result['resolved_ref'] - if search_ref in seen_refs: - continue + for result in results: + search_ref = result['resolved_ref'] + if search_ref in seen_refs: + continue - try: - result_oref = Ref(search_ref) + try: + result_oref = Ref(search_ref) - if not result_oref.is_segment_level(): - continue + if not result_oref.is_segment_level(): + continue - # Check if this result matches any candidate - for cand in candidates: - cand_oref = cand['oref'] - if cand_oref.contains(result_oref): - logger.info( - "Search result %s matches candidate %s for query: %s", - search_ref, - cand["ref"], - query, - ) - seen_refs.add(search_ref) - matching_candidates.append({ - 'ref': cand['ref'], # The candidate ref - 'resolved_ref': search_ref, # The specific segment from search - 'query': query, - 'raw': result - }) - break - except Exception: - continue + # Check if this result matches any candidate + for cand in candidates: + cand_oref = cand['oref'] + if cand_oref.contains(result_oref): + logger.info( + "Search result %s matches candidate %s for query: %s", + search_ref, + cand["ref"], + query, + ) + seen_refs.add(search_ref) + matching_candidates.append({ + 'ref': cand['ref'], # The candidate ref + 'resolved_ref': search_ref, # The specific segment from search + 'query': query, + 'queries': [query], + 'raw': result + }) + break + except Exception: + continue if not matching_candidates: logger.info("Search found no matches among candidates") @@ -1321,6 +1612,17 @@ def _try_search_for_candidates(marked_text: str, candidates: List[Dict[str, Any] segment_ref = match['resolved_ref'] if segment_ref not in deduped: deduped[segment_ref] = match + else: + prev = deduped[segment_ref] + merged = [] + if isinstance(prev.get("queries"), list): + merged.extend(prev["queries"]) + if isinstance(match.get("queries"), list): + merged.extend(match["queries"]) + if match.get("query"): + merged.append(match["query"]) + if merged: + prev["queries"] = sorted({q for q in merged if q}) deduped_matches = list(deduped.values()) @@ -1404,7 +1706,7 @@ def _query_sefaria_search_raw(query_text: str, slop: int = 10) -> Optional[Dict[ @traceable(run_type="tool", name="query_sefaria_search_with_books") -def _query_sefaria_search_with_books(query_text: str, books: Optional[List[str]] = None, slop: int = 10) -> Optional[Dict[str, Any]]: +def _query_sefaria_search_with_books(query_text: str, books: Optional[List[str]] = None, slop: int = 10) -> List[Dict[str, Any]]: """Query Sefaria search with optional filtering by list of books.""" bool_query = { 'must': {'match_phrase': {'naive_lemmatizer': {'query': query_text, 'slop': slop}}} @@ -1445,10 +1747,11 @@ def _query_sefaria_search_with_books(query_text: str, books: Optional[List[str]] data = resp.json() except Exception as e: logger.warning(f"Sefaria search API request failed: {e}") - return None + return [] hits = (data.get('hits') or {}).get('hits', []) + matches: List[Dict[str, Any]] = [] for entry in hits: normalized = _extract_ref_from_search_hit(entry) if not normalized: @@ -1457,11 +1760,11 @@ def _query_sefaria_search_with_books(query_text: str, books: Optional[List[str]] try: cand_oref = Ref(normalized) if cand_oref.is_segment_level(): - return { + matches.append({ 'resolved_ref': normalized, 'raw': entry - } + }) except Exception: continue - return None + return matches diff --git a/sefaria/helper/linker/tasks.py b/sefaria/helper/linker/tasks.py index d0771211d5..5a1382b2ed 100644 --- a/sefaria/helper/linker/tasks.py +++ b/sefaria/helper/linker/tasks.py @@ -23,6 +23,7 @@ NonSegmentResolutionPayload, AmbiguousResolutionResult, NonSegmentResolutionResult, + DictaAPIError, ) from dataclasses import dataclass, field, asdict from bson import ObjectId @@ -346,6 +347,7 @@ def _apply_non_segment_resolution(payload: NonSegmentResolutionPayload, result: ) _create_link_for_resolution(citing_ref, resolved_ref) + _update_linker_output_resolution_fields(payload, result) def _apply_ambiguous_resolution(payload: AmbiguousResolutionPayload, result: Optional[AmbiguousResolutionResult]) -> None: @@ -368,6 +370,26 @@ def _apply_ambiguous_resolution(payload: AmbiguousResolutionPayload, result: Opt ) _create_link_for_resolution(citing_ref, resolved_ref) + if result.matched_segment: + try: + matched_oref = Ref(result.matched_segment) + except Exception: + matched_oref = None + if matched_oref is not None and matched_oref.is_segment_level(): + _upsert_mutc_span( + ref=payload.ref, + version_title=payload.versionTitle, + language=payload.language, + char_range=payload.charRange, + text=payload.text, + resolved_ref=result.matched_segment, + ) + _create_or_update_link_for_non_segment_resolution( + citing_ref=citing_ref, + non_segment_ref=resolved_ref, + resolved_ref=result.matched_segment, + ) + _update_linker_output_resolution_fields(payload, result) def _apply_non_segment_resolution_with_record(payload: NonSegmentResolutionPayload, result: Optional[NonSegmentResolutionResult]) -> None: @@ -398,6 +420,9 @@ def _apply_non_segment_resolution_with_record(payload: NonSegmentResolutionPaylo "ref": payload.ref, "versionTitle": payload.versionTitle, "language": payload.language, + "llm_resolved_ref_non_segment": result.resolved_ref, + "llm_resolved_method_non_segment": result.method, + "llm_resolved_phrase_non_segment": result.llm_resolved_phrase, }) link_obj, action = _create_or_update_link_for_non_segment_resolution( @@ -417,7 +442,11 @@ def _apply_non_segment_resolution_with_record(payload: NonSegmentResolutionPaylo "language": payload.language, "previous_ref": payload.resolved_non_segment_ref, "resolved_ref": resolved_ref, + "llm_resolved_ref_non_segment": result.resolved_ref, + "llm_resolved_method_non_segment": result.method, + "llm_resolved_phrase_non_segment": result.llm_resolved_phrase, }) + _update_linker_output_resolution_fields(payload, result) def _apply_ambiguous_resolution_with_record(payload: AmbiguousResolutionPayload, result: Optional[AmbiguousResolutionResult]) -> None: @@ -448,6 +477,10 @@ def _apply_ambiguous_resolution_with_record(payload: AmbiguousResolutionPayload, "ref": payload.ref, "versionTitle": payload.versionTitle, "language": payload.language, + "llm_resolved_ref_ambiguous": result.matched_segment or result.resolved_ref, + "llm_resolved_method_ambiguous": result.method, + "llm_resolved_phrase_ambiguous": result.llm_resolved_phrase, + "llm_ambiguous_option_valid": True, }) link_obj = _create_link_for_resolution(citing_ref, resolved_ref) @@ -459,8 +492,92 @@ def _apply_ambiguous_resolution_with_record(payload: AmbiguousResolutionPayload, "ref": payload.ref, "versionTitle": payload.versionTitle, "language": payload.language, + "llm_resolved_ref_ambiguous": result.matched_segment or result.resolved_ref, + "llm_resolved_method_ambiguous": result.method, + "llm_resolved_phrase_ambiguous": result.llm_resolved_phrase, + "llm_ambiguous_option_valid": True, }) + if result.matched_segment: + try: + matched_oref = Ref(result.matched_segment) + except Exception: + matched_oref = None + if matched_oref is not None and matched_oref.is_segment_level(): + _upsert_mutc_span( + ref=payload.ref, + version_title=payload.versionTitle, + language=payload.language, + char_range=payload.charRange, + text=payload.text, + resolved_ref=result.matched_segment, + ) + link_obj, action = _create_or_update_link_for_non_segment_resolution( + citing_ref=citing_ref, + non_segment_ref=resolved_ref, + resolved_ref=result.matched_segment, + ) + if link_obj is not None: + _record_disambiguated_link({ + "id": link_obj._id, + "type": "link", + "action": action, + "link": link_obj.contents(), + "resolution_type": "ambiguous", + "ref": payload.ref, + "versionTitle": payload.versionTitle, + "language": payload.language, + "previous_ref": resolved_ref, + "resolved_ref": result.matched_segment, + "llm_resolved_ref_ambiguous": result.matched_segment or result.resolved_ref, + "llm_resolved_method_ambiguous": result.method, + "llm_resolved_phrase_ambiguous": result.llm_resolved_phrase, + "llm_ambiguous_option_valid": True, + }) + _update_linker_output_resolution_fields(payload, result) + + +def _update_linker_output_resolution_fields(payload: object, result: object) -> None: + """Persist resolution metadata onto LinkerOutput spans by charRange.""" + try: + query = { + "ref": payload.ref, + "versionTitle": payload.versionTitle, + "language": payload.language, + } + except Exception: + return + + linker_output = LinkerOutput().load(query) + if not linker_output: + return + + updated = False + is_ambiguous = hasattr(payload, "ambiguous_refs") + for span in linker_output.spans: + if span.get("type") != MUTCSpanType.CITATION.value: + continue + if span.get("charRange") != payload.charRange: + continue + if is_ambiguous: + is_valid = (span.get("ref") == getattr(result, "resolved_ref", None)) + span["llm_ambiguous_option_valid"] = is_valid + if is_valid: + span["llm_resolved_ref_ambiguous"] = result.matched_segment or result.resolved_ref + span["llm_resolved_method_ambiguous"] = result.method + span["llm_resolved_phrase_ambiguous"] = result.llm_resolved_phrase + else: + if span.get("ambiguous"): + if not span.get("llm_ambiguous_option_valid"): + continue + span["llm_resolved_ref_non_segment"] = result.resolved_ref + span["llm_resolved_method_non_segment"] = result.method + span["llm_resolved_phrase_non_segment"] = result.llm_resolved_phrase + updated = True + + if updated: + linker_output.save() + def _record_disambiguated_mutc(payload: dict) -> None: """ @@ -489,6 +606,36 @@ def _record_disambiguated_link(payload: dict) -> None: except Exception: logger.exception("Failed recording disambiguated link", payload=doc) + +def _record_dicta_failure(payload: dict) -> None: + doc = dict(payload) + doc["created_at"] = datetime.utcnow() + try: + db.linker_dicta_failures_tmp.insert_one(doc) + logger.info("Recorded dicta failure", payload=doc) + except Exception: + logger.exception("Failed recording dicta failure", payload=doc) + + +def _dicta_error_payload(info: dict, payload_obj: object) -> dict: + payload_doc = None + payload_type = None + try: + payload_doc = asdict(payload_obj) + payload_type = type(payload_obj).__name__ + except Exception: + payload_doc = None + return { + "type": "dicta_non_200", + "status_code": info.get("status_code"), + "url": info.get("url"), + "target_ref": info.get("target_ref"), + "query_text": (info.get("query_text") or "")[:4000], + "response_text": (info.get("response_text") or "")[:2000], + "payload": payload_doc, + "payload_type": payload_type, + } + def _extract_resolved_spans(resolved_refs): spans = [] for resolved_ref in resolved_refs: @@ -806,6 +953,8 @@ def process_ambiguous_resolution(resolution_data: dict) -> None: print(f"Ambiguous Options: {payload.ambiguous_refs}") print(f"→ RESOLVED TO: {resolved_ref}") print(f" Method: {result.method}") + if getattr(result, "llm_resolved_phrase", None): + print(f" Phrase: {result.llm_resolved_phrase}") if result.matched_segment: print(f" Matched Segment: {result.matched_segment}") print(f"{'='*80}\n") @@ -872,6 +1021,8 @@ def process_non_segment_resolution(resolution_data: dict) -> None: print(f"Original Non-Segment Ref: {payload.resolved_non_segment_ref}") print(f"→ RESOLVED TO SEGMENT: {resolved_ref}") print(f" Method: {result.method}") + if getattr(result, "llm_resolved_phrase", None): + print(f" Phrase: {result.llm_resolved_phrase}") print(f"{'='*80}\n") logger.info(f"✓ Resolved to segment: {resolved_ref} (method: {result.method})") @@ -906,13 +1057,19 @@ def cauldron_routine_disambiguation(payload: dict) -> dict: logger.info("=== Processing Bulk Disambiguation (single) ===") if "ambiguous_refs" in payload: amb_payload = AmbiguousResolutionPayload(**payload) - result = disambiguate_ambiguous_ref(amb_payload) - if result and result.resolved_ref: - _apply_ambiguous_resolution_with_record(amb_payload, result) + try: + result = disambiguate_ambiguous_ref(amb_payload) + if result and result.resolved_ref: + _apply_ambiguous_resolution_with_record(amb_payload, result) + except DictaAPIError as e: + _record_dicta_failure(_dicta_error_payload(e.info, amb_payload)) return None ns_payload = NonSegmentResolutionPayload(**payload) - result = disambiguate_non_segment_ref(ns_payload) - if result and result.resolved_ref: - _apply_non_segment_resolution_with_record(ns_payload, result) + try: + result = disambiguate_non_segment_ref(ns_payload) + if result and result.resolved_ref: + _apply_non_segment_resolution_with_record(ns_payload, result) + except DictaAPIError as e: + _record_dicta_failure(_dicta_error_payload(e.info, ns_payload)) return None diff --git a/sefaria/helper/linker/tests/ambiguous_disambiguator_test.py b/sefaria/helper/linker/tests/ambiguous_disambiguator_test.py new file mode 100644 index 0000000000..a74cb1713e --- /dev/null +++ b/sefaria/helper/linker/tests/ambiguous_disambiguator_test.py @@ -0,0 +1,93 @@ +import os + +import pytest + +from sefaria.helper.linker.disambiguator import ( + AmbiguousResolutionPayload, + disambiguate_ambiguous_ref, +) + + +TEST_CASES = [ + # { + # "id": "example_case", + # "payload": { + # "ref": "Some Commentary 1:1", + # "versionTitle": "Some Version", + # "language": "he", + # "charRange": [10, 25], + # "text": "ציטוט לדוגמה", + # "ambiguous_refs": ["Genesis 1:1-3", "Exodus 2:1-2"], + # }, + # "expected_resolutions": ["Genesis 1:1-3"], + # "expected_matched_segments": ["Genesis 1:2"], + # }, + { + "id": "mishnah_oholot_9_3_ikar_tosafot_yom_tov_5_6_2", + "payload": { + "ref": "Ikar Tosafot Yom Tov on Mishnah Oholot 5:6:2", + "versionTitle": "On Your Way", + "language": "he", + "charRange": [139, 154], + "text": "בפרק ט' משנה ג'", + "ambiguous_refs": ["Mishnah Oholot 9:3", "Ikar Tosafot Yom Tov on Mishnah Oholot 9:3"], + }, + "expected_resolutions": ["Mishnah Oholot 9:3"], + }, + { + "id": "isaiah_24_4_malbim_beur_hamilot_34_1_2", + "payload": { + "ref": "Malbim Beur Hamilot on Isaiah 34:1:2", + "versionTitle": "On Your Way", + "language": "he", + "charRange": [72, 77], + "text": "כד ד'", + "ambiguous_refs": ["Isaiah 24:4", "Malbim Beur Hamilot on Isaiah 24:4"], + }, + "expected_resolutions": ["Malbim Beur Hamilot on Isaiah 24:4"], + }, +] + + +def _missing_api_keys(): + missing = [] + if not os.getenv("ANTHROPIC_API_KEY"): + missing.append("ANTHROPIC_API_KEY") + if not os.getenv("OPENAI_API_KEY"): + missing.append("OPENAI_API_KEY") + return missing + + +@pytest.mark.deep +@pytest.mark.parametrize("case", TEST_CASES, ids=[c["id"] for c in TEST_CASES]) +def test_ambiguous_disambiguator_integration(case): + missing_keys = _missing_api_keys() + if missing_keys: + pytest.skip(f"Missing API keys for integration test: {', '.join(missing_keys)}") + + payload = AmbiguousResolutionPayload(**case["payload"]) + expected = case.get("expected_resolutions", []) + expected_matched = case.get("expected_matched_segments", []) + + result = disambiguate_ambiguous_ref(payload) + + if not expected: + assert result is None, f"Expected no resolution for case {case['id']}, got {result}" + return + + if result is None: + assert None in expected, ( + f"Expected one of {expected} for case {case['id']}, got None" + ) + return + + assert result.resolved_ref in expected, ( + f"Unexpected resolution for case {case['id']}: {result.resolved_ref} " + f"(expected one of {expected})" + ) + + if expected_matched: + assert result.matched_segment in expected_matched, ( + f"Unexpected matched segment for case {case['id']}: {result.matched_segment} " + f"(expected one of {expected_matched})" + ) diff --git a/sefaria/helper/linker/tests/non_segment_disambiguator_test.py b/sefaria/helper/linker/tests/non_segment_disambiguator_test.py new file mode 100644 index 0000000000..0ededa95bd --- /dev/null +++ b/sefaria/helper/linker/tests/non_segment_disambiguator_test.py @@ -0,0 +1,146 @@ +import os +from dataclasses import asdict + +import pytest + +from sefaria.helper.linker.disambiguator import ( + NonSegmentResolutionPayload, + disambiguate_non_segment_ref, +) + + +TEST_CASES = [ + { + "id": "jt_ketubot_2_siftei_kohen_cm_46_12_1", + "payload": { + "charRange": [245, 262], + "language": "he", + "ref": "Siftei Kohen on Shulchan Arukh, Choshen Mishpat 46:12:1", + "resolved_non_segment_ref": "Jerusalem Talmud Ketubot 2", + "text": "בירו' פ\"ב דכתובות", + "versionTitle": "Shulhan Arukh, Hoshen ha-Mishpat; Lemberg, 1898", + }, + "expected_resolutions": ["Jerusalem Talmud Ketubot 2:3:2"], + }, + # { + # "id": "shevuot_16_tzafnat_paneach_fwcn_6_8_1", + # "payload": { + # "charRange": [802, 814], + # "language": "he", + # "ref": "Tzafnat Pa'neach on Mishneh Torah, Foreign Worship and Customs of the Nations 6:8:1", + # "resolved_non_segment_ref": "Shevuot 16", + # "text": "דשבועות דט\"ז", + # "versionTitle": "Tzafnat Pa'neach on Mishneh Torah, Warsaw-Piotrków, 1903-1908", + # }, + # "expected_resolutions": ["Shevuot 16b:9:5, Shevuot 16b:9:6, Shevuot 16b:9:7", None], ## discuss with noah - i don't think we can expect it to succeed here + # }, + { + "id": "makkot_3b_ben_yehoyada_kiddushin_70a_5", + "payload": { + "charRange": [727, 734], + "language": "he", + "ref": "Ben Yehoyada on Kiddushin 70a:5", + "resolved_non_segment_ref": "Makkot 3b", + "text": "מכות ג:", + "versionTitle": "Senlake edition 2019 based on Ben Yehoyada, Jerusalem, 1897", + }, + "expected_resolutions": ["Makkot 3b:11", "Makkot 3b:12"] ## discuss noah - both are possible even though Makkot 3b:11 is better + }, + { + "id": "berakhot_19b_masoret_hatosefta_2_11_2", + "payload": { + "charRange": [70, 85], + "language": "he", + "ref": "Masoret HaTosefta on Berakhot 2:11:2", + "resolved_non_segment_ref": "Berakhot 19b", + "text": "בבלי כאן י\"ט ב'", + "versionTitle": "The Tosefta according to to codex Vienna. Third Augmented Edition, JTS 2001", + }, + "expected_resolutions": ["Berakhot 19b:1", None], ## discuss noah - search fails so none is the least evil + }, + { + "id": "jt_berakhot_3_2_masoret_hatosefta_2_11_2", + "payload": { + "charRange": [22, 43], + "language": "he", + "ref": "Masoret HaTosefta on Berakhot 2:11:2", + "resolved_non_segment_ref": "Jerusalem Talmud Berakhot 3:2", + "text": "ירוש' פ\"ג ה\"ב, ו' ע\"ב", + "versionTitle": "The Tosefta according to to codex Vienna. Third Augmented Edition, JTS 2001", + }, + "expected_resolutions": ["Jerusalem Talmud Berakhot 3:2:5"], + }, + { + "id": "gittin_37_petach_einayim_sheviit_10_1_2", + "payload": { + "charRange": [206, 218], + "language": "he", + "ref": "Petach Einayim on Mishnah Sheviit 10:1:2", + "resolved_non_segment_ref": "Gittin 37", + "text": "גיטין דף ל\"ז", + "versionTitle": "Petach Einayim, Jerusalem 1959", + }, + "expected_resolutions": ["Gittin 37a:12"], + }, + { + "id": "menachot_63a_otzar_laazei_rashi_45", + "payload": { + "charRange": [8, 17], + "language": "he", + "ref": "Otzar La'azei Rashi, Talmud, Menachot 45", + "resolved_non_segment_ref": "Menachot 63a", + "text": "מנחות סג.", + "versionTitle": "Otzar Laazei Rashi, Jerusalem, 1988", + }, + "expected_resolutions": ["Menachot 63a:9"], + }, + { + "id": "mt_ownerless_property_8_ketzot_hachoshen_cm_252_1_1", + "payload": { + "charRange": [47, 63], + "language": "he", + "ref": "Ketzot HaChoshen on Shulchan Arukh, Choshen Mishpat 252:1:1", + "resolved_non_segment_ref": "Mishneh Torah, Ownerless Property and Gifts 8", + "text": "הרמב\"ם פ\"ח מזכיה", + "versionTitle": "Shulhan Arukh, Hoshen ha-Mishpat; Lemberg, 1898", + }, + "expected_resolutions": ["Mishneh Torah, Ownerless Property and Gifts 8:9"], + }, +] + + +def _missing_api_keys(): + missing = [] + if not os.getenv("ANTHROPIC_API_KEY"): + missing.append("ANTHROPIC_API_KEY") + if not os.getenv("OPENAI_API_KEY"): + missing.append("OPENAI_API_KEY") + return missing + + +@pytest.mark.deep +@pytest.mark.parametrize("case", TEST_CASES, ids=[c["id"] for c in TEST_CASES]) +def test_non_segment_disambiguator_integration(case): + missing_keys = _missing_api_keys() + if missing_keys: + pytest.skip(f"Missing API keys for integration test: {', '.join(missing_keys)}") + + payload = NonSegmentResolutionPayload(**case["payload"]) + expected = case.get("expected_resolutions", []) + + result = disambiguate_non_segment_ref(payload) + + if not expected: + assert result is None, f"Expected no resolution for case {case['id']}, got {result}" + return + + if result is None: + assert None in expected, ( + f"Expected one of {expected} for case {case['id']}, got None" + ) + return + + assert result.resolved_ref in expected, ( + f"Unexpected resolution for case {case['id']}: {result.resolved_ref} " + f"(expected one of {expected})" + ) diff --git a/sefaria/model/marked_up_text_chunk.py b/sefaria/model/marked_up_text_chunk.py index 8ddddf0c1a..fee9703bf7 100644 --- a/sefaria/model/marked_up_text_chunk.py +++ b/sefaria/model/marked_up_text_chunk.py @@ -234,6 +234,13 @@ class LinkerOutput(MarkedUpTextChunk): "topicSlug": {"type": "string", "required": False, "nullable": True}, "contextRef": {"type": "string", "required": False, "nullable": True}, "contextType": {"type": "string", "required": False, "nullable": True}, + "llm_resolved_ref_ambiguous": {"type": "string", "required": False, "nullable": True}, + "llm_resolved_method_ambiguous": {"type": "string", "required": False, "nullable": True}, + "llm_resolved_phrase_ambiguous": {"type": "string", "required": False, "nullable": True}, + "llm_resolved_ref_non_segment": {"type": "string", "required": False, "nullable": True}, + "llm_resolved_method_non_segment": {"type": "string", "required": False, "nullable": True}, + "llm_resolved_phrase_non_segment": {"type": "string", "required": False, "nullable": True}, + "llm_ambiguous_option_valid": {"type": "boolean", "required": False, "nullable": True}, "failed": {"type": "boolean", "required": True}, "ambiguous": {"type": "boolean", "required": True}, **{k: {"type": "list", "schema": {"type": "string"}, "required": False, "nullable": True} for k in optional_list_str_schema_keys}